aboutsummaryrefslogtreecommitdiff
path: root/server
diff options
context:
space:
mode:
authoralpine <[email protected]>2020-06-13 22:27:52 +0200
committeralpine <[email protected]>2020-06-13 22:27:52 +0200
commitbad7b4f2d19f95b278fdcb3056be01cae9af1dbb (patch)
tree5bef91f910a2c03d74df9693a077ee33b2fe7886 /server
parentInitial commit (diff)
downloadloader-bad7b4f2d19f95b278fdcb3056be01cae9af1dbb.tar.xz
loader-bad7b4f2d19f95b278fdcb3056be01cae9af1dbb.zip
Client.
Message encryption. Packet handler. Disconnect event handler.
Diffstat (limited to 'server')
-rw-r--r--server/src/client/client.cpp27
-rw-r--r--server/src/client/client.h57
-rw-r--r--server/src/main.cpp34
-rw-r--r--server/src/server/packet.h42
-rw-r--r--server/src/server/server.cpp106
-rw-r--r--server/src/server/server.h46
-rw-r--r--server/src/server/ssl.h59
-rw-r--r--server/src/util/xor.cpp39
-rw-r--r--server/src/util/xor.h12
9 files changed, 403 insertions, 19 deletions
diff --git a/server/src/client/client.cpp b/server/src/client/client.cpp
new file mode 100644
index 0000000..4c458c8
--- /dev/null
+++ b/server/src/client/client.cpp
@@ -0,0 +1,27 @@
+#include "../include.h"
+#include "../util/io.h"
+#include "client.h"
+
+bool tcp::client::init_ssl(SSL_CTX *server_ctx) {
+ m_ssl = SSL_new(server_ctx);
+ if (!m_ssl) {
+ io::logger->error("failed to create ssl on client {}.", m_ip);
+ return false;
+ }
+
+ int ret = SSL_set_fd(m_ssl, m_socket);
+ if (ret <= 0) {
+ io::logger->error("failed to set descriptor on client {}.", m_ip);
+ return false;
+ }
+
+ ret = SSL_accept(m_ssl);
+ if (ret <= 0) {
+ int err = SSL_get_error(m_ssl, ret);
+ io::logger->error("client {} failed to accept ssl, return code {}", m_ip,
+ err);
+ return false;
+ }
+
+ return true;
+} \ No newline at end of file
diff --git a/server/src/client/client.h b/server/src/client/client.h
new file mode 100644
index 0000000..5242ae7
--- /dev/null
+++ b/server/src/client/client.h
@@ -0,0 +1,57 @@
+#pragma once
+#include "../server/packet.h"
+
+namespace tcp {
+constexpr uint8_t client_version = 0;
+
+class client {
+ int m_socket;
+ SSL *m_ssl;
+
+ time_t m_time;
+
+ std::string m_ip;
+ std::array<char, tcp::uid_len> m_uid;
+ public:
+ client() : m_socket{-1} {};
+ client(const int &socket, const std::string_view ip)
+ : m_socket{std::move(socket)}, m_ip{ip}, m_ssl{nullptr} {
+
+ }
+ ~client() = default;
+
+ bool init_ssl(SSL_CTX *server_ctx);
+
+ void cleanup() {
+ close(m_socket);
+ SSL_shutdown(m_ssl);
+ SSL_free(m_ssl);
+ }
+
+ int write(void *data, size_t size) {
+ return SSL_write(m_ssl, data, size);
+ }
+
+ int read(void *data, size_t size) {
+ return SSL_read(m_ssl, data, size);
+ }
+
+ bool set_uid(const std::string_view uid_str) {
+ const size_t uid_str_len = uid_str.size();
+ if (uid_str_len != tcp::uid_len) {
+ io::logger->error("packet uid len mismatch!");
+ return false;
+ }
+
+ for (size_t i = 0; i < uid_len; ++i) {
+ m_uid[i] = uid_str[i];
+ }
+
+ return true;
+ }
+
+ int &get_socket() { return m_socket; }
+ auto &get_ip() { return m_ip; }
+ auto &get_uid() { return m_uid; }
+};
+}; // namespace tcp \ No newline at end of file
diff --git a/server/src/main.cpp b/server/src/main.cpp
index f34c433..9db266f 100644
--- a/server/src/main.cpp
+++ b/server/src/main.cpp
@@ -2,13 +2,39 @@
#include "util/io.h"
#include "util/commands.h"
#include "server/server.h"
+#include "util/xor.h"
int main(int argc, char *argv[]) {
io::init(false);
- tcp::server server;
- server.start("6666");
- server.start("8981");
+ tcp::server server("6666");
- std::cin.get();
+ server.start();
+
+ server.connect_event.add([&](tcp::client &client) {
+ io::logger->info("{} connected.", client.get_ip());
+ });
+
+ server.disconnect_event.add([&](tcp::client &client) {
+ auto it = std::find_if(server.client_stack.begin(), server.client_stack.end(), [&](tcp::client &c) {
+ return client.get_socket() == client.get_socket();
+ });
+
+ server.client_stack.erase(it);
+ client.cleanup();
+
+ io::logger->info("{} disconnected.", client.get_ip());
+ });
+
+ server.receive_event.add([&](tcp::packet_t &packet, tcp::client &client) {
+ if (!packet) return;
+
+ io::logger->info("{} : {}", packet.uid.data(), packet.message);
+
+ tcp::packet_t resp("hello nigga", tcp::packet_type::write, "1234567890");
+ client.write(resp.message.data(), resp.message.size());
+ });
+
+ std::thread t{tcp::server::monitor, std::ref(server)};
+ t.join();
}
diff --git a/server/src/server/packet.h b/server/src/server/packet.h
index 02d90d1..3930243 100644
--- a/server/src/server/packet.h
+++ b/server/src/server/packet.h
@@ -1,10 +1,42 @@
#pragma once
+#include "../util/xor.h"
namespace tcp {
- constexpr uint8_t uid_len = 10;
+constexpr size_t uid_len = 10;
- struct packet_t {
- std::string message;
- std::array<char, uid_len> uid;
+enum packet_type : int { write = 0, read };
+
+struct packet_t {
+ std::string message;
+ char action;
+ std::string uid;
+
+ packet_t() {}
+ packet_t(const std::string msg, const packet_type &type, std::string userid = "") {
+ if (type == read) {
+ std::string decrypted{msg};
+ enc::decrypt_message(decrypted);
+
+ if (decrypted.size() < uid_len) {
+ io::logger->error("client packet message invalid!");
+ return;
+ }
+
+ uid = decrypted.substr(0, uid_len);
+
+ action = decrypted[uid_len];
+ message = decrypted.substr(uid_len);
+ } else {
+ uid = userid;
+
+ message = fmt::format("{}{}", uid, msg);
+
+ enc::encrypt_message(message);
+ }
+ }
+
+ operator bool() const {
+ return !message.empty() && !uid.empty();
}
-}
+};
+}; // namespace tcp
diff --git a/server/src/server/server.cpp b/server/src/server/server.cpp
index 2f684e1..b140665 100644
--- a/server/src/server/server.cpp
+++ b/server/src/server/server.cpp
@@ -2,8 +2,13 @@
#include "../util/io.h"
#include "server.h"
-bool tcp::server::start(const std::string_view port) {
- io::logger->info("starting server on port {}...", port.data());
+bool tcp::server::start() {
+ io::logger->info("starting server on port {}...", m_port.data());
+
+ ssl ctx("ssl/server.crt", "ssl/server.key", "ssl/rootCA.crt");
+ if (!ctx.init()) return false;
+
+ m_ctx = std::move(ctx.get_context());
m_socket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
if (m_socket < 0) {
@@ -19,7 +24,7 @@ bool tcp::server::start(const std::string_view port) {
hints.ai_protocol = IPPROTO_TCP;
hints.ai_flags = AI_PASSIVE;
- int ret = getaddrinfo(nullptr, port.data(), &hints, &addrinfo);
+ int ret = getaddrinfo(nullptr, m_port.data(), &hints, &addrinfo);
if (ret != 0) {
io::logger->critical("failed to get address info.");
close(m_socket);
@@ -37,16 +42,103 @@ bool tcp::server::start(const std::string_view port) {
ret = listen(m_socket, SOMAXCONN);
if (ret < 0) {
- io::logger->critical("failed to listen on port {}.", port.data());
+ io::logger->critical("failed to listen on port {}.", m_port.data());
close(m_socket);
return false;
}
- io::logger->info("listening on {}.", port.data());
+ io::logger->info("listening on {}.", m_port.data());
+
+ m_status = true;
return true;
}
+tcp::select_status tcp::server::peek() {
+ FD_ZERO(&m_server_set);
+ FD_SET(m_socket, &m_server_set);
+
+ int maxfd = m_socket;
+
+ for (auto &c : client_stack) {
+ const int s = c.get_socket();
+ FD_SET(s, &m_server_set);
+
+ maxfd = std::max(maxfd, s);
+ }
+
+ struct timeval tv;
+ tv.tv_sec = 1;
+ tv.tv_usec = 0;
+
+ const int ret = select(maxfd + 1, &m_server_set, nullptr, nullptr, &tv);
+ if (ret < 0) {
+ io::logger->error("select error : {}", strerror(errno));
+ return tcp::select_status::error;
+ }
+
+ if (ret == 0) {
+ return tcp::select_status::standby;
+ }
+
+ return tcp::select_status::ready;
+}
+
+void tcp::server::accept_client() {
+ if (!FD_ISSET(m_socket, &m_server_set)) return;
+
+ struct sockaddr_in addr;
+ socklen_t len = sizeof(addr);
+ const int client_socket =
+ accept(m_socket, reinterpret_cast<sockaddr *>(&addr), &len);
+
+ const auto ip = inet_ntoa(addr.sin_addr);
+ if (client_socket < 0) {
+ io::logger->warn("{} failed to accept.", ip);
+ close(client_socket);
+ } else {
+ client cli(client_socket, ip);
+ if (!cli.init_ssl(m_ctx)) {
+ cli.cleanup();
+ return;
+ }
+
+ // check for an existing connection
+ auto it = std::find_if(client_stack.begin(), client_stack.end(),
+ [&](client &c) { return c.get_ip() == ip; });
+ if (it != client_stack.end()) {
+ io::logger->info("{} is already connected, dropping...", ip);
+ cli.cleanup();
+ return;
+ }
+
+ connect_event.call(cli);
+ client_stack.emplace_back(cli);
+ }
+}
+
+void tcp::server::receive() {
+ std::array<char, 4096> buf;
+ for (auto &c : client_stack) {
+ const int socket = c.get_socket();
+
+ if (!FD_ISSET(socket, &m_server_set)) continue;
+
+ buf.fill(0);
+
+ const int read = c.read(&buf[0], buf.size());
+ if (read > 0) {
+ std::string msg(buf.data(), read);
+
+ tcp::packet_t packet(msg, tcp::packet_type::read);
+
+ receive_event.call(packet, c);
+ } else {
+ disconnect_event.call(c);
+ }
+ }
+}
+
void tcp::server::stop() {
- io::logger->info("stopping server on port {}.", m_port);
- close(m_socket);
+ io::logger->info("stopping server on port {}.", m_port);
+ close(m_socket);
} \ No newline at end of file
diff --git a/server/src/server/server.h b/server/src/server/server.h
index f848ae2..39b9580 100644
--- a/server/src/server/server.h
+++ b/server/src/server/server.h
@@ -1,12 +1,54 @@
#pragma once
+#include "../client/client.h"
+#include "../util/events.h"
+#include "ssl.h"
namespace tcp {
+constexpr uint8_t server_version = 0;
+
+enum select_status : int { error = 0, standby, ready };
+
class server {
int m_socket;
-
+ std::string_view m_port;
+
+ fd_set m_server_set;
+ SSL_CTX *m_ctx;
+
+ std::atomic<bool> m_status = false;
+
public:
+ std::vector<tcp::client> client_stack;
+
- bool start(const std::string_view port);
+ event<client &> connect_event;
+ event<packet_t &, client &> receive_event;
+ event<client &> disconnect_event;
+
+ server(const std::string_view port) : m_port{port} {}
+ ~server() = default;
+
+ bool start();
+ select_status peek();
+ void accept_client();
+ void receive();
void stop();
+
+ bool running() { return m_status; }
+
+ static void monitor(server &srv) {
+ while (srv.running()) {
+ auto ret = srv.peek();
+ if (ret == select_status::ready) {
+ srv.accept_client();
+ srv.receive();
+ } else if (ret == select_status::standby) {
+ // check for timeout
+ } else {
+ break;
+ }
+ }
+ }
};
+
}; // namespace tcp
diff --git a/server/src/server/ssl.h b/server/src/server/ssl.h
index 7b9637e..06826be 100644
--- a/server/src/server/ssl.h
+++ b/server/src/server/ssl.h
@@ -1 +1,58 @@
-#pragma once \ No newline at end of file
+#pragma once
+
+class ssl {
+ std::string_view m_cert, m_key, m_ca;
+ std::string m_passphrase;
+ SSL_CTX* m_ctx;
+
+ public:
+ ssl(const std::string_view cert, const std::string_view key,
+ const std::string_view ca = "")
+ : m_cert{cert}, m_key{key}, m_ca{ca}, m_ctx{nullptr} {
+ SSL_library_init();
+ }
+ ~ssl() = default;
+
+ bool init() {
+ m_ctx = SSL_CTX_new(TLS_server_method());
+ if (!m_ctx) {
+ io::logger->error("failed to create ssl context.");
+ return false;
+ }
+
+ int res =
+ SSL_CTX_use_certificate_file(m_ctx, m_cert.data(), SSL_FILETYPE_PEM);
+ if (res != 1) {
+ io::logger->error("failed to load certificate.");
+ return false;
+ }
+
+ if (!m_passphrase.empty())
+ SSL_CTX_set_default_passwd_cb_userdata(m_ctx, m_passphrase.data());
+
+ res = SSL_CTX_use_PrivateKey_file(m_ctx, m_key.data(), SSL_FILETYPE_PEM);
+ if (res != 1) {
+ io::logger->error("failed to load private key.");
+ return false;
+ }
+
+ res = SSL_CTX_check_private_key(m_ctx);
+ if (res != 1) {
+ io::logger->error("failed to verify private key.");
+ return false;
+ }
+
+ res = SSL_CTX_load_verify_locations(m_ctx, m_ca.data(), nullptr);
+ if (res != 1) {
+ io::logger->error("failed to load root ca.");
+ return false;
+ }
+
+ SSL_CTX_set_verify(m_ctx, SSL_VERIFY_PEER, 0);
+
+ return true;
+ }
+
+ void set_passphrase(const std::string_view phrase) { m_passphrase = phrase; }
+ auto &get_context() { return m_ctx; }
+};
diff --git a/server/src/util/xor.cpp b/server/src/util/xor.cpp
new file mode 100644
index 0000000..a00ecc9
--- /dev/null
+++ b/server/src/util/xor.cpp
@@ -0,0 +1,39 @@
+#include "../include.h"
+#include "xor.h"
+
+char enc::gen_key() {
+ std::random_device r;
+
+ std::default_random_engine e1(r());
+ std::uniform_real_distribution<> uniform_dist(0, 255);
+ return static_cast<char>(uniform_dist(e1));
+}
+
+void enc::encrypt_message(std::string &str) {
+ std::array<char, key_len> keys;
+ for (size_t i = 0; i < key_len; i++) {
+ keys[i] = gen_key();
+ str.insert(str.end(), keys[i]);
+ }
+
+ for (auto &key : keys) {
+ for (size_t i = 0; i < str.size() - key_len; i++) {
+ str[i] ^= key;
+ }
+ }
+}
+
+void enc::decrypt_message(std::string &str) {
+ if (str.size() <= key_len) return;
+
+ std::string keys = str.substr(0, key_len);
+ std::reverse(keys.begin(), keys.end());
+
+ for (auto &key : keys) {
+ for (size_t i = key_len; i < str.size(); i++) {
+ str[i] ^= key;
+ }
+ }
+
+ str.erase(str.begin(), str.begin() + key_len);
+} \ No newline at end of file
diff --git a/server/src/util/xor.h b/server/src/util/xor.h
new file mode 100644
index 0000000..7180945
--- /dev/null
+++ b/server/src/util/xor.h
@@ -0,0 +1,12 @@
+#pragma once
+
+namespace enc {
+constexpr size_t key_len = 50;
+
+char gen_key();
+
+void encrypt_message(std::string &str);
+
+void decrypt_message(std::string &str);
+
+} // namespace enc \ No newline at end of file