From 0c7bea6fecbfba642b190eb91697d88c7a5d7305 Mon Sep 17 00:00:00 2001
From: Michael Manganiello <mike@fmanganiello.com.ar>
Date: Thu, 1 Dec 2022 21:41:44 -0300
Subject: [PATCH] feat: Allow override_flag testutil to set different checks

Up until now, the `override_flag` test util was only useful to
completely activate or deactivate a flag. It didn't allow to provide
more customization for all the internal checks it makes, like evaluating
if the request user is authenticated or staff.

This change provides that functionality, without breaking the current
API, to avoid breaking changes. Now, users can provide more granular
arguments to `override_flag`, for a more flexible flag testing.

Closes #439.
---
 docs/testing/automated.rst     |  11 ++
 waffle/tests/test_testutils.py | 252 ++++++++++++++++++++++++++++++++-
 waffle/testutils.py            | 144 ++++++++++++++++++-
 3 files changed, 400 insertions(+), 7 deletions(-)

diff --git a/docs/testing/automated.rst b/docs/testing/automated.rst
index 9da4f532..db62f092 100644
--- a/docs/testing/automated.rst
+++ b/docs/testing/automated.rst
@@ -41,6 +41,17 @@ Or::
         # samples behave normally.
         assert waffle.sample_is_active('sample_name')
 
+``override_flag`` also allows providing values for its different checks.
+For example::
+
+    @override_flag('flag_name', percent=25)
+    def test_with_flag():
+        ...
+
+    @override_flag('flag_name', staff=True, superusers=True)
+    def test_with_flag():
+        ...
+
 All three will restore the relevant flag, sample, or switch to its
 previous state: they will restore the old values and will delete objects
 that did not exist.
diff --git a/waffle/tests/test_testutils.py b/waffle/tests/test_testutils.py
index 3872a041..6b67fca5 100644
--- a/waffle/tests/test_testutils.py
+++ b/waffle/tests/test_testutils.py
@@ -1,11 +1,15 @@
+import contextlib
+import random
 from decimal import Decimal
+from unittest import mock
 
+from django.contrib.auth import get_user_model
 from django.contrib.auth.models import AnonymousUser
 from django.db import transaction
-from django.test import TransactionTestCase, RequestFactory, TestCase
+from django.test import RequestFactory, TestCase, TransactionTestCase
 
 import waffle
-from waffle.testutils import override_switch, override_flag, override_sample
+from waffle.testutils import override_flag, override_sample, override_switch
 
 
 class OverrideSwitchMixin:
@@ -116,6 +120,15 @@ def req():
     return r
 
 
+@contextlib.contextmanager
+def provide_user(**kwargs):
+    user = get_user_model()(**kwargs)
+    user.save()
+    yield user
+    user.delete()
+
+
+
 class OverrideFlagTestsMixin:
     def test_flag_existed_and_was_active(self):
         waffle.get_waffle_flag_model().objects.create(name='foo', everyone=True)
@@ -173,6 +186,241 @@ def test_cache_is_flushed_by_testutils_even_in_transaction(self):
 
         assert waffle.flag_is_active(req(), 'foo')
 
