aboutsummaryrefslogtreecommitdiff
path: root/src/zenhorde/hordetransport.cpp
blob: 65eaea47735dde93b2df7a7d43b5df20b50a1fbf (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
// Copyright Epic Games, Inc. All Rights Reserved.

#include "hordetransport.h"

#include <zencore/logging.h>
#include <zencore/trace.h>

ZEN_THIRD_PARTY_INCLUDES_START
#include <asio.hpp>
ZEN_THIRD_PARTY_INCLUDES_END

namespace zen::horde {

// --- AsyncTcpComputeTransport ---

struct AsyncTcpComputeTransport::Impl
{
	asio::io_context&	  IoContext;
	asio::ip::tcp::socket Socket;

	explicit Impl(asio::io_context& Ctx) : IoContext(Ctx), Socket(Ctx) {}
};

AsyncTcpComputeTransport::AsyncTcpComputeTransport(asio::io_context& IoContext)
: m_Impl(std::make_unique<Impl>(IoContext))
, m_Log(zen::logging::Get("horde.transport.async"))
{
}

AsyncTcpComputeTransport::~AsyncTcpComputeTransport()
{
	Close();
}

void
AsyncTcpComputeTransport::AsyncConnect(const MachineInfo& Info, AsyncConnectHandler Handler)
{
	ZEN_TRACE_CPU("AsyncTcpComputeTransport::AsyncConnect");

	asio::error_code Ec;

	const asio::ip::address Address = asio::ip::make_address(Info.GetConnectionAddress(), Ec);
	if (Ec)
	{
		ZEN_WARN("invalid address '{}': {}", Info.GetConnectionAddress(), Ec.message());
		m_HasErrors = true;
		asio::post(m_Impl->IoContext, [Handler = std::move(Handler), Ec] { Handler(Ec); });
		return;
	}

	const asio::ip::tcp::endpoint Endpoint(Address, Info.GetConnectionPort());

	// Copy the nonce so it survives past this scope into the async callback
	auto NonceBuf = std::make_shared<std::vector<uint8_t>>(Info.Nonce, Info.Nonce + NonceSize);

	m_Impl->Socket.async_connect(Endpoint, [this, Handler = std::move(Handler), NonceBuf](const asio::error_code& Ec) mutable {
		if (Ec)
		{
			ZEN_WARN("async connect failed: {}", Ec.message());
			m_HasErrors = true;
			Handler(Ec);
			return;
		}

		asio::error_code SetOptEc;
		m_Impl->Socket.set_option(asio::ip::tcp::no_delay(true), SetOptEc);

		// Send the 64-byte nonce as the first thing on the wire
		asio::async_write(m_Impl->Socket,
						  asio::buffer(*NonceBuf),
						  [this, Handler = std::move(Handler), NonceBuf](const asio::error_code& Ec, size_t /*BytesWritten*/) {
							  if (Ec)
							  {
								  ZEN_WARN("nonce write failed: {}", Ec.message());
								  m_HasErrors = true;
							  }
							  Handler(Ec);
						  });
	});
}

bool
AsyncTcpComputeTransport::IsValid() const
{
	return m_Impl && m_Impl->Socket.is_open() && !m_HasErrors && !m_IsClosed;
}

void
AsyncTcpComputeTransport::AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler)
{
	if (!IsValid())
	{
		asio::post(m_Impl->IoContext,
				   [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); });
		return;
	}

	asio::async_write(m_Impl->Socket, asio::buffer(Data, Size), std::move(Handler));
}

void
AsyncTcpComputeTransport::AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler)
{
	if (!IsValid())
	{
		asio::post(m_Impl->IoContext,
				   [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); });
		return;
	}

	asio::async_read(m_Impl->Socket, asio::buffer(Data, Size), std::move(Handler));
}

void
AsyncTcpComputeTransport::Close()
{
	if (!m_IsClosed && m_Impl && m_Impl->Socket.is_open())
	{
		asio::error_code Ec;
		m_Impl->Socket.shutdown(asio::ip::tcp::socket::shutdown_both, Ec);
		m_Impl->Socket.close(Ec);
	}
	m_IsClosed = true;
}

}  // namespace zen::horde