Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cirq-core/cirq/protocols/approximate_equality_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,11 @@ def _approx_eq_iterables(val: Iterable, other: Iterable, *, atol: float) -> bool
types.
"""

if isinstance(val, set):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(val, set):
if isinstance(val, (frozenset, set)):

And the same below.

val = sorted(val)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will throw TypeError for set sa0 = {"a", 0}, but such set should be approx_eq to itself - approx_eq(sa0, sa0) is True.

Please add a test for unsortable sets, ie, that asserts cirq.approx_eq({"a", 0}, {"a", 0}) is true as expected.

Suggested change
val = sorted(val)
try:
val = sorted(val)
except TypeError:
return NotImplemented

and the same for other below.

You will also need to change the caller to correctly handle a NotImplemented value here -

# If the values are iterable, try comparing recursively on items.
if isinstance(val, Iterable) and isinstance(other, Iterable):
return _approx_eq_iterables(val, other, atol=atol)

if isinstance(other, set):
other = sorted(other)

iter1 = iter(val)
iter2 = iter(other)
done = object()
Expand Down
47 changes: 27 additions & 20 deletions cirq-core/cirq/protocols/approximate_equality_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def test_approx_eq_primitives() -> None:
assert not cirq.approx_eq(0.0, 1e-10, atol=1e-11)
assert cirq.approx_eq(complex(1, 1), complex(1.1, 1.2), atol=0.3)
assert not cirq.approx_eq(complex(1, 1), complex(1.1, 1.2), atol=0.1)
assert cirq.approx_eq('ab', 'ab', atol=1e-3)
assert not cirq.approx_eq('ab', 'ac', atol=1e-3)
assert not cirq.approx_eq('1', '2', atol=999)
assert not cirq.approx_eq('test', 1, atol=1e-3)
assert not cirq.approx_eq('1', 1, atol=1e-3)
assert cirq.approx_eq("ab", "ab", atol=1e-3)
assert not cirq.approx_eq("ab", "ac", atol=1e-3)
assert not cirq.approx_eq("1", "2", atol=999)
assert not cirq.approx_eq("test", 1, atol=1e-3)
assert not cirq.approx_eq("1", 1, atol=1e-3)
Comment on lines +36 to +40
Copy link
Collaborator

@pavoljuhas pavoljuhas Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please revert all formatting changes that convert single to double quotes. They make it harder to review the code and also make git-blame less useful.

We may decide to standardize double quotes later, but that would need its own PR (which can be added to .git-blame-ignore-revs).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure thing, I accidentally called ruff on it. I'll revert them in my commit.



def test_approx_eq_mixed_primitives() -> None:
Expand All @@ -62,15 +62,15 @@ def test_numpy_dtype_compatibility() -> None:

f_a, f_b, f_c = 0, 1e-8, 1
f_types = [np.float16, np.float32, np.float64]
if hasattr(np, 'float128'):
if hasattr(np, "float128"):
f_types.append(np.float128)
for f_type in f_types:
assert cirq.approx_eq(f_type(f_a), f_type(f_b), atol=1e-8)
assert not cirq.approx_eq(f_type(f_a), f_type(f_c), atol=1e-8)

c_a, c_b, c_c = 0, 1e-8j, 1j
c_types = [np.complex64, np.complex128]
if hasattr(np, 'complex256'):
if hasattr(np, "complex256"):
c_types.append(np.complex256)
for c_type in c_types:
assert cirq.approx_eq(c_type(c_a), c_type(c_b), atol=1e-8)
Expand All @@ -83,33 +83,33 @@ def test_fractions_compatibility() -> None:


def test_decimal_compatibility() -> None:
assert cirq.approx_eq(Decimal('0'), Decimal('0.0000000001'), atol=1e-9)
assert not cirq.approx_eq(Decimal('0'), Decimal('0.00000001'), atol=1e-9)
assert not cirq.approx_eq(Decimal('NaN'), Decimal('-Infinity'), atol=1e-9)
assert cirq.approx_eq(Decimal("0"), Decimal("0.0000000001"), atol=1e-9)
assert not cirq.approx_eq(Decimal("0"), Decimal("0.00000001"), atol=1e-9)
assert not cirq.approx_eq(Decimal("NaN"), Decimal("-Infinity"), atol=1e-9)


def test_approx_eq_mixed_types() -> None:
assert cirq.approx_eq(np.float32(1), 1.0 + 1e-10, atol=1e-9)
assert cirq.approx_eq(np.float64(1), np.complex64(1 + 1e-8j), atol=1e-4)
assert cirq.approx_eq(np.uint8(1), np.complex64(1 + 1e-8j), atol=1e-4)
if hasattr(np, 'complex256'):
if hasattr(np, "complex256"):
assert cirq.approx_eq(np.complex256(1), complex(1, 1e-8), atol=1e-4)
assert cirq.approx_eq(np.int32(1), 1, atol=1e-9)
assert cirq.approx_eq(complex(0.5, 0), Fraction(1, 2), atol=0.0)
assert cirq.approx_eq(0.5 + 1e-4j, Fraction(1, 2), atol=1e-4)
assert cirq.approx_eq(0, Fraction(1, 100000000), atol=1e-8)
assert cirq.approx_eq(np.uint16(1), Decimal('1'), atol=0.0)
assert cirq.approx_eq(np.float64(1.0), Decimal('1.00000001'), atol=1e-8)
assert not cirq.approx_eq(np.complex64(1e-5j), Decimal('0.001'), atol=1e-4)
assert cirq.approx_eq(np.uint16(1), Decimal("1"), atol=0.0)
assert cirq.approx_eq(np.float64(1.0), Decimal("1.00000001"), atol=1e-8)
assert not cirq.approx_eq(np.complex64(1e-5j), Decimal("0.001"), atol=1e-4)


def test_approx_eq_special_numerics() -> None:
assert not cirq.approx_eq(float('nan'), 0, atol=0.0)
assert not cirq.approx_eq(float('nan'), float('nan'), atol=0.0)
assert not cirq.approx_eq(float('inf'), float('-inf'), atol=0.0)
assert not cirq.approx_eq(float('inf'), 5, atol=0.0)
assert not cirq.approx_eq(float('inf'), 0, atol=0.0)
assert cirq.approx_eq(float('inf'), float('inf'), atol=0.0)
assert not cirq.approx_eq(float("nan"), 0, atol=0.0)
assert not cirq.approx_eq(float("nan"), float("nan"), atol=0.0)
assert not cirq.approx_eq(float("inf"), float("-inf"), atol=0.0)
assert not cirq.approx_eq(float("inf"), 5, atol=0.0)
assert not cirq.approx_eq(float("inf"), 0, atol=0.0)
assert cirq.approx_eq(float("inf"), float("inf"), atol=0.0)


class X(Number):
Expand Down Expand Up @@ -152,6 +152,13 @@ def test_approx_eq_tuple() -> None:
assert not cirq.approx_eq((1.1, 1.2, 1.3), (1, 1, 1), atol=0.2)


def test_approx_eq_frozenset() -> None:
for n in range(10, 20):
assert cirq.approx_eq(
frozenset(cirq.LineQubit.range(n)), frozenset({*cirq.LineQubit.range(n)})
)
Comment on lines +156 to +159
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not test the new code as shown by coverage error. set and frozenset use the same hash and are constructed from the same sequence so this test works without sorting too:

for n in range(10, 20):
    assert list(frozenset(cirq.LineQubit.range(n))) == list(frozenset({*cirq.LineQubit.range(n)}))

To test the new code, you need to find sets for which the iteration sequence changes with insertion order:

found_differently_ordered_sets = False
for i in range(20):
    for j in range(i + 1, 20):
        sij = {cirq.q(i), cirq.q(j)}
        sji = {cirq.q(j), cirq.q(i)}
        if list(sij) != list(sji):
            found_differently_ordered_sets = True
            assert cirq.approx_eq(sij, sji)
            assert cirq.approx_eq(frozenset(sij), sji)

assert found_differently_ordered_sets



def test_approx_eq_list() -> None:
assert cirq.approx_eq([], [], atol=0.0)
assert not cirq.approx_eq([], [[]], atol=0.0)
Expand Down
Loading