protocol: Messages must now be created from a MessageTemplate

It shouldn't have been possible to create a Message manually.
This commit is contained in:
Joshua Scott 2018-04-20 01:35:05 +01:00
parent 2632ef563c
commit d1118a937b
13 changed files with 327 additions and 208 deletions

View File

@ -1,4 +1,5 @@
#pragma once
#include "MessageHeader.h"
#include "../../util/Serializable.h"
#include "../../dml/Record.h"
#include <iostream>
@ -9,48 +10,40 @@ namespace protocol
{
namespace dml
{
class MessageTemplate;
class Message final : public util::Serializable
{
public:
Message(uint8_t service_id = 0, uint8_t type = 0);
Message(const MessageTemplate *message_template = nullptr);
virtual ~Message();
uint8_t get_service_id() const;
void set_service_id(uint8_t service_id);
uint8_t get_type() const;
void set_type(uint8_t type);
const MessageTemplate *get_template() const;
void set_template(const MessageTemplate *message_template);
ki::dml::Record *get_record();
const ki::dml::Record *get_record() const;
/**
* Sets the record to a copy of the specified record.
*/
void set_record(const ki::dml::Record &record);
ki::dml::FieldBase *get_field(std::string name);
const ki::dml::FieldBase *get_field(std::string name) const;
/**
* If raw data is present, then this uses the specified record
* to parse the raw DML message payload into a new Record.
* If raw data is not present, this is equivalent to set_record.
*
* If the raw data is parsed successfully, the internal raw
* data is cleared, and calls to get_record will return a valid
* Record pointer.
*
* However, if the raw data is not parsed successfully, then
* calls to get_record will still return nullptr.
*/
void use_template_record(const ki::dml::Record &record);
uint8_t get_service_id() const;
uint8_t get_type() const;
uint16_t get_message_size() const;
std::string get_handler() const;
uint8_t get_access_level() const;
void write_to(std::ostream &ostream) const override final;
void read_from(std::istream &istream) override final;
size_t get_size() const override final;
private:
uint8_t m_service_id;
uint8_t m_type;
std::vector<char> m_raw_data;
const MessageTemplate *m_template;
ki::dml::Record *m_record;
// This is used to store raw data when a Message is
// constructed without a MessageTemplate.
MessageHeader m_header;
std::vector<char> m_raw_data;
};
}
}

View File

@ -1,42 +0,0 @@
#pragma once
#include "Message.h"
#include "ki/protocol/exception.h"
#include <string>
#include <sstream>
namespace ki
{
namespace protocol
{
namespace dml
{
class MessageBuilder
{
public:
MessageBuilder(uint8_t service_id = 0, uint8_t type = 0);
MessageBuilder &set_service_id(uint8_t service_id);
MessageBuilder &set_message_type(uint8_t type);
MessageBuilder &use_template_record(const ki::dml::Record &record);
template <typename ValueT>
MessageBuilder &set_field_value(std::string name, ValueT value)
{
auto *field = m_message->get_record()->get_field<ValueT>(name);
if (!field)
{
std::ostringstream oss;
oss << "No field with name " << name << " exists with specified type.";
throw value_error(oss.str());
}
field->set_value(value);
return *this;
}
Message *get_message() const;
private:
Message *m_message;
};
}
}
}

View File

@ -0,0 +1,37 @@
#pragma once
#include "../../util/Serializable.h"
#include <iostream>
namespace ki
{
namespace protocol
{
namespace dml
{
class MessageHeader : public util::Serializable
{
public:
MessageHeader(uint8_t service_id = 0,
uint8_t type = 0, uint16_t size = 0);
virtual ~MessageHeader() = default;
uint8_t get_service_id() const;
void set_service_id(uint8_t service_id);
uint8_t get_type() const;
void set_type(uint8_t type);
uint16_t get_message_size() const;
void set_message_size(uint16_t size);
void write_to(std::ostream &ostream) const override final;
void read_from(std::istream &istream) override final;
size_t get_size() const override final;
private:
uint8_t m_service_id;
uint8_t m_type;
uint16_t m_size;
};
}
}
}

