From 26e9b5bfac840c95744bbdd0ab356842e9efa72f Mon Sep 17 00:00:00 2001 From: Rapptz Date: Tue, 19 Mar 2019 06:21:39 -0400 Subject: [commands] Add Bot.reload_extension for atomic loading. Also do atomic loading in Bot.load_extension --- discord/ext/commands/bot.py | 149 +++++++++++++++++++++++++++++++------------- 1 file changed, 106 insertions(+), 43 deletions(-) (limited to 'discord') diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index 157fbe21..a3301bd8 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -523,6 +523,65 @@ class BotBase(GroupMixin): # extensions + def _remove_module_references(self, name): + # find all references to the module + # remove the cogs registered from the module + for cogname, cog in self._cogs.copy().items(): + if _is_submodule(name, cog.__module__): + self.remove_cog(cogname) + + # remove all the commands from the module + for cmd in self.all_commands.copy().values(): + if cmd.module is not None and _is_submodule(name, cmd.module): + if isinstance(cmd, GroupMixin): + cmd.recursively_remove_all_commands() + self.remove_command(cmd.name) + + # remove all the listeners from the module + for event_list in self.extra_events.copy().values(): + remove = [] + for index, event in enumerate(event_list): + if event.__module__ is not None and _is_submodule(name, event.__module__): + remove.append(index) + + for index in reversed(remove): + del event_list[index] + + def _call_module_finalizers(self, lib, key): + try: + func = getattr(lib, 'teardown') + except AttributeError: + pass + else: + try: + func(self) + except Exception: + pass + finally: + self._extensions.pop(key, None) + sys.modules.pop(key, None) + name = lib.__name__ + for module in list(sys.modules.keys()): + if _is_submodule(name, module): + del sys.modules[module] + + def _load_from_module_spec(self, lib, key): + # precondition: key not in self._extensions + try: + setup = getattr(lib, 'setup') + except AttributeError: + del sys.modules[key] + raise discord.ClientException('extension {!r} ({!r}) does not have a setup function.'.format(key, lib)) + + try: + setup(self) + except Exception: + self._remove_module_references(lib.__name__) + self._call_module_finalizers(lib, key) + raise + else: + self._extensions[key] = lib + def load_extension(self, name): """Loads an extension. @@ -546,19 +605,16 @@ class BotBase(GroupMixin): The extension does not have a setup function. ImportError The extension could not be imported. + Exception + Any other exception raised by the extension will be raised back + to the caller. """ if name in self._extensions: return lib = importlib.import_module(name) - if not hasattr(lib, 'setup'): - del lib - del sys.modules[name] - raise discord.ClientException('extension does not have a setup function') - - lib.setup(self) - self._extensions[name] = lib + self._load_from_module_spec(lib, name) def unload_extension(self, name): """Unloads an extension. @@ -583,49 +639,56 @@ class BotBase(GroupMixin): if lib is None: return - lib_name = lib.__name__ + self._remove_module_references(lib.__name__) + self._call_module_finalizers(lib, name) - # find all references to the module + def reload_extension(self, name): + """Atomically reloads an extension. - # remove the cogs registered from the module - for cogname, cog in self._cogs.copy().items(): - if _is_submodule(lib_name, cog.__module__): - self.remove_cog(cogname) + This replaces the extension with the same extension, only refreshed. This is + equivalent to a :meth:`unload_extension` followed by a :meth:`load_extension` + except done in an atomic way. That is, if an operation fails mid-reload then + the bot will roll-back to the prior working state. - # remove all the commands from the module - for cmd in self.all_commands.copy().values(): - if cmd.module is not None and _is_submodule(lib_name, cmd.module): - if isinstance(cmd, GroupMixin): - cmd.recursively_remove_all_commands() - self.remove_command(cmd.name) + Parameters + ------------ + name: :class:`str` + The extension name to reload. It must be dot separated like + regular Python imports if accessing a sub-module. e.g. + ``foo.test`` if you want to import ``foo/test.py``. - # remove all the listeners from the module - for event_list in self.extra_events.copy().values(): - remove = [] - for index, event in enumerate(event_list): - if event.__module__ is not None and _is_submodule(lib_name, event.__module__): - remove.append(index) + Raises + ------- + Exception + Any exception raised by the extension will be raised back + to the caller. + """ - for index in reversed(remove): - del event_list[index] + lib = self._extensions.get(name) + if lib is None: + return + + # get the previous module states from sys modules + modules = { + name: module + for name, module in sys.modules.items() + if _is_submodule(lib.__name__, name) + } try: - func = getattr(lib, 'teardown') - except AttributeError: - pass - else: - try: - func(self) - except Exception: - pass - finally: - # finally remove the import.. - del lib - del self._extensions[name] - del sys.modules[name] - for module in list(sys.modules.keys()): - if _is_submodule(lib_name, module): - del sys.modules[module] + # Unload and then load the module... + self._remove_module_references(lib.__name__) + self._call_module_finalizers(lib, name) + self.load_extension(name) + except Exception as e: + # if the load failed, the remnants should have been + # cleaned from the load_extension function call + # so let's load it from our old compiled library. + self._load_from_module_spec(lib, name) + + # revert sys.modules back to normal and raise back to caller + sys.modules.update(modules) + raise @property def extensions(self): -- cgit v1.2.3