aboutsummaryrefslogtreecommitdiff
path: root/discord
diff options
context:
space:
mode:
Diffstat (limited to 'discord')
-rw-r--r--discord/ext/commands/bot.py149
1 files changed, 106 insertions, 43 deletions
diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py
index 157fbe21..a3301bd8 100644
--- a/discord/ext/commands/bot.py
+++ b/discord/ext/commands/bot.py
@@ -523,6 +523,65 @@ class BotBase(GroupMixin):
# extensions
+ def _remove_module_references(self, name):
+ # find all references to the module
+ # remove the cogs registered from the module
+ for cogname, cog in self._cogs.copy().items():
+ if _is_submodule(name, cog.__module__):
+ self.remove_cog(cogname)
+
+ # remove all the commands from the module
+ for cmd in self.all_commands.copy().values():
+ if cmd.module is not None and _is_submodule(name, cmd.module):
+ if isinstance(cmd, GroupMixin):
+ cmd.recursively_remove_all_commands()
+ self.remove_command(cmd.name)
+
+ # remove all the listeners from the module
+ for event_list in self.extra_events.copy().values():
+ remove = []
+ for index, event in enumerate(event_list):
+ if event.__module__ is not None and _is_submodule(name, event.__module__):
+ remove.append(index)
+
+ for index in reversed(remove):
+ del event_list[index]
+
+ def _call_module_finalizers(self, lib, key):
+ try:
+ func = getattr(lib, 'teardown')
+ except AttributeError:
+ pass
+ else:
+ try:
+ func(self)
+ except Exception:
+ pass
+ finally:
+ self._extensions.pop(key, None)
+ sys.modules.pop(key, None)
+ name = lib.__name__
+ for module in list(sys.modules.keys()):
+ if _is_submodule(name, module):
+ del sys.modules[module]
+
+ def _load_from_module_spec(self, lib, key):
+ # precondition: key not in self._extensions
+ try:
+ setup = getattr(lib, 'setup')
+ except AttributeError:
+ del sys.modules[key]
+ raise discord.ClientException('extension {!r} ({!r}) does not have a setup function.'.format(key, lib))
+
+ try:
+ setup(self)
+ except Exception:
+ self._remove_module_references(lib.__name__)
+ self._call_module_finalizers(lib, key)
+ raise
+ else:
+ self._extensions[key] = lib
+
def load_extension(self, name):
"""Loads an extension.
@@ -546,19 +605,16 @@ class BotBase(GroupMixin):
The extension does not have a setup function.
ImportError
The extension could not be imported.
+ Exception
+ Any other exception raised by the extension will be raised back
+ to the caller.
"""
if name in self._extensions:
return
lib = importlib.import_module(name)
- if not hasattr(lib, 'setup'):
- del lib
- del sys.modules[name]
- raise discord.ClientException('extension does not have a setup function')
-
- lib.setup(self)
- self._extensions[name] = lib
+ self._load_from_module_spec(lib, name)
def unload_extension(self, name):
"""Unloads an extension.
@@ -583,49 +639,56 @@ class BotBase(GroupMixin):
if lib is None:
return
- lib_name = lib.__name__
+ self._remove_module_references(lib.__name__)
+ self._call_module_finalizers(lib, name)
- # find all references to the module
+ def reload_extension(self, name):
+ """Atomically reloads an extension.
- # remove the cogs registered from the module
- for cogname, cog in self._cogs.copy().items():
- if _is_submodule(lib_name, cog.__module__):
- self.remove_cog(cogname)
+ This replaces the extension with the same extension, only refreshed. This is
+ equivalent to a :meth:`unload_extension` followed by a :meth:`load_extension`
+ except done in an atomic way. That is, if an operation fails mid-reload then
+ the bot will roll-back to the prior working state.
- # remove all the commands from the module
- for cmd in self.all_commands.copy().values():
- if cmd.module is not None and _is_submodule(lib_name, cmd.module):
- if isinstance(cmd, GroupMixin):
- cmd.recursively_remove_all_commands()
- self.remove_command(cmd.name)
+ Parameters
+ ------------
+ name: :class:`str`
+ The extension name to reload. It must be dot separated like
+ regular Python imports if accessing a sub-module. e.g.
+ ``foo.test`` if you want to import ``foo/test.py``.
- # remove all the listeners from the module
- for event_list in self.extra_events.copy().values():
- remove = []
- for index, event in enumerate(event_list):
- if event.__module__ is not None and _is_submodule(lib_name, event.__module__):
- remove.append(index)
+ Raises
+ -------
+ Exception
+ Any exception raised by the extension will be raised back
+ to the caller.
+ """
- for index in reversed(remove):
- del event_list[index]
+ lib = self._extensions.get(name)
+ if lib is None:
+ return
+
+ # get the previous module states from sys modules
+ modules = {
+ name: module
+ for name, module in sys.modules.items()
+ if _is_submodule(lib.__name__, name)
+ }
try:
- func = getattr(lib, 'teardown')
- except AttributeError:
- pass
- else:
- try:
- func(self)
- except Exception:
- pass
- finally:
- # finally remove the import..
- del lib
- del self._extensions[name]
- del sys.modules[name]
- for module in list(sys.modules.keys()):
- if _is_submodule(lib_name, module):
- del sys.modules[module]
+ # Unload and then load the module...
+ self._remove_module_references(lib.__name__)
+ self._call_module_finalizers(lib, name)
+ self.load_extension(name)
+ except Exception as e:
+ # if the load failed, the remnants should have been
+ # cleaned from the load_extension function call
+ # so let's load it from our old compiled library.
+ self._load_from_module_spec(lib, name)
+
+ # revert sys.modules back to normal and raise back to caller
+ sys.modules.update(modules)
+ raise
@property
def extensions(self):