aboutsummaryrefslogtreecommitdiff
path: root/client/src
diff options
context:
space:
mode:
authoralpine <[email protected]>2020-06-24 13:05:48 +0200
committeralpine <[email protected]>2020-06-24 13:05:48 +0200
commit8f0130c8f74482a7d54f9bfb8763f4c6d705765c (patch)
treee9090817a498da9ae18a58adee92f4f8ba2db03b /client/src
parentChanged clang format style. (diff)
downloadloader-8f0130c8f74482a7d54f9bfb8763f4c6d705765c.tar.xz
loader-8f0130c8f74482a7d54f9bfb8763f4c6d705765c.zip
Added client version control.
Reverted back to google formatting.
Diffstat (limited to 'client/src')
-rw-r--r--client/src/client/client.cpp127
-rw-r--r--client/src/client/client.h125
-rw-r--r--client/src/client/packet.h100
-rw-r--r--client/src/main.cpp56
-rw-r--r--client/src/util/events.h31
-rw-r--r--client/src/util/io.cpp10
-rw-r--r--client/src/util/io.h6
7 files changed, 230 insertions, 225 deletions
diff --git a/client/src/client/client.cpp b/client/src/client/client.cpp
index 3dd55fc..30d3c9e 100644
--- a/client/src/client/client.cpp
+++ b/client/src/client/client.cpp
@@ -1,93 +1,92 @@
#include "../include.h"
#include "client.h"
-bool tcp::client::start(const std::string_view server_ip, const uint16_t port)
-{
- SSL_library_init();
+void tcp::client::start(const std::string_view server_ip, const uint16_t port) {
+ SSL_library_init();
- m_ssl_ctx = SSL_CTX_new(TLS_client_method());
+ m_ssl_ctx = SSL_CTX_new(TLS_client_method());
- m_socket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
- if(m_socket == -1) {
- io::logger->error("failed to create socket.");
- return false;
- }
+ m_socket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+ if (m_socket == -1) {
+ io::logger->error("failed to create socket.");
+ return;
+ }
- sockaddr_in server_addr;
+ sockaddr_in server_addr;
- server_addr.sin_family = AF_INET;
- server_addr.sin_addr.s_addr = inet_addr(server_ip.data());
- server_addr.sin_port = htons(port);
+ server_addr.sin_family = AF_INET;
+ server_addr.sin_addr.s_addr = inet_addr(server_ip.data());
+ server_addr.sin_port = htons(port);
- int ret =
- connect(m_socket, reinterpret_cast<sockaddr*>(&server_addr), sizeof(server_addr));
- if(ret < 0) {
- io::logger->error("failed to connect to server.");
- return false;
- }
+ int ret = connect(m_socket, reinterpret_cast<sockaddr*>(&server_addr),
+ sizeof(server_addr));
+ if (ret < 0) {
+ io::logger->error("failed to connect to server.");
+ return;
+ }
- m_server_ssl = SSL_new(m_ssl_ctx);
- SSL_set_fd(m_server_ssl, m_socket);
+ m_server_ssl = SSL_new(m_ssl_ctx);
+ SSL_set_fd(m_server_ssl, m_socket);
- ret = SSL_connect(m_server_ssl);
+ ret = SSL_connect(m_server_ssl);
- if(ret != 1) {
- ret = SSL_get_error(m_server_ssl, ret);
- io::logger->error("ssl connection failed, code {}", ret);
- return false;
- }
+ if (ret != 1) {
+ ret = SSL_get_error(m_server_ssl, ret);
+ io::logger->error("ssl connection failed, code {}", ret);
+ return;
+ }
- return true;
-}
+ m_active = true;
-int tcp::client::read_stream(std::vector<char>& out)
-{
- size_t size;
- read(&size, sizeof(size));
+ io::logger->info("connected.");
+}
- size = ntohl(size);
- out.resize(size);
+int tcp::client::read_stream(std::vector<char>& out) {
+ size_t size;
+ read(&size, sizeof(size));
- constexpr size_t chunk_size = 4096;
- size_t total = 0;
+ size = ntohl(size);
+ out.resize(size);
- while(size > 0) {
- auto to_read = std::min(size, chunk_size);
+ constexpr size_t chunk_size = 4096;
+ size_t total = 0;
- int ret = read(&out[total], to_read);
- if(ret <= 0) {
- break;
- }
+ while (size > 0) {
+ auto to_read = std::min(size, chunk_size);
- size -= ret;
- total += ret;
+ int ret = read(&out[total], to_read);
+ if (ret <= 0) {
+ break;
}
- return total;
-}
+ size -= ret;
+ total += ret;
+ }
-int tcp::client::stream(std::vector<char>& data)
-{
- auto size = data.size();
+ return total;
+}
- auto networked_size = htonl(size);
- write(&networked_size, sizeof(networked_size));
+int tcp::client::stream(std::vector<char>& data) {
+ auto size = data.size();
- // with 4kb chunk size, speed peaks at 90mb/s
- constexpr size_t chunk_size = 4096;
- size_t sent = 0;
+ auto networked_size = htonl(size);
+ write(&networked_size, sizeof(networked_size));
- while(size > 0) {
- auto to_send = std::min(size, chunk_size);
+ // with 4kb chunk size, speed peaks at 90mb/s
+ constexpr size_t chunk_size = 4096;
+ size_t sent = 0;
- int ret = write(&data[sent], to_send);
- if(ret <= 0) {
- break;
- }
+ while (size > 0) {
+ auto to_send = std::min(size, chunk_size);
- sent += ret;
- size -= ret;
+ int ret = write(&data[sent], to_send);
+ if (ret <= 0) {
+ break;
}
- return sent;
+ sent += ret;
+ size -= ret;
+ }
+
+ return sent;
}
diff --git a/client/src/client/client.h b/client/src/client/client.h
index 2a3eee2..34eee84 100644
--- a/client/src/client/client.h
+++ b/client/src/client/client.h
@@ -5,60 +5,71 @@
namespace tcp {
- enum client_state : uint8_t { idle = 0, active, standby };
-
- class client {
- int m_socket;
- std::atomic<uint8_t> m_state;
-
- SSL* m_server_ssl;
- SSL_CTX* m_ssl_ctx;
-
- public:
- static constexpr int version = 0;
- std::string session_id;
- event<packet_t&> receive_event;
-
- client() : m_socket{ -1 }, m_state{ 0 } {}
-
- bool start(const std::string_view server_ip, const uint16_t port);
-
- int write(const packet_t& packet)
- {
- if(!packet)
- return 0;
- return SSL_write(m_server_ssl, packet.message.data(), packet.message.size());
- }
-
- int write(void* data, size_t size) { return SSL_write(m_server_ssl, data, size); }
-
- int read(void* data, size_t size) { return SSL_read(m_server_ssl, data, size); }
-
- int read_stream(std::vector<char>& out);
- int stream(std::vector<char>& data);
-
- int get_socket() { return m_socket; }
- void set_state(const uint8_t state) { m_state = state; }
-
- operator bool() const { return m_state == client_state::active; }
-
- static void monitor(client& client)
- {
- while(!client)
- std::this_thread::sleep_for(std::chrono::milliseconds(100));
-
- std::array<char, message_len> buf;
- while(client) {
- int ret = client.read(&buf[0], buf.size());
- if(ret <= 0) {
- io::logger->error("connection lost.");
- break;
- }
- std::string msg(buf.data(), ret);
- packet_t packet(msg, packet_type::read);
-
- client.receive_event.call(packet);
- }
- }
- };
-} // namespace tcp
+struct version_t {
+ uint8_t major = 0;
+ uint8_t minor = 1;
+ uint8_t patch = 0;
+};
+
+class client {
+ int m_socket;
+ std::atomic<bool> m_active;
+
+ SSL* m_server_ssl;
+ SSL_CTX* m_ssl_ctx;
+
+ public:
+ std::string session_id;
+ event<packet_t&> receive_event;
+
+ client() : m_socket{-1}, m_active{false} {}
+
+ void start(const std::string_view server_ip, const uint16_t port);
+
+ int write(const packet_t& packet) {
+ if (!packet) return 0;
+ return SSL_write(m_server_ssl, packet.message.data(),
+ packet.message.size());
+ }
+
+ int write(void* data, size_t size) {
+ return SSL_write(m_server_ssl, data, size);
+ }
+
+ int read(void* data, size_t size) {
+ return SSL_read(m_server_ssl, data, size);
+ }
+
+ int read_stream(std::vector<char>& out);
+ int stream(std::vector<char>& data);
+
+ int get_socket() { return m_socket; }
+
+ operator bool() const { return m_active; }
+
+ void shutdown() {
+ close(m_socket);
+ SSL_shutdown(m_server_ssl);
+ SSL_free(m_server_ssl);
+
+ m_active = false;
+ }
+
+ static void monitor(client& client) {
+ while (!client) std::this_thread::sleep_for(std::chrono::milliseconds(100));
+
+ std::array<char, message_len> buf;
+ while (client) {
+ int ret = client.read(&buf[0], buf.size());
+ if (ret <= 0) {
+ io::logger->error("connection lost.");
+ break;
+ }
+ std::string msg(buf.data(), ret);
+ packet_t packet(msg, packet_type::read);
+
+ client.receive_event.call(packet);
+ }
+ }
+};
+} // namespace tcp
diff --git a/client/src/client/packet.h b/client/src/client/packet.h
index 05bdee2..df44041 100644
--- a/client/src/client/packet.h
+++ b/client/src/client/packet.h
@@ -1,55 +1,51 @@
#pragma once
namespace tcp {
- constexpr size_t session_id_len = 10;
- constexpr size_t message_len = 256 + session_id_len;
-
- enum packet_type : int { write = 0, read };
-
- struct packet_t {
- std::string message;
- char action;
- std::string session_id;
- int id;
-
- packet_t() {}
- packet_t(const std::string_view msg,
- const packet_type& type,
- std::string_view session = "")
- {
- if(type == read) {
- ++id;
-
- if(msg.size() < session_id_len) {
- io::logger->error("packet message invalid!");
- return;
- }
-
- session_id = msg.substr(0, session_id_len);
-
- action = msg[session_id_len];
- message = msg.substr(session_id_len);
- }
- else {
- session_id = session;
-
- message = fmt::format("{}{}", session_id, msg);
-
- if(msg.size() > message_len) {
- io::logger->error("packet message exceeds limit");
- message.clear();
- session_id.clear();
- return;
- }
- }
- }
-
- ~packet_t()
- {
- message.clear();
- session_id.clear();
- }
-
- operator bool() const { return !message.empty() && !session_id.empty(); }
- };
-}; // namespace tcp
+constexpr size_t session_id_len = 10;
+constexpr size_t message_len = 256 + session_id_len;
+
+enum packet_type : int { write = 0, read };
+
+struct packet_t {
+ std::string message;
+ char action;
+ std::string session_id;
+ int id;
+
+ packet_t() {}
+ packet_t(const std::string_view msg, const packet_type& type,
+ std::string_view session = "") {
+ if (type == read) {
+ ++id;
+
+ if (msg.size() < session_id_len) {
+ io::logger->error("packet message invalid!");
+ return;
+ }
+
+ session_id = msg.substr(0, session_id_len);
+
+ action = msg[session_id_len];
+ message = msg.substr(session_id_len);
+ } else {
+ session_id = session;
+
+ message = fmt::format("{}{}", session_id, msg);
+
+ if (msg.size() > message_len) {
+ io::logger->error("packet message exceeds limit");
+ message.clear();
+ session_id.clear();
+ return;
+ }
+ }
+ }
+
+ ~packet_t() {
+ message.clear();
+ session_id.clear();
+ }
+
+ operator bool() const { return !message.empty() && !session_id.empty(); }
+};
+}; // namespace tcp
diff --git a/client/src/main.cpp b/client/src/main.cpp
index 3cc3c77..fef9fca 100644
--- a/client/src/main.cpp
+++ b/client/src/main.cpp
@@ -2,40 +2,42 @@
#include "util/io.h"
#include "client/client.h"
-int main(int argc, char* argv[])
-{
- io::init();
+int main(int argc, char* argv[]) {
+ io::init();
- tcp::client client;
+ tcp::client client;
- std::thread t{ tcp::client::monitor, std::ref(client) };
- t.detach();
+ std::thread t{tcp::client::monitor, std::ref(client)};
+ t.detach();
- if(client.start("127.0.0.1", 6666)) {
- io::logger->info("connected.");
- client.set_state(tcp::client_state::active);
- }
+ client.start("127.0.0.1", 6666);
- client.receive_event.add([&](tcp::packet_t& packet) {
- if(!packet)
- return;
+ client.receive_event.add([&](tcp::packet_t& packet) {
+ if (!packet) return;
- // first packet is the session id and current version
- if(packet.id == 1) {
- client.session_id = packet.session_id;
- }
+ // first packet is the session id and current version
+ if (packet.id == 1) {
+ client.session_id = packet.session_id;
+ tcp::version_t v;
+ auto version = fmt::format("{}.{}.{}", v.major, v.minor, v.patch);
+ if(version != packet.message) {
+ io::logger->error("please update your client");
+ client.shutdown();
+ return;
+ }
+ }
- io::logger->info("{}:{}->{}", packet.id, packet.session_id, packet.message);
- });
+ io::logger->info("{}:{}->{}", packet.id, packet.session_id, packet.message);
+ });
- while(client) {
- std::string p;
- getline(std::cin, p);
+ while (client) {
+ std::string p;
+ getline(std::cin, p);
- int ret =
- client.write(tcp::packet_t(p, tcp::packet_type::write, client.session_id));
- if(ret <= 0) {
- break;
- }
+ int ret = client.write(
+ tcp::packet_t(p, tcp::packet_type::write, client.session_id));
+ if (ret <= 0) {
+ break;
}
+ }
}
diff --git a/client/src/util/events.h b/client/src/util/events.h
index 04ad251..b8d7781 100644
--- a/client/src/util/events.h
+++ b/client/src/util/events.h
@@ -1,27 +1,24 @@
#pragma once
-template<typename... Args>
+template <typename... Args>
class event {
- using func_type = std::function<void(Args...)>;
+ using func_type = std::function<void(Args...)>;
- std::mutex event_lock;
- std::list<func_type> m_funcs;
+ std::mutex event_lock;
+ std::list<func_type> m_funcs;
-public:
- void add(const func_type& func)
- {
- std::lock_guard<std::mutex> lock(event_lock);
+ public:
+ void add(const func_type& func) {
+ std::lock_guard<std::mutex> lock(event_lock);
- m_funcs.push_back(std::move(func));
- }
+ m_funcs.push_back(std::move(func));
+ }
- void call(Args... params)
- {
- std::lock_guard<std::mutex> lock(event_lock);
+ void call(Args... params) {
+ std::lock_guard<std::mutex> lock(event_lock);
- for(auto& func : m_funcs) {
- if(func)
- func(std::forward<Args>(params)...);
- }
+ for (auto& func : m_funcs) {
+ if (func) func(std::forward<Args>(params)...);
}
+ }
}; \ No newline at end of file
diff --git a/client/src/util/io.cpp b/client/src/util/io.cpp
index 94f5575..06d2b9a 100644
--- a/client/src/util/io.cpp
+++ b/client/src/util/io.cpp
@@ -3,10 +3,10 @@
std::shared_ptr<spdlog::logger> io::logger;
-void io::init()
-{
- spdlog::sink_ptr sink = std::make_shared<spdlog::sinks::stdout_color_sink_mt>();
- sink->set_pattern("%^~>%$ %v");
+void io::init() {
+ spdlog::sink_ptr sink =
+ std::make_shared<spdlog::sinks::stdout_color_sink_mt>();
+ sink->set_pattern("%^~>%$ %v");
- logger = std::make_shared<spdlog::logger>("client", sink);
+ logger = std::make_shared<spdlog::logger>("client", sink);
}
diff --git a/client/src/util/io.h b/client/src/util/io.h
index d5ab3be..8eae321 100644
--- a/client/src/util/io.h
+++ b/client/src/util/io.h
@@ -1,7 +1,7 @@
#pragma once
namespace io {
- extern std::shared_ptr<spdlog::logger> logger;
+extern std::shared_ptr<spdlog::logger> logger;
- void init();
-}; // namespace io
+void init();
+}; // namespace io