protocol: Messages must now be created from a MessageTemplate

It shouldn't have been possible to create a Message manually.
This commit is contained in:
Joshua Scott 2018-04-20 01:35:05 +01:00
parent 2632ef563c
commit d1118a937b
13 changed files with 327 additions and 208 deletions

View File

@ -1,4 +1,5 @@
#pragma once #pragma once
#include "MessageHeader.h"
#include "../../util/Serializable.h" #include "../../util/Serializable.h"
#include "../../dml/Record.h" #include "../../dml/Record.h"
#include <iostream> #include <iostream>
@ -9,48 +10,40 @@ namespace protocol
{ {
namespace dml namespace dml
{ {
class MessageTemplate;
class Message final : public util::Serializable class Message final : public util::Serializable
{ {
public: public:
Message(uint8_t service_id = 0, uint8_t type = 0); Message(const MessageTemplate *message_template = nullptr);
virtual ~Message(); virtual ~Message();
uint8_t get_service_id() const; const MessageTemplate *get_template() const;
void set_service_id(uint8_t service_id); void set_template(const MessageTemplate *message_template);
uint8_t get_type() const;
void set_type(uint8_t type);
ki::dml::Record *get_record(); ki::dml::Record *get_record();
const ki::dml::Record *get_record() const; const ki::dml::Record *get_record() const;
/** ki::dml::FieldBase *get_field(std::string name);
* Sets the record to a copy of the specified record. const ki::dml::FieldBase *get_field(std::string name) const;
*/
void set_record(const ki::dml::Record &record);
/** uint8_t get_service_id() const;
* If raw data is present, then this uses the specified record uint8_t get_type() const;
* to parse the raw DML message payload into a new Record. uint16_t get_message_size() const;
* If raw data is not present, this is equivalent to set_record. std::string get_handler() const;
* uint8_t get_access_level() const;
* If the raw data is parsed successfully, the internal raw
* data is cleared, and calls to get_record will return a valid
* Record pointer.
*
* However, if the raw data is not parsed successfully, then
* calls to get_record will still return nullptr.
*/
void use_template_record(const ki::dml::Record &record);
void write_to(std::ostream &ostream) const override final; void write_to(std::ostream &ostream) const override final;
void read_from(std::istream &istream) override final; void read_from(std::istream &istream) override final;
size_t get_size() const override final; size_t get_size() const override final;
private: private:
uint8_t m_service_id; const MessageTemplate *m_template;
uint8_t m_type;
std::vector<char> m_raw_data;
ki::dml::Record *m_record; ki::dml::Record *m_record;
// This is used to store raw data when a Message is
// constructed without a MessageTemplate.
MessageHeader m_header;
std::vector<char> m_raw_data;
}; };
} }
} }

View File

@ -1,42 +0,0 @@
#pragma once
#include "Message.h"
#include "ki/protocol/exception.h"
#include <string>
#include <sstream>
namespace ki
{
namespace protocol
{
namespace dml
{
class MessageBuilder
{
public:
MessageBuilder(uint8_t service_id = 0, uint8_t type = 0);
MessageBuilder &set_service_id(uint8_t service_id);
MessageBuilder &set_message_type(uint8_t type);
MessageBuilder &use_template_record(const ki::dml::Record &record);
template <typename ValueT>
MessageBuilder &set_field_value(std::string name, ValueT value)
{
auto *field = m_message->get_record()->get_field<ValueT>(name);
if (!field)
{
std::ostringstream oss;
oss << "No field with name " << name << " exists with specified type.";
throw value_error(oss.str());
}
field->set_value(value);
return *this;
}
Message *get_message() const;
private:
Message *m_message;
};
}
}
}

View File