View File

@ -1,6 +1,6 @@
#pragma once
#include "Message.h"
#include "MessageModule.h"
#include "MessageBuilder.h"
#include "../../dml/Record.h"
#include <string>
@ -20,10 +20,10 @@ namespace dml
const MessageModule *get_module(uint8_t service_id) const;
const MessageModule *get_module(const std::string &protocol_type) const;
MessageBuilder &build_message(uint8_t service_id, uint8_t message_type) const;
MessageBuilder &build_message(uint8_t service_id, const std::string &message_name) const;
MessageBuilder &build_message(const std::string &protocol_type, uint8_t message_type) const;
MessageBuilder &build_message(const std::string &protocol_type, const std::string &message_name) const;
Message *create_message(uint8_t service_id, uint8_t message_type) const;
Message *create_message(uint8_t service_id, const std::string &message_name) const;
Message *create_message(const std::string &protocol_type, uint8_t message_type) const;
Message *create_message(const std::string &protocol_type, const std::string &message_name) const;
/**
* If the DML message header cannot be read, then a nullptr

View File

@ -1,4 +1,5 @@
#pragma once
#include "Message.h"
#include "MessageTemplate.h"
#include <cstdint>
#include <string>
@ -33,8 +34,8 @@ namespace dml
void sort_lookup();
MessageBuilder &build_message(uint8_t message_type) const;
MessageBuilder &build_message(std::string message_name) const;
Message *create_message(uint8_t message_type) const;
Message *create_message(std::string message_name) const;
private:
uint8_t m_service_id;
std::string m_protocol_type;

View File

@ -1,6 +1,6 @@
#pragma once
#include "../../dml/Record.h"
#include "MessageBuilder.h"
#include "Message.h"
#include <string>
namespace ki
@ -25,10 +25,16 @@ namespace dml
uint8_t get_service_id() const;
void set_service_id(uint8_t service_id);
std::string get_handler() const;
void set_handler(std::string handler);
uint8_t get_access_level() const;
void set_access_level(uint8_t access_level);
const ki::dml::Record &get_record() const;
void set_record(ki::dml::Record *record);
MessageBuilder &build_message() const;
Message *create_message() const;
private:
std::string m_name;
uint8_t m_type;

View File

@ -5,7 +5,7 @@ target_sources(${PROJECT_NAME}
${PROJECT_SOURCE_DIR}/src/protocol/control/SessionAccept.cpp
${PROJECT_SOURCE_DIR}/src/protocol/control/SessionOffer.cpp
${PROJECT_SOURCE_DIR}/src/protocol/dml/Message.cpp
${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageBuilder.cpp
${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageHeader.cpp
${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageManager.cpp
${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageModule.cpp
${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageTemplate.cpp

View File

@ -1,4 +1,5 @@
#include "ki/protocol/dml/Message.h"
#include "ki/protocol/dml/MessageTemplate.h"
#include "ki/protocol/exception.h"
namespace ki
@ -7,10 +8,12 @@ namespace protocol
{
namespace dml
{
Message::Message(uint8_t service_id, uint8_t type)
Message::Message(const MessageTemplate *message_template)
{
m_service_id = service_id;
m_type = type;
m_template = message_template;
if (m_template)
m_record = new ki::dml::Record(m_template->get_record());
else
m_record = nullptr;
}
@ -19,24 +22,72 @@ namespace dml
delete m_record;
}
uint8_t Message::get_service_id() const
const MessageTemplate *Message::get_template() const
{
return m_service_id;
return m_template;
}
void Message::set_service_id(uint8_t service_id)
void Message::set_template(const MessageTemplate *message_template)
{
m_service_id = service_id;
m_template = message_template;
if (!m_template)
return;
m_record = new ki::dml::Record(message_template->get_record());
if (!m_raw_data.empty())
{
std::istringstream iss(std::string(m_raw_data.data(), m_raw_data.size()));
try
{
m_record->read_from(iss);
m_raw_data.clear();
}
catch (ki::dml::parse_error &e)
{
delete m_record;
m_template = nullptr;
m_record = nullptr;
std::ostringstream oss;
oss << "Error reading DML message payload: " << e.what();
throw parse_error(oss.str());
}
}
}
uint8_t Message::get_service_id() const
{
if (m_template)
return m_template->get_service_id();
return m_header.get_service_id();
}
uint8_t Message::get_type() const
{
return m_type;
if (m_template)
return m_template->get_type();
return m_header.get_type();
}
void Message::set_type(uint8_t type)
uint16_t Message::get_message_size() const
{
m_type = type;
if (m_record)
return m_record->get_size();
return m_raw_data.size();
}
std::string Message::get_handler() const
{
if (m_template)
return m_template->get_handler();
return "";
}
uint8_t Message::get_access_level() const
{
if (m_template)
return m_template->get_access_level();
return 0;
}
ki::dml::Record *Message::get_record()
@ -49,82 +100,69 @@ namespace dml
return m_record;
}
void Message::set_record(const ki::dml::Record &record)
ki::dml::FieldBase* Message::get_field(std::string name)
{
m_record = new ki::dml::Record(record);
if (m_record)
return m_record->get_field(name);
return nullptr;
}
void Message::use_template_record(const ki::dml::Record &record)
const ki::dml::FieldBase* Message::get_field(std::string name) const
{
set_record(record);
if (!m_raw_data.empty())
{
std::istringstream iss(std::string(m_raw_data.data(), m_raw_data.size()));
try
{
m_record->read_from(iss);
m_raw_data.clear();
}
catch (ki::dml::parse_error &e)
{
delete m_record;
m_record = nullptr;
std::ostringstream oss;
oss << "Error reading DML message payload: " << e.what();
throw parse_error(oss.str());
}
}
if (m_record)
return m_record->get_field(name);
return nullptr;
}
void Message::write_to(std::ostream &ostream) const
{
ki::dml::Record record;
record.add_field<ki::dml::UBYT>("m_service_id")->set_value(m_service_id);
record.add_field<ki::dml::UBYT>("m_type")->set_value(m_type);
auto *size_field = record.add_field<ki::dml::USHRT>("size");
if (m_record)
size_field->set_value(m_record->get_size() + 4);
// Write the header
if (m_template)
{
MessageHeader header(
get_service_id(), get_type(), get_message_size());
header.write_to(ostream);
}
else
size_field->set_value(m_raw_data.size() + 4);
record.write_to(ostream);
m_header.write_to(ostream);
// Write the payload
if (m_record)
record.write_to(ostream);
m_record->write_to(ostream);
else
ostream.write(m_raw_data.data(), m_raw_data.size());
}
void Message::read_from(std::istream &istream)
{
ki::dml::Record record;
auto *service_id_field = record.add_field<ki::dml::UBYT>("ServiceID");
auto *message_type_field = record.add_field<ki::dml::UBYT>("MsgType");
auto *size_field = record.add_field<ki::dml::USHRT>("Length");
try
m_header.read_from(istream);
if (m_template)
{
record.read_from(istream);
}
catch (ki::dml::parse_error &e)
{
std::ostringstream oss;
oss << "Error reading DML message header: " << e.what();
throw parse_error(oss.str());
}
// Check for mismatches between the header and template
if (m_header.get_service_id() != m_template->get_service_id())
throw value_error("ServiceID mismatch between MessageHeader and assigned template.");
if (m_header.get_type() != m_template->get_type())
throw value_error("Message Type mismatch between MessageHeader and assigned template.");
m_service_id = service_id_field->get_value();
m_type = message_type_field->get_value();
const ki::dml::USHRT size = size_field->get_value() - 4;
// Read the payload into the record
m_record->read_from(istream);
}
else
{
// We don't have a template for the record structure, so
// just read the raw data into a buffer.
const auto size = m_header.get_message_size();
m_raw_data.resize(size);
istream.read(m_raw_data.data(), size);
if (istream.fail())
throw parse_error("Not enough data was available to read DML message payload.");
}
}
size_t Message::get_size() const
{
if (m_record)
return 4 + m_record->get_size();
return m_header.get_size() + m_record->get_size();
return 4 + m_raw_data.size();
}
}

View File

@ -1,38 +0,0 @@
#include "ki/protocol/dml/MessageBuilder.h"
namespace ki
{
namespace protocol
{
namespace dml
{
MessageBuilder::MessageBuilder(uint8_t service_id, uint8_t type)
{
m_message = new Message(service_id, type);
}
MessageBuilder &MessageBuilder::set_service_id(uint8_t service_id)
{
m_message->set_service_id(service_id);
return *this;
}
MessageBuilder &MessageBuilder::set_message_type(uint8_t type)
{
m_message->set_type(type);
return *this;
}
MessageBuilder &MessageBuilder::use_template_record(const ki::dml::Record& record)
{
m_message->set_record(record);
return *this;
}
Message *MessageBuilder::get_message() const
{
return m_message;
}
}
}
}

View File

@ -0,0 +1,88 @@
#include "ki/protocol/dml/MessageHeader.h"
#include "ki/dml/Record.h"
#include "ki/protocol/exception.h"
namespace ki
{
namespace protocol
{
namespace dml
{
MessageHeader::MessageHeader(const uint8_t service_id,
const uint8_t type, const uint16_t size)
{
m_service_id = service_id;
m_type = type;
m_size = size;
}
uint8_t MessageHeader::get_service_id() const
{
return m_service_id;
}
void MessageHeader::set_service_id(const uint8_t service_id)
{
m_service_id = service_id;
}
uint8_t MessageHeader::get_type() const
{
return m_type;
}
void MessageHeader::set_type(const uint8_t type)
{
m_type = type;
}
uint16_t MessageHeader::get_message_size() const
{
return m_size;
}
void MessageHeader::set_message_size(const uint16_t size)
{
m_size = size;
}
void MessageHeader::write_to(std::ostream& ostream) const
{
ki::dml::Record record;
record.add_field<ki::dml::UBYT>("m_service_id")->set_value(m_service_id);
record.add_field<ki::dml::UBYT>("m_type")->set_value(m_type);
record.add_field<ki::dml::USHRT>("m_size")->set_value(m_size + 4);
record.write_to(ostream);
}
void MessageHeader::read_from(std::istream& istream)
{
ki::dml::Record record;
const auto *service_id = record.add_field<ki::dml::UBYT>("m_service_id");
const auto *type = record.add_field<ki::dml::UBYT>("m_type");
const auto size = record.add_field<ki::dml::USHRT>("m_size");
try
{
record.read_from(istream);
}
catch (ki::dml::parse_error &e)
{
std::ostringstream oss;
oss << "Error reading MessageHeader: " << e.what();
throw parse_error(oss.str());
}
m_service_id = service_id->get_value();
m_type = type->get_value();
m_size = size->get_value() - 4;
}
size_t MessageHeader::get_size() const
{
return sizeof(ki::dml::UBYT) + sizeof(ki::dml::UBYT) +
sizeof(ki::dml::USHRT);
}
}
}
}

View File

@ -1,4 +1,5 @@
#include "ki/protocol/dml/MessageManager.h"
#include "ki/protocol/dml/MessageHeader.h"
#include "ki/protocol/exception.h"
#include "ki/dml/Record.h"
#include "ki/util/ValueBytes.h"
@ -155,7 +156,7 @@ namespace dml
return nullptr;
}
MessageBuilder &MessageManager::build_message(uint8_t service_id, uint8_t message_type) const
Message *MessageManager::create_message(uint8_t service_id, uint8_t message_type) const
{
auto *message_module = get_module(service_id);
if (!message_module)
@ -165,10 +166,10 @@ namespace dml
throw value_error(oss.str());
}
return message_module->build_message(message_type);
return message_module->create_message(message_type);
}
MessageBuilder& MessageManager::build_message(uint8_t service_id, const std::string& message_name) const
Message *MessageManager::create_message(uint8_t service_id, const std::string& message_name) const
{
auto *message_module = get_module(service_id);
if (!message_module)
@ -178,10 +179,10 @@ namespace dml
throw value_error(oss.str());
}
return message_module->build_message(message_name);
return message_module->create_message(message_name);
}
MessageBuilder& MessageManager::build_message(const std::string& protocol_type, uint8_t message_type) const
Message *MessageManager::create_message(const std::string& protocol_type, uint8_t message_type) const
{
auto *message_module = get_module(protocol_type);
if (!message_module)
@ -191,10 +192,10 @@ namespace dml
throw value_error(oss.str());
}
return message_module->build_message(message_type);
return message_module->create_message(message_type);
}
MessageBuilder& MessageManager::build_message(const std::string& protocol_type, const std::string& message_name) const
Message *MessageManager::create_message(const std::string& protocol_type, const std::string& message_name) const
{
auto *message_module = get_module(protocol_type);
if (!message_module)
@ -204,35 +205,47 @@ namespace dml
throw value_error(oss.str());
}
return message_module->build_message(message_name);
return message_module->create_message(message_name);
}
const Message *MessageManager::message_from_binary(std::istream& istream) const
{
// Read the message header and raw payload
Message *message = new Message();
// Read the message header
MessageHeader header;
try
{
message->read_from(istream);
header.read_from(istream);
}
catch (parse_error &e)
{
delete message;
return nullptr;
}
// Get the message module that uses the specified service id
auto *message_module = get_module(message->get_service_id());
auto *message_module = get_module(header.get_service_id());
if (!message_module)
return message;
return nullptr;
// Get the message template for this message type
auto *message_template = message_module->get_message_template(message->get_type());
auto *message_template = message_module->get_message_template(header.get_type());
if (!message_template)
return message;
return nullptr;
// Parse the raw payload with the template
message->use_template_record(message_template->get_record());
// Make sure that the size specified is enough to read this message
if (header.get_message_size() < message_template->get_record().get_size())
return nullptr;
// Create a new Message from the template
auto *message = new Message(message_template);
try
{
message->get_record()->read_from(istream);
}
catch (ki::dml::parse_error &e)
{
delete message;
return nullptr;
}
return message;
}
}

View File

@ -139,7 +139,7 @@ namespace dml
}
}
MessageBuilder& MessageModule::build_message(uint8_t message_type) const
Message *MessageModule::create_message(uint8_t message_type) const
{
auto *message_template = get_message_template(message_type);
if (!message_template)
@ -150,10 +150,10 @@ namespace dml
throw value_error(oss.str());
}
return message_template->build_message();
return message_template->create_message();
}
MessageBuilder &MessageModule::build_message(std::string message_name) const
Message *MessageModule::create_message(std::string message_name) const
{
auto *message_template = get_message_template(message_name);
if (!message_template)
@ -164,7 +164,7 @@ namespace dml
throw value_error(oss.str());
}
return message_template->build_message();
return message_template->create_message();
}
}
}

View File

@ -50,6 +50,32 @@ namespace dml
m_service_id = service_id;
}
std::string MessageTemplate::get_handler() const
{
const auto field = m_record->get_field<ki::dml::STR>("_MsgHandler");
if (field)
return field->get_value();
return m_name;
}
void MessageTemplate::set_handler(std::string handler)
{
m_record->add_field<ki::dml::STR>("_MsgHandler")->set_value(handler);
}
uint8_t MessageTemplate::get_access_level() const
{
const auto field = m_record->get_field<ki::dml::UBYT>("_MsgAccessLvl");
if (field)
return field->get_value();
return 0;
}
void MessageTemplate::set_access_level(uint8_t access_level)
{
m_record->add_field<ki::dml::UBYT>("_MsgAccessLvl")->set_value(access_level);
}
const ki::dml::Record& MessageTemplate::get_record() const
{
return *m_record;
@ -60,12 +86,9 @@ namespace dml
m_record = record;
}
MessageBuilder &MessageTemplate::build_message() const
Message *MessageTemplate::create_message() const
{
return MessageBuilder()
.set_message_type(m_type)
.set_service_id(m_service_id)
.use_template_record(*m_record);
return new Message(this);
}
}
}