From 7c9b809078b5cd53e3d54c0004c683da2ec679af Mon Sep 17 00:00:00 2001 From: Adelyn Breedlove Date: Mon, 11 Feb 2019 17:23:59 +0000 Subject: Add a cache --- lib/gateway/dispatch.ml | 36 +++++ lib/gateway/dispatch.mli | 120 +++++++++++++++ lib/gateway/event.ml | 176 +++++++++++++++++++++ lib/gateway/event.mli | 49 ++++++ lib/gateway/opcode.ml | 54 +++++++ lib/gateway/opcode.mli | 29 ++++ lib/gateway/sharder.ml | 393 +++++++++++++++++++++++++++++++++++++++++++++++ lib/gateway/sharder.mli | 102 ++++++++++++ 8 files changed, 959 insertions(+) create mode 100644 lib/gateway/dispatch.ml create mode 100644 lib/gateway/dispatch.mli create mode 100644 lib/gateway/event.ml create mode 100644 lib/gateway/event.mli create mode 100644 lib/gateway/opcode.ml create mode 100644 lib/gateway/opcode.mli create mode 100644 lib/gateway/sharder.ml create mode 100644 lib/gateway/sharder.mli (limited to 'lib/gateway') diff --git a/lib/gateway/dispatch.ml b/lib/gateway/dispatch.ml new file mode 100644 index 0000000..b4fc9d2 --- /dev/null +++ b/lib/gateway/dispatch.ml @@ -0,0 +1,36 @@ +open Event_models + +let ready = ref (fun (_:Ready.t) -> ()) +let resumed = ref (fun (_:Resumed.t) -> ()) +let channel_create = ref (fun (_:ChannelCreate.t) -> ()) +let channel_update = ref (fun (_:ChannelUpdate.t) -> ()) +let channel_delete = ref (fun (_:ChannelDelete.t) -> ()) +let channel_pins_update = ref (fun (_:ChannelPinsUpdate.t) -> ()) +let guild_create = ref (fun (_:GuildCreate.t) -> ()) +let guild_update = ref (fun (_:GuildUpdate.t) -> ()) +let guild_delete = ref (fun (_:GuildDelete.t) -> ()) +let member_ban = ref (fun (_:GuildBanAdd.t) -> ()) +let member_unban = ref (fun (_:GuildBanRemove.t) -> ()) +let guild_emojis_update = ref (fun (_:GuildEmojisUpdate.t) -> ()) +(* let integrations_update = ref (fun (_:Yojson.Safe.t) -> ()) *) +let member_join = ref (fun (_:GuildMemberAdd.t) -> ()) +let member_leave = ref (fun (_:GuildMemberRemove.t) -> ()) +let member_update = ref (fun (_:GuildMemberUpdate.t) -> ()) +let members_chunk = ref (fun (_:GuildMembersChunk.t) -> ()) +let role_create = ref (fun (_:GuildRoleCreate.t) -> ()) +let role_update = ref (fun (_:GuildRoleUpdate.t) -> ()) +let role_delete = ref (fun (_:GuildRoleDelete.t) -> ()) +let message_create = ref (fun (_:MessageCreate.t) -> ()) +let message_update = ref (fun (_:MessageUpdate.t) -> ()) +let message_delete = ref (fun (_:MessageDelete.t) -> ()) +let message_delete_bulk = ref (fun (_:MessageDeleteBulk.t) -> ()) +let reaction_add = ref (fun (_:ReactionAdd.t) -> ()) +let reaction_remove = ref (fun (_:ReactionRemove.t) -> ()) +let reaction_remove_all = ref (fun (_:ReactionRemoveAll.t) -> ()) +let presence_update = ref (fun (_:PresenceUpdate.t) -> ()) +let typing_start = ref (fun (_:TypingStart.t) -> ()) +let user_update = ref (fun (_:UserUpdate.t) -> ()) +(* let voice_state_update = ref (fun (_:Yojson.Safe.t) -> ()) *) +(* let voice_server_update = ref (fun (_:Yojson.Safe.t) -> ()) *) +let webhook_update = ref (fun (_:WebhookUpdate.t) -> ()) +let unknown = ref (fun (_:Unknown.t) -> ()) \ No newline at end of file diff --git a/lib/gateway/dispatch.mli b/lib/gateway/dispatch.mli new file mode 100644 index 0000000..18b9261 --- /dev/null +++ b/lib/gateway/dispatch.mli @@ -0,0 +1,120 @@ +(** Used to store dispatch callbacks. Each event can only have one callback registered at a time. + These should be accessed through their re-export in {!Client}. + {3 Examples} + [Client.ready := (fun _ -> print_endline "Shard is Ready!")] + + [Client.guild_create := (fun guild -> print_endline guild.name)] + + {[ + open Core + open Disml + + let check_command (msg : Message.t) = + if String.is_prefix ~prefix:"!ping" msg.content then + Message.reply msg "Pong!" >>> ignore + + Client.message_create := check_command + ]} +*) + +open Event_models + +(** Dispatched when each shard receives READY from discord after identifying on the gateway. Other event dispatch is received after this. *) +val ready : (Ready.t -> unit) ref + +(** Dispatched when successfully reconnecting to the gateway. *) +val resumed : (Resumed.t -> unit) ref + +(** Dispatched when a channel is created which is visible to the bot. *) +val channel_create : (ChannelCreate.t -> unit) ref + +(** Dispatched when a channel visible to the bot is changed. *) +val channel_update : (ChannelUpdate.t -> unit) ref + +(** Dispatched when a channel visible to the bot is deleted. *) +val channel_delete : (ChannelDelete.t -> unit) ref + +(** Dispatched when messages are pinned or unpinned from a a channel. *) +val channel_pins_update : (ChannelPinsUpdate.t -> unit) ref + +(** Dispatched when the bot joins a guild, and during startup. *) +val guild_create : (GuildCreate.t -> unit) ref + +(** Dispatched when a guild the bot is in is edited. *) +val guild_update : (GuildUpdate.t -> unit) ref + +(** Dispatched when the bot is removed from a guild. *) +val guild_delete : (GuildDelete.t -> unit) ref + +(** Dispatched when a member is banned. *) +val member_ban : (GuildBanAdd.t -> unit) ref + +(** Dispatched when a member is unbanned. *) +val member_unban : (GuildBanRemove.t -> unit) ref + +(** Dispatched when emojis are added or removed from a guild. *) +val guild_emojis_update : (GuildEmojisUpdate.t -> unit) ref + +(** Dispatched when a guild's integrations are updated. *) +(* val integrations_update : (Yojson.Safe.t -> unit) ref *) + +(** Dispatched when a member joins a guild. *) +val member_join : (GuildMemberAdd.t -> unit) ref + +(** Dispatched when a member leaves a guild. Is Dispatched alongside {!Client.member_ban} when a user is banned. *) +val member_leave : (GuildMemberRemove.t -> unit) ref + +(** Dispatched when a member object is updated. *) +val member_update : (GuildMemberUpdate.t -> unit) ref + +(** Dispatched when requesting guild members through {!Client.request_guild_members} *) +val members_chunk : (GuildMembersChunk.t -> unit) ref + +(** Dispatched when a role is created. *) +val role_create : (GuildRoleCreate.t -> unit) ref + +(** Dispatched when a role is edited. *) +val role_update : (GuildRoleUpdate.t -> unit) ref + +(** Dispatched when a role is deleted. *) +val role_delete : (GuildRoleDelete.t -> unit) ref + +(** Dispatched when a message is sent. *) +val message_create : (MessageCreate.t -> unit) ref + +(** Dispatched when a message is edited. This does not necessarily mean the content changed. *) +val message_update : (MessageUpdate.t -> unit) ref + +(** Dispatched when a message is deleted. *) +val message_delete : (MessageDelete.t -> unit) ref + +(** Dispatched when messages are bulk deleted. *) +val message_delete_bulk : (MessageDeleteBulk.t -> unit) ref + +(** Dispatched when a rection is added to a message. *) +val reaction_add : (ReactionAdd.t -> unit) ref + +(** Dispatched when a reaction is removed from a message. *) +val reaction_remove : (ReactionRemove.t -> unit) ref + +(** Dispatched when all reactions are cleared from a message. *) +val reaction_remove_all : (ReactionRemoveAll.t -> unit) ref + +(** Dispatched when a user updates their presence. *) +val presence_update : (PresenceUpdate.t -> unit) ref + +(** Dispatched when a typing indicator is displayed. *) +val typing_start : (TypingStart.t -> unit) ref + +(** Dispatched when the current user is updated. You most likely want {!Client.member_update} or {!Client.presence_update} instead. *) +val user_update : (UserUpdate.t -> unit) ref + +(** Dispatched when a webhook is updated. *) +val webhook_update : (WebhookUpdate.t -> unit) ref + +(** Dispatched as a fallback for unknown events. *) +val unknown : (Unknown.t -> unit) ref + +(**/**) +(* val voice_state_update : (Yojson.Safe.t -> unit) ref *) +(* val voice_server_update : (Yojson.Safe.t -> unit) ref *) \ No newline at end of file diff --git a/lib/gateway/event.ml b/lib/gateway/event.ml new file mode 100644 index 0000000..88dd50d --- /dev/null +++ b/lib/gateway/event.ml @@ -0,0 +1,176 @@ +open Async +open Core +open Event_models + +type t = +| READY of Ready.t +| RESUMED of Resumed.t +| CHANNEL_CREATE of ChannelCreate.t +| CHANNEL_UPDATE of ChannelUpdate.t +| CHANNEL_DELETE of ChannelDelete.t +| CHANNEL_PINS_UPDATE of ChannelPinsUpdate.t +| GUILD_CREATE of GuildCreate.t +| GUILD_UPDATE of GuildUpdate.t +| GUILD_DELETE of GuildDelete.t +| GUILD_BAN_ADD of GuildBanAdd.t +| GUILD_BAN_REMOVE of GuildBanRemove.t +| GUILD_EMOJIS_UPDATE of GuildEmojisUpdate.t +(* | GUILD_INTEGRATIONS_UPDATE of Yojson.Safe.t *) +| GUILD_MEMBER_ADD of GuildMemberAdd.t +| GUILD_MEMBER_REMOVE of GuildMemberRemove.t +| GUILD_MEMBER_UPDATE of GuildMemberUpdate.t +| GUILD_MEMBERS_CHUNK of GuildMembersChunk.t +| GUILD_ROLE_CREATE of GuildRoleCreate.t +| GUILD_ROLE_UPDATE of GuildRoleUpdate.t +| GUILD_ROLE_DELETE of GuildRoleDelete.t +| MESSAGE_CREATE of MessageCreate.t +| MESSAGE_UPDATE of MessageUpdate.t +| MESSAGE_DELETE of MessageDelete.t +| MESSAGE_DELETE_BULK of MessageDeleteBulk.t +| REACTION_ADD of ReactionAdd.t +| REACTION_REMOVE of ReactionRemove.t +| REACTION_REMOVE_ALL of ReactionRemoveAll.t +| PRESENCE_UPDATE of PresenceUpdate.t +| TYPING_START of TypingStart.t +| USER_UPDATE of UserUpdate.t +(* | VOICE_STATE_UPDATE of Yojson.Safe.t *) +(* | VOICE_SERVER_UPDATE of Yojson.Safe.t *) +| WEBHOOK_UPDATE of WebhookUpdate.t +| UNKNOWN of Unknown.t + +let event_of_yojson ~contents = function + | "READY" -> READY Ready.(deserialize contents) + | "RESUMED" -> RESUMED Resumed.(deserialize contents) + | "CHANNEL_CREATE" -> CHANNEL_CREATE ChannelCreate.(deserialize contents) + | "CHANNEL_UPDATE" -> CHANNEL_UPDATE ChannelUpdate.(deserialize contents) + | "CHANNEL_DELETE" -> CHANNEL_DELETE ChannelDelete.(deserialize contents) + | "CHANNEL_PINS_UPDATE" -> CHANNEL_PINS_UPDATE ChannelPinsUpdate.(deserialize contents) + | "GUILD_CREATE" -> GUILD_CREATE GuildCreate.(deserialize contents) + | "GUILD_UPDATE" -> GUILD_UPDATE GuildUpdate.(deserialize contents) + | "GUILD_DELETE" -> GUILD_DELETE GuildDelete.(deserialize contents) + | "GUILD_BAN_ADD" -> GUILD_BAN_ADD GuildBanAdd.(deserialize contents) + | "GUILD_BAN_REMOVE" -> GUILD_BAN_REMOVE GuildBanRemove.(deserialize contents) + | "GUILD_EMOJIS_UPDATE" -> GUILD_EMOJIS_UPDATE GuildEmojisUpdate.(deserialize contents) + (* | "GUILD_INTEGRATIONS_UPDATE" -> GUILD_INTEGRATIONS_UPDATE contents *) + | "GUILD_MEMBER_ADD" -> GUILD_MEMBER_ADD GuildMemberAdd.(deserialize contents) + | "GUILD_MEMBER_REMOVE" -> GUILD_MEMBER_REMOVE GuildMemberRemove.(deserialize contents) + | "GUILD_MEMBER_UPDATE" -> GUILD_MEMBER_UPDATE GuildMemberUpdate.(deserialize contents) + | "GUILD_MEMBERS_CHUNK" -> GUILD_MEMBERS_CHUNK GuildMembersChunk.(deserialize contents) + | "GUILD_ROLE_CREATE" -> GUILD_ROLE_CREATE GuildRoleCreate.(deserialize contents) + | "GUILD_ROLE_UPDATE" -> GUILD_ROLE_UPDATE GuildRoleUpdate.(deserialize contents) + | "GUILD_ROLE_DELETE" -> GUILD_ROLE_DELETE GuildRoleDelete.(deserialize contents) + | "MESSAGE_CREATE" -> MESSAGE_CREATE MessageCreate.(deserialize contents) + | "MESSAGE_UPDATE" -> MESSAGE_UPDATE MessageUpdate.(deserialize contents) + | "MESSAGE_DELETE" -> MESSAGE_DELETE MessageDelete.(deserialize contents) + | "MESSAGE_DELETE_BULK" -> MESSAGE_DELETE_BULK MessageDeleteBulk.(deserialize contents) + | "MESSAGE_REACTION_ADD" -> REACTION_ADD ReactionAdd.(deserialize contents) + | "MESSAGE_REACTION_REMOVE" -> REACTION_REMOVE ReactionRemove.(deserialize contents) + | "MESSAGE_REACTION_REMOVE_ALL" -> REACTION_REMOVE_ALL ReactionRemoveAll.(deserialize contents) + | "PRESENCE_UPDATE" -> PRESENCE_UPDATE PresenceUpdate.(deserialize contents) + | "TYPING_START" -> TYPING_START TypingStart.(deserialize contents) + | "USER_UPDATE" -> USER_UPDATE UserUpdate.(deserialize contents) + (* | "VOICE_STATE_UPDATE" -> VOICE_STATE_UPDATE contents *) + (* | "VOICE_SERVER_UPDATE" -> VOICE_SERVER_UPDATE contents *) + | "WEBHOOK_UPDATE" -> WEBHOOK_UPDATE WebhookUpdate.(deserialize contents) + | s -> UNKNOWN Unknown.(deserialize s contents) + +let dispatch ev = + match ev with + | READY d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> Ready.update_cache cache d); + !Dispatch.ready d + | RESUMED d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> Resumed.update_cache cache d); + !Dispatch.resumed d + | CHANNEL_CREATE d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> ChannelCreate.update_cache cache d); + !Dispatch.channel_create d + | CHANNEL_UPDATE d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> ChannelUpdate.update_cache cache d); + !Dispatch.channel_update d + | CHANNEL_DELETE d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> ChannelDelete.update_cache cache d); + !Dispatch.channel_delete d + | CHANNEL_PINS_UPDATE d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> ChannelPinsUpdate.update_cache cache d); + !Dispatch.channel_pins_update d + | GUILD_CREATE d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> GuildCreate.update_cache cache d); + !Dispatch.guild_create d + | GUILD_UPDATE d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> GuildUpdate.update_cache cache d); + !Dispatch.guild_update d + | GUILD_DELETE d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> GuildDelete.update_cache cache d); + !Dispatch.guild_delete d + | GUILD_BAN_ADD d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> GuildBanAdd.update_cache cache d); + !Dispatch.member_ban d + | GUILD_BAN_REMOVE d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> GuildBanRemove.update_cache cache d); + !Dispatch.member_unban d + | GUILD_EMOJIS_UPDATE d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> GuildEmojisUpdate.update_cache cache d); + !Dispatch.guild_emojis_update d + (* | GUILD_INTEGRATIONS_UPDATE d -> !Dispatch.integrations_update d *) + | GUILD_MEMBER_ADD d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> GuildMemberAdd.update_cache cache d); + !Dispatch.member_join d + | GUILD_MEMBER_REMOVE d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> GuildMemberRemove.update_cache cache d); + !Dispatch.member_leave d + | GUILD_MEMBER_UPDATE d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> GuildMemberUpdate.update_cache cache d); + !Dispatch.member_update d + | GUILD_MEMBERS_CHUNK d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> GuildMembersChunk.update_cache cache d); + !Dispatch.members_chunk d + | GUILD_ROLE_CREATE d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> GuildRoleCreate.update_cache cache d); + !Dispatch.role_create d + | GUILD_ROLE_UPDATE d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> GuildRoleUpdate.update_cache cache d); + !Dispatch.role_update d + | GUILD_ROLE_DELETE d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> GuildRoleDelete.update_cache cache d); + !Dispatch.role_delete d + | MESSAGE_CREATE d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> MessageCreate.update_cache cache d); + !Dispatch.message_create d + | MESSAGE_UPDATE d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> MessageUpdate.update_cache cache d); + !Dispatch.message_update d + | MESSAGE_DELETE d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> MessageDelete.update_cache cache d); + !Dispatch.message_delete d + | MESSAGE_DELETE_BULK d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> MessageDeleteBulk.update_cache cache d); + !Dispatch.message_delete_bulk d + | REACTION_ADD d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> ReactionAdd.update_cache cache d); + !Dispatch.reaction_add d + | REACTION_REMOVE d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> ReactionRemove.update_cache cache d); + !Dispatch.reaction_remove d + | REACTION_REMOVE_ALL d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> ReactionRemoveAll.update_cache cache d); + !Dispatch.reaction_remove_all d + | PRESENCE_UPDATE d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> PresenceUpdate.update_cache cache d); + !Dispatch.presence_update d + | TYPING_START d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> TypingStart.update_cache cache d); + !Dispatch.typing_start d + | USER_UPDATE d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> UserUpdate.update_cache cache d); + !Dispatch.user_update d + (* | VOICE_STATE_UPDATE d -> !Dispatch.voice_state_update d *) + (* | VOICE_SERVER_UPDATE d -> !Dispatch.voice_server_update d *) + | WEBHOOK_UPDATE d -> + Mvar.update_exn Cache.cache ~f:(fun cache -> WebhookUpdate.update_cache cache d); + !Dispatch.webhook_update d + | UNKNOWN d -> !Dispatch.unknown d + +let handle_event ~ev contents = + event_of_yojson ~contents ev + |> dispatch \ No newline at end of file diff --git a/lib/gateway/event.mli b/lib/gateway/event.mli new file mode 100644 index 0000000..4db3c84 --- /dev/null +++ b/lib/gateway/event.mli @@ -0,0 +1,49 @@ +(** Barebones of event dispatching. Most users will have no reason to look here. *) + +open Event_models + +(** Event dispatch type wrapper. Used internally. *) +type t = +| READY of Ready.t +| RESUMED of Resumed.t +| CHANNEL_CREATE of ChannelCreate.t +| CHANNEL_UPDATE of ChannelUpdate.t +| CHANNEL_DELETE of ChannelDelete.t +| CHANNEL_PINS_UPDATE of ChannelPinsUpdate.t +| GUILD_CREATE of GuildCreate.t +| GUILD_UPDATE of GuildUpdate.t +| GUILD_DELETE of GuildDelete.t +| GUILD_BAN_ADD of GuildBanAdd.t +| GUILD_BAN_REMOVE of GuildBanRemove.t +| GUILD_EMOJIS_UPDATE of GuildEmojisUpdate.t +(* | GUILD_INTEGRATIONS_UPDATE of Yojson.Safe.t *) +| GUILD_MEMBER_ADD of GuildMemberAdd.t +| GUILD_MEMBER_REMOVE of GuildMemberRemove.t +| GUILD_MEMBER_UPDATE of GuildMemberUpdate.t +| GUILD_MEMBERS_CHUNK of GuildMembersChunk.t +| GUILD_ROLE_CREATE of GuildRoleCreate.t +| GUILD_ROLE_UPDATE of GuildRoleUpdate.t +| GUILD_ROLE_DELETE of GuildRoleDelete.t +| MESSAGE_CREATE of MessageCreate.t +| MESSAGE_UPDATE of MessageUpdate.t +| MESSAGE_DELETE of MessageDelete.t +| MESSAGE_DELETE_BULK of MessageDeleteBulk.t +| REACTION_ADD of ReactionAdd.t +| REACTION_REMOVE of ReactionRemove.t +| REACTION_REMOVE_ALL of ReactionRemoveAll.t +| PRESENCE_UPDATE of PresenceUpdate.t +| TYPING_START of TypingStart.t +| USER_UPDATE of UserUpdate.t +(* | VOICE_STATE_UPDATE of Yojson.Safe.t *) +(* | VOICE_SERVER_UPDATE of Yojson.Safe.t *) +| WEBHOOK_UPDATE of WebhookUpdate.t +| UNKNOWN of Unknown.t + +(** Used to convert an event string and payload into a t wrapper type. *) +val event_of_yojson : contents:Yojson.Safe.t -> string -> t + +(** Sends the event to the registered handler. *) +val dispatch : t -> unit + +(** Wrapper to other functions. This is called from the shards. *) +val handle_event : ev:string -> Yojson.Safe.t -> unit \ No newline at end of file diff --git a/lib/gateway/opcode.ml b/lib/gateway/opcode.ml new file mode 100644 index 0000000..32ab5b4 --- /dev/null +++ b/lib/gateway/opcode.ml @@ -0,0 +1,54 @@ +type t = + | DISPATCH + | HEARTBEAT + | IDENTIFY + | STATUS_UPDATE + | VOICE_STATE_UPDATE + | RESUME + | RECONNECT + | REQUEST_GUILD_MEMBERS + | INVALID_SESSION + | HELLO + | HEARTBEAT_ACK + +exception Invalid_Opcode of int + +let to_int = function + | DISPATCH -> 0 + | HEARTBEAT -> 1 + | IDENTIFY -> 2 + | STATUS_UPDATE -> 3 + | VOICE_STATE_UPDATE -> 4 + | RESUME -> 6 + | RECONNECT -> 7 + | REQUEST_GUILD_MEMBERS -> 8 + | INVALID_SESSION -> 9 + | HELLO -> 10 + | HEARTBEAT_ACK -> 11 + +let from_int = function + | 0 -> DISPATCH + | 1 -> HEARTBEAT + | 2 -> IDENTIFY + | 3 -> STATUS_UPDATE + | 4 -> VOICE_STATE_UPDATE + | 6 -> RESUME + | 7 -> RECONNECT + | 8 -> REQUEST_GUILD_MEMBERS + | 9 -> INVALID_SESSION + | 10 -> HELLO + | 11 -> HEARTBEAT_ACK + | op -> raise (Invalid_Opcode op) + +let to_string = function + | DISPATCH -> "DISPATCH" + | HEARTBEAT -> "HEARTBEAT" + | IDENTIFY -> "IDENTIFY" + | STATUS_UPDATE -> "STATUS_UPDATE" + | VOICE_STATE_UPDATE -> "VOICE_STATE_UPDATE" + | RESUME -> "RESUME" + | RECONNECT -> "RECONNECT" + | REQUEST_GUILD_MEMBERS -> "REQUEST_GUILD_MEMBER" + | INVALID_SESSION -> "INVALID_SESSION" + | HELLO -> "HELLO" + | HEARTBEAT_ACK -> "HEARTBEAT_ACK" \ No newline at end of file diff --git a/lib/gateway/opcode.mli b/lib/gateway/opcode.mli new file mode 100644 index 0000000..9fa5b96 --- /dev/null +++ b/lib/gateway/opcode.mli @@ -0,0 +1,29 @@ +(** Internal Opcode abstractions. *) + +(** Type of known opcodes. *) +type t = +| DISPATCH +| HEARTBEAT +| IDENTIFY +| STATUS_UPDATE +| VOICE_STATE_UPDATE +| RESUME +| RECONNECT +| REQUEST_GUILD_MEMBERS +| INVALID_SESSION +| HELLO +| HEARTBEAT_ACK + +(** Raised when receiving an invalid opcode. This should never occur. *) +exception Invalid_Opcode of int + +(** Converts an opcode to its integer form for outgoing frames. *) +val to_int : t -> int + +(** Converts an integer to an opcode for incoming frames. + Raise {!Invalid_Opcode} Raised when an unkown opcode is received. +*) +val from_int : int -> t + +(** Converts and opcode to a human-readable string. Used for logging purposes. *) +val to_string : t -> string \ No newline at end of file diff --git a/lib/gateway/sharder.ml b/lib/gateway/sharder.ml new file mode 100644 index 0000000..9fcb10d --- /dev/null +++ b/lib/gateway/sharder.ml @@ -0,0 +1,393 @@ +open Async +open Core +open Decompress +open Websocket_async + +exception Invalid_Payload +exception Failure_to_Establish_Heartbeat +exception Inflate_error of Zlib_inflate.error + +let window = Window.create ~witness:B.bytes + +let decompress src = + let in_buf = Bytes.create 0xFFFF in + let out_buf = Bytes.create 0xFFFF in + let window = Window.reset window in + let pos = ref 0 in + let src_len = String.length src in + let res = Buffer.create (src_len) in + Zlib_inflate.bytes in_buf out_buf + (fun dst -> + let len = min 0xFFFF (src_len - !pos) in + Caml.Bytes.blit_string src !pos dst 0 len; + pos := !pos + len; + len) + (fun obuf len -> + Buffer.add_subbytes res obuf 0 len; 0xFFFF) + (Zlib_inflate.default ~witness:B.bytes window) + |> function + | Ok _ -> Buffer.contents res + | Error exn -> raise (Inflate_error exn) + +module Shard = struct + type shard = + { compress: bool + ; id: int * int + ; hb_interval: Time.Span.t Ivar.t + ; hb_stopper: unit Ivar.t + ; large_threshold: int + ; pipe: Frame.t Pipe.Reader.t * Frame.t Pipe.Writer.t + ; ready: unit Ivar.t + ; seq: int + ; session: string option + ; url: string + ; _internal: Reader.t * Writer.t + } + + type 'a t = + { mutable state: 'a + ; mutable stopped: bool + ; mutable can_resume: bool + } + + let identify_lock = Mvar.create () + let _ = Mvar.set identify_lock () + + let parse ~compress (frame:[`Ok of Frame.t | `Eof]) = + match frame with + | `Ok s -> begin + let open Frame.Opcode in + match s.opcode with + | Text -> `Ok (Yojson.Safe.from_string s.content) + | Binary -> + if compress then `Ok (decompress s.content |> Yojson.Safe.from_string) + else `Error "Failed to decompress" + | Close -> `Close s.extension + | op -> + let op = Frame.Opcode.to_string op in + `Error ("Unexpected opcode " ^ op) + end + | `Eof -> `Eof + + let push_frame ?payload ~ev shard = + let content = match payload with + | None -> "" + | Some p -> + Yojson.Safe.to_string @@ `Assoc [ + "op", `Int (Opcode.to_int ev); + "d", p; + ] + in + let (_, write) = shard.pipe in + Pipe.write_if_open write @@ Frame.create ~content () + >>| fun () -> + shard + + let heartbeat shard = + match shard.seq with + | 0 -> return shard + | i -> + Logs.debug (fun m -> m "Heartbeating - Shard: [%d, %d] - Seq: %d" (fst shard.id) (snd shard.id) (shard.seq)); + push_frame ~payload:(`Int i) ~ev:HEARTBEAT shard + + let dispatch ~payload shard = + let module J = Yojson.Safe.Util in + let seq = J.(member "s" payload |> to_int) in + let t = J.(member "t" payload |> to_string) in + let data = J.member "d" payload in + let session = if t = "READY" then begin + Ivar.fill_if_empty shard.ready (); + Clock.after (Core.Time.Span.create ~sec:5 ()) + >>> (fun _ -> Mvar.put identify_lock () >>> ignore); + J.(member "session_id" data |> to_string_option) + end else None in + Event.handle_event ~ev:t data; + return + { shard with seq = seq + ; session = session + } + + let set_status ~(status:Yojson.Safe.t) shard = + let payload = match status with + | `Assoc ["name", `String name; "type", `Int t] + | `Assoc ["type", `Int t; "name", `String name] -> + `Assoc [ + "status", `String "online"; + "afk", `Bool false; + "since", `Null; + "game", `Assoc [ + "name", `String name; + "type", `Int t; + ] + ] + | `String name -> + `Assoc [ + "status", `String "online"; + "afk", `Bool false; + "since", `Null; + "game", `Assoc [ + "name", `String name; + "type", `Int 0 + ] + ] + | _ -> raise Invalid_Payload + in + Ivar.read shard.ready >>= fun _ -> + push_frame ~payload ~ev:STATUS_UPDATE shard + + let request_guild_members ?(query="") ?(limit=0) ~guild shard = + let payload = `Assoc [ + "guild_id", `String (Int.to_string guild); + "query", `String query; + "limit", `Int limit; + ] in + Ivar.read shard.ready >>= fun _ -> + push_frame ~payload ~ev:REQUEST_GUILD_MEMBERS shard + + let initialize ?data shard = + let module J = Yojson.Safe.Util in + let _ = match data with + | Some data -> Ivar.fill_if_empty shard.hb_interval (Time.Span.create ~ms:J.(member "heartbeat_interval" data |> to_int) ()) + | None -> raise Failure_to_Establish_Heartbeat + in + let shards = [`Int (fst shard.id); `Int (snd shard.id)] in + match shard.session with + | None -> begin + Mvar.take identify_lock >>= fun () -> + Logs.debug (fun m -> m "Identifying shard [%d, %d]" (fst shard.id) (snd shard.id)); + let payload = `Assoc + [ "token", `String !Client_options.token + ; "properties", `Assoc + [ "$os", `String Sys.os_type + ; "$device", `String "dis.ml" + ; "$browser", `String "dis.ml" + ] + ; "compress", `Bool shard.compress + ; "large_threshold", `Int shard.large_threshold + ; "shard", `List shards + ] + in + push_frame ~payload ~ev:IDENTIFY shard + >>| fun s -> s + end + | Some s -> + let payload = `Assoc + [ "token", `String !Client_options.token + ; "session_id", `String s + ; "seq", `Int shard.seq + ] + in + push_frame ~payload ~ev:RESUME shard + + let handle_frame ~f shard = + let module J = Yojson.Safe.Util in + let op = J.(member "op" f |> to_int) |> Opcode.from_int in + match op with + | DISPATCH -> dispatch ~payload:f shard + | HEARTBEAT -> heartbeat shard + | INVALID_SESSION -> begin + Logs.err (fun m -> m "Invalid Session on Shard [%d, %d]: %s" (fst shard.id) (snd shard.id) (Yojson.Safe.pretty_to_string f)); + if J.(member "d" f |> to_bool) then + initialize shard + else begin + initialize { shard with session = None; } + end + end + | RECONNECT -> initialize shard + | HELLO -> initialize ~data:(J.member "d" f) shard + | HEARTBEAT_ACK -> return shard + | opcode -> + Logs.warn (fun m -> m "Invalid Opcode: %s" (Opcode.to_string opcode)); + return shard + + let rec make_client + ~initialized + ~extra_headers + ~app_to_ws + ~ws_to_app + ~net_to_ws + ~ws_to_net + ?(ms=500) + uri = + client + ~initialized + ~extra_headers + ~app_to_ws + ~ws_to_app + ~net_to_ws + ~ws_to_net + uri + >>> fun res -> + match res with + | Ok () -> () + | Error _ -> + let backoff = Time.Span.create ~ms () in + Clock.after backoff >>> (fun () -> + make_client + ~initialized + ~extra_headers + ~app_to_ws + ~ws_to_app + ~net_to_ws + ~ws_to_net + ~ms:(min 60_000 (ms * 2)) + uri) + + + let create ~url ~shards ?(compress=true) ?(large_threshold=100) () = + let open Core in + let uri = (url ^ "?v=6&encoding=json") |> Uri.of_string in + let extra_headers = Http.Base.process_request_headers () in + let host = Option.value_exn ~message:"no host in uri" Uri.(host uri) in + let port = + match Uri.port uri, Uri_services.tcp_port_of_uri uri with + | Some p, _ -> p + | None, Some p -> p + | _ -> 443 in + let scheme = Option.value_exn ~message:"no scheme in uri" Uri.(scheme uri) in + let tcp_fun (net_to_ws, ws_to_net) = + let (app_to_ws, write) = Pipe.create () in + let (read, ws_to_app) = Pipe.create () in + let initialized = Ivar.create () in + make_client + ~initialized + ~extra_headers + ~app_to_ws + ~ws_to_app + ~net_to_ws + ~ws_to_net + uri; + Ivar.read initialized >>| fun () -> + { pipe = (read, write) + ; ready = Ivar.create () + ; hb_interval = Ivar.create () + ; hb_stopper = Ivar.create () + ; seq = 0 + ; id = shards + ; session = None + ; url + ; large_threshold + ; compress + ; _internal = (net_to_ws, ws_to_net) + } + in + match Unix.getaddrinfo host (string_of_int port) [] with + | [] -> failwithf "DNS resolution failed for %s" host () + | { ai_addr; _ } :: _ -> + let addr = + match scheme, ai_addr with + | _, ADDR_UNIX path -> `Unix_domain_socket path + | "https", ADDR_INET (h, p) + | "wss", ADDR_INET (h, p) -> + let h = Ipaddr_unix.of_inet_addr h in + `OpenSSL (h, p, Conduit_async.V2.Ssl.Config.create ()) + | _, ADDR_INET (h, p) -> + let h = Ipaddr_unix.of_inet_addr h in + `TCP (h, p) + in + Conduit_async.V2.connect addr >>= tcp_fun + + let shutdown ?(clean=false) ?(restart=true) t = + let _ = clean in + t.can_resume <- restart; + t.stopped <- true; + Logs.debug (fun m -> m "Performing shutdown. Shard [%d, %d]" (fst t.state.id) (snd t.state.id)); + Pipe.write_if_open (snd t.state.pipe) (Frame.close 1001) + >>= fun () -> + Ivar.fill_if_empty t.state.hb_stopper (); + Pipe.close_read (fst t.state.pipe); + Writer.close (snd t.state._internal) +end + +type t = { shards: (Shard.shard Shard.t) list } + +let start ?count ?compress ?large_threshold () = + let module J = Yojson.Safe.Util in + Http.get_gateway_bot () >>= fun data -> + let data = match data with + | Ok d -> d + | Error e -> Error.raise e + in + let url = J.(member "url" data |> to_string) in + let count = match count with + | Some c -> c + | None -> J.(member "shards" data |> to_int) + in + let shard_list = (0, count) in + Logs.info (fun m -> m "Connecting to %s" url); + let rec ev_loop (t:Shard.shard Shard.t) = + let step (t:Shard.shard Shard.t) = + Pipe.read (fst t.state.pipe) >>= fun frame -> + begin match Shard.parse ~compress:t.state.compress frame with + | `Ok f -> + Shard.handle_frame ~f t.state >>| fun s -> + t.state <- s + | `Close c -> + Logs.warn (fun m -> m "Close frame received. Code: %d" c); + Shard.shutdown t + | `Error e -> + Logs.warn (fun m -> m "Websocket soft error: %s" e); + return () + | `Eof -> + Logs.warn (fun m -> m "Websocket closed unexpectedly"); + Shard.shutdown t + end >>| fun () -> t + in + if t.stopped then return () + else step t >>= ev_loop + in + let rec gen_shards l a = + match l with + | (id, total) when id >= total -> return a + | (id, total) -> + let wrap ?(reuse:Shard.shard Shard.t option) state = match reuse with + | Some t -> + t.state <- state; + t.stopped <- false; + return t + | None -> + return Shard.{ state + ; stopped = false + ; can_resume = true + } + in + let create () = + Shard.create ~url ~shards:(id, total) ?compress ?large_threshold () + in + let rec bind (t:Shard.shard Shard.t) = + let _ = Ivar.read t.state.hb_interval >>> fun hb -> + Clock.every' + ~stop:(Ivar.read t.state.hb_stopper) + ~continue_on_error:true + hb (fun () -> Shard.heartbeat t.state >>| ignore) in + ev_loop t >>> (fun () -> Logs.debug (fun m -> m "Event loop stopped.")); + Pipe.closed (fst t.state.pipe) >>> (fun () -> if t.can_resume then + create () >>= wrap ~reuse:t >>= bind >>> ignore); + return t + in + create () >>= wrap >>= bind >>= fun t -> + gen_shards (id+1, total) (t :: a) + in + gen_shards shard_list [] + >>| fun shards -> + { shards } + +let set_status ~status sharder = + Deferred.all @@ List.map ~f:(fun t -> + Shard.set_status ~status t.state + ) sharder.shards + +let set_status_with ~f sharder = + Deferred.all @@ List.map ~f:(fun t -> + Shard.set_status ~status:(f t.state) t.state + ) sharder.shards + +let request_guild_members ?query ?limit ~guild sharder = + Deferred.all @@ List.map ~f:(fun t -> + Shard.request_guild_members ~guild ?query ?limit t.state + ) sharder.shards + +let shutdown_all ?restart sharder = + Deferred.all @@ List.map ~f:(fun t -> + Shard.shutdown ~clean:true ?restart t + ) sharder.shards \ No newline at end of file diff --git a/lib/gateway/sharder.mli b/lib/gateway/sharder.mli new file mode 100644 index 0000000..a5f18e6 --- /dev/null +++ b/lib/gateway/sharder.mli @@ -0,0 +1,102 @@ +(** Internal sharding manager. Most of this is accessed through {!Client}. *) + +open Core +open Async +open Websocket_async + +exception Invalid_Payload +exception Failure_to_Establish_Heartbeat + +type t + +(** Start the Sharder. This is called by {!Client.start}. *) +val start : + ?count:int -> + ?compress:bool -> + ?large_threshold:int -> + unit -> + t Deferred.t + +(** Module representing a single shard. *) +module Shard : sig + (** Representation of the state of a shard. *) + type shard = { + compress: bool; (** Whether to compress payloads. *) + id: int * int; (** A tuple as expected by Discord. First element is the current shard index, second element is the total shard count. *) + hb_interval: Time.Span.t Ivar.t; (** Time span between heartbeats, wrapped in an Ivar. *) + hb_stopper: unit Ivar.t; (** Stops the heartbeat sequencer when filled. *) + large_threshold: int; (** Minimum number of members needed for a guild to be considered large. *) + pipe: Frame.t Pipe.Reader.t * Frame.t Pipe.Writer.t; (** Raw frame IO pipe used for websocket communications. *) + ready: unit Ivar.t; (** A simple Ivar indicating if the shard has received READY. *) + seq: int; (** Current sequence number *) + session: string option; (** Session id, if one exists. *) + url: string; (** The websocket URL in use. *) + _internal: Reader.t * Writer.t; + } + + (** Wrapper around an internal state, used to wrap {!shard}. *) + type 'a t = { + mutable state: 'a; + mutable stopped: bool; + mutable can_resume: bool; + } + + (** Send a heartbeat to Discord. This is handled automatically. *) + val heartbeat : + shard -> + shard Deferred.t + + (** Set the status of the shard. *) + val set_status : + status:Yojson.Safe.t -> + shard -> + shard Deferred.t + + (** Request guild members for the shard's guild. Causes dispatch of multiple {{!Dispatch.members_chunk}member chunk} events. *) + val request_guild_members : + ?query:string -> + ?limit:int -> + guild:Snowflake.t -> + shard -> + shard Deferred.t + + (** Create a new shard *) + val create : + url:string -> + shards:int * int -> + ?compress:bool -> + ?large_threshold:int -> + unit -> + shard Deferred.t + + val shutdown : + ?clean:bool -> + ?restart:bool -> + shard t -> + unit Deferred.t +end + +(** Calls {!Shard.set_status} for each shard registered with the sharder. *) +val set_status : + status:Yojson.Safe.t -> + t -> + Shard.shard list Deferred.t + +(** Like {!set_status} but takes a function with a {{!Shard.shard}shard} as its parameter and {{!Yojson.Safe.t}json} for its return. *) +val set_status_with : + f:(Shard.shard -> Yojson.Safe.t) -> + t -> + Shard.shard list Deferred.t + +(** Calls {!Shard.request_guild_members} for each shard registered with the sharder. *) +val request_guild_members : + ?query:string -> + ?limit:int -> + guild:Snowflake.t -> + t -> + Shard.shard list Deferred.t + +val shutdown_all : + ?restart:bool -> + t -> + unit list Deferred.t -- cgit v1.2.3