Skip to content

Commit e05ef48

Browse files
noah-chaedvora-hzach-iee
authored
Cherry pick sharded pubsub commit into 4.5 (#9)
* resolve conflict * fix import order --------- Co-authored-by: dvora-h <[email protected]> Co-authored-by: zach.lee <[email protected]>
1 parent 1064caa commit e05ef48

File tree

6 files changed

+559
-37
lines changed

6 files changed

+559
-37
lines changed

redis/client.py

Lines changed: 73 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,7 @@ class AbstractRedis:
808808
"QUIT": bool_ok,
809809
"STRALGO": parse_stralgo,
810810
"PUBSUB NUMSUB": parse_pubsub_numsub,
811+
"PUBSUB SHARDNUMSUB": parse_pubsub_numsub,
811812
"RANDOMKEY": lambda r: r and r or None,
812813
"RESET": str_if_bytes,
813814
"SCAN": parse_scan,
@@ -1376,8 +1377,8 @@ class PubSub:
13761377
will be returned and it's safe to start listening again.
13771378
"""
13781379

1379-
PUBLISH_MESSAGE_TYPES = ("message", "pmessage")
1380-
UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe")
1380+
PUBLISH_MESSAGE_TYPES = ("message", "pmessage", "smessage")
1381+
UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe", "sunsubscribe")
13811382
HEALTH_CHECK_MESSAGE = "redis-py-health-check"
13821383

13831384
def __init__(
@@ -1425,9 +1426,11 @@ def reset(self):
14251426
self.connection.clear_connect_callbacks()
14261427
self.connection_pool.release(self.connection)
14271428
self.connection = None
1428-
self.channels = {}
14291429
self.health_check_response_counter = 0
1430+
self.channels = {}
14301431
self.pending_unsubscribe_channels = set()
1432+
self.shard_channels = {}
1433+
self.pending_unsubscribe_shard_channels = set()
14311434
self.patterns = {}
14321435
self.pending_unsubscribe_patterns = set()
14331436
self.subscribed_event.clear()
@@ -1442,16 +1445,23 @@ def on_connect(self, connection):
14421445
# before passing them to [p]subscribe.
14431446
self.pending_unsubscribe_channels.clear()
14441447
self.pending_unsubscribe_patterns.clear()
1448+
self.pending_unsubscribe_shard_channels.clear()
14451449
if self.channels:
1446-
channels = {}
1447-
for k, v in self.channels.items():
1448-
channels[self.encoder.decode(k, force=True)] = v
1450+
channels = {
1451+
self.encoder.decode(k, force=True): v for k, v in self.channels.items()
1452+
}
14491453
self.subscribe(**channels)
14501454
if self.patterns:
1451-
patterns = {}
1452-
for k, v in self.patterns.items():
1453-
patterns[self.encoder.decode(k, force=True)] = v
1455+
patterns = {
1456+
self.encoder.decode(k, force=True): v for k, v in self.patterns.items()
1457+
}
14541458
self.psubscribe(**patterns)
1459+
if self.shard_channels:
1460+
shard_channels = {
1461+
self.encoder.decode(k, force=True): v
1462+
for k, v in self.shard_channels.items()
1463+
}
1464+
self.ssubscribe(**shard_channels)
14551465

14561466
@property
14571467
def subscribed(self):
@@ -1658,6 +1668,45 @@ def unsubscribe(self, *args):
16581668
self.pending_unsubscribe_channels.update(channels)
16591669
return self.execute_command("UNSUBSCRIBE", *args)
16601670

1671+
def ssubscribe(self, *args, target_node=None, **kwargs):
1672+
"""
1673+
Subscribes the client to the specified shard channels.
1674+
Channels supplied as keyword arguments expect a channel name as the key
1675+
and a callable as the value. A channel's callable will be invoked automatically
1676+
when a message is received on that channel rather than producing a message via
1677+
``listen()`` or ``get_sharded_message()``.
1678+
"""
1679+
if args:
1680+
args = list_or_args(args[0], args[1:])
1681+
new_s_channels = dict.fromkeys(args)
1682+
new_s_channels.update(kwargs)
1683+
ret_val = self.execute_command("SSUBSCRIBE", *new_s_channels.keys())
1684+
# update the s_channels dict AFTER we send the command. we don't want to
1685+
# subscribe twice to these channels, once for the command and again
1686+
# for the reconnection.
1687+
new_s_channels = self._normalize_keys(new_s_channels)
1688+
self.shard_channels.update(new_s_channels)
1689+
if not self.subscribed:
1690+
# Set the subscribed_event flag to True
1691+
self.subscribed_event.set()
1692+
# Clear the health check counter
1693+
self.health_check_response_counter = 0
1694+
self.pending_unsubscribe_shard_channels.difference_update(new_s_channels)
1695+
return ret_val
1696+
1697+
def sunsubscribe(self, *args, target_node=None):
1698+
"""
1699+
Unsubscribe from the supplied shard_channels. If empty, unsubscribe from
1700+
all shard_channels
1701+
"""
1702+
if args:
1703+
args = list_or_args(args[0], args[1:])
1704+
s_channels = self._normalize_keys(dict.fromkeys(args))
1705+
else:
1706+
s_channels = self.shard_channels
1707+
self.pending_unsubscribe_shard_channels.update(s_channels)
1708+
return self.execute_command("SUNSUBSCRIBE", *args)
1709+
16611710
def listen(self):
16621711
"Listen for messages on channels this client has been subscribed to"
16631712
while self.subscribed:
@@ -1692,6 +1741,8 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0.0):
16921741
return self.handle_message(response, ignore_subscribe_messages)
16931742
return None
16941743

1744+
get_sharded_message = get_message
1745+
16951746
def ping(self, message=None):
16961747
"""
16971748
Ping the Redis server
@@ -1737,12 +1788,17 @@ def handle_message(self, response, ignore_subscribe_messages=False):
17371788
if pattern in self.pending_unsubscribe_patterns:
17381789
self.pending_unsubscribe_patterns.remove(pattern)
17391790
self.patterns.pop(pattern, None)
1791+
elif message_type == "sunsubscribe":
1792+
s_channel = response[1]
1793+
if s_channel in self.pending_unsubscribe_shard_channels:
1794+
self.pending_unsubscribe_shard_channels.remove(s_channel)
1795+
self.shard_channels.pop(s_channel, None)
17401796
else:
17411797
channel = response[1]
17421798
if channel in self.pending_unsubscribe_channels:
17431799
self.pending_unsubscribe_channels.remove(channel)
17441800
self.channels.pop(channel, None)
1745-
if not self.channels and not self.patterns:
1801+
if not self.channels and not self.patterns and not self.shard_channels:
17461802
# There are no subscriptions anymore, set subscribed_event flag
17471803
# to false
17481804
self.subscribed_event.clear()
@@ -1751,6 +1807,8 @@ def handle_message(self, response, ignore_subscribe_messages=False):
17511807
# if there's a message handler, invoke it
17521808
if message_type == "pmessage":
17531809
handler = self.patterns.get(message["pattern"], None)
1810+
elif message_type == "smessage":
1811+
handler = self.shard_channels.get(message["channel"], None)
17541812
else:
17551813
handler = self.channels.get(message["channel"], None)
17561814
if handler:
@@ -1771,6 +1829,11 @@ def run_in_thread(self, sleep_time=0, daemon=False, exception_handler=None):
17711829
for pattern, handler in self.patterns.items():
17721830
if handler is None:
17731831
raise PubSubError(f"Pattern: '{pattern}' has no handler registered")
1832+
for s_channel, handler in self.shard_channels.items():
1833+
if handler is None:
1834+
raise PubSubError(
1835+
f"Shard Channel: '{s_channel}' has no handler registered"
1836+
)
17741837

17751838
thread = PubSubWorkerThread(
17761839
self, sleep_time, daemon=daemon, exception_handler=exception_handler

redis/cluster.py

Lines changed: 101 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from redis.backoff import default_backoff
1010
from redis.client import CaseInsensitiveDict, PubSub, Redis, parse_scan
1111
from redis.commands import READ_COMMANDS, CommandsParser, RedisClusterCommands
12+
from redis.commands.helpers import list_or_args
1213
from redis.connection import ConnectionPool, DefaultParser, Encoder, parse_url
1314
from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
1415
from redis.exceptions import (
@@ -227,6 +228,8 @@ class AbstractRedisCluster:
227228
"PUBSUB CHANNELS",
228229
"PUBSUB NUMPAT",
229230
"PUBSUB NUMSUB",
231+
"PUBSUB SHARDCHANNELS",
232+
"PUBSUB SHARDNUMSUB",
230233
"PING",
231234
"INFO",
232235
"SHUTDOWN",
@@ -352,11 +355,13 @@ class AbstractRedisCluster:
352355
}
353356

354357
RESULT_CALLBACKS = dict_merge(
355-
list_keys_to_dict(["PUBSUB NUMSUB"], parse_pubsub_numsub),
358+
list_keys_to_dict(["PUBSUB NUMSUB", "PUBSUB SHARDNUMSUB"], parse_pubsub_numsub),
356359
list_keys_to_dict(
357360
["PUBSUB NUMPAT"], lambda command, res: sum(list(res.values()))
358361
),
359-
list_keys_to_dict(["KEYS", "PUBSUB CHANNELS"], merge_result),
362+
list_keys_to_dict(
363+
["KEYS", "PUBSUB CHANNELS", "PUBSUB SHARDCHANNELS"], merge_result
364+
),
360365
list_keys_to_dict(
361366
[
362367
"PING",
@@ -1685,6 +1690,8 @@ def __init__(self, redis_cluster, node=None, host=None, port=None, **kwargs):
16851690
else redis_cluster.get_redis_connection(self.node).connection_pool
16861691
)
16871692
self.cluster = redis_cluster
1693+
self.node_pubsub_mapping = {}
1694+
self._pubsubs_generator = self._pubsubs_generator()
16881695
super().__init__(
16891696
**kwargs, connection_pool=connection_pool, encoder=redis_cluster.encoder
16901697
)
@@ -1738,9 +1745,9 @@ def _raise_on_invalid_node(self, redis_cluster, node, host, port):
17381745
f"Node {host}:{port} doesn't exist in the cluster"
17391746
)
17401747

1741-
def execute_command(self, *args, **kwargs):
1748+
def execute_command(self, *args):
17421749
"""
1743-
Execute a publish/subscribe command.
1750+
Execute a subscribe/unsubscribe command.
17441751
17451752
Taken code from redis-py and tweak to make it work within a cluster.
17461753
"""
@@ -1773,13 +1780,103 @@ def execute_command(self, *args, **kwargs):
17731780
connection = self.connection
17741781
self._execute(connection, connection.send_command, *args)
17751782

1783+
def _get_node_pubsub(self, node):
1784+
try:
1785+
return self.node_pubsub_mapping[node.name]
1786+
except KeyError:
1787+
pubsub = node.redis_connection.pubsub()
1788+
self.node_pubsub_mapping[node.name] = pubsub
1789+
return pubsub
1790+
1791+
def _sharded_message_generator(self):
1792+
for _ in range(len(self.node_pubsub_mapping)):
1793+
pubsub = next(self._pubsubs_generator)
1794+
message = pubsub.get_message()
1795+
if message is not None:
1796+
return message
1797+
return None
1798+
1799+
def _pubsubs_generator(self):
1800+
while True:
1801+
for pubsub in self.node_pubsub_mapping.values():
1802+
yield pubsub
1803+
1804+
def get_sharded_message(
1805+
self, ignore_subscribe_messages=False, timeout=0.0, target_node=None
1806+
):
1807+
if target_node:
1808+
message = self.node_pubsub_mapping[target_node.name].get_message(
1809+
ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout
1810+
)
1811+
else:
1812+
message = self._sharded_message_generator()
1813+
if message is None:
1814+
return None
1815+
elif str_if_bytes(message["type"]) == "sunsubscribe":
1816+
if message["channel"] in self.pending_unsubscribe_shard_channels:
1817+
self.pending_unsubscribe_shard_channels.remove(message["channel"])
1818+
self.shard_channels.pop(message["channel"], None)
1819+
node = self.cluster.get_node_from_key(message["channel"])
1820+
if self.node_pubsub_mapping[node.name].subscribed is False:
1821+
self.node_pubsub_mapping.pop(node.name)
1822+
if not self.channels and not self.patterns and not self.shard_channels:
1823+
# There are no subscriptions anymore, set subscribed_event flag
1824+
# to false
1825+
self.subscribed_event.clear()
1826+
if self.ignore_subscribe_messages or ignore_subscribe_messages:
1827+
return None
1828+
return message
1829+
1830+
def ssubscribe(self, *args, **kwargs):
1831+
if args:
1832+
args = list_or_args(args[0], args[1:])
1833+
s_channels = dict.fromkeys(args)
1834+
s_channels.update(kwargs)
1835+
for s_channel, handler in s_channels.items():
1836+
node = self.cluster.get_node_from_key(s_channel)
1837+
pubsub = self._get_node_pubsub(node)
1838+
if handler:
1839+
pubsub.ssubscribe(**{s_channel: handler})
1840+
else:
1841+
pubsub.ssubscribe(s_channel)
1842+
self.shard_channels.update(pubsub.shard_channels)
1843+
self.pending_unsubscribe_shard_channels.difference_update(
1844+
self._normalize_keys({s_channel: None})
1845+
)
1846+
if pubsub.subscribed and not self.subscribed:
1847+
self.subscribed_event.set()
1848+
self.health_check_response_counter = 0
1849+
1850+
def sunsubscribe(self, *args):
1851+
if args:
1852+
args = list_or_args(args[0], args[1:])
1853+
else:
1854+
args = self.shard_channels
1855+
1856+
for s_channel in args:
1857+
node = self.cluster.get_node_from_key(s_channel)
1858+
p = self._get_node_pubsub(node)
1859+
p.sunsubscribe(s_channel)
1860+
self.pending_unsubscribe_shard_channels.update(
1861+
p.pending_unsubscribe_shard_channels
1862+
)
1863+
17761864
def get_redis_connection(self):
17771865
"""
17781866
Get the Redis connection of the pubsub connected node.
17791867
"""
17801868
if self.node is not None:
17811869
return self.node.redis_connection
17821870

1871+
def disconnect(self):
1872+
"""
1873+
Disconnect the pubsub connection.
1874+
"""
1875+
if self.connection:
1876+
self.connection.disconnect()
1877+
for pubsub in self.node_pubsub_mapping.values():
1878+
pubsub.connection.disconnect()
1879+
17831880

17841881
class ClusterPipeline(RedisCluster):
17851882
"""

redis/commands/core.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5123,6 +5123,15 @@ def publish(self, channel: ChannelT, message: EncodableT, **kwargs) -> ResponseT
51235123
"""
51245124
return self.execute_command("PUBLISH", channel, message, **kwargs)
51255125

5126+
def spublish(self, shard_channel: ChannelT, message: EncodableT) -> ResponseT:
5127+
"""
5128+
Posts a message to the given shard channel.
5129+
Returns the number of clients that received the message
5130+
5131+
For more information see https://redis.io/commands/spublish
5132+
"""
5133+
return self.execute_command("SPUBLISH", shard_channel, message)
5134+
51265135
def pubsub_channels(self, pattern: PatternT = "*", **kwargs) -> ResponseT:
51275136
"""
51285137
Return a list of channels that have at least one subscriber
@@ -5131,6 +5140,14 @@ def pubsub_channels(self, pattern: PatternT = "*", **kwargs) -> ResponseT:
51315140
"""
51325141
return self.execute_command("PUBSUB CHANNELS", pattern, **kwargs)
51335142

5143+
def pubsub_shardchannels(self, pattern: PatternT = "*", **kwargs) -> ResponseT:
5144+
"""
5145+
Return a list of shard_channels that have at least one subscriber
5146+
5147+
For more information see https://redis.io/commands/pubsub-shardchannels
5148+
"""
5149+
return self.execute_command("PUBSUB SHARDCHANNELS", pattern, **kwargs)
5150+
51345151
def pubsub_numpat(self, **kwargs) -> ResponseT:
51355152
"""
51365153
Returns the number of subscriptions to patterns
@@ -5148,6 +5165,15 @@ def pubsub_numsub(self, *args: ChannelT, **kwargs) -> ResponseT:
51485165
"""
51495166
return self.execute_command("PUBSUB NUMSUB", *args, **kwargs)
51505167

5168+
def pubsub_shardnumsub(self, *args: ChannelT, **kwargs) -> ResponseT:
5169+
"""
5170+
Return a list of (shard_channel, number of subscribers) tuples
5171+
for each channel given in ``*args``
5172+
5173+
For more information see https://redis.io/commands/pubsub-shardnumsub
5174+
"""
5175+
return self.execute_command("PUBSUB SHARDNUMSUB", *args, **kwargs)
5176+
51515177

51525178
AsyncPubSubCommands = PubSubCommands
51535179

redis/commands/parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,13 +153,13 @@ def _get_pubsub_keys(self, *args):
153153
# the second argument is a part of the command name, e.g.
154154
# ['PUBSUB', 'NUMSUB', 'foo'].
155155
pubsub_type = args[1].upper()
156-
if pubsub_type in ["CHANNELS", "NUMSUB"]:
156+
if pubsub_type in ["CHANNELS", "NUMSUB", "SHARDCHANNELS", "SHARDNUMSUB"]:
157157
keys = args[2:]
158158
elif command in ["SUBSCRIBE", "PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE"]:
159159
# format example:
160160
# SUBSCRIBE channel [channel ...]
161161
keys = list(args[1:])
162-
elif command == "PUBLISH":
162+
elif command in ["PUBLISH", "SPUBLISH"]:
163163
# format example:
164164
# PUBLISH channel message
165165
keys = [args[1]]

0 commit comments

Comments
 (0)