util: Add BitStream class + tests

This commit is contained in:
Joshua Scott 2018-10-20 00:09:50 +01:00
parent 1c013677ea
commit 6249d6ee32
7 changed files with 481 additions and 0 deletions

View File

@ -21,6 +21,7 @@ target_link_libraries(${PROJECT_NAME} RapidXML)
add_subdirectory("src/dml")
add_subdirectory("src/protocol")
add_subdirectory("src/util")
option(KI_BUILD_EXAMPLES "Determines whether to build examples." ON)
if (KI_BUILD_EXAMPLES)

152
include/ki/util/BitStream.h Normal file
View File

@ -0,0 +1,152 @@
#pragma once
#include <cstdint>
#include <type_traits>
#define KI_BITSTREAM_DEFAULT_BUFFER_SIZE 0x2000
namespace ki
{
/**
*
*/
class BitStream
{
public:
/**
* Represents a position in a BitStream's buffer.
*/
struct stream_pos
{
explicit stream_pos(intmax_t byte = 0, int bit = 0);
stream_pos(const stream_pos &cp);
intmax_t get_byte() const;
uint8_t get_bit() const;
stream_pos operator +(const stream_pos &rhs) const;
stream_pos operator -(const stream_pos &rhs) const;
stream_pos operator +(const int &rhs) const;
stream_pos operator -(const int &rhs) const;
stream_pos &operator +=(stream_pos lhs);
stream_pos &operator -=(stream_pos lhs);
stream_pos &operator +=(int bits);
stream_pos &operator -=(int bits);
stream_pos &operator ++();
stream_pos &operator --();
private:
intmax_t m_byte;
uint8_t m_bit;
void set_bit(int bit);
};
explicit BitStream(size_t buffer_size = KI_BITSTREAM_DEFAULT_BUFFER_SIZE);
~BitStream();
/**
* @return The stream's current position.
*/
stream_pos tell() const;
/**
* Sets the position of the stream.
* @param position The new position of the stream.
*/
void seek(stream_pos position);
/**
* @return A pointer to the start of the internal buffer.
*/
const uint8_t *data() const;
/**
* Reads a value from the buffer given a defined number of bits.
* @param bits The number of bits to read.
*/
template <
typename IntegerT,
typename = std::enable_if<std::is_integral<IntegerT>::value>
>
IntegerT read(const uint8_t bits)
{
IntegerT value = 0;
// Iterate until we've read all of the bits
auto unread_bits = bits;
while (unread_bits > 0)
{
// Calculate how many bits to read from the current byte based on how many bits
// are left and how many bits we still need to read
const uint8_t bits_available = (8 - m_position.get_bit());
const auto bit_count = unread_bits < bits_available ? unread_bits : bits_available;
// Find the bit-mask based on how many bits are being read
const uint8_t bit_mask = ((1 << bit_count) - 1) << m_position.get_bit();
// Read the bits from the current byte and position them on the least-signficant bit
const uint8_t bits_value = (m_buffer[m_position.get_byte()] & bit_mask) >> m_position.get_bit();
// Position the value of the bits we just read based on how many bits of the value
// we've already read
const uint8_t read_bits = bits - unread_bits;
value |= (IntegerT)bits_value << read_bits;
// Remove the bits we just read from the count of unread bits
unread_bits -= bit_count;
// Move forward the number of bits we just read
seek(tell() + bit_count);
}
return value;
}
/**
* Writes a value to the buffer that occupies a defined number of bits.
* @param value The value to write.
* @param bits The number of bits to use.
*/
template <
typename IntegerT,
typename = std::enable_if<std::is_integral<IntegerT>::value>
>
void write(IntegerT value, const uint8_t bits)
{
// Iterate until we've written all of the bits
auto unwritten_bits = bits;
while (unwritten_bits > 0)
{
// Calculate how many bits to write based on how many bits are left in the current byte
// and how many bits from the value we still need to write
const uint8_t bits_available = (8 - m_position.get_bit());
const auto bit_count = unwritten_bits < bits_available ? unwritten_bits : bits_available;
// Find the bit-mask based on how many bits are being written, and how many bits we've
// already written
const uint8_t written_bits = bits - unwritten_bits;
IntegerT bit_mask = (IntegerT)((1 << bit_count) - 1) << written_bits;
// Get the bits from the value and position them at the current bit position
uint8_t value_byte = ((value & bit_mask) >> written_bits) & 0xFF;
value_byte <<= m_position.get_bit();
// Write the bits into the byte we're currently at
m_buffer[m_position.get_byte()] |= value_byte;
unwritten_bits -= bit_count;
// Move forward the number of bits we just wrote
seek(tell() + bit_count);
}
}
private:
uint8_t *m_buffer;
size_t m_buffer_size;
stream_pos m_position;
void expand_buffer();
void validate_buffer();
};
}

