diff --git a/src/cogs/dev.py b/src/cogs/dev.py index 564dc98..162646d 100644 --- a/src/cogs/dev.py +++ b/src/cogs/dev.py @@ -1,4 +1,7 @@ import asyncio +import re +import traceback +from io import StringIO from pprint import pprint import discord @@ -10,12 +13,14 @@ from discord.ext.commands import ( Cog, ExtensionNotLoaded, Context, + is_owner, ) from discord.utils import get from ptpython.repl import embed from src.constants import * from src.core import CustomBot +from src.errors import TfjmError from src.utils import fg COGS_SHORTCUTS = { @@ -29,6 +34,10 @@ COGS_SHORTCUTS = { "v": "dev", } +RE_QUERY = re.compile( + r"^! ?e(val)? (`{1,3}py(thon)?\n)?(?P.*?)\n?(`{1,3})?\n?$", re.DOTALL +) + class DevCog(Cog, name="Dev tools"): def __init__(self, bot: CustomBot): @@ -204,6 +213,53 @@ class DevCog(Cog, name="Dev tools"): await channel.delete_messages(to_delete) await ctx.message.delete() + @command(name="eval", aliases=["e"]) + @is_owner() + async def eval_cmd(self, ctx: Context): + """""" + msg: Message = ctx.message + guild: discord.Guild = ctx.guild + + query = re.match(RE_QUERY, msg.content).group("query") + + if not query: + raise TfjmError("No query found.") + + if "\n" in query: + lines = query.splitlines() + if "return" not in lines[-1] and "=" not in lines[-1]: + lines[-1] = f"return {lines[-1]}" + query = "\n ".join(lines) + query = f"def q():\n {query}\nresp = q()" + + try: + if "\n" in query: + q = compile(query, filename="query.py", mode="exec") + globs = {**globals(), **locals()} + locs = {} + exec(query, globs, locs) + resp = locs["resp"] + else: + resp = eval(query, globals(), locals()) + except Exception as e: + tb = StringIO() + traceback.print_tb(e.__traceback__, file=tb) + tb.seek(0) + + embed = discord.Embed(title=str(e), color=discord.Colour.red()) + embed.add_field(name="Query", value=f"```py\n{query}\n```", inline=False) + embed.add_field( + name="Traceback", value=f"```py\n{tb.read()}```", inline=False + ) + else: + out = StringIO() + pprint(resp, out) + out.seek(0) + embed = discord.Embed(title="Result", color=discord.Colour.green()) + embed.add_field(name="Query", value=f"```py\n{query}```", inline=False) + embed.add_field(name="Value", value=f"```py\n{out.read()}```", inline=False) + await ctx.send(embed=embed) + @Cog.listener() async def on_message(self, msg: Message): ch: TextChannel = msg.channel