Skip to content

Commit

Permalink
Relax restriction on classical inputs for modular addition to fix nig…
Browse files Browse the repository at this point in the history
…htly CI (#1537)

Fix Nightly CI
  • Loading branch information
NoureldinYosri authored Feb 5, 2025
1 parent 676b02a commit 65b35a0
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 12 deletions.
22 changes: 13 additions & 9 deletions qualtran/bloqs/mod_arithmetic/mod_addition.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,15 @@ def signature(self) -> 'Signature':
def on_classical_vals(
self, x: 'ClassicalValT', y: 'ClassicalValT'
) -> Dict[str, 'ClassicalValT']:
if not (0 <= x < self.mod):
# The construction still works when at most one of inputs equals `mod`.
special_case = (x == self.mod) ^ (y == self.mod)
if not (0 <= x < self.mod or special_case):
raise ValueError(
f'{x=} is outside the valid interval for modular addition [0, {self.mod})'
f'{x=} is outside the valid interval for modular addition [0, {self.mod}]'
)
if not (0 <= y < self.mod):
if not (0 <= y < self.mod or special_case):
raise ValueError(
f'{y=} is outside the valid interval for modular addition [0, {self.mod})'
f'{y=} is outside the valid interval for modular addition [0, {self.mod}]'
)

y = (x + y) % self.mod
Expand Down Expand Up @@ -320,7 +322,7 @@ def on_classical_vals(

if not (0 <= x < self.mod):
raise ValueError(
f'{x=} is outside the valid interval for modular addition [0, {self.mod})'
f'{x=} is outside the valid interval for modular addition [0, {self.mod}]'
)

x = (x + self.k) % self.mod
Expand Down Expand Up @@ -508,13 +510,15 @@ def on_classical_vals(
if ctrl != self.cv:
return {'ctrl': ctrl, 'x': x, 'y': y}

if not (0 <= x < self.mod):
# The construction still works when at most one of inputs equals `mod`.
special_case = (x == self.mod) ^ (y == self.mod)
if not (0 <= x < self.mod or special_case):
raise ValueError(
f'{x=} is outside the valid interval for modular addition [0, {self.mod})'
f'{x=} is outside the valid interval for modular addition [0, {self.mod}]'
)
if not (0 <= y < self.mod):
if not (0 <= y < self.mod or special_case):
raise ValueError(
f'{y=} is outside the valid interval for modular addition [0, {self.mod})'
f'{y=} is outside the valid interval for modular addition [0, {self.mod}]'
)

y = (x + y) % self.mod
Expand Down
5 changes: 2 additions & 3 deletions qualtran/bloqs/mod_arithmetic/mod_addition_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,8 @@ def test_classical_action_mod_add(prime, bitsize):
def test_classical_action_cmodadd(control, prime, dtype, bitsize):
b = CModAdd(dtype(bitsize), mod=prime, cv=control)
cb = b.decompose_bloq()
valid_range = range(prime)
for c in range(2):
for x, y in itertools.product(valid_range, repeat=2):
for x, y in itertools.product(range(prime + 1), range(prime)):
assert b.call_classically(ctrl=c, x=x, y=y) == cb.call_classically(ctrl=c, x=x, y=y)


Expand Down Expand Up @@ -207,7 +206,7 @@ def test_cmod_add_complexity_vs_ref():
@pytest.mark.parametrize(['prime', 'bitsize'], [(p, bitsize) for p in [5, 7] for bitsize in (5, 6)])
def test_mod_add_classical_action(bitsize, prime):
b = ModAdd(bitsize, prime)
assert_consistent_classical_action(b, x=range(prime), y=range(prime))
assert_consistent_classical_action(b, x=range(prime + 1), y=range(prime))


def test_cmodadd_tensor():
Expand Down

0 comments on commit 65b35a0

Please sign in to comment.