from __future__ import annotations
import asyncio
import inspect
from typing import List, Any, Iterable, Optional, TYPE_CHECKING, Union, Tuple, Callable, Dict, Awaitable
import aiomysql
import aiopg
import aiosqlite
import discord
from motor import motor_asyncio
if TYPE_CHECKING:
from discord.ext import commands
from .Database import Database
__all__ = (
"COLUMN_TYPES",
"DatabaseNotConnected",
"InvalidGenerator",
"get_generator_response",
"maybe_coroutine",
"generate_column_types",
"questionnaire",
"EventManager",
"CogManager",
"DatabaseChecker"
)
COLUMN_TYPES = {
motor_asyncio.AsyncIOMotorDatabase: None, # mongo does not require any columns
aiosqlite.core.Connection: {"snowflake": "INTEGER", "string": 'TEXT', "number": "INTEGER", "smallnumber": "INTEGER"},
aiopg.pool.Pool: {"snowflake": "bigint", "string": 'character varying', "number": "integer", "smallnumber": "smallint"},
aiomysql.pool.Pool: {"snowflake": "BIGINT", "string": 'TEXT', "number": "INT", "smallnumber": "SMALLINT"}
}
[docs]class DatabaseNotConnected(Exception):
"""Raises an error when the user tries to use a method of a manager without a database connected to it."""
[docs]class InvalidGenerator(Exception):
"""
Raises an exception when the user passes an invalid generator.
"""
__slots__ = ("generator",)
def __init__(self, generator):
self.generator = generator
super().__init__(f"Generator of type {type(self.generator)!r} is not supported.")
[docs]async def maybe_coroutine(function: Callable, *args, **kwargs) -> Any:
"""
|coro|
Returns the coroutine version of the function.
:param function: The function to convert.
:type function: Union[Awaitable, Callable]
:param args: The arguments.
:param kwargs: The key arguments:
:return: The coroutine version of the function.
:rtype: Awaitable
"""
value = function(*args, **kwargs)
if inspect.isawaitable(value):
return await value
return value
[docs]def get_generator_response(generator: Any, generator_type: Any, *args, **kwargs) -> Any:
"""
Returns the generator response with the arguments.
:param generator: The generator to get the response from.
:type generator: Any
:param generator_type: The generator type. (Should be same as the generator type.
:type generator_type: Any
:param args: The arguments of the generator.
:param kwargs: The key arguments of the generator
:return: The generator response.
:rtype: Any
"""
if inspect.isclass(generator) and issubclass(generator, generator_type):
if inspect.ismethod(generator.generate):
return generator.generate(*args, **kwargs)
return generator().generate(*args, **kwargs)
if isinstance(generator, generator_type):
return generator.generate(*args, **kwargs)
raise InvalidGenerator(generator)
[docs]def generate_column_types(types: Iterable[str], database_type: Any) -> Optional[List[str]]:
"""
Generates the column type names that are suitable for the database type.
:param types: The column types.
:type types: Iterable[str]
:param database_type: The database type.
:type database_type: Any
:return: The suitable column types for the database types.
:rtype: Optional[List[str]]
"""
database_type_configuration = COLUMN_TYPES.get(database_type)
if database_type_configuration is None:
return
return [database_type_configuration[x] for x in types]
[docs]async def questionnaire(ctx: commands.Context,
questions: Iterable[Union[str, discord.Embed]],
public: bool = False,
timeout: Union[float, int] = 30,
member: discord.Member = None) -> Tuple[List[str], bool]:
"""
|coro|
Questions the member using a "quiz" and returns the answers.
The questionnaire can be used without a specific member and be public.
If no member was passed and the questionnaire public argument is true, a ValueError will be raised.
:raises: ValueError: The questionnaire is private and no member was provided.
:param ctx: The context (where the questionnaire will ask the questions).
:type ctx: commands.Context
:param questions: The questions the questionnaire will ask.
:type questions: Iterable[Union[str, discord.Embed]]
:param public: A bool indicating if the questionnaire is public.
:type public: bool
:param timeout: The number of seconds until the questionnaire will stop and time out.
:type timeout: Union[float, int]
:param member: The member the questionnaire will get the answers from.
:type member: discord.Member
:return: The answers and a boolean indicating if the questionnaire timed out.
:rtype: Tuple[List[str], bool]
"""
answers = []
timed_out = False
if not public and not member:
raise ValueError("The questionnaire is private and no member was provided.")
def checks(msg):
return msg.channel == ctx.channel if public else msg.channel == ctx.channel and msg.author == member
for question in questions:
if isinstance(question, str):
await ctx.send(question)
elif isinstance(question, discord.Embed):
await ctx.send(embed=question)
else:
raise TypeError("Question must be of type 'str' or 'discord.Embed'.")
try:
message = await ctx.bot.wait_for('message', check=checks, timeout=timeout)
except asyncio.TimeoutError:
timed_out = True
break
answers.append(message.content)
return answers, timed_out
[docs]class EventManager:
"""
An event manager that manages events for managers.
"""
def __init__(self):
self.events = {}
[docs] async def call_event(self, name: str, *args, **kwargs) -> None:
"""
Calls the event name with the arguments
:param name: The event name.
:type name: str
:param args: The arguments.
:param kwargs: The key arguments.
:return: None
:rtype: None
"""
if name in self.events:
for event in self.events[name]:
await event(*args, **kwargs)
[docs] def event(self, name: str = None) -> Callable:
"""
A decorator which adds an event listener.
:param name: The event name.
:type name: str
:return: The inner function.
:rtype: Callable
"""
def inner(func):
self.add_event(func, name)
return func
return inner
[docs] def add_event(self, func: Callable, name: str = None) -> None:
"""
Adds an event to the event dictionary.
:param func: The event callback.
:type func: Callable
:param name: The event name.
:type name: str
:return: None
:rtype: None
:raises: TypeError: The listener isn't async.
"""
name = func.__name__ if not name else name
if not asyncio.iscoroutinefunction(func):
raise TypeError('Listeners must be async.')
if name in self.events:
self.events[name].append(func)
else:
self.events[name] = [func]
[docs] def remove_event(self, func: Callable, name: str = None) -> None:
"""
Removes an event from the event dictionary.
:param func: The event callback.
:type func: Callable
:param name: The event name.
:type name: str
:return: None
:rtype: None
"""
name = func.__name__ if not name else name
if name in self.events:
self.events[name].remove(func)
[docs]class CogManager:
"""
A CogManager which helps the user use the managers inside discord cogs.
"""
[docs] class Cog:
"""
The internal Cog class.
"""
def __init__(self, managers: List = None):
listeners = {}
managers = [] if managers is None else managers
attribute_objects = [getattr(self, attr) for attr in dir(self)]
for attr in attribute_objects:
listener_type = getattr(attr, "_listener_type", None)
if listener_type:
if listener_type in listeners:
listeners[listener_type].append(attr)
else:
listeners[listener_type] = [attr]
managers = managers or [attr for attr in attribute_objects if type(attr) in listeners]
for event_type in listeners:
for manager in managers:
for event in listeners[event_type]:
manager.add_event(event)
[docs] @staticmethod
def event(manager_type: Any) -> Callable:
"""
Adds an event to the Cog event list.
:param manager_type: The manager type of the event.
:type manager_type: Any
:rtype: Callable
:return: The inner function.
:raises: TypeError: The listener isn't async.
"""
def decorator(func):
if not inspect.iscoroutinefunction(func):
raise TypeError('Listeners must be async.')
func._listener_type = manager_type
return func
return decorator
[docs]class DatabaseChecker(EventManager):
"""
A database checker which makes sure the database is connected to a manager and handles the table creation.
"""
def __init__(self,
tables_column_data: List[Dict[str, str]],
table_identifiers: List[str]):
super().__init__()
self.database = None
self.table_identifiers = table_identifiers
self.tables = {}
self.tables_column_data = tables_column_data
def _check_database(self, raise_error: bool = True) -> bool:
"""
A function which checks if the database is connected.
:param raise_error: A bool indicating if the function should raise an error if the database is not connected.
:type raise_error: bool
:rtype: bool
:return: If the database is connected.
:raises: DatabaseNotConnected: The database is not connected.
"""
if not self.database:
if raise_error:
raise DatabaseNotConnected(f"Database not connected."
f" Connect this manager to a database using 'connect_to_database'")
return False
return True
[docs] async def connect_to_database(self, database: Database, tables: List[str]) -> None:
"""
Connects to the database.
Calls on_database_connect when connected.
:param database: The database to connect to.
:type database: Database
:param tables: The tables to create (incase they do not exist).
:type tables: List[str]
:rtype: None
:return: None
"""
for table, table_data, identifier in zip(tables, self.tables_column_data, self.table_identifiers):
types = generate_column_types(table_data.values(), type(database.database))
await database.create_table(table, dict(zip(list(table_data), types)) if types else None, True)
self.database = database
self.tables[identifier] = table
await self.call_event("on_database_connect")