Merge pull request #1 from Joshsora/messaging

Implements network and DML protocols
This commit is contained in:
Joshua Scott 2018-04-26 18:56:02 +01:00 committed by GitHub
commit cb4d89dad9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 2877 additions and 0 deletions

View File

@ -20,6 +20,7 @@ target_include_directories(${PROJECT_NAME}
target_link_libraries(${PROJECT_NAME} RapidXML)
add_subdirectory("src/dml")
add_subdirectory("src/protocol")
option(KI_BUILD_EXAMPLES "Determines whether to build examples." ON)
if (KI_BUILD_EXAMPLES)

View File

@ -0,0 +1,64 @@
#include <ki/protocol/dml/MessageManager.h>
#include <ki/protocol/exception.h>
#include <iostream>
using namespace ki::protocol;
int main(int argc, char **argv)
{
// Get command-line arguments
if (argc < 3)
{
std::cout << "usage: example-dml-module.exe <module_file> <message_name>" << std::endl;
std::cout << "Prints out information for specified message." << std::endl;
return 1;
}
// Create a manager to load modules into
auto *message_manager = new dml::MessageManager();
const dml::MessageModule *message_module;
// Load the message module file
const std::string filepath = argv[1];
try
{
message_module = message_manager->load_module(filepath);
}
catch (value_error &e)
{
std::cout << "Failed to load message module.";
return 1;
}
// Print some information about the module itself
std::cout << "Service ID: " << (uint16_t)message_module->get_service_id() << std::endl;
std::cout << "Protocol Type: " << message_module->get_protocol_type() << std::endl;
// Get the message template from the module we just loaded
const std::string message_name = argv[2];
auto *message_template = message_module->get_message_template(message_name);
if (message_template)
{
std::cout << "Message Name: " << message_template->get_name() << std::endl;
std::cout << "Mesasge Type: " << (uint16_t)message_template->get_type() << std::endl;
// Print out the fields in the template record
std::cout << std::endl;
auto &record = message_template->get_record();
for (auto it = record.fields_begin();
it != record.fields_end(); ++it)
{
auto *field = *it;
if (field->is_transferable())
std::cout << field->get_type_name() << " " << field->get_name() << ";" << std::endl;
}
}
else
{
std::cout << "Could not find message with name: " << message_name << std::endl;
return 1;
}
// Exit successfully
return 0;
}

View File

@ -0,0 +1,38 @@
#pragma once
#include "../../util/Serializable.h"
#include <cstdint>
#include <iostream>
namespace ki
{
namespace protocol
{
namespace control
{
class ClientKeepAlive final : public util::Serializable
{
public:
ClientKeepAlive(uint16_t session_id = 0,
uint16_t milliseconds = 0, uint16_t minutes = 0);
virtual ~ClientKeepAlive() = default;
uint16_t get_session_id() const;
void set_session_id(uint16_t session_id);
uint16_t get_milliseconds() const;
void set_milliseconds(uint16_t milliseconds);
uint16_t get_minutes() const;
void set_minutes(uint16_t minutes);
void write_to(std::ostream &ostream) const override final;
void read_from(std::istream &istream) override final;
size_t get_size() const override final;
private:
uint16_t m_session_id;
uint16_t m_milliseconds;
uint16_t m_minutes;
};
}
}
}

View File

@ -0,0 +1,21 @@
#pragma once
#include <cstdint>
namespace ki
{
namespace protocol
{
namespace control
{
enum class Opcode : uint8_t
{
NONE = 0,
SESSION_OFFER = 0,
UDP_HELLO = 1,
KEEP_ALIVE = 3,
KEEP_ALIVE_RSP = 4,
SESSION_ACCEPT = 5
};
}
}
}

View File

@ -0,0 +1,29 @@
#pragma once
#include "../../util/Serializable.h"
#include <cstdint>
#include <iostream>
namespace ki
{
namespace protocol
{
namespace control
{
class ServerKeepAlive final : public util::Serializable
{
public:
ServerKeepAlive(uint32_t timestamp = 0);
virtual ~ServerKeepAlive() = default;
uint32_t get_timestamp() const;
void set_timestamp(uint32_t timestamp);
void write_to(std::ostream &ostream) const override final;
void read_from(std::istream &istream) override final;
size_t get_size() const override final;
private:
uint32_t m_timestamp;
};
}
}
}

View File

@ -0,0 +1,38 @@
#pragma once
#include "../../util/Serializable.h"
#include <iostream>
#include <cstdint>
namespace ki
{
namespace protocol
{
namespace control
{
class SessionAccept final : public util::Serializable
{
public:
SessionAccept(uint16_t session_id = 0,
int32_t timestamp = 0, uint32_t milliseconds = 0);
virtual ~SessionAccept() = default;
uint16_t get_session_id() const;
void set_session_id(uint16_t session_id);
int32_t get_timestamp() const;
void set_timestamp(int32_t timestamp);
uint32_t get_milliseconds() const;
void set_milliseconds(uint32_t milliseconds);
void write_to(std::ostream &ostream) const override final;
void read_from(std::istream &istream) override final;
size_t get_size() const override final;
private:
uint16_t m_session_id;
int32_t m_timestamp;
uint32_t m_milliseconds;
};
}
}
}

View File

@ -0,0 +1,38 @@
#pragma once
#include "../../util/Serializable.h"
#include <iostream>
#include <cstdint>
namespace ki
{
namespace protocol
{
namespace control
{
class SessionOffer final : public util::Serializable
{
public:
SessionOffer(uint16_t session_id = 0,
int32_t timestamp = 0, uint32_t milliseconds = 0);
virtual ~SessionOffer() = default;
uint16_t get_session_id() const;
void set_session_id(uint16_t session_id);
int32_t get_timestamp() const;
void set_timestamp(int32_t timestamp);
uint32_t get_milliseconds() const;
void set_milliseconds(uint32_t milliseconds);
void write_to(std::ostream &ostream) const override final;
void read_from(std::istream &istream) override final;
size_t get_size() const override final;
private:
uint16_t m_session_id;
int32_t m_timestamp;
uint32_t m_milliseconds;
};
}
}
}

View File

@ -0,0 +1,50 @@
#pragma once
#include "MessageHeader.h"
#include "../../util/Serializable.h"
#include "../../dml/Record.h"
#include <iostream>
namespace ki
{
namespace protocol
{
namespace dml
{
class MessageTemplate;
class Message final : public util::Serializable
{
public:
Message(const MessageTemplate *message_template = nullptr);
virtual ~Message();
const MessageTemplate *get_template() const;
void set_template(const MessageTemplate *message_template);
ki::dml::Record *get_record();
const ki::dml::Record *get_record() const;
ki::dml::FieldBase *get_field(std::string name);
const ki::dml::FieldBase *get_field(std::string name) const;
uint8_t get_service_id() const;
uint8_t get_type() const;
uint16_t get_message_size() const;
std::string get_handler() const;
uint8_t get_access_level() const;
void write_to(std::ostream &ostream) const override final;
void read_from(std::istream &istream) override final;
size_t get_size() const override final;
private:
const MessageTemplate *m_template;
ki::dml::Record *m_record;
// This is used to store raw data when a Message is
// constructed without a MessageTemplate.
MessageHeader m_header;
std::vector<char> m_raw_data;
};
}
}
}

View File

@ -0,0 +1,37 @@
#pragma once
#include "../../util/Serializable.h"
#include <iostream>
namespace ki
{
namespace protocol
{
namespace dml
{
class MessageHeader : public util::Serializable
{
public:
MessageHeader(uint8_t service_id = 0,
uint8_t type = 0, uint16_t size = 0);
virtual ~MessageHeader() = default;
uint8_t get_service_id() const;
void set_service_id(uint8_t service_id);
uint8_t get_type() const;
void set_type(uint8_t type);
uint16_t get_message_size() const;
void set_message_size(uint16_t size);
void write_to(std::ostream &ostream) const override final;
void read_from(std::istream &istream) override final;
size_t get_size() const override final;
private:
uint8_t m_service_id;
uint8_t m_type;
uint16_t m_size;
};
}
}
}

View File

@ -0,0 +1,44 @@
#pragma once
#include "Message.h"
#include "MessageModule.h"
#include "../../dml/Record.h"
#include <string>
namespace ki
{
namespace protocol
{
namespace dml
{
class MessageManager
{
public:
MessageManager() = default;
~MessageManager();
const MessageModule *load_module(std::string filepath);
const MessageModule *get_module(uint8_t service_id) const;
const MessageModule *get_module(const std::string &protocol_type) const;
Message *create_message(uint8_t service_id, uint8_t message_type) const;
Message *create_message(uint8_t service_id, const std::string &message_name) const;
Message *create_message(const std::string &protocol_type, uint8_t message_type) const;
Message *create_message(const std::string &protocol_type, const std::string &message_name) const;
/**
* If the DML message header cannot be read, then a nullptr
* is returned; otherwise, a valid Message pointer is always returned.
* However, that does not mean that the message itself is valid.
*
* To verify if the record was completely parsed, get_record
* should return a valid Record pointer, rather than nullptr.
*/
const Message *message_from_binary(std::istream &istream) const;
private:
MessageModuleList m_modules;
MessageModuleServiceIdMap m_service_id_map;
MessageModuleProtocolTypeMap m_protocol_type_map;
};
}
}
}

View File

@ -0,0 +1,55 @@
#pragma once
#include "Message.h"
#include "MessageTemplate.h"
#include <cstdint>
#include <string>
#include <vector>
#include <map>
namespace ki
{
namespace protocol
{
namespace dml
{
class MessageModule
{
public:
MessageModule(uint8_t service_id = 0, std::string protocol_type = "");
~MessageModule();
uint8_t get_service_id() const;
void set_service_id(uint8_t service_id);
std::string get_protocol_type() const;
void set_protocol_type(std::string protocol_type);
std::string get_protocol_desription() const;
void set_protocol_description(std::string protocol_description);
const MessageTemplate *add_message_template(std::string name,
ki::dml::Record *record, bool auto_sort = true);
const MessageTemplate *get_message_template(uint8_t type) const;
const MessageTemplate *get_message_template(std::string name) const;
void sort_lookup();
Message *create_message(uint8_t message_type) const;
Message *create_message(std::string message_name) const;
private:
uint8_t m_service_id;
std::string m_protocol_type;
std::string m_protocol_description;
uint8_t m_last_message_type;
std::vector<MessageTemplate *> m_templates;
std::map<uint8_t, MessageTemplate *> m_message_type_map;
std::map<std::string, MessageTemplate *> m_message_name_map;
};
typedef std::vector<MessageModule *> MessageModuleList;
typedef std::map<uint8_t, MessageModule *> MessageModuleServiceIdMap;
typedef std::map<std::string, MessageModule *> MessageModuleProtocolTypeMap;
}
}
}

View File

@ -0,0 +1,46 @@
#pragma once
#include "../../dml/Record.h"
#include "Message.h"
#include <string>
namespace ki
{
namespace protocol
{
namespace dml
{
class MessageTemplate
{
public:
MessageTemplate(std::string name, uint8_t type,
uint8_t service_id, ki::dml::Record *record);
~MessageTemplate();
std::string get_name() const;
void set_name(std::string name);
uint8_t get_type() const;
void set_type(uint8_t type);
uint8_t get_service_id() const;
void set_service_id(uint8_t service_id);
std::string get_handler() const;
void set_handler(std::string handler);
uint8_t get_access_level() const;
void set_access_level(uint8_t access_level);
const ki::dml::Record &get_record() const;
void set_record(ki::dml::Record *record);
Message *create_message() const;
private:
std::string m_name;
uint8_t m_type;
uint8_t m_service_id;
ki::dml::Record *m_record;
};
}
}
}

View File

@ -0,0 +1,64 @@
#pragma once
#include <stdexcept>
namespace ki
{
namespace protocol
{
class runtime_error : public std::runtime_error
{
public:
explicit runtime_error(std::string message) : std::runtime_error(message) {}
};
class parse_error : public runtime_error
{
public:
enum code
{
NONE,
INVALID_XML_DATA,
INVALID_HEADER_DATA,
INSUFFICIENT_MESSAGE_DATA,
INVALID_MESSAGE_DATA
};
explicit parse_error(std::string message, code error = code::NONE)
: runtime_error(message)
{
m_code = error;
}
code get_error_code() const { return m_code; }
private:
code m_code;
};
class value_error : public runtime_error
{
public:
enum code
{
NONE,
MISSING_FILE,
OVERWRITES_LOOKUP,
EXCEEDS_LIMIT,
DML_INVALID_SERVICE,
DML_INVALID_PROTOCOL_TYPE,
DML_INVALID_MESSAGE_TYPE,
DML_INVALID_MESSAGE_NAME
};
explicit value_error(std::string message, code error = code::NONE)
: runtime_error(message)
{
m_code = error;
}
code get_error_code() const { return m_code; }
private:
code m_code;
};
}
}

View File

@ -0,0 +1,35 @@
#pragma once
#include "ClientSession.h"
#include "DMLSession.h"
// Disable inheritance via dominance warning
#if _MSC_VER
#pragma warning(disable: 4250)
#endif
namespace ki
{
namespace protocol
{
namespace net
{
class ClientDMLSession : public ClientSession, public DMLSession
{
// Explicitly specify that we are intentionally inheritting
// via dominance.
using DMLSession::on_application_message;
using ClientSession::on_control_message;
using ClientSession::is_alive;
public:
ClientDMLSession(const uint16_t id, const dml::MessageManager &manager)
: Session(id), ClientSession(id), DMLSession(id, manager) {}
virtual ~ClientDMLSession() = default;
};
}
}
}
// Re-enable inheritance via dominance warning
#if _MSC_VER
#pragma warning(default: 4250)
#endif

View File

@ -0,0 +1,34 @@
#pragma once
#include "Session.h"
#define KI_SERVER_HEARTBEAT 60
namespace ki
{
namespace protocol
{
namespace net
{
/**
* Implements client-sided session logic.
*/
class ClientSession : public virtual Session
{
public:
explicit ClientSession(uint16_t id);
virtual ~ClientSession() = default;
void send_keep_alive();
bool is_alive() const override;
protected:
void on_connected();
virtual void on_established() {}
void on_control_message(const PacketHeader& header) override;
private:
void on_session_offer();
void on_keep_alive();
void on_keep_alive_response();
};
}
}
}

View File

@ -0,0 +1,44 @@
#pragma once
#include "Session.h"
#include "../dml/MessageManager.h"
namespace ki
{
namespace protocol
{
namespace net
{
enum class InvalidDMLMessageErrorCode
{
NONE,
UNKNOWN,
INVALID_HEADER_DATA,
INVALID_MESSAGE_DATA,
INVALID_SERVICE,
INVALID_MESSAGE_TYPE,
INSUFFICIENT_ACCESS
};
/**
* Implements an application protocol that uses the DML
* message system (as seen in Wizard101 and Pirate101).
*/
class DMLSession : public virtual Session
{
public:
DMLSession(uint16_t id, const dml::MessageManager &manager);
virtual ~DMLSession() = default;
const dml::MessageManager &get_manager() const;
void send_message(const dml::Message &message);
protected:
void on_application_message(const PacketHeader& header) override;
virtual void on_message(const dml::Message *message) {}
virtual void on_invalid_message(InvalidDMLMessageErrorCode error) {}
private:
const dml::MessageManager &m_manager;
};
}
}
}

View File

@ -0,0 +1,33 @@
#pragma once
#include "../../util/Serializable.h"
#include <cstdint>
#include <iostream>
namespace ki
{
namespace protocol
{
namespace net
{
class PacketHeader final : public util::Serializable
{
public:
PacketHeader(bool control = false, uint8_t opcode = 0);
virtual ~PacketHeader() = default;
bool is_control() const;
void set_control(bool control);
uint8_t get_opcode() const;
void set_opcode(uint8_t opcode);
void write_to(std::ostream &ostream) const override final;
void read_from(std::istream &istream) override final;
size_t get_size() const override final;
private:
bool m_control;
uint8_t m_opcode;
};
}
}
}

View File

@ -0,0 +1,35 @@
#pragma once
#include "ServerSession.h"
#include "DMLSession.h"
// Disable inheritance via dominance warning
#if _MSC_VER
#pragma warning(disable: 4250)
#endif
namespace ki
{
namespace protocol
{
namespace net
{
class ServerDMLSession : public ServerSession, public DMLSession
{
// Explicitly specify that we are intentionally inheritting
// via dominance.
using DMLSession::on_application_message;
using ServerSession::on_control_message;
using ServerSession::is_alive;
public:
ServerDMLSession(const uint16_t id, const dml::MessageManager &manager)
: Session(id), ServerSession(id), DMLSession(id, manager) {}
virtual ~ServerDMLSession() = default;
};
}
}
}
// Re-enable inheritance via dominance warning
#if _MSC_VER
#pragma warning(default: 4250)
#endif

View File

@ -0,0 +1,34 @@
#pragma once
#include "Session.h"
#define KI_CLIENT_HEARTBEAT 10
namespace ki
{
namespace protocol
{
namespace net
{
/**
* Implements server-sided session logic.
*/
class ServerSession : public virtual Session
{
public:
explicit ServerSession(uint16_t id);
virtual ~ServerSession() = default;
void send_keep_alive(uint32_t milliseconds_since_startup);
bool is_alive() const override;
protected:
void on_connected();
virtual void on_established() {}
void on_control_message(const PacketHeader& header) override;
private:
void on_session_accept();
void on_keep_alive();
void on_keep_alive_response();
};
}
}
}

View File

@ -0,0 +1,138 @@
#pragma once
#include "PacketHeader.h"
#include "../control/Opcode.h"
#include "../../util/Serializable.h"
#include <cstdint>
#include <sstream>
#include <chrono>
#include <type_traits>
#define KI_DEFAULT_MAXIMUM_RECEIVE_SIZE 0x2000
#define KI_START_SIGNAL 0xF00D
#define KI_CONNECTION_TIMEOUT 3
namespace ki
{
namespace protocol
{
namespace net
{
enum class SessionCloseErrorCode
{
NONE,
APPLICATION_ERROR,
INVALID_FRAMING_START_SIGNAL,
INVALID_FRAMING_SIZE_EXCEEDS_MAXIMUM,
UNHANDLED_CONTROL_MESSAGE,
UNHANDLED_APPLICATION_MESSAGE,
INVALID_MESSAGE,
SESSION_OFFER_TIMED_OUT,
SESSION_DIED
};
enum class ReceiveState
{
// Waiting for the 0xF00D start signal.
WAITING_FOR_START_SIGNAL,
// Waiting for the 2-byte length.
WAITING_FOR_LENGTH,
// Waiting for the packet data.
WAITING_FOR_PACKET
};
/**
* This class implements session and packet framing logic
* when sending and receiving data to/from an external
* source.
*/
class Session
{
public:
explicit Session(uint16_t id = 0);
virtual ~Session() = default;
uint16_t get_maximum_packet_size() const;
void set_maximum_packet_size(uint16_t maximum_packet_size);
uint16_t get_id() const;
bool is_established() const;
uint8_t get_access_level() const;
void set_access_level(uint8_t access_level);
uint16_t get_latency() const;
virtual bool is_alive() const = 0;
void send_packet(bool is_control, uint8_t opcode,
const util::Serializable &data);
protected:
/* Higher-level session members */
uint16_t m_id;
bool m_established;
uint8_t m_access_level;
/* Timing members */
std::chrono::steady_clock::time_point m_creation_time;
std::chrono::steady_clock::time_point m_connection_time;
std::chrono::steady_clock::time_point m_establish_time;
std::chrono::steady_clock::time_point m_last_received_heartbeat_time;
std::chrono::steady_clock::time_point m_last_sent_heartbeat_time;
bool m_waiting_for_keep_alive_response;
uint16_t m_latency;
// The packet data stream
std::stringstream m_data_stream;
/**
* Reads a serializable structure from the data stream.
*/
template <typename DataT>
DataT read_data()
{
static_assert(std::is_base_of<util::Serializable, DataT>::value,
"DataT must inherit Serializable.");
DataT data = DataT();
data.read_from(m_data_stream);
return data;
}
/**
* Frames raw data into a Packet, and transmits it.
*/
void send_data(const char *data, size_t size);
/**
* Process incoming raw data into Packets.
* Once a packet is read into the internal data
* stream, handle_packet_available is called.
*/
void process_data(const char *data, size_t size);
/* Event handlers */
virtual void on_invalid_packet() {}
virtual void on_control_message(const PacketHeader &header) {}
virtual void on_application_message(const PacketHeader &header) {}
/* Low-level socket methods */
virtual void send_packet_data(const char *data, const size_t size) = 0;
virtual void close(SessionCloseErrorCode error) = 0;
private:
/* Low-level networking members */
uint16_t m_maximum_packet_size;
ReceiveState m_receive_state;
uint16_t m_start_signal;
uint16_t m_incoming_packet_size;
uint8_t m_shift;
void on_packet_available();
};
}
}
}

View File

@ -0,0 +1,17 @@
target_sources(${PROJECT_NAME}
PRIVATE
${PROJECT_SOURCE_DIR}/src/protocol/control/ClientKeepAlive.cpp
${PROJECT_SOURCE_DIR}/src/protocol/control/ServerKeepAlive.cpp
${PROJECT_SOURCE_DIR}/src/protocol/control/SessionAccept.cpp
${PROJECT_SOURCE_DIR}/src/protocol/control/SessionOffer.cpp
${PROJECT_SOURCE_DIR}/src/protocol/dml/Message.cpp
${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageHeader.cpp
${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageManager.cpp
${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageModule.cpp
${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageTemplate.cpp
${PROJECT_SOURCE_DIR}/src/protocol/net/ClientSession.cpp
${PROJECT_SOURCE_DIR}/src/protocol/net/DMLSession.cpp
${PROJECT_SOURCE_DIR}/src/protocol/net/PacketHeader.cpp
${PROJECT_SOURCE_DIR}/src/protocol/net/ServerSession.cpp
${PROJECT_SOURCE_DIR}/src/protocol/net/Session.cpp
)

View File

@ -0,0 +1,87 @@
#include "ki/protocol/control/ClientKeepAlive.h"
#include "ki/dml/Record.h"
#include "ki/protocol/exception.h"
namespace ki
{
namespace protocol
{
namespace control
{
ClientKeepAlive::ClientKeepAlive(const uint16_t session_id, const uint16_t milliseconds,
const uint16_t minutes)
{
m_session_id = session_id;
m_milliseconds = milliseconds;
m_minutes = minutes;
}
uint16_t ClientKeepAlive::get_session_id() const
{
return m_session_id;
}
void ClientKeepAlive::set_session_id(const uint16_t session_id)
{
m_session_id = session_id;
}
uint16_t ClientKeepAlive::get_milliseconds() const
{
return m_milliseconds;
}
void ClientKeepAlive::set_milliseconds(const uint16_t milliseconds)
{
m_milliseconds = milliseconds;
}
uint16_t ClientKeepAlive::get_minutes() const
{
return m_minutes;
}
void ClientKeepAlive::set_minutes(const uint16_t minutes)
{
m_minutes = minutes;
}
void ClientKeepAlive::write_to(std::ostream &ostream) const
{
dml::Record record;
record.add_field<dml::USHRT>("m_session_id")->set_value(m_session_id);
record.add_field<dml::USHRT>("m_milliseconds")->set_value(m_milliseconds);
record.add_field<dml::USHRT>("m_minutes")->set_value(m_minutes);
record.write_to(ostream);
}
void ClientKeepAlive::read_from(std::istream &istream)
{
dml::Record record;
auto *session_id = record.add_field<dml::USHRT>("m_session_id");
auto *milliseconds = record.add_field<dml::USHRT>("m_milliseconds");
auto *minutes = record.add_field<dml::USHRT>("m_minutes");
try
{
record.read_from(istream);
}
catch (dml::parse_error &e)
{
std::ostringstream oss;
oss << "Error reading ClientKeepAlive payload: " << e.what();
throw parse_error(oss.str(), parse_error::INVALID_MESSAGE_DATA);
}
m_session_id = session_id->get_value();
m_milliseconds = milliseconds->get_value();
m_minutes = minutes->get_value();
}
size_t ClientKeepAlive::get_size() const
{
return sizeof(dml::USHRT) + sizeof(dml::USHRT) +
sizeof(dml::USHRT);
}
}
}
}

View File

@ -0,0 +1,61 @@
#include "ki/protocol/control/ServerKeepAlive.h"
#include "ki/dml/Record.h"
#include "ki/protocol/exception.h"
#include <chrono>
namespace ki
{
namespace protocol
{
namespace control
{
ServerKeepAlive::ServerKeepAlive(const uint32_t timestamp)
{
m_timestamp = timestamp;
}
uint32_t ServerKeepAlive::get_timestamp() const
{
return m_timestamp;
}
void ServerKeepAlive::set_timestamp(const uint32_t timestamp)
{
m_timestamp = timestamp;
}
void ServerKeepAlive::write_to(std::ostream& ostream) const
{
dml::Record record;
record.add_field<dml::USHRT>("m_session_id");
record.add_field<dml::INT>("m_timestamp")->set_value(m_timestamp);
record.write_to(ostream);
}
void ServerKeepAlive::read_from(std::istream& istream)
{
dml::Record record;
record.add_field<dml::USHRT>("m_session_id");
auto *timestamp = record.add_field<dml::INT>("m_timestamp");
try
{
record.read_from(istream);
}
catch (dml::parse_error &e)
{
std::ostringstream oss;
oss << "Error reading ServerKeepAlive payload: " << e.what();
throw parse_error(oss.str(), parse_error::INVALID_MESSAGE_DATA);
}
m_timestamp = timestamp->get_value();
}
size_t ServerKeepAlive::get_size() const
{
return sizeof(dml::USHRT) + sizeof(dml::INT);
}
}
}
}

View File

@ -0,0 +1,92 @@
#include "ki/protocol/control/SessionAccept.h"
#include "ki/dml/Record.h"
#include "ki/protocol/exception.h"
namespace ki
{
namespace protocol
{
namespace control
{
SessionAccept::SessionAccept(const uint16_t session_id,
const int32_t timestamp, const uint32_t milliseconds)
{
m_session_id = session_id;
m_timestamp = timestamp;
m_milliseconds = milliseconds;
}
uint16_t SessionAccept::get_session_id() const
{
return m_session_id;
}
void SessionAccept::set_session_id(const uint16_t session_id)
{
m_session_id = session_id;
}
int32_t SessionAccept::get_timestamp() const
{
return m_timestamp;
}
void SessionAccept::set_timestamp(const int32_t timestamp)
{
m_timestamp = timestamp;
}
uint32_t SessionAccept::get_milliseconds() const
{
return m_milliseconds;
}
void SessionAccept::set_milliseconds(const uint32_t milliseconds)
{
m_milliseconds = milliseconds;
}
void SessionAccept::write_to(std::ostream& ostream) const
{
dml::Record record;
record.add_field<dml::USHRT>("unknown");
record.add_field<dml::UINT>("unknown2");
record.add_field<dml::INT>("m_timestamp")->set_value(m_timestamp);
record.add_field<dml::UINT>("m_milliseconds")->set_value(m_milliseconds);
record.add_field<dml::USHRT>("m_session_id")->set_value(m_session_id);
record.write_to(ostream);
}
void SessionAccept::read_from(std::istream& istream)
{
dml::Record record;
record.add_field<dml::USHRT>("unknown");
record.add_field<dml::UINT>("unknown2");
auto *timestamp = record.add_field<dml::INT>("m_timestamp");
auto *milliseconds = record.add_field<dml::UINT>("m_milliseconds");
auto *session_id = record.add_field<dml::USHRT>("m_session_id");
try
{
record.read_from(istream);
}
catch (dml::parse_error &e)
{
std::ostringstream oss;
oss << "Error reading SessionAccept payload: " << e.what();
throw parse_error(oss.str(), parse_error::INVALID_MESSAGE_DATA);
}
m_timestamp = timestamp->get_value();
m_milliseconds = milliseconds->get_value();
m_session_id = session_id->get_value();
}
size_t SessionAccept::get_size() const
{
return sizeof(dml::USHRT) + sizeof(dml::UINT) +
sizeof(dml::INT) + sizeof(dml::UINT) +
sizeof(dml::USHRT);
}
}
}
}

View File

@ -0,0 +1,89 @@
#include "ki/protocol/control/SessionOffer.h"
#include "ki/dml/Record.h"
#include "ki/protocol/exception.h"
namespace ki
{
namespace protocol
{
namespace control
{
SessionOffer::SessionOffer(const uint16_t session_id,
const int32_t timestamp, const uint32_t milliseconds)
{
m_session_id = session_id;
m_timestamp = timestamp;
m_milliseconds = milliseconds;
}
uint16_t SessionOffer::get_session_id() const
{
return m_session_id;
}
void SessionOffer::set_session_id(const uint16_t session_id)
{
m_session_id = session_id;
}
int32_t SessionOffer::get_timestamp() const
{
return m_timestamp;
}
void SessionOffer::set_timestamp(const int32_t timestamp)
{
m_timestamp = timestamp;
}
uint32_t SessionOffer::get_milliseconds() const
{
return m_milliseconds;
}
void SessionOffer::set_milliseconds(const uint32_t milliseconds)
{
m_milliseconds = milliseconds;
}
void SessionOffer::write_to(std::ostream& ostream) const
{
dml::Record record;
record.add_field<dml::USHRT>("m_session_id")->set_value(m_session_id);
record.add_field<dml::UINT>("unknown");
record.add_field<dml::INT>("m_timestamp")->set_value(m_timestamp);
record.add_field<dml::UINT>("m_milliseconds")->set_value(m_milliseconds);
record.write_to(ostream);
}
void SessionOffer::read_from(std::istream& istream)
{
dml::Record record;
auto *session_id = record.add_field<dml::USHRT>("m_session_id");
record.add_field<dml::UINT>("unknown");
auto *timestamp = record.add_field<dml::INT>("m_timestamp");
auto *milliseconds = record.add_field<dml::UINT>("m_milliseconds");
try
{
record.read_from(istream);
}
catch (dml::parse_error &e)
{
std::ostringstream oss;
oss << "Error reading SessionOffer payload: " << e.what();
throw parse_error(oss.str(), parse_error::INVALID_MESSAGE_DATA);
}
m_session_id = session_id->get_value();
m_timestamp = timestamp->get_value();
m_milliseconds = milliseconds->get_value();
}
size_t SessionOffer::get_size() const
{
return sizeof(dml::USHRT) + sizeof(dml::UINT) +
sizeof(dml::INT) + sizeof(dml::UINT);
}
}
}
}

View File

@ -0,0 +1,173 @@
#include "ki/protocol/dml/Message.h"
#include "ki/protocol/dml/MessageTemplate.h"
#include "ki/protocol/exception.h"
namespace ki
{
namespace protocol
{
namespace dml
{
Message::Message(const MessageTemplate *message_template)
{
m_template = message_template;
if (m_template)
m_record = new ki::dml::Record(m_template->get_record());
else
m_record = nullptr;
}
Message::~Message()
{
delete m_record;
}
const MessageTemplate *Message::get_template() const
{
return m_template;
}
void Message::set_template(const MessageTemplate *message_template)
{
m_template = message_template;
if (!m_template)
return;
m_record = new ki::dml::Record(message_template->get_record());
if (!m_raw_data.empty())
{
std::istringstream iss(std::string(m_raw_data.data(), m_raw_data.size()));
try
{
m_record->read_from(iss);
m_raw_data.clear();
}
catch (ki::dml::parse_error &e)
{
delete m_record;
m_template = nullptr;
m_record = nullptr;
std::ostringstream oss;
oss << "Error reading DML message payload: " << e.what();
throw parse_error(oss.str(), parse_error::INVALID_MESSAGE_DATA);
}
}
}
uint8_t Message::get_service_id() const
{
if (m_template)
return m_template->get_service_id();
return m_header.get_service_id();
}
uint8_t Message::get_type() const
{
if (m_template)
return m_template->get_type();
return m_header.get_type();
}
uint16_t Message::get_message_size() const
{
if (m_record)
return m_record->get_size();
return m_raw_data.size();
}
std::string Message::get_handler() const
{
if (m_template)
return m_template->get_handler();
return "";
}
uint8_t Message::get_access_level() const
{
if (m_template)
return m_template->get_access_level();
return 0;
}
ki::dml::Record *Message::get_record()
{
return m_record;
}
const ki::dml::Record *Message::get_record() const
{
return m_record;
}
ki::dml::FieldBase* Message::get_field(std::string name)
{
if (m_record)
return m_record->get_field(name);
return nullptr;
}
const ki::dml::FieldBase* Message::get_field(std::string name) const
{
if (m_record)
return m_record->get_field(name);
return nullptr;
}
void Message::write_to(std::ostream &ostream) const
{
// Write the header
if (m_template)
{
MessageHeader header(
get_service_id(), get_type(), get_message_size());
header.write_to(ostream);
}
else
m_header.write_to(ostream);
// Write the payload
if (m_record)
m_record->write_to(ostream);
else
ostream.write(m_raw_data.data(), m_raw_data.size());
}
void Message::read_from(std::istream &istream)
{
m_header.read_from(istream);
if (m_template)
{
// Check for mismatches between the header and template
if (m_header.get_service_id() != m_template->get_service_id())
throw value_error("ServiceID mismatch between MessageHeader and assigned template.",
value_error::DML_INVALID_SERVICE);
if (m_header.get_type() != m_template->get_type())
throw value_error("Message Type mismatch between MessageHeader and assigned template.",
value_error::DML_INVALID_MESSAGE_TYPE);
// Read the payload into the record
m_record->read_from(istream);
}
else
{
// We don't have a template for the record structure, so
// just read the raw data into a buffer.
const auto size = m_header.get_message_size();
m_raw_data.resize(size);
istream.read(m_raw_data.data(), size);
if (istream.fail())
throw parse_error("Not enough data was available to read DML message payload.",
parse_error::INSUFFICIENT_MESSAGE_DATA);
}
}
size_t Message::get_size() const
{
if (m_record)
return m_header.get_size() + m_record->get_size();
return 4 + m_raw_data.size();
}
}
}
}

View File

@ -0,0 +1,88 @@
#include "ki/protocol/dml/MessageHeader.h"
#include "ki/dml/Record.h"
#include "ki/protocol/exception.h"
namespace ki
{
namespace protocol
{
namespace dml
{
MessageHeader::MessageHeader(const uint8_t service_id,
const uint8_t type, const uint16_t size)
{
m_service_id = service_id;
m_type = type;
m_size = size;
}
uint8_t MessageHeader::get_service_id() const
{
return m_service_id;
}
void MessageHeader::set_service_id(const uint8_t service_id)
{
m_service_id = service_id;
}
uint8_t MessageHeader::get_type() const
{
return m_type;
}
void MessageHeader::set_type(const uint8_t type)
{
m_type = type;
}
uint16_t MessageHeader::get_message_size() const
{
return m_size;
}
void MessageHeader::set_message_size(const uint16_t size)
{
m_size = size;
}
void MessageHeader::write_to(std::ostream& ostream) const
{
ki::dml::Record record;
record.add_field<ki::dml::UBYT>("m_service_id")->set_value(m_service_id);
record.add_field<ki::dml::UBYT>("m_type")->set_value(m_type);
record.add_field<ki::dml::USHRT>("m_size")->set_value(m_size + 4);
record.write_to(ostream);
}
void MessageHeader::read_from(std::istream& istream)
{
ki::dml::Record record;
const auto *service_id = record.add_field<ki::dml::UBYT>("m_service_id");
const auto *type = record.add_field<ki::dml::UBYT>("m_type");
const auto size = record.add_field<ki::dml::USHRT>("m_size");
try
{
record.read_from(istream);
}
catch (ki::dml::parse_error &e)
{
std::ostringstream oss;
oss << "Error reading MessageHeader: " << e.what();
throw parse_error(oss.str(), parse_error::INVALID_HEADER_DATA);
}
m_service_id = service_id->get_value();
m_type = type->get_value();
m_size = size->get_value() - 4;
}
size_t MessageHeader::get_size() const
{
return sizeof(ki::dml::UBYT) + sizeof(ki::dml::UBYT) +
sizeof(ki::dml::USHRT);
}
}
}
}

View File

@ -0,0 +1,260 @@
#include "ki/protocol/dml/MessageManager.h"
#include "ki/protocol/dml/MessageHeader.h"
#include "ki/protocol/exception.h"
#include "ki/dml/Record.h"
#include "ki/util/ValueBytes.h"
#include <fstream>
#include <sstream>
#include <rapidxml.hpp>
namespace ki
{
namespace protocol
{
namespace dml
{
MessageManager::~MessageManager()
{
for (auto it = m_modules.begin();
it != m_modules.end(); ++it)
delete *it;
m_modules.clear();
m_service_id_map.clear();
m_protocol_type_map.clear();
}
const MessageModule *MessageManager::load_module(std::string filepath)
{
// Open the file
std::ifstream ifs(filepath, std::ios::ate);
if (!ifs.is_open())
{
std::ostringstream oss;
oss << "Could not open file: " << filepath;
throw value_error(oss.str(), value_error::MISSING_FILE);
}
// Load contents into memory
size_t size = ifs.tellg();
ifs.seekg(0, std::ios::beg);
char *data = new char[size + 1] { 0 };
ifs.read(data, size);
// Parse the contents
rapidxml::xml_document<> doc;
try
{
doc.parse<0>(data);
}
catch (rapidxml::parse_error &e)
{
delete[] data;
std::ostringstream oss;
oss << "Failed to parse: " << filepath;
throw parse_error(oss.str(), parse_error::INVALID_XML_DATA);
}
// It's safe to allocate the module we're working on now
auto *message_module = new MessageModule();
// Get the root node and iterate through children
// Each child is a MessageTemplate
auto *root = doc.first_node();
for (auto *node = root->first_node();
node; node = node->next_sibling())
{
// Parse the record node inside this node
auto *record_node = node->first_node();
if (!record_node)
continue;
auto *record = new ki::dml::Record();
record->from_xml(record_node);
// The message name is initially based on the element name
const std::string message_name = node->name();
if (message_name == "_ProtocolInfo")
{
auto *service_id_field = record->get_field<ki::dml::UBYT>("ServiceID");
auto *type_field = record->get_field<ki::dml::STR>("ProtocolType");
auto *description_field = record->get_field<ki::dml::STR>("ProtocolDescription");
// Set the module metadata from this template
if (service_id_field)
message_module->set_service_id(service_id_field->get_value());
if (type_field)
message_module->set_protocol_type(type_field->get_value());
if (description_field)
message_module->set_protocol_description(description_field->get_value());
}
else
{
// Only do sorting after we've reached the final message
// This only affects modules that aren't ordered with _MsgOrder.
const bool auto_sort = node->next_sibling() == nullptr;
// The template will use the record itself to figure out name and type;
// we only give the XML data incase the record doesn't have it defined.
auto *message_template = message_module->add_message_template(message_name, record, auto_sort);
if (!message_template)
{
delete[] data;
delete message_module;
delete record;
std::ostringstream oss;
oss << "Failed to create message template for ";
oss << message_name;
throw value_error(oss.str());
}
}
}
// Make sure we aren't overwriting another module
if (m_service_id_map.count(message_module->get_service_id()) == 1)
{
delete[] data;
delete message_module;
std::ostringstream oss;
oss << "Message Module has already been loaded with Service ID ";
oss << (uint16_t)message_module->get_service_id();
throw value_error(oss.str(), value_error::OVERWRITES_LOOKUP);
}
if (m_protocol_type_map.count(message_module->get_protocol_type()) == 1)
{
delete[] data;
delete message_module;
std::ostringstream oss;
oss << "Message Module has already been loaded with Protocol Type ";
oss << message_module->get_protocol_type();
throw value_error(oss.str(), value_error::OVERWRITES_LOOKUP);
}
// Add it to our maps
m_modules.push_back(message_module);
m_service_id_map.insert({ message_module->get_service_id(), message_module });
m_protocol_type_map.insert({ message_module->get_protocol_type(), message_module });
delete[] data;
return message_module;
}
const MessageModule *MessageManager::get_module(uint8_t service_id) const
{
if (m_service_id_map.count(service_id) == 1)
return m_service_id_map.at(service_id);
return nullptr;
}
const MessageModule *MessageManager::get_module(const std::string &protocol_type) const
{
if (m_protocol_type_map.count(protocol_type) == 1)
return m_protocol_type_map.at(protocol_type);
return nullptr;
}
Message *MessageManager::create_message(uint8_t service_id, uint8_t message_type) const
{
auto *message_module = get_module(service_id);
if (!message_module)
{
std::ostringstream oss;
oss << "No service exists with id: " << (uint16_t)service_id;
throw value_error(oss.str(), value_error::DML_INVALID_SERVICE);
}
return message_module->create_message(message_type);
}
Message *MessageManager::create_message(uint8_t service_id, const std::string& message_name) const
{
auto *message_module = get_module(service_id);
if (!message_module)
{
std::ostringstream oss;
oss << "No service exists with id: " << (uint16_t)service_id;
throw value_error(oss.str(), value_error::DML_INVALID_SERVICE);
}
return message_module->create_message(message_name);
}
Message *MessageManager::create_message(const std::string& protocol_type, uint8_t message_type) const
{
auto *message_module = get_module(protocol_type);
if (!message_module)
{
std::ostringstream oss;
oss << "No service exists with protocol type: " << protocol_type;
throw value_error(oss.str(), value_error::DML_INVALID_PROTOCOL_TYPE);
}
return message_module->create_message(message_type);
}
Message *MessageManager::create_message(const std::string& protocol_type, const std::string& message_name) const
{
auto *message_module = get_module(protocol_type);
if (!message_module)
{
std::ostringstream oss;
oss << "No service exists with protocol type: " << protocol_type;
throw value_error(oss.str(), value_error::DML_INVALID_PROTOCOL_TYPE);
}
return message_module->create_message(message_name);
}
const Message *MessageManager::message_from_binary(std::istream& istream) const
{
// Read the message header
MessageHeader header;
header.read_from(istream);
// Get the message module that uses the specified service id
auto *message_module = get_module(header.get_service_id());
if (!message_module)
{
std::ostringstream oss;
oss << "No service exists with id: " << (uint16_t)header.get_service_id();
throw value_error(oss.str(), value_error::DML_INVALID_SERVICE);
}
// Get the message template for this message type
auto *message_template = message_module->get_message_template(header.get_type());
if (!message_template)
{
std::ostringstream oss;
oss << "No message exists with type: " << (uint16_t)header.get_service_id();
oss << "(service=" << message_module->get_protocol_type() << ")";
throw value_error(oss.str(), value_error::DML_INVALID_MESSAGE_TYPE);
}
// Make sure that the size specified is enough to read this message
if (header.get_message_size() < message_template->get_record().get_size())
{
std::ostringstream oss;
oss << "No message exists with type: " << (uint16_t)header.get_service_id();
oss << "(service=" << message_module->get_protocol_type() << ")";
throw value_error(oss.str(), value_error::DML_INVALID_MESSAGE_TYPE);
}
// Create a new Message from the template
auto *message = new Message(message_template);
try
{
message->get_record()->read_from(istream);
}
catch (ki::dml::parse_error &e)
{
delete message;
throw parse_error("Failed to read DML message payload.", parse_error::INVALID_MESSAGE_DATA);
}
return message;
}
}
}
}

View File

@ -0,0 +1,171 @@
#include "ki/protocol/dml/MessageModule.h"
#include "ki/protocol/exception.h"
#include <sstream>
namespace ki
{
namespace protocol
{
namespace dml
{
MessageModule::MessageModule(uint8_t service_id, std::string protocol_type)
{
m_service_id = service_id;
m_protocol_type = protocol_type;
m_protocol_description = "";
m_last_message_type = 0;
}
MessageModule::~MessageModule()
{
for (auto it = m_templates.begin();
it != m_templates.end(); ++it)
delete *it;
m_message_type_map.clear();
m_message_name_map.clear();
}
uint8_t MessageModule::get_service_id() const
{
return m_service_id;
}
void MessageModule::set_service_id(uint8_t service_id)
{
m_service_id = service_id;
}
std::string MessageModule::get_protocol_type() const
{
return m_protocol_type;
}
void MessageModule::set_protocol_type(std::string protocol_type)
{
m_protocol_type = protocol_type;
}
std::string MessageModule::get_protocol_desription() const
{
return m_protocol_description;
}
void MessageModule::set_protocol_description(std::string protocol_description)
{
m_protocol_description = protocol_description;
}
const MessageTemplate *MessageModule::add_message_template(std::string name,
ki::dml::Record *record, bool auto_sort)
{
if (!record)
return nullptr;
// If the field exists, get the name from the record rather than the XML
auto *name_field = record->get_field<ki::dml::STR>("_MsgName");
if (name_field)
name = name_field->get_value();
// Do we already have a message template with this name?
if (m_message_name_map.count(name) == 1)
return m_message_name_map.at(name);
// Message type is based on the _MsgOrder field if it's present
// Otherwise it's based on the alphabetical order of template names
uint8_t message_type = 0;
auto *order_field = record->get_field<ki::dml::UBYT>("_MsgOrder");
if (order_field)
{
message_type = order_field->get_value();
// Don't allow message type to be 0
if (message_type == 0)
return nullptr;
// Do we already have a template with this type?
if (m_message_type_map.count(message_type) == 1)
return nullptr;
}
// Create the message template and add it to our lookups
auto *message_template = new MessageTemplate(name, message_type, m_service_id, record);
m_templates.push_back(message_template);
m_message_name_map.insert({ name, message_template });
// Is this module ordered?
if (message_type != 0)
m_message_type_map.insert({ message_type, message_template });
else if (auto_sort)
sort_lookup();
return message_template;
}
const MessageTemplate *MessageModule::get_message_template(uint8_t type) const
{
if (m_message_type_map.count(type) == 1)
return m_message_type_map.at(type);
return nullptr;
}
const MessageTemplate *MessageModule::get_message_template(std::string name) const
{
if (m_message_name_map.count(name) == 1)
return m_message_name_map.at(name);
return nullptr;
}
void MessageModule::sort_lookup()
{
uint8_t message_type = 1;
// First, clear the message type map since we're going to be
// moving everything around
m_message_type_map.clear();
// Iterating over a map with std::string as the key
// is guaranteed to be in alphabetical order
for (auto it = m_message_name_map.begin();
it != m_message_name_map.end(); ++it)
{
auto *message_template = it->second;
message_template->set_type(message_type);
m_message_type_map.insert({ message_type, message_template });
message_type++;
// Make sure we haven't overflowed
if (message_type == 0)
throw value_error("Module has more than 254 messages.", value_error::EXCEEDS_LIMIT);
}
}
Message *MessageModule::create_message(uint8_t message_type) const
{
auto *message_template = get_message_template(message_type);
if (!message_template)
{
std::ostringstream oss;
oss << "No message exists with type: " << message_type;
oss << "(service=" << m_protocol_type << ")";
throw value_error(oss.str(), value_error::DML_INVALID_MESSAGE_TYPE);
}
return message_template->create_message();
}
Message *MessageModule::create_message(std::string message_name) const
{
auto *message_template = get_message_template(message_name);
if (!message_template)
{
std::ostringstream oss;
oss << "No message exists with name: " << message_name;
oss << "(service=" << m_protocol_type << ")";
throw value_error(oss.str(), value_error::DML_INVALID_MESSAGE_NAME);
}
return message_template->create_message();
}
}
}
}

View File

@ -0,0 +1,95 @@
#include "ki/protocol/dml/MessageTemplate.h"
namespace ki
{
namespace protocol
{
namespace dml
{
MessageTemplate::MessageTemplate(std::string name, uint8_t type,
uint8_t service_id, ki::dml::Record* record)
{
m_name = name;
m_type = type;
m_service_id = service_id;
m_record = record;
}
MessageTemplate::~MessageTemplate()
{
delete m_record;
}
std::string MessageTemplate::get_name() const
{
return m_name;
}
void MessageTemplate::set_name(std::string name)
{
m_name = name;
}
uint8_t MessageTemplate::get_type() const
{
return m_type;
}
void MessageTemplate::set_type(uint8_t type)
{
m_type = type;
}
uint8_t MessageTemplate::get_service_id() const
{
return m_service_id;
}
void MessageTemplate::set_service_id(uint8_t service_id)
{
m_service_id = service_id;
}
std::string MessageTemplate::get_handler() const
{
const auto field = m_record->get_field<ki::dml::STR>("_MsgHandler");
if (field)
return field->get_value();
return m_name;
}
void MessageTemplate::set_handler(std::string handler)
{
m_record->add_field<ki::dml::STR>("_MsgHandler")->set_value(handler);
}
uint8_t MessageTemplate::get_access_level() const
{
const auto field = m_record->get_field<ki::dml::UBYT>("_MsgAccessLvl");
if (field)
return field->get_value();
return 0;
}
void MessageTemplate::set_access_level(uint8_t access_level)
{
m_record->add_field<ki::dml::UBYT>("_MsgAccessLvl")->set_value(access_level);
}
const ki::dml::Record& MessageTemplate::get_record() const
{
return *m_record;
}
void MessageTemplate::set_record(ki::dml::Record* record)
{
m_record = record;
}
Message *MessageTemplate::create_message() const
{
return new Message(this);
}
}
}
}

View File

@ -0,0 +1,173 @@
#include "ki/protocol/net/ClientSession.h"
#include "ki/protocol/control/SessionOffer.h"
#include "ki/protocol/control/SessionAccept.h"
#include "ki/protocol/control/ClientKeepAlive.h"
#include "ki/protocol/control/ServerKeepAlive.h"
#include "ki/protocol/exception.h"
namespace ki
{
namespace protocol
{
namespace net
{
ClientSession::ClientSession(const uint16_t id)
: Session(id) {}
void ClientSession::send_keep_alive()
{
// Don't send a keep alive if we're waiting for a response
if (m_waiting_for_keep_alive_response)
return;
m_waiting_for_keep_alive_response = true;
// Work out how many minutes have been since the establish time, and
// how many milliseconds we are in to the current minute.
const auto time_since_establish = std::chrono::steady_clock::now() - m_establish_time;
const auto minutes = std::chrono::duration_cast<std::chrono::minutes>(time_since_establish);
const auto milliseconds = std::chrono::duration_cast<std::chrono::milliseconds>(
time_since_establish - minutes
).count();
// Send a KEEP_ALIVE packet
control::ClientKeepAlive keep_alive(m_id, milliseconds, minutes.count());
send_packet(true, (uint8_t)control::Opcode::KEEP_ALIVE, keep_alive);
m_last_sent_heartbeat_time = std::chrono::steady_clock::now();
}
bool ClientSession::is_alive() const
{
// If the session isn't established yet, use the time of
// creation to decide whether this session is alive.
if (!m_established)
return std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::steady_clock::now() - m_creation_time
).count() <= (KI_CONNECTION_TIMEOUT * 2);
// Otherwise, use the last time we received a heartbeat.
return std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::steady_clock::now() - m_last_received_heartbeat_time
).count() <= (KI_SERVER_HEARTBEAT * 2);
}
void ClientSession::on_connected()
{
m_connection_time = std::chrono::steady_clock::now();
}
void ClientSession::on_control_message(const PacketHeader& header)
{
switch ((control::Opcode)header.get_opcode())
{
case control::Opcode::SESSION_OFFER:
on_session_offer();
break;
case control::Opcode::KEEP_ALIVE:
on_keep_alive();
break;
case control::Opcode::KEEP_ALIVE_RSP:
on_keep_alive_response();
break;
default:
close(SessionCloseErrorCode::UNHANDLED_CONTROL_MESSAGE);
break;
}
}
void ClientSession::on_session_offer()
{
// Read the payload data into a structure
control::SessionOffer offer;
try
{
offer = read_data<control::SessionOffer>();
}
catch (parse_error &e)
{
// The SESSION_OFFER wasn't valid...
// Close the session
close(SessionCloseErrorCode::INVALID_MESSAGE);
return;
}
// Should this session have already timed out?
if (std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::steady_clock::now() - m_connection_time
).count() > KI_CONNECTION_TIMEOUT)
{
close(SessionCloseErrorCode::SESSION_OFFER_TIMED_OUT);
return;
}
// Work out the current timestamp and how many milliseconds
// have elapsed in the current second.
auto now = std::chrono::system_clock::now();
const auto timestamp = std::chrono::duration_cast<std::chrono::seconds>(
now.time_since_epoch()
).count();
const auto milliseconds = std::chrono::duration_cast<std::chrono::milliseconds>(
now.time_since_epoch()
).count() - (timestamp * 1000);
// Accept the session
m_id = offer.get_session_id();
control::SessionAccept accept(m_id, timestamp, milliseconds);
send_packet(true, (uint8_t)control::Opcode::SESSION_ACCEPT, accept);
// The session is successfully established
m_established = true;
m_establish_time = std::chrono::steady_clock::now();
m_last_received_heartbeat_time = m_establish_time;
on_established();
}
void ClientSession::on_keep_alive()
{
// Read the payload data into a structure
control::ServerKeepAlive keep_alive;
try
{
keep_alive = read_data<control::ServerKeepAlive>();
}
catch (parse_error &e)
{
// The KEEP_ALIVE wasn't valid...
// Close the session
close(SessionCloseErrorCode::INVALID_MESSAGE);
return;
}
// Send the response
m_last_received_heartbeat_time = std::chrono::steady_clock::now();
send_packet(true, (uint8_t)control::Opcode::KEEP_ALIVE_RSP, keep_alive);
}
void ClientSession::on_keep_alive_response()
{
// Read the payload data into a structure
try
{
// We don't actually need the data inside, but
// read it to check if the structure is right.
read_data<control::ClientKeepAlive>();
}
catch (parse_error &e)
{
// The KEEP_ALIVE_RSP wasn't valid...
// Close the session
close(SessionCloseErrorCode::INVALID_MESSAGE);
return;
}
// Calculate latency and allow for KEEP_ALIVE packets to be sent again
m_latency = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now() - m_last_sent_heartbeat_time
).count();
m_waiting_for_keep_alive_response = false;
}
}
}
}

View File

@ -0,0 +1,77 @@
#include "ki/protocol/net/DMLSession.h"
#include "ki/protocol/exception.h"
namespace ki
{
namespace protocol
{
namespace net
{
DMLSession::DMLSession(const uint16_t id, const dml::MessageManager& manager)
: Session(id), m_manager(manager) {}
const dml::MessageManager& DMLSession::get_manager() const
{
return m_manager;
}
void DMLSession::send_message(const dml::Message& message)
{
send_packet(false, 0, message);
}
void DMLSession::on_application_message(const PacketHeader& header)
{
// Attempt to create a Message instance from the data in the stream
auto error_code = InvalidDMLMessageErrorCode::NONE;
const dml::Message *message = nullptr;
try
{
message = m_manager.message_from_binary(m_data_stream);
}
catch (parse_error &e)
{
switch (e.get_error_code())
{
case parse_error::INVALID_HEADER_DATA:
error_code = InvalidDMLMessageErrorCode::INVALID_HEADER_DATA;
break;
case parse_error::INSUFFICIENT_MESSAGE_DATA:
case parse_error::INVALID_MESSAGE_DATA:
error_code = InvalidDMLMessageErrorCode::INVALID_MESSAGE_DATA;
break;
default:
error_code = InvalidDMLMessageErrorCode::UNKNOWN;
}
}
catch (value_error &e)
{
switch (e.get_error_code())
{
case value_error::DML_INVALID_SERVICE:
error_code = InvalidDMLMessageErrorCode::INVALID_SERVICE;
break;
case value_error::DML_INVALID_MESSAGE_TYPE:
error_code = InvalidDMLMessageErrorCode::INVALID_MESSAGE_TYPE;
break;
default:
error_code = InvalidDMLMessageErrorCode::UNKNOWN;
}
}
if (!message)
{
on_invalid_message(error_code);
return;
}
// Are we sufficiently authenticated to handle this message?
if (get_access_level() >= message->get_access_level())
on_message(message);
else
on_invalid_message(InvalidDMLMessageErrorCode::INSUFFICIENT_ACCESS);
delete message;
}
}
}
}

View File

@ -0,0 +1,67 @@
#include "ki/protocol/net/PacketHeader.h"
#include "ki/protocol/exception.h"
#include <sstream>
namespace ki
{
namespace protocol
{
namespace net
{
PacketHeader::PacketHeader(const bool control, const uint8_t opcode)
{
m_control = control;
m_opcode = opcode;
}
bool PacketHeader::is_control() const
{
return m_control;
}
void PacketHeader::set_control(const bool control)
{
m_control = control;
}
uint8_t PacketHeader::get_opcode() const
{
return m_opcode;
}
void PacketHeader::set_opcode(const uint8_t opcode)
{
m_opcode = opcode;
}
void PacketHeader::write_to(std::ostream& ostream) const
{
ostream.put(m_control);
ostream.put(m_opcode);
ostream.put(0);
ostream.put(0);
}
void PacketHeader::read_from(std::istream& istream)
{
m_control = istream.get() >= 1;
if (istream.fail())
throw parse_error("Not enough data was available to read packet header. (m_control)",
parse_error::INVALID_HEADER_DATA);
m_opcode = istream.get();
if (istream.fail())
throw parse_error("Not enough data was available to read packet header. (m_opcode)",
parse_error::INVALID_HEADER_DATA);
istream.ignore(2);
if (istream.eof())
throw parse_error("Not enough data was available to read packet header. (ignored bytes)",
parse_error::INVALID_HEADER_DATA);
}
size_t PacketHeader::get_size() const
{
return 4;
}
}
}
}

View File

@ -0,0 +1,171 @@
#include "ki/protocol/net/ServerSession.h"
#include "ki/protocol/control/SessionOffer.h"
#include "ki/protocol/control/SessionAccept.h"
#include "ki/protocol/control/ClientKeepAlive.h"
#include "ki/protocol/control/ServerKeepAlive.h"
#include "ki/protocol/exception.h"
namespace ki
{
namespace protocol
{
namespace net
{
ServerSession::ServerSession(const uint16_t id)
: Session(id) {}
void ServerSession::send_keep_alive(const uint32_t milliseconds_since_startup)
{
// Don't send a keep alive if we're waiting for a response
if (m_waiting_for_keep_alive_response)
return;
m_waiting_for_keep_alive_response = true;
// Send a KEEP_ALIVE packet
const control::ServerKeepAlive keep_alive(milliseconds_since_startup);
send_packet(true, (uint8_t)control::Opcode::KEEP_ALIVE, keep_alive);
m_last_sent_heartbeat_time = std::chrono::steady_clock::now();
}
bool ServerSession::is_alive() const
{
// If the session isn't established yet, use the time of
// creation to decide whether this session is alive.
if (!m_established)
return std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::steady_clock::now() - m_creation_time
).count() <= (KI_CONNECTION_TIMEOUT * 2);
// Otherwise, use the last time we received a heartbeat.
return std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::steady_clock::now() - m_last_received_heartbeat_time
).count() <= (KI_CLIENT_HEARTBEAT * 2);
}
void ServerSession::on_connected()
{
m_connection_time = std::chrono::steady_clock::now();
// Work out the current timestamp and how many milliseconds
// have elapsed in the current second.
auto now = std::chrono::system_clock::now();
const auto timestamp = std::chrono::duration_cast<std::chrono::seconds>(
now.time_since_epoch()
).count();
const auto milliseconds = std::chrono::duration_cast<std::chrono::milliseconds>(
now.time_since_epoch()
).count() - (timestamp * 1000);
// Send a SESSION_OFFER packet to the client
const control::SessionOffer offer(m_id, timestamp, milliseconds);
send_packet(true, (uint8_t)control::Opcode::SESSION_OFFER, offer);
}
void ServerSession::on_control_message(const PacketHeader& header)
{
switch ((control::Opcode)header.get_opcode())
{
case control::Opcode::SESSION_ACCEPT:
on_session_accept();
break;
case control::Opcode::KEEP_ALIVE:
on_keep_alive();
break;
case control::Opcode::KEEP_ALIVE_RSP:
on_keep_alive_response();
break;
default:
close(SessionCloseErrorCode::UNHANDLED_CONTROL_MESSAGE);
break;
}
}
void ServerSession::on_session_accept()
{
// Read the payload data into a structure
control::SessionAccept accept;
try
{
accept = read_data<control::SessionAccept>();
}
catch (parse_error &e)
{
// The SESSION_ACCEPT wasn't valid...
// Close the session
close(SessionCloseErrorCode::INVALID_MESSAGE);
return;
}
// Should this session have already timed out?
if (std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::steady_clock::now() - m_connection_time
).count() > KI_CONNECTION_TIMEOUT)
{
close(SessionCloseErrorCode::SESSION_OFFER_TIMED_OUT);
return;
}
// Make sure they're accepting this session
if (accept.get_session_id() != m_id)
{
close(SessionCloseErrorCode::INVALID_MESSAGE);
return;
}
// The session is successfully established
m_established = true;
m_establish_time = std::chrono::steady_clock::now();
m_last_received_heartbeat_time = m_establish_time;
on_established();
}
void ServerSession::on_keep_alive()
{
// Read the payload data into a structure
control::ClientKeepAlive keep_alive;
try
{
keep_alive = read_data<control::ClientKeepAlive>();
}
catch (parse_error &e)
{
// The KEEP_ALIVE wasn't valid...
// Close the session
close(SessionCloseErrorCode::INVALID_MESSAGE);
return;
}
// Send the response
m_last_received_heartbeat_time = std::chrono::steady_clock::now();
send_packet(true, (uint8_t)control::Opcode::KEEP_ALIVE_RSP, keep_alive);
}
void ServerSession::on_keep_alive_response()
{
// Read the payload data into a structure
try
{
// We don't actually need the data inside, but
// read it to check if the structure is right.
read_data<control::ServerKeepAlive>();
}
catch (parse_error &e)
{
// The KEEP_ALIVE_RSP wasn't valid...
// Close the session
close(SessionCloseErrorCode::INVALID_MESSAGE);
return;
}
// Calculate latency and allow for KEEP_ALIVE packets to be sent again
m_latency = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now() - m_last_sent_heartbeat_time
).count();
m_waiting_for_keep_alive_response = false;
}
}
}
}

