mirror of https://github.com/SeanOMik/libki.git
Merge pull request #1 from Joshsora/messaging
Implements network and DML protocols
This commit is contained in:
commit
cb4d89dad9
|
@ -20,6 +20,7 @@ target_include_directories(${PROJECT_NAME}
|
|||
target_link_libraries(${PROJECT_NAME} RapidXML)
|
||||
|
||||
add_subdirectory("src/dml")
|
||||
add_subdirectory("src/protocol")
|
||||
|
||||
option(KI_BUILD_EXAMPLES "Determines whether to build examples." ON)
|
||||
if (KI_BUILD_EXAMPLES)
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
#include <ki/protocol/dml/MessageManager.h>
|
||||
#include <ki/protocol/exception.h>
|
||||
#include <iostream>
|
||||
|
||||
using namespace ki::protocol;
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
// Get command-line arguments
|
||||
if (argc < 3)
|
||||
{
|
||||
std::cout << "usage: example-dml-module.exe <module_file> <message_name>" << std::endl;
|
||||
std::cout << "Prints out information for specified message." << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Create a manager to load modules into
|
||||
auto *message_manager = new dml::MessageManager();
|
||||
const dml::MessageModule *message_module;
|
||||
|
||||
// Load the message module file
|
||||
const std::string filepath = argv[1];
|
||||
try
|
||||
{
|
||||
message_module = message_manager->load_module(filepath);
|
||||
}
|
||||
catch (value_error &e)
|
||||
{
|
||||
std::cout << "Failed to load message module.";
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Print some information about the module itself
|
||||
std::cout << "Service ID: " << (uint16_t)message_module->get_service_id() << std::endl;
|
||||
std::cout << "Protocol Type: " << message_module->get_protocol_type() << std::endl;
|
||||
|
||||
// Get the message template from the module we just loaded
|
||||
const std::string message_name = argv[2];
|
||||
auto *message_template = message_module->get_message_template(message_name);
|
||||
if (message_template)
|
||||
{
|
||||
std::cout << "Message Name: " << message_template->get_name() << std::endl;
|
||||
std::cout << "Mesasge Type: " << (uint16_t)message_template->get_type() << std::endl;
|
||||
|
||||
// Print out the fields in the template record
|
||||
std::cout << std::endl;
|
||||
auto &record = message_template->get_record();
|
||||
for (auto it = record.fields_begin();
|
||||
it != record.fields_end(); ++it)
|
||||
{
|
||||
auto *field = *it;
|
||||
if (field->is_transferable())
|
||||
std::cout << field->get_type_name() << " " << field->get_name() << ";" << std::endl;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Could not find message with name: " << message_name << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Exit successfully
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
#pragma once
|
||||
#include "../../util/Serializable.h"
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace control
|
||||
{
|
||||
class ClientKeepAlive final : public util::Serializable
|
||||
{
|
||||
public:
|
||||
ClientKeepAlive(uint16_t session_id = 0,
|
||||
uint16_t milliseconds = 0, uint16_t minutes = 0);
|
||||
virtual ~ClientKeepAlive() = default;
|
||||
|
||||
uint16_t get_session_id() const;
|
||||
void set_session_id(uint16_t session_id);
|
||||
|
||||
uint16_t get_milliseconds() const;
|
||||
void set_milliseconds(uint16_t milliseconds);
|
||||
|
||||
uint16_t get_minutes() const;
|
||||
void set_minutes(uint16_t minutes);
|
||||
|
||||
void write_to(std::ostream &ostream) const override final;
|
||||
void read_from(std::istream &istream) override final;
|
||||
size_t get_size() const override final;
|
||||
private:
|
||||
uint16_t m_session_id;
|
||||
uint16_t m_milliseconds;
|
||||
uint16_t m_minutes;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
#pragma once
|
||||
#include <cstdint>
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace control
|
||||
{
|
||||
enum class Opcode : uint8_t
|
||||
{
|
||||
NONE = 0,
|
||||
SESSION_OFFER = 0,
|
||||
UDP_HELLO = 1,
|
||||
KEEP_ALIVE = 3,
|
||||
KEEP_ALIVE_RSP = 4,
|
||||
SESSION_ACCEPT = 5
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
#pragma once
|
||||
#include "../../util/Serializable.h"
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace control
|
||||
{
|
||||
class ServerKeepAlive final : public util::Serializable
|
||||
{
|
||||
public:
|
||||
ServerKeepAlive(uint32_t timestamp = 0);
|
||||
virtual ~ServerKeepAlive() = default;
|
||||
|
||||
uint32_t get_timestamp() const;
|
||||
void set_timestamp(uint32_t timestamp);
|
||||
|
||||
void write_to(std::ostream &ostream) const override final;
|
||||
void read_from(std::istream &istream) override final;
|
||||
size_t get_size() const override final;
|
||||
private:
|
||||
uint32_t m_timestamp;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
#pragma once
|
||||
#include "../../util/Serializable.h"
|
||||
#include <iostream>
|
||||
#include <cstdint>
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace control
|
||||
{
|
||||
class SessionAccept final : public util::Serializable
|
||||
{
|
||||
public:
|
||||
SessionAccept(uint16_t session_id = 0,
|
||||
int32_t timestamp = 0, uint32_t milliseconds = 0);
|
||||
virtual ~SessionAccept() = default;
|
||||
|
||||
uint16_t get_session_id() const;
|
||||
void set_session_id(uint16_t session_id);
|
||||
|
||||
int32_t get_timestamp() const;
|
||||
void set_timestamp(int32_t timestamp);
|
||||
|
||||
uint32_t get_milliseconds() const;
|
||||
void set_milliseconds(uint32_t milliseconds);
|
||||
|
||||
void write_to(std::ostream &ostream) const override final;
|
||||
void read_from(std::istream &istream) override final;
|
||||
size_t get_size() const override final;
|
||||
private:
|
||||
uint16_t m_session_id;
|
||||
int32_t m_timestamp;
|
||||
uint32_t m_milliseconds;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
#pragma once
|
||||
#include "../../util/Serializable.h"
|
||||
#include <iostream>
|
||||
#include <cstdint>
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace control
|
||||
{
|
||||
class SessionOffer final : public util::Serializable
|
||||
{
|
||||
public:
|
||||
SessionOffer(uint16_t session_id = 0,
|
||||
int32_t timestamp = 0, uint32_t milliseconds = 0);
|
||||
virtual ~SessionOffer() = default;
|
||||
|
||||
uint16_t get_session_id() const;
|
||||
void set_session_id(uint16_t session_id);
|
||||
|
||||
int32_t get_timestamp() const;
|
||||
void set_timestamp(int32_t timestamp);
|
||||
|
||||
uint32_t get_milliseconds() const;
|
||||
void set_milliseconds(uint32_t milliseconds);
|
||||
|
||||
void write_to(std::ostream &ostream) const override final;
|
||||
void read_from(std::istream &istream) override final;
|
||||
size_t get_size() const override final;
|
||||
private:
|
||||
uint16_t m_session_id;
|
||||
int32_t m_timestamp;
|
||||
uint32_t m_milliseconds;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,50 @@
|
|||
#pragma once
|
||||
#include "MessageHeader.h"
|
||||
#include "../../util/Serializable.h"
|
||||
#include "../../dml/Record.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace dml
|
||||
{
|
||||
class MessageTemplate;
|
||||
|
||||
class Message final : public util::Serializable
|
||||
{
|
||||
public:
|
||||
Message(const MessageTemplate *message_template = nullptr);
|
||||
virtual ~Message();
|
||||
|
||||
const MessageTemplate *get_template() const;
|
||||
void set_template(const MessageTemplate *message_template);
|
||||
|
||||
ki::dml::Record *get_record();
|
||||
const ki::dml::Record *get_record() const;
|
||||
|
||||
ki::dml::FieldBase *get_field(std::string name);
|
||||
const ki::dml::FieldBase *get_field(std::string name) const;
|
||||
|
||||
uint8_t get_service_id() const;
|
||||
uint8_t get_type() const;
|
||||
uint16_t get_message_size() const;
|
||||
std::string get_handler() const;
|
||||
uint8_t get_access_level() const;
|
||||
|
||||
void write_to(std::ostream &ostream) const override final;
|
||||
void read_from(std::istream &istream) override final;
|
||||
size_t get_size() const override final;
|
||||
private:
|
||||
const MessageTemplate *m_template;
|
||||
ki::dml::Record *m_record;
|
||||
|
||||
// This is used to store raw data when a Message is
|
||||
// constructed without a MessageTemplate.
|
||||
MessageHeader m_header;
|
||||
std::vector<char> m_raw_data;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
#pragma once
|
||||
#include "../../util/Serializable.h"
|
||||
#include <iostream>
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace dml
|
||||
{
|
||||
class MessageHeader : public util::Serializable
|
||||
{
|
||||
public:
|
||||
MessageHeader(uint8_t service_id = 0,
|
||||
uint8_t type = 0, uint16_t size = 0);
|
||||
virtual ~MessageHeader() = default;
|
||||
|
||||
uint8_t get_service_id() const;
|
||||
void set_service_id(uint8_t service_id);
|
||||
|
||||
uint8_t get_type() const;
|
||||
void set_type(uint8_t type);
|
||||
|
||||
uint16_t get_message_size() const;
|
||||
void set_message_size(uint16_t size);
|
||||
|
||||
void write_to(std::ostream &ostream) const override final;
|
||||
void read_from(std::istream &istream) override final;
|
||||
size_t get_size() const override final;
|
||||
private:
|
||||
uint8_t m_service_id;
|
||||
uint8_t m_type;
|
||||
uint16_t m_size;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
#pragma once
|
||||
#include "Message.h"
|
||||
#include "MessageModule.h"
|
||||
#include "../../dml/Record.h"
|
||||
#include <string>
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace dml
|
||||
{
|
||||
class MessageManager
|
||||
{
|
||||
public:
|
||||
MessageManager() = default;
|
||||
~MessageManager();
|
||||
|
||||
const MessageModule *load_module(std::string filepath);
|
||||
const MessageModule *get_module(uint8_t service_id) const;
|
||||
const MessageModule *get_module(const std::string &protocol_type) const;
|
||||
|
||||
Message *create_message(uint8_t service_id, uint8_t message_type) const;
|
||||
Message *create_message(uint8_t service_id, const std::string &message_name) const;
|
||||
Message *create_message(const std::string &protocol_type, uint8_t message_type) const;
|
||||
Message *create_message(const std::string &protocol_type, const std::string &message_name) const;
|
||||
|
||||
/**
|
||||
* If the DML message header cannot be read, then a nullptr
|
||||
* is returned; otherwise, a valid Message pointer is always returned.
|
||||
* However, that does not mean that the message itself is valid.
|
||||
*
|
||||
* To verify if the record was completely parsed, get_record
|
||||
* should return a valid Record pointer, rather than nullptr.
|
||||
*/
|
||||
const Message *message_from_binary(std::istream &istream) const;
|
||||
private:
|
||||
MessageModuleList m_modules;
|
||||
MessageModuleServiceIdMap m_service_id_map;
|
||||
MessageModuleProtocolTypeMap m_protocol_type_map;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
#pragma once
|
||||
#include "Message.h"
|
||||
#include "MessageTemplate.h"
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace dml
|
||||
{
|
||||
class MessageModule
|
||||
{
|
||||
public:
|
||||
MessageModule(uint8_t service_id = 0, std::string protocol_type = "");
|
||||
~MessageModule();
|
||||
|
||||
uint8_t get_service_id() const;
|
||||
void set_service_id(uint8_t service_id);
|
||||
|
||||
std::string get_protocol_type() const;
|
||||
void set_protocol_type(std::string protocol_type);
|
||||
|
||||
std::string get_protocol_desription() const;
|
||||
void set_protocol_description(std::string protocol_description);
|
||||
|
||||
const MessageTemplate *add_message_template(std::string name,
|
||||
ki::dml::Record *record, bool auto_sort = true);
|
||||
const MessageTemplate *get_message_template(uint8_t type) const;
|
||||
const MessageTemplate *get_message_template(std::string name) const;
|
||||
|
||||
void sort_lookup();
|
||||
|
||||
Message *create_message(uint8_t message_type) const;
|
||||
Message *create_message(std::string message_name) const;
|
||||
private:
|
||||
uint8_t m_service_id;
|
||||
std::string m_protocol_type;
|
||||
std::string m_protocol_description;
|
||||
uint8_t m_last_message_type;
|
||||
|
||||
std::vector<MessageTemplate *> m_templates;
|
||||
std::map<uint8_t, MessageTemplate *> m_message_type_map;
|
||||
std::map<std::string, MessageTemplate *> m_message_name_map;
|
||||
};
|
||||
|
||||
typedef std::vector<MessageModule *> MessageModuleList;
|
||||
typedef std::map<uint8_t, MessageModule *> MessageModuleServiceIdMap;
|
||||
typedef std::map<std::string, MessageModule *> MessageModuleProtocolTypeMap;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
#pragma once
|
||||
#include "../../dml/Record.h"
|
||||
#include "Message.h"
|
||||
#include <string>
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace dml
|
||||
{
|
||||
class MessageTemplate
|
||||
{
|
||||
public:
|
||||
MessageTemplate(std::string name, uint8_t type,
|
||||
uint8_t service_id, ki::dml::Record *record);
|
||||
~MessageTemplate();
|
||||
|
||||
std::string get_name() const;
|
||||
void set_name(std::string name);
|
||||
|
||||
uint8_t get_type() const;
|
||||
void set_type(uint8_t type);
|
||||
|
||||
uint8_t get_service_id() const;
|
||||
void set_service_id(uint8_t service_id);
|
||||
|
||||
std::string get_handler() const;
|
||||
void set_handler(std::string handler);
|
||||
|
||||
uint8_t get_access_level() const;
|
||||
void set_access_level(uint8_t access_level);
|
||||
|
||||
const ki::dml::Record &get_record() const;
|
||||
void set_record(ki::dml::Record *record);
|
||||
|
||||
Message *create_message() const;
|
||||
private:
|
||||
std::string m_name;
|
||||
uint8_t m_type;
|
||||
uint8_t m_service_id;
|
||||
ki::dml::Record *m_record;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
#pragma once
|
||||
#include <stdexcept>
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
class runtime_error : public std::runtime_error
|
||||
{
|
||||
public:
|
||||
explicit runtime_error(std::string message) : std::runtime_error(message) {}
|
||||
};
|
||||
|
||||
class parse_error : public runtime_error
|
||||
{
|
||||
public:
|
||||
enum code
|
||||
{
|
||||
NONE,
|
||||
INVALID_XML_DATA,
|
||||
INVALID_HEADER_DATA,
|
||||
INSUFFICIENT_MESSAGE_DATA,
|
||||
INVALID_MESSAGE_DATA
|
||||
};
|
||||
|
||||
explicit parse_error(std::string message, code error = code::NONE)
|
||||
: runtime_error(message)
|
||||
{
|
||||
m_code = error;
|
||||
}
|
||||
|
||||
code get_error_code() const { return m_code; }
|
||||
private:
|
||||
code m_code;
|
||||
};
|
||||
|
||||
class value_error : public runtime_error
|
||||
{
|
||||
public:
|
||||
enum code
|
||||
{
|
||||
NONE,
|
||||
MISSING_FILE,
|
||||
OVERWRITES_LOOKUP,
|
||||
EXCEEDS_LIMIT,
|
||||
|
||||
DML_INVALID_SERVICE,
|
||||
DML_INVALID_PROTOCOL_TYPE,
|
||||
DML_INVALID_MESSAGE_TYPE,
|
||||
DML_INVALID_MESSAGE_NAME
|
||||
};
|
||||
|
||||
explicit value_error(std::string message, code error = code::NONE)
|
||||
: runtime_error(message)
|
||||
{
|
||||
m_code = error;
|
||||
}
|
||||
|
||||
code get_error_code() const { return m_code; }
|
||||
private:
|
||||
code m_code;
|
||||
};
|
||||
}
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
#pragma once
|
||||
#include "ClientSession.h"
|
||||
#include "DMLSession.h"
|
||||
|
||||
// Disable inheritance via dominance warning
|
||||
#if _MSC_VER
|
||||
#pragma warning(disable: 4250)
|
||||
#endif
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace net
|
||||
{
|
||||
class ClientDMLSession : public ClientSession, public DMLSession
|
||||
{
|
||||
// Explicitly specify that we are intentionally inheritting
|
||||
// via dominance.
|
||||
using DMLSession::on_application_message;
|
||||
using ClientSession::on_control_message;
|
||||
using ClientSession::is_alive;
|
||||
public:
|
||||
ClientDMLSession(const uint16_t id, const dml::MessageManager &manager)
|
||||
: Session(id), ClientSession(id), DMLSession(id, manager) {}
|
||||
virtual ~ClientDMLSession() = default;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Re-enable inheritance via dominance warning
|
||||
#if _MSC_VER
|
||||
#pragma warning(default: 4250)
|
||||
#endif
|
|
@ -0,0 +1,34 @@
|
|||
#pragma once
|
||||
#include "Session.h"
|
||||
|
||||
#define KI_SERVER_HEARTBEAT 60
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace net
|
||||
{
|
||||
/**
|
||||
* Implements client-sided session logic.
|
||||
*/
|
||||
class ClientSession : public virtual Session
|
||||
{
|
||||
public:
|
||||
explicit ClientSession(uint16_t id);
|
||||
virtual ~ClientSession() = default;
|
||||
|
||||
void send_keep_alive();
|
||||
bool is_alive() const override;
|
||||
protected:
|
||||
void on_connected();
|
||||
virtual void on_established() {}
|
||||
void on_control_message(const PacketHeader& header) override;
|
||||
private:
|
||||
void on_session_offer();
|
||||
void on_keep_alive();
|
||||
void on_keep_alive_response();
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,44 @@
|
|||
#pragma once
|
||||
#include "Session.h"
|
||||
#include "../dml/MessageManager.h"
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace net
|
||||
{
|
||||
enum class InvalidDMLMessageErrorCode
|
||||
{
|
||||
NONE,
|
||||
UNKNOWN,
|
||||
INVALID_HEADER_DATA,
|
||||
INVALID_MESSAGE_DATA,
|
||||
INVALID_SERVICE,
|
||||
INVALID_MESSAGE_TYPE,
|
||||
INSUFFICIENT_ACCESS
|
||||
};
|
||||
|
||||
/**
|
||||
* Implements an application protocol that uses the DML
|
||||
* message system (as seen in Wizard101 and Pirate101).
|
||||
*/
|
||||
class DMLSession : public virtual Session
|
||||
{
|
||||
public:
|
||||
DMLSession(uint16_t id, const dml::MessageManager &manager);
|
||||
virtual ~DMLSession() = default;
|
||||
|
||||
const dml::MessageManager &get_manager() const;
|
||||
|
||||
void send_message(const dml::Message &message);
|
||||
protected:
|
||||
void on_application_message(const PacketHeader& header) override;
|
||||
virtual void on_message(const dml::Message *message) {}
|
||||
virtual void on_invalid_message(InvalidDMLMessageErrorCode error) {}
|
||||
private:
|
||||
const dml::MessageManager &m_manager;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
#pragma once
|
||||
#include "../../util/Serializable.h"
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace net
|
||||
{
|
||||
class PacketHeader final : public util::Serializable
|
||||
{
|
||||
public:
|
||||
PacketHeader(bool control = false, uint8_t opcode = 0);
|
||||
virtual ~PacketHeader() = default;
|
||||
|
||||
bool is_control() const;
|
||||
void set_control(bool control);
|
||||
|
||||
uint8_t get_opcode() const;
|
||||
void set_opcode(uint8_t opcode);
|
||||
|
||||
void write_to(std::ostream &ostream) const override final;
|
||||
void read_from(std::istream &istream) override final;
|
||||
size_t get_size() const override final;
|
||||
private:
|
||||
bool m_control;
|
||||
uint8_t m_opcode;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
#pragma once
|
||||
#include "ServerSession.h"
|
||||
#include "DMLSession.h"
|
||||
|
||||
// Disable inheritance via dominance warning
|
||||
#if _MSC_VER
|
||||
#pragma warning(disable: 4250)
|
||||
#endif
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace net
|
||||
{
|
||||
class ServerDMLSession : public ServerSession, public DMLSession
|
||||
{
|
||||
// Explicitly specify that we are intentionally inheritting
|
||||
// via dominance.
|
||||
using DMLSession::on_application_message;
|
||||
using ServerSession::on_control_message;
|
||||
using ServerSession::is_alive;
|
||||
public:
|
||||
ServerDMLSession(const uint16_t id, const dml::MessageManager &manager)
|
||||
: Session(id), ServerSession(id), DMLSession(id, manager) {}
|
||||
virtual ~ServerDMLSession() = default;
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Re-enable inheritance via dominance warning
|
||||
#if _MSC_VER
|
||||
#pragma warning(default: 4250)
|
||||
#endif
|
|
@ -0,0 +1,34 @@
|
|||
#pragma once
|
||||
#include "Session.h"
|
||||
|
||||
#define KI_CLIENT_HEARTBEAT 10
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace net
|
||||
{
|
||||
/**
|
||||
* Implements server-sided session logic.
|
||||
*/
|
||||
class ServerSession : public virtual Session
|
||||
{
|
||||
public:
|
||||
explicit ServerSession(uint16_t id);
|
||||
virtual ~ServerSession() = default;
|
||||
|
||||
void send_keep_alive(uint32_t milliseconds_since_startup);
|
||||
bool is_alive() const override;
|
||||
protected:
|
||||
void on_connected();
|
||||
virtual void on_established() {}
|
||||
void on_control_message(const PacketHeader& header) override;
|
||||
private:
|
||||
void on_session_accept();
|
||||
void on_keep_alive();
|
||||
void on_keep_alive_response();
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,138 @@
|
|||
#pragma once
|
||||
#include "PacketHeader.h"
|
||||
#include "../control/Opcode.h"
|
||||
#include "../../util/Serializable.h"
|
||||
#include <cstdint>
|
||||
#include <sstream>
|
||||
#include <chrono>
|
||||
#include <type_traits>
|
||||
|
||||
#define KI_DEFAULT_MAXIMUM_RECEIVE_SIZE 0x2000
|
||||
#define KI_START_SIGNAL 0xF00D
|
||||
#define KI_CONNECTION_TIMEOUT 3
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace net
|
||||
{
|
||||
enum class SessionCloseErrorCode
|
||||
{
|
||||
NONE,
|
||||
APPLICATION_ERROR,
|
||||
|
||||
INVALID_FRAMING_START_SIGNAL,
|
||||
INVALID_FRAMING_SIZE_EXCEEDS_MAXIMUM,
|
||||
|
||||
UNHANDLED_CONTROL_MESSAGE,
|
||||
UNHANDLED_APPLICATION_MESSAGE,
|
||||
INVALID_MESSAGE,
|
||||
|
||||
SESSION_OFFER_TIMED_OUT,
|
||||
SESSION_DIED
|
||||
};
|
||||
|
||||
enum class ReceiveState
|
||||
{
|
||||
// Waiting for the 0xF00D start signal.
|
||||
WAITING_FOR_START_SIGNAL,
|
||||
|
||||
// Waiting for the 2-byte length.
|
||||
WAITING_FOR_LENGTH,
|
||||
|
||||
// Waiting for the packet data.
|
||||
WAITING_FOR_PACKET
|
||||
};
|
||||
|
||||
/**
|
||||
* This class implements session and packet framing logic
|
||||
* when sending and receiving data to/from an external
|
||||
* source.
|
||||
*/
|
||||
class Session
|
||||
{
|
||||
public:
|
||||
explicit Session(uint16_t id = 0);
|
||||
virtual ~Session() = default;
|
||||
|
||||
uint16_t get_maximum_packet_size() const;
|
||||
void set_maximum_packet_size(uint16_t maximum_packet_size);
|
||||
|
||||
uint16_t get_id() const;
|
||||
bool is_established() const;
|
||||
|
||||
uint8_t get_access_level() const;
|
||||
void set_access_level(uint8_t access_level);
|
||||
|
||||
uint16_t get_latency() const;
|
||||
|
||||
virtual bool is_alive() const = 0;
|
||||
|
||||
void send_packet(bool is_control, uint8_t opcode,
|
||||
const util::Serializable &data);
|
||||
protected:
|
||||
/* Higher-level session members */
|
||||
uint16_t m_id;
|
||||
bool m_established;
|
||||
uint8_t m_access_level;
|
||||
|
||||
/* Timing members */
|
||||
std::chrono::steady_clock::time_point m_creation_time;
|
||||
std::chrono::steady_clock::time_point m_connection_time;
|
||||
std::chrono::steady_clock::time_point m_establish_time;
|
||||
std::chrono::steady_clock::time_point m_last_received_heartbeat_time;
|
||||
std::chrono::steady_clock::time_point m_last_sent_heartbeat_time;
|
||||
bool m_waiting_for_keep_alive_response;
|
||||
uint16_t m_latency;
|
||||
|
||||
// The packet data stream
|
||||
std::stringstream m_data_stream;
|
||||
|
||||
/**
|
||||
* Reads a serializable structure from the data stream.
|
||||
*/
|
||||
template <typename DataT>
|
||||
DataT read_data()
|
||||
{
|
||||
static_assert(std::is_base_of<util::Serializable, DataT>::value,
|
||||
"DataT must inherit Serializable.");
|
||||
|
||||
DataT data = DataT();
|
||||
data.read_from(m_data_stream);
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Frames raw data into a Packet, and transmits it.
|
||||
*/
|
||||
void send_data(const char *data, size_t size);
|
||||
|
||||
/**
|
||||
* Process incoming raw data into Packets.
|
||||
* Once a packet is read into the internal data
|
||||
* stream, handle_packet_available is called.
|
||||
*/
|
||||
void process_data(const char *data, size_t size);
|
||||
|
||||
/* Event handlers */
|
||||
virtual void on_invalid_packet() {}
|
||||
virtual void on_control_message(const PacketHeader &header) {}
|
||||
virtual void on_application_message(const PacketHeader &header) {}
|
||||
|
||||
/* Low-level socket methods */
|
||||
virtual void send_packet_data(const char *data, const size_t size) = 0;
|
||||
virtual void close(SessionCloseErrorCode error) = 0;
|
||||
private:
|
||||
/* Low-level networking members */
|
||||
uint16_t m_maximum_packet_size;
|
||||
ReceiveState m_receive_state;
|
||||
uint16_t m_start_signal;
|
||||
uint16_t m_incoming_packet_size;
|
||||
uint8_t m_shift;
|
||||
|
||||
void on_packet_available();
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
target_sources(${PROJECT_NAME}
|
||||
PRIVATE
|
||||
${PROJECT_SOURCE_DIR}/src/protocol/control/ClientKeepAlive.cpp
|
||||
${PROJECT_SOURCE_DIR}/src/protocol/control/ServerKeepAlive.cpp
|
||||
${PROJECT_SOURCE_DIR}/src/protocol/control/SessionAccept.cpp
|
||||
${PROJECT_SOURCE_DIR}/src/protocol/control/SessionOffer.cpp
|
||||
${PROJECT_SOURCE_DIR}/src/protocol/dml/Message.cpp
|
||||
${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageHeader.cpp
|
||||
${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageManager.cpp
|
||||
${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageModule.cpp
|
||||
${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageTemplate.cpp
|
||||
${PROJECT_SOURCE_DIR}/src/protocol/net/ClientSession.cpp
|
||||
${PROJECT_SOURCE_DIR}/src/protocol/net/DMLSession.cpp
|
||||
${PROJECT_SOURCE_DIR}/src/protocol/net/PacketHeader.cpp
|
||||
${PROJECT_SOURCE_DIR}/src/protocol/net/ServerSession.cpp
|
||||
${PROJECT_SOURCE_DIR}/src/protocol/net/Session.cpp
|
||||
)
|
|
@ -0,0 +1,87 @@
|
|||
#include "ki/protocol/control/ClientKeepAlive.h"
|
||||
#include "ki/dml/Record.h"
|
||||
#include "ki/protocol/exception.h"
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace control
|
||||
{
|
||||
ClientKeepAlive::ClientKeepAlive(const uint16_t session_id, const uint16_t milliseconds,
|
||||
const uint16_t minutes)
|
||||
{
|
||||
m_session_id = session_id;
|
||||
m_milliseconds = milliseconds;
|
||||
m_minutes = minutes;
|
||||
}
|
||||
|
||||
uint16_t ClientKeepAlive::get_session_id() const
|
||||
{
|
||||
return m_session_id;
|
||||
}
|
||||
|
||||
void ClientKeepAlive::set_session_id(const uint16_t session_id)
|
||||
{
|
||||
m_session_id = session_id;
|
||||
}
|
||||
|
||||
uint16_t ClientKeepAlive::get_milliseconds() const
|
||||
{
|
||||
return m_milliseconds;
|
||||
}
|
||||
|
||||
void ClientKeepAlive::set_milliseconds(const uint16_t milliseconds)
|
||||
{
|
||||
m_milliseconds = milliseconds;
|
||||
}
|
||||
|
||||
uint16_t ClientKeepAlive::get_minutes() const
|
||||
{
|
||||
return m_minutes;
|
||||
}
|
||||
|
||||
void ClientKeepAlive::set_minutes(const uint16_t minutes)
|
||||
{
|
||||
m_minutes = minutes;
|
||||
}
|
||||
|
||||
void ClientKeepAlive::write_to(std::ostream &ostream) const
|
||||
{
|
||||
dml::Record record;
|
||||
record.add_field<dml::USHRT>("m_session_id")->set_value(m_session_id);
|
||||
record.add_field<dml::USHRT>("m_milliseconds")->set_value(m_milliseconds);
|
||||
record.add_field<dml::USHRT>("m_minutes")->set_value(m_minutes);
|
||||
record.write_to(ostream);
|
||||
}
|
||||
|
||||
void ClientKeepAlive::read_from(std::istream &istream)
|
||||
{
|
||||
dml::Record record;
|
||||
auto *session_id = record.add_field<dml::USHRT>("m_session_id");
|
||||
auto *milliseconds = record.add_field<dml::USHRT>("m_milliseconds");
|
||||
auto *minutes = record.add_field<dml::USHRT>("m_minutes");
|
||||
try
|
||||
{
|
||||
record.read_from(istream);
|
||||
}
|
||||
catch (dml::parse_error &e)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "Error reading ClientKeepAlive payload: " << e.what();
|
||||
throw parse_error(oss.str(), parse_error::INVALID_MESSAGE_DATA);
|
||||
}
|
||||
|
||||
m_session_id = session_id->get_value();
|
||||
m_milliseconds = milliseconds->get_value();
|
||||
m_minutes = minutes->get_value();
|
||||
}
|
||||
|
||||
size_t ClientKeepAlive::get_size() const
|
||||
{
|
||||
return sizeof(dml::USHRT) + sizeof(dml::USHRT) +
|
||||
sizeof(dml::USHRT);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
#include "ki/protocol/control/ServerKeepAlive.h"
|
||||
#include "ki/dml/Record.h"
|
||||
#include "ki/protocol/exception.h"
|
||||
#include <chrono>
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace control
|
||||
{
|
||||
ServerKeepAlive::ServerKeepAlive(const uint32_t timestamp)
|
||||
{
|
||||
m_timestamp = timestamp;
|
||||
}
|
||||
|
||||
uint32_t ServerKeepAlive::get_timestamp() const
|
||||
{
|
||||
return m_timestamp;
|
||||
}
|
||||
|
||||
void ServerKeepAlive::set_timestamp(const uint32_t timestamp)
|
||||
{
|
||||
m_timestamp = timestamp;
|
||||
}
|
||||
|
||||
void ServerKeepAlive::write_to(std::ostream& ostream) const
|
||||
{
|
||||
dml::Record record;
|
||||
record.add_field<dml::USHRT>("m_session_id");
|
||||
record.add_field<dml::INT>("m_timestamp")->set_value(m_timestamp);
|
||||
record.write_to(ostream);
|
||||
}
|
||||
|
||||
void ServerKeepAlive::read_from(std::istream& istream)
|
||||
{
|
||||
dml::Record record;
|
||||
record.add_field<dml::USHRT>("m_session_id");
|
||||
auto *timestamp = record.add_field<dml::INT>("m_timestamp");
|
||||
try
|
||||
{
|
||||
record.read_from(istream);
|
||||
}
|
||||
catch (dml::parse_error &e)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "Error reading ServerKeepAlive payload: " << e.what();
|
||||
throw parse_error(oss.str(), parse_error::INVALID_MESSAGE_DATA);
|
||||
}
|
||||
|
||||
m_timestamp = timestamp->get_value();
|
||||
}
|
||||
|
||||
size_t ServerKeepAlive::get_size() const
|
||||
{
|
||||
return sizeof(dml::USHRT) + sizeof(dml::INT);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,92 @@
|
|||
#include "ki/protocol/control/SessionAccept.h"
|
||||
#include "ki/dml/Record.h"
|
||||
#include "ki/protocol/exception.h"
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace control
|
||||
{
|
||||
SessionAccept::SessionAccept(const uint16_t session_id,
|
||||
const int32_t timestamp, const uint32_t milliseconds)
|
||||
{
|
||||
m_session_id = session_id;
|
||||
m_timestamp = timestamp;
|
||||
m_milliseconds = milliseconds;
|
||||
}
|
||||
|
||||
uint16_t SessionAccept::get_session_id() const
|
||||
{
|
||||
return m_session_id;
|
||||
}
|
||||
|
||||
void SessionAccept::set_session_id(const uint16_t session_id)
|
||||
{
|
||||
m_session_id = session_id;
|
||||
}
|
||||
|
||||
int32_t SessionAccept::get_timestamp() const
|
||||
{
|
||||
return m_timestamp;
|
||||
}
|
||||
|
||||
void SessionAccept::set_timestamp(const int32_t timestamp)
|
||||
{
|
||||
m_timestamp = timestamp;
|
||||
}
|
||||
|
||||
uint32_t SessionAccept::get_milliseconds() const
|
||||
{
|
||||
return m_milliseconds;
|
||||
}
|
||||
|
||||
void SessionAccept::set_milliseconds(const uint32_t milliseconds)
|
||||
{
|
||||
m_milliseconds = milliseconds;
|
||||
}
|
||||
|
||||
void SessionAccept::write_to(std::ostream& ostream) const
|
||||
{
|
||||
dml::Record record;
|
||||
record.add_field<dml::USHRT>("unknown");
|
||||
record.add_field<dml::UINT>("unknown2");
|
||||
record.add_field<dml::INT>("m_timestamp")->set_value(m_timestamp);
|
||||
record.add_field<dml::UINT>("m_milliseconds")->set_value(m_milliseconds);
|
||||
record.add_field<dml::USHRT>("m_session_id")->set_value(m_session_id);
|
||||
record.write_to(ostream);
|
||||
}
|
||||
|
||||
void SessionAccept::read_from(std::istream& istream)
|
||||
{
|
||||
dml::Record record;
|
||||
record.add_field<dml::USHRT>("unknown");
|
||||
record.add_field<dml::UINT>("unknown2");
|
||||
auto *timestamp = record.add_field<dml::INT>("m_timestamp");
|
||||
auto *milliseconds = record.add_field<dml::UINT>("m_milliseconds");
|
||||
auto *session_id = record.add_field<dml::USHRT>("m_session_id");
|
||||
try
|
||||
{
|
||||
record.read_from(istream);
|
||||
}
|
||||
catch (dml::parse_error &e)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "Error reading SessionAccept payload: " << e.what();
|
||||
throw parse_error(oss.str(), parse_error::INVALID_MESSAGE_DATA);
|
||||
}
|
||||
|
||||
m_timestamp = timestamp->get_value();
|
||||
m_milliseconds = milliseconds->get_value();
|
||||
m_session_id = session_id->get_value();
|
||||
}
|
||||
|
||||
size_t SessionAccept::get_size() const
|
||||
{
|
||||
return sizeof(dml::USHRT) + sizeof(dml::UINT) +
|
||||
sizeof(dml::INT) + sizeof(dml::UINT) +
|
||||
sizeof(dml::USHRT);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,89 @@
|
|||
#include "ki/protocol/control/SessionOffer.h"
|
||||
#include "ki/dml/Record.h"
|
||||
#include "ki/protocol/exception.h"
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace control
|
||||
{
|
||||
SessionOffer::SessionOffer(const uint16_t session_id,
|
||||
const int32_t timestamp, const uint32_t milliseconds)
|
||||
{
|
||||
m_session_id = session_id;
|
||||
m_timestamp = timestamp;
|
||||
m_milliseconds = milliseconds;
|
||||
}
|
||||
|
||||
uint16_t SessionOffer::get_session_id() const
|
||||
{
|
||||
return m_session_id;
|
||||
}
|
||||
|
||||
void SessionOffer::set_session_id(const uint16_t session_id)
|
||||
{
|
||||
m_session_id = session_id;
|
||||
}
|
||||
|
||||
int32_t SessionOffer::get_timestamp() const
|
||||
{
|
||||
return m_timestamp;
|
||||
}
|
||||
|
||||
void SessionOffer::set_timestamp(const int32_t timestamp)
|
||||
{
|
||||
m_timestamp = timestamp;
|
||||
}
|
||||
|
||||
uint32_t SessionOffer::get_milliseconds() const
|
||||
{
|
||||
return m_milliseconds;
|
||||
}
|
||||
|
||||
void SessionOffer::set_milliseconds(const uint32_t milliseconds)
|
||||
{
|
||||
m_milliseconds = milliseconds;
|
||||
}
|
||||
|
||||
void SessionOffer::write_to(std::ostream& ostream) const
|
||||
{
|
||||
dml::Record record;
|
||||
record.add_field<dml::USHRT>("m_session_id")->set_value(m_session_id);
|
||||
record.add_field<dml::UINT>("unknown");
|
||||
record.add_field<dml::INT>("m_timestamp")->set_value(m_timestamp);
|
||||
record.add_field<dml::UINT>("m_milliseconds")->set_value(m_milliseconds);
|
||||
record.write_to(ostream);
|
||||
}
|
||||
|
||||
void SessionOffer::read_from(std::istream& istream)
|
||||
{
|
||||
dml::Record record;
|
||||
auto *session_id = record.add_field<dml::USHRT>("m_session_id");
|
||||
record.add_field<dml::UINT>("unknown");
|
||||
auto *timestamp = record.add_field<dml::INT>("m_timestamp");
|
||||
auto *milliseconds = record.add_field<dml::UINT>("m_milliseconds");
|
||||
try
|
||||
{
|
||||
record.read_from(istream);
|
||||
}
|
||||
catch (dml::parse_error &e)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "Error reading SessionOffer payload: " << e.what();
|
||||
throw parse_error(oss.str(), parse_error::INVALID_MESSAGE_DATA);
|
||||
}
|
||||
|
||||
m_session_id = session_id->get_value();
|
||||
m_timestamp = timestamp->get_value();
|
||||
m_milliseconds = milliseconds->get_value();
|
||||
}
|
||||
|
||||
size_t SessionOffer::get_size() const
|
||||
{
|
||||
return sizeof(dml::USHRT) + sizeof(dml::UINT) +
|
||||
sizeof(dml::INT) + sizeof(dml::UINT);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,173 @@
|
|||
#include "ki/protocol/dml/Message.h"
|
||||
#include "ki/protocol/dml/MessageTemplate.h"
|
||||
#include "ki/protocol/exception.h"
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace dml
|
||||
{
|
||||
Message::Message(const MessageTemplate *message_template)
|
||||
{
|
||||
m_template = message_template;
|
||||
if (m_template)
|
||||
m_record = new ki::dml::Record(m_template->get_record());
|
||||
else
|
||||
m_record = nullptr;
|
||||
}
|
||||
|
||||
Message::~Message()
|
||||
{
|
||||
delete m_record;
|
||||
}
|
||||
|
||||
const MessageTemplate *Message::get_template() const
|
||||
{
|
||||
return m_template;
|
||||
}
|
||||
|
||||
void Message::set_template(const MessageTemplate *message_template)
|
||||
{
|
||||
m_template = message_template;
|
||||
if (!m_template)
|
||||
return;
|
||||
|
||||
m_record = new ki::dml::Record(message_template->get_record());
|
||||
if (!m_raw_data.empty())
|
||||
{
|
||||
std::istringstream iss(std::string(m_raw_data.data(), m_raw_data.size()));
|
||||
try
|
||||
{
|
||||
m_record->read_from(iss);
|
||||
m_raw_data.clear();
|
||||
}
|
||||
catch (ki::dml::parse_error &e)
|
||||
{
|
||||
delete m_record;
|
||||
m_template = nullptr;
|
||||
m_record = nullptr;
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << "Error reading DML message payload: " << e.what();
|
||||
throw parse_error(oss.str(), parse_error::INVALID_MESSAGE_DATA);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
uint8_t Message::get_service_id() const
|
||||
{
|
||||
if (m_template)
|
||||
return m_template->get_service_id();
|
||||
return m_header.get_service_id();
|
||||
}
|
||||
|
||||
uint8_t Message::get_type() const
|
||||
{
|
||||
if (m_template)
|
||||
return m_template->get_type();
|
||||
return m_header.get_type();
|
||||
}
|
||||
|
||||
uint16_t Message::get_message_size() const
|
||||
{
|
||||
if (m_record)
|
||||
return m_record->get_size();
|
||||
return m_raw_data.size();
|
||||
}
|
||||
|
||||
std::string Message::get_handler() const
|
||||
{
|
||||
if (m_template)
|
||||
return m_template->get_handler();
|
||||
return "";
|
||||
}
|
||||
|
||||
uint8_t Message::get_access_level() const
|
||||
{
|
||||
if (m_template)
|
||||
return m_template->get_access_level();
|
||||
return 0;
|
||||
}
|
||||
|
||||
ki::dml::Record *Message::get_record()
|
||||
{
|
||||
return m_record;
|
||||
}
|
||||
|
||||
const ki::dml::Record *Message::get_record() const
|
||||
{
|
||||
return m_record;
|
||||
}
|
||||
|
||||
ki::dml::FieldBase* Message::get_field(std::string name)
|
||||
{
|
||||
if (m_record)
|
||||
return m_record->get_field(name);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const ki::dml::FieldBase* Message::get_field(std::string name) const
|
||||
{
|
||||
if (m_record)
|
||||
return m_record->get_field(name);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void Message::write_to(std::ostream &ostream) const
|
||||
{
|
||||
// Write the header
|
||||
if (m_template)
|
||||
{
|
||||
MessageHeader header(
|
||||
get_service_id(), get_type(), get_message_size());
|
||||
header.write_to(ostream);
|
||||
}
|
||||
else
|
||||
m_header.write_to(ostream);
|
||||
|
||||
// Write the payload
|
||||
if (m_record)
|
||||
m_record->write_to(ostream);
|
||||
else
|
||||
ostream.write(m_raw_data.data(), m_raw_data.size());
|
||||
}
|
||||
|
||||
void Message::read_from(std::istream &istream)
|
||||
{
|
||||
m_header.read_from(istream);
|
||||
if (m_template)
|
||||
{
|
||||
// Check for mismatches between the header and template
|
||||
if (m_header.get_service_id() != m_template->get_service_id())
|
||||
throw value_error("ServiceID mismatch between MessageHeader and assigned template.",
|
||||
value_error::DML_INVALID_SERVICE);
|
||||
if (m_header.get_type() != m_template->get_type())
|
||||
throw value_error("Message Type mismatch between MessageHeader and assigned template.",
|
||||
value_error::DML_INVALID_MESSAGE_TYPE);
|
||||
|
||||
// Read the payload into the record
|
||||
m_record->read_from(istream);
|
||||
}
|
||||
else
|
||||
{
|
||||
// We don't have a template for the record structure, so
|
||||
// just read the raw data into a buffer.
|
||||
const auto size = m_header.get_message_size();
|
||||
m_raw_data.resize(size);
|
||||
istream.read(m_raw_data.data(), size);
|
||||
if (istream.fail())
|
||||
throw parse_error("Not enough data was available to read DML message payload.",
|
||||
parse_error::INSUFFICIENT_MESSAGE_DATA);
|
||||
}
|
||||
}
|
||||
|
||||
size_t Message::get_size() const
|
||||
{
|
||||
if (m_record)
|
||||
return m_header.get_size() + m_record->get_size();
|
||||
return 4 + m_raw_data.size();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,88 @@
|
|||
#include "ki/protocol/dml/MessageHeader.h"
|
||||
#include "ki/dml/Record.h"
|
||||
#include "ki/protocol/exception.h"
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace dml
|
||||
{
|
||||
MessageHeader::MessageHeader(const uint8_t service_id,
|
||||
const uint8_t type, const uint16_t size)
|
||||
{
|
||||
m_service_id = service_id;
|
||||
m_type = type;
|
||||
m_size = size;
|
||||
}
|
||||
|
||||
uint8_t MessageHeader::get_service_id() const
|
||||
{
|
||||
return m_service_id;
|
||||
}
|
||||
|
||||
void MessageHeader::set_service_id(const uint8_t service_id)
|
||||
{
|
||||
m_service_id = service_id;
|
||||
}
|
||||
|
||||
uint8_t MessageHeader::get_type() const
|
||||
{
|
||||
return m_type;
|
||||
}
|
||||
|
||||
void MessageHeader::set_type(const uint8_t type)
|
||||
{
|
||||
m_type = type;
|
||||
}
|
||||
|
||||
uint16_t MessageHeader::get_message_size() const
|
||||
{
|
||||
return m_size;
|
||||
}
|
||||
|
||||
void MessageHeader::set_message_size(const uint16_t size)
|
||||
{
|
||||
m_size = size;
|
||||
}
|
||||
|
||||
void MessageHeader::write_to(std::ostream& ostream) const
|
||||
{
|
||||
ki::dml::Record record;
|
||||
record.add_field<ki::dml::UBYT>("m_service_id")->set_value(m_service_id);
|
||||
record.add_field<ki::dml::UBYT>("m_type")->set_value(m_type);
|
||||
record.add_field<ki::dml::USHRT>("m_size")->set_value(m_size + 4);
|
||||
record.write_to(ostream);
|
||||
}
|
||||
|
||||
void MessageHeader::read_from(std::istream& istream)
|
||||
{
|
||||
ki::dml::Record record;
|
||||
const auto *service_id = record.add_field<ki::dml::UBYT>("m_service_id");
|
||||
const auto *type = record.add_field<ki::dml::UBYT>("m_type");
|
||||
const auto size = record.add_field<ki::dml::USHRT>("m_size");
|
||||
|
||||
try
|
||||
{
|
||||
record.read_from(istream);
|
||||
}
|
||||
catch (ki::dml::parse_error &e)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "Error reading MessageHeader: " << e.what();
|
||||
throw parse_error(oss.str(), parse_error::INVALID_HEADER_DATA);
|
||||
}
|
||||
|
||||
m_service_id = service_id->get_value();
|
||||
m_type = type->get_value();
|
||||
m_size = size->get_value() - 4;
|
||||
}
|
||||
|
||||
size_t MessageHeader::get_size() const
|
||||
{
|
||||
return sizeof(ki::dml::UBYT) + sizeof(ki::dml::UBYT) +
|
||||
sizeof(ki::dml::USHRT);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,260 @@
|
|||
#include "ki/protocol/dml/MessageManager.h"
|
||||
#include "ki/protocol/dml/MessageHeader.h"
|
||||
#include "ki/protocol/exception.h"
|
||||
#include "ki/dml/Record.h"
|
||||
#include "ki/util/ValueBytes.h"
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <rapidxml.hpp>
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace dml
|
||||
{
|
||||
MessageManager::~MessageManager()
|
||||
{
|
||||
for (auto it = m_modules.begin();
|
||||
it != m_modules.end(); ++it)
|
||||
delete *it;
|
||||
m_modules.clear();
|
||||
m_service_id_map.clear();
|
||||
m_protocol_type_map.clear();
|
||||
}
|
||||
|
||||
const MessageModule *MessageManager::load_module(std::string filepath)
|
||||
{
|
||||
// Open the file
|
||||
std::ifstream ifs(filepath, std::ios::ate);
|
||||
if (!ifs.is_open())
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "Could not open file: " << filepath;
|
||||
throw value_error(oss.str(), value_error::MISSING_FILE);
|
||||
}
|
||||
|
||||
// Load contents into memory
|
||||
size_t size = ifs.tellg();
|
||||
ifs.seekg(0, std::ios::beg);
|
||||
char *data = new char[size + 1] { 0 };
|
||||
ifs.read(data, size);
|
||||
|
||||
// Parse the contents
|
||||
rapidxml::xml_document<> doc;
|
||||
try
|
||||
{
|
||||
doc.parse<0>(data);
|
||||
}
|
||||
catch (rapidxml::parse_error &e)
|
||||
{
|
||||
delete[] data;
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << "Failed to parse: " << filepath;
|
||||
throw parse_error(oss.str(), parse_error::INVALID_XML_DATA);
|
||||
}
|
||||
|
||||
// It's safe to allocate the module we're working on now
|
||||
auto *message_module = new MessageModule();
|
||||
|
||||
// Get the root node and iterate through children
|
||||
// Each child is a MessageTemplate
|
||||
auto *root = doc.first_node();
|
||||
for (auto *node = root->first_node();
|
||||
node; node = node->next_sibling())
|
||||
{
|
||||
// Parse the record node inside this node
|
||||
auto *record_node = node->first_node();
|
||||
if (!record_node)
|
||||
continue;
|
||||
auto *record = new ki::dml::Record();
|
||||
record->from_xml(record_node);
|
||||
|
||||
// The message name is initially based on the element name
|
||||
const std::string message_name = node->name();
|
||||
if (message_name == "_ProtocolInfo")
|
||||
{
|
||||
auto *service_id_field = record->get_field<ki::dml::UBYT>("ServiceID");
|
||||
auto *type_field = record->get_field<ki::dml::STR>("ProtocolType");
|
||||
auto *description_field = record->get_field<ki::dml::STR>("ProtocolDescription");
|
||||
|
||||
// Set the module metadata from this template
|
||||
if (service_id_field)
|
||||
message_module->set_service_id(service_id_field->get_value());
|
||||
if (type_field)
|
||||
message_module->set_protocol_type(type_field->get_value());
|
||||
if (description_field)
|
||||
message_module->set_protocol_description(description_field->get_value());
|
||||
}
|
||||
else
|
||||
{
|
||||
// Only do sorting after we've reached the final message
|
||||
// This only affects modules that aren't ordered with _MsgOrder.
|
||||
const bool auto_sort = node->next_sibling() == nullptr;
|
||||
|
||||
// The template will use the record itself to figure out name and type;
|
||||
// we only give the XML data incase the record doesn't have it defined.
|
||||
auto *message_template = message_module->add_message_template(message_name, record, auto_sort);
|
||||
if (!message_template)
|
||||
{
|
||||
delete[] data;
|
||||
delete message_module;
|
||||
delete record;
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << "Failed to create message template for ";
|
||||
oss << message_name;
|
||||
throw value_error(oss.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure we aren't overwriting another module
|
||||
if (m_service_id_map.count(message_module->get_service_id()) == 1)
|
||||
{
|
||||
delete[] data;
|
||||
delete message_module;
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << "Message Module has already been loaded with Service ID ";
|
||||
oss << (uint16_t)message_module->get_service_id();
|
||||
throw value_error(oss.str(), value_error::OVERWRITES_LOOKUP);
|
||||
}
|
||||
|
||||
if (m_protocol_type_map.count(message_module->get_protocol_type()) == 1)
|
||||
{
|
||||
delete[] data;
|
||||
delete message_module;
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << "Message Module has already been loaded with Protocol Type ";
|
||||
oss << message_module->get_protocol_type();
|
||||
throw value_error(oss.str(), value_error::OVERWRITES_LOOKUP);
|
||||
}
|
||||
|
||||
// Add it to our maps
|
||||
m_modules.push_back(message_module);
|
||||
m_service_id_map.insert({ message_module->get_service_id(), message_module });
|
||||
m_protocol_type_map.insert({ message_module->get_protocol_type(), message_module });
|
||||
|
||||
delete[] data;
|
||||
return message_module;
|
||||
}
|
||||
|
||||
const MessageModule *MessageManager::get_module(uint8_t service_id) const
|
||||
{
|
||||
if (m_service_id_map.count(service_id) == 1)
|
||||
return m_service_id_map.at(service_id);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const MessageModule *MessageManager::get_module(const std::string &protocol_type) const
|
||||
{
|
||||
if (m_protocol_type_map.count(protocol_type) == 1)
|
||||
return m_protocol_type_map.at(protocol_type);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Message *MessageManager::create_message(uint8_t service_id, uint8_t message_type) const
|
||||
{
|
||||
auto *message_module = get_module(service_id);
|
||||
if (!message_module)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "No service exists with id: " << (uint16_t)service_id;
|
||||
throw value_error(oss.str(), value_error::DML_INVALID_SERVICE);
|
||||
}
|
||||
|
||||
return message_module->create_message(message_type);
|
||||
}
|
||||
|
||||
Message *MessageManager::create_message(uint8_t service_id, const std::string& message_name) const
|
||||
{
|
||||
auto *message_module = get_module(service_id);
|
||||
if (!message_module)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "No service exists with id: " << (uint16_t)service_id;
|
||||
throw value_error(oss.str(), value_error::DML_INVALID_SERVICE);
|
||||
}
|
||||
|
||||
return message_module->create_message(message_name);
|
||||
}
|
||||
|
||||
Message *MessageManager::create_message(const std::string& protocol_type, uint8_t message_type) const
|
||||
{
|
||||
auto *message_module = get_module(protocol_type);
|
||||
if (!message_module)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "No service exists with protocol type: " << protocol_type;
|
||||
throw value_error(oss.str(), value_error::DML_INVALID_PROTOCOL_TYPE);
|
||||
}
|
||||
|
||||
return message_module->create_message(message_type);
|
||||
}
|
||||
|
||||
Message *MessageManager::create_message(const std::string& protocol_type, const std::string& message_name) const
|
||||
{
|
||||
auto *message_module = get_module(protocol_type);
|
||||
if (!message_module)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "No service exists with protocol type: " << protocol_type;
|
||||
throw value_error(oss.str(), value_error::DML_INVALID_PROTOCOL_TYPE);
|
||||
}
|
||||
|
||||
return message_module->create_message(message_name);
|
||||
}
|
||||
|
||||
const Message *MessageManager::message_from_binary(std::istream& istream) const
|
||||
{
|
||||
// Read the message header
|
||||
MessageHeader header;
|
||||
header.read_from(istream);
|
||||
|
||||
// Get the message module that uses the specified service id
|
||||
auto *message_module = get_module(header.get_service_id());
|
||||
if (!message_module)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "No service exists with id: " << (uint16_t)header.get_service_id();
|
||||
throw value_error(oss.str(), value_error::DML_INVALID_SERVICE);
|
||||
}
|
||||
|
||||
// Get the message template for this message type
|
||||
auto *message_template = message_module->get_message_template(header.get_type());
|
||||
if (!message_template)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "No message exists with type: " << (uint16_t)header.get_service_id();
|
||||
oss << "(service=" << message_module->get_protocol_type() << ")";
|
||||
throw value_error(oss.str(), value_error::DML_INVALID_MESSAGE_TYPE);
|
||||
}
|
||||
|
||||
// Make sure that the size specified is enough to read this message
|
||||
if (header.get_message_size() < message_template->get_record().get_size())
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "No message exists with type: " << (uint16_t)header.get_service_id();
|
||||
oss << "(service=" << message_module->get_protocol_type() << ")";
|
||||
throw value_error(oss.str(), value_error::DML_INVALID_MESSAGE_TYPE);
|
||||
}
|
||||
|
||||
// Create a new Message from the template
|
||||
auto *message = new Message(message_template);
|
||||
try
|
||||
{
|
||||
message->get_record()->read_from(istream);
|
||||
}
|
||||
catch (ki::dml::parse_error &e)
|
||||
{
|
||||
delete message;
|
||||
throw parse_error("Failed to read DML message payload.", parse_error::INVALID_MESSAGE_DATA);
|
||||
}
|
||||
return message;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,171 @@
|
|||
#include "ki/protocol/dml/MessageModule.h"
|
||||
#include "ki/protocol/exception.h"
|
||||
#include <sstream>
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace dml
|
||||
{
|
||||
MessageModule::MessageModule(uint8_t service_id, std::string protocol_type)
|
||||
{
|
||||
m_service_id = service_id;
|
||||
m_protocol_type = protocol_type;
|
||||
m_protocol_description = "";
|
||||
m_last_message_type = 0;
|
||||
}
|
||||
|
||||
MessageModule::~MessageModule()
|
||||
{
|
||||
for (auto it = m_templates.begin();
|
||||
it != m_templates.end(); ++it)
|
||||
delete *it;
|
||||
m_message_type_map.clear();
|
||||
m_message_name_map.clear();
|
||||
}
|
||||
|
||||
uint8_t MessageModule::get_service_id() const
|
||||
{
|
||||
return m_service_id;
|
||||
}
|
||||
|
||||
void MessageModule::set_service_id(uint8_t service_id)
|
||||
{
|
||||
m_service_id = service_id;
|
||||
}
|
||||
|
||||
std::string MessageModule::get_protocol_type() const
|
||||
{
|
||||
return m_protocol_type;
|
||||
}
|
||||
|
||||
void MessageModule::set_protocol_type(std::string protocol_type)
|
||||
{
|
||||
m_protocol_type = protocol_type;
|
||||
}
|
||||
|
||||
std::string MessageModule::get_protocol_desription() const
|
||||
{
|
||||
return m_protocol_description;
|
||||
}
|
||||
|
||||
void MessageModule::set_protocol_description(std::string protocol_description)
|
||||
{
|
||||
m_protocol_description = protocol_description;
|
||||
}
|
||||
|
||||
const MessageTemplate *MessageModule::add_message_template(std::string name,
|
||||
ki::dml::Record *record, bool auto_sort)
|
||||
{
|
||||
if (!record)
|
||||
return nullptr;
|
||||
|
||||
// If the field exists, get the name from the record rather than the XML
|
||||
auto *name_field = record->get_field<ki::dml::STR>("_MsgName");
|
||||
if (name_field)
|
||||
name = name_field->get_value();
|
||||
|
||||
// Do we already have a message template with this name?
|
||||
if (m_message_name_map.count(name) == 1)
|
||||
return m_message_name_map.at(name);
|
||||
|
||||
// Message type is based on the _MsgOrder field if it's present
|
||||
// Otherwise it's based on the alphabetical order of template names
|
||||
uint8_t message_type = 0;
|
||||
auto *order_field = record->get_field<ki::dml::UBYT>("_MsgOrder");
|
||||
if (order_field)
|
||||
{
|
||||
message_type = order_field->get_value();
|
||||
|
||||
// Don't allow message type to be 0
|
||||
if (message_type == 0)
|
||||
return nullptr;
|
||||
|
||||
// Do we already have a template with this type?
|
||||
if (m_message_type_map.count(message_type) == 1)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Create the message template and add it to our lookups
|
||||
auto *message_template = new MessageTemplate(name, message_type, m_service_id, record);
|
||||
m_templates.push_back(message_template);
|
||||
m_message_name_map.insert({ name, message_template });
|
||||
|
||||
// Is this module ordered?
|
||||
if (message_type != 0)
|
||||
m_message_type_map.insert({ message_type, message_template });
|
||||
else if (auto_sort)
|
||||
sort_lookup();
|
||||
|
||||
return message_template;
|
||||
}
|
||||
|
||||
const MessageTemplate *MessageModule::get_message_template(uint8_t type) const
|
||||
{
|
||||
if (m_message_type_map.count(type) == 1)
|
||||
return m_message_type_map.at(type);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const MessageTemplate *MessageModule::get_message_template(std::string name) const
|
||||
{
|
||||
if (m_message_name_map.count(name) == 1)
|
||||
return m_message_name_map.at(name);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void MessageModule::sort_lookup()
|
||||
{
|
||||
uint8_t message_type = 1;
|
||||
|
||||
// First, clear the message type map since we're going to be
|
||||
// moving everything around
|
||||
m_message_type_map.clear();
|
||||
|
||||
// Iterating over a map with std::string as the key
|
||||
// is guaranteed to be in alphabetical order
|
||||
for (auto it = m_message_name_map.begin();
|
||||
it != m_message_name_map.end(); ++it)
|
||||
{
|
||||
auto *message_template = it->second;
|
||||
message_template->set_type(message_type);
|
||||
m_message_type_map.insert({ message_type, message_template });
|
||||
message_type++;
|
||||
|
||||
// Make sure we haven't overflowed
|
||||
if (message_type == 0)
|
||||
throw value_error("Module has more than 254 messages.", value_error::EXCEEDS_LIMIT);
|
||||
}
|
||||
}
|
||||
|
||||
Message *MessageModule::create_message(uint8_t message_type) const
|
||||
{
|
||||
auto *message_template = get_message_template(message_type);
|
||||
if (!message_template)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "No message exists with type: " << message_type;
|
||||
oss << "(service=" << m_protocol_type << ")";
|
||||
throw value_error(oss.str(), value_error::DML_INVALID_MESSAGE_TYPE);
|
||||
}
|
||||
|
||||
return message_template->create_message();
|
||||
}
|
||||
|
||||
Message *MessageModule::create_message(std::string message_name) const
|
||||
{
|
||||
auto *message_template = get_message_template(message_name);
|
||||
if (!message_template)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "No message exists with name: " << message_name;
|
||||
oss << "(service=" << m_protocol_type << ")";
|
||||
throw value_error(oss.str(), value_error::DML_INVALID_MESSAGE_NAME);
|
||||
}
|
||||
|
||||
return message_template->create_message();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,95 @@
|
|||
#include "ki/protocol/dml/MessageTemplate.h"
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace dml
|
||||
{
|
||||
MessageTemplate::MessageTemplate(std::string name, uint8_t type,
|
||||
uint8_t service_id, ki::dml::Record* record)
|
||||
{
|
||||
m_name = name;
|
||||
m_type = type;
|
||||
m_service_id = service_id;
|
||||
m_record = record;
|
||||
}
|
||||
|
||||
MessageTemplate::~MessageTemplate()
|
||||
{
|
||||
delete m_record;
|
||||
}
|
||||
|
||||
std::string MessageTemplate::get_name() const
|
||||
{
|
||||
return m_name;
|
||||
}
|
||||
|
||||
void MessageTemplate::set_name(std::string name)
|
||||
{
|
||||
m_name = name;
|
||||
}
|
||||
|
||||
uint8_t MessageTemplate::get_type() const
|
||||
{
|
||||
return m_type;
|
||||
}
|
||||
|
||||
void MessageTemplate::set_type(uint8_t type)
|
||||
{
|
||||
m_type = type;
|
||||
}
|
||||
|
||||
uint8_t MessageTemplate::get_service_id() const
|
||||
{
|
||||
return m_service_id;
|
||||
}
|
||||
|
||||
void MessageTemplate::set_service_id(uint8_t service_id)
|
||||
{
|
||||
m_service_id = service_id;
|
||||
}
|
||||
|
||||
std::string MessageTemplate::get_handler() const
|
||||
{
|
||||
const auto field = m_record->get_field<ki::dml::STR>("_MsgHandler");
|
||||
if (field)
|
||||
return field->get_value();
|
||||
return m_name;
|
||||
}
|
||||
|
||||
void MessageTemplate::set_handler(std::string handler)
|
||||
{
|
||||
m_record->add_field<ki::dml::STR>("_MsgHandler")->set_value(handler);
|
||||
}
|
||||
|
||||
uint8_t MessageTemplate::get_access_level() const
|
||||
{
|
||||
const auto field = m_record->get_field<ki::dml::UBYT>("_MsgAccessLvl");
|
||||
if (field)
|
||||
return field->get_value();
|
||||
return 0;
|
||||
}
|
||||
|
||||
void MessageTemplate::set_access_level(uint8_t access_level)
|
||||
{
|
||||
m_record->add_field<ki::dml::UBYT>("_MsgAccessLvl")->set_value(access_level);
|
||||
}
|
||||
|
||||
const ki::dml::Record& MessageTemplate::get_record() const
|
||||
{
|
||||
return *m_record;
|
||||
}
|
||||
|
||||
void MessageTemplate::set_record(ki::dml::Record* record)
|
||||
{
|
||||
m_record = record;
|
||||
}
|
||||
|
||||
Message *MessageTemplate::create_message() const
|
||||
{
|
||||
return new Message(this);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,173 @@
|
|||
#include "ki/protocol/net/ClientSession.h"
|
||||
#include "ki/protocol/control/SessionOffer.h"
|
||||
#include "ki/protocol/control/SessionAccept.h"
|
||||
#include "ki/protocol/control/ClientKeepAlive.h"
|
||||
#include "ki/protocol/control/ServerKeepAlive.h"
|
||||
#include "ki/protocol/exception.h"
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace net
|
||||
{
|
||||
ClientSession::ClientSession(const uint16_t id)
|
||||
: Session(id) {}
|
||||
|
||||
void ClientSession::send_keep_alive()
|
||||
{
|
||||
// Don't send a keep alive if we're waiting for a response
|
||||
if (m_waiting_for_keep_alive_response)
|
||||
return;
|
||||
m_waiting_for_keep_alive_response = true;
|
||||
|
||||
// Work out how many minutes have been since the establish time, and
|
||||
// how many milliseconds we are in to the current minute.
|
||||
const auto time_since_establish = std::chrono::steady_clock::now() - m_establish_time;
|
||||
const auto minutes = std::chrono::duration_cast<std::chrono::minutes>(time_since_establish);
|
||||
const auto milliseconds = std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
time_since_establish - minutes
|
||||
).count();
|
||||
|
||||
// Send a KEEP_ALIVE packet
|
||||
control::ClientKeepAlive keep_alive(m_id, milliseconds, minutes.count());
|
||||
send_packet(true, (uint8_t)control::Opcode::KEEP_ALIVE, keep_alive);
|
||||
m_last_sent_heartbeat_time = std::chrono::steady_clock::now();
|
||||
}
|
||||
|
||||
bool ClientSession::is_alive() const
|
||||
{
|
||||
// If the session isn't established yet, use the time of
|
||||
// creation to decide whether this session is alive.
|
||||
if (!m_established)
|
||||
return std::chrono::duration_cast<std::chrono::seconds>(
|
||||
std::chrono::steady_clock::now() - m_creation_time
|
||||
).count() <= (KI_CONNECTION_TIMEOUT * 2);
|
||||
|
||||
// Otherwise, use the last time we received a heartbeat.
|
||||
return std::chrono::duration_cast<std::chrono::seconds>(
|
||||
std::chrono::steady_clock::now() - m_last_received_heartbeat_time
|
||||
).count() <= (KI_SERVER_HEARTBEAT * 2);
|
||||
}
|
||||
|
||||
void ClientSession::on_connected()
|
||||
{
|
||||
m_connection_time = std::chrono::steady_clock::now();
|
||||
}
|
||||
|
||||
void ClientSession::on_control_message(const PacketHeader& header)
|
||||
{
|
||||
switch ((control::Opcode)header.get_opcode())
|
||||
{
|
||||
case control::Opcode::SESSION_OFFER:
|
||||
on_session_offer();
|
||||
break;
|
||||
|
||||
case control::Opcode::KEEP_ALIVE:
|
||||
on_keep_alive();
|
||||
break;
|
||||
|
||||
case control::Opcode::KEEP_ALIVE_RSP:
|
||||
on_keep_alive_response();
|
||||
break;
|
||||
|
||||
default:
|
||||
close(SessionCloseErrorCode::UNHANDLED_CONTROL_MESSAGE);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void ClientSession::on_session_offer()
|
||||
{
|
||||
// Read the payload data into a structure
|
||||
control::SessionOffer offer;
|
||||
try
|
||||
{
|
||||
offer = read_data<control::SessionOffer>();
|
||||
}
|
||||
catch (parse_error &e)
|
||||
{
|
||||
// The SESSION_OFFER wasn't valid...
|
||||
// Close the session
|
||||
close(SessionCloseErrorCode::INVALID_MESSAGE);
|
||||
return;
|
||||
}
|
||||
|
||||
// Should this session have already timed out?
|
||||
if (std::chrono::duration_cast<std::chrono::seconds>(
|
||||
std::chrono::steady_clock::now() - m_connection_time
|
||||
).count() > KI_CONNECTION_TIMEOUT)
|
||||
{
|
||||
close(SessionCloseErrorCode::SESSION_OFFER_TIMED_OUT);
|
||||
return;
|
||||
}
|
||||
|
||||
// Work out the current timestamp and how many milliseconds
|
||||
// have elapsed in the current second.
|
||||
auto now = std::chrono::system_clock::now();
|
||||
const auto timestamp = std::chrono::duration_cast<std::chrono::seconds>(
|
||||
now.time_since_epoch()
|
||||
).count();
|
||||
const auto milliseconds = std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
now.time_since_epoch()
|
||||
).count() - (timestamp * 1000);
|
||||
|
||||
// Accept the session
|
||||
m_id = offer.get_session_id();
|
||||
control::SessionAccept accept(m_id, timestamp, milliseconds);
|
||||
send_packet(true, (uint8_t)control::Opcode::SESSION_ACCEPT, accept);
|
||||
|
||||
// The session is successfully established
|
||||
m_established = true;
|
||||
m_establish_time = std::chrono::steady_clock::now();
|
||||
m_last_received_heartbeat_time = m_establish_time;
|
||||
on_established();
|
||||
}
|
||||
|
||||
void ClientSession::on_keep_alive()
|
||||
{
|
||||
// Read the payload data into a structure
|
||||
control::ServerKeepAlive keep_alive;
|
||||
try
|
||||
{
|
||||
keep_alive = read_data<control::ServerKeepAlive>();
|
||||
}
|
||||
catch (parse_error &e)
|
||||
{
|
||||
// The KEEP_ALIVE wasn't valid...
|
||||
// Close the session
|
||||
close(SessionCloseErrorCode::INVALID_MESSAGE);
|
||||
return;
|
||||
}
|
||||
|
||||
// Send the response
|
||||
m_last_received_heartbeat_time = std::chrono::steady_clock::now();
|
||||
send_packet(true, (uint8_t)control::Opcode::KEEP_ALIVE_RSP, keep_alive);
|
||||
}
|
||||
|
||||
void ClientSession::on_keep_alive_response()
|
||||
{
|
||||
// Read the payload data into a structure
|
||||
try
|
||||
{
|
||||
// We don't actually need the data inside, but
|
||||
// read it to check if the structure is right.
|
||||
read_data<control::ClientKeepAlive>();
|
||||
}
|
||||
catch (parse_error &e)
|
||||
{
|
||||
// The KEEP_ALIVE_RSP wasn't valid...
|
||||
// Close the session
|
||||
close(SessionCloseErrorCode::INVALID_MESSAGE);
|
||||
return;
|
||||
}
|
||||
|
||||
// Calculate latency and allow for KEEP_ALIVE packets to be sent again
|
||||
m_latency = std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
std::chrono::steady_clock::now() - m_last_sent_heartbeat_time
|
||||
).count();
|
||||
m_waiting_for_keep_alive_response = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
#include "ki/protocol/net/DMLSession.h"
|
||||
#include "ki/protocol/exception.h"
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace net
|
||||
{
|
||||
DMLSession::DMLSession(const uint16_t id, const dml::MessageManager& manager)
|
||||
: Session(id), m_manager(manager) {}
|
||||
|
||||
const dml::MessageManager& DMLSession::get_manager() const
|
||||
{
|
||||
return m_manager;
|
||||
}
|
||||
|
||||
void DMLSession::send_message(const dml::Message& message)
|
||||
{
|
||||
send_packet(false, 0, message);
|
||||
}
|
||||
|
||||
void DMLSession::on_application_message(const PacketHeader& header)
|
||||
{
|
||||
// Attempt to create a Message instance from the data in the stream
|
||||
auto error_code = InvalidDMLMessageErrorCode::NONE;
|
||||
const dml::Message *message = nullptr;
|
||||
try
|
||||
{
|
||||
message = m_manager.message_from_binary(m_data_stream);
|
||||
}
|
||||
catch (parse_error &e)
|
||||
{
|
||||
switch (e.get_error_code())
|
||||
{
|
||||
case parse_error::INVALID_HEADER_DATA:
|
||||
error_code = InvalidDMLMessageErrorCode::INVALID_HEADER_DATA;
|
||||
break;
|
||||
case parse_error::INSUFFICIENT_MESSAGE_DATA:
|
||||
case parse_error::INVALID_MESSAGE_DATA:
|
||||
error_code = InvalidDMLMessageErrorCode::INVALID_MESSAGE_DATA;
|
||||
break;
|
||||
default:
|
||||
error_code = InvalidDMLMessageErrorCode::UNKNOWN;
|
||||
}
|
||||
}
|
||||
catch (value_error &e)
|
||||
{
|
||||
switch (e.get_error_code())
|
||||
{
|
||||
case value_error::DML_INVALID_SERVICE:
|
||||
error_code = InvalidDMLMessageErrorCode::INVALID_SERVICE;
|
||||
break;
|
||||
case value_error::DML_INVALID_MESSAGE_TYPE:
|
||||
error_code = InvalidDMLMessageErrorCode::INVALID_MESSAGE_TYPE;
|
||||
break;
|
||||
default:
|
||||
error_code = InvalidDMLMessageErrorCode::UNKNOWN;
|
||||
}
|
||||
}
|
||||
|
||||
if (!message)
|
||||
{
|
||||
on_invalid_message(error_code);
|
||||
return;
|
||||
}
|
||||
|
||||
// Are we sufficiently authenticated to handle this message?
|
||||
if (get_access_level() >= message->get_access_level())
|
||||
on_message(message);
|
||||
else
|
||||
on_invalid_message(InvalidDMLMessageErrorCode::INSUFFICIENT_ACCESS);
|
||||
delete message;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,67 @@
|
|||
#include "ki/protocol/net/PacketHeader.h"
|
||||
#include "ki/protocol/exception.h"
|
||||
#include <sstream>
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace net
|
||||
{
|
||||
PacketHeader::PacketHeader(const bool control, const uint8_t opcode)
|
||||
{
|
||||
m_control = control;
|
||||
m_opcode = opcode;
|
||||
}
|
||||
|
||||
bool PacketHeader::is_control() const
|
||||
{
|
||||
return m_control;
|
||||
}
|
||||
|
||||
void PacketHeader::set_control(const bool control)
|
||||
{
|
||||
m_control = control;
|
||||
}
|
||||
|
||||
uint8_t PacketHeader::get_opcode() const
|
||||
{
|
||||
return m_opcode;
|
||||
}
|
||||
|
||||
void PacketHeader::set_opcode(const uint8_t opcode)
|
||||
{
|
||||
m_opcode = opcode;
|
||||
}
|
||||
|
||||
void PacketHeader::write_to(std::ostream& ostream) const
|
||||
{
|
||||
ostream.put(m_control);
|
||||
ostream.put(m_opcode);
|
||||
ostream.put(0);
|
||||
ostream.put(0);
|
||||
}
|
||||
|
||||
void PacketHeader::read_from(std::istream& istream)
|
||||
{
|
||||
m_control = istream.get() >= 1;
|
||||
if (istream.fail())
|
||||
throw parse_error("Not enough data was available to read packet header. (m_control)",
|
||||
parse_error::INVALID_HEADER_DATA);
|
||||
m_opcode = istream.get();
|
||||
if (istream.fail())
|
||||
throw parse_error("Not enough data was available to read packet header. (m_opcode)",
|
||||
parse_error::INVALID_HEADER_DATA);
|
||||
istream.ignore(2);
|
||||
if (istream.eof())
|
||||
throw parse_error("Not enough data was available to read packet header. (ignored bytes)",
|
||||
parse_error::INVALID_HEADER_DATA);
|
||||
}
|
||||
|
||||
size_t PacketHeader::get_size() const
|
||||
{
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,171 @@
|
|||
#include "ki/protocol/net/ServerSession.h"
|
||||
#include "ki/protocol/control/SessionOffer.h"
|
||||
#include "ki/protocol/control/SessionAccept.h"
|
||||
#include "ki/protocol/control/ClientKeepAlive.h"
|
||||
#include "ki/protocol/control/ServerKeepAlive.h"
|
||||
#include "ki/protocol/exception.h"
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace net
|
||||
{
|
||||
ServerSession::ServerSession(const uint16_t id)
|
||||
: Session(id) {}
|
||||
|
||||
void ServerSession::send_keep_alive(const uint32_t milliseconds_since_startup)
|
||||
{
|
||||
// Don't send a keep alive if we're waiting for a response
|
||||
if (m_waiting_for_keep_alive_response)
|
||||
return;
|
||||
m_waiting_for_keep_alive_response = true;
|
||||
|
||||
// Send a KEEP_ALIVE packet
|
||||
const control::ServerKeepAlive keep_alive(milliseconds_since_startup);
|
||||
send_packet(true, (uint8_t)control::Opcode::KEEP_ALIVE, keep_alive);
|
||||
m_last_sent_heartbeat_time = std::chrono::steady_clock::now();
|
||||
}
|
||||
|
||||
bool ServerSession::is_alive() const
|
||||
{
|
||||
// If the session isn't established yet, use the time of
|
||||
// creation to decide whether this session is alive.
|
||||
if (!m_established)
|
||||
return std::chrono::duration_cast<std::chrono::seconds>(
|
||||
std::chrono::steady_clock::now() - m_creation_time
|
||||
).count() <= (KI_CONNECTION_TIMEOUT * 2);
|
||||
|
||||
// Otherwise, use the last time we received a heartbeat.
|
||||
return std::chrono::duration_cast<std::chrono::seconds>(
|
||||
std::chrono::steady_clock::now() - m_last_received_heartbeat_time
|
||||
).count() <= (KI_CLIENT_HEARTBEAT * 2);
|
||||
}
|
||||
|
||||
void ServerSession::on_connected()
|
||||
{
|
||||
m_connection_time = std::chrono::steady_clock::now();
|
||||
|
||||
// Work out the current timestamp and how many milliseconds
|
||||
// have elapsed in the current second.
|
||||
auto now = std::chrono::system_clock::now();
|
||||
const auto timestamp = std::chrono::duration_cast<std::chrono::seconds>(
|
||||
now.time_since_epoch()
|
||||
).count();
|
||||
const auto milliseconds = std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
now.time_since_epoch()
|
||||
).count() - (timestamp * 1000);
|
||||
|
||||
// Send a SESSION_OFFER packet to the client
|
||||
const control::SessionOffer offer(m_id, timestamp, milliseconds);
|
||||
send_packet(true, (uint8_t)control::Opcode::SESSION_OFFER, offer);
|
||||
}
|
||||
|
||||
void ServerSession::on_control_message(const PacketHeader& header)
|
||||
{
|
||||
switch ((control::Opcode)header.get_opcode())
|
||||
{
|
||||
case control::Opcode::SESSION_ACCEPT:
|
||||
on_session_accept();
|
||||
break;
|
||||
|
||||
case control::Opcode::KEEP_ALIVE:
|
||||
on_keep_alive();
|
||||
break;
|
||||
|
||||
case control::Opcode::KEEP_ALIVE_RSP:
|
||||
on_keep_alive_response();
|
||||
break;
|
||||
|
||||
default:
|
||||
close(SessionCloseErrorCode::UNHANDLED_CONTROL_MESSAGE);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void ServerSession::on_session_accept()
|
||||
{
|
||||
// Read the payload data into a structure
|
||||
control::SessionAccept accept;
|
||||
try
|
||||
{
|
||||
accept = read_data<control::SessionAccept>();
|
||||
}
|
||||
catch (parse_error &e)
|
||||
{
|
||||
// The SESSION_ACCEPT wasn't valid...
|
||||
// Close the session
|
||||
close(SessionCloseErrorCode::INVALID_MESSAGE);
|
||||
return;
|
||||
}
|
||||
|
||||
// Should this session have already timed out?
|
||||
if (std::chrono::duration_cast<std::chrono::seconds>(
|
||||
std::chrono::steady_clock::now() - m_connection_time
|
||||
).count() > KI_CONNECTION_TIMEOUT)
|
||||
{
|
||||
close(SessionCloseErrorCode::SESSION_OFFER_TIMED_OUT);
|
||||
return;
|
||||
}
|
||||
|
||||
// Make sure they're accepting this session
|
||||
if (accept.get_session_id() != m_id)
|
||||
{
|
||||
close(SessionCloseErrorCode::INVALID_MESSAGE);
|
||||
return;
|
||||
}
|
||||
|
||||
// The session is successfully established
|
||||
m_established = true;
|
||||
m_establish_time = std::chrono::steady_clock::now();
|
||||
m_last_received_heartbeat_time = m_establish_time;
|
||||
on_established();
|
||||
}
|
||||
|
||||
void ServerSession::on_keep_alive()
|
||||
{
|
||||
// Read the payload data into a structure
|
||||
control::ClientKeepAlive keep_alive;
|
||||
try
|
||||
{
|
||||
keep_alive = read_data<control::ClientKeepAlive>();
|
||||
}
|
||||
catch (parse_error &e)
|
||||
{
|
||||
// The KEEP_ALIVE wasn't valid...
|
||||
// Close the session
|
||||
close(SessionCloseErrorCode::INVALID_MESSAGE);
|
||||
return;
|
||||
}
|
||||
|
||||
// Send the response
|
||||
m_last_received_heartbeat_time = std::chrono::steady_clock::now();
|
||||
send_packet(true, (uint8_t)control::Opcode::KEEP_ALIVE_RSP, keep_alive);
|
||||
}
|
||||
|
||||
void ServerSession::on_keep_alive_response()
|
||||
{
|
||||
// Read the payload data into a structure
|
||||
try
|
||||
{
|
||||
// We don't actually need the data inside, but
|
||||
// read it to check if the structure is right.
|
||||
read_data<control::ServerKeepAlive>();
|
||||
}
|
||||
catch (parse_error &e)
|
||||
{
|
||||
// The KEEP_ALIVE_RSP wasn't valid...
|
||||
// Close the session
|
||||
close(SessionCloseErrorCode::INVALID_MESSAGE);
|
||||
return;
|
||||
}
|
||||
|
||||
// Calculate latency and allow for KEEP_ALIVE packets to be sent again
|
||||
m_latency = std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
std::chrono::steady_clock::now() - m_last_sent_heartbeat_time
|
||||
).count();
|
||||
m_waiting_for_keep_alive_response = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,190 @@
|
|||
#include "ki/protocol/net/Session.h"
|
||||
#include "ki/protocol/exception.h"
|
||||
#include <cstring>
|
||||
|
||||
namespace ki
|
||||
{
|
||||
namespace protocol
|
||||
{
|
||||
namespace net
|
||||
{
|
||||
Session::Session(const uint16_t id)
|
||||
{
|
||||
m_id = id;
|
||||
m_established = false;
|
||||
m_access_level = 0;
|
||||
m_latency = 0;
|
||||
m_creation_time = std::chrono::steady_clock::now();
|
||||
m_waiting_for_keep_alive_response = false;
|
||||
|
||||
m_maximum_packet_size = KI_DEFAULT_MAXIMUM_RECEIVE_SIZE;
|
||||
m_receive_state = ReceiveState::WAITING_FOR_START_SIGNAL;
|
||||
m_start_signal = 0;
|
||||
m_incoming_packet_size = 0;
|
||||
m_shift = 0;
|
||||
}
|
||||
|
||||
uint16_t Session::get_maximum_packet_size() const
|
||||
{
|
||||
return m_maximum_packet_size;
|
||||
}
|
||||
|
||||
void Session::set_maximum_packet_size(const uint16_t maximum_packet_size)
|
||||
{
|
||||
m_maximum_packet_size = maximum_packet_size;
|
||||
}
|
||||
|
||||
uint16_t Session::get_id() const
|
||||
{
|
||||
return m_id;
|
||||
}
|
||||
|
||||
bool Session::is_established() const
|
||||
{
|
||||
return m_established;
|
||||
}
|
||||
|
||||
uint8_t Session::get_access_level() const
|
||||
{
|
||||
return m_access_level;
|
||||
}
|
||||
|
||||
void Session::set_access_level(const uint8_t access_level)
|
||||
{
|
||||
m_access_level = access_level;
|
||||
}
|
||||
|
||||
uint16_t Session::get_latency() const
|
||||
{
|
||||
return m_latency;
|
||||
}
|
||||
|
||||
void Session::send_packet(const bool is_control, const uint8_t opcode,
|
||||
const util::Serializable& data)
|
||||
{
|
||||
std::ostringstream ss;
|
||||
PacketHeader header(is_control, opcode);
|
||||
header.write_to(ss);
|
||||
data.write_to(ss);
|
||||
|
||||
const auto buffer = ss.str();
|
||||
send_data(buffer.c_str(), buffer.length());
|
||||
}
|
||||
|
||||
void Session::send_data(const char* data, const size_t size)
|
||||
{
|
||||
// Allocate the entire buffer
|
||||
char *packet_data = new char[size + 4];
|
||||
|
||||
// Add the frame header
|
||||
((uint16_t *)packet_data)[0] = KI_START_SIGNAL;
|
||||
((uint16_t *)packet_data)[1] = size;
|
||||
|
||||
// Copy the payload into the buffer and send it
|
||||
std::memcpy(&packet_data[4], data, size);
|
||||
send_packet_data(packet_data, size + 4);
|
||||
delete[] packet_data;
|
||||
}
|
||||
|
||||
void Session::process_data(const char *data, const size_t size)
|
||||
{
|
||||
size_t position = 0;
|
||||
while (position < size)
|
||||
{
|
||||
switch (m_receive_state)
|
||||
{
|
||||
case ReceiveState::WAITING_FOR_START_SIGNAL:
|
||||
m_start_signal |= ((uint8_t)data[position] << m_shift);
|
||||
if (m_shift == 0)
|
||||
m_shift = 8;
|
||||
else
|
||||
{
|
||||
// If the start signal isn't correct, we've either
|
||||
// gotten out of sync, or they are not framing packets
|
||||
// correctly.
|
||||
if (m_start_signal != KI_START_SIGNAL)
|
||||
{
|
||||
close(SessionCloseErrorCode::INVALID_FRAMING_START_SIGNAL);
|
||||
return;
|
||||
}
|
||||
|
||||
// Reset the shift and incoming packet size
|
||||
m_shift = 0;
|
||||
m_incoming_packet_size = 0;
|
||||
m_receive_state = ReceiveState::WAITING_FOR_LENGTH;
|
||||
}
|
||||
position++;
|
||||
break;
|
||||
|
||||
case ReceiveState::WAITING_FOR_LENGTH:
|
||||
m_incoming_packet_size |= ((uint8_t)data[position] << m_shift);
|
||||
if (m_shift == 0)
|
||||
m_shift = 8;
|
||||
else
|
||||
{
|
||||
// If the incoming packet is larger than we are accepting
|
||||
// stop processing data.
|
||||
if (m_incoming_packet_size > m_maximum_packet_size)
|
||||
{
|
||||
close(SessionCloseErrorCode::INVALID_FRAMING_SIZE_EXCEEDS_MAXIMUM);
|
||||
return;
|
||||
}
|
||||
|
||||
// Reset read and write positions
|
||||
m_data_stream.seekp(0, std::ios::beg);
|
||||
m_data_stream.seekg(0, std::ios::beg);
|
||||
m_receive_state = ReceiveState::WAITING_FOR_PACKET;
|
||||
}
|
||||
position++;
|
||||
break;
|
||||
|
||||
case ReceiveState::WAITING_FOR_PACKET:
|
||||
// Work out how much data we should read into our stream
|
||||
const size_t data_available = (size - position);
|
||||
const size_t read_size = (data_available >= m_incoming_packet_size) ?
|
||||
m_incoming_packet_size : data_available;
|
||||
|
||||
// Write the data to the data stream
|
||||
m_data_stream.write(&data[position], read_size);
|
||||
position += read_size;
|
||||
m_incoming_packet_size -= read_size;
|
||||
|
||||
// Have we received the entire packet?
|
||||
if (m_incoming_packet_size == 0)
|
||||
{
|
||||
on_packet_available();
|
||||
|
||||
// Reset the shift and start signal
|
||||
m_shift = 0;
|
||||
m_start_signal = 0;
|
||||
m_receive_state = ReceiveState::WAITING_FOR_START_SIGNAL;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Session::on_packet_available()
|
||||
{
|
||||
// Read the packet header
|
||||
PacketHeader header;
|
||||
try
|
||||
{
|
||||
header.read_from(m_data_stream);
|
||||
}
|
||||
catch (parse_error &e)
|
||||
{
|
||||
on_invalid_packet();
|
||||
return;
|
||||
}
|
||||
|
||||
// Hand off to the right handler based on
|
||||
// whether this is a control packet or not
|
||||
if (header.is_control())
|
||||
on_control_message(header);
|
||||
else
|
||||
on_application_message(header);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,188 @@
|
|||
#define CATCH_CONFIG_MAIN
|
||||
#include <catch.hpp>
|
||||
#include <fstream>
|
||||
|
||||
#include <ki/protocol/control/SessionOffer.h>
|
||||
#include <ki/protocol/control/SessionAccept.h>
|
||||
#include <ki/protocol/control/ClientKeepAlive.h>
|
||||
#include <ki/protocol/control/ServerKeepAlive.h>
|
||||
|
||||
using namespace ki::protocol;
|
||||
|
||||
TEST_CASE("Control Message Serialization", "[control]")
|
||||
{
|
||||
std::ostringstream oss;
|
||||
|
||||
SECTION("SessionOffer")
|
||||
{
|
||||
control::SessionOffer offer(0xABCD, 0xAABBCCDD, 0xAABBCCDD);
|
||||
offer.write_to(oss);
|
||||
|
||||
const char expected_bytes[] = {
|
||||
// Session ID
|
||||
'\xCD', '\xAB',
|
||||
|
||||
// Unknown
|
||||
'\x00', '\x00', '\x00', '\x00',
|
||||
|
||||
// Timestamp
|
||||
'\xDD', '\xCC', '\xBB', '\xAA',
|
||||
|
||||
// Milliseconds
|
||||
'\xDD', '\xCC', '\xBB', '\xAA'
|
||||
};
|
||||
REQUIRE(oss.str() == std::string(expected_bytes, sizeof(expected_bytes)));
|
||||
}
|
||||
|
||||
SECTION("SessionAccept")
|
||||
{
|
||||
control::SessionAccept accept(0xABCD, 0xAABBCCDD, 0xAABBCCDD);
|
||||
accept.write_to(oss);
|
||||
|
||||
const char expected_bytes[] = {
|
||||
// Unknown
|
||||
'\x00', '\x00',
|
||||
|
||||
// Unknown
|
||||
'\x00', '\x00', '\x00', '\x00',
|
||||
|
||||
// Timestamp
|
||||
'\xDD', '\xCC', '\xBB', '\xAA',
|
||||
|
||||
// Milliseconds
|
||||
'\xDD', '\xCC', '\xBB', '\xAA',
|
||||
|
||||
// Session ID
|
||||
'\xCD', '\xAB'
|
||||
};
|
||||
REQUIRE(oss.str() == std::string(expected_bytes, sizeof(expected_bytes)));
|
||||
}
|
||||
|
||||
SECTION("ClientKeepAlive")
|
||||
{
|
||||
control::ClientKeepAlive keep_alive(0xABCD, 0xABCD, 0xABCD);
|
||||
keep_alive.write_to(oss);
|
||||
|
||||
const char expected_bytes[] = {
|
||||
// Session ID
|
||||
'\xCD', '\xAB',
|
||||
|
||||
// Milliseconds
|
||||
'\xCD', '\xAB',
|
||||
|
||||
// Minutes
|
||||
'\xCD', '\xAB'
|
||||
};
|
||||
REQUIRE(oss.str() == std::string(expected_bytes, sizeof(expected_bytes)));
|
||||
}
|
||||
|
||||
SECTION("ServerKeepAlive")
|
||||
{
|
||||
control::ServerKeepAlive keep_alive(0xAABBCCDD);
|
||||
keep_alive.write_to(oss);
|
||||
|
||||
const char expected_bytes[] = {
|
||||
// Unknown
|
||||
'\x00', '\x00',
|
||||
|
||||
// Timestamp
|
||||
'\xDD', '\xCC', '\xBB', '\xAA'
|
||||
};
|
||||
REQUIRE(oss.str() == std::string(expected_bytes, sizeof(expected_bytes)));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("Control Message Deserialization", "[control]")
|
||||
{
|
||||
SECTION("SessionOffer")
|
||||
{
|
||||
const char bytes[] = {
|
||||
// Session ID
|
||||
'\xCD', '\xAB',
|
||||
|
||||
// Unknown
|
||||
'\x00', '\x00', '\x00', '\x00',
|
||||
|
||||
// Timestamp
|
||||
'\xDD', '\xCC', '\xBB', '\xAA',
|
||||
|
||||
// Milliseconds
|
||||
'\xDD', '\xCC', '\xBB', '\xAA'
|
||||
};
|
||||
std::istringstream iss(std::string(bytes, sizeof(bytes)));
|
||||
|
||||
control::SessionOffer offer;
|
||||
offer.read_from(iss);
|
||||
|
||||
REQUIRE(offer.get_session_id() == 0xABCD);
|
||||
REQUIRE(offer.get_timestamp() == 0xAABBCCDD);
|
||||
REQUIRE(offer.get_milliseconds() == 0xAABBCCDD);
|
||||
}
|
||||
|
||||
SECTION("SessionAccept")
|
||||
{
|
||||
const char bytes[] = {
|
||||
// Unknown
|
||||
'\x00', '\x00',
|
||||
|
||||
// Unknown
|
||||
'\x00', '\x00', '\x00', '\x00',
|
||||
|
||||
// Timestamp
|
||||
'\xDD', '\xCC', '\xBB', '\xAA',
|
||||
|
||||
// Milliseconds
|
||||
'\xDD', '\xCC', '\xBB', '\xAA',
|
||||
|
||||
// Session ID
|
||||
'\xCD', '\xAB'
|
||||
};
|
||||
std::istringstream iss(std::string(bytes, sizeof(bytes)));
|
||||
|
||||
control::SessionAccept accept;
|
||||
accept.read_from(iss);
|
||||
|
||||
REQUIRE(accept.get_session_id() == 0xABCD);
|
||||
REQUIRE(accept.get_timestamp() == 0xAABBCCDD);
|
||||
REQUIRE(accept.get_milliseconds() == 0xAABBCCDD);
|
||||
}
|
||||
|
||||
SECTION("ClientKeepAlive")
|
||||
{
|
||||
const char bytes[] = {
|
||||
// Session ID
|
||||
'\xCD', '\xAB',
|
||||
|
||||
// Milliseconds
|
||||
'\xCD', '\xAB',
|
||||
|
||||
// Minutes
|
||||
'\xCD', '\xAB'
|
||||
};
|
||||
std::istringstream iss(std::string(bytes, sizeof(bytes)));
|
||||
|
||||
control::ClientKeepAlive keep_alive;
|
||||
keep_alive.read_from(iss);
|
||||
|
||||
REQUIRE(keep_alive.get_session_id() == 0xABCD);
|
||||
REQUIRE(keep_alive.get_milliseconds() == 0xABCD);
|
||||
REQUIRE(keep_alive.get_minutes() == 0xABCD);
|
||||
}
|
||||
|
||||
SECTION("ServerKeepAlive")
|
||||
{
|
||||
const char bytes[] = {
|
||||
// Unknown
|
||||
'\x00', '\x00',
|
||||
|
||||
// Timestamp
|
||||
'\xDD', '\xCC', '\xBB', '\xAA'
|
||||
};
|
||||
std::istringstream iss(std::string(bytes, sizeof(bytes)));
|
||||
|
||||
control::ServerKeepAlive keep_alive;
|
||||
keep_alive.read_from(iss);
|
||||
|
||||
REQUIRE(keep_alive.get_timestamp() == 0xAABBCCDD);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue