diff options
Diffstat (limited to 'lib/sharder.ml')
| -rw-r--r-- | lib/sharder.ml | 161 |
1 files changed, 99 insertions, 62 deletions
diff --git a/lib/sharder.ml b/lib/sharder.ml index defcfe1..74870bd 100644 --- a/lib/sharder.ml +++ b/lib/sharder.ml @@ -3,25 +3,39 @@ open Core open Websocket_async exception Invalid_Payload +exception Failure_to_Establish_Heartbeat module Shard = struct - type t = { - mutable hb: unit Ivar.t option; - mutable seq: int; - mutable session: string option; - token: string; - shard: int * int; - write: Frame.t Pipe.Writer.t; - read: Frame.t Pipe.Reader.t; + type shard = { + hb: unit Ivar.t option; + seq: int; + session: string option; + pipe: Frame.t Pipe.Reader.t * Frame.t Pipe.Writer.t; ready: unit Ivar.t; + token: string; + url: string; + id: int * int; + } + + type 'a t = { + mutable shard: 'a; + mutable binds: ('a -> unit) list; } let identify_lock = Mutex.create () + let bind ~f t = + t.binds <- f :: t.binds + let parse (frame:[`Ok of Frame.t | `Eof]) = match frame with - | `Ok s -> Yojson.Basic.from_string s.content (* TODO Handler non-text frames *) - | `Eof -> raise Invalid_Payload (* TODO This needs to go into reconnect code, or stop using client_ez and handle frames manually *) + | `Ok s -> begin + let open Frame.Opcode in + match s.opcode with + | Text -> Some (Yojson.Basic.from_string s.content) + | _ -> None + end + | `Eof -> None let push_frame ?payload ~ev shard = print_endline @@ "Pushing frame. OP: " ^ Opcode.to_string @@ ev; @@ -33,33 +47,31 @@ module Shard = struct ("d", p); ] in - Pipe.write shard.write @@ Frame.create ~content () + let (_, write) = shard.pipe in + Pipe.write_if_open write @@ Frame.create ~content () >>| fun () -> shard let heartbeat shard = - let seq = match shard.seq with + let payload = match shard.seq with | 0 -> `Null | i -> `Int i in - let payload = `Assoc [ - ("op", `Int 1); - ("d", seq); - ] in push_frame ~payload ~ev:HEARTBEAT shard let dispatch ~payload shard = let module J = Yojson.Basic.Util in let seq = J.(member "s" payload |> to_int) in - shard.seq <- seq; let t = J.(member "t" payload |> to_string) in let data = J.member "d" payload in + let session = J.(member "session_id" data |> to_string_option) in if t = "READY" then begin Ivar.fill_if_empty shard.ready (); - let session = J.(member "session_id" data |> to_string) in - shard.session <- Some session end; - return shard + return { shard with + seq = seq; + session = session; + } let set_status ~status shard = let payload = match status with @@ -97,27 +109,30 @@ module Shard = struct Ivar.read shard.ready >>= fun _ -> push_frame ~payload ~ev:REQUEST_GUILD_MEMBERS shard - let initialize ~data shard = + let initialize ?data shard = let module J = Yojson.Basic.Util in let hb = match shard.hb with | None -> begin - let hb_interval = J.(member "heartbeat_interval" data |> to_int) in - let finished = Ivar.create () in - Clock.every' - ~continue_on_error:true - ~finished - (Core.Time.Span.create ~ms:hb_interval ()) - (fun () -> heartbeat shard >>= fun _ -> return ()); - finished + match data with + | Some data -> + let hb_interval = J.(member "heartbeat_interval" data |> to_int) in + let finished = Ivar.create () in + Clock.every' + ~continue_on_error:true + ~finished + (Core.Time.Span.create ~ms:hb_interval ()) + (fun () -> heartbeat shard >>= fun _ -> return ()); + finished + | None -> raise Failure_to_Establish_Heartbeat end | Some s -> s in - shard.hb <- Some hb; - Mutex.lock identify_lock; - let (cur, max) = shard.shard in + let shard = { shard with hb = Some hb; } in + let (cur, max) = shard.id in let shards = [`Int cur; `Int max] in match shard.session with - | None -> + | None -> begin + Mutex.lock identify_lock; let payload = `Assoc [ ("token", `String shard.token); ("properties", `Assoc [ @@ -130,6 +145,12 @@ module Shard = struct ("shard", `List shards); ] in push_frame ~payload ~ev:IDENTIFY shard + >>| fun s -> begin + Clock.after (Core.Time.Span.create ~sec:5 ()) + >>> (fun _ -> Mutex.unlock identify_lock); + s + end + end | Some s -> let payload = `Assoc [ ("token", `String shard.token); @@ -137,11 +158,6 @@ module Shard = struct ("seq", `Int shard.seq) ] in push_frame ~payload ~ev:RESUME shard - >>| fun s -> - Clock.after (Core.Time.Span.create ~sec:5 ()) - >>| (fun _ -> Mutex.unlock identify_lock) - |> ignore; - s let handle_frame ~f shard = let module J = Yojson.Basic.Util in @@ -151,15 +167,21 @@ module Shard = struct match op with | DISPATCH -> dispatch ~payload:f shard | HEARTBEAT -> heartbeat shard - | RECONNECT -> print_endline "OP 7"; return shard (* TODO reconnect *) - | INVALID_SESSION -> print_endline "OP 9"; return shard (* TODO invalid session *) + | INVALID_SESSION -> begin + if J.(member "d" f |> to_bool) then + initialize shard + else begin + initialize { shard with session = None; } + end + end + | RECONNECT -> initialize shard | HELLO -> initialize ~data:(J.member "d" f) shard | HEARTBEAT_ACK -> return shard | opcode -> print_endline @@ "Invalid Opcode: " ^ Opcode.to_string opcode; return shard - let create ~url ~shards ~token () = + let rec create ~url ~shards ~token () = let open Core in let uri = (url ^ "?v=6&encoding=json") |> Uri.of_string in let extra_headers = Http.Base.process_request_headers () in @@ -184,26 +206,36 @@ module Shard = struct uri >>> ignore; Ivar.read initialized >>| fun () -> - let rec ev_loop ~reader shard = - Pipe.read reader + let rec ev_loop t = + let (read, _) = t.shard.pipe in + Pipe.read read >>= fun frame -> - handle_frame ~f:(parse frame) shard - >>= fun shard -> - ev_loop ~reader shard + (match parse frame with + | Some f -> begin + handle_frame ~f t.shard + >>| fun shard -> + t.shard <- shard; + t + end + | None -> recreate t.shard) + >>= fun t -> + List.iter ~f:(fun f -> f t.shard) t.binds; + ev_loop t in let shard = { - read; - write; + pipe = (read, write); ready = Ivar.create (); hb = None; seq = 0; - shard = shards; + id = shards; session = None; - token = token; + token; + url; } in - ev_loop ~reader:read shard |> ignore; - shard + let t = { shard; binds = []; } in + ev_loop t >>> ignore; + t in match Unix.getaddrinfo host (string_of_int port) [] with | [] -> failwithf "DNS resolution failed for %s" host () @@ -220,10 +252,17 @@ module Shard = struct `TCP (h, p) in Conduit_async.V2.connect addr >>= tcp_fun + and recreate shard = + print_endline "Reconnecting..."; + (match shard.hb with + | Some hb -> Ivar.fill_if_empty hb () + | None -> () + ); + create ~url:(shard.url) ~shards:(shard.id) ~token:(shard.token) () end type t = { - shards: Shard.t list; + mutable shards: (Shard.shard Shard.t) list; } let start ?count token = @@ -246,21 +285,19 @@ let start ?count token = in gen_shards shard_list [] >>| fun shards -> - { - shards; - } + { shards; } let set_status sharder status = - Deferred.all @@ List.map ~f:(fun shard -> - Shard.set_status ~status shard + Deferred.all @@ List.map ~f:(fun t -> + Shard.set_status ~status t.shard ) sharder.shards let set_status_with sharder f = - Deferred.all @@ List.map ~f:(fun shard -> - Shard.set_status ~status:(f shard) shard + Deferred.all @@ List.map ~f:(fun t -> + Shard.set_status ~status:(f t.shard) t.shard ) sharder.shards let request_guild_members ~guild ?query ?limit sharder = - Deferred.all @@ List.map ~f:(fun shard -> - Shard.request_guild_members ~guild ?query ?limit shard + Deferred.all @@ List.map ~f:(fun t -> + Shard.request_guild_members ~guild ?query ?limit t.shard ) sharder.shards |