diff --git a/CMakeLists.txt b/CMakeLists.txt index 9af9364..45f290e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,6 +20,7 @@ target_include_directories(${PROJECT_NAME} target_link_libraries(${PROJECT_NAME} RapidXML) add_subdirectory("src/dml") +add_subdirectory("src/protocol") option(KI_BUILD_EXAMPLES "Determines whether to build examples." ON) if (KI_BUILD_EXAMPLES) diff --git a/examples/src/example-dml-module.cpp b/examples/src/example-dml-module.cpp new file mode 100644 index 0000000..108df4f --- /dev/null +++ b/examples/src/example-dml-module.cpp @@ -0,0 +1,64 @@ +#include +#include +#include + +using namespace ki::protocol; + +int main(int argc, char **argv) +{ + // Get command-line arguments + if (argc < 3) + { + std::cout << "usage: example-dml-module.exe " << std::endl; + std::cout << "Prints out information for specified message." << std::endl; + return 1; + } + + // Create a manager to load modules into + auto *message_manager = new dml::MessageManager(); + const dml::MessageModule *message_module; + + // Load the message module file + const std::string filepath = argv[1]; + try + { + message_module = message_manager->load_module(filepath); + } + catch (value_error &e) + { + std::cout << "Failed to load message module."; + return 1; + } + + // Print some information about the module itself + std::cout << "Service ID: " << (uint16_t)message_module->get_service_id() << std::endl; + std::cout << "Protocol Type: " << message_module->get_protocol_type() << std::endl; + + // Get the message template from the module we just loaded + const std::string message_name = argv[2]; + auto *message_template = message_module->get_message_template(message_name); + if (message_template) + { + std::cout << "Message Name: " << message_template->get_name() << std::endl; + std::cout << "Mesasge Type: " << (uint16_t)message_template->get_type() << std::endl; + + // Print out the fields in the template record + std::cout << std::endl; + auto &record = message_template->get_record(); + for (auto it = record.fields_begin(); + it != record.fields_end(); ++it) + { + auto *field = *it; + if (field->is_transferable()) + std::cout << field->get_type_name() << " " << field->get_name() << ";" << std::endl; + } + } + else + { + std::cout << "Could not find message with name: " << message_name << std::endl; + return 1; + } + + // Exit successfully + return 0; +} diff --git a/include/ki/protocol/control/ClientKeepAlive.h b/include/ki/protocol/control/ClientKeepAlive.h new file mode 100644 index 0000000..5b9f9bb --- /dev/null +++ b/include/ki/protocol/control/ClientKeepAlive.h @@ -0,0 +1,38 @@ +#pragma once +#include "../../util/Serializable.h" +#include +#include + +namespace ki +{ +namespace protocol +{ +namespace control +{ + class ClientKeepAlive final : public util::Serializable + { + public: + ClientKeepAlive(uint16_t session_id = 0, + uint16_t milliseconds = 0, uint16_t minutes = 0); + virtual ~ClientKeepAlive() = default; + + uint16_t get_session_id() const; + void set_session_id(uint16_t session_id); + + uint16_t get_milliseconds() const; + void set_milliseconds(uint16_t milliseconds); + + uint16_t get_minutes() const; + void set_minutes(uint16_t minutes); + + void write_to(std::ostream &ostream) const override final; + void read_from(std::istream &istream) override final; + size_t get_size() const override final; + private: + uint16_t m_session_id; + uint16_t m_milliseconds; + uint16_t m_minutes; + }; +} +} +} diff --git a/include/ki/protocol/control/Opcode.h b/include/ki/protocol/control/Opcode.h new file mode 100644 index 0000000..c63b6aa --- /dev/null +++ b/include/ki/protocol/control/Opcode.h @@ -0,0 +1,21 @@ +#pragma once +#include + +namespace ki +{ +namespace protocol +{ +namespace control +{ + enum class Opcode : uint8_t + { + NONE = 0, + SESSION_OFFER = 0, + UDP_HELLO = 1, + KEEP_ALIVE = 3, + KEEP_ALIVE_RSP = 4, + SESSION_ACCEPT = 5 + }; +} +} +} \ No newline at end of file diff --git a/include/ki/protocol/control/ServerKeepAlive.h b/include/ki/protocol/control/ServerKeepAlive.h new file mode 100644 index 0000000..520ff51 --- /dev/null +++ b/include/ki/protocol/control/ServerKeepAlive.h @@ -0,0 +1,29 @@ +#pragma once +#include "../../util/Serializable.h" +#include +#include + +namespace ki +{ +namespace protocol +{ +namespace control +{ + class ServerKeepAlive final : public util::Serializable + { + public: + ServerKeepAlive(uint32_t timestamp = 0); + virtual ~ServerKeepAlive() = default; + + uint32_t get_timestamp() const; + void set_timestamp(uint32_t timestamp); + + void write_to(std::ostream &ostream) const override final; + void read_from(std::istream &istream) override final; + size_t get_size() const override final; + private: + uint32_t m_timestamp; + }; +} +} +} diff --git a/include/ki/protocol/control/SessionAccept.h b/include/ki/protocol/control/SessionAccept.h new file mode 100644 index 0000000..4350e11 --- /dev/null +++ b/include/ki/protocol/control/SessionAccept.h @@ -0,0 +1,38 @@ +#pragma once +#include "../../util/Serializable.h" +#include +#include + +namespace ki +{ +namespace protocol +{ +namespace control +{ + class SessionAccept final : public util::Serializable + { + public: + SessionAccept(uint16_t session_id = 0, + int32_t timestamp = 0, uint32_t milliseconds = 0); + virtual ~SessionAccept() = default; + + uint16_t get_session_id() const; + void set_session_id(uint16_t session_id); + + int32_t get_timestamp() const; + void set_timestamp(int32_t timestamp); + + uint32_t get_milliseconds() const; + void set_milliseconds(uint32_t milliseconds); + + void write_to(std::ostream &ostream) const override final; + void read_from(std::istream &istream) override final; + size_t get_size() const override final; + private: + uint16_t m_session_id; + int32_t m_timestamp; + uint32_t m_milliseconds; + }; +} +} +} diff --git a/include/ki/protocol/control/SessionOffer.h b/include/ki/protocol/control/SessionOffer.h new file mode 100644 index 0000000..ab59fb1 --- /dev/null +++ b/include/ki/protocol/control/SessionOffer.h @@ -0,0 +1,38 @@ +#pragma once +#include "../../util/Serializable.h" +#include +#include + +namespace ki +{ +namespace protocol +{ +namespace control +{ + class SessionOffer final : public util::Serializable + { + public: + SessionOffer(uint16_t session_id = 0, + int32_t timestamp = 0, uint32_t milliseconds = 0); + virtual ~SessionOffer() = default; + + uint16_t get_session_id() const; + void set_session_id(uint16_t session_id); + + int32_t get_timestamp() const; + void set_timestamp(int32_t timestamp); + + uint32_t get_milliseconds() const; + void set_milliseconds(uint32_t milliseconds); + + void write_to(std::ostream &ostream) const override final; + void read_from(std::istream &istream) override final; + size_t get_size() const override final; + private: + uint16_t m_session_id; + int32_t m_timestamp; + uint32_t m_milliseconds; + }; +} +} +} diff --git a/include/ki/protocol/dml/Message.h b/include/ki/protocol/dml/Message.h new file mode 100644 index 0000000..bddbb35 --- /dev/null +++ b/include/ki/protocol/dml/Message.h @@ -0,0 +1,50 @@ +#pragma once +#include "MessageHeader.h" +#include "../../util/Serializable.h" +#include "../../dml/Record.h" +#include + +namespace ki +{ +namespace protocol +{ +namespace dml +{ + class MessageTemplate; + + class Message final : public util::Serializable + { + public: + Message(const MessageTemplate *message_template = nullptr); + virtual ~Message(); + + const MessageTemplate *get_template() const; + void set_template(const MessageTemplate *message_template); + + ki::dml::Record *get_record(); + const ki::dml::Record *get_record() const; + + ki::dml::FieldBase *get_field(std::string name); + const ki::dml::FieldBase *get_field(std::string name) const; + + uint8_t get_service_id() const; + uint8_t get_type() const; + uint16_t get_message_size() const; + std::string get_handler() const; + uint8_t get_access_level() const; + + void write_to(std::ostream &ostream) const override final; + void read_from(std::istream &istream) override final; + size_t get_size() const override final; + private: + const MessageTemplate *m_template; + ki::dml::Record *m_record; + + // This is used to store raw data when a Message is + // constructed without a MessageTemplate. + MessageHeader m_header; + std::vector m_raw_data; + }; +} +} +} diff --git a/include/ki/protocol/dml/MessageHeader.h b/include/ki/protocol/dml/MessageHeader.h new file mode 100644 index 0000000..7bda780 --- /dev/null +++ b/include/ki/protocol/dml/MessageHeader.h @@ -0,0 +1,37 @@ +#pragma once +#include "../../util/Serializable.h" +#include + +namespace ki +{ +namespace protocol +{ +namespace dml +{ + class MessageHeader : public util::Serializable + { + public: + MessageHeader(uint8_t service_id = 0, + uint8_t type = 0, uint16_t size = 0); + virtual ~MessageHeader() = default; + + uint8_t get_service_id() const; + void set_service_id(uint8_t service_id); + + uint8_t get_type() const; + void set_type(uint8_t type); + + uint16_t get_message_size() const; + void set_message_size(uint16_t size); + + void write_to(std::ostream &ostream) const override final; + void read_from(std::istream &istream) override final; + size_t get_size() const override final; + private: + uint8_t m_service_id; + uint8_t m_type; + uint16_t m_size; + }; +} +} +} \ No newline at end of file diff --git a/include/ki/protocol/dml/MessageManager.h b/include/ki/protocol/dml/MessageManager.h new file mode 100644 index 0000000..a051ba2 --- /dev/null +++ b/include/ki/protocol/dml/MessageManager.h @@ -0,0 +1,44 @@ +#pragma once +#include "Message.h" +#include "MessageModule.h" +#include "../../dml/Record.h" +#include + +namespace ki +{ +namespace protocol +{ +namespace dml +{ + class MessageManager + { + public: + MessageManager() = default; + ~MessageManager(); + + const MessageModule *load_module(std::string filepath); + const MessageModule *get_module(uint8_t service_id) const; + const MessageModule *get_module(const std::string &protocol_type) const; + + Message *create_message(uint8_t service_id, uint8_t message_type) const; + Message *create_message(uint8_t service_id, const std::string &message_name) const; + Message *create_message(const std::string &protocol_type, uint8_t message_type) const; + Message *create_message(const std::string &protocol_type, const std::string &message_name) const; + + /** + * If the DML message header cannot be read, then a nullptr + * is returned; otherwise, a valid Message pointer is always returned. + * However, that does not mean that the message itself is valid. + * + * To verify if the record was completely parsed, get_record + * should return a valid Record pointer, rather than nullptr. + */ + const Message *message_from_binary(std::istream &istream) const; + private: + MessageModuleList m_modules; + MessageModuleServiceIdMap m_service_id_map; + MessageModuleProtocolTypeMap m_protocol_type_map; + }; +} +} +} diff --git a/include/ki/protocol/dml/MessageModule.h b/include/ki/protocol/dml/MessageModule.h new file mode 100644 index 0000000..c58e080 --- /dev/null +++ b/include/ki/protocol/dml/MessageModule.h @@ -0,0 +1,55 @@ +#pragma once +#include "Message.h" +#include "MessageTemplate.h" +#include +#include +#include +#include + +namespace ki +{ +namespace protocol +{ +namespace dml +{ + class MessageModule + { + public: + MessageModule(uint8_t service_id = 0, std::string protocol_type = ""); + ~MessageModule(); + + uint8_t get_service_id() const; + void set_service_id(uint8_t service_id); + + std::string get_protocol_type() const; + void set_protocol_type(std::string protocol_type); + + std::string get_protocol_desription() const; + void set_protocol_description(std::string protocol_description); + + const MessageTemplate *add_message_template(std::string name, + ki::dml::Record *record, bool auto_sort = true); + const MessageTemplate *get_message_template(uint8_t type) const; + const MessageTemplate *get_message_template(std::string name) const; + + void sort_lookup(); + + Message *create_message(uint8_t message_type) const; + Message *create_message(std::string message_name) const; + private: + uint8_t m_service_id; + std::string m_protocol_type; + std::string m_protocol_description; + uint8_t m_last_message_type; + + std::vector m_templates; + std::map m_message_type_map; + std::map m_message_name_map; + }; + + typedef std::vector MessageModuleList; + typedef std::map MessageModuleServiceIdMap; + typedef std::map MessageModuleProtocolTypeMap; +} +} +} diff --git a/include/ki/protocol/dml/MessageTemplate.h b/include/ki/protocol/dml/MessageTemplate.h new file mode 100644 index 0000000..3f82ab2 --- /dev/null +++ b/include/ki/protocol/dml/MessageTemplate.h @@ -0,0 +1,46 @@ +#pragma once +#include "../../dml/Record.h" +#include "Message.h" +#include + +namespace ki +{ +namespace protocol +{ +namespace dml +{ + class MessageTemplate + { + public: + MessageTemplate(std::string name, uint8_t type, + uint8_t service_id, ki::dml::Record *record); + ~MessageTemplate(); + + std::string get_name() const; + void set_name(std::string name); + + uint8_t get_type() const; + void set_type(uint8_t type); + + uint8_t get_service_id() const; + void set_service_id(uint8_t service_id); + + std::string get_handler() const; + void set_handler(std::string handler); + + uint8_t get_access_level() const; + void set_access_level(uint8_t access_level); + + const ki::dml::Record &get_record() const; + void set_record(ki::dml::Record *record); + + Message *create_message() const; + private: + std::string m_name; + uint8_t m_type; + uint8_t m_service_id; + ki::dml::Record *m_record; + }; +} +} +} \ No newline at end of file diff --git a/include/ki/protocol/exception.h b/include/ki/protocol/exception.h new file mode 100644 index 0000000..b379ef6 --- /dev/null +++ b/include/ki/protocol/exception.h @@ -0,0 +1,64 @@ +#pragma once +#include + +namespace ki +{ +namespace protocol +{ + class runtime_error : public std::runtime_error + { + public: + explicit runtime_error(std::string message) : std::runtime_error(message) {} + }; + + class parse_error : public runtime_error + { + public: + enum code + { + NONE, + INVALID_XML_DATA, + INVALID_HEADER_DATA, + INSUFFICIENT_MESSAGE_DATA, + INVALID_MESSAGE_DATA + }; + + explicit parse_error(std::string message, code error = code::NONE) + : runtime_error(message) + { + m_code = error; + } + + code get_error_code() const { return m_code; } + private: + code m_code; + }; + + class value_error : public runtime_error + { + public: + enum code + { + NONE, + MISSING_FILE, + OVERWRITES_LOOKUP, + EXCEEDS_LIMIT, + + DML_INVALID_SERVICE, + DML_INVALID_PROTOCOL_TYPE, + DML_INVALID_MESSAGE_TYPE, + DML_INVALID_MESSAGE_NAME + }; + + explicit value_error(std::string message, code error = code::NONE) + : runtime_error(message) + { + m_code = error; + } + + code get_error_code() const { return m_code; } + private: + code m_code; + }; +} +} \ No newline at end of file diff --git a/include/ki/protocol/net/ClientDMLSession.h b/include/ki/protocol/net/ClientDMLSession.h new file mode 100644 index 0000000..a0ced83 --- /dev/null +++ b/include/ki/protocol/net/ClientDMLSession.h @@ -0,0 +1,35 @@ +#pragma once +#include "ClientSession.h" +#include "DMLSession.h" + +// Disable inheritance via dominance warning +#if _MSC_VER +#pragma warning(disable: 4250) +#endif + +namespace ki +{ +namespace protocol +{ +namespace net +{ + class ClientDMLSession : public ClientSession, public DMLSession + { + // Explicitly specify that we are intentionally inheritting + // via dominance. + using DMLSession::on_application_message; + using ClientSession::on_control_message; + using ClientSession::is_alive; + public: + ClientDMLSession(const uint16_t id, const dml::MessageManager &manager) + : Session(id), ClientSession(id), DMLSession(id, manager) {} + virtual ~ClientDMLSession() = default; + }; +} +} +} + +// Re-enable inheritance via dominance warning +#if _MSC_VER +#pragma warning(default: 4250) +#endif diff --git a/include/ki/protocol/net/ClientSession.h b/include/ki/protocol/net/ClientSession.h new file mode 100644 index 0000000..a9b8351 --- /dev/null +++ b/include/ki/protocol/net/ClientSession.h @@ -0,0 +1,34 @@ +#pragma once +#include "Session.h" + +#define KI_SERVER_HEARTBEAT 60 + +namespace ki +{ +namespace protocol +{ +namespace net +{ + /** + * Implements client-sided session logic. + */ + class ClientSession : public virtual Session + { + public: + explicit ClientSession(uint16_t id); + virtual ~ClientSession() = default; + + void send_keep_alive(); + bool is_alive() const override; + protected: + void on_connected(); + virtual void on_established() {} + void on_control_message(const PacketHeader& header) override; + private: + void on_session_offer(); + void on_keep_alive(); + void on_keep_alive_response(); + }; +} +} +} \ No newline at end of file diff --git a/include/ki/protocol/net/DMLSession.h b/include/ki/protocol/net/DMLSession.h new file mode 100644 index 0000000..c86e29e --- /dev/null +++ b/include/ki/protocol/net/DMLSession.h @@ -0,0 +1,44 @@ +#pragma once +#include "Session.h" +#include "../dml/MessageManager.h" + +namespace ki +{ +namespace protocol +{ +namespace net +{ + enum class InvalidDMLMessageErrorCode + { + NONE, + UNKNOWN, + INVALID_HEADER_DATA, + INVALID_MESSAGE_DATA, + INVALID_SERVICE, + INVALID_MESSAGE_TYPE, + INSUFFICIENT_ACCESS + }; + + /** + * Implements an application protocol that uses the DML + * message system (as seen in Wizard101 and Pirate101). + */ + class DMLSession : public virtual Session + { + public: + DMLSession(uint16_t id, const dml::MessageManager &manager); + virtual ~DMLSession() = default; + + const dml::MessageManager &get_manager() const; + + void send_message(const dml::Message &message); + protected: + void on_application_message(const PacketHeader& header) override; + virtual void on_message(const dml::Message *message) {} + virtual void on_invalid_message(InvalidDMLMessageErrorCode error) {} + private: + const dml::MessageManager &m_manager; + }; +} +} +} diff --git a/include/ki/protocol/net/PacketHeader.h b/include/ki/protocol/net/PacketHeader.h new file mode 100644 index 0000000..e516fa9 --- /dev/null +++ b/include/ki/protocol/net/PacketHeader.h @@ -0,0 +1,33 @@ +#pragma once +#include "../../util/Serializable.h" +#include +#include + +namespace ki +{ +namespace protocol +{ +namespace net +{ + class PacketHeader final : public util::Serializable + { + public: + PacketHeader(bool control = false, uint8_t opcode = 0); + virtual ~PacketHeader() = default; + + bool is_control() const; + void set_control(bool control); + + uint8_t get_opcode() const; + void set_opcode(uint8_t opcode); + + void write_to(std::ostream &ostream) const override final; + void read_from(std::istream &istream) override final; + size_t get_size() const override final; + private: + bool m_control; + uint8_t m_opcode; + }; +} +} +} \ No newline at end of file diff --git a/include/ki/protocol/net/ServerDMLSession.h b/include/ki/protocol/net/ServerDMLSession.h new file mode 100644 index 0000000..d450626 --- /dev/null +++ b/include/ki/protocol/net/ServerDMLSession.h @@ -0,0 +1,35 @@ +#pragma once +#include "ServerSession.h" +#include "DMLSession.h" + +// Disable inheritance via dominance warning +#if _MSC_VER +#pragma warning(disable: 4250) +#endif + +namespace ki +{ +namespace protocol +{ +namespace net +{ + class ServerDMLSession : public ServerSession, public DMLSession + { + // Explicitly specify that we are intentionally inheritting + // via dominance. + using DMLSession::on_application_message; + using ServerSession::on_control_message; + using ServerSession::is_alive; + public: + ServerDMLSession(const uint16_t id, const dml::MessageManager &manager) + : Session(id), ServerSession(id), DMLSession(id, manager) {} + virtual ~ServerDMLSession() = default; + }; +} +} +} + +// Re-enable inheritance via dominance warning +#if _MSC_VER +#pragma warning(default: 4250) +#endif diff --git a/include/ki/protocol/net/ServerSession.h b/include/ki/protocol/net/ServerSession.h new file mode 100644 index 0000000..77b4da7 --- /dev/null +++ b/include/ki/protocol/net/ServerSession.h @@ -0,0 +1,34 @@ +#pragma once +#include "Session.h" + +#define KI_CLIENT_HEARTBEAT 10 + +namespace ki +{ +namespace protocol +{ +namespace net +{ + /** + * Implements server-sided session logic. + */ + class ServerSession : public virtual Session + { + public: + explicit ServerSession(uint16_t id); + virtual ~ServerSession() = default; + + void send_keep_alive(uint32_t milliseconds_since_startup); + bool is_alive() const override; + protected: + void on_connected(); + virtual void on_established() {} + void on_control_message(const PacketHeader& header) override; + private: + void on_session_accept(); + void on_keep_alive(); + void on_keep_alive_response(); + }; +} +} +} \ No newline at end of file diff --git a/include/ki/protocol/net/Session.h b/include/ki/protocol/net/Session.h new file mode 100644 index 0000000..e877319 --- /dev/null +++ b/include/ki/protocol/net/Session.h @@ -0,0 +1,138 @@ +#pragma once +#include "PacketHeader.h" +#include "../control/Opcode.h" +#include "../../util/Serializable.h" +#include +#include +#include +#include + +#define KI_DEFAULT_MAXIMUM_RECEIVE_SIZE 0x2000 +#define KI_START_SIGNAL 0xF00D +#define KI_CONNECTION_TIMEOUT 3 + +namespace ki +{ +namespace protocol +{ +namespace net +{ + enum class SessionCloseErrorCode + { + NONE, + APPLICATION_ERROR, + + INVALID_FRAMING_START_SIGNAL, + INVALID_FRAMING_SIZE_EXCEEDS_MAXIMUM, + + UNHANDLED_CONTROL_MESSAGE, + UNHANDLED_APPLICATION_MESSAGE, + INVALID_MESSAGE, + + SESSION_OFFER_TIMED_OUT, + SESSION_DIED + }; + + enum class ReceiveState + { + // Waiting for the 0xF00D start signal. + WAITING_FOR_START_SIGNAL, + + // Waiting for the 2-byte length. + WAITING_FOR_LENGTH, + + // Waiting for the packet data. + WAITING_FOR_PACKET + }; + + /** + * This class implements session and packet framing logic + * when sending and receiving data to/from an external + * source. + */ + class Session + { + public: + explicit Session(uint16_t id = 0); + virtual ~Session() = default; + + uint16_t get_maximum_packet_size() const; + void set_maximum_packet_size(uint16_t maximum_packet_size); + + uint16_t get_id() const; + bool is_established() const; + + uint8_t get_access_level() const; + void set_access_level(uint8_t access_level); + + uint16_t get_latency() const; + + virtual bool is_alive() const = 0; + + void send_packet(bool is_control, uint8_t opcode, + const util::Serializable &data); + protected: + /* Higher-level session members */ + uint16_t m_id; + bool m_established; + uint8_t m_access_level; + + /* Timing members */ + std::chrono::steady_clock::time_point m_creation_time; + std::chrono::steady_clock::time_point m_connection_time; + std::chrono::steady_clock::time_point m_establish_time; + std::chrono::steady_clock::time_point m_last_received_heartbeat_time; + std::chrono::steady_clock::time_point m_last_sent_heartbeat_time; + bool m_waiting_for_keep_alive_response; + uint16_t m_latency; + + // The packet data stream + std::stringstream m_data_stream; + + /** + * Reads a serializable structure from the data stream. + */ + template + DataT read_data() + { + static_assert(std::is_base_of::value, + "DataT must inherit Serializable."); + + DataT data = DataT(); + data.read_from(m_data_stream); + return data; + } + + /** + * Frames raw data into a Packet, and transmits it. + */ + void send_data(const char *data, size_t size); + + /** + * Process incoming raw data into Packets. + * Once a packet is read into the internal data + * stream, handle_packet_available is called. + */ + void process_data(const char *data, size_t size); + + /* Event handlers */ + virtual void on_invalid_packet() {} + virtual void on_control_message(const PacketHeader &header) {} + virtual void on_application_message(const PacketHeader &header) {} + + /* Low-level socket methods */ + virtual void send_packet_data(const char *data, const size_t size) = 0; + virtual void close(SessionCloseErrorCode error) = 0; + private: + /* Low-level networking members */ + uint16_t m_maximum_packet_size; + ReceiveState m_receive_state; + uint16_t m_start_signal; + uint16_t m_incoming_packet_size; + uint8_t m_shift; + + void on_packet_available(); + }; +} +} +} diff --git a/src/protocol/CMakeLists.txt b/src/protocol/CMakeLists.txt new file mode 100644 index 0000000..191b07e --- /dev/null +++ b/src/protocol/CMakeLists.txt @@ -0,0 +1,17 @@ +target_sources(${PROJECT_NAME} + PRIVATE + ${PROJECT_SOURCE_DIR}/src/protocol/control/ClientKeepAlive.cpp + ${PROJECT_SOURCE_DIR}/src/protocol/control/ServerKeepAlive.cpp + ${PROJECT_SOURCE_DIR}/src/protocol/control/SessionAccept.cpp + ${PROJECT_SOURCE_DIR}/src/protocol/control/SessionOffer.cpp + ${PROJECT_SOURCE_DIR}/src/protocol/dml/Message.cpp + ${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageHeader.cpp + ${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageManager.cpp + ${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageModule.cpp + ${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageTemplate.cpp + ${PROJECT_SOURCE_DIR}/src/protocol/net/ClientSession.cpp + ${PROJECT_SOURCE_DIR}/src/protocol/net/DMLSession.cpp + ${PROJECT_SOURCE_DIR}/src/protocol/net/PacketHeader.cpp + ${PROJECT_SOURCE_DIR}/src/protocol/net/ServerSession.cpp + ${PROJECT_SOURCE_DIR}/src/protocol/net/Session.cpp +) \ No newline at end of file diff --git a/src/protocol/control/ClientKeepAlive.cpp b/src/protocol/control/ClientKeepAlive.cpp new file mode 100644 index 0000000..49303e4 --- /dev/null +++ b/src/protocol/control/ClientKeepAlive.cpp @@ -0,0 +1,87 @@ +#include "ki/protocol/control/ClientKeepAlive.h" +#include "ki/dml/Record.h" +#include "ki/protocol/exception.h" + +namespace ki +{ +namespace protocol +{ +namespace control +{ + ClientKeepAlive::ClientKeepAlive(const uint16_t session_id, const uint16_t milliseconds, + const uint16_t minutes) + { + m_session_id = session_id; + m_milliseconds = milliseconds; + m_minutes = minutes; + } + + uint16_t ClientKeepAlive::get_session_id() const + { + return m_session_id; + } + + void ClientKeepAlive::set_session_id(const uint16_t session_id) + { + m_session_id = session_id; + } + + uint16_t ClientKeepAlive::get_milliseconds() const + { + return m_milliseconds; + } + + void ClientKeepAlive::set_milliseconds(const uint16_t milliseconds) + { + m_milliseconds = milliseconds; + } + + uint16_t ClientKeepAlive::get_minutes() const + { + return m_minutes; + } + + void ClientKeepAlive::set_minutes(const uint16_t minutes) + { + m_minutes = minutes; + } + + void ClientKeepAlive::write_to(std::ostream &ostream) const + { + dml::Record record; + record.add_field("m_session_id")->set_value(m_session_id); + record.add_field("m_milliseconds")->set_value(m_milliseconds); + record.add_field("m_minutes")->set_value(m_minutes); + record.write_to(ostream); + } + + void ClientKeepAlive::read_from(std::istream &istream) + { + dml::Record record; + auto *session_id = record.add_field("m_session_id"); + auto *milliseconds = record.add_field("m_milliseconds"); + auto *minutes = record.add_field("m_minutes"); + try + { + record.read_from(istream); + } + catch (dml::parse_error &e) + { + std::ostringstream oss; + oss << "Error reading ClientKeepAlive payload: " << e.what(); + throw parse_error(oss.str(), parse_error::INVALID_MESSAGE_DATA); + } + + m_session_id = session_id->get_value(); + m_milliseconds = milliseconds->get_value(); + m_minutes = minutes->get_value(); + } + + size_t ClientKeepAlive::get_size() const + { + return sizeof(dml::USHRT) + sizeof(dml::USHRT) + + sizeof(dml::USHRT); + } +} +} +} diff --git a/src/protocol/control/ServerKeepAlive.cpp b/src/protocol/control/ServerKeepAlive.cpp new file mode 100644 index 0000000..4b77df5 --- /dev/null +++ b/src/protocol/control/ServerKeepAlive.cpp @@ -0,0 +1,61 @@ +#include "ki/protocol/control/ServerKeepAlive.h" +#include "ki/dml/Record.h" +#include "ki/protocol/exception.h" +#include + +namespace ki +{ +namespace protocol +{ +namespace control +{ + ServerKeepAlive::ServerKeepAlive(const uint32_t timestamp) + { + m_timestamp = timestamp; + } + + uint32_t ServerKeepAlive::get_timestamp() const + { + return m_timestamp; + } + + void ServerKeepAlive::set_timestamp(const uint32_t timestamp) + { + m_timestamp = timestamp; + } + + void ServerKeepAlive::write_to(std::ostream& ostream) const + { + dml::Record record; + record.add_field("m_session_id"); + record.add_field("m_timestamp")->set_value(m_timestamp); + record.write_to(ostream); + } + + void ServerKeepAlive::read_from(std::istream& istream) + { + dml::Record record; + record.add_field("m_session_id"); + auto *timestamp = record.add_field("m_timestamp"); + try + { + record.read_from(istream); + } + catch (dml::parse_error &e) + { + std::ostringstream oss; + oss << "Error reading ServerKeepAlive payload: " << e.what(); + throw parse_error(oss.str(), parse_error::INVALID_MESSAGE_DATA); + } + + m_timestamp = timestamp->get_value(); + } + + size_t ServerKeepAlive::get_size() const + { + return sizeof(dml::USHRT) + sizeof(dml::INT); + } + +} +} +} diff --git a/src/protocol/control/SessionAccept.cpp b/src/protocol/control/SessionAccept.cpp new file mode 100644 index 0000000..20c92ca --- /dev/null +++ b/src/protocol/control/SessionAccept.cpp @@ -0,0 +1,92 @@ +#include "ki/protocol/control/SessionAccept.h" +#include "ki/dml/Record.h" +#include "ki/protocol/exception.h" + +namespace ki +{ +namespace protocol +{ +namespace control +{ + SessionAccept::SessionAccept(const uint16_t session_id, + const int32_t timestamp, const uint32_t milliseconds) + { + m_session_id = session_id; + m_timestamp = timestamp; + m_milliseconds = milliseconds; + } + + uint16_t SessionAccept::get_session_id() const + { + return m_session_id; + } + + void SessionAccept::set_session_id(const uint16_t session_id) + { + m_session_id = session_id; + } + + int32_t SessionAccept::get_timestamp() const + { + return m_timestamp; + } + + void SessionAccept::set_timestamp(const int32_t timestamp) + { + m_timestamp = timestamp; + } + + uint32_t SessionAccept::get_milliseconds() const + { + return m_milliseconds; + } + + void SessionAccept::set_milliseconds(const uint32_t milliseconds) + { + m_milliseconds = milliseconds; + } + + void SessionAccept::write_to(std::ostream& ostream) const + { + dml::Record record; + record.add_field("unknown"); + record.add_field("unknown2"); + record.add_field("m_timestamp")->set_value(m_timestamp); + record.add_field("m_milliseconds")->set_value(m_milliseconds); + record.add_field("m_session_id")->set_value(m_session_id); + record.write_to(ostream); + } + + void SessionAccept::read_from(std::istream& istream) + { + dml::Record record; + record.add_field("unknown"); + record.add_field("unknown2"); + auto *timestamp = record.add_field("m_timestamp"); + auto *milliseconds = record.add_field("m_milliseconds"); + auto *session_id = record.add_field("m_session_id"); + try + { + record.read_from(istream); + } + catch (dml::parse_error &e) + { + std::ostringstream oss; + oss << "Error reading SessionAccept payload: " << e.what(); + throw parse_error(oss.str(), parse_error::INVALID_MESSAGE_DATA); + } + + m_timestamp = timestamp->get_value(); + m_milliseconds = milliseconds->get_value(); + m_session_id = session_id->get_value(); + } + + size_t SessionAccept::get_size() const + { + return sizeof(dml::USHRT) + sizeof(dml::UINT) + + sizeof(dml::INT) + sizeof(dml::UINT) + + sizeof(dml::USHRT); + } +} +} +} diff --git a/src/protocol/control/SessionOffer.cpp b/src/protocol/control/SessionOffer.cpp new file mode 100644 index 0000000..07627eb --- /dev/null +++ b/src/protocol/control/SessionOffer.cpp @@ -0,0 +1,89 @@ +#include "ki/protocol/control/SessionOffer.h" +#include "ki/dml/Record.h" +#include "ki/protocol/exception.h" + +namespace ki +{ +namespace protocol +{ +namespace control +{ + SessionOffer::SessionOffer(const uint16_t session_id, + const int32_t timestamp, const uint32_t milliseconds) + { + m_session_id = session_id; + m_timestamp = timestamp; + m_milliseconds = milliseconds; + } + + uint16_t SessionOffer::get_session_id() const + { + return m_session_id; + } + + void SessionOffer::set_session_id(const uint16_t session_id) + { + m_session_id = session_id; + } + + int32_t SessionOffer::get_timestamp() const + { + return m_timestamp; + } + + void SessionOffer::set_timestamp(const int32_t timestamp) + { + m_timestamp = timestamp; + } + + uint32_t SessionOffer::get_milliseconds() const + { + return m_milliseconds; + } + + void SessionOffer::set_milliseconds(const uint32_t milliseconds) + { + m_milliseconds = milliseconds; + } + + void SessionOffer::write_to(std::ostream& ostream) const + { + dml::Record record; + record.add_field("m_session_id")->set_value(m_session_id); + record.add_field("unknown"); + record.add_field("m_timestamp")->set_value(m_timestamp); + record.add_field("m_milliseconds")->set_value(m_milliseconds); + record.write_to(ostream); + } + + void SessionOffer::read_from(std::istream& istream) + { + dml::Record record; + auto *session_id = record.add_field("m_session_id"); + record.add_field("unknown"); + auto *timestamp = record.add_field("m_timestamp"); + auto *milliseconds = record.add_field("m_milliseconds"); + try + { + record.read_from(istream); + } + catch (dml::parse_error &e) + { + std::ostringstream oss; + oss << "Error reading SessionOffer payload: " << e.what(); + throw parse_error(oss.str(), parse_error::INVALID_MESSAGE_DATA); + } + + m_session_id = session_id->get_value(); + m_timestamp = timestamp->get_value(); + m_milliseconds = milliseconds->get_value(); + } + + size_t SessionOffer::get_size() const + { + return sizeof(dml::USHRT) + sizeof(dml::UINT) + + sizeof(dml::INT) + sizeof(dml::UINT); + } +} +} +} diff --git a/src/protocol/dml/Message.cpp b/src/protocol/dml/Message.cpp new file mode 100644 index 0000000..215a694 --- /dev/null +++ b/src/protocol/dml/Message.cpp @@ -0,0 +1,173 @@ +#include "ki/protocol/dml/Message.h" +#include "ki/protocol/dml/MessageTemplate.h" +#include "ki/protocol/exception.h" + +namespace ki +{ +namespace protocol +{ +namespace dml +{ + Message::Message(const MessageTemplate *message_template) + { + m_template = message_template; + if (m_template) + m_record = new ki::dml::Record(m_template->get_record()); + else + m_record = nullptr; + } + + Message::~Message() + { + delete m_record; + } + + const MessageTemplate *Message::get_template() const + { + return m_template; + } + + void Message::set_template(const MessageTemplate *message_template) + { + m_template = message_template; + if (!m_template) + return; + + m_record = new ki::dml::Record(message_template->get_record()); + if (!m_raw_data.empty()) + { + std::istringstream iss(std::string(m_raw_data.data(), m_raw_data.size())); + try + { + m_record->read_from(iss); + m_raw_data.clear(); + } + catch (ki::dml::parse_error &e) + { + delete m_record; + m_template = nullptr; + m_record = nullptr; + + std::ostringstream oss; + oss << "Error reading DML message payload: " << e.what(); + throw parse_error(oss.str(), parse_error::INVALID_MESSAGE_DATA); + } + } + } + + uint8_t Message::get_service_id() const + { + if (m_template) + return m_template->get_service_id(); + return m_header.get_service_id(); + } + + uint8_t Message::get_type() const + { + if (m_template) + return m_template->get_type(); + return m_header.get_type(); + } + + uint16_t Message::get_message_size() const + { + if (m_record) + return m_record->get_size(); + return m_raw_data.size(); + } + + std::string Message::get_handler() const + { + if (m_template) + return m_template->get_handler(); + return ""; + } + + uint8_t Message::get_access_level() const + { + if (m_template) + return m_template->get_access_level(); + return 0; + } + + ki::dml::Record *Message::get_record() + { + return m_record; + } + + const ki::dml::Record *Message::get_record() const + { + return m_record; + } + + ki::dml::FieldBase* Message::get_field(std::string name) + { + if (m_record) + return m_record->get_field(name); + return nullptr; + } + + const ki::dml::FieldBase* Message::get_field(std::string name) const + { + if (m_record) + return m_record->get_field(name); + return nullptr; + } + + void Message::write_to(std::ostream &ostream) const + { + // Write the header + if (m_template) + { + MessageHeader header( + get_service_id(), get_type(), get_message_size()); + header.write_to(ostream); + } + else + m_header.write_to(ostream); + + // Write the payload + if (m_record) + m_record->write_to(ostream); + else + ostream.write(m_raw_data.data(), m_raw_data.size()); + } + + void Message::read_from(std::istream &istream) + { + m_header.read_from(istream); + if (m_template) + { + // Check for mismatches between the header and template + if (m_header.get_service_id() != m_template->get_service_id()) + throw value_error("ServiceID mismatch between MessageHeader and assigned template.", + value_error::DML_INVALID_SERVICE); + if (m_header.get_type() != m_template->get_type()) + throw value_error("Message Type mismatch between MessageHeader and assigned template.", + value_error::DML_INVALID_MESSAGE_TYPE); + + // Read the payload into the record + m_record->read_from(istream); + } + else + { + // We don't have a template for the record structure, so + // just read the raw data into a buffer. + const auto size = m_header.get_message_size(); + m_raw_data.resize(size); + istream.read(m_raw_data.data(), size); + if (istream.fail()) + throw parse_error("Not enough data was available to read DML message payload.", + parse_error::INSUFFICIENT_MESSAGE_DATA); + } + } + + size_t Message::get_size() const + { + if (m_record) + return m_header.get_size() + m_record->get_size(); + return 4 + m_raw_data.size(); + } +} +} +} \ No newline at end of file diff --git a/src/protocol/dml/MessageHeader.cpp b/src/protocol/dml/MessageHeader.cpp new file mode 100644 index 0000000..0e4d746 --- /dev/null +++ b/src/protocol/dml/MessageHeader.cpp @@ -0,0 +1,88 @@ +#include "ki/protocol/dml/MessageHeader.h" +#include "ki/dml/Record.h" +#include "ki/protocol/exception.h" + +namespace ki +{ +namespace protocol +{ +namespace dml +{ + MessageHeader::MessageHeader(const uint8_t service_id, + const uint8_t type, const uint16_t size) + { + m_service_id = service_id; + m_type = type; + m_size = size; + } + + uint8_t MessageHeader::get_service_id() const + { + return m_service_id; + } + + void MessageHeader::set_service_id(const uint8_t service_id) + { + m_service_id = service_id; + } + + uint8_t MessageHeader::get_type() const + { + return m_type; + } + + void MessageHeader::set_type(const uint8_t type) + { + m_type = type; + } + + uint16_t MessageHeader::get_message_size() const + { + return m_size; + } + + void MessageHeader::set_message_size(const uint16_t size) + { + m_size = size; + } + + void MessageHeader::write_to(std::ostream& ostream) const + { + ki::dml::Record record; + record.add_field("m_service_id")->set_value(m_service_id); + record.add_field("m_type")->set_value(m_type); + record.add_field("m_size")->set_value(m_size + 4); + record.write_to(ostream); + } + + void MessageHeader::read_from(std::istream& istream) + { + ki::dml::Record record; + const auto *service_id = record.add_field("m_service_id"); + const auto *type = record.add_field("m_type"); + const auto size = record.add_field("m_size"); + + try + { + record.read_from(istream); + } + catch (ki::dml::parse_error &e) + { + std::ostringstream oss; + oss << "Error reading MessageHeader: " << e.what(); + throw parse_error(oss.str(), parse_error::INVALID_HEADER_DATA); + } + + m_service_id = service_id->get_value(); + m_type = type->get_value(); + m_size = size->get_value() - 4; + } + + size_t MessageHeader::get_size() const + { + return sizeof(ki::dml::UBYT) + sizeof(ki::dml::UBYT) + + sizeof(ki::dml::USHRT); + } +} +} +} diff --git a/src/protocol/dml/MessageManager.cpp b/src/protocol/dml/MessageManager.cpp new file mode 100644 index 0000000..09b105d --- /dev/null +++ b/src/protocol/dml/MessageManager.cpp @@ -0,0 +1,260 @@ +#include "ki/protocol/dml/MessageManager.h" +#include "ki/protocol/dml/MessageHeader.h" +#include "ki/protocol/exception.h" +#include "ki/dml/Record.h" +#include "ki/util/ValueBytes.h" +#include +#include +#include + +namespace ki +{ +namespace protocol +{ +namespace dml +{ + MessageManager::~MessageManager() + { + for (auto it = m_modules.begin(); + it != m_modules.end(); ++it) + delete *it; + m_modules.clear(); + m_service_id_map.clear(); + m_protocol_type_map.clear(); + } + + const MessageModule *MessageManager::load_module(std::string filepath) + { + // Open the file + std::ifstream ifs(filepath, std::ios::ate); + if (!ifs.is_open()) + { + std::ostringstream oss; + oss << "Could not open file: " << filepath; + throw value_error(oss.str(), value_error::MISSING_FILE); + } + + // Load contents into memory + size_t size = ifs.tellg(); + ifs.seekg(0, std::ios::beg); + char *data = new char[size + 1] { 0 }; + ifs.read(data, size); + + // Parse the contents + rapidxml::xml_document<> doc; + try + { + doc.parse<0>(data); + } + catch (rapidxml::parse_error &e) + { + delete[] data; + + std::ostringstream oss; + oss << "Failed to parse: " << filepath; + throw parse_error(oss.str(), parse_error::INVALID_XML_DATA); + } + + // It's safe to allocate the module we're working on now + auto *message_module = new MessageModule(); + + // Get the root node and iterate through children + // Each child is a MessageTemplate + auto *root = doc.first_node(); + for (auto *node = root->first_node(); + node; node = node->next_sibling()) + { + // Parse the record node inside this node + auto *record_node = node->first_node(); + if (!record_node) + continue; + auto *record = new ki::dml::Record(); + record->from_xml(record_node); + + // The message name is initially based on the element name + const std::string message_name = node->name(); + if (message_name == "_ProtocolInfo") + { + auto *service_id_field = record->get_field("ServiceID"); + auto *type_field = record->get_field("ProtocolType"); + auto *description_field = record->get_field("ProtocolDescription"); + + // Set the module metadata from this template + if (service_id_field) + message_module->set_service_id(service_id_field->get_value()); + if (type_field) + message_module->set_protocol_type(type_field->get_value()); + if (description_field) + message_module->set_protocol_description(description_field->get_value()); + } + else + { + // Only do sorting after we've reached the final message + // This only affects modules that aren't ordered with _MsgOrder. + const bool auto_sort = node->next_sibling() == nullptr; + + // The template will use the record itself to figure out name and type; + // we only give the XML data incase the record doesn't have it defined. + auto *message_template = message_module->add_message_template(message_name, record, auto_sort); + if (!message_template) + { + delete[] data; + delete message_module; + delete record; + + std::ostringstream oss; + oss << "Failed to create message template for "; + oss << message_name; + throw value_error(oss.str()); + } + } + } + + // Make sure we aren't overwriting another module + if (m_service_id_map.count(message_module->get_service_id()) == 1) + { + delete[] data; + delete message_module; + + std::ostringstream oss; + oss << "Message Module has already been loaded with Service ID "; + oss << (uint16_t)message_module->get_service_id(); + throw value_error(oss.str(), value_error::OVERWRITES_LOOKUP); + } + + if (m_protocol_type_map.count(message_module->get_protocol_type()) == 1) + { + delete[] data; + delete message_module; + + std::ostringstream oss; + oss << "Message Module has already been loaded with Protocol Type "; + oss << message_module->get_protocol_type(); + throw value_error(oss.str(), value_error::OVERWRITES_LOOKUP); + } + + // Add it to our maps + m_modules.push_back(message_module); + m_service_id_map.insert({ message_module->get_service_id(), message_module }); + m_protocol_type_map.insert({ message_module->get_protocol_type(), message_module }); + + delete[] data; + return message_module; + } + + const MessageModule *MessageManager::get_module(uint8_t service_id) const + { + if (m_service_id_map.count(service_id) == 1) + return m_service_id_map.at(service_id); + return nullptr; + } + + const MessageModule *MessageManager::get_module(const std::string &protocol_type) const + { + if (m_protocol_type_map.count(protocol_type) == 1) + return m_protocol_type_map.at(protocol_type); + return nullptr; + } + + Message *MessageManager::create_message(uint8_t service_id, uint8_t message_type) const + { + auto *message_module = get_module(service_id); + if (!message_module) + { + std::ostringstream oss; + oss << "No service exists with id: " << (uint16_t)service_id; + throw value_error(oss.str(), value_error::DML_INVALID_SERVICE); + } + + return message_module->create_message(message_type); + } + + Message *MessageManager::create_message(uint8_t service_id, const std::string& message_name) const + { + auto *message_module = get_module(service_id); + if (!message_module) + { + std::ostringstream oss; + oss << "No service exists with id: " << (uint16_t)service_id; + throw value_error(oss.str(), value_error::DML_INVALID_SERVICE); + } + + return message_module->create_message(message_name); + } + + Message *MessageManager::create_message(const std::string& protocol_type, uint8_t message_type) const + { + auto *message_module = get_module(protocol_type); + if (!message_module) + { + std::ostringstream oss; + oss << "No service exists with protocol type: " << protocol_type; + throw value_error(oss.str(), value_error::DML_INVALID_PROTOCOL_TYPE); + } + + return message_module->create_message(message_type); + } + + Message *MessageManager::create_message(const std::string& protocol_type, const std::string& message_name) const + { + auto *message_module = get_module(protocol_type); + if (!message_module) + { + std::ostringstream oss; + oss << "No service exists with protocol type: " << protocol_type; + throw value_error(oss.str(), value_error::DML_INVALID_PROTOCOL_TYPE); + } + + return message_module->create_message(message_name); + } + + const Message *MessageManager::message_from_binary(std::istream& istream) const + { + // Read the message header + MessageHeader header; + header.read_from(istream); + + // Get the message module that uses the specified service id + auto *message_module = get_module(header.get_service_id()); + if (!message_module) + { + std::ostringstream oss; + oss << "No service exists with id: " << (uint16_t)header.get_service_id(); + throw value_error(oss.str(), value_error::DML_INVALID_SERVICE); + } + + // Get the message template for this message type + auto *message_template = message_module->get_message_template(header.get_type()); + if (!message_template) + { + std::ostringstream oss; + oss << "No message exists with type: " << (uint16_t)header.get_service_id(); + oss << "(service=" << message_module->get_protocol_type() << ")"; + throw value_error(oss.str(), value_error::DML_INVALID_MESSAGE_TYPE); + } + + // Make sure that the size specified is enough to read this message + if (header.get_message_size() < message_template->get_record().get_size()) + { + std::ostringstream oss; + oss << "No message exists with type: " << (uint16_t)header.get_service_id(); + oss << "(service=" << message_module->get_protocol_type() << ")"; + throw value_error(oss.str(), value_error::DML_INVALID_MESSAGE_TYPE); + } + + // Create a new Message from the template + auto *message = new Message(message_template); + try + { + message->get_record()->read_from(istream); + } + catch (ki::dml::parse_error &e) + { + delete message; + throw parse_error("Failed to read DML message payload.", parse_error::INVALID_MESSAGE_DATA); + } + return message; + } +} +} +} diff --git a/src/protocol/dml/MessageModule.cpp b/src/protocol/dml/MessageModule.cpp new file mode 100644 index 0000000..147c295 --- /dev/null +++ b/src/protocol/dml/MessageModule.cpp @@ -0,0 +1,171 @@ +#include "ki/protocol/dml/MessageModule.h" +#include "ki/protocol/exception.h" +#include + +namespace ki +{ +namespace protocol +{ +namespace dml +{ + MessageModule::MessageModule(uint8_t service_id, std::string protocol_type) + { + m_service_id = service_id; + m_protocol_type = protocol_type; + m_protocol_description = ""; + m_last_message_type = 0; + } + + MessageModule::~MessageModule() + { + for (auto it = m_templates.begin(); + it != m_templates.end(); ++it) + delete *it; + m_message_type_map.clear(); + m_message_name_map.clear(); + } + + uint8_t MessageModule::get_service_id() const + { + return m_service_id; + } + + void MessageModule::set_service_id(uint8_t service_id) + { + m_service_id = service_id; + } + + std::string MessageModule::get_protocol_type() const + { + return m_protocol_type; + } + + void MessageModule::set_protocol_type(std::string protocol_type) + { + m_protocol_type = protocol_type; + } + + std::string MessageModule::get_protocol_desription() const + { + return m_protocol_description; + } + + void MessageModule::set_protocol_description(std::string protocol_description) + { + m_protocol_description = protocol_description; + } + + const MessageTemplate *MessageModule::add_message_template(std::string name, + ki::dml::Record *record, bool auto_sort) + { + if (!record) + return nullptr; + + // If the field exists, get the name from the record rather than the XML + auto *name_field = record->get_field("_MsgName"); + if (name_field) + name = name_field->get_value(); + + // Do we already have a message template with this name? + if (m_message_name_map.count(name) == 1) + return m_message_name_map.at(name); + + // Message type is based on the _MsgOrder field if it's present + // Otherwise it's based on the alphabetical order of template names + uint8_t message_type = 0; + auto *order_field = record->get_field("_MsgOrder"); + if (order_field) + { + message_type = order_field->get_value(); + + // Don't allow message type to be 0 + if (message_type == 0) + return nullptr; + + // Do we already have a template with this type? + if (m_message_type_map.count(message_type) == 1) + return nullptr; + } + + // Create the message template and add it to our lookups + auto *message_template = new MessageTemplate(name, message_type, m_service_id, record); + m_templates.push_back(message_template); + m_message_name_map.insert({ name, message_template }); + + // Is this module ordered? + if (message_type != 0) + m_message_type_map.insert({ message_type, message_template }); + else if (auto_sort) + sort_lookup(); + + return message_template; + } + + const MessageTemplate *MessageModule::get_message_template(uint8_t type) const + { + if (m_message_type_map.count(type) == 1) + return m_message_type_map.at(type); + return nullptr; + } + + const MessageTemplate *MessageModule::get_message_template(std::string name) const + { + if (m_message_name_map.count(name) == 1) + return m_message_name_map.at(name); + return nullptr; + } + + void MessageModule::sort_lookup() + { + uint8_t message_type = 1; + + // First, clear the message type map since we're going to be + // moving everything around + m_message_type_map.clear(); + + // Iterating over a map with std::string as the key + // is guaranteed to be in alphabetical order + for (auto it = m_message_name_map.begin(); + it != m_message_name_map.end(); ++it) + { + auto *message_template = it->second; + message_template->set_type(message_type); + m_message_type_map.insert({ message_type, message_template }); + message_type++; + + // Make sure we haven't overflowed + if (message_type == 0) + throw value_error("Module has more than 254 messages.", value_error::EXCEEDS_LIMIT); + } + } + + Message *MessageModule::create_message(uint8_t message_type) const + { + auto *message_template = get_message_template(message_type); + if (!message_template) + { + std::ostringstream oss; + oss << "No message exists with type: " << message_type; + oss << "(service=" << m_protocol_type << ")"; + throw value_error(oss.str(), value_error::DML_INVALID_MESSAGE_TYPE); + } + + return message_template->create_message(); + } + + Message *MessageModule::create_message(std::string message_name) const + { + auto *message_template = get_message_template(message_name); + if (!message_template) + { + std::ostringstream oss; + oss << "No message exists with name: " << message_name; + oss << "(service=" << m_protocol_type << ")"; + throw value_error(oss.str(), value_error::DML_INVALID_MESSAGE_NAME); + } + + return message_template->create_message(); + } +} +} +} \ No newline at end of file diff --git a/src/protocol/dml/MessageTemplate.cpp b/src/protocol/dml/MessageTemplate.cpp new file mode 100644 index 0000000..9e79929 --- /dev/null +++ b/src/protocol/dml/MessageTemplate.cpp @@ -0,0 +1,95 @@ +#include "ki/protocol/dml/MessageTemplate.h" + +namespace ki +{ +namespace protocol +{ +namespace dml +{ + MessageTemplate::MessageTemplate(std::string name, uint8_t type, + uint8_t service_id, ki::dml::Record* record) + { + m_name = name; + m_type = type; + m_service_id = service_id; + m_record = record; + } + + MessageTemplate::~MessageTemplate() + { + delete m_record; + } + + std::string MessageTemplate::get_name() const + { + return m_name; + } + + void MessageTemplate::set_name(std::string name) + { + m_name = name; + } + + uint8_t MessageTemplate::get_type() const + { + return m_type; + } + + void MessageTemplate::set_type(uint8_t type) + { + m_type = type; + } + + uint8_t MessageTemplate::get_service_id() const + { + return m_service_id; + } + + void MessageTemplate::set_service_id(uint8_t service_id) + { + m_service_id = service_id; + } + + std::string MessageTemplate::get_handler() const + { + const auto field = m_record->get_field("_MsgHandler"); + if (field) + return field->get_value(); + return m_name; + } + + void MessageTemplate::set_handler(std::string handler) + { + m_record->add_field("_MsgHandler")->set_value(handler); + } + + uint8_t MessageTemplate::get_access_level() const + { + const auto field = m_record->get_field("_MsgAccessLvl"); + if (field) + return field->get_value(); + return 0; + } + + void MessageTemplate::set_access_level(uint8_t access_level) + { + m_record->add_field("_MsgAccessLvl")->set_value(access_level); + } + + const ki::dml::Record& MessageTemplate::get_record() const + { + return *m_record; + } + + void MessageTemplate::set_record(ki::dml::Record* record) + { + m_record = record; + } + + Message *MessageTemplate::create_message() const + { + return new Message(this); + } +} +} +} \ No newline at end of file diff --git a/src/protocol/net/ClientSession.cpp b/src/protocol/net/ClientSession.cpp new file mode 100644 index 0000000..b77506d --- /dev/null +++ b/src/protocol/net/ClientSession.cpp @@ -0,0 +1,173 @@ +#include "ki/protocol/net/ClientSession.h" +#include "ki/protocol/control/SessionOffer.h" +#include "ki/protocol/control/SessionAccept.h" +#include "ki/protocol/control/ClientKeepAlive.h" +#include "ki/protocol/control/ServerKeepAlive.h" +#include "ki/protocol/exception.h" + +namespace ki +{ +namespace protocol +{ +namespace net +{ + ClientSession::ClientSession(const uint16_t id) + : Session(id) {} + + void ClientSession::send_keep_alive() + { + // Don't send a keep alive if we're waiting for a response + if (m_waiting_for_keep_alive_response) + return; + m_waiting_for_keep_alive_response = true; + + // Work out how many minutes have been since the establish time, and + // how many milliseconds we are in to the current minute. + const auto time_since_establish = std::chrono::steady_clock::now() - m_establish_time; + const auto minutes = std::chrono::duration_cast(time_since_establish); + const auto milliseconds = std::chrono::duration_cast( + time_since_establish - minutes + ).count(); + + // Send a KEEP_ALIVE packet + control::ClientKeepAlive keep_alive(m_id, milliseconds, minutes.count()); + send_packet(true, (uint8_t)control::Opcode::KEEP_ALIVE, keep_alive); + m_last_sent_heartbeat_time = std::chrono::steady_clock::now(); + } + + bool ClientSession::is_alive() const + { + // If the session isn't established yet, use the time of + // creation to decide whether this session is alive. + if (!m_established) + return std::chrono::duration_cast( + std::chrono::steady_clock::now() - m_creation_time + ).count() <= (KI_CONNECTION_TIMEOUT * 2); + + // Otherwise, use the last time we received a heartbeat. + return std::chrono::duration_cast( + std::chrono::steady_clock::now() - m_last_received_heartbeat_time + ).count() <= (KI_SERVER_HEARTBEAT * 2); + } + + void ClientSession::on_connected() + { + m_connection_time = std::chrono::steady_clock::now(); + } + + void ClientSession::on_control_message(const PacketHeader& header) + { + switch ((control::Opcode)header.get_opcode()) + { + case control::Opcode::SESSION_OFFER: + on_session_offer(); + break; + + case control::Opcode::KEEP_ALIVE: + on_keep_alive(); + break; + + case control::Opcode::KEEP_ALIVE_RSP: + on_keep_alive_response(); + break; + + default: + close(SessionCloseErrorCode::UNHANDLED_CONTROL_MESSAGE); + break; + } + } + + void ClientSession::on_session_offer() + { + // Read the payload data into a structure + control::SessionOffer offer; + try + { + offer = read_data(); + } + catch (parse_error &e) + { + // The SESSION_OFFER wasn't valid... + // Close the session + close(SessionCloseErrorCode::INVALID_MESSAGE); + return; + } + + // Should this session have already timed out? + if (std::chrono::duration_cast( + std::chrono::steady_clock::now() - m_connection_time + ).count() > KI_CONNECTION_TIMEOUT) + { + close(SessionCloseErrorCode::SESSION_OFFER_TIMED_OUT); + return; + } + + // Work out the current timestamp and how many milliseconds + // have elapsed in the current second. + auto now = std::chrono::system_clock::now(); + const auto timestamp = std::chrono::duration_cast( + now.time_since_epoch() + ).count(); + const auto milliseconds = std::chrono::duration_cast( + now.time_since_epoch() + ).count() - (timestamp * 1000); + + // Accept the session + m_id = offer.get_session_id(); + control::SessionAccept accept(m_id, timestamp, milliseconds); + send_packet(true, (uint8_t)control::Opcode::SESSION_ACCEPT, accept); + + // The session is successfully established + m_established = true; + m_establish_time = std::chrono::steady_clock::now(); + m_last_received_heartbeat_time = m_establish_time; + on_established(); + } + + void ClientSession::on_keep_alive() + { + // Read the payload data into a structure + control::ServerKeepAlive keep_alive; + try + { + keep_alive = read_data(); + } + catch (parse_error &e) + { + // The KEEP_ALIVE wasn't valid... + // Close the session + close(SessionCloseErrorCode::INVALID_MESSAGE); + return; + } + + // Send the response + m_last_received_heartbeat_time = std::chrono::steady_clock::now(); + send_packet(true, (uint8_t)control::Opcode::KEEP_ALIVE_RSP, keep_alive); + } + + void ClientSession::on_keep_alive_response() + { + // Read the payload data into a structure + try + { + // We don't actually need the data inside, but + // read it to check if the structure is right. + read_data(); + } + catch (parse_error &e) + { + // The KEEP_ALIVE_RSP wasn't valid... + // Close the session + close(SessionCloseErrorCode::INVALID_MESSAGE); + return; + } + + // Calculate latency and allow for KEEP_ALIVE packets to be sent again + m_latency = std::chrono::duration_cast( + std::chrono::steady_clock::now() - m_last_sent_heartbeat_time + ).count(); + m_waiting_for_keep_alive_response = false; + } +} +} +} diff --git a/src/protocol/net/DMLSession.cpp b/src/protocol/net/DMLSession.cpp new file mode 100644 index 0000000..06d75d7 --- /dev/null +++ b/src/protocol/net/DMLSession.cpp @@ -0,0 +1,77 @@ +#include "ki/protocol/net/DMLSession.h" +#include "ki/protocol/exception.h" + +namespace ki +{ +namespace protocol +{ +namespace net +{ + DMLSession::DMLSession(const uint16_t id, const dml::MessageManager& manager) + : Session(id), m_manager(manager) {} + + const dml::MessageManager& DMLSession::get_manager() const + { + return m_manager; + } + + void DMLSession::send_message(const dml::Message& message) + { + send_packet(false, 0, message); + } + + void DMLSession::on_application_message(const PacketHeader& header) + { + // Attempt to create a Message instance from the data in the stream + auto error_code = InvalidDMLMessageErrorCode::NONE; + const dml::Message *message = nullptr; + try + { + message = m_manager.message_from_binary(m_data_stream); + } + catch (parse_error &e) + { + switch (e.get_error_code()) + { + case parse_error::INVALID_HEADER_DATA: + error_code = InvalidDMLMessageErrorCode::INVALID_HEADER_DATA; + break; + case parse_error::INSUFFICIENT_MESSAGE_DATA: + case parse_error::INVALID_MESSAGE_DATA: + error_code = InvalidDMLMessageErrorCode::INVALID_MESSAGE_DATA; + break; + default: + error_code = InvalidDMLMessageErrorCode::UNKNOWN; + } + } + catch (value_error &e) + { + switch (e.get_error_code()) + { + case value_error::DML_INVALID_SERVICE: + error_code = InvalidDMLMessageErrorCode::INVALID_SERVICE; + break; + case value_error::DML_INVALID_MESSAGE_TYPE: + error_code = InvalidDMLMessageErrorCode::INVALID_MESSAGE_TYPE; + break; + default: + error_code = InvalidDMLMessageErrorCode::UNKNOWN; + } + } + + if (!message) + { + on_invalid_message(error_code); + return; + } + + // Are we sufficiently authenticated to handle this message? + if (get_access_level() >= message->get_access_level()) + on_message(message); + else + on_invalid_message(InvalidDMLMessageErrorCode::INSUFFICIENT_ACCESS); + delete message; + } +} +} +} diff --git a/src/protocol/net/PacketHeader.cpp b/src/protocol/net/PacketHeader.cpp new file mode 100644 index 0000000..986087e --- /dev/null +++ b/src/protocol/net/PacketHeader.cpp @@ -0,0 +1,67 @@ +#include "ki/protocol/net/PacketHeader.h" +#include "ki/protocol/exception.h" +#include + +namespace ki +{ +namespace protocol +{ +namespace net +{ + PacketHeader::PacketHeader(const bool control, const uint8_t opcode) + { + m_control = control; + m_opcode = opcode; + } + + bool PacketHeader::is_control() const + { + return m_control; + } + + void PacketHeader::set_control(const bool control) + { + m_control = control; + } + + uint8_t PacketHeader::get_opcode() const + { + return m_opcode; + } + + void PacketHeader::set_opcode(const uint8_t opcode) + { + m_opcode = opcode; + } + + void PacketHeader::write_to(std::ostream& ostream) const + { + ostream.put(m_control); + ostream.put(m_opcode); + ostream.put(0); + ostream.put(0); + } + + void PacketHeader::read_from(std::istream& istream) + { + m_control = istream.get() >= 1; + if (istream.fail()) + throw parse_error("Not enough data was available to read packet header. (m_control)", + parse_error::INVALID_HEADER_DATA); + m_opcode = istream.get(); + if (istream.fail()) + throw parse_error("Not enough data was available to read packet header. (m_opcode)", + parse_error::INVALID_HEADER_DATA); + istream.ignore(2); + if (istream.eof()) + throw parse_error("Not enough data was available to read packet header. (ignored bytes)", + parse_error::INVALID_HEADER_DATA); + } + + size_t PacketHeader::get_size() const + { + return 4; + } +} +} +} diff --git a/src/protocol/net/ServerSession.cpp b/src/protocol/net/ServerSession.cpp new file mode 100644 index 0000000..dc55b66 --- /dev/null +++ b/src/protocol/net/ServerSession.cpp @@ -0,0 +1,171 @@ +#include "ki/protocol/net/ServerSession.h" +#include "ki/protocol/control/SessionOffer.h" +#include "ki/protocol/control/SessionAccept.h" +#include "ki/protocol/control/ClientKeepAlive.h" +#include "ki/protocol/control/ServerKeepAlive.h" +#include "ki/protocol/exception.h" + +namespace ki +{ +namespace protocol +{ +namespace net +{ + ServerSession::ServerSession(const uint16_t id) + : Session(id) {} + + void ServerSession::send_keep_alive(const uint32_t milliseconds_since_startup) + { + // Don't send a keep alive if we're waiting for a response + if (m_waiting_for_keep_alive_response) + return; + m_waiting_for_keep_alive_response = true; + + // Send a KEEP_ALIVE packet + const control::ServerKeepAlive keep_alive(milliseconds_since_startup); + send_packet(true, (uint8_t)control::Opcode::KEEP_ALIVE, keep_alive); + m_last_sent_heartbeat_time = std::chrono::steady_clock::now(); + } + + bool ServerSession::is_alive() const + { + // If the session isn't established yet, use the time of + // creation to decide whether this session is alive. + if (!m_established) + return std::chrono::duration_cast( + std::chrono::steady_clock::now() - m_creation_time + ).count() <= (KI_CONNECTION_TIMEOUT * 2); + + // Otherwise, use the last time we received a heartbeat. + return std::chrono::duration_cast( + std::chrono::steady_clock::now() - m_last_received_heartbeat_time + ).count() <= (KI_CLIENT_HEARTBEAT * 2); + } + + void ServerSession::on_connected() + { + m_connection_time = std::chrono::steady_clock::now(); + + // Work out the current timestamp and how many milliseconds + // have elapsed in the current second. + auto now = std::chrono::system_clock::now(); + const auto timestamp = std::chrono::duration_cast( + now.time_since_epoch() + ).count(); + const auto milliseconds = std::chrono::duration_cast( + now.time_since_epoch() + ).count() - (timestamp * 1000); + + // Send a SESSION_OFFER packet to the client + const control::SessionOffer offer(m_id, timestamp, milliseconds); + send_packet(true, (uint8_t)control::Opcode::SESSION_OFFER, offer); + } + + void ServerSession::on_control_message(const PacketHeader& header) + { + switch ((control::Opcode)header.get_opcode()) + { + case control::Opcode::SESSION_ACCEPT: + on_session_accept(); + break; + + case control::Opcode::KEEP_ALIVE: + on_keep_alive(); + break; + + case control::Opcode::KEEP_ALIVE_RSP: + on_keep_alive_response(); + break; + + default: + close(SessionCloseErrorCode::UNHANDLED_CONTROL_MESSAGE); + break; + } + } + + void ServerSession::on_session_accept() + { + // Read the payload data into a structure + control::SessionAccept accept; + try + { + accept = read_data(); + } + catch (parse_error &e) + { + // The SESSION_ACCEPT wasn't valid... + // Close the session + close(SessionCloseErrorCode::INVALID_MESSAGE); + return; + } + + // Should this session have already timed out? + if (std::chrono::duration_cast( + std::chrono::steady_clock::now() - m_connection_time + ).count() > KI_CONNECTION_TIMEOUT) + { + close(SessionCloseErrorCode::SESSION_OFFER_TIMED_OUT); + return; + } + + // Make sure they're accepting this session + if (accept.get_session_id() != m_id) + { + close(SessionCloseErrorCode::INVALID_MESSAGE); + return; + } + + // The session is successfully established + m_established = true; + m_establish_time = std::chrono::steady_clock::now(); + m_last_received_heartbeat_time = m_establish_time; + on_established(); + } + + void ServerSession::on_keep_alive() + { + // Read the payload data into a structure + control::ClientKeepAlive keep_alive; + try + { + keep_alive = read_data(); + } + catch (parse_error &e) + { + // The KEEP_ALIVE wasn't valid... + // Close the session + close(SessionCloseErrorCode::INVALID_MESSAGE); + return; + } + + // Send the response + m_last_received_heartbeat_time = std::chrono::steady_clock::now(); + send_packet(true, (uint8_t)control::Opcode::KEEP_ALIVE_RSP, keep_alive); + } + + void ServerSession::on_keep_alive_response() + { + // Read the payload data into a structure + try + { + // We don't actually need the data inside, but + // read it to check if the structure is right. + read_data(); + } + catch (parse_error &e) + { + // The KEEP_ALIVE_RSP wasn't valid... + // Close the session + close(SessionCloseErrorCode::INVALID_MESSAGE); + return; + } + + // Calculate latency and allow for KEEP_ALIVE packets to be sent again + m_latency = std::chrono::duration_cast( + std::chrono::steady_clock::now() - m_last_sent_heartbeat_time + ).count(); + m_waiting_for_keep_alive_response = false; + } +} +} +} diff --git a/src/protocol/net/Session.cpp b/src/protocol/net/Session.cpp new file mode 100644 index 0000000..9807773 --- /dev/null +++ b/src/protocol/net/Session.cpp @@ -0,0 +1,190 @@ +#include "ki/protocol/net/Session.h" +#include "ki/protocol/exception.h" +#include + +namespace ki +{ +namespace protocol +{ +namespace net +{ + Session::Session(const uint16_t id) + { + m_id = id; + m_established = false; + m_access_level = 0; + m_latency = 0; + m_creation_time = std::chrono::steady_clock::now(); + m_waiting_for_keep_alive_response = false; + + m_maximum_packet_size = KI_DEFAULT_MAXIMUM_RECEIVE_SIZE; + m_receive_state = ReceiveState::WAITING_FOR_START_SIGNAL; + m_start_signal = 0; + m_incoming_packet_size = 0; + m_shift = 0; + } + + uint16_t Session::get_maximum_packet_size() const + { + return m_maximum_packet_size; + } + + void Session::set_maximum_packet_size(const uint16_t maximum_packet_size) + { + m_maximum_packet_size = maximum_packet_size; + } + + uint16_t Session::get_id() const + { + return m_id; + } + + bool Session::is_established() const + { + return m_established; + } + + uint8_t Session::get_access_level() const + { + return m_access_level; + } + + void Session::set_access_level(const uint8_t access_level) + { + m_access_level = access_level; + } + + uint16_t Session::get_latency() const + { + return m_latency; + } + + void Session::send_packet(const bool is_control, const uint8_t opcode, + const util::Serializable& data) + { + std::ostringstream ss; + PacketHeader header(is_control, opcode); + header.write_to(ss); + data.write_to(ss); + + const auto buffer = ss.str(); + send_data(buffer.c_str(), buffer.length()); + } + + void Session::send_data(const char* data, const size_t size) + { + // Allocate the entire buffer + char *packet_data = new char[size + 4]; + + // Add the frame header + ((uint16_t *)packet_data)[0] = KI_START_SIGNAL; + ((uint16_t *)packet_data)[1] = size; + + // Copy the payload into the buffer and send it + std::memcpy(&packet_data[4], data, size); + send_packet_data(packet_data, size + 4); + delete[] packet_data; + } + + void Session::process_data(const char *data, const size_t size) + { + size_t position = 0; + while (position < size) + { + switch (m_receive_state) + { + case ReceiveState::WAITING_FOR_START_SIGNAL: + m_start_signal |= ((uint8_t)data[position] << m_shift); + if (m_shift == 0) + m_shift = 8; + else + { + // If the start signal isn't correct, we've either + // gotten out of sync, or they are not framing packets + // correctly. + if (m_start_signal != KI_START_SIGNAL) + { + close(SessionCloseErrorCode::INVALID_FRAMING_START_SIGNAL); + return; + } + + // Reset the shift and incoming packet size + m_shift = 0; + m_incoming_packet_size = 0; + m_receive_state = ReceiveState::WAITING_FOR_LENGTH; + } + position++; + break; + + case ReceiveState::WAITING_FOR_LENGTH: + m_incoming_packet_size |= ((uint8_t)data[position] << m_shift); + if (m_shift == 0) + m_shift = 8; + else + { + // If the incoming packet is larger than we are accepting + // stop processing data. + if (m_incoming_packet_size > m_maximum_packet_size) + { + close(SessionCloseErrorCode::INVALID_FRAMING_SIZE_EXCEEDS_MAXIMUM); + return; + } + + // Reset read and write positions + m_data_stream.seekp(0, std::ios::beg); + m_data_stream.seekg(0, std::ios::beg); + m_receive_state = ReceiveState::WAITING_FOR_PACKET; + } + position++; + break; + + case ReceiveState::WAITING_FOR_PACKET: + // Work out how much data we should read into our stream + const size_t data_available = (size - position); + const size_t read_size = (data_available >= m_incoming_packet_size) ? + m_incoming_packet_size : data_available; + + // Write the data to the data stream + m_data_stream.write(&data[position], read_size); + position += read_size; + m_incoming_packet_size -= read_size; + + // Have we received the entire packet? + if (m_incoming_packet_size == 0) + { + on_packet_available(); + + // Reset the shift and start signal + m_shift = 0; + m_start_signal = 0; + m_receive_state = ReceiveState::WAITING_FOR_START_SIGNAL; + } + break; + } + } + } + + void Session::on_packet_available() + { + // Read the packet header + PacketHeader header; + try + { + header.read_from(m_data_stream); + } + catch (parse_error &e) + { + on_invalid_packet(); + return; + } + + // Hand off to the right handler based on + // whether this is a control packet or not + if (header.is_control()) + on_control_message(header); + else + on_application_message(header); + } +} +} +} diff --git a/test/src/unit-protocol.cpp b/test/src/unit-protocol.cpp new file mode 100644 index 0000000..2c54a0d --- /dev/null +++ b/test/src/unit-protocol.cpp @@ -0,0 +1,188 @@ +#define CATCH_CONFIG_MAIN +#include +#include + +#include +#include +#include +#include + +using namespace ki::protocol; + +TEST_CASE("Control Message Serialization", "[control]") +{ + std::ostringstream oss; + + SECTION("SessionOffer") + { + control::SessionOffer offer(0xABCD, 0xAABBCCDD, 0xAABBCCDD); + offer.write_to(oss); + + const char expected_bytes[] = { + // Session ID + '\xCD', '\xAB', + + // Unknown + '\x00', '\x00', '\x00', '\x00', + + // Timestamp + '\xDD', '\xCC', '\xBB', '\xAA', + + // Milliseconds + '\xDD', '\xCC', '\xBB', '\xAA' + }; + REQUIRE(oss.str() == std::string(expected_bytes, sizeof(expected_bytes))); + } + + SECTION("SessionAccept") + { + control::SessionAccept accept(0xABCD, 0xAABBCCDD, 0xAABBCCDD); + accept.write_to(oss); + + const char expected_bytes[] = { + // Unknown + '\x00', '\x00', + + // Unknown + '\x00', '\x00', '\x00', '\x00', + + // Timestamp + '\xDD', '\xCC', '\xBB', '\xAA', + + // Milliseconds + '\xDD', '\xCC', '\xBB', '\xAA', + + // Session ID + '\xCD', '\xAB' + }; + REQUIRE(oss.str() == std::string(expected_bytes, sizeof(expected_bytes))); + } + + SECTION("ClientKeepAlive") + { + control::ClientKeepAlive keep_alive(0xABCD, 0xABCD, 0xABCD); + keep_alive.write_to(oss); + + const char expected_bytes[] = { + // Session ID + '\xCD', '\xAB', + + // Milliseconds + '\xCD', '\xAB', + + // Minutes + '\xCD', '\xAB' + }; + REQUIRE(oss.str() == std::string(expected_bytes, sizeof(expected_bytes))); + } + + SECTION("ServerKeepAlive") + { + control::ServerKeepAlive keep_alive(0xAABBCCDD); + keep_alive.write_to(oss); + + const char expected_bytes[] = { + // Unknown + '\x00', '\x00', + + // Timestamp + '\xDD', '\xCC', '\xBB', '\xAA' + }; + REQUIRE(oss.str() == std::string(expected_bytes, sizeof(expected_bytes))); + } +} + +TEST_CASE("Control Message Deserialization", "[control]") +{ + SECTION("SessionOffer") + { + const char bytes[] = { + // Session ID + '\xCD', '\xAB', + + // Unknown + '\x00', '\x00', '\x00', '\x00', + + // Timestamp + '\xDD', '\xCC', '\xBB', '\xAA', + + // Milliseconds + '\xDD', '\xCC', '\xBB', '\xAA' + }; + std::istringstream iss(std::string(bytes, sizeof(bytes))); + + control::SessionOffer offer; + offer.read_from(iss); + + REQUIRE(offer.get_session_id() == 0xABCD); + REQUIRE(offer.get_timestamp() == 0xAABBCCDD); + REQUIRE(offer.get_milliseconds() == 0xAABBCCDD); + } + + SECTION("SessionAccept") + { + const char bytes[] = { + // Unknown + '\x00', '\x00', + + // Unknown + '\x00', '\x00', '\x00', '\x00', + + // Timestamp + '\xDD', '\xCC', '\xBB', '\xAA', + + // Milliseconds + '\xDD', '\xCC', '\xBB', '\xAA', + + // Session ID + '\xCD', '\xAB' + }; + std::istringstream iss(std::string(bytes, sizeof(bytes))); + + control::SessionAccept accept; + accept.read_from(iss); + + REQUIRE(accept.get_session_id() == 0xABCD); + REQUIRE(accept.get_timestamp() == 0xAABBCCDD); + REQUIRE(accept.get_milliseconds() == 0xAABBCCDD); + } + + SECTION("ClientKeepAlive") + { + const char bytes[] = { + // Session ID + '\xCD', '\xAB', + + // Milliseconds + '\xCD', '\xAB', + + // Minutes + '\xCD', '\xAB' + }; + std::istringstream iss(std::string(bytes, sizeof(bytes))); + + control::ClientKeepAlive keep_alive; + keep_alive.read_from(iss); + + REQUIRE(keep_alive.get_session_id() == 0xABCD); + REQUIRE(keep_alive.get_milliseconds() == 0xABCD); + REQUIRE(keep_alive.get_minutes() == 0xABCD); + } + + SECTION("ServerKeepAlive") + { + const char bytes[] = { + // Unknown + '\x00', '\x00', + + // Timestamp + '\xDD', '\xCC', '\xBB', '\xAA' + }; + std::istringstream iss(std::string(bytes, sizeof(bytes))); + + control::ServerKeepAlive keep_alive; + keep_alive.read_from(iss); + + REQUIRE(keep_alive.get_timestamp() == 0xAABBCCDD); + } +}