From 9331bc40ad1f7ce09d5b03ec1d3c171bc2bf67fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Honza=20Kr=C3=A1l?= Date: Thu, 3 Mar 2022 18:15:05 +0100 Subject: [PATCH] Make CREATE_MISSING_FLAGS logic universal (#400) (#427) Before CREATE_MISSING_FLAGS was only respected when `is_active(request)` method was used, now any check against a flag (calling `get(name)`) will trigger the automatic creation of the `Flag` model. --- waffle/models.py | 17 +++++++---------- waffle/tests/test_waffle.py | 10 ++++++++++ 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/waffle/models.py b/waffle/models.py index fa76b6eb..00c7c7b2 100644 --- a/waffle/models.py +++ b/waffle/models.py @@ -56,6 +56,13 @@ def get_from_db(cls, name): objects = cls.objects if get_setting('READ_FROM_WRITE_DB'): objects = objects.using(router.db_for_write(cls)) + + if get_setting('CREATE_MISSING_FLAGS'): + flag, _created = objects.get_or_create( + name=name, defaults={'everyone': get_setting('FLAG_DEFAULT')} + ) + return flag + return objects.get(name=name) @classmethod @@ -247,16 +254,6 @@ def is_active(self, request): log_level = get_setting('LOG_MISSING_FLAGS') if log_level: logger.log(log_level, 'Flag %s not found', self.name) - if get_setting('CREATE_MISSING_FLAGS'): - flag, _created = get_waffle_flag_model().objects.get_or_create( - name=self.name, - defaults={ - 'everyone': get_setting('FLAG_DEFAULT') - } - ) - cache = get_cache() - cache.set(self._cache_key(self.name), flag) - return get_setting('FLAG_DEFAULT') if get_setting('OVERRIDE'): diff --git a/waffle/tests/test_waffle.py b/waffle/tests/test_waffle.py index 1b906eed..684f0809 100644 --- a/waffle/tests/test_waffle.py +++ b/waffle/tests/test_waffle.py @@ -441,6 +441,16 @@ def test_flag_created_dynamically_default_false(self): def test_flag_created_dynamically_default_true(self): self.assert_flag_dynamically_created_with_value(True) + @override_settings(WAFFLE_CREATE_MISSING_FLAGS=True) + @override_settings(WAFFLE_FLAG_DEFAULT=True) + def test_flag_created_dynamically_upon_retrieval(self): + FLAG_NAME = 'myflag' + flag_model = waffle.get_waffle_flag_model() + flag = flag_model.get(FLAG_NAME) + + assert flag.is_active(get()) + assert flag_model.objects.filter(name=FLAG_NAME).exists() + @mock.patch('waffle.models.logger') def test_no_logging_missing_flag_by_default(self, mock_logger): request = get()