Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 82 additions & 1 deletion cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2160,6 +2160,81 @@ def _can_add_op_at(self, moment_index: int, operation: cirq.Operation) -> bool:

return not self._moments[moment_index].operates_on(operation.qubits)

def _latest_available_moment(self, op: cirq.Operation, *, start_moment_index: int = 0) -> int:
"""Finds the index of the latest (i.e. right most) moment which can accommodate `op`.

Assumes that `start_moment_index` is between 0 and `len(self._moments)`.

Args:
op: Operation for which the latest moment that can accommodate it needs to be found.
start_moment_index: The starting point of the reverse search. Defaults to 0.

Returns:
Index of the latest matching moment. Returns `start_moment_index - 1` if no moment on
the right is available, or `len(self._moments)` if `start_moment_index` equals it.
"""
if start_moment_index == len(self._moments):
return start_moment_index
op_control_keys = protocols.control_keys(op)
op_measurement_keys = protocols.measurement_key_objs(op)
op_qubits = op.qubits
k = start_moment_index
while k < len(self._moments):
moment = self._moments[k]
if moment.operates_on(op_qubits):
return k - 1
moment_measurement_keys = moment._measurement_key_objs_()
if not (
op_measurement_keys.isdisjoint(moment_measurement_keys)
and op_control_keys.isdisjoint(moment_measurement_keys)
and moment._control_keys_().isdisjoint(op_measurement_keys)
):
return k - 1
k += 1
return k - 1

def _insert_latest(self, k: int, batches: list[list[_MOMENT_OR_OP]]) -> int:
"""Inserts batches of moments or operations using LATEST strategy.

Batches are processed in reverse order.
Operations are inserted into the latest possible moment from the starting
index k. Moments are inserted intact at index k.

Args:
k: The index to insert the batches at.
batches: Moments or operations to insert.

Returns:
The insertion index that will place operations just after the
operations that were inserted by this method. Returns k if batches
is empty.
"""
max_latest_index = -1 # Maximum index of a changed moment
for batch in reversed(batches):
for moment_or_op in batch:
if isinstance(moment_or_op, Moment):
self._moments.insert(k, moment_or_op)
# All later moments shift by 1 when the new moment is inserted
max_latest_index = max(k, max_latest_index + 1)
else:
end_moment_index = len(self.moments)
p = self._latest_available_moment(moment_or_op, start_moment_index=k)
if p < k:
self._moments.insert(k, Moment.from_ops(moment_or_op))
max_latest_index = max(k, max_latest_index + 1)
elif p < end_moment_index:
self._moments[p] = self._moments[p].with_operation(moment_or_op)
max_latest_index = max(p, max_latest_index)
else:
assert p == end_moment_index
self._moments.append(Moment.from_ops(moment_or_op))
max_latest_index = end_moment_index
# handle returned position index for empty batch
pos = k if max_latest_index == -1 else max_latest_index + 1
if max_latest_index != -1:
self._mutated(preserve_placement_cache=False)
return pos