174
src/util/BitStream.cpp Normal file
View File

@ -0,0 +1,174 @@
#include "ki/util/BitStream.h"
#include <limits>
#include <exception>
namespace ki
{
BitStream::stream_pos::stream_pos(const intmax_t byte, const int bit)
{
m_byte = byte;
set_bit(bit);
}
BitStream::stream_pos::stream_pos(const stream_pos& cp)
{
m_byte = cp.m_byte;
set_bit(cp.m_bit);
}
intmax_t BitStream::stream_pos::get_byte() const
{
return m_byte;
}
uint8_t BitStream::stream_pos::get_bit() const
{
return m_bit;
}
void BitStream::stream_pos::set_bit(int bit)
{
if (bit < 0)
{
bit = -bit;
m_byte -= (bit / 8) + 1;
m_bit = 8 - (bit % 8);
}
else if (bit >= 8)
{
m_byte += bit / 8;
m_bit = bit % 8;
}
else
m_bit = bit;
}
BitStream::stream_pos BitStream::stream_pos::operator+(
const stream_pos &rhs) const
{
return stream_pos(
m_byte + rhs.m_byte, m_bit + rhs.m_bit
);
}
BitStream::stream_pos BitStream::stream_pos::operator-(
const stream_pos &rhs) const
{
return stream_pos(
m_byte - rhs.m_byte, m_bit - rhs.m_bit
);
}
BitStream::stream_pos BitStream::stream_pos::operator+(
const int& rhs) const
{
return stream_pos(
m_byte, m_bit + rhs
);
}
BitStream::stream_pos BitStream::stream_pos::operator-(
const int& rhs) const
{
return stream_pos(
m_byte, m_bit - rhs
);
}
BitStream::stream_pos& BitStream::stream_pos::operator+=(
const stream_pos rhs)
{
m_byte += rhs.m_byte;
set_bit(m_bit + rhs.m_bit);
return *this;
}
BitStream::stream_pos& BitStream::stream_pos::operator-=(
const stream_pos rhs)
{
m_byte -= rhs.m_byte;
set_bit(m_bit - rhs.m_bit);
return *this;
}
BitStream::stream_pos& BitStream::stream_pos::operator+=(const int bits)
{
set_bit(m_bit + bits);
return *this;
}
BitStream::stream_pos& BitStream::stream_pos::operator-=(const int bits)
{
set_bit(m_bit - bits);
return *this;
}
BitStream::stream_pos& BitStream::stream_pos::operator++()
{
set_bit(m_bit + 1);
return *this;
}
BitStream::stream_pos& BitStream::stream_pos::operator--()
{
set_bit(m_bit - 1);
return *this;
}
BitStream::BitStream(const size_t buffer_size)
{
m_buffer = new uint8_t[buffer_size] { 0 };
m_buffer_size = buffer_size;
m_position = stream_pos(0, 0);
}
BitStream::~BitStream()
{
delete[] m_buffer;
}
BitStream::stream_pos BitStream::tell() const
{
return m_position;
}
void BitStream::seek(const stream_pos position)
{
m_position = position;
validate_buffer();
}
const uint8_t* BitStream::data() const
{
return m_buffer;
}
void BitStream::expand_buffer()
{
// Work out a new buffer size
auto new_size = (m_buffer_size << 1) + 2;
if (new_size < m_buffer_size)
new_size = std::numeric_limits<size_t>::max();
// Has the buffer reached maximum size?
if (new_size == m_buffer_size)
throw std::exception("Buffer cannot be expanded as it has reached maximum size.");
// Allocate a new buffer, copy everything over, and then delete the old buffer
auto *new_buffer = new uint8_t[new_size] { 0 };
memcpy(new_buffer, m_buffer, m_buffer_size);
delete[] m_buffer;
m_buffer = new_buffer;
}
void BitStream::validate_buffer()
{
// Make sure we haven't underflowed
if (m_position.get_byte() < 0)
throw std::exception("Position of buffer is less than 0!");
// Expand the buffer if we've overflowed
if (m_position.get_byte() >= m_buffer_size)
expand_buffer();
}
}

4
src/util/CMakeLists.txt Normal file
View File

@ -0,0 +1,4 @@
target_sources(${PROJECT_NAME}
PRIVATE
${PROJECT_SOURCE_DIR}/src/util/BitStream.cpp
)

View File

@ -0,0 +1 @@
UUUU

View File

@ -0,0 +1,2 @@



147
test/src/unit-bitstream.cpp Normal file
View File

