diff --git a/include/ki/protocol/dml/Message.h b/include/ki/protocol/dml/Message.h index 096c96c..bddbb35 100644 --- a/include/ki/protocol/dml/Message.h +++ b/include/ki/protocol/dml/Message.h @@ -1,4 +1,5 @@ #pragma once +#include "MessageHeader.h" #include "../../util/Serializable.h" #include "../../dml/Record.h" #include @@ -9,49 +10,41 @@ namespace protocol { namespace dml { + class MessageTemplate; + class Message final : public util::Serializable { public: - Message(uint8_t service_id = 0, uint8_t type = 0); + Message(const MessageTemplate *message_template = nullptr); virtual ~Message(); - uint8_t get_service_id() const; - void set_service_id(uint8_t service_id); - - uint8_t get_type() const; - void set_type(uint8_t type); + const MessageTemplate *get_template() const; + void set_template(const MessageTemplate *message_template); ki::dml::Record *get_record(); const ki::dml::Record *get_record() const; - /** - * Sets the record to a copy of the specified record. - */ - void set_record(const ki::dml::Record &record); + ki::dml::FieldBase *get_field(std::string name); + const ki::dml::FieldBase *get_field(std::string name) const; - /** - * If raw data is present, then this uses the specified record - * to parse the raw DML message payload into a new Record. - * If raw data is not present, this is equivalent to set_record. - * - * If the raw data is parsed successfully, the internal raw - * data is cleared, and calls to get_record will return a valid - * Record pointer. - * - * However, if the raw data is not parsed successfully, then - * calls to get_record will still return nullptr. - */ - void use_template_record(const ki::dml::Record &record); + uint8_t get_service_id() const; + uint8_t get_type() const; + uint16_t get_message_size() const; + std::string get_handler() const; + uint8_t get_access_level() const; void write_to(std::ostream &ostream) const override final; void read_from(std::istream &istream) override final; size_t get_size() const override final; private: - uint8_t m_service_id; - uint8_t m_type; - std::vector m_raw_data; + const MessageTemplate *m_template; ki::dml::Record *m_record; + + // This is used to store raw data when a Message is + // constructed without a MessageTemplate. + MessageHeader m_header; + std::vector m_raw_data; }; } } -} \ No newline at end of file +} diff --git a/include/ki/protocol/dml/MessageBuilder.h b/include/ki/protocol/dml/MessageBuilder.h deleted file mode 100644 index 33995eb..0000000 --- a/include/ki/protocol/dml/MessageBuilder.h +++ /dev/null @@ -1,42 +0,0 @@ -#pragma once -#include "Message.h" -#include "ki/protocol/exception.h" -#include -#include - -namespace ki -{ -namespace protocol -{ -namespace dml -{ - class MessageBuilder - { - public: - MessageBuilder(uint8_t service_id = 0, uint8_t type = 0); - - MessageBuilder &set_service_id(uint8_t service_id); - MessageBuilder &set_message_type(uint8_t type); - MessageBuilder &use_template_record(const ki::dml::Record &record); - - template - MessageBuilder &set_field_value(std::string name, ValueT value) - { - auto *field = m_message->get_record()->get_field(name); - if (!field) - { - std::ostringstream oss; - oss << "No field with name " << name << " exists with specified type."; - throw value_error(oss.str()); - } - field->set_value(value); - return *this; - } - - Message *get_message() const; - private: - Message *m_message; - }; -} -} -} diff --git a/include/ki/protocol/dml/MessageHeader.h b/include/ki/protocol/dml/MessageHeader.h new file mode 100644 index 0000000..7bda780 --- /dev/null +++ b/include/ki/protocol/dml/MessageHeader.h @@ -0,0 +1,37 @@ +#pragma once +#include "../../util/Serializable.h" +#include + +namespace ki +{ +namespace protocol +{ +namespace dml +{ + class MessageHeader : public util::Serializable + { + public: + MessageHeader(uint8_t service_id = 0, + uint8_t type = 0, uint16_t size = 0); + virtual ~MessageHeader() = default; + + uint8_t get_service_id() const; + void set_service_id(uint8_t service_id); + + uint8_t get_type() const; + void set_type(uint8_t type); + + uint16_t get_message_size() const; + void set_message_size(uint16_t size); + + void write_to(std::ostream &ostream) const override final; + void read_from(std::istream &istream) override final; + size_t get_size() const override final; + private: + uint8_t m_service_id; + uint8_t m_type; + uint16_t m_size; + }; +} +} +} \ No newline at end of file diff --git a/include/ki/protocol/dml/MessageManager.h b/include/ki/protocol/dml/MessageManager.h index 75e9e16..a051ba2 100644 --- a/include/ki/protocol/dml/MessageManager.h +++ b/include/ki/protocol/dml/MessageManager.h @@ -1,6 +1,6 @@ #pragma once +#include "Message.h" #include "MessageModule.h" -#include "MessageBuilder.h" #include "../../dml/Record.h" #include @@ -20,10 +20,10 @@ namespace dml const MessageModule *get_module(uint8_t service_id) const; const MessageModule *get_module(const std::string &protocol_type) const; - MessageBuilder &build_message(uint8_t service_id, uint8_t message_type) const; - MessageBuilder &build_message(uint8_t service_id, const std::string &message_name) const; - MessageBuilder &build_message(const std::string &protocol_type, uint8_t message_type) const; - MessageBuilder &build_message(const std::string &protocol_type, const std::string &message_name) const; + Message *create_message(uint8_t service_id, uint8_t message_type) const; + Message *create_message(uint8_t service_id, const std::string &message_name) const; + Message *create_message(const std::string &protocol_type, uint8_t message_type) const; + Message *create_message(const std::string &protocol_type, const std::string &message_name) const; /** * If the DML message header cannot be read, then a nullptr diff --git a/include/ki/protocol/dml/MessageModule.h b/include/ki/protocol/dml/MessageModule.h index 3b763c3..c58e080 100644 --- a/include/ki/protocol/dml/MessageModule.h +++ b/include/ki/protocol/dml/MessageModule.h @@ -1,4 +1,5 @@ #pragma once +#include "Message.h" #include "MessageTemplate.h" #include #include @@ -33,8 +34,8 @@ namespace dml void sort_lookup(); - MessageBuilder &build_message(uint8_t message_type) const; - MessageBuilder &build_message(std::string message_name) const; + Message *create_message(uint8_t message_type) const; + Message *create_message(std::string message_name) const; private: uint8_t m_service_id; std::string m_protocol_type; diff --git a/include/ki/protocol/dml/MessageTemplate.h b/include/ki/protocol/dml/MessageTemplate.h index 5cf3375..3f82ab2 100644 --- a/include/ki/protocol/dml/MessageTemplate.h +++ b/include/ki/protocol/dml/MessageTemplate.h @@ -1,6 +1,6 @@ #pragma once #include "../../dml/Record.h" -#include "MessageBuilder.h" +#include "Message.h" #include namespace ki @@ -25,10 +25,16 @@ namespace dml uint8_t get_service_id() const; void set_service_id(uint8_t service_id); + std::string get_handler() const; + void set_handler(std::string handler); + + uint8_t get_access_level() const; + void set_access_level(uint8_t access_level); + const ki::dml::Record &get_record() const; void set_record(ki::dml::Record *record); - MessageBuilder &build_message() const; + Message *create_message() const; private: std::string m_name; uint8_t m_type; diff --git a/src/protocol/CMakeLists.txt b/src/protocol/CMakeLists.txt index 79c513c..191b07e 100644 --- a/src/protocol/CMakeLists.txt +++ b/src/protocol/CMakeLists.txt @@ -5,7 +5,7 @@ target_sources(${PROJECT_NAME} ${PROJECT_SOURCE_DIR}/src/protocol/control/SessionAccept.cpp ${PROJECT_SOURCE_DIR}/src/protocol/control/SessionOffer.cpp ${PROJECT_SOURCE_DIR}/src/protocol/dml/Message.cpp - ${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageBuilder.cpp + ${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageHeader.cpp ${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageManager.cpp ${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageModule.cpp ${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageTemplate.cpp diff --git a/src/protocol/dml/Message.cpp b/src/protocol/dml/Message.cpp index 107dc56..27d4110 100644 --- a/src/protocol/dml/Message.cpp +++ b/src/protocol/dml/Message.cpp @@ -1,4 +1,5 @@ #include "ki/protocol/dml/Message.h" +#include "ki/protocol/dml/MessageTemplate.h" #include "ki/protocol/exception.h" namespace ki @@ -7,11 +8,13 @@ namespace protocol { namespace dml { - Message::Message(uint8_t service_id, uint8_t type) + Message::Message(const MessageTemplate *message_template) { - m_service_id = service_id; - m_type = type; - m_record = nullptr; + m_template = message_template; + if (m_template) + m_record = new ki::dml::Record(m_template->get_record()); + else + m_record = nullptr; } Message::~Message() @@ -19,24 +22,72 @@ namespace dml delete m_record; } - uint8_t Message::get_service_id() const + const MessageTemplate *Message::get_template() const { - return m_service_id; + return m_template; } - void Message::set_service_id(uint8_t service_id) + void Message::set_template(const MessageTemplate *message_template) { - m_service_id = service_id; + m_template = message_template; + if (!m_template) + return; + + m_record = new ki::dml::Record(message_template->get_record()); + if (!m_raw_data.empty()) + { + std::istringstream iss(std::string(m_raw_data.data(), m_raw_data.size())); + try + { + m_record->read_from(iss); + m_raw_data.clear(); + } + catch (ki::dml::parse_error &e) + { + delete m_record; + m_template = nullptr; + m_record = nullptr; + + std::ostringstream oss; + oss << "Error reading DML message payload: " << e.what(); + throw parse_error(oss.str()); + } + } + } + + uint8_t Message::get_service_id() const + { + if (m_template) + return m_template->get_service_id(); + return m_header.get_service_id(); } uint8_t Message::get_type() const { - return m_type; + if (m_template) + return m_template->get_type(); + return m_header.get_type(); } - void Message::set_type(uint8_t type) + uint16_t Message::get_message_size() const { - m_type = type; + if (m_record) + return m_record->get_size(); + return m_raw_data.size(); + } + + std::string Message::get_handler() const + { + if (m_template) + return m_template->get_handler(); + return ""; + } + + uint8_t Message::get_access_level() const + { + if (m_template) + return m_template->get_access_level(); + return 0; } ki::dml::Record *Message::get_record() @@ -49,82 +100,69 @@ namespace dml return m_record; } - void Message::set_record(const ki::dml::Record &record) + ki::dml::FieldBase* Message::get_field(std::string name) { - m_record = new ki::dml::Record(record); + if (m_record) + return m_record->get_field(name); + return nullptr; } - void Message::use_template_record(const ki::dml::Record &record) + const ki::dml::FieldBase* Message::get_field(std::string name) const { - set_record(record); - if (!m_raw_data.empty()) - { - std::istringstream iss(std::string(m_raw_data.data(), m_raw_data.size())); - try - { - m_record->read_from(iss); - m_raw_data.clear(); - } - catch (ki::dml::parse_error &e) - { - delete m_record; - m_record = nullptr; - - std::ostringstream oss; - oss << "Error reading DML message payload: " << e.what(); - throw parse_error(oss.str()); - } - } + if (m_record) + return m_record->get_field(name); + return nullptr; } void Message::write_to(std::ostream &ostream) const { - ki::dml::Record record; - record.add_field("m_service_id")->set_value(m_service_id); - record.add_field("m_type")->set_value(m_type); - auto *size_field = record.add_field("size"); - if (m_record) - size_field->set_value(m_record->get_size() + 4); + // Write the header + if (m_template) + { + MessageHeader header( + get_service_id(), get_type(), get_message_size()); + header.write_to(ostream); + } else - size_field->set_value(m_raw_data.size() + 4); - record.write_to(ostream); + m_header.write_to(ostream); + // Write the payload if (m_record) - record.write_to(ostream); + m_record->write_to(ostream); else ostream.write(m_raw_data.data(), m_raw_data.size()); } void Message::read_from(std::istream &istream) { - ki::dml::Record record; - auto *service_id_field = record.add_field("ServiceID"); - auto *message_type_field = record.add_field("MsgType"); - auto *size_field = record.add_field("Length"); - try + m_header.read_from(istream); + if (m_template) { - record.read_from(istream); - } - catch (ki::dml::parse_error &e) - { - std::ostringstream oss; - oss << "Error reading DML message header: " << e.what(); - throw parse_error(oss.str()); - } + // Check for mismatches between the header and template + if (m_header.get_service_id() != m_template->get_service_id()) + throw value_error("ServiceID mismatch between MessageHeader and assigned template."); + if (m_header.get_type() != m_template->get_type()) + throw value_error("Message Type mismatch between MessageHeader and assigned template."); - m_service_id = service_id_field->get_value(); - m_type = message_type_field->get_value(); - const ki::dml::USHRT size = size_field->get_value() - 4; - m_raw_data.resize(size); - istream.read(m_raw_data.data(), size); - if (istream.fail()) - throw parse_error("Not enough data was available to read DML message payload."); + // Read the payload into the record + m_record->read_from(istream); + } + else + { + // We don't have a template for the record structure, so + // just read the raw data into a buffer. + const auto size = m_header.get_message_size(); + m_raw_data.resize(size); + istream.read(m_raw_data.data(), size); + if (istream.fail()) + throw parse_error("Not enough data was available to read DML message payload."); + } } size_t Message::get_size() const { if (m_record) - return 4 + m_record->get_size(); + return m_header.get_size() + m_record->get_size(); return 4 + m_raw_data.size(); } } diff --git a/src/protocol/dml/MessageBuilder.cpp b/src/protocol/dml/MessageBuilder.cpp deleted file mode 100644 index aaed66f..0000000 --- a/src/protocol/dml/MessageBuilder.cpp +++ /dev/null @@ -1,38 +0,0 @@ -#include "ki/protocol/dml/MessageBuilder.h" - -namespace ki -{ -namespace protocol -{ -namespace dml -{ - MessageBuilder::MessageBuilder(uint8_t service_id, uint8_t type) - { - m_message = new Message(service_id, type); - } - - MessageBuilder &MessageBuilder::set_service_id(uint8_t service_id) - { - m_message->set_service_id(service_id); - return *this; - } - - MessageBuilder &MessageBuilder::set_message_type(uint8_t type) - { - m_message->set_type(type); - return *this; - } - - MessageBuilder &MessageBuilder::use_template_record(const ki::dml::Record& record) - { - m_message->set_record(record); - return *this; - } - - Message *MessageBuilder::get_message() const - { - return m_message; - } -} -} -} \ No newline at end of file diff --git a/src/protocol/dml/MessageHeader.cpp b/src/protocol/dml/MessageHeader.cpp new file mode 100644 index 0000000..46290a3 --- /dev/null +++ b/src/protocol/dml/MessageHeader.cpp @@ -0,0 +1,88 @@ +#include "ki/protocol/dml/MessageHeader.h" +#include "ki/dml/Record.h" +#include "ki/protocol/exception.h" + +namespace ki +{ +namespace protocol +{ +namespace dml +{ + MessageHeader::MessageHeader(const uint8_t service_id, + const uint8_t type, const uint16_t size) + { + m_service_id = service_id; + m_type = type; + m_size = size; + } + + uint8_t MessageHeader::get_service_id() const + { + return m_service_id; + } + + void MessageHeader::set_service_id(const uint8_t service_id) + { + m_service_id = service_id; + } + + uint8_t MessageHeader::get_type() const + { + return m_type; + } + + void MessageHeader::set_type(const uint8_t type) + { + m_type = type; + } + + uint16_t MessageHeader::get_message_size() const + { + return m_size; + } + + void MessageHeader::set_message_size(const uint16_t size) + { + m_size = size; + } + + void MessageHeader::write_to(std::ostream& ostream) const + { + ki::dml::Record record; + record.add_field("m_service_id")->set_value(m_service_id); + record.add_field("m_type")->set_value(m_type); + record.add_field("m_size")->set_value(m_size + 4); + record.write_to(ostream); + } + + void MessageHeader::read_from(std::istream& istream) + { + ki::dml::Record record; + const auto *service_id = record.add_field("m_service_id"); + const auto *type = record.add_field("m_type"); + const auto size = record.add_field("m_size"); + + try + { + record.read_from(istream); + } + catch (ki::dml::parse_error &e) + { + std::ostringstream oss; + oss << "Error reading MessageHeader: " << e.what(); + throw parse_error(oss.str()); + } + + m_service_id = service_id->get_value(); + m_type = type->get_value(); + m_size = size->get_value() - 4; + } + + size_t MessageHeader::get_size() const + { + return sizeof(ki::dml::UBYT) + sizeof(ki::dml::UBYT) + + sizeof(ki::dml::USHRT); + } +} +} +} diff --git a/src/protocol/dml/MessageManager.cpp b/src/protocol/dml/MessageManager.cpp index 90903ae..328e597 100644 --- a/src/protocol/dml/MessageManager.cpp +++ b/src/protocol/dml/MessageManager.cpp @@ -1,4 +1,5 @@ #include "ki/protocol/dml/MessageManager.h" +#include "ki/protocol/dml/MessageHeader.h" #include "ki/protocol/exception.h" #include "ki/dml/Record.h" #include "ki/util/ValueBytes.h" @@ -155,7 +156,7 @@ namespace dml return nullptr; } - MessageBuilder &MessageManager::build_message(uint8_t service_id, uint8_t message_type) const + Message *MessageManager::create_message(uint8_t service_id, uint8_t message_type) const { auto *message_module = get_module(service_id); if (!message_module) @@ -165,10 +166,10 @@ namespace dml throw value_error(oss.str()); } - return message_module->build_message(message_type); + return message_module->create_message(message_type); } - MessageBuilder& MessageManager::build_message(uint8_t service_id, const std::string& message_name) const + Message *MessageManager::create_message(uint8_t service_id, const std::string& message_name) const { auto *message_module = get_module(service_id); if (!message_module) @@ -178,10 +179,10 @@ namespace dml throw value_error(oss.str()); } - return message_module->build_message(message_name); + return message_module->create_message(message_name); } - MessageBuilder& MessageManager::build_message(const std::string& protocol_type, uint8_t message_type) const + Message *MessageManager::create_message(const std::string& protocol_type, uint8_t message_type) const { auto *message_module = get_module(protocol_type); if (!message_module) @@ -191,10 +192,10 @@ namespace dml throw value_error(oss.str()); } - return message_module->build_message(message_type); + return message_module->create_message(message_type); } - MessageBuilder& MessageManager::build_message(const std::string& protocol_type, const std::string& message_name) const + Message *MessageManager::create_message(const std::string& protocol_type, const std::string& message_name) const { auto *message_module = get_module(protocol_type); if (!message_module) @@ -204,37 +205,49 @@ namespace dml throw value_error(oss.str()); } - return message_module->build_message(message_name); + return message_module->create_message(message_name); } const Message *MessageManager::message_from_binary(std::istream& istream) const { - // Read the message header and raw payload - Message *message = new Message(); + // Read the message header + MessageHeader header; try { - message->read_from(istream); + header.read_from(istream); } catch (parse_error &e) { - delete message; return nullptr; } // Get the message module that uses the specified service id - auto *message_module = get_module(message->get_service_id()); + auto *message_module = get_module(header.get_service_id()); if (!message_module) - return message; + return nullptr; // Get the message template for this message type - auto *message_template = message_module->get_message_template(message->get_type()); + auto *message_template = message_module->get_message_template(header.get_type()); if (!message_template) - return message; + return nullptr; - // Parse the raw payload with the template - message->use_template_record(message_template->get_record()); + // Make sure that the size specified is enough to read this message + if (header.get_message_size() < message_template->get_record().get_size()) + return nullptr; + + // Create a new Message from the template + auto *message = new Message(message_template); + try + { + message->get_record()->read_from(istream); + } + catch (ki::dml::parse_error &e) + { + delete message; + return nullptr; + } return message; } } } -} \ No newline at end of file +} diff --git a/src/protocol/dml/MessageModule.cpp b/src/protocol/dml/MessageModule.cpp index 6345507..11a48c0 100644 --- a/src/protocol/dml/MessageModule.cpp +++ b/src/protocol/dml/MessageModule.cpp @@ -139,7 +139,7 @@ namespace dml } } - MessageBuilder& MessageModule::build_message(uint8_t message_type) const + Message *MessageModule::create_message(uint8_t message_type) const { auto *message_template = get_message_template(message_type); if (!message_template) @@ -150,10 +150,10 @@ namespace dml throw value_error(oss.str()); } - return message_template->build_message(); + return message_template->create_message(); } - MessageBuilder &MessageModule::build_message(std::string message_name) const + Message *MessageModule::create_message(std::string message_name) const { auto *message_template = get_message_template(message_name); if (!message_template) @@ -164,7 +164,7 @@ namespace dml throw value_error(oss.str()); } - return message_template->build_message(); + return message_template->create_message(); } } } diff --git a/src/protocol/dml/MessageTemplate.cpp b/src/protocol/dml/MessageTemplate.cpp index e89db88..9e79929 100644 --- a/src/protocol/dml/MessageTemplate.cpp +++ b/src/protocol/dml/MessageTemplate.cpp @@ -50,6 +50,32 @@ namespace dml m_service_id = service_id; } + std::string MessageTemplate::get_handler() const + { + const auto field = m_record->get_field("_MsgHandler"); + if (field) + return field->get_value(); + return m_name; + } + + void MessageTemplate::set_handler(std::string handler) + { + m_record->add_field("_MsgHandler")->set_value(handler); + } + + uint8_t MessageTemplate::get_access_level() const + { + const auto field = m_record->get_field("_MsgAccessLvl"); + if (field) + return field->get_value(); + return 0; + } + + void MessageTemplate::set_access_level(uint8_t access_level) + { + m_record->add_field("_MsgAccessLvl")->set_value(access_level); + } + const ki::dml::Record& MessageTemplate::get_record() const { return *m_record; @@ -60,12 +86,9 @@ namespace dml m_record = record; } - MessageBuilder &MessageTemplate::build_message() const + Message *MessageTemplate::create_message() const { - return MessageBuilder() - .set_message_type(m_type) - .set_service_id(m_service_id) - .use_template_record(*m_record); + return new Message(this); } } }