def insert(
self,
index: int,
Expand Down Expand Up @@ -2195,6 +2270,10 @@ def insert(
batches = [[mop] for mop in mops] # Each op goes into its own moment.
else:
batches = list(_group_into_moment_compatible(mops))

if strategy is InsertStrategy.LATEST:
return self._insert_latest(k, batches)

for batch in batches:
# Insert a moment if inline/earliest and _any_ op in the batch requires it.
if (
Expand Down Expand Up @@ -2222,8 +2301,10 @@ def insert(
p = k
elif strategy is InsertStrategy.INLINE:
p = k - 1
else: # InsertStrategy.EARLIEST:
elif strategy is InsertStrategy.EARLIEST:
p = self.earliest_available_moment(moment_or_op, end_moment_index=k)
else:
raise ValueError('Unknown insertion strategy')
# Place
if isinstance(moment_or_op, Moment):
self._moments.insert(p, moment_or_op)
Expand Down
76 changes: 76 additions & 0 deletions cirq-core/cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4987,3 +4987,79 @@ def test_tagged_circuits() -> None:
assert (
circuit2.concat_ragged(tagged_circuit).tags == ()
) # We only preserve the tags for the first one


def test_latest_available_moment() -> None:
q = cirq.LineQubit.range(3)
c = cirq.Circuit(
cirq.Moment(cirq.measure(q[0], key="m")),
cirq.Moment(cirq.X(q[1]).with_classical_controls("m")),
)
assert c._latest_available_moment(cirq.Y(q[0])) == -1
assert c._latest_available_moment(cirq.Y(q[1])) == 0
assert c._latest_available_moment(cirq.Y(q[2])) == 1
assert c._latest_available_moment(cirq.Y(q[2]).with_classical_controls("m")) == -1
assert (
c._latest_available_moment(cirq.Y(q[2]).with_classical_controls("m"), start_moment_index=1)
== 1
)
# Defaults to len(moments) if start_moment_index == len(moments)
assert c._latest_available_moment(cirq.Y(q[1]), start_moment_index=2) == 2
# Y(q[1]) can be in the same moment as X(q[1])
assert c._latest_available_moment(cirq.Y(q[1]), start_moment_index=1) == 0
# A measurement on q[0] with different key can be in moment 1
assert c._latest_available_moment(cirq.measure(q[0], key="n"), start_moment_index=1) == 1
# A measurement on q[0] with the same key can't be in moment 1
assert c._latest_available_moment(cirq.measure(q[0], key="m"), start_moment_index=1) == 0
assert c._latest_available_moment(cirq.measure(q[0], key="m"), start_moment_index=0) == -1


def test_insert_op_tree_latest() -> None:
q = cirq.LineQubit.range(3)

op_tree_list = [
(0, [0, 1], [cirq.X(q[0]), cirq.X(q[1])], [q[0], q[1]], 2),
(0, [2], [cirq.X(q[2])], [q[2]], 3),
(2, [2], [cirq.Y(q[1])], [q[1]], 3),
(
1,
[1, 1, 3],
[cirq.measure(q[0], key="m"), cirq.Y(q[1]), cirq.Z(q[2])],
[q[0], q[1], q[2]],
4,
),
(0, [0], [cirq.measure(q[0], key="n")], [q[0]], 1),
(1, [2], [cirq.X(q[2]).with_classical_controls("m")], [q[2]], 3),
(0, [0, 1, 2], [cirq.X(q[2]), cirq.Y(q[2]), cirq.H(q[2])], [q[2], q[2], q[2]], 3),
(3, [3], [cirq.H(q[1])], [q[1]], 4),
]

for insert_index, result_indices, op_list, qubits, index_after in op_tree_list:
c = cirq.Circuit(
cirq.Moment(cirq.measure(q[0], key="m")),
cirq.Moment(cirq.X(q[1]).with_classical_controls("m")),
cirq.Moment([cirq.H(q[1])]),
)
assert c.insert(insert_index, op_list, cirq.InsertStrategy.LATEST) == index_after
for i in range(len(op_list)):
assert c.operation_at(qubits[i], result_indices[i]) == op_list[i]


def test_insert_moments_and_ops_latest() -> None:
q = cirq.LineQubit.range(3)

moments_and_ops_list = [
(0, [cirq.Moment(cirq.H(q[2])), cirq.Moment(cirq.X(q[2]))], 2),
(0, [cirq.Moment(cirq.H(q[2])), cirq.X(q[2])], 4),
(1, [cirq.X(q[0]), cirq.Moment(cirq.Y(q[1]))], 4),
(1, [cirq.Y(q[1]), cirq.Moment(cirq.Y(q[2]))], 2),
(1, [], 1),
]

for insert_index, moments_and_ops, index_after in moments_and_ops_list:
c = cirq.Circuit(
cirq.Moment(cirq.measure(q[0], key="m")),
cirq.Moment(cirq.X(q[1]).with_classical_controls("m")),
cirq.Moment([cirq.H(q[1])]),
)
assert c.insert(insert_index, moments_and_ops, cirq.InsertStrategy.LATEST) == index_after
18 changes: 18 additions & 0 deletions cirq-core/cirq/circuits/insert_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class InsertStrategy:
NEW_THEN_INLINE: InsertStrategy
INLINE: InsertStrategy
EARLIEST: InsertStrategy
LATEST: InsertStrategy

def __new__(cls, name: str, doc: str) -> InsertStrategy:
inst = getattr(cls, name, None)
Expand Down Expand Up @@ -93,3 +94,20 @@ def __repr__(self) -> str:
at the desired location.
""",
)

InsertStrategy.LATEST = InsertStrategy(
'LATEST',
"""
Scans forward from the insert location until a moment with operations
touching qubits affected by the operation to insert is found. The operation
is added into the moment just before that conflicting location.

If the scan reaches the end of the circuit without finding any conflicting
operations, the operation is added into the last moment of the circuit
if possible, otherwise in a new moment at the end.

The operation is never added into moments before the initial insert
location. If the moment at the initial insert location has conflicting
operations, the operation is added into a new moment before it.
""",
)
1 change: 1 addition & 0 deletions cirq-core/cirq/circuits/insert_strategy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def test_repr() -> None:
cirq.InsertStrategy.NEW_THEN_INLINE,
cirq.InsertStrategy.INLINE,
cirq.InsertStrategy.EARLIEST,
cirq.InsertStrategy.LATEST,
],
ids=lambda strategy: strategy.name,
)
Expand Down