@ -0,0 +1,147 @@
#define CATCH_CONFIG_MAIN
#include <catch.hpp>
#include <ki/util/BitStream.h>
using namespace ki;
TEST_CASE("Write bits", "[bit-stream]")
{
auto *bit_stream = new BitStream();
// Write an alternating pattern of bits
bit_stream->write(0b1, 1);
bit_stream->write(0b10, 2);
bit_stream->write(0b010, 3);
bit_stream->write(0b0101, 4);
bit_stream->write(0b10101, 5);
bit_stream->write(0b101010, 6);
bit_stream->write(0b0101010, 7);
bit_stream->write(0b0101, 4);
// Make sure tell is reporting the right position
auto position = bit_stream->tell();
if (position.get_byte() != 4 || position.get_bit() != 0)
FAIL();
const auto size = position.get_byte();
// Validate what we've got here with a hand-written sample
std::ifstream sample("samples/bitstream1.bin", std::ios::binary);
if (!sample.is_open())
FAIL();
// Load the sample data and compare
auto *sample_data = new char[size + 1] { 0 };
sample.read(sample_data, size);
if (strcmp(sample_data, (char *)bit_stream->data()) != 0)
FAIL();
// Free resources
delete bit_stream;
delete[] sample_data;
}
TEST_CASE("Write bytes", "[bit-stream]")
{
auto *bit_stream = new BitStream();
// Write an alternating pattern of bits
bit_stream->write<uint8_t>(0x01, 8);
bit_stream->write<uint16_t>(0x0302, 16);
bit_stream->write<uint32_t>(0x060504, 24);
bit_stream->write<uint32_t>(0x0A090807, 32);
bit_stream->write<uint64_t>(0x1211100F0E0D0C0BL, 64);
// Make sure tell is reporting the right position
auto position = bit_stream->tell();
if (position.get_byte() != 18 || position.get_bit() != 0)
FAIL();
const auto size = position.get_byte();
// Validate what we've got here with a hand-written sample
std::ifstream sample("samples/bitstream2.bin", std::ios::binary);
if (!sample.is_open())
FAIL();
// Load the sample data and compare
auto *sample_data = new char[size + 1]{ 0 };
sample.read(sample_data, size);
if (strcmp(sample_data, (char *)bit_stream->data()) != 0)
FAIL();
// Free resources
delete bit_stream;
delete[] sample_data;
}
TEST_CASE("Read bits", "[bit-stream]")
{
auto *bit_stream = new BitStream();
// Validate what we've got here with a hand-written sample
std::ifstream sample("samples/bitstream1.bin", std::ios::binary);
if (!sample.is_open())
FAIL();
// Load the sample data into the bit stream
const auto begin = sample.tellg();
sample.seekg(0, std::ios::end);
const auto end = sample.tellg();
const size_t size = end - begin;
sample.seekg(std::ios::beg);
sample.read((char *)bit_stream->data(), size);
// Read the values and check they are what we are expecting
if (bit_stream->read<uint8_t>(1) != 0b1)
FAIL();
if (bit_stream->read<uint8_t>(2) != 0b10)
FAIL();
if (bit_stream->read<uint8_t>(3) != 0b010)
FAIL();
if (bit_stream->read<uint8_t>(4) != 0b0101)
FAIL();
if (bit_stream->read<uint8_t>(5) != 0b10101)
FAIL();
if (bit_stream->read<uint8_t>(6) != 0b101010)
FAIL();
if (bit_stream->read<uint8_t>(7) != 0b0101010)
FAIL();
if (bit_stream->read<uint8_t>(4) != 0b0101)
FAIL();
// Free resources
delete bit_stream;
}
TEST_CASE("Read bytes", "[bit-stream]")
{
auto *bit_stream = new BitStream();
// Validate what we've got here with a hand-written sample
std::ifstream sample("samples/bitstream2.bin", std::ios::binary);
if (!sample.is_open())
FAIL();
// Load the sample data into the bit stream
const auto begin = sample.tellg();
sample.seekg(0, std::ios::end);
const auto end = sample.tellg();
const size_t size = end - begin;
sample.seekg(std::ios::beg);
sample.read((char *)bit_stream->data(), size);
// Read the values and check they are what we are expecting
if (bit_stream->read<uint8_t>(8) != 0x01)
FAIL();
if (bit_stream->read<uint16_t>(16) != 0x0302)
FAIL();
if (bit_stream->read<uint32_t>(24) != 0x060504)
FAIL();
if (bit_stream->read<uint32_t>(32) != 0x0A090807)
FAIL();
if (bit_stream->read<uint64_t>(64) != 0x1211100F0E0D0C0BU)
FAIL();
// Free resources
delete bit_stream;
}