-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix bugs in ECAdd bloq #1489
Open
fpapa250
wants to merge
13
commits into
quantumlib:main
Choose a base branch
from
fpapa250:ecc-fix-add
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Fix bugs in ECAdd bloq #1489
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
af46d88
fix som ebugs
fpapa250 1755ef4
remove one dirty statement
fpapa250 b97ac02
Merge branch 'main' into ecc-fix-add
fpapa250 1ddc3ee
Correct the rest of the ECAdd bugs from the paper
fpapa250 9a83d8e
Fix symbolic cvs for step 5
fpapa250 c5668d5
Merge branch 'main' into ecc-fix-add
fpapa250 2f52006
Merge branch 'main' into ecc-fix-add
fpapa250 6da2b0e
Fix final test with Equals.controlled working
fpapa250 64e3989
Update qualtran/bloqs/factoring/ecc/ec_add.py
fpapa250 15fa7c5
formatting
fpapa250 3694769
Merge branch 'main' into ecc-fix-add
fpapa250 500f8b2
Merge branch 'main' into ecc-fix-add
fpapa250 af3563c
Solve nit
fpapa250 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ | |
bloq_example, | ||
BloqBuilder, | ||
BloqDocSpec, | ||
CtrlSpec, | ||
DecomposeTypeError, | ||
QBit, | ||
QMontgomeryUInt, | ||
|
@@ -33,7 +34,7 @@ | |
SoquetT, | ||
) | ||
from qualtran.bloqs.arithmetic.comparison import Equals | ||
from qualtran.bloqs.basic_gates import CNOT, IntState, Toffoli, ZeroState | ||
from qualtran.bloqs.basic_gates import CNOT, IntState, Toffoli, XGate, ZeroState | ||
from qualtran.bloqs.bookkeeping import Free | ||
from qualtran.bloqs.mcmt import MultiAnd, MultiControlX, MultiTargetCNOT | ||
from qualtran.bloqs.mod_arithmetic import ( | ||
|
@@ -253,10 +254,6 @@ def on_classical_vals( | |
lam = QMontgomeryUInt(self.n, self.mod).montgomery_product( | ||
int(y), QMontgomeryUInt(self.n, self.mod).montgomery_inverse(int(x)) | ||
) | ||
# TODO(https://github.com/quantumlib/Qualtran/issues/1461): Fix bug in circuit | ||
# which flips f1 when lam and lam_r are equal. | ||
if lam == lam_r: | ||
f1 = (f1 + 1) % 2 | ||
else: | ||
lam = 0 | ||
return {'f1': f1, 'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y, 'lam': lam, 'lam_r': lam_r} | ||
|
@@ -296,6 +293,12 @@ def build_composite_bloq( | |
y=y, | ||
) | ||
|
||
# Allocate an ancilla qubit that acts as a flag for the rare condition that the | ||
# pre-computed lambda_r is equal to the calculated lambda. This ancilla is used to properly | ||
# clear the f1 qubit when lambda is set to lambda_r. | ||
ancilla = bb.allocate() | ||
z4, lam_r, ancilla = bb.add(Equals(QMontgomeryUInt(self.n)), x=z4, y=lam_r, target=ancilla) | ||
|
||
# If ctrl = 1 and x != a: lam = (y - b) / (x - a) % p. | ||
z4_split = bb.split(z4) | ||
lam_split = bb.split(lam) | ||
|
@@ -323,7 +326,18 @@ def build_composite_bloq( | |
lam = bb.join(lam_split, dtype=QMontgomeryUInt(self.n)) | ||
|
||
# If lam = lam_r: return f1 = 0. (If not we will flip f1 to 0 at the end iff x_r = y_r = 0). | ||
lam, lam_r, f1 = bb.add(Equals(QMontgomeryUInt(self.n)), x=lam, y=lam_r, target=f1) | ||
# Only flip when lam is set to lam_r. | ||
ancilla, lam, lam_r, f1 = bb.add( | ||
Equals(QMontgomeryUInt(self.n)).controlled(ctrl_spec=CtrlSpec(cvs=0)), | ||
ctrl=ancilla, | ||
x=lam, | ||
y=lam_r, | ||
target=f1, | ||
) | ||
|
||
# Clear the ancilla bit and free it. | ||
z4, lam_r, ancilla = bb.add(Equals(QMontgomeryUInt(self.n)), x=z4, y=lam_r, target=ancilla) | ||
bb.free(ancilla) | ||
|
||
# Uncompute the modular multiplication then the modular inversion. | ||
x, y = bb.add( | ||
|
@@ -343,7 +357,8 @@ def build_composite_bloq( | |
|
||
def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: | ||
return { | ||
Equals(QMontgomeryUInt(self.n)): 1, | ||
Equals(QMontgomeryUInt(self.n)): 2, | ||
Equals(QMontgomeryUInt(self.n)).controlled(ctrl_spec=CtrlSpec(cvs=0)): 1, | ||
ModSub(QMontgomeryUInt(self.n), mod=self.mod): 1, | ||
CModSub(QMontgomeryUInt(self.n), mod=self.mod): 1, | ||
KaliskiModInverse(bitsize=self.n, mod=self.mod): 1, | ||
|
@@ -652,6 +667,7 @@ class _ECAddStepFive(Bloq): | |
will contain the x component of the resultant curve point. | ||
y: The y component of the second input elliptic curve point of bitsize `n` in montgomery form, which | ||
will contain the y component of the resultant curve point. | ||
lam_r: The precomputed lambda slope used in the addition operation if (a, b) = (x, y) in montgomery form. | ||
lam: The lambda slope used in the addition operation. | ||
|
||
References: | ||
|
@@ -672,6 +688,7 @@ def signature(self) -> 'Signature': | |
Register('b', QMontgomeryUInt(self.n)), | ||
Register('x', QMontgomeryUInt(self.n)), | ||
Register('y', QMontgomeryUInt(self.n)), | ||
Register('lam_r', QMontgomeryUInt(self.n)), | ||
Register('lam', QMontgomeryUInt(self.n), side=Side.LEFT), | ||
] | ||
) | ||
|
@@ -683,14 +700,15 @@ def on_classical_vals( | |
b: 'ClassicalValT', | ||
x: 'ClassicalValT', | ||
y: 'ClassicalValT', | ||
lam_r: 'ClassicalValT', | ||
lam: 'ClassicalValT', | ||
) -> Dict[str, 'ClassicalValT']: | ||
if ctrl == 1: | ||
x = (a - x) % self.mod | ||
y = (y - b) % self.mod | ||
else: | ||
x = (x + a) % self.mod | ||
return {'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y} | ||
return {'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y, 'lam_r': lam_r} | ||
|
||
def build_composite_bloq( | ||
self, | ||
|
@@ -700,6 +718,7 @@ def build_composite_bloq( | |
b: Soquet, | ||
x: Soquet, | ||
y: Soquet, | ||
lam_r: Soquet, | ||
lam: Soquet, | ||
) -> Dict[str, 'SoquetT']: | ||
if is_symbolic(self.n): | ||
|
@@ -729,9 +748,31 @@ def build_composite_bloq( | |
z4_split[i] = ctrls[1] | ||
z4 = bb.join(z4_split, dtype=QMontgomeryUInt(self.n)) | ||
lam = bb.join(lam_split, dtype=QMontgomeryUInt(self.n)) | ||
# TODO(https://github.com/quantumlib/Qualtran/issues/1461): Fix bug in circuit where lambda | ||
# is not set to 0 before being freed. | ||
bb.add(Free(QMontgomeryUInt(self.n), dirty=True), reg=lam) | ||
|
||
# If the denominator of lambda is 0, lam = lam_r so we clear lam with lam_r. | ||
ancilla = bb.allocate() | ||
x_split = bb.split(x) | ||
x_split, ancilla = bb.add( | ||
MultiControlX(cvs=[0] * int(self.n)), controls=x_split, target=ancilla | ||
) | ||
lam_r_split = bb.split(lam_r) | ||
lam_split = bb.split(lam) | ||
for i in range(int(self.n)): | ||
ctrls = [ctrl, ancilla, lam_r_split[i]] | ||
ctrls, lam_split[i] = bb.add( | ||
MultiControlX(cvs=[1, 1, 1]), controls=ctrls, target=lam_split[i] | ||
) | ||
ctrl = ctrls[0] | ||
ancilla = ctrls[1] | ||
lam_r_split[i] = ctrls[2] | ||
lam_r = bb.join(lam_r_split, dtype=QMontgomeryUInt(self.n)) | ||
lam = bb.join(lam_split, dtype=QMontgomeryUInt(self.n)) | ||
x_split, ancilla = bb.add( | ||
MultiControlX(cvs=[0] * int(self.n)), controls=x_split, target=ancilla | ||
) | ||
x = bb.join(x_split, dtype=QMontgomeryUInt(self.n)) | ||
bb.free(ancilla) | ||
bb.add(Free(QMontgomeryUInt(self.n)), reg=lam) | ||
|
||
# Uncompute multiplication and inverse. | ||
x, y = bb.add( | ||
|
@@ -756,9 +797,14 @@ def build_composite_bloq( | |
ctrl, b, y = bb.add(CModSub(QMontgomeryUInt(self.n), mod=self.mod), ctrl=ctrl, x=b, y=y) | ||
|
||
# Return the output registers. | ||
return {'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y} | ||
return {'ctrl': ctrl, 'a': a, 'b': b, 'x': x, 'y': y, 'lam_r': lam_r} | ||
|
||
def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: | ||
cvs: Union[list[int], HasLength] | ||
if isinstance(self.n, int): | ||
cvs = [0] * self.n | ||
else: | ||
cvs = HasLength(self.n) | ||
return { | ||
CModSub(QMontgomeryUInt(self.n), mod=self.mod): 1, | ||
KaliskiModInverse(bitsize=self.n, mod=self.mod): 1, | ||
|
@@ -771,6 +817,8 @@ def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: | |
KaliskiModInverse(bitsize=self.n, mod=self.mod).adjoint(): 1, | ||
ModAdd(self.n, mod=self.mod): 1, | ||
MultiControlX(cvs=[1, 1]): self.n, | ||
MultiControlX(cvs=cvs): 2, | ||
MultiControlX(cvs=[1, 1, 1]): self.n, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can be replaced by the above |
||
CModNeg(QMontgomeryUInt(self.n), mod=self.mod): 1, | ||
} | ||
|
||
|
@@ -863,6 +911,11 @@ def build_composite_bloq( | |
f3 = f_ctrls[1] | ||
f4 = f_ctrls[2] | ||
|
||
# Unset f2 if ((a, b) = (0, 0) AND y = 0) OR ((x, y) = (0, 0) AND b = 0). | ||
mcx = XGate().controlled(CtrlSpec(qdtypes=QMontgomeryUInt(self.n), cvs=[0, 0, 0])) | ||
[a, b, y], f2 = bb.add(mcx, ctrl=[a, b, y], q=f2) | ||
[x, y, b], f2 = bb.add(mcx, ctrl=[x, y, b], q=f2) | ||
|
||
# Set (x, y) to (a, b) if f4 is set. | ||
a_split = bb.split(a) | ||
x_split = bb.split(x) | ||
|
@@ -883,24 +936,6 @@ def build_composite_bloq( | |
b = bb.join(b_split, QMontgomeryUInt(self.n)) | ||
y = bb.join(y_split, QMontgomeryUInt(self.n)) | ||
|
||
# Unset f4 if (x, y) = (a, b). | ||
ab = bb.join(np.concatenate([bb.split(a), bb.split(b)]), dtype=QMontgomeryUInt(2 * self.n)) | ||
xy = bb.join(np.concatenate([bb.split(x), bb.split(y)]), dtype=QMontgomeryUInt(2 * self.n)) | ||
ab, xy, f4 = bb.add(Equals(QMontgomeryUInt(2 * self.n)), x=ab, y=xy, target=f4) | ||
ab_split = bb.split(ab) | ||
a = bb.join(ab_split[: int(self.n)], dtype=QMontgomeryUInt(self.n)) | ||
b = bb.join(ab_split[int(self.n) :], dtype=QMontgomeryUInt(self.n)) | ||
xy_split = bb.split(xy) | ||
x = bb.join(xy_split[: int(self.n)], dtype=QMontgomeryUInt(self.n)) | ||
y = bb.join(xy_split[int(self.n) :], dtype=QMontgomeryUInt(self.n)) | ||
|
||
# Unset f3 if (a, b) = (0, 0). | ||
ab_arr = np.concatenate([bb.split(a), bb.split(b)]) | ||
ab_arr, f3 = bb.add(MultiControlX(cvs=[0] * 2 * self.n), controls=ab_arr, target=f3) | ||
ab_arr = np.split(ab_arr, 2) | ||
a = bb.join(ab_arr[0], dtype=QMontgomeryUInt(self.n)) | ||
b = bb.join(ab_arr[1], dtype=QMontgomeryUInt(self.n)) | ||
|
||
# If f1 and f2 are set, subtract a from x and add b to y. | ||
ancilla = bb.add(ZeroState()) | ||
toff_ctrl = [f1, f2] | ||
|
@@ -923,6 +958,24 @@ def build_composite_bloq( | |
f2 = toff_ctrl[1] | ||
bb.add(Free(QBit()), reg=ancilla) | ||
|
||
# Unset f4 if (x, y) = (a, b). | ||
ab = bb.join(np.concatenate([bb.split(a), bb.split(b)]), dtype=QMontgomeryUInt(2 * self.n)) | ||
xy = bb.join(np.concatenate([bb.split(x), bb.split(y)]), dtype=QMontgomeryUInt(2 * self.n)) | ||
ab, xy, f4 = bb.add(Equals(QMontgomeryUInt(2 * self.n)), x=ab, y=xy, target=f4) | ||
ab_split = bb.split(ab) | ||
a = bb.join(ab_split[: int(self.n)], dtype=QMontgomeryUInt(self.n)) | ||
b = bb.join(ab_split[int(self.n) :], dtype=QMontgomeryUInt(self.n)) | ||
xy_split = bb.split(xy) | ||
x = bb.join(xy_split[: int(self.n)], dtype=QMontgomeryUInt(self.n)) | ||
y = bb.join(xy_split[int(self.n) :], dtype=QMontgomeryUInt(self.n)) | ||
|
||
# Unset f3 if (a, b) = (0, 0). | ||
ab_arr = np.concatenate([bb.split(a), bb.split(b)]) | ||
ab_arr, f3 = bb.add(MultiControlX(cvs=[0] * 2 * self.n), controls=ab_arr, target=f3) | ||
ab_arr = np.split(ab_arr, 2) | ||
a = bb.join(ab_arr[0], dtype=QMontgomeryUInt(self.n)) | ||
b = bb.join(ab_arr[1], dtype=QMontgomeryUInt(self.n)) | ||
|
||
# Unset f1 and f2 if (x, y) = (0, 0). | ||
xy_arr = np.concatenate([bb.split(x), bb.split(y)]) | ||
xy_arr, junk, out = bb.add(MultiAnd(cvs=[0] * 2 * self.n), ctrl=xy_arr) | ||
|
@@ -939,33 +992,35 @@ def build_composite_bloq( | |
y = bb.join(xy_arr[1], dtype=QMontgomeryUInt(self.n)) | ||
|
||
# Free all ancilla qubits in the zero state. | ||
# TODO(https://github.com/quantumlib/Qualtran/issues/1461): Fix bugs in circuit where f1, | ||
# f2, and f4 are freed before being set to 0. | ||
bb.add(Free(QBit(), dirty=True), reg=f1) | ||
bb.add(Free(QBit(), dirty=True), reg=f2) | ||
bb.add(Free(QBit()), reg=f1) | ||
bb.add(Free(QBit()), reg=f2) | ||
bb.add(Free(QBit()), reg=f3) | ||
bb.add(Free(QBit(), dirty=True), reg=f4) | ||
bb.add(Free(QBit()), reg=f4) | ||
bb.add(Free(QBit()), reg=ctrl) | ||
|
||
# Return the output registers. | ||
return {'a': a, 'b': b, 'x': x, 'y': y} | ||
|
||
def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT: | ||
cvs: Union[list[int], HasLength] | ||
cvs2: Union[list[int], HasLength] | ||
cvs3: Union[list[int], HasLength] | ||
if isinstance(self.n, int): | ||
cvs = [0] * 2 * self.n | ||
cvs2 = [0] * 2 * self.n | ||
cvs3 = [0] * 3 * self.n | ||
else: | ||
cvs = HasLength(2 * self.n) | ||
cvs2 = HasLength(2 * self.n) | ||
cvs3 = HasLength(3 * self.n) | ||
return { | ||
MultiControlX(cvs=cvs): 1, | ||
MultiControlX(cvs=cvs2): 1, | ||
MultiControlX(cvs=cvs3): 2, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this the newly refactored |
||
MultiControlX(cvs=[0] * 3): 1, | ||
CModSub(QMontgomeryUInt(self.n), mod=self.mod): 1, | ||
CModAdd(QMontgomeryUInt(self.n), mod=self.mod): 1, | ||
Toffoli(): 2 * self.n + 4, | ||
Equals(QMontgomeryUInt(2 * self.n)): 1, | ||
MultiAnd(cvs=cvs): 1, | ||
MultiAnd(cvs=cvs2): 1, | ||
MultiTargetCNOT(2): 1, | ||
MultiAnd(cvs=cvs).adjoint(): 1, | ||
MultiAnd(cvs=cvs2).adjoint(): 1, | ||
} | ||
|
||
|
||
|
@@ -1044,13 +1099,14 @@ def build_composite_bloq( | |
x, y, lam = bb.add( | ||
_ECAddStepFour(n=self.n, mod=self.mod, window_size=self.window_size), x=x, y=y, lam=lam | ||
) | ||
ctrl, a, b, x, y = bb.add( | ||
ctrl, a, b, x, y, lam_r = bb.add( | ||
_ECAddStepFive(n=self.n, mod=self.mod, window_size=self.window_size), | ||
ctrl=ctrl, | ||
a=a, | ||
b=b, | ||
x=x, | ||
y=y, | ||
lam_r=lam_r, | ||
lam=lam, | ||
) | ||
a, b, x, y = bb.add( | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can also factor the
clear_lam
bloq into a class property, so it can also be used in the call graph.