View File

@ -0,0 +1,190 @@
#include "ki/protocol/net/Session.h"
#include "ki/protocol/exception.h"
#include <cstring>
namespace ki
{
namespace protocol
{
namespace net
{
Session::Session(const uint16_t id)
{
m_id = id;
m_established = false;
m_access_level = 0;
m_latency = 0;
m_creation_time = std::chrono::steady_clock::now();
m_waiting_for_keep_alive_response = false;
m_maximum_packet_size = KI_DEFAULT_MAXIMUM_RECEIVE_SIZE;
m_receive_state = ReceiveState::WAITING_FOR_START_SIGNAL;
m_start_signal = 0;
m_incoming_packet_size = 0;
m_shift = 0;
}
uint16_t Session::get_maximum_packet_size() const
{
return m_maximum_packet_size;
}
void Session::set_maximum_packet_size(const uint16_t maximum_packet_size)
{
m_maximum_packet_size = maximum_packet_size;
}
uint16_t Session::get_id() const
{
return m_id;
}
bool Session::is_established() const
{
return m_established;
}
uint8_t Session::get_access_level() const
{
return m_access_level;
}
void Session::set_access_level(const uint8_t access_level)
{
m_access_level = access_level;
}
uint16_t Session::get_latency() const
{
return m_latency;
}
void Session::send_packet(const bool is_control, const uint8_t opcode,
const util::Serializable& data)
{
std::ostringstream ss;
PacketHeader header(is_control, opcode);
header.write_to(ss);
data.write_to(ss);
const auto buffer = ss.str();
send_data(buffer.c_str(), buffer.length());
}
void Session::send_data(const char* data, const size_t size)
{
// Allocate the entire buffer
char *packet_data = new char[size + 4];
// Add the frame header
((uint16_t *)packet_data)[0] = KI_START_SIGNAL;
((uint16_t *)packet_data)[1] = size;
// Copy the payload into the buffer and send it
std::memcpy(&packet_data[4], data, size);
send_packet_data(packet_data, size + 4);
delete[] packet_data;
}
void Session::process_data(const char *data, const size_t size)
{
size_t position = 0;
while (position < size)
{
switch (m_receive_state)
{
case ReceiveState::WAITING_FOR_START_SIGNAL:
m_start_signal |= ((uint8_t)data[position] << m_shift);
if (m_shift == 0)
m_shift = 8;
else
{
// If the start signal isn't correct, we've either
// gotten out of sync, or they are not framing packets
// correctly.
if (m_start_signal != KI_START_SIGNAL)
{
close(SessionCloseErrorCode::INVALID_FRAMING_START_SIGNAL);
return;
}
// Reset the shift and incoming packet size
m_shift = 0;
m_incoming_packet_size = 0;
m_receive_state = ReceiveState::WAITING_FOR_LENGTH;
}
position++;
break;
case ReceiveState::WAITING_FOR_LENGTH:
m_incoming_packet_size |= ((uint8_t)data[position] << m_shift);
if (m_shift == 0)
m_shift = 8;
else
{
// If the incoming packet is larger than we are accepting
// stop processing data.
if (m_incoming_packet_size > m_maximum_packet_size)
{
close(SessionCloseErrorCode::INVALID_FRAMING_SIZE_EXCEEDS_MAXIMUM);
return;
}
// Reset read and write positions
m_data_stream.seekp(0, std::ios::beg);
m_data_stream.seekg(0, std::ios::beg);
m_receive_state = ReceiveState::WAITING_FOR_PACKET;
}
position++;
break;
case ReceiveState::WAITING_FOR_PACKET:
// Work out how much data we should read into our stream
const size_t data_available = (size - position);
const size_t read_size = (data_available >= m_incoming_packet_size) ?
m_incoming_packet_size : data_available;
// Write the data to the data stream
m_data_stream.write(&data[position], read_size);
position += read_size;
m_incoming_packet_size -= read_size;
// Have we received the entire packet?
if (m_incoming_packet_size == 0)
{
on_packet_available();
// Reset the shift and start signal
m_shift = 0;
m_start_signal = 0;
m_receive_state = ReceiveState::WAITING_FOR_START_SIGNAL;
}
break;
}
}
}
void Session::on_packet_available()
{
// Read the packet header
PacketHeader header;
try
{
header.read_from(m_data_stream);
}
catch (parse_error &e)
{
on_invalid_packet();
return;
}
// Hand off to the right handler based on
// whether this is a control packet or not
if (header.is_control())
on_control_message(header);
else
on_application_message(header);
}
}
}
}

