Skip to content

Commit

Permalink
Feature: async default value (#1498)
Browse files Browse the repository at this point in the history
* feat: support field value to be async function

* test: async function default value

* docs: change log for async function default value (#1498)
  • Loading branch information
YAGregor authored Apr 27, 2024
1 parent 25f21bf commit b72c175
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Added
- Enhancement for FastAPI lifespan support (#1371)
- Add __eq__ method to Q to more easily test dynamically-built queries (#1506)
- Added PlainToTsQuery function for postgres (#1347)
- Allow field's default keyword to be async function (#1498)

Fixed
^^^^^
Expand Down
21 changes: 21 additions & 0 deletions tests/test_callable_default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from tests import testmodels
from tortoise.contrib import test


class TestCallableDefault(test.TestCase):
async def test_default_create(self):
model = await testmodels.CallableDefault.create()
self.assertEqual(model.callable_default, "callable_default")
self.assertEqual(model.async_default, "async_callable_default")

async def test_default_by_save(self):
saved_model = testmodels.CallableDefault()
await saved_model.save()
self.assertEqual(saved_model.callable_default, "callable_default")
self.assertEqual(saved_model.async_default, "async_callable_default")

async def test_async_default_change(self):
default_change = testmodels.CallableDefault()
default_change.async_default = "changed"
await default_change.save()
self.assertEqual(default_change.async_default, "changed")
14 changes: 14 additions & 0 deletions tests/testmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,3 +892,17 @@ class PydanticMeta:
alias_generator=camelize_var,
populate_by_name=True,
)


def callable_default() -> str:
return "callable_default"


async def async_callable_default() -> str:
return "async_callable_default"


class CallableDefault(Model):
id = fields.IntField(pk=True)
callable_default = fields.CharField(max_length=32, default=callable_default)
async_default = fields.CharField(max_length=32, default=async_callable_default)
23 changes: 21 additions & 2 deletions tortoise/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,15 +667,25 @@ def __init__(self, **kwargs: Any) -> None:
self._partial = False
self._saved_in_db = False
self._custom_generated_pk = False
self._await_when_save: Dict[str, Callable[[], Awaitable[Any]]] = {}

# Assign defaults for missing fields
for key in meta.fields.difference(self._set_kwargs(kwargs)):
field_object = meta.fields_map[key]
if callable(field_object.default):
setattr(self, key, field_object.default())
field_default = field_object.default
if inspect.iscoroutinefunction(field_default):
self._await_when_save[key] = field_default
elif callable(field_default):
setattr(self, key, field_default())
else:
setattr(self, key, deepcopy(field_object.default))

def __setattr__(self, key, value):
# set field value override async default function
if hasattr(self, "_await_when_save"):
self._await_when_save.pop(key, None)
super().__setattr__(key, value)

def _set_kwargs(self, kwargs: dict) -> Set[str]:
meta = self._meta

Expand Down Expand Up @@ -719,6 +729,7 @@ def _init_from_db(cls: Type[MODEL], **kwargs: Any) -> MODEL:
self._partial = False
self._saved_in_db = True
self._custom_generated_pk = self._meta.db_pk_column not in self._meta.generated_db_fields
self._await_when_save = {}

meta = self._meta

Expand Down Expand Up @@ -845,6 +856,13 @@ def register_listener(cls, signal: Signals, listener: Callable):
if listener not in cls_listeners:
cls_listeners.append(listener)

async def _set_async_default_field(self) -> None:
"""retrieve value from field's async default value"""
if hasattr(self, "_await_when_save"):
for k, v in self._await_when_save.copy().items():
setattr(self, k, await v())
self._await_when_save = {}

async def _pre_delete(
self,
using_db: Optional[BaseDBAsyncClient] = None,
Expand Down Expand Up @@ -921,6 +939,7 @@ async def save(
:raises IncompleteInstanceError: If the model is partial and the fields are not available for persistence.
:raises IntegrityError: If the model can't be created or updated (specifically if force_create or force_update has been set)
"""
await self._set_async_default_field()
db = using_db or self._choose_db(True)
executor = db.executor_class(model=self.__class__, db=db)
if self._partial:
Expand Down

0 comments on commit b72c175

Please sign in to comment.