From 62d0f2ad3124a80b04ba6f4fd7c3791ea2468b88 Mon Sep 17 00:00:00 2001 From: Luca Sbardella Date: Fri, 21 Sep 2018 15:01:45 +0100 Subject: [PATCH] allow to specify the data_field in a db column --- openapi/__init__.py | 2 +- openapi/data/db.py | 48 +++++++++++++++++++++++++++++++++------------ tests/example/db.py | 10 +++++++++- 3 files changed, 45 insertions(+), 15 deletions(-) diff --git a/openapi/__init__.py b/openapi/__init__.py index 91a0b28..ca2482f 100644 --- a/openapi/__init__.py +++ b/openapi/__init__.py @@ -1,4 +1,4 @@ """Minimal OpenAPI asynchronous server application """ -__version__ = '0.8.6' +__version__ = '0.8.7' diff --git a/openapi/data/db.py b/openapi/data/db.py index 01ea864..eefb1e0 100644 --- a/openapi/data/db.py +++ b/openapi/data/db.py @@ -43,74 +43,96 @@ def _(f): @converter(sa.Boolean) def bl(col, required): - return (bool, fields.bool_field(**info(col, required))) + data_field = col.info.get('data_field', fields.bool_field) + return ( + bool, + data_field(**info(col, required)) + ) @converter(sa.Integer) def integer(col, required): - return (int, fields.number_field(precision=0, **info(col, required))) + data_field = col.info.get('data_field', fields.number_field) + return ( + int, + data_field(precision=0, **info(col, required)) + ) @converter(sa.Numeric) def number(col, required): + data_field = col.info.get('data_field', fields.decimal_field) return ( - Decimal, fields.decimal_field( - precision=col.type.scale, **info(col, required) - ) + Decimal, + data_field(precision=col.type.scale, **info(col, required)) ) @converter(sa.String, sa.Text, sa.CHAR, sa.VARCHAR) def string(col, required): - return (str, fields.str_field( - max_length=col.type.length or 0, **info(col, required))) + data_field = col.info.get('data_field', fields.str_field) + return ( + str, + data_field(max_length=col.type.length or 0, **info(col, required)) + ) @converter(sa.DateTime) def dt_ti(col, required): - return (datetime, fields.date_time_field(**info(col, required))) + data_field = col.info.get('data_field', fields.date_time_field) + return ( + datetime, + data_field(**info(col, required)) + ) @converter(sa.Date) def dt(col, required): - return (date, fields.date_field(**info(col, required))) + data_field = col.info.get('data_field', fields.date_field) + return ( + date, + data_field(**info(col, required)) + ) @converter(sa.Enum) def en(col, required): + data_field = col.info.get('data_field', fields.enum_field) return ( col.type.enum_class, - fields.enum_field(col.type.enum_class, **info(col, required)) + data_field(col.type.enum_class, **info(col, required)) ) @converter(sa.JSON) def js(col, required): + data_field = col.info.get('data_field', fields.json_field) val = None if col.default: arg = col.default.arg val = arg() if col.default.is_callable else arg return ( JsonTypes.get(type(val), typing.Dict), - fields.json_field(**info(col, required)) + data_field(**info(col, required)) ) @converter(UUIDType) def uuid(col, required): + data_field = col.info.get('data_field', fields.uuid_field) return ( str, - fields.uuid_field(**info(col, required)) + data_field(**info(col, required)) ) def info(col, required): - data = dict( description=col.doc, required=not col.nullable if required is not False else False ) data.update(col.info) + data.pop('data_field', None) return data diff --git a/tests/example/db.py b/tests/example/db.py index ee99fb4..c24a600 100644 --- a/tests/example/db.py +++ b/tests/example/db.py @@ -5,6 +5,8 @@ from sqlalchemy_utils import UUIDType from openapi.db.columns import UUIDColumn +from openapi.data import fields + from .models import TaskType original_init = UUIDType.__init__ @@ -17,6 +19,10 @@ def patch_init(self, binary=True, native=True, **kw): UUIDType.__init__ = patch_init +def title_field(**kwargs): + return fields.str_field(**kwargs) + + def meta(meta=None): """Add task related tables """ @@ -28,7 +34,9 @@ def meta(meta=None): UUIDColumn( 'id', make_default=True, doc='Unique ID'), sa.Column( - 'title', sa.String(64), nullable=False, info=dict(min_length=3)), + 'title', sa.String(64), nullable=False, + info=dict(min_length=3, data_field=title_field) + ), sa.Column('done', sa.DateTime), sa.Column('severity', sa.Integer), sa.Column('type', sa.Enum(TaskType)),