188
test/src/unit-protocol.cpp Normal file
View File

@ -0,0 +1,188 @@
#define CATCH_CONFIG_MAIN
#include <catch.hpp>
#include <fstream>
#include <ki/protocol/control/SessionOffer.h>
#include <ki/protocol/control/SessionAccept.h>
#include <ki/protocol/control/ClientKeepAlive.h>
#include <ki/protocol/control/ServerKeepAlive.h>
using namespace ki::protocol;
TEST_CASE("Control Message Serialization", "[control]")
{
std::ostringstream oss;
SECTION("SessionOffer")
{
control::SessionOffer offer(0xABCD, 0xAABBCCDD, 0xAABBCCDD);
offer.write_to(oss);
const char expected_bytes[] = {
// Session ID
'\xCD', '\xAB',
// Unknown
'\x00', '\x00', '\x00', '\x00',
// Timestamp
'\xDD', '\xCC', '\xBB', '\xAA',
// Milliseconds
'\xDD', '\xCC', '\xBB', '\xAA'
};
REQUIRE(oss.str() == std::string(expected_bytes, sizeof(expected_bytes)));
}
SECTION("SessionAccept")
{
control::SessionAccept accept(0xABCD, 0xAABBCCDD, 0xAABBCCDD);
accept.write_to(oss);
const char expected_bytes[] = {
// Unknown
'\x00', '\x00',
// Unknown
'\x00', '\x00', '\x00', '\x00',
// Timestamp
'\xDD', '\xCC', '\xBB', '\xAA',
// Milliseconds
'\xDD', '\xCC', '\xBB', '\xAA',
// Session ID
'\xCD', '\xAB'
};
REQUIRE(oss.str() == std::string(expected_bytes, sizeof(expected_bytes)));
}
SECTION("ClientKeepAlive")
{
control::ClientKeepAlive keep_alive(0xABCD, 0xABCD, 0xABCD);
keep_alive.write_to(oss);
const char expected_bytes[] = {
// Session ID
'\xCD', '\xAB',
// Milliseconds
'\xCD', '\xAB',
// Minutes
'\xCD', '\xAB'
};
REQUIRE(oss.str() == std::string(expected_bytes, sizeof(expected_bytes)));
}
SECTION("ServerKeepAlive")
{
control::ServerKeepAlive keep_alive(0xAABBCCDD);
keep_alive.write_to(oss);
const char expected_bytes[] = {
// Unknown
'\x00', '\x00',
// Timestamp
'\xDD', '\xCC', '\xBB', '\xAA'
};
REQUIRE(oss.str() == std::string(expected_bytes, sizeof(expected_bytes)));
}
}
TEST_CASE("Control Message Deserialization", "[control]")
{
SECTION("SessionOffer")
{
const char bytes[] = {
// Session ID
'\xCD', '\xAB',
// Unknown
'\x00', '\x00', '\x00', '\x00',
// Timestamp
'\xDD', '\xCC', '\xBB', '\xAA',
// Milliseconds
'\xDD', '\xCC', '\xBB', '\xAA'
};
std::istringstream iss(std::string(bytes, sizeof(bytes)));
control::SessionOffer offer;
offer.read_from(iss);
REQUIRE(offer.get_session_id() == 0xABCD);
REQUIRE(offer.get_timestamp() == 0xAABBCCDD);
REQUIRE(offer.get_milliseconds() == 0xAABBCCDD);
}
SECTION("SessionAccept")
{
const char bytes[] = {
// Unknown
'\x00', '\x00',
// Unknown
'\x00', '\x00', '\x00', '\x00',
// Timestamp
'\xDD', '\xCC', '\xBB', '\xAA',
// Milliseconds
'\xDD', '\xCC', '\xBB', '\xAA',
// Session ID
'\xCD', '\xAB'
};
std::istringstream iss(std::string(bytes, sizeof(bytes)));
control::SessionAccept accept;
accept.read_from(iss);
REQUIRE(accept.get_session_id() == 0xABCD);
REQUIRE(accept.get_timestamp() == 0xAABBCCDD);
REQUIRE(accept.get_milliseconds() == 0xAABBCCDD);
}
SECTION("ClientKeepAlive")
{
const char bytes[] = {
// Session ID
'\xCD', '\xAB',
// Milliseconds
'\xCD', '\xAB',
// Minutes
'\xCD', '\xAB'
};
std::istringstream iss(std::string(bytes, sizeof(bytes)));
control::ClientKeepAlive keep_alive;
keep_alive.read_from(iss);
REQUIRE(keep_alive.get_session_id() == 0xABCD);
REQUIRE(keep_alive.get_milliseconds() == 0xABCD);
REQUIRE(keep_alive.get_minutes() == 0xABCD);
}
SECTION("ServerKeepAlive")
{
const char bytes[] = {
// Unknown
'\x00', '\x00',
// Timestamp
'\xDD', '\xCC', '\xBB', '\xAA'
};
std::istringstream iss(std::string(bytes, sizeof(bytes)));
control::ServerKeepAlive keep_alive;
keep_alive.read_from(iss);
REQUIRE(keep_alive.get_timestamp() == 0xAABBCCDD);
}
}