diff --git a/include/ki/protocol/net/Participant.h b/include/ki/protocol/net/Participant.h new file mode 100644 index 0000000..36fb876 --- /dev/null +++ b/include/ki/protocol/net/Participant.h @@ -0,0 +1,77 @@ +#pragma once +#include +#include + +#define KI_DEFAULT_MAXIMUM_RECEIVE_SIZE 0x2000 +#define KI_START_SIGNAL 0xF00D + +namespace ki +{ +namespace protocol +{ +namespace net +{ + enum class ReceiveState + { + // Waiting for the 0xF00D start signal. + WAITING_FOR_START_SIGNAL, + + // Waiting for the 2-byte length. + WAITING_FOR_LENGTH, + + // Waiting for the packet data. + WAITING_FOR_PACKET + }; + + enum class ParticipantType + { + SERVER, + CLIENT + }; + + /** + * This class implements the packet framing logic when + * sending and receiving data to/from an external source. + */ + class Participant + { + public: + Participant(ParticipantType type); + virtual ~Participant() = default; + + ParticipantType get_type() const; + void set_type(ParticipantType type); + + uint16_t get_maximum_packet_size() const; + void set_maximum_packet_size(uint16_t maximum_packet_size); + protected: + std::stringstream m_data_stream; + + /** + * Frames raw data into a Packet, and transmits it. + */ + void send_data(const char *data, size_t size); + + /** + * Process incoming raw data into Packets. + * Once a packet is read into the internal data + * stream, handle_packet_available is called. + */ + void process_data(const char *data, size_t size); + + virtual void close() = 0; + private: + ParticipantType m_type; + uint16_t m_maximum_packet_size; + + ReceiveState m_receive_state; + uint16_t m_start_signal; + uint16_t m_incoming_packet_size; + uint8_t m_shift; + + virtual void send_packet_data(const char *data, const size_t size) = 0; + virtual void on_packet_available() {}; + }; +} +} +} diff --git a/include/ki/protocol/net/Session.h b/include/ki/protocol/net/Session.h new file mode 100644 index 0000000..4b43479 --- /dev/null +++ b/include/ki/protocol/net/Session.h @@ -0,0 +1,88 @@ +#pragma once +#include "Participant.h" +#include "PacketHeader.h" +#include "ki/protocol/control/Opcode.h" +#include "../../util/Serializable.h" +#include +#include +#include + +namespace ki +{ +namespace protocol +{ +namespace net +{ + /** + * This class implements session logic on top of the + * low-level Participant class. + */ + class Session : public Participant + { + public: + Session(ParticipantType type, uint16_t id); + + uint16_t get_id() const; + bool is_established() const; + + uint8_t get_access_level() const; + void set_access_level(uint8_t access_level); + + uint16_t get_latency() const; + + bool is_alive() const; + protected: + template + void send_packet(const bool is_control, const control::Opcode opcode, + const DataT &data) + { + static_assert(std::is_base_of::value, + "DataT must inherit Serializable."); + + std::ostringstream ss; + PacketHeader header(is_control, (uint8_t)opcode); + header.write_to(ss); + data.write_to(ss); + + const auto buffer = ss.str(); + send_data(buffer.c_str(), buffer.length()); + } + + template + DataT read_data() + { + static_assert(std::is_base_of::value, + "DataT must inherit Serializable."); + + DataT data = DataT(); + data.read_from(m_data_stream); + return data; + } + + void on_connected(); + virtual void on_established() {}; + virtual void on_application_message(const PacketHeader &header) {}; + virtual void on_invalid_packet() {}; + private: + uint16_t m_id; + bool m_established; + uint8_t m_access_level; + uint16_t m_latency; + + std::chrono::steady_clock::time_point m_creation_time; + std::chrono::steady_clock::time_point m_establish_time; + std::chrono::steady_clock::time_point m_last_heartbeat; + + void on_packet_available() override final; + void on_control_message(const PacketHeader &header); + void on_server_hello(); + void on_client_hello(); + void on_ping(); + void on_ping_response(); + + void on_hello(uint16_t session_id, uint32_t timestamp, + uint16_t milliseconds); + }; +} +} +} diff --git a/src/protocol/CMakeLists.txt b/src/protocol/CMakeLists.txt index 2d6393a..e59aea3 100644 --- a/src/protocol/CMakeLists.txt +++ b/src/protocol/CMakeLists.txt @@ -1,6 +1,5 @@ target_sources(${PROJECT_NAME} PRIVATE - ${PROJECT_SOURCE_DIR}/src/protocol/Packet.cpp ${PROJECT_SOURCE_DIR}/src/protocol/control/ClientHello.cpp ${PROJECT_SOURCE_DIR}/src/protocol/control/ServerHello.cpp ${PROJECT_SOURCE_DIR}/src/protocol/control/Ping.cpp @@ -9,4 +8,7 @@ target_sources(${PROJECT_NAME} ${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageManager.cpp ${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageModule.cpp ${PROJECT_SOURCE_DIR}/src/protocol/dml/MessageTemplate.cpp + ${PROJECT_SOURCE_DIR}/src/protocol/net/PacketHeader.cpp + ${PROJECT_SOURCE_DIR}/src/protocol/net/Participant.cpp + ${PROJECT_SOURCE_DIR}/src/protocol/net/Session.cpp ) \ No newline at end of file diff --git a/src/protocol/net/Participant.cpp b/src/protocol/net/Participant.cpp new file mode 100644 index 0000000..d5db22b --- /dev/null +++ b/src/protocol/net/Participant.cpp @@ -0,0 +1,137 @@ +#include "ki/protocol/net/Participant.h" +#include "ki/protocol/exception.h" + +namespace ki +{ +namespace protocol +{ +namespace net +{ + Participant::Participant(const ParticipantType type) + { + m_type = type; + m_maximum_packet_size = KI_DEFAULT_MAXIMUM_RECEIVE_SIZE; + + m_receive_state = ReceiveState::WAITING_FOR_START_SIGNAL; + m_start_signal = 0; + m_incoming_packet_size = 0; + m_shift = 0; + } + + ParticipantType Participant::get_type() const + { + return m_type; + } + + void Participant::set_type(const ParticipantType type) + { + m_type = type; + } + + + uint16_t Participant::get_maximum_packet_size() const + { + return m_maximum_packet_size; + } + + void Participant::set_maximum_packet_size(const uint16_t maximum_packet_size) + { + m_maximum_packet_size = maximum_packet_size; + } + + void Participant::send_data(const char* data, const size_t size) + { + // Allocate the entire buffer + char *packet_data = new char[size + 4]; + + // Add the frame header + ((uint16_t *)packet_data)[0] = KI_START_SIGNAL; + ((uint16_t *)packet_data)[1] = size; + + // Copy the payload into the buffer and send it + memcpy(&packet_data[4], data, size); + send_packet_data(packet_data, size + 4); + delete[] packet_data; + } + + void Participant::process_data(const char *data, const size_t size) + { + size_t position = 0; + while (position < size) + { + switch (m_receive_state) + { + case ReceiveState::WAITING_FOR_START_SIGNAL: + m_start_signal |= ((uint8_t)data[position] << m_shift); + if (m_shift == 0) + m_shift = 8; + else + { + // If the start signal isn't correct, we've either + // gotten out of sync, or they are not framing packets + // correctly. + if (m_start_signal != KI_START_SIGNAL) + { + close(); + return; + } + + // Reset the shift and incoming packet size + m_shift = 0; + m_incoming_packet_size = 0; + m_receive_state = ReceiveState::WAITING_FOR_LENGTH; + } + position++; + break; + + case ReceiveState::WAITING_FOR_LENGTH: + m_incoming_packet_size |= ((uint8_t)data[position] << m_shift); + if (m_shift == 0) + m_shift = 8; + else + { + // If the incoming packet is larger than we are accepting + // stop processing data. + if (m_incoming_packet_size > m_maximum_packet_size) + { + close(); + return; + } + + // Reset read and write positions + m_data_stream.seekp(0, std::ios::beg); + m_data_stream.seekg(0, std::ios::beg); + m_receive_state = ReceiveState::WAITING_FOR_PACKET; + } + position++; + break; + + case ReceiveState::WAITING_FOR_PACKET: + // Work out how much data we should read into our stream + const size_t data_available = (size - position); + const size_t read_size = (data_available >= m_incoming_packet_size) ? + m_incoming_packet_size : data_available; + + // Write the data to the data stream + m_data_stream.write(&data[position], read_size); + position += read_size; + m_incoming_packet_size -= read_size; + + // Have we received the entire packet? + if (m_incoming_packet_size == 0) + { + on_packet_available(); + + // Reset the shift and start signal + m_shift = 0; + m_start_signal = 0; + m_receive_state = ReceiveState::WAITING_FOR_START_SIGNAL; + } + break; + } + } + } + +} +} +} diff --git a/src/protocol/net/Session.cpp b/src/protocol/net/Session.cpp new file mode 100644 index 0000000..9238fb4 --- /dev/null +++ b/src/protocol/net/Session.cpp @@ -0,0 +1,273 @@ +#include "ki/protocol/net/Session.h" +#include "ki/protocol/exception.h" +#include "ki/protocol/control/ServerHello.h" +#include "ki/protocol/control/ClientHello.h" +#include "ki/protocol/control/Ping.h" + +namespace ki +{ +namespace protocol +{ +namespace net +{ + Session::Session(const ParticipantType type, const uint16_t id) + : Participant(type) + { + m_id = id; + m_established = false; + m_access_level = 0; + m_latency = 0; + m_creation_time = std::chrono::steady_clock::now(); + } + + uint16_t Session::get_id() const + { + return m_id; + } + + bool Session::is_established() const + { + return m_established; + } + + uint8_t Session::get_access_level() const + { + return m_access_level; + } + + void Session::set_access_level(const uint8_t access_level) + { + m_access_level = access_level; + } + + uint16_t Session::get_latency() const + { + return m_latency; + } + + bool Session::is_alive() const + { + // If the session isn't established yet, use the time of + // creation to decide whether this session is alive. + if (!m_established) + return std::chrono::duration_cast( + std::chrono::steady_clock::now() - m_creation_time + ).count() <= 3; + + // Otherwise, use the last time we received a heartbeat. + return std::chrono::duration_cast( + std::chrono::steady_clock::now() - m_last_heartbeat + ).count() <= 10; + } + + void Session::on_connected() + { + // If this is the server-side of a Session + // we need to send SERVER_HELLO first. + if (get_type() == ParticipantType::SERVER) + { + // Work out the current timestamp and how many milliseconds + // have elapsed in the current second. + auto now = std::chrono::system_clock::now(); + const auto timestamp = std::chrono::duration_cast( + now.time_since_epoch() + ).count(); + const auto milliseconds = std::chrono::duration_cast( + now.time_since_epoch() + ).count() - (timestamp * 1000); + + // Send a SERVER_HELLO packet to the client + const control::ServerHello hello(m_id, timestamp, milliseconds); + send_packet( + true, control::Opcode::SERVER_HELLO, hello); + } + } + + void Session::on_packet_available() + { + // Read the packet header + PacketHeader header; + try + { + header.read_from(m_data_stream); + } + catch (parse_error &e) + { + on_invalid_packet(); + return; + } + + // Hand off to the right handler based on + // whether this is a control packet or not + if (header.is_control()) + on_control_message(header); + else + on_application_message(header); + } + + void Session::on_control_message(const PacketHeader& header) + { + switch ((control::Opcode)header.get_opcode()) + { + case (control::Opcode::SERVER_HELLO): + on_server_hello(); + break; + + case (control::Opcode::CLIENT_HELLO): + on_client_hello(); + break; + + case (control::Opcode::PING): + on_ping(); + break; + + case (control::Opcode::PING_RSP): + on_ping_response(); + break; + + default: + break; + } + } + + void Session::on_server_hello() + { + // If this is the server-side of a Session + // we can't handle a SERVER_HELLO + if (get_type() != ParticipantType::CLIENT) + { + close(); + return; + } + + // Read the payload data into a structure + try + { + // We've been given our id from the server now + const auto server_hello = read_data(); + m_id = server_hello.get_session_id(); + on_hello(m_id, + server_hello.get_timestamp(), + server_hello.get_milliseconds()); + + // Work out the current timestamp and how many milliseconds + // have elapsed in the current second. + auto now = std::chrono::system_clock::now(); + const auto timestamp = std::chrono::duration_cast( + now.time_since_epoch() + ).count(); + const auto milliseconds = std::chrono::duration_cast( + now.time_since_epoch() + ).count() - (timestamp * 1000); + + // Send a CLIENT_HELLO packet to the server + const control::ClientHello hello(m_id, timestamp, milliseconds); + send_packet( + true, control::Opcode::CLIENT_HELLO, hello); + } + catch (parse_error &e) + { + // The CLIENT_HELLO wasn't valid... + // Close the session + close(); + } + } + + void Session::on_client_hello() + { + // If this is the client-side of a Session + // we can't handle a CLIENT_HELLO + if (get_type() != ParticipantType::SERVER) + { + close(); + return; + } + + // Read the payload data into a structure + try + { + // The session is now established! + const auto client_hello = read_data(); + on_hello(client_hello.get_session_id(), + client_hello.get_timestamp(), + client_hello.get_milliseconds()); + } + catch (parse_error &e) + { + // The CLIENT_HELLO wasn't valid... + // Close the session + close(); + } + } + + void Session::on_ping() + { + // Read the payload data into a structure + try + { + const auto ping = read_data(); + if (get_type() == ParticipantType::SERVER) + { + // Calculate latency + const auto send_time = m_establish_time + + std::chrono::milliseconds(ping.get_milliseconds()) + + std::chrono::minutes(ping.get_minutes()); + m_latency = std::chrono::duration_cast( + std::chrono::steady_clock::now() - send_time + ).count(); + } + + // Send the response + send_packet( + true, control::Opcode::PING_RSP, ping); + } + catch (parse_error &e) + { + // The CLIENT_HELLO wasn't valid... + // Close the session + close(); + } + } + + void Session::on_ping_response() + { + // Read the payload data into a structure + try + { + const auto ping = read_data(); + } + catch (parse_error &e) + { + // The CLIENT_HELLO wasn't valid... + // Close the session + close(); + } + } + + void Session::on_hello(const uint16_t session_id, + const uint32_t timestamp, const uint16_t milliseconds) + { + // Make sure they're accepting this session + if (session_id != m_id) + { + close(); + return; + } + + // Calculate initial latency + const std::chrono::system_clock::time_point epoch; + const auto send_time = epoch + (std::chrono::seconds(timestamp) + + std::chrono::milliseconds(milliseconds)); + m_latency = std::chrono::duration_cast( + std::chrono::system_clock::now() - send_time + ).count(); + + // The session is successfully established + m_established = true; + m_establish_time = std::chrono::steady_clock::now(); + m_last_heartbeat = m_establish_time; + on_established(); + } +} +} +}