From 7f3cd452420bee32b9d4dc538fd633beed8edeed Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Tue, 16 Dec 2025 18:17:21 +0200 Subject: [PATCH 1/2] Fix RuntimeError in ClusterPubSub sharded message generator --- redis/cluster.py | 19 ++++++- tests/test_pubsub.py | 116 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 133 insertions(+), 2 deletions(-) diff --git a/redis/cluster.py b/redis/cluster.py index 005206a725..ad9a89ada2 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -2199,16 +2199,31 @@ def _get_node_pubsub(self, node): return pubsub def _sharded_message_generator(self): - for _ in range(len(self.node_pubsub_mapping)): + """ + Iterate through pubsubs until a complete cycle is done. + """ + while True: pubsub = next(self._pubsubs_generator) + + # None marks end of cycle + if pubsub is None: + break + message = pubsub.get_message() if message is not None: return message + return None def _pubsubs_generator(self): + """ + Generator that yields pubsubs in round-robin fashion. + Yields None to mark cycle boundaries. + """ while True: - yield from self.node_pubsub_mapping.values() + current_nodes = list(self.node_pubsub_mapping.values()) + yield from current_nodes + yield None # Cycle marker def get_sharded_message( self, ignore_subscribe_messages=False, timeout=0.0, target_node=None diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index db313e2437..5c37b8f38e 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -871,6 +871,122 @@ def test_pubsub_shardnumsub(self, r): channels = [(b"foo", 1), (b"bar", 2), (b"baz", 3)] assert r.pubsub_shardnumsub("foo", "bar", "baz", target_nodes="all") == channels + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_ssubscribe_multiple_channels_different_nodes(self, r): + """ + Test subscribing to multiple sharded channels on different nodes. + Validates that the generator properly handles multiple node_pubsub_mapping entries. + """ + pubsub = r.pubsub() + channel1 = "test-channel:{0}" + channel2 = "test-channel:{6}" + + # Subscribe to first channel + pubsub.ssubscribe(channel1) + msg = wait_for_message(pubsub, timeout=1.0, func=pubsub.get_sharded_message) + assert msg is not None + assert msg["type"] == "ssubscribe" + + # Subscribe to second channel (likely different node) + pubsub.ssubscribe(channel2) + msg = wait_for_message(pubsub, timeout=1.0, func=pubsub.get_sharded_message) + assert msg is not None + assert msg["type"] == "ssubscribe" + + # Verify both channels are in shard_channels + assert channel1.encode() in pubsub.shard_channels + assert channel2.encode() in pubsub.shard_channels + + pubsub.close() + + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_ssubscribe_multiple_channels_publish_and_read(self, r): + """ + Test publishing to multiple sharded channels and reading messages. + Validates that _sharded_message_generator properly cycles through + multiple node_pubsub_mapping entries. + """ + pubsub = r.pubsub() + channel1 = "test-channel:{0}" + channel2 = "test-channel:{6}" + msg1_data = "message-1" + msg2_data = "message-2" + + # Subscribe to both channels + pubsub.ssubscribe(channel1, channel2) + + # Read subscription confirmations + for _ in range(2): + msg = wait_for_message(pubsub, timeout=1.0, func=pubsub.get_sharded_message) + assert msg is not None + assert msg["type"] == "ssubscribe" + + # Publish messages to both channels + r.spublish(channel1, msg1_data) + r.spublish(channel2, msg2_data) + + # Read messages - should get both messages + messages = [] + for _ in range(2): + msg = wait_for_message(pubsub, timeout=1.0, func=pubsub.get_sharded_message) + assert msg is not None + assert msg["type"] == "smessage" + messages.append(msg) + + # Verify we got messages from both channels + channels_received = {msg["channel"] for msg in messages} + assert channel1.encode() in channels_received + assert channel2.encode() in channels_received + + pubsub.close() + + @pytest.mark.onlycluster + @skip_if_server_version_lt("7.0.0") + def test_generator_handles_concurrent_mapping_changes(self, r): + """ + Test that the generator properly handles mapping changes during iteration. + This validates the fix for the RuntimeError: dictionary changed size during iteration. + """ + pubsub = r.pubsub() + channel1 = "test-channel:{0}" + channel2 = "test-channel:{6}" + + # Subscribe to first channel + pubsub.ssubscribe(channel1) + msg = wait_for_message(pubsub, timeout=1.0, func=pubsub.get_sharded_message) + assert msg is not None + assert msg["type"] == "ssubscribe" + + # Get initial mapping size (if available) + initial_size = 0 + if hasattr(pubsub, "node_pubsub_mapping"): + initial_size = len(pubsub.node_pubsub_mapping) + + # Subscribe to second channel (modifies mapping during potential iteration) + pubsub.ssubscribe(channel2) + msg = wait_for_message(pubsub, timeout=1.0, func=pubsub.get_sharded_message) + assert msg is not None + assert msg["type"] == "ssubscribe" + + # Verify mapping was updated (if available) + if hasattr(pubsub, "node_pubsub_mapping"): + assert len(pubsub.node_pubsub_mapping) >= initial_size + + # Publish and read messages - should not raise RuntimeError + r.spublish(channel1, "msg1") + r.spublish(channel2, "msg2") + + messages_received = 0 + for _ in range(2): + msg = wait_for_message(pubsub, timeout=1.0, func=pubsub.get_sharded_message) + if msg and msg["type"] == "smessage": + messages_received += 1 + + assert messages_received == 2 + pubsub.close() + class TestPubSubPings: @skip_if_server_version_lt("3.0.0") From 50f8ab39d2a092cfc7c6e21b522b5a84d9e38c5d Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Wed, 17 Dec 2025 11:35:10 +0200 Subject: [PATCH 2/2] Applying review comments --- redis/cluster.py | 16 +--------------- tests/test_pubsub.py | 12 +++++------- 2 files changed, 6 insertions(+), 22 deletions(-) diff --git a/redis/cluster.py b/redis/cluster.py index ad9a89ada2..dabac841db 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -2199,31 +2199,17 @@ def _get_node_pubsub(self, node): return pubsub def _sharded_message_generator(self): - """ - Iterate through pubsubs until a complete cycle is done. - """ - while True: + for _ in range(len(self.node_pubsub_mapping)): pubsub = next(self._pubsubs_generator) - - # None marks end of cycle - if pubsub is None: - break - message = pubsub.get_message() if message is not None: return message - return None def _pubsubs_generator(self): - """ - Generator that yields pubsubs in round-robin fashion. - Yields None to mark cycle boundaries. - """ while True: current_nodes = list(self.node_pubsub_mapping.values()) yield from current_nodes - yield None # Cycle marker def get_sharded_message( self, ignore_subscribe_messages=False, timeout=0.0, target_node=None diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 5c37b8f38e..ef71d6a30c 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -959,10 +959,9 @@ def test_generator_handles_concurrent_mapping_changes(self, r): assert msg is not None assert msg["type"] == "ssubscribe" - # Get initial mapping size (if available) - initial_size = 0 - if hasattr(pubsub, "node_pubsub_mapping"): - initial_size = len(pubsub.node_pubsub_mapping) + # Get initial mapping size (cluster pubsub only) + assert hasattr(pubsub, "node_pubsub_mapping"), "Test requires ClusterPubSub" + initial_size = len(pubsub.node_pubsub_mapping) # Subscribe to second channel (modifies mapping during potential iteration) pubsub.ssubscribe(channel2) @@ -970,9 +969,8 @@ def test_generator_handles_concurrent_mapping_changes(self, r): assert msg is not None assert msg["type"] == "ssubscribe" - # Verify mapping was updated (if available) - if hasattr(pubsub, "node_pubsub_mapping"): - assert len(pubsub.node_pubsub_mapping) >= initial_size + # Verify mapping was updated + assert len(pubsub.node_pubsub_mapping) >= initial_size # Publish and read messages - should not raise RuntimeError r.spublish(channel1, "msg1")