import asyncio
import sys
from abc import ABC, abstractmethod
from typing import (
Dict,
Any,
Optional,
List,
Union
)
import aiomysql
import aiopg
import aiosqlite
from motor import motor_asyncio
if sys.version_info >= (3, 8) and sys.platform.lower().startswith("win"):
# Aiopg requires the event loop policy to be WindowsSelectorEventLoop, if it is not, aiopg raises an error.
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
[docs]async def create_mysql(host, port, user, password, dbname):
# Created this function to make sure the user has autocommit enabled.
# we must make sure autocommit is enabled because manual commits are not working on aiomysql :)
return await aiomysql.create_pool(host=host, port=port, user=user, password=password, db=dbname, autocommit=True)
[docs]class UnsupportedDatabase(Exception):
"""Raises error when the user tries to use an unsupported database."""
[docs]class Database(ABC):
def __init__(self, database):
self.database = database
[docs] @abstractmethod
async def close(self):
pass
[docs] @abstractmethod
async def insertifnotexists(self, table_name: str, data: Dict[str, Any], checks: Dict[str, Any]):
pass
[docs] @abstractmethod
async def insert(self, table_name: str, data: Dict[str, Any]):
pass
[docs] @abstractmethod
async def create_table(self,
table_name: str,
columns: Optional[Dict[str, str]] = None,
exists: Optional[bool] = False):
pass
[docs] @abstractmethod
async def update(self, table_name: str, data: Dict[str, Any], checks: Dict[str, Any]):
pass
[docs] @abstractmethod
async def updateorinsert(self,
table_name: str,
data: Dict[str, Any],
checks: Dict[str, Any],
insert_data: Dict[str, Any]):
pass
[docs] @abstractmethod
async def delete(self, table_name: str, checks: Dict[str, Any]):
pass
[docs] @abstractmethod
async def select(self,
table_name: str,
keys: List[str],
checks: Optional[Dict[str, Any]] = None,
fetchall: Optional[bool] = False):
pass
[docs] @abstractmethod
async def execute(self,
sql_query: str,
values: List[Any],
fetchall: bool = True) -> Union[List[Dict[str, Any]], Dict[str, Any]]:
pass
class _MongoDatabase(Database):
def __str__(self):
return f"<{self.__class__.__name__} '{self.name}'>"
@property
def name(self):
return self.database.name
async def close(self):
self.database.client.close()
async def insertifnotexists(self, table_name, data, checks):
response = await self.select(table_name, [], checks, True)
if not response:
return await self.insert(table_name, data)
async def insert(self, table_name, data):
return await self.database[table_name].insert_one(data)
async def create_table(self, table_name, _=None, exists=False):
# create_table has an unused positional parameter to make the methods consistent between database types.
if exists and table_name in await self.database.list_collection_names():
return
return await self.database.create_collection(table_name)
async def update(self, table_name, data, checks):
return await self.database[table_name].update_one(checks, {"$set": data})
async def updateorinsert(self, table_name, data, checks, insert_data):
response = await self.select(table_name, [], checks, True)
if len(response) == 1:
return await self.update(table_name, data, checks)
return await self.insert(table_name, insert_data)
async def delete(self, table_name, checks=None):
return await self.database[table_name].delete_one({} if checks is None else checks)
async def select(self, table_name, keys, checks=None, fetchall=False):
checks = {} if checks is None else checks
if fetchall:
fetch = self.database[table_name].find(checks)
result = []
async for doc in fetch:
current_doc = {}
for key, value in doc.items():
if not keys or key in keys:
current_doc[key] = value
result.append(current_doc)
else:
fetch = await self.database[table_name].find_one(checks)
result = {}
if fetch is not None:
for key, value in fetch.items():
if not keys or key in keys:
result[key] = value
else:
result = None
return result
async def execute(self,
sql_query: str,
values: List[Any],
fetchall: bool = True) -> Union[List[Dict[str, Any]], Dict[str, Any]]:
raise NotImplementedError("NoSQL databases cannot execute sql queries.")
class _SqlDatabase(Database):
def __str__(self):
return f"<{self.__class__.__name__}>"
def with_commit(func):
async def inner(self, *args, **kwargs):
resp = await func(self, *args, **kwargs)
if self.commit_needed:
await self.commit()
return resp
return inner
def with_cursor(func):
async def inner(self, *args, **kwargs):
database = await self.database.acquire() if self.pool else self.database
if self.cursor_context:
async with database.cursor() as cursor:
resp = await func(self, cursor, *args, **kwargs)
else:
cursor = await database.cursor()
resp = await func(self, cursor, *args, **kwargs)
await cursor.close()
if self.pool:
self.database.release(database)
return resp
return inner
def __init__(self, database):
super().__init__(database)
self.place_holder = DATABASE_TYPES[type(database)]["placeholder"]
self.cursor_context = DATABASE_TYPES[type(database)]['cursorcontext']
self.commit_needed = DATABASE_TYPES[type(database)]['commit']
self.quote = DATABASE_TYPES[type(database)]['quotes']
self.pool = DATABASE_TYPES[type(database)]['pool']
async def commit(self):
if not self.pool:
await self.database.commit()
async def close(self):
await self.database.close()
async def insertifnotexists(self, table_name, data, checks):
response = await self.select(table_name, [], checks, True)
if not response:
return await self.insert(table_name, data)
@with_cursor
@with_commit
async def insert(self, cursor, table_name, data):
query = f"INSERT INTO {table_name} ({', '.join(data.keys())}) VALUES ({', '.join([self.place_holder] * len(data.values()))})"
await cursor.execute(query, list(data.values()))
@with_cursor
@with_commit
async def create_table(self, cursor, table_name, columns=None, exists=False):
query = f'CREATE TABLE {"IF NOT EXISTS" if exists else ""} {self.quote}{table_name}{self.quote} ('
for column in [] if columns is None else columns:
query += f"\n{self.quote}{column}{self.quote} {columns[column]},"
query = query[:-1]
query += "\n);"
await cursor.execute(query)
@with_cursor
@with_commit
async def update(self, cursor, table_name, data, checks):
query = f"UPDATE {table_name} SET "
if data:
for key in data:
query += f"{key} = {self.place_holder}, "
query = query[:-2]
if checks:
query += " WHERE "
for check in checks:
query += f"{check} = {self.place_holder} AND "
query = query[:-4]
await cursor.execute(query, list(data.values()) + list(checks.values()))
async def updateorinsert(self, table_name, data, checks, insert_data):
response = await self.select(table_name, [], checks, True)
if len(response) == 1:
return await self.update(table_name, data, checks)
return await self.insert(table_name, insert_data)
@with_cursor
@with_commit
async def delete(self, cursor, table_name, checks=None):
checks = {} if checks is None else checks
query = f"DELETE FROM {table_name} "
if checks:
query += "WHERE "
for check in checks:
query += f"{check} = {self.place_holder} AND "
query = query[:-4]
await cursor.execute(query, list(checks.values()))
@with_cursor
async def select(self, cursor, table_name, keys, checks=None, fetchall=False):
checks = {} if checks is None else checks
keys = '*' if not keys else keys
query = f"SELECT {','.join(keys)} FROM {table_name} "
if checks:
query += "WHERE "
for check in checks:
query += f"{check} = {self.place_holder} AND "
query = query[:-4]
await cursor.execute(query, list(checks.values()))
columns = [x[0] for x in cursor.description]
result = await cursor.fetchall() if fetchall else await cursor.fetchone()
if not result:
return result
return [dict(zip(columns, x)) for x in result] if fetchall else dict(zip(columns, result))
@with_cursor
@with_commit
async def execute(self,
cursor,
sql_query: str,
values: List[Any] = None,
fetchall: bool = True) -> Union[List[Dict[str, Any]], Dict[str, Any]]:
await cursor.execute(sql_query, values if values is not None else [])
result = await cursor.fetchall() if fetchall else await cursor.fetchone()
columns = [x[0] for x in cursor.description]
if not result:
return result
return [dict(zip(columns, x)) for x in result] if fetchall else dict(zip(columns, result))
DATABASE_TYPES: Dict[Any, Dict[str, Any]] = {
motor_asyncio.AsyncIOMotorDatabase: {"class": _MongoDatabase,
"placeholder": None},
aiosqlite.core.Connection: {"class": _SqlDatabase,
"placeholder": '?',
'cursorcontext': True,
'commit': True,
'quotes': '"',
'pool': False},
aiopg.pool.Pool: {"class": _SqlDatabase,
"placeholder": '%s',
'cursorcontext': True,
'commit': True,
'quotes': '"',
'pool': True},
aiomysql.pool.Pool: {"class": _SqlDatabase,
"placeholder": '%s',
'cursorcontext': True,
'commit': False,
'quotes': '`',
'pool': True}
}
DATABASES: List = [_SqlDatabase, _MongoDatabase]
[docs]class DatabaseManager:
[docs] @staticmethod
def connect(database):
if type(database) not in DATABASE_TYPES:
raise UnsupportedDatabase(f"Database of type {type(database)} is not supported by the database manager.")
return DATABASE_TYPES[type(database)]["class"](database)