From 65b35a0ec9d02b7d064a77f57f58f6de61d1ff52 Mon Sep 17 00:00:00 2001 From: Noureldin Date: Wed, 5 Feb 2025 12:44:56 -0800 Subject: [PATCH] Relax restriction on classical inputs for modular addition to fix nightly CI (#1537) Fix Nightly CI --- qualtran/bloqs/mod_arithmetic/mod_addition.py | 22 +++++++++++-------- .../bloqs/mod_arithmetic/mod_addition_test.py | 5 ++--- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/qualtran/bloqs/mod_arithmetic/mod_addition.py b/qualtran/bloqs/mod_arithmetic/mod_addition.py index db4fe7dba..b165a3946 100644 --- a/qualtran/bloqs/mod_arithmetic/mod_addition.py +++ b/qualtran/bloqs/mod_arithmetic/mod_addition.py @@ -87,13 +87,15 @@ def signature(self) -> 'Signature': def on_classical_vals( self, x: 'ClassicalValT', y: 'ClassicalValT' ) -> Dict[str, 'ClassicalValT']: - if not (0 <= x < self.mod): + # The construction still works when at most one of inputs equals `mod`. + special_case = (x == self.mod) ^ (y == self.mod) + if not (0 <= x < self.mod or special_case): raise ValueError( - f'{x=} is outside the valid interval for modular addition [0, {self.mod})' + f'{x=} is outside the valid interval for modular addition [0, {self.mod}]' ) - if not (0 <= y < self.mod): + if not (0 <= y < self.mod or special_case): raise ValueError( - f'{y=} is outside the valid interval for modular addition [0, {self.mod})' + f'{y=} is outside the valid interval for modular addition [0, {self.mod}]' ) y = (x + y) % self.mod @@ -320,7 +322,7 @@ def on_classical_vals( if not (0 <= x < self.mod): raise ValueError( - f'{x=} is outside the valid interval for modular addition [0, {self.mod})' + f'{x=} is outside the valid interval for modular addition [0, {self.mod}]' ) x = (x + self.k) % self.mod @@ -508,13 +510,15 @@ def on_classical_vals( if ctrl != self.cv: return {'ctrl': ctrl, 'x': x, 'y': y} - if not (0 <= x < self.mod): + # The construction still works when at most one of inputs equals `mod`. + special_case = (x == self.mod) ^ (y == self.mod) + if not (0 <= x < self.mod or special_case): raise ValueError( - f'{x=} is outside the valid interval for modular addition [0, {self.mod})' + f'{x=} is outside the valid interval for modular addition [0, {self.mod}]' ) - if not (0 <= y < self.mod): + if not (0 <= y < self.mod or special_case): raise ValueError( - f'{y=} is outside the valid interval for modular addition [0, {self.mod})' + f'{y=} is outside the valid interval for modular addition [0, {self.mod}]' ) y = (x + y) % self.mod diff --git a/qualtran/bloqs/mod_arithmetic/mod_addition_test.py b/qualtran/bloqs/mod_arithmetic/mod_addition_test.py index bd5a11f4f..7786a0ce6 100644 --- a/qualtran/bloqs/mod_arithmetic/mod_addition_test.py +++ b/qualtran/bloqs/mod_arithmetic/mod_addition_test.py @@ -131,9 +131,8 @@ def test_classical_action_mod_add(prime, bitsize): def test_classical_action_cmodadd(control, prime, dtype, bitsize): b = CModAdd(dtype(bitsize), mod=prime, cv=control) cb = b.decompose_bloq() - valid_range = range(prime) for c in range(2): - for x, y in itertools.product(valid_range, repeat=2): + for x, y in itertools.product(range(prime + 1), range(prime)): assert b.call_classically(ctrl=c, x=x, y=y) == cb.call_classically(ctrl=c, x=x, y=y) @@ -207,7 +206,7 @@ def test_cmod_add_complexity_vs_ref(): @pytest.mark.parametrize(['prime', 'bitsize'], [(p, bitsize) for p in [5, 7] for bitsize in (5, 6)]) def test_mod_add_classical_action(bitsize, prime): b = ModAdd(bitsize, prime) - assert_consistent_classical_action(b, x=range(prime), y=range(prime)) + assert_consistent_classical_action(b, x=range(prime + 1), y=range(prime)) def test_cmodadd_tensor():