From 3b9f4aae95ac38d9ca25a22cf7628a0e7d2e483a Mon Sep 17 00:00:00 2001 From: ddorn Date: Thu, 30 Apr 2020 17:26:33 +0200 Subject: [PATCH] :sparkles: CustomBot with full reload --- src/cogs/dev.py | 20 ++++++++++-- src/cogs/errors.py | 23 +++++++++---- src/cogs/misc.py | 9 ++++-- src/cogs/teams.py | 3 +- src/cogs/tirages.py | 19 ++++++++--- src/constants.py | 7 ++++ src/core.py | 72 +++++++++++++++++++++++++++++++++++++++++ src/tfjm_discord_bot.py | 8 +++-- src/utils.py | 22 ++++++------- 9 files changed, 149 insertions(+), 34 deletions(-) create mode 100644 src/core.py diff --git a/src/cogs/dev.py b/src/cogs/dev.py index 262df0f..3a7494d 100644 --- a/src/cogs/dev.py +++ b/src/cogs/dev.py @@ -1,4 +1,6 @@ import code +import sys +from importlib import reload from pprint import pprint import discord @@ -8,13 +10,20 @@ from discord.ext.commands import Cog from discord.utils import get from src.constants import * +from src.core import CustomBot - -COGS_SHORTCUTS = {"d": "dev", "ts": "teams", "t": "tirages", "m": "misc", "e": "errors"} +COGS_SHORTCUTS = { + "d": "tirages", + "e": "errors", + "m": "misc", + "t": "teams", + "u": "src.utils", + "v": "dev", +} class DevCog(Cog, name="Dev tools"): - def __init__(self, bot: Bot): + def __init__(self, bot: CustomBot): self.bot = bot @command(name="interrupt") @@ -69,6 +78,11 @@ class DevCog(Cog, name="Dev tools"): possibles: `teams`, `tirages`, `dev`. """ + if name is None: + self.bot.reload() + await ctx.send(":tada: The bot was reloaded !") + return + names = [name] if name else list(COGS_SHORTCUTS.values()) for name in names: diff --git a/src/cogs/errors.py b/src/cogs/errors.py index fa17d12..174cd43 100644 --- a/src/cogs/errors.py +++ b/src/cogs/errors.py @@ -5,8 +5,8 @@ import discord from discord.ext.commands import * from discord.utils import maybe_coroutine -from src.errors import UnwantedCommand - +from src.core import CustomBot +from src.errors import UnwantedCommand, TfjmError # Global variable and function because I'm too lazy to make a metaclass handlers = {} @@ -29,6 +29,9 @@ def handles(error_type): class ErrorsCog(Cog): """This cog defines all the handles for errors.""" + def __init__(self, bot: CustomBot): + self.bot = bot + @Cog.listener() async def on_command_error(self, ctx: Context, error: CommandError): print(repr(error), file=sys.stderr) @@ -47,10 +50,11 @@ class ErrorsCog(Cog): msg = await maybe_coroutine(handler, self, ctx, error) if msg: - await ctx.send(msg) + message = await ctx.send(msg) + await self.bot.wait_for_bin(ctx.message.author, message) @handles(UnwantedCommand) - async def on_unwanted_command(self, ctx, error): + async def on_unwanted_command(self, ctx, error: UnwantedCommand): await ctx.message.delete() author: discord.Message await ctx.author.send( @@ -59,14 +63,19 @@ class ErrorsCog(Cog): + "\nC'est pas grave, c'est juste pour ne pas encombrer " "le chat lors du tirage." ) - await ctx.author.send("Raison: " + error.original.msg) + await ctx.author.send("Raison: " + error.msg) + + @handles(TfjmError) + async def on_tfjm_error(self, ctx: Context, error: TfjmError): + msg = await ctx.send(error.msg) + await self.bot.wait_for_bin(ctx.author, msg) @handles(CommandInvokeError) async def on_command_invoke_error(self, ctx, error): specific_handler = handlers.get(type(error.original)) if specific_handler: - return await specific_handler(self, ctx, error) + return await specific_handler(self, ctx, error.original) traceback.print_tb(error.original.__traceback__, file=sys.stderr) return ( @@ -91,4 +100,4 @@ class ErrorsCog(Cog): def setup(bot): - bot.add_cog(ErrorsCog()) + bot.add_cog(ErrorsCog(bot)) diff --git a/src/cogs/misc.py b/src/cogs/misc.py index 60d8c71..d883f47 100644 --- a/src/cogs/misc.py +++ b/src/cogs/misc.py @@ -17,13 +17,15 @@ from discord.ext.commands import ( Group, ) +from src import utils from src.constants import * from src.constants import Emoji -from src.utils import has_role +from src.core import CustomBot +from src.utils import has_role, start_time class MiscCog(Cog, name="Divers"): - def __init__(self, bot: Bot): + def __init__(self, bot: CustomBot): self.bot = bot self.show_hidden = False self.verify_checks = True @@ -55,6 +57,7 @@ class MiscCog(Cog, name="Divers"): await message.add_reaction(Emoji.JOY) await message.add_reaction(Emoji.SOB) + await self.bot.wait_for_bin(ctx.message.author, message) @command(name="status") @commands.has_role(Role.CNO) @@ -65,7 +68,7 @@ class MiscCog(Cog, name="Divers"): benevoles = [g for g in guild.members if has_role(g, Role.BENEVOLE)] participants = [g for g in guild.members if has_role(g, Role.PARTICIPANT)] no_role = [g for g in guild.members if g.top_role == guild.default_role] - uptime = datetime.timedelta(seconds=round(time() - START_TIME)) + uptime = datetime.timedelta(seconds=round(time() - start_time())) infos = { "Bénévoles": len(benevoles), diff --git a/src/cogs/teams.py b/src/cogs/teams.py index c551944..e417ce0 100644 --- a/src/cogs/teams.py +++ b/src/cogs/teams.py @@ -7,13 +7,14 @@ from discord.ext.commands import Cog, Bot, group, Context from discord.utils import get, find from src.constants import * +from src.core import CustomBot from src.utils import has_role Team = namedtuple("Team", ["name", "trigram", "tournoi", "secret", "status"]) class TeamsCog(Cog, name="Teams"): - def __init__(self, bot: Bot): + def __init__(self, bot: CustomBot): self.bot = bot self.teams = self.load_teams() diff --git a/src/cogs/tirages.py b/src/cogs/tirages.py index 0ef7cf2..63a9b63 100644 --- a/src/cogs/tirages.py +++ b/src/cogs/tirages.py @@ -14,6 +14,7 @@ from discord.ext.commands import group, Cog, Context from discord.utils import get from src.constants import * +from src.core import CustomBot from src.errors import TfjmError, UnwantedCommand __all__ = ["Tirage", "TirageCog"] @@ -606,7 +607,7 @@ class TirageOrderPhase(OrderPhase): class TirageCog(Cog, name="Tirages"): def __init__(self, bot): - self.bot: commands.Bot = bot + self.bot: CustomBot = bot # We retrieve the global variable. # We don't want tirages to be ust an attribute @@ -627,13 +628,21 @@ class TirageCog(Cog, name="Tirages"): if channel in self.tirages: await self.tirages[channel].dice(ctx, n) else: + if n == 0: + raise TfjmError(f"Un dé sans faces ? Le concept m'intéresse...") if n < 1: - raise TfjmError(f"Je ne peux pas lancer un dé à {n} faces, désolé.") + raise TfjmError( + f"Je ne peux pas lancer un dé avec un " + f"nombre négatif faces, désolé." + ) + if len(str(n)) > 1900: + raise TfjmError( + "Oulà... Je sais que la taille ça ne compte pas, " + "mais là il est vraiment gros ton dé !" + ) dice = random.randint(1, n) - await ctx.send( - f"Le dé à {n} face{'s' * (n > 1)} s'est arrêté sur... **{dice}**" - ) + await ctx.send(f"{ctx.author.mention} : {Emoji.DICE} {dice}") @commands.command( name="random-problem", diff --git a/src/constants.py b/src/constants.py index 3c8aba5..c81268f 100644 --- a/src/constants.py +++ b/src/constants.py @@ -59,6 +59,8 @@ class Role: class Emoji: JOY = "😂" SOB = "😭" + BIN = "🗑️" + DICE = "🎲" class File: @@ -71,3 +73,8 @@ class File: with open(File.TOP_LEVEL / "data" / "problems") as f: PROBLEMS = f.read().splitlines() MAX_REFUSE = len(PROBLEMS) - 4 # -5 usually but not in 2020 because of covid-19 + + +def setup(bot): + # Just so we can reload the constants + pass diff --git a/src/core.py b/src/core.py new file mode 100644 index 0000000..a9894c4 --- /dev/null +++ b/src/core.py @@ -0,0 +1,72 @@ +import asyncio +import sys +from importlib import reload + +import psutil +from discord import User, Message, Reaction +from discord.ext.commands import Bot + + +__all__ = ["CustomBot"] + +from discord.utils import get + +from src.constants import Emoji + + +class CustomBot(Bot): + """ + This is the same as a discord bot except + for class reloading and it provides hints + for the type checker about the modules + that are added by extensions. + """ + + def __str__(self): + return f"{self.__class__.__name__}:{hex(id(self.__class__))} obj at {hex(id(self))}" + + def reload(self): + cls = self.__class__ + module_name = cls.__module__ + old_module = sys.modules[module_name] + + print("Trying to reload the bot.") + try: + # del sys.modules[module_name] + module = reload(old_module) + self.__class__ = getattr(module, cls.__name__, cls) + except: + print("Could not reload the bot :/") + raise + print("The bot has reloaded !") + + async def wait_for_bin(bot: Bot, user: User, *msgs: Message, timeout=300): + """Wait for timeout seconds for `user` to delete the messages.""" + + msgs = list(msgs) + + assert msgs, "No messages in wait_for_bin" + + for m in msgs: + await m.add_reaction(Emoji.BIN) + + def check(reaction: Reaction, u): + return ( + user == u + and any(m.id == reaction.message.id for m in msgs) + and str(reaction.emoji) == Emoji.BIN + ) + + try: + while msgs: + reaction, u = await bot.wait_for( + "reaction_add", check=check, timeout=timeout + ) + the_msg = get(msgs, id=reaction.message.id) + await the_msg.delete() + msgs.remove(the_msg) + except asyncio.TimeoutError: + pass + + for m in msgs: + await m.clear_reaction(Emoji.BIN) diff --git a/src/tfjm_discord_bot.py b/src/tfjm_discord_bot.py index 85a79a5..c8a11a2 100644 --- a/src/tfjm_discord_bot.py +++ b/src/tfjm_discord_bot.py @@ -1,12 +1,13 @@ #!/bin/python -from discord.ext import commands - from src.constants import * + # We allow "! " to catch people that put a space in their commands. # It must be in first otherwise "!" always match first and the space is not recognised -bot = commands.Bot(("! ", "!")) +from src.core import CustomBot + +bot = CustomBot(("! ", "!")) # Global variable to hold the tirages. # We *want* it to be global so we can reload the tirages cog without @@ -25,6 +26,7 @@ bot.load_extension("src.cogs.errors") bot.load_extension("src.cogs.misc") bot.load_extension("src.cogs.teams") bot.load_extension("src.cogs.tirages") +bot.load_extension("src.utils") if __name__ == "__main__": diff --git a/src/utils.py b/src/utils.py index bd7fb52..8d489ed 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,6 +1,12 @@ +import asyncio +from typing import Sequence + import psutil -from discord import Message +from discord import Message, Member, User, Reaction from discord.ext.commands import Context, Bot +from discord.utils import get + +from src.constants import Emoji def has_role(member, role: str): @@ -9,17 +15,9 @@ def has_role(member, role: str): return any(r.name == role for r in member.roles) -async def send_and_bin(bot: Bot, ctx: Context, msg=None, *, embed=None): - """Send a message and wait 5min for the author to delete it.""" - - message: Message = await ctx.send(msg, embed=embed) - - await msg - - -def start_time(): +def start_time(self): return psutil.Process().create_time() -def setup(bot): - bot.send_and_bin = send_and_bin +def setup(bot: Bot): + pass