diff --git a/README.md b/README.md index cdd5ab6..275f76d 100644 --- a/README.md +++ b/README.md @@ -166,15 +166,16 @@ All fields are required unless one of the following is set: The following column types are supported. See TypeSystem for [type-specific validation keyword arguments][typesystem-fields]. -* `orm.String(max_length)` -* `orm.Text()` * `orm.BigInteger()` * `orm.Boolean()` -* `orm.Integer()` -* `orm.Float()` * `orm.Date()` -* `orm.Time()` * `orm.DateTime()` +* `orm.Enum()` +* `orm.Float()` +* `orm.Integer()` +* `orm.String(max_length)` +* `orm.Text()` +* `orm.Time()` * `orm.JSON()` [sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ diff --git a/orm/__init__.py b/orm/__init__.py index 67b42c6..a595c26 100644 --- a/orm/__init__.py +++ b/orm/__init__.py @@ -5,6 +5,7 @@ Boolean, Date, DateTime, + Enum, Float, ForeignKey, Integer, @@ -22,6 +23,7 @@ "Boolean", "Date", "DateTime", + "Enum", "Float", "Integer", "String", diff --git a/orm/fields.py b/orm/fields.py index 66f1b38..3defa59 100644 --- a/orm/fields.py +++ b/orm/fields.py @@ -117,3 +117,12 @@ def expand_relationship(self, value): if isinstance(value, self.to): return value return self.to({self.to.__pkname__: value}) + + +class Enum(ModelField, typesystem.Any): + def __init__(self, enum, **kwargs): + super().__init__(**kwargs) + self.enum = enum + + def get_column_type(self): + return sqlalchemy.Enum(self.enum) diff --git a/tests/test_columns.py b/tests/test_columns.py index e0e0dfb..1dcb008 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -1,6 +1,7 @@ import asyncio import datetime import functools +from enum import Enum import databases import pytest @@ -17,6 +18,11 @@ def time(): return datetime.datetime.now().time() +class StatusEnum(Enum): + DRAFT = "Draft" + RELEASED = "Released" + + class Example(orm.Model): __tablename__ = "example" __metadata__ = metadata @@ -30,6 +36,7 @@ class Example(orm.Model): description = orm.Text(allow_blank=True) value = orm.Float(allow_null=True) data = orm.JSON(default={}) + status = orm.Enum(StatusEnum, default=StatusEnum.DRAFT) @pytest.fixture(autouse=True, scope="module") @@ -66,8 +73,12 @@ async def test_model_crud(): assert example.description == "" assert example.value is None assert example.data == {} + assert example.status == StatusEnum.DRAFT - await example.update(data={"foo": 123}, value=123.456) + await example.update( + data={"foo": 123}, value=123.456, status=StatusEnum.RELEASED + ) example = await Example.objects.get() assert example.value == 123.456 assert example.data == {"foo": 123} + assert example.status == StatusEnum.RELEASED