@ -0,0 +1,37 @@
#pragma once
#include "../../util/Serializable.h"
#include <iostream>
namespace ki
{
namespace protocol
{
namespace dml
{
class MessageHeader : public util::Serializable
{
public:
MessageHeader(uint8_t service_id = 0,
uint8_t type = 0, uint16_t size = 0);
virtual ~MessageHeader() = default;
uint8_t get_service_id() const;
void set_service_id(uint8_t service_id);
uint8_t get_type() const;
void set_type(uint8_t type);
uint16_t get_message_size() const;
void set_message_size(uint16_t size);
void write_to(std::ostream &ostream) const override final;
void read_from(std::istream &istream) override final;
size_t get_size() const override final;
private:
uint8_t m_service_id;
uint8_t m_type;
uint16_t m_size;
};
}
}
}

View File

@ -1,6 +1,6 @@
#pragma once #pragma once
#include "Message.h"
#include "MessageModule.h" #include "MessageModule.h"
#include "MessageBuilder.h"
#include "../../dml/Record.h" #include "../../dml/Record.h"
#include <string> #include <string>
@ -20,10 +20,10 @@ namespace dml
const MessageModule *get_module(uint8_t service_id) const; const MessageModule *get_module(uint8_t service_id) const;
const MessageModule *get_module(const std::string &protocol_type) const; const MessageModule *get_module(const std::string &protocol_type) const;
MessageBuilder &build_message(uint8_t service_id, uint8_t message_type) const; Message *create_message(uint8_t service_id, uint8_t message_type) const;
MessageBuilder &build_message(uint8_t service_id, const std::string &message_name) const; Message *create_message(uint8_t service_id, const std::string &message_name) const;
MessageBuilder &build_message(const std::string &protocol_type, uint8_t message_type) const; Message *create_message(const std::string &protocol_type, uint8_t message_type) const;
MessageBuilder &build_message(const std::string &protocol_type, const std::string &message_name) const; Message *create_message(const std::string &protocol_type, const std::string &message_name) const;
/** /**
* If the DML message header cannot be read, then a nullptr * If the DML message header cannot be read, then a nullptr

View File

@ -1,4 +1,5 @@
#pragma once #pragma once
#include "Message.h"
#include "MessageTemplate.h" #include "MessageTemplate.h"
#include <cstdint> #include <cstdint>
#include <string> #include <string>
@ -33,8 +34,8 @@ namespace dml
void sort_lookup(); void sort_lookup();
MessageBuilder &build_message(uint8_t message_type) const; Message *create_message(uint8_t message_type) const;
MessageBuilder &build_message(std::string message_name) const; Message *create_message(std::string message_name) const;
private: private:
uint8_t m_service_id; uint8_t m_service_id;
std::string m_protocol_type; std::string m_protocol_type;

View File

@ -1,6 +1,6 @@
#pragma once #pragma once
#include "../../dml/Record.h" #include "../../dml/Record.h"
#include "MessageBuilder.h" #include "Message.h"
#include <string> #include <string>
namespace ki namespace ki
@ -25,10 +25,16 @@ namespace dml
uint8_t get_service_id() const; uint8_t get_service_id() const;
void set_service_id(uint8_t service_id); void set_service_id(uint8_t service_id);
std::string get_handler() const;
void set_handler(std::string handler);
uint8_t get_access_level() const;
void set_access_level(uint8_t access_level);
const ki::dml::Record &get_record() const; const ki::dml::Record &get_record() const;
void set_record(ki::dml::Record *record); void set_record(ki::dml::Record *record);
MessageBuilder &build_message() const; Message *create_message() const;
private: private:
std::string m_name; std::string m_name;
uint8_t m_type; uint8_t m_type;

View File

@ -5,7 +5,7 @@ target_sources(${PROJECT_NAME}
${PROJECT_SOURCE_DIR}/src/protocol/control/SessionAccept.cpp ${PROJECT_SOURCE_DIR}/src/protocol/control/SessionAccept.cpp
${PROJECT_SOURCE_DIR}/src/protocol/control/SessionOffer.cpp ${PROJECT_SOURCE_DIR}/src/protocol/control/SessionOffer.cpp
${PROJECT_SOURCE_DIR}/src/protocol/dml/Message.cpp ${PROJECT_SOURCE_DIR}/src/protocol/dml/Message.cpp
${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageBuilder.cpp ${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageHeader.cpp
${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageManager.cpp ${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageManager.cpp
${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageModule.cpp ${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageModule.cpp
${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageTemplate.cpp ${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageTemplate.cpp

View File

@ -1,4 +1,5 @@
#include "ki/protocol/dml/Message.h" #include "ki/protocol/dml/Message.h"
#include "ki/protocol/dml/MessageTemplate.h"
#include "ki/protocol/exception.h" #include "ki/protocol/exception.h"
namespace ki namespace ki
@ -7,11 +8,13 @@ namespace protocol
{ {
namespace dml namespace dml
{ {
Message::Message(uint8_t service_id, uint8_t type) Message::Message(const MessageTemplate *message_template)
{ {
m_service_id = service_id; m_template = message_template;
m_type = type; if (m_template)
m_record = nullptr; m_record = new ki::dml::Record(m_template->get_record());
else
m_record = nullptr;
} }
Message::~Message() Message::~Message()
@ -19,24 +22,72 @@ namespace dml
delete m_record; delete m_record;
} }
uint8_t Message::get_service_id() const const MessageTemplate *Message::get_template() const
{ {
return m_service_id; return m_template;
} }
void Message::set_service_id(uint8_t service_id) void Message::set_template(const MessageTemplate *message_template)
{ {
m_service_id = service_id; m_template = message_template;
if (!m_template)
return;
m_record = new ki::dml::Record(message_template->get_record());
if (!m_raw_data.empty())
{
std::istringstream iss(std::string(m_raw_data.data(), m_raw_data.size()));
try
{
m_record->read_from(iss);
m_raw_data.clear();
}
catch (ki::dml::parse_error &e)
{
delete m_record;
m_template = nullptr;
m_record = nullptr;
std::ostringstream oss;
oss << "Error reading DML message payload: " << e.what();
throw parse_error(oss.str());
}
}
}
uint8_t Message::get_service_id() const
{
if (m_template)
return m_template->get_service_id();
return m_header.get_service_id();
} }
uint8_t Message::get_type() const uint8_t Message::get_type() const
{ {
return m_type; if (m_template)
return m_template->get_type();
return m_header.get_type();
} }
void Message::set_type(uint8_t type) uint16_t Message::get_message_size() const
{ {
m_type = type; if (m_record)
return m_record->get_size();
return m_raw_data.size();
}
std::string Message::get_handler() const
{
if (m_template)
return m_template->get_handler();
return "";
}
uint8_t Message::get_access_level() const
{
if (m_template)
return m_template->get_access_level();
return 0;
} }
ki::dml::Record *Message::get_record() ki::dml::Record *Message::get_record()
@ -49,82 +100,69 @@ namespace dml
return m_record; return m_record;
} }
void Message::set_record(const ki::dml::Record &record) ki::dml::FieldBase* Message::get_field(std::string name)
{ {
m_record = new ki::dml::Record(record); if (m_record)
return m_record->get_field(name);
return nullptr;
} }
void Message::use_template_record(const ki::dml::Record &record) const ki::dml::FieldBase* Message::get_field(std::string name) const
{ {
set_record(record); if (m_record)
if (!m_raw_data.empty()) return m_record->get_field(name);
{ return nullptr;
std::istringstream iss(std::string(m_raw_data.data(), m_raw_data.size()));
try
{
m_record->read_from(iss);
m_raw_data.clear();
}
catch (ki::dml::parse_error &e)
{
delete m_record;
m_record = nullptr;
std::ostringstream oss;
oss << "Error reading DML message payload: " << e.what();
throw parse_error(oss.str());
}
}
} }
void Message::write_to(std::ostream &ostream) const void Message::write_to(std::ostream &ostream) const
{ {
ki::dml::Record record; // Write the header
record.add_field<ki::dml::UBYT>("m_service_id")->set_value(m_service_id); if (m_template)
record.add_field<ki::dml::UBYT>("m_type")->set_value(m_type); {
auto *size_field = record.add_field<ki::dml::USHRT>("size"); MessageHeader header(
if (m_record) get_service_id(), get_type(), get_message_size());
size_field->set_value(m_record->get_size() + 4); header.write_to(ostream);
}
else else
size_field->set_value(m_raw_data.size() + 4); m_header.write_to(ostream);
record.write_to(ostream);
// Write the payload
if (m_record) if (m_record)
record.write_to(ostream); m_record->write_to(ostream);
else else
ostream.write(m_raw_data.data(), m_raw_data.size()); ostream.write(m_raw_data.data(), m_raw_data.size());
} }
void Message::read_from(std::istream &istream) void Message::read_from(std::istream &istream)
{ {
ki::dml::Record record; m_header.read_from(istream);
auto *service_id_field = record.add_field<ki::dml::UBYT>("ServiceID"); if (m_template)
auto *message_type_field = record.add_field<ki::dml::UBYT>("MsgType");
auto *size_field = record.add_field<ki::dml::USHRT>("Length");
try
{ {
record.read_from(istream); // Check for mismatches between the header and template
} if (m_header.get_service_id() != m_template->get_service_id())
catch (ki::dml::parse_error &e) throw value_error("ServiceID mismatch between MessageHeader and assigned template.");
{ if (m_header.get_type() != m_template->get_type())
std::ostringstream oss; throw value_error("Message Type mismatch between MessageHeader and assigned template.");
oss << "Error reading DML message header: " << e.what();
throw parse_error(oss.str());
}
m_service_id = service_id_field->get_value(); // Read the payload into the record
m_type = message_type_field->get_value(); m_record->read_from(istream);
const ki::dml::USHRT size = size_field->get_value() - 4; }
m_raw_data.resize(size); else
istream.read(m_raw_data.data(), size); {
if (istream.fail()) // We don't have a template for the record structure, so
throw parse_error("Not enough data was available to read DML message payload."); // just read the raw data into a buffer.
const auto size = m_header.get_message_size();
m_raw_data.resize(size);
istream.read(m_raw_data.data(), size);
if (istream.fail())
throw parse_error("Not enough data was available to read DML message payload.");
}
} }
size_t Message::get_size() const size_t Message::get_size() const
{ {
if (m_record) if (m_record)
return 4 + m_record->get_size(); return m_header.get_size() + m_record->get_size();
return 4 + m_raw_data.size(); return 4 + m_raw_data.size();
} }
} }

View File

@ -1,38 +0,0 @@
#include "ki/protocol/dml/MessageBuilder.h"
namespace ki
{
namespace protocol
{
namespace dml
{
MessageBuilder::MessageBuilder(uint8_t service_id, uint8_t type)
{
m_message = new Message(service_id, type);
}
MessageBuilder &MessageBuilder::set_service_id(uint8_t service_id)
{
m_message->set_service_id(service_id);
return *this;
}
MessageBuilder &MessageBuilder::set_message_type(uint8_t type)
{
m_message->set_type(type);
return *this;
}
MessageBuilder &MessageBuilder::use_template_record(const ki::dml::Record& record)
{
m_message->set_record(record);
return *this;
}
Message *MessageBuilder::get_message() const
{
return m_message;
}
}
}
}

View File

@ -0,0 +1,88 @@
#include "ki/protocol/dml/MessageHeader.h"
#include "ki/dml/Record.h"
#include "ki/protocol/exception.h"
namespace ki
{
namespace protocol
{
namespace dml
{
MessageHeader::MessageHeader(const uint8_t service_id,
const uint8_t type, const uint16_t size)
{
m_service_id = service_id;
m_type = type;
m_size = size;
}
uint8_t MessageHeader::get_service_id() const
{
return m_service_id;
}
void MessageHeader::set_service_id(const uint8_t service_id)
{
m_service_id = service_id;
}
uint8_t MessageHeader::get_type() const
{
return m_type;
}
void MessageHeader::set_type(const uint8_t type)
{
m_type = type;
}
uint16_t MessageHeader::get_message_size() const
{
return m_size;
}
void MessageHeader::set_message_size(const uint16_t size)
{
m_size = size;
}
void MessageHeader::write_to(std::ostream& ostream) const
{
ki::dml::Record record;
record.add_field<ki::dml::UBYT>("m_service_id")->set_value(m_service_id);
record.add_field<ki::dml::UBYT>("m_type")->set_value(m_type);
record.add_field<ki::dml::USHRT>("m_size")->set_value(m_size + 4);
record.write_to(ostream);
}
void MessageHeader::read_from(std::istream& istream)
{
ki::dml::Record record;
const auto *service_id = record.add_field<ki::dml::UBYT>("m_service_id");
const auto *type = record.add_field<ki::dml::UBYT>("m_type");
const auto size = record.add_field<ki::dml::USHRT>("m_size");
try
{
record.read_from(istream);
}
catch (ki::dml::parse_error &e)
{
std::ostringstream oss;
oss << "Error reading MessageHeader: " << e.what();
throw parse_error(oss.str());
}
m_service_id = service_id->get_value();
m_type = type->get_value();
m_size = size->get_value() - 4;
}
size_t MessageHeader::get_size() const
{
return sizeof(ki::dml::UBYT) + sizeof(ki::dml::UBYT) +
sizeof(ki::dml::USHRT);
}
}
}
}

View File

@ -1,4 +1,5 @@
#include "ki/protocol/dml/MessageManager.h" #include "ki/protocol/dml/MessageManager.h"
#include "ki/protocol/dml/MessageHeader.h"
#include "ki/protocol/exception.h" #include "ki/protocol/exception.h"
#include "ki/dml/Record.h" #include "ki/dml/Record.h"
#include "ki/util/ValueBytes.h" #include "ki/util/ValueBytes.h"
@ -155,7 +156,7 @@ namespace dml
return nullptr; return nullptr;
} }
MessageBuilder &MessageManager::build_message(uint8_t service_id, uint8_t message_type) const Message *MessageManager::create_message(uint8_t service_id, uint8_t message_type) const
{ {
auto *message_module = get_module(service_id); auto *message_module = get_module(service_id);
if (!message_module) if (!message_module)
@ -165,10 +166,10 @@ namespace dml
throw value_error(oss.str()); throw value_error(oss.str());
} }
return message_module->build_message(message_type); return message_module->create_message(message_type);
} }
MessageBuilder& MessageManager::build_message(uint8_t service_id, const std::string& message_name) const Message *MessageManager::create_message(uint8_t service_id, const std::string& message_name) const
{ {
auto *message_module = get_module(service_id); auto *message_module = get_module(service_id);
if (!message_module) if (!message_module)
@ -178,10 +179,10 @@ namespace dml
throw value_error(oss.str()); throw value_error(oss.str());
} }
return message_module->build_message(message_name); return message_module->create_message(message_name);
} }
MessageBuilder& MessageManager::build_message(const std::string& protocol_type, uint8_t message_type) const Message *MessageManager::create_message(const std::string& protocol_type, uint8_t message_type) const
{ {
auto *message_module = get_module(protocol_type); auto *message_module = get_module(protocol_type);
if (!message_module) if (!message_module)
@ -191,10 +192,10 @@ namespace dml
throw value_error(oss.str()); throw value_error(oss.str());
} }
return message_module->build_message(message_type); return message_module->create_message(message_type);
} }
MessageBuilder& MessageManager::build_message(const std::string& protocol_type, const std::string& message_name) const Message *MessageManager::create_message(const std::string& protocol_type, const std::string& message_name) const
{ {
auto *message_module = get_module(protocol_type); auto *message_module = get_module(protocol_type);
if (!message_module) if (!message_module)
@ -204,35 +205,47 @@ namespace dml
throw value_error(oss.str()); throw value_error(oss.str());
} }
return message_module->build_message(message_name); return message_module->create_message(message_name);
} }
const Message *MessageManager::message_from_binary(std::istream& istream) const const Message *MessageManager::message_from_binary(std::istream& istream) const
{ {
// Read the message header and raw payload // Read the message header
Message *message = new Message(); MessageHeader header;
try try
{ {
message->read_from(istream); header.read_from(istream);
} }
catch (parse_error &e) catch (parse_error &e)
{ {
delete message;
return nullptr; return nullptr;
} }
// Get the message module that uses the specified service id // Get the message module that uses the specified service id
auto *message_module = get_module(message->get_service_id()); auto *message_module = get_module(header.get_service_id());
if (!message_module) if (!message_module)
return message; return nullptr;
// Get the message template for this message type // Get the message template for this message type
auto *message_template = message_module->get_message_template(message->get_type()); auto *message_template = message_module->get_message_template(header.get_type());
if (!message_template) if (!message_template)
return message; return nullptr;
// Parse the raw payload with the template // Make sure that the size specified is enough to read this message
message->use_template_record(message_template->get_record()); if (header.get_message_size() < message_template->get_record().get_size())
return nullptr;
// Create a new Message from the template
auto *message = new Message(message_template);
try
{
message->get_record()->read_from(istream);
}
catch (ki::dml::parse_error &e)
{
delete message;
return nullptr;
}
return message; return message;
} }
} }

View File

@ -139,7 +139,7 @@ namespace dml
} }
} }
MessageBuilder& MessageModule::build_message(uint8_t message_type) const Message *MessageModule::create_message(uint8_t message_type) const
{ {
auto *message_template = get_message_template(message_type); auto *message_template = get_message_template(message_type);
if (!message_template) if (!message_template)
@ -150,10 +150,10 @@ namespace dml
throw value_error(oss.str()); throw value_error(oss.str());
} }
return message_template->build_message(); return message_template->create_message();
} }
MessageBuilder &MessageModule::build_message(std::string message_name) const Message *MessageModule::create_message(std::string message_name) const
{ {
auto *message_template = get_message_template(message_name); auto *message_template = get_message_template(message_name);
if (!message_template) if (!message_template)
@ -164,7 +164,7 @@ namespace dml
throw value_error(oss.str()); throw value_error(oss.str());
} }
return message_template->build_message(); return message_template->create_message();
} }
} }
} }

View File

@ -50,6 +50,32 @@ namespace dml
m_service_id = service_id; m_service_id = service_id;
} }
std::string MessageTemplate::get_handler() const
{
const auto field = m_record->get_field<ki::dml::STR>("_MsgHandler");
if (field)
return field->get_value();
return m_name;
}
void MessageTemplate::set_handler(std::string handler)
{
m_record->add_field<ki::dml::STR>("_MsgHandler")->set_value(handler);
}
uint8_t MessageTemplate::get_access_level() const
{
const auto field = m_record->get_field<ki::dml::UBYT>("_MsgAccessLvl");
if (field)
return field->get_value();
return 0;
}
void MessageTemplate::set_access_level(uint8_t access_level)
{
m_record->add_field<ki::dml::UBYT>("_MsgAccessLvl")->set_value(access_level);
}
const ki::dml::Record& MessageTemplate::get_record() const const ki::dml::Record& MessageTemplate::get_record() const
{ {
return *m_record; return *m_record;
@ -60,12 +86,9 @@ namespace dml
m_record = record; m_record = record;
} }
MessageBuilder &MessageTemplate::build_message() const Message *MessageTemplate::create_message() const
{ {
return MessageBuilder() return new Message(this);
.set_message_type(m_type)
.set_service_id(m_service_id)
.use_template_record(*m_record);
} }
} }
} }