diff options
| author | alpine <[email protected]> | 2020-06-24 13:05:48 +0200 |
|---|---|---|
| committer | alpine <[email protected]> | 2020-06-24 13:05:48 +0200 |
| commit | 8f0130c8f74482a7d54f9bfb8763f4c6d705765c (patch) | |
| tree | e9090817a498da9ae18a58adee92f4f8ba2db03b /client/src | |
| parent | Changed clang format style. (diff) | |
| download | loader-8f0130c8f74482a7d54f9bfb8763f4c6d705765c.tar.xz loader-8f0130c8f74482a7d54f9bfb8763f4c6d705765c.zip | |
Added client version control.
Reverted back to google formatting.
Diffstat (limited to 'client/src')
| -rw-r--r-- | client/src/client/client.cpp | 127 | ||||
| -rw-r--r-- | client/src/client/client.h | 125 | ||||
| -rw-r--r-- | client/src/client/packet.h | 100 | ||||
| -rw-r--r-- | client/src/main.cpp | 56 | ||||
| -rw-r--r-- | client/src/util/events.h | 31 | ||||
| -rw-r--r-- | client/src/util/io.cpp | 10 | ||||
| -rw-r--r-- | client/src/util/io.h | 6 |
7 files changed, 230 insertions, 225 deletions
diff --git a/client/src/client/client.cpp b/client/src/client/client.cpp index 3dd55fc..30d3c9e 100644 --- a/client/src/client/client.cpp +++ b/client/src/client/client.cpp @@ -1,93 +1,92 @@ #include "../include.h" #include "client.h" -bool tcp::client::start(const std::string_view server_ip, const uint16_t port) -{ - SSL_library_init(); +void tcp::client::start(const std::string_view server_ip, const uint16_t port) { + SSL_library_init(); - m_ssl_ctx = SSL_CTX_new(TLS_client_method()); + m_ssl_ctx = SSL_CTX_new(TLS_client_method()); - m_socket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); - if(m_socket == -1) { - io::logger->error("failed to create socket."); - return false; - } + m_socket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (m_socket == -1) { + io::logger->error("failed to create socket."); + return; + } - sockaddr_in server_addr; + sockaddr_in server_addr; - server_addr.sin_family = AF_INET; - server_addr.sin_addr.s_addr = inet_addr(server_ip.data()); - server_addr.sin_port = htons(port); + server_addr.sin_family = AF_INET; + server_addr.sin_addr.s_addr = inet_addr(server_ip.data()); + server_addr.sin_port = htons(port); - int ret = - connect(m_socket, reinterpret_cast<sockaddr*>(&server_addr), sizeof(server_addr)); - if(ret < 0) { - io::logger->error("failed to connect to server."); - return false; - } + int ret = connect(m_socket, reinterpret_cast<sockaddr*>(&server_addr), + sizeof(server_addr)); + if (ret < 0) { + io::logger->error("failed to connect to server."); + return; + } - m_server_ssl = SSL_new(m_ssl_ctx); - SSL_set_fd(m_server_ssl, m_socket); + m_server_ssl = SSL_new(m_ssl_ctx); + SSL_set_fd(m_server_ssl, m_socket); - ret = SSL_connect(m_server_ssl); + ret = SSL_connect(m_server_ssl); - if(ret != 1) { - ret = SSL_get_error(m_server_ssl, ret); - io::logger->error("ssl connection failed, code {}", ret); - return false; - } + if (ret != 1) { + ret = SSL_get_error(m_server_ssl, ret); + io::logger->error("ssl connection failed, code {}", ret); + return; + } - return true; -} + m_active = true; -int tcp::client::read_stream(std::vector<char>& out) -{ - size_t size; - read(&size, sizeof(size)); + io::logger->info("connected."); +} - size = ntohl(size); - out.resize(size); +int tcp::client::read_stream(std::vector<char>& out) { + size_t size; + read(&size, sizeof(size)); - constexpr size_t chunk_size = 4096; - size_t total = 0; + size = ntohl(size); + out.resize(size); - while(size > 0) { - auto to_read = std::min(size, chunk_size); + constexpr size_t chunk_size = 4096; + size_t total = 0; - int ret = read(&out[total], to_read); - if(ret <= 0) { - break; - } + while (size > 0) { + auto to_read = std::min(size, chunk_size); - size -= ret; - total += ret; + int ret = read(&out[total], to_read); + if (ret <= 0) { + break; } - return total; -} + size -= ret; + total += ret; + } -int tcp::client::stream(std::vector<char>& data) -{ - auto size = data.size(); + return total; +} - auto networked_size = htonl(size); - write(&networked_size, sizeof(networked_size)); +int tcp::client::stream(std::vector<char>& data) { + auto size = data.size(); - // with 4kb chunk size, speed peaks at 90mb/s - constexpr size_t chunk_size = 4096; - size_t sent = 0; + auto networked_size = htonl(size); + write(&networked_size, sizeof(networked_size)); - while(size > 0) { - auto to_send = std::min(size, chunk_size); + // with 4kb chunk size, speed peaks at 90mb/s + constexpr size_t chunk_size = 4096; + size_t sent = 0; - int ret = write(&data[sent], to_send); - if(ret <= 0) { - break; - } + while (size > 0) { + auto to_send = std::min(size, chunk_size); - sent += ret; - size -= ret; + int ret = write(&data[sent], to_send); + if (ret <= 0) { + break; } - return sent; + sent += ret; + size -= ret; + } + + return sent; } diff --git a/client/src/client/client.h b/client/src/client/client.h index 2a3eee2..34eee84 100644 --- a/client/src/client/client.h +++ b/client/src/client/client.h @@ -5,60 +5,71 @@ namespace tcp { - enum client_state : uint8_t { idle = 0, active, standby }; - - class client { - int m_socket; - std::atomic<uint8_t> m_state; - - SSL* m_server_ssl; - SSL_CTX* m_ssl_ctx; - - public: - static constexpr int version = 0; - std::string session_id; - event<packet_t&> receive_event; - - client() : m_socket{ -1 }, m_state{ 0 } {} - - bool start(const std::string_view server_ip, const uint16_t port); - - int write(const packet_t& packet) - { - if(!packet) - return 0; - return SSL_write(m_server_ssl, packet.message.data(), packet.message.size()); - } - - int write(void* data, size_t size) { return SSL_write(m_server_ssl, data, size); } - - int read(void* data, size_t size) { return SSL_read(m_server_ssl, data, size); } - - int read_stream(std::vector<char>& out); - int stream(std::vector<char>& data); - - int get_socket() { return m_socket; } - void set_state(const uint8_t state) { m_state = state; } - - operator bool() const { return m_state == client_state::active; } - - static void monitor(client& client) - { - while(!client) - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - - std::array<char, message_len> buf; - while(client) { - int ret = client.read(&buf[0], buf.size()); - if(ret <= 0) { - io::logger->error("connection lost."); - break; - } - std::string msg(buf.data(), ret); - packet_t packet(msg, packet_type::read); - - client.receive_event.call(packet); - } - } - }; -} // namespace tcp +struct version_t { + uint8_t major = 0; + uint8_t minor = 1; + uint8_t patch = 0; +}; + +class client { + int m_socket; + std::atomic<bool> m_active; + + SSL* m_server_ssl; + SSL_CTX* m_ssl_ctx; + + public: + std::string session_id; + event<packet_t&> receive_event; + + client() : m_socket{-1}, m_active{false} {} + + void start(const std::string_view server_ip, const uint16_t port); + + int write(const packet_t& packet) { + if (!packet) return 0; + return SSL_write(m_server_ssl, packet.message.data(), + packet.message.size()); + } + + int write(void* data, size_t size) { + return SSL_write(m_server_ssl, data, size); + } + + int read(void* data, size_t size) { + return SSL_read(m_server_ssl, data, size); + } + + int read_stream(std::vector<char>& out); + int stream(std::vector<char>& data); + + int get_socket() { return m_socket; } + + operator bool() const { return m_active; } + + void shutdown() { + close(m_socket); + SSL_shutdown(m_server_ssl); + SSL_free(m_server_ssl); + + m_active = false; + } + + static void monitor(client& client) { + while (!client) std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + std::array<char, message_len> buf; + while (client) { + int ret = client.read(&buf[0], buf.size()); + if (ret <= 0) { + io::logger->error("connection lost."); + break; + } + std::string msg(buf.data(), ret); + packet_t packet(msg, packet_type::read); + + client.receive_event.call(packet); + } + } +}; +} // namespace tcp diff --git a/client/src/client/packet.h b/client/src/client/packet.h index 05bdee2..df44041 100644 --- a/client/src/client/packet.h +++ b/client/src/client/packet.h @@ -1,55 +1,51 @@ #pragma once namespace tcp { - constexpr size_t session_id_len = 10; - constexpr size_t message_len = 256 + session_id_len; - - enum packet_type : int { write = 0, read }; - - struct packet_t { - std::string message; - char action; - std::string session_id; - int id; - - packet_t() {} - packet_t(const std::string_view msg, - const packet_type& type, - std::string_view session = "") - { - if(type == read) { - ++id; - - if(msg.size() < session_id_len) { - io::logger->error("packet message invalid!"); - return; - } - - session_id = msg.substr(0, session_id_len); - - action = msg[session_id_len]; - message = msg.substr(session_id_len); - } - else { - session_id = session; - - message = fmt::format("{}{}", session_id, msg); - - if(msg.size() > message_len) { - io::logger->error("packet message exceeds limit"); - message.clear(); - session_id.clear(); - return; - } - } - } - - ~packet_t() - { - message.clear(); - session_id.clear(); - } - - operator bool() const { return !message.empty() && !session_id.empty(); } - }; -}; // namespace tcp +constexpr size_t session_id_len = 10; +constexpr size_t message_len = 256 + session_id_len; + +enum packet_type : int { write = 0, read }; + +struct packet_t { + std::string message; + char action; + std::string session_id; + int id; + + packet_t() {} + packet_t(const std::string_view msg, const packet_type& type, + std::string_view session = "") { + if (type == read) { + ++id; + + if (msg.size() < session_id_len) { + io::logger->error("packet message invalid!"); + return; + } + + session_id = msg.substr(0, session_id_len); + + action = msg[session_id_len]; + message = msg.substr(session_id_len); + } else { + session_id = session; + + message = fmt::format("{}{}", session_id, msg); + + if (msg.size() > message_len) { + io::logger->error("packet message exceeds limit"); + message.clear(); + session_id.clear(); + return; + } + } + } + + ~packet_t() { + message.clear(); + session_id.clear(); + } + + operator bool() const { return !message.empty() && !session_id.empty(); } +}; +}; // namespace tcp diff --git a/client/src/main.cpp b/client/src/main.cpp index 3cc3c77..fef9fca 100644 --- a/client/src/main.cpp +++ b/client/src/main.cpp @@ -2,40 +2,42 @@ #include "util/io.h" #include "client/client.h" -int main(int argc, char* argv[]) -{ - io::init(); +int main(int argc, char* argv[]) { + io::init(); - tcp::client client; + tcp::client client; - std::thread t{ tcp::client::monitor, std::ref(client) }; - t.detach(); + std::thread t{tcp::client::monitor, std::ref(client)}; + t.detach(); - if(client.start("127.0.0.1", 6666)) { - io::logger->info("connected."); - client.set_state(tcp::client_state::active); - } + client.start("127.0.0.1", 6666); - client.receive_event.add([&](tcp::packet_t& packet) { - if(!packet) - return; + client.receive_event.add([&](tcp::packet_t& packet) { + if (!packet) return; - // first packet is the session id and current version - if(packet.id == 1) { - client.session_id = packet.session_id; - } + // first packet is the session id and current version + if (packet.id == 1) { + client.session_id = packet.session_id; + tcp::version_t v; + auto version = fmt::format("{}.{}.{}", v.major, v.minor, v.patch); + if(version != packet.message) { + io::logger->error("please update your client"); + client.shutdown(); + return; + } + } - io::logger->info("{}:{}->{}", packet.id, packet.session_id, packet.message); - }); + io::logger->info("{}:{}->{}", packet.id, packet.session_id, packet.message); + }); - while(client) { - std::string p; - getline(std::cin, p); + while (client) { + std::string p; + getline(std::cin, p); - int ret = - client.write(tcp::packet_t(p, tcp::packet_type::write, client.session_id)); - if(ret <= 0) { - break; - } + int ret = client.write( + tcp::packet_t(p, tcp::packet_type::write, client.session_id)); + if (ret <= 0) { + break; } + } } diff --git a/client/src/util/events.h b/client/src/util/events.h index 04ad251..b8d7781 100644 --- a/client/src/util/events.h +++ b/client/src/util/events.h @@ -1,27 +1,24 @@ #pragma once -template<typename... Args> +template <typename... Args> class event { - using func_type = std::function<void(Args...)>; + using func_type = std::function<void(Args...)>; - std::mutex event_lock; - std::list<func_type> m_funcs; + std::mutex event_lock; + std::list<func_type> m_funcs; -public: - void add(const func_type& func) - { - std::lock_guard<std::mutex> lock(event_lock); + public: + void add(const func_type& func) { + std::lock_guard<std::mutex> lock(event_lock); - m_funcs.push_back(std::move(func)); - } + m_funcs.push_back(std::move(func)); + } - void call(Args... params) - { - std::lock_guard<std::mutex> lock(event_lock); + void call(Args... params) { + std::lock_guard<std::mutex> lock(event_lock); - for(auto& func : m_funcs) { - if(func) - func(std::forward<Args>(params)...); - } + for (auto& func : m_funcs) { + if (func) func(std::forward<Args>(params)...); } + } };
\ No newline at end of file diff --git a/client/src/util/io.cpp b/client/src/util/io.cpp index 94f5575..06d2b9a 100644 --- a/client/src/util/io.cpp +++ b/client/src/util/io.cpp @@ -3,10 +3,10 @@ std::shared_ptr<spdlog::logger> io::logger; -void io::init() -{ - spdlog::sink_ptr sink = std::make_shared<spdlog::sinks::stdout_color_sink_mt>(); - sink->set_pattern("%^~>%$ %v"); +void io::init() { + spdlog::sink_ptr sink = + std::make_shared<spdlog::sinks::stdout_color_sink_mt>(); + sink->set_pattern("%^~>%$ %v"); - logger = std::make_shared<spdlog::logger>("client", sink); + logger = std::make_shared<spdlog::logger>("client", sink); } diff --git a/client/src/util/io.h b/client/src/util/io.h index d5ab3be..8eae321 100644 --- a/client/src/util/io.h +++ b/client/src/util/io.h @@ -1,7 +1,7 @@ #pragma once namespace io { - extern std::shared_ptr<spdlog::logger> logger; +extern std::shared_ptr<spdlog::logger> logger; - void init(); -}; // namespace io +void init(); +}; // namespace io |