+    @mock.patch.object(random, 'uniform')
+    def test_flag_existed_and_was_active_for_percent(self, uniform):
+        waffle.get_waffle_flag_model().objects.create(name='foo', everyone=None, percent='50')
+
+        uniform.return_value = '75'
+
+        with override_flag('foo', percent=80.0):
+            assert waffle.flag_is_active(req(), 'foo')
+
+        with override_flag('foo', percent=40.0):
+            assert not waffle.flag_is_active(req(), 'foo')
+
+        assert waffle.get_waffle_flag_model().objects.get(name='foo').percent == Decimal('50')
+
+    @mock.patch.object(random, 'uniform')
+    def test_flag_existed_and_was_inactive_for_percent(self, uniform):
+        waffle.get_waffle_flag_model().objects.create(name='foo', everyone=None, percent=None)
+
+        uniform.return_value = '75'
+
+        with override_flag('foo', percent=80.0):
+            assert waffle.flag_is_active(req(), 'foo')
+
+        with override_flag('foo', percent=40.0):
+            assert not waffle.flag_is_active(req(), 'foo')
+
+        assert not waffle.get_waffle_flag_model().objects.get(name='foo').percent
+
+
+    def test_flag_existed_and_was_active_for_testing(self):
+        waffle.get_waffle_flag_model().objects.create(name='foo', everyone=None, testing=True)
+
+        with override_flag('foo', testing=True):
+            request = req()
+            request.COOKIES['dwft_foo'] = 'True'
+            assert waffle.flag_is_active(request, 'foo')
+            assert not waffle.flag_is_active(req(), 'foo')
+
+        with override_flag('foo', testing=False):
+            request = req()
+            request.COOKIES['dwft_foo'] = 'True'
+            assert not waffle.flag_is_active(request, 'foo')
+            assert not waffle.flag_is_active(req(), 'foo')
+
+        assert waffle.get_waffle_flag_model().objects.get(name='foo').testing
+
+    def test_flag_existed_and_was_inactive_for_testing(self):
+        waffle.get_waffle_flag_model().objects.create(name='foo', everyone=None, testing=False)
+
+        with override_flag('foo', testing=True):
+            request = req()
+            request.COOKIES['dwft_foo'] = 'True'
+            assert waffle.flag_is_active(request, 'foo')
+            assert not waffle.flag_is_active(req(), 'foo')
+
+        with override_flag('foo', testing=False):
+            request = req()
+            request.COOKIES['dwft_foo'] = 'True'
+            assert not waffle.flag_is_active(request, 'foo')
+            assert not waffle.flag_is_active(req(), 'foo')
+
+        assert not waffle.get_waffle_flag_model().objects.get(name='foo').testing
+
+    def test_flag_existed_and_was_active_for_superusers(self):
+        waffle.get_waffle_flag_model().objects.create(name='foo', everyone=None, superusers=True)
+
+        with override_flag('foo', superusers=True):
+            with provide_user(username='foo', is_superuser=True) as user:
+                request = req()
+                request.user = user
+                assert waffle.flag_is_active(request, 'foo')
+            with provide_user(username='foo', is_superuser=False) as user:
+                request = req()
+                request.user = user
+                assert not waffle.flag_is_active(request, 'foo')
+
+        with override_flag('foo', superusers=False):
+            with provide_user(username='foo', is_superuser=True) as user:
+                request = req()
+                request.user = user
+                assert not waffle.flag_is_active(request, 'foo')
+            with provide_user(username='foo', is_superuser=False) as user:
+                request = req()
+                request.user = user
+                assert not waffle.flag_is_active(request, 'foo')
+
+        assert waffle.get_waffle_flag_model().objects.get(name='foo').superusers
+
+    def test_flag_existed_and_was_inactive_for_superusers(self):
+        waffle.get_waffle_flag_model().objects.create(name='foo', everyone=None, superusers=False)
+
+        with override_flag('foo', superusers=True):
+            with provide_user(username='foo', is_superuser=True) as user:
+                request = req()
+                request.user = user
+                assert waffle.flag_is_active(request, 'foo')
+            with provide_user(username='foo', is_superuser=False) as user:
+                request = req()
+                request.user = user
+                assert not waffle.flag_is_active(request, 'foo')
+
+        with override_flag('foo', superusers=False):
+            with provide_user(username='foo', is_superuser=True) as user:
+                request = req()
+                request.user = user
+                assert not waffle.flag_is_active(request, 'foo')
+            with provide_user(username='foo', is_superuser=False) as user:
+                request = req()
+                request.user = user
+                assert not waffle.flag_is_active(request, 'foo')
+
+        assert not waffle.get_waffle_flag_model().objects.get(name='foo').superusers
+
+    def test_flag_existed_and_was_active_for_staff(self):
+        waffle.get_waffle_flag_model().objects.create(name='foo', everyone=None, staff=True)
+
+        with override_flag('foo', staff=True):
+            with provide_user(username='foo', is_staff=True) as user:
+                request = req()
+                request.user = user
+                assert waffle.flag_is_active(request, 'foo')
+            with provide_user(username='foo', is_staff=False) as user:
+                request = req()
+                request.user = user
+                assert not waffle.flag_is_active(request, 'foo')
+
+        with override_flag('foo', staff=False):
+            with provide_user(username='foo', is_staff=True) as user:
+                request = req()
+                request.user = user
+                assert not waffle.flag_is_active(request, 'foo')
+            with provide_user(username='foo', is_staff=False) as user:
+                request = req()
+                request.user = user
+                assert not waffle.flag_is_active(request, 'foo')
+
+        assert waffle.get_waffle_flag_model().objects.get(name='foo').staff
+
+    def test_flag_existed_and_was_inactive_for_staff(self):
+        waffle.get_waffle_flag_model().objects.create(name='foo', everyone=None, staff=False)
+
+        with override_flag('foo', staff=True):
+            with provide_user(username='foo', is_staff=True) as user:
+                request = req()
+                request.user = user
+                assert waffle.flag_is_active(request, 'foo')
+            with provide_user(username='foo', is_staff=False) as user:
+                request = req()
+                request.user = user
+                assert not waffle.flag_is_active(request, 'foo')
+
+        with override_flag('foo', staff=False):
+            with provide_user(username='foo', is_staff=True) as user:
+                request = req()
+                request.user = user
+                assert not waffle.flag_is_active(request, 'foo')
+            with provide_user(username='foo', is_staff=False) as user:
+                request = req()
+                request.user = user
+                assert not waffle.flag_is_active(request, 'foo')
+
+        assert not waffle.get_waffle_flag_model().objects.get(name='foo').staff
+
+    def test_flag_existed_and_was_active_for_authenticated(self):
+        waffle.get_waffle_flag_model().objects.create(name='foo', everyone=None, authenticated=True)
+
+        with override_flag('foo', authenticated=True):
+            with provide_user(username='foo') as user:
+                request = req()
+                request.user = user
+                assert waffle.flag_is_active(request, 'foo')
+            assert not waffle.flag_is_active(req(), 'foo')
+
+        with override_flag('foo', authenticated=False):
+            with provide_user(username='foo') as user:
+                request = req()
+                request.user = user
+                assert not waffle.flag_is_active(request, 'foo')
+            assert not waffle.flag_is_active(req(), 'foo')
+
+        assert waffle.get_waffle_flag_model().objects.get(name='foo').authenticated
+
+    def test_flag_existed_and_was_inactive_for_authenticated(self):
+        waffle.get_waffle_flag_model().objects.create(name='foo', everyone=None, authenticated=False)
+
+        with override_flag('foo', authenticated=True):
+            with provide_user(username='foo') as user:
+                request = req()
+                request.user = user
+                assert waffle.flag_is_active(request, 'foo')
+            assert not waffle.flag_is_active(req(), 'foo')
+
+        with override_flag('foo', authenticated=False):
+            with provide_user(username='foo') as user:
+                request = req()
+                request.user = user
+                assert not waffle.flag_is_active(request, 'foo')
+            assert not waffle.flag_is_active(req(), 'foo')
+
+        assert not waffle.get_waffle_flag_model().objects.get(name='foo').authenticated
+
+    def test_flag_existed_and_was_active_for_languages(self):
+        waffle.get_waffle_flag_model().objects.create(name='foo', everyone=None, languages="en,es")
+
+        with override_flag('foo', languages="en,es"):
+            request = req()
+            request.LANGUAGE_CODE = "en"
+            assert waffle.flag_is_active(request, 'foo')
+            assert not waffle.flag_is_active(req(), 'foo')
+
+        with override_flag('foo', languages=""):
+            request = req()
+            request.LANGUAGE_CODE = "en"
+            assert not waffle.flag_is_active(request, 'foo')
+            assert not waffle.flag_is_active(req(), 'foo')
+
+        assert waffle.get_waffle_flag_model().objects.get(name='foo').languages == "en,es"
+
+    def test_flag_existed_and_was_inactive_for_languages(self):
+        waffle.get_waffle_flag_model().objects.create(name='foo', everyone=None, languages="")
+
+        with override_flag('foo', languages="en,es"):
+            request = req()
+            request.LANGUAGE_CODE = "en"
+            assert waffle.flag_is_active(request, 'foo')
+            assert not waffle.flag_is_active(req(), 'foo')
+
+        with override_flag('foo', languages=""):
+            request = req()
+            request.LANGUAGE_CODE = "en"
+            assert not waffle.flag_is_active(request, 'foo')
+            assert not waffle.flag_is_active(req(), 'foo')
+
+        assert waffle.get_waffle_flag_model().objects.get(name='foo').languages == ""
+
 
 class OverrideFlagsTestCase(OverrideFlagTestsMixin, TestCase):
     """
diff --git a/waffle/testutils.py b/waffle/testutils.py
index c54b81cc..210871db 100644
--- a/waffle/testutils.py
+++ b/waffle/testutils.py
@@ -1,3 +1,4 @@
+import sys
 from typing import Generic, Optional, TypeVar, Union
 
 from django.test.utils import TestContextDecorator
@@ -10,11 +11,27 @@
 from waffle.models import Switch, Sample
 
 
+if sys.version_info >= (3, 8):
+    from typing import TypedDict
+else:
+    TypedDict = dict
+
+
 __all__ = ['override_flag', 'override_sample', 'override_switch']
 
 _T = TypeVar("_T")
 
 
+class _FlagValues(TypedDict):
+    active: Optional[bool]
+    percent: Optional[float]
+    testing: Optional[bool]
+    superusers: Optional[bool]
+    staff: Optional[bool]
+    authenticated: Optional[bool]
+    languages: Optional[str]
+
+
 class _overrider(TestContextDecorator, Generic[_T]):
     def __init__(self, name: str, active: _T):
         super().__init__()
@@ -44,6 +61,69 @@ def disable(self) -> None:
             self.update(self.old_value)
 
 
+class _flag_overrider(TestContextDecorator):
+    def __init__(
+        self,
+        name: str,
+        active: Optional[bool] = None,
+        percent: Optional[float] = None,
+        testing: Optional[bool] = None,
+        superusers: Optional[bool] = None,
+        staff: Optional[bool] = None,
+        authenticated: Optional[bool] = None,
+        languages: Optional[str] = None,
+    ):
+        super().__init__()
+        self.name = name
+        self.active = active
+        self.percent = percent
+        self.testing = testing
+        self.superusers = superusers
+        self.staff = staff
+        self.authenticated = authenticated
+        self.languages = languages
+
+    def get(self) -> None:
+        self.obj, self.created = self.cls.objects.get_or_create(name=self.name)
+
+    def update(
+        self,
+        active: Optional[bool] = None,
+        percent: Optional[float] = None,
+        testing: Optional[bool] = None,
+        superusers: Optional[bool] = None,
+        staff: Optional[bool] = None,
+        authenticated: Optional[bool] = None,
+        languages: Optional[str] = None,
+    ) -> None:
+        raise NotImplementedError
+
+    def get_values(self) -> _FlagValues:
+        raise NotImplementedError
+
+    def enable(self) -> None:
+        self.get()
+        self.old_values = self.get_values()
+        current_values = _FlagValues(
+            active=self.active,
+            percent=self.percent,
+            testing=self.testing,
+            superusers=self.superusers,
+            staff=self.staff,
+            authenticated=self.authenticated,
+            languages=self.languages,
+        )
+        if self.old_values != current_values:
+            self.update(**current_values)
+
+    def disable(self) -> None:
+        if self.created:
+            self.obj.delete()
+            self.obj.flush()
+        else:
+            self.update(**self.old_values)
+
+
 class override_switch(_overrider[bool]):
     """
     override_switch is a contextmanager for easier testing of switches.
@@ -54,7 +134,7 @@ class override_switch(_overrider[bool]):
         with override_switch('happy_mode', active=True):
             ...
 
-    If `Switch` already existed, it's value would be changed inside the context
+    If `Switch` already existed, its value would be changed inside the context
     block, then restored to the original value. If `Switch` did not exist
     before entering the context, it is created, then removed at the end of the
     block.
@@ -78,17 +158,71 @@ def get_value(self) -> bool:
         return self.obj.active
 
 
-class override_flag(_overrider[Optional[bool]]):
+class override_flag(_flag_overrider):
+    """
+    override_flag is a contextmanager for easier testing of flags.
+
+    It accepts two parameters, name of the switch and it's state. Example
+    usage::
+
+        with override_flag('happy_mode', active=True):
+            ...
+
+        with override_flag('happy_mode', staff=True):
+            ...
+
+    If `Flag` already existed, its values would be changed inside the context
+    block, then restored to its original values. If `Flag` did not exist
+    before entering the context, it is created, then removed at the end of the
+    block.
+
+    It can also act as a decorator::
+
+        @override_flag('happy_mode', active=True)
+        def test_happy_mode_enabled():
+            ...
+
+    """
     cls = get_waffle_flag_model()
 
-    def update(self, active: Optional[bool]) -> None:
+    def update(
+        self,
+        active: Optional[bool] = None,
+        percent: Optional[float] = None,
+        testing: Optional[bool] = None,
+        superusers: Optional[bool] = None,
+        staff: Optional[bool] = None,
+        authenticated: Optional[bool] = None,
+        languages: Optional[str] = None,
+    ) -> None:
         obj = self.cls.objects.get(pk=self.obj.pk)
         obj.everyone = active
+        obj.percent = percent
+
+        if testing is not None:
+            obj.testing = testing
+        if superusers is not None:
+            obj.superusers = superusers
+        if staff is not None:
+            obj.staff = staff
+        if authenticated is not None:
+            obj.authenticated = authenticated
+        if languages is not None:
+            obj.languages = languages
+
         obj.save()
         obj.flush()
 
-    def get_value(self) -> Optional[bool]:
-        return self.obj.everyone
+    def get_values(self) -> _FlagValues:
+        return {
+            "active": self.obj.everyone,
+            "percent": self.obj.percent,
+            "testing": self.obj.testing,
+            "superusers": self.obj.superusers,
+            "staff": self.obj.staff,
+            "authenticated": self.obj.authenticated,
+            "languages": self.obj.languages,
+        }
 
 
 class override_sample(_overrider[Union[bool, float]]):