diff --git a/CMakeLists.txt b/CMakeLists.txt index 45f290e..98b8a09 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -21,6 +21,7 @@ target_link_libraries(${PROJECT_NAME} RapidXML) add_subdirectory("src/dml") add_subdirectory("src/protocol") +add_subdirectory("src/util") option(KI_BUILD_EXAMPLES "Determines whether to build examples." ON) if (KI_BUILD_EXAMPLES) diff --git a/include/ki/util/BitStream.h b/include/ki/util/BitStream.h new file mode 100644 index 0000000..f50bf8e --- /dev/null +++ b/include/ki/util/BitStream.h @@ -0,0 +1,152 @@ +#pragma once +#include +#include + +#define KI_BITSTREAM_DEFAULT_BUFFER_SIZE 0x2000 + +namespace ki +{ + /** + * + */ + class BitStream + { + + public: + /** + * Represents a position in a BitStream's buffer. + */ + struct stream_pos + { + explicit stream_pos(intmax_t byte = 0, int bit = 0); + stream_pos(const stream_pos &cp); + + intmax_t get_byte() const; + uint8_t get_bit() const; + + stream_pos operator +(const stream_pos &rhs) const; + stream_pos operator -(const stream_pos &rhs) const; + stream_pos operator +(const int &rhs) const; + stream_pos operator -(const int &rhs) const; + stream_pos &operator +=(stream_pos lhs); + stream_pos &operator -=(stream_pos lhs); + stream_pos &operator +=(int bits); + stream_pos &operator -=(int bits); + stream_pos &operator ++(); + stream_pos &operator --(); + + private: + intmax_t m_byte; + uint8_t m_bit; + + void set_bit(int bit); + }; + + explicit BitStream(size_t buffer_size = KI_BITSTREAM_DEFAULT_BUFFER_SIZE); + ~BitStream(); + + /** + * @return The stream's current position. + */ + stream_pos tell() const; + + /** + * Sets the position of the stream. + * @param position The new position of the stream. + */ + void seek(stream_pos position); + + /** + * @return A pointer to the start of the internal buffer. + */ + const uint8_t *data() const; + + /** + * Reads a value from the buffer given a defined number of bits. + * @param bits The number of bits to read. + */ + template < + typename IntegerT, + typename = std::enable_if::value> + > + IntegerT read(const uint8_t bits) + { + IntegerT value = 0; + + // Iterate until we've read all of the bits + auto unread_bits = bits; + while (unread_bits > 0) + { + // Calculate how many bits to read from the current byte based on how many bits + // are left and how many bits we still need to read + const uint8_t bits_available = (8 - m_position.get_bit()); + const auto bit_count = unread_bits < bits_available ? unread_bits : bits_available; + + // Find the bit-mask based on how many bits are being read + const uint8_t bit_mask = ((1 << bit_count) - 1) << m_position.get_bit(); + + // Read the bits from the current byte and position them on the least-signficant bit + const uint8_t bits_value = (m_buffer[m_position.get_byte()] & bit_mask) >> m_position.get_bit(); + + // Position the value of the bits we just read based on how many bits of the value + // we've already read + const uint8_t read_bits = bits - unread_bits; + value |= (IntegerT)bits_value << read_bits; + + // Remove the bits we just read from the count of unread bits + unread_bits -= bit_count; + + // Move forward the number of bits we just read + seek(tell() + bit_count); + } + + return value; + } + + /** + * Writes a value to the buffer that occupies a defined number of bits. + * @param value The value to write. + * @param bits The number of bits to use. + */ + template < + typename IntegerT, + typename = std::enable_if::value> + > + void write(IntegerT value, const uint8_t bits) + { + // Iterate until we've written all of the bits + auto unwritten_bits = bits; + while (unwritten_bits > 0) + { + // Calculate how many bits to write based on how many bits are left in the current byte + // and how many bits from the value we still need to write + const uint8_t bits_available = (8 - m_position.get_bit()); + const auto bit_count = unwritten_bits < bits_available ? unwritten_bits : bits_available; + + // Find the bit-mask based on how many bits are being written, and how many bits we've + // already written + const uint8_t written_bits = bits - unwritten_bits; + IntegerT bit_mask = (IntegerT)((1 << bit_count) - 1) << written_bits; + + // Get the bits from the value and position them at the current bit position + uint8_t value_byte = ((value & bit_mask) >> written_bits) & 0xFF; + value_byte <<= m_position.get_bit(); + + // Write the bits into the byte we're currently at + m_buffer[m_position.get_byte()] |= value_byte; + unwritten_bits -= bit_count; + + // Move forward the number of bits we just wrote + seek(tell() + bit_count); + } + } + + private: + uint8_t *m_buffer; + size_t m_buffer_size; + stream_pos m_position; + + void expand_buffer(); + void validate_buffer(); + }; +} diff --git a/src/util/BitStream.cpp b/src/util/BitStream.cpp new file mode 100644 index 0000000..b9308f8 --- /dev/null +++ b/src/util/BitStream.cpp @@ -0,0 +1,174 @@ +#include "ki/util/BitStream.h" +#include +#include + +namespace ki +{ + BitStream::stream_pos::stream_pos(const intmax_t byte, const int bit) + { + m_byte = byte; + set_bit(bit); + } + + BitStream::stream_pos::stream_pos(const stream_pos& cp) + { + m_byte = cp.m_byte; + set_bit(cp.m_bit); + } + + intmax_t BitStream::stream_pos::get_byte() const + { + return m_byte; + } + + uint8_t BitStream::stream_pos::get_bit() const + { + return m_bit; + } + + void BitStream::stream_pos::set_bit(int bit) + { + if (bit < 0) + { + bit = -bit; + m_byte -= (bit / 8) + 1; + m_bit = 8 - (bit % 8); + } + else if (bit >= 8) + { + m_byte += bit / 8; + m_bit = bit % 8; + } + else + m_bit = bit; + } + + BitStream::stream_pos BitStream::stream_pos::operator+( + const stream_pos &rhs) const + { + return stream_pos( + m_byte + rhs.m_byte, m_bit + rhs.m_bit + ); + } + + BitStream::stream_pos BitStream::stream_pos::operator-( + const stream_pos &rhs) const + { + return stream_pos( + m_byte - rhs.m_byte, m_bit - rhs.m_bit + ); + } + + BitStream::stream_pos BitStream::stream_pos::operator+( + const int& rhs) const + { + return stream_pos( + m_byte, m_bit + rhs + ); + } + + BitStream::stream_pos BitStream::stream_pos::operator-( + const int& rhs) const + { + return stream_pos( + m_byte, m_bit - rhs + ); + } + + BitStream::stream_pos& BitStream::stream_pos::operator+=( + const stream_pos rhs) + { + m_byte += rhs.m_byte; + set_bit(m_bit + rhs.m_bit); + return *this; + } + + BitStream::stream_pos& BitStream::stream_pos::operator-=( + const stream_pos rhs) + { + m_byte -= rhs.m_byte; + set_bit(m_bit - rhs.m_bit); + return *this; + } + + BitStream::stream_pos& BitStream::stream_pos::operator+=(const int bits) + { + set_bit(m_bit + bits); + return *this; + } + + BitStream::stream_pos& BitStream::stream_pos::operator-=(const int bits) + { + set_bit(m_bit - bits); + return *this; + } + + BitStream::stream_pos& BitStream::stream_pos::operator++() + { + set_bit(m_bit + 1); + return *this; + } + + BitStream::stream_pos& BitStream::stream_pos::operator--() + { + set_bit(m_bit - 1); + return *this; + } + + BitStream::BitStream(const size_t buffer_size) + { + m_buffer = new uint8_t[buffer_size] { 0 }; + m_buffer_size = buffer_size; + m_position = stream_pos(0, 0); + } + + BitStream::~BitStream() + { + delete[] m_buffer; + } + + BitStream::stream_pos BitStream::tell() const + { + return m_position; + } + + void BitStream::seek(const stream_pos position) + { + m_position = position; + validate_buffer(); + } + + const uint8_t* BitStream::data() const + { + return m_buffer; + } + + void BitStream::expand_buffer() + { + // Work out a new buffer size + auto new_size = (m_buffer_size << 1) + 2; + if (new_size < m_buffer_size) + new_size = std::numeric_limits::max(); + + // Has the buffer reached maximum size? + if (new_size == m_buffer_size) + throw std::exception("Buffer cannot be expanded as it has reached maximum size."); + + // Allocate a new buffer, copy everything over, and then delete the old buffer + auto *new_buffer = new uint8_t[new_size] { 0 }; + memcpy(new_buffer, m_buffer, m_buffer_size); + delete[] m_buffer; + m_buffer = new_buffer; + } + + void BitStream::validate_buffer() + { + // Make sure we haven't underflowed + if (m_position.get_byte() < 0) + throw std::exception("Position of buffer is less than 0!"); + + // Expand the buffer if we've overflowed + if (m_position.get_byte() >= m_buffer_size) + expand_buffer(); + } +} diff --git a/src/util/CMakeLists.txt b/src/util/CMakeLists.txt new file mode 100644 index 0000000..e6ee9ce --- /dev/null +++ b/src/util/CMakeLists.txt @@ -0,0 +1,4 @@ +target_sources(${PROJECT_NAME} + PRIVATE + ${PROJECT_SOURCE_DIR}/src/util/BitStream.cpp +) \ No newline at end of file diff --git a/test/samples/bitstream1.bin b/test/samples/bitstream1.bin new file mode 100644 index 0000000..9e6a9ab --- /dev/null +++ b/test/samples/bitstream1.bin @@ -0,0 +1 @@ +UUUU \ No newline at end of file diff --git a/test/samples/bitstream2.bin b/test/samples/bitstream2.bin new file mode 100644 index 0000000..59961e8 --- /dev/null +++ b/test/samples/bitstream2.bin @@ -0,0 +1,2 @@ + +  \ No newline at end of file diff --git a/test/src/unit-bitstream.cpp b/test/src/unit-bitstream.cpp new file mode 100644 index 0000000..cc9930f --- /dev/null +++ b/test/src/unit-bitstream.cpp @@ -0,0 +1,147 @@ +#define CATCH_CONFIG_MAIN +#include + +#include + +using namespace ki; + +TEST_CASE("Write bits", "[bit-stream]") +{ + auto *bit_stream = new BitStream(); + + // Write an alternating pattern of bits + bit_stream->write(0b1, 1); + bit_stream->write(0b10, 2); + bit_stream->write(0b010, 3); + bit_stream->write(0b0101, 4); + bit_stream->write(0b10101, 5); + bit_stream->write(0b101010, 6); + bit_stream->write(0b0101010, 7); + bit_stream->write(0b0101, 4); + + // Make sure tell is reporting the right position + auto position = bit_stream->tell(); + if (position.get_byte() != 4 || position.get_bit() != 0) + FAIL(); + const auto size = position.get_byte(); + + // Validate what we've got here with a hand-written sample + std::ifstream sample("samples/bitstream1.bin", std::ios::binary); + if (!sample.is_open()) + FAIL(); + + // Load the sample data and compare + auto *sample_data = new char[size + 1] { 0 }; + sample.read(sample_data, size); + if (strcmp(sample_data, (char *)bit_stream->data()) != 0) + FAIL(); + + // Free resources + delete bit_stream; + delete[] sample_data; +} + +TEST_CASE("Write bytes", "[bit-stream]") +{ + auto *bit_stream = new BitStream(); + + // Write an alternating pattern of bits + bit_stream->write(0x01, 8); + bit_stream->write(0x0302, 16); + bit_stream->write(0x060504, 24); + bit_stream->write(0x0A090807, 32); + bit_stream->write(0x1211100F0E0D0C0BL, 64); + + // Make sure tell is reporting the right position + auto position = bit_stream->tell(); + if (position.get_byte() != 18 || position.get_bit() != 0) + FAIL(); + const auto size = position.get_byte(); + + // Validate what we've got here with a hand-written sample + std::ifstream sample("samples/bitstream2.bin", std::ios::binary); + if (!sample.is_open()) + FAIL(); + + // Load the sample data and compare + auto *sample_data = new char[size + 1]{ 0 }; + sample.read(sample_data, size); + if (strcmp(sample_data, (char *)bit_stream->data()) != 0) + FAIL(); + + // Free resources + delete bit_stream; + delete[] sample_data; +} + +TEST_CASE("Read bits", "[bit-stream]") +{ + auto *bit_stream = new BitStream(); + + // Validate what we've got here with a hand-written sample + std::ifstream sample("samples/bitstream1.bin", std::ios::binary); + if (!sample.is_open()) + FAIL(); + + // Load the sample data into the bit stream + const auto begin = sample.tellg(); + sample.seekg(0, std::ios::end); + const auto end = sample.tellg(); + const size_t size = end - begin; + sample.seekg(std::ios::beg); + sample.read((char *)bit_stream->data(), size); + + // Read the values and check they are what we are expecting + if (bit_stream->read(1) != 0b1) + FAIL(); + if (bit_stream->read(2) != 0b10) + FAIL(); + if (bit_stream->read(3) != 0b010) + FAIL(); + if (bit_stream->read(4) != 0b0101) + FAIL(); + if (bit_stream->read(5) != 0b10101) + FAIL(); + if (bit_stream->read(6) != 0b101010) + FAIL(); + if (bit_stream->read(7) != 0b0101010) + FAIL(); + if (bit_stream->read(4) != 0b0101) + FAIL(); + + // Free resources + delete bit_stream; +} + +TEST_CASE("Read bytes", "[bit-stream]") +{ + auto *bit_stream = new BitStream(); + + // Validate what we've got here with a hand-written sample + std::ifstream sample("samples/bitstream2.bin", std::ios::binary); + if (!sample.is_open()) + FAIL(); + + // Load the sample data into the bit stream + const auto begin = sample.tellg(); + sample.seekg(0, std::ios::end); + const auto end = sample.tellg(); + const size_t size = end - begin; + sample.seekg(std::ios::beg); + sample.read((char *)bit_stream->data(), size); + + // Read the values and check they are what we are expecting + if (bit_stream->read(8) != 0x01) + FAIL(); + if (bit_stream->read(16) != 0x0302) + FAIL(); + if (bit_stream->read(24) != 0x060504) + FAIL(); + if (bit_stream->read(32) != 0x0A090807) + FAIL(); + if (bit_stream->read(64) != 0x1211100F0E0D0C0BU) + FAIL(); + + // Free resources + delete bit_stream; +}