Skip to content

Commit 1f75b91

Browse files
noah-chaedvora-h
andauthored
resolve conflict (#10)
Co-authored-by: dvora-h <[email protected]>
1 parent e85e3f7 commit 1f75b91

File tree

6 files changed

+559
-37
lines changed

6 files changed

+559
-37
lines changed

redis/client.py

Lines changed: 73 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,7 @@ class AbstractRedis:
801801
"QUIT": bool_ok,
802802
"STRALGO": parse_stralgo,
803803
"PUBSUB NUMSUB": parse_pubsub_numsub,
804+
"PUBSUB SHARDNUMSUB": parse_pubsub_numsub,
804805
"RANDOMKEY": lambda r: r and r or None,
805806
"RESET": str_if_bytes,
806807
"SCAN": parse_scan,
@@ -1365,8 +1366,8 @@ class PubSub:
13651366
will be returned and it's safe to start listening again.
13661367
"""
13671368

1368-
PUBLISH_MESSAGE_TYPES = ("message", "pmessage")
1369-
UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe")
1369+
PUBLISH_MESSAGE_TYPES = ("message", "pmessage", "smessage")
1370+
UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe", "sunsubscribe")
13701371
HEALTH_CHECK_MESSAGE = "redis-py-health-check"
13711372

13721373
def __init__(
@@ -1414,9 +1415,11 @@ def reset(self):
14141415
self.connection.clear_connect_callbacks()
14151416
self.connection_pool.release(self.connection)
14161417
self.connection = None
1417-
self.channels = {}
14181418
self.health_check_response_counter = 0
1419+
self.channels = {}
14191420
self.pending_unsubscribe_channels = set()
1421+
self.shard_channels = {}
1422+
self.pending_unsubscribe_shard_channels = set()
14201423
self.patterns = {}
14211424
self.pending_unsubscribe_patterns = set()
14221425
self.subscribed_event.clear()
@@ -1431,16 +1434,23 @@ def on_connect(self, connection):
14311434
# before passing them to [p]subscribe.
14321435
self.pending_unsubscribe_channels.clear()
14331436
self.pending_unsubscribe_patterns.clear()
1437+
self.pending_unsubscribe_shard_channels.clear()
14341438
if self.channels:
1435-
channels = {}
1436-
for k, v in self.channels.items():
1437-
channels[self.encoder.decode(k, force=True)] = v
1439+
channels = {
1440+
self.encoder.decode(k, force=True): v for k, v in self.channels.items()
1441+
}
14381442
self.subscribe(**channels)
14391443
if self.patterns:
1440-
patterns = {}
1441-
for k, v in self.patterns.items():
1442-
patterns[self.encoder.decode(k, force=True)] = v
1444+
patterns = {
1445+
self.encoder.decode(k, force=True): v for k, v in self.patterns.items()
1446+
}
14431447
self.psubscribe(**patterns)
1448+
if self.shard_channels:
1449+
shard_channels = {
1450+
self.encoder.decode(k, force=True): v
1451+
for k, v in self.shard_channels.items()
1452+
}
1453+
self.ssubscribe(**shard_channels)
14441454

14451455
@property
14461456
def subscribed(self):
@@ -1647,6 +1657,45 @@ def unsubscribe(self, *args):
16471657
self.pending_unsubscribe_channels.update(channels)
16481658
return self.execute_command("UNSUBSCRIBE", *args)
16491659

1660+
def ssubscribe(self, *args, target_node=None, **kwargs):
1661+
"""
1662+
Subscribes the client to the specified shard channels.
1663+
Channels supplied as keyword arguments expect a channel name as the key
1664+
and a callable as the value. A channel's callable will be invoked automatically
1665+
when a message is received on that channel rather than producing a message via
1666+
``listen()`` or ``get_sharded_message()``.
1667+
"""
1668+
if args:
1669+
args = list_or_args(args[0], args[1:])
1670+
new_s_channels = dict.fromkeys(args)
1671+
new_s_channels.update(kwargs)
1672+
ret_val = self.execute_command("SSUBSCRIBE", *new_s_channels.keys())
1673+
# update the s_channels dict AFTER we send the command. we don't want to
1674+
# subscribe twice to these channels, once for the command and again
1675+
# for the reconnection.
1676+
new_s_channels = self._normalize_keys(new_s_channels)
1677+
self.shard_channels.update(new_s_channels)
1678+
if not self.subscribed:
1679+
# Set the subscribed_event flag to True
1680+
self.subscribed_event.set()
1681+
# Clear the health check counter
1682+
self.health_check_response_counter = 0
1683+
self.pending_unsubscribe_shard_channels.difference_update(new_s_channels)
1684+
return ret_val
1685+
1686+
def sunsubscribe(self, *args, target_node=None):
1687+
"""
1688+
Unsubscribe from the supplied shard_channels. If empty, unsubscribe from
1689+
all shard_channels
1690+
"""
1691+
if args:
1692+
args = list_or_args(args[0], args[1:])
1693+
s_channels = self._normalize_keys(dict.fromkeys(args))
1694+
else:
1695+
s_channels = self.shard_channels
1696+
self.pending_unsubscribe_shard_channels.update(s_channels)
1697+
return self.execute_command("SUNSUBSCRIBE", *args)
1698+
16501699
def listen(self):
16511700
"Listen for messages on channels this client has been subscribed to"
16521701
while self.subscribed:
@@ -1681,6 +1730,8 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0.0):
16811730
return self.handle_message(response, ignore_subscribe_messages)
16821731
return None
16831732

1733+
get_sharded_message = get_message
1734+
16841735
def ping(self, message=None):
16851736
"""
16861737
Ping the Redis server
@@ -1726,12 +1777,17 @@ def handle_message(self, response, ignore_subscribe_messages=False):
17261777
if pattern in self.pending_unsubscribe_patterns:
17271778
self.pending_unsubscribe_patterns.remove(pattern)
17281779
self.patterns.pop(pattern, None)
1780+
elif message_type == "sunsubscribe":
1781+
s_channel = response[1]
1782+
if s_channel in self.pending_unsubscribe_shard_channels:
1783+
self.pending_unsubscribe_shard_channels.remove(s_channel)
1784+
self.shard_channels.pop(s_channel, None)
17291785
else:
17301786
channel = response[1]
17311787
if channel in self.pending_unsubscribe_channels:
17321788
self.pending_unsubscribe_channels.remove(channel)
17331789
self.channels.pop(channel, None)
1734-
if not self.channels and not self.patterns:
1790+
if not self.channels and not self.patterns and not self.shard_channels:
17351791
# There are no subscriptions anymore, set subscribed_event flag
17361792
# to false
17371793
self.subscribed_event.clear()
@@ -1740,6 +1796,8 @@ def handle_message(self, response, ignore_subscribe_messages=False):
17401796
# if there's a message handler, invoke it
17411797
if message_type == "pmessage":
17421798
handler = self.patterns.get(message["pattern"], None)
1799+
elif message_type == "smessage":
1800+
handler = self.shard_channels.get(message["channel"], None)
17431801
else:
17441802
handler = self.channels.get(message["channel"], None)
17451803
if handler:
@@ -1760,6 +1818,11 @@ def run_in_thread(self, sleep_time=0, daemon=False, exception_handler=None):
17601818
for pattern, handler in self.patterns.items():
17611819
if handler is None:
17621820
raise PubSubError(f"Pattern: '{pattern}' has no handler registered")
1821+
for s_channel, handler in self.shard_channels.items():
1822+
if handler is None:
1823+
raise PubSubError(
1824+
f"Shard Channel: '{s_channel}' has no handler registered"
1825+
)
17631826

17641827
thread = PubSubWorkerThread(
17651828
self, sleep_time, daemon=daemon, exception_handler=exception_handler

redis/cluster.py

Lines changed: 101 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from redis.client import CaseInsensitiveDict, PubSub, Redis, parse_scan
1111
from redis.commands import READ_COMMANDS, CommandsParser, RedisClusterCommands
1212
from redis.connection import ConnectionPool, DefaultParser, Encoder, parse_url
13+
from redis.commands.helpers import list_or_args
1314
from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
1415
from redis.exceptions import (
1516
AskError,
@@ -219,6 +220,8 @@ class AbstractRedisCluster:
219220
"PUBSUB CHANNELS",
220221
"PUBSUB NUMPAT",
221222
"PUBSUB NUMSUB",
223+
"PUBSUB SHARDCHANNELS",
224+
"PUBSUB SHARDNUMSUB",
222225
"PING",
223226
"INFO",
224227
"SHUTDOWN",
@@ -343,11 +346,13 @@ class AbstractRedisCluster:
343346
}
344347

345348
RESULT_CALLBACKS = dict_merge(
346-
list_keys_to_dict(["PUBSUB NUMSUB"], parse_pubsub_numsub),
349+
list_keys_to_dict(["PUBSUB NUMSUB", "PUBSUB SHARDNUMSUB"], parse_pubsub_numsub),
347350
list_keys_to_dict(
348351
["PUBSUB NUMPAT"], lambda command, res: sum(list(res.values()))
349352
),
350-
list_keys_to_dict(["KEYS", "PUBSUB CHANNELS"], merge_result),
353+
list_keys_to_dict(
354+
["KEYS", "PUBSUB CHANNELS", "PUBSUB SHARDCHANNELS"], merge_result
355+
),
351356
list_keys_to_dict(
352357
[
353358
"PING",
@@ -1655,6 +1660,8 @@ def __init__(self, redis_cluster, node=None, host=None, port=None, **kwargs):
16551660
else redis_cluster.get_redis_connection(self.node).connection_pool
16561661
)
16571662
self.cluster = redis_cluster
1663+
self.node_pubsub_mapping = {}
1664+
self._pubsubs_generator = self._pubsubs_generator()
16581665
super().__init__(
16591666
**kwargs, connection_pool=connection_pool, encoder=redis_cluster.encoder
16601667
)
@@ -1708,9 +1715,9 @@ def _raise_on_invalid_node(self, redis_cluster, node, host, port):
17081715
f"Node {host}:{port} doesn't exist in the cluster"
17091716
)
17101717

1711-
def execute_command(self, *args, **kwargs):
1718+
def execute_command(self, *args):
17121719
"""
1713-
Execute a publish/subscribe command.
1720+
Execute a subscribe/unsubscribe command.
17141721
17151722
Taken code from redis-py and tweak to make it work within a cluster.
17161723
"""
@@ -1743,13 +1750,103 @@ def execute_command(self, *args, **kwargs):
17431750
connection = self.connection
17441751
self._execute(connection, connection.send_command, *args)
17451752

1753+
def _get_node_pubsub(self, node):
1754+
try:
1755+
return self.node_pubsub_mapping[node.name]
1756+
except KeyError:
1757+
pubsub = node.redis_connection.pubsub()
1758+
self.node_pubsub_mapping[node.name] = pubsub
1759+
return pubsub
1760+
1761+
def _sharded_message_generator(self):
1762+
for _ in range(len(self.node_pubsub_mapping)):
1763+
pubsub = next(self._pubsubs_generator)
1764+
message = pubsub.get_message()
1765+
if message is not None:
1766+
return message
1767+
return None
1768+
1769+
def _pubsubs_generator(self):
1770+
while True:
1771+
for pubsub in self.node_pubsub_mapping.values():
1772+
yield pubsub
1773+
1774+
def get_sharded_message(
1775+
self, ignore_subscribe_messages=False, timeout=0.0, target_node=None
1776+
):
1777+
if target_node:
1778+
message = self.node_pubsub_mapping[target_node.name].get_message(
1779+
ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout
1780+
)
1781+
else:
1782+
message = self._sharded_message_generator()
1783+
if message is None:
1784+
return None
1785+
elif str_if_bytes(message["type"]) == "sunsubscribe":
1786+
if message["channel"] in self.pending_unsubscribe_shard_channels:
1787+
self.pending_unsubscribe_shard_channels.remove(message["channel"])
1788+
self.shard_channels.pop(message["channel"], None)
1789+
node = self.cluster.get_node_from_key(message["channel"])
1790+
if self.node_pubsub_mapping[node.name].subscribed is False:
1791+
self.node_pubsub_mapping.pop(node.name)
1792+
if not self.channels and not self.patterns and not self.shard_channels:
1793+
# There are no subscriptions anymore, set subscribed_event flag
1794+
# to false
1795+
self.subscribed_event.clear()
1796+
if self.ignore_subscribe_messages or ignore_subscribe_messages:
1797+
return None
1798+
return message
1799+
1800+
def ssubscribe(self, *args, **kwargs):
1801+
if args:
1802+
args = list_or_args(args[0], args[1:])
1803+
s_channels = dict.fromkeys(args)
1804+
s_channels.update(kwargs)
1805+
for s_channel, handler in s_channels.items():
1806+
node = self.cluster.get_node_from_key(s_channel)
1807+
pubsub = self._get_node_pubsub(node)
1808+
if handler:
1809+
pubsub.ssubscribe(**{s_channel: handler})
1810+
else:
1811+
pubsub.ssubscribe(s_channel)
1812+
self.shard_channels.update(pubsub.shard_channels)
1813+
self.pending_unsubscribe_shard_channels.difference_update(
1814+
self._normalize_keys({s_channel: None})
1815+
)
1816+
if pubsub.subscribed and not self.subscribed:
1817+
self.subscribed_event.set()
1818+
self.health_check_response_counter = 0
1819+
1820+
def sunsubscribe(self, *args):
1821+
if args:
1822+
args = list_or_args(args[0], args[1:])
1823+
else:
1824+
args = self.shard_channels
1825+
1826+
for s_channel in args:
1827+
node = self.cluster.get_node_from_key(s_channel)
1828+
p = self._get_node_pubsub(node)
1829+
p.sunsubscribe(s_channel)
1830+
self.pending_unsubscribe_shard_channels.update(
1831+
p.pending_unsubscribe_shard_channels
1832+
)
1833+
17461834
def get_redis_connection(self):
17471835
"""
17481836
Get the Redis connection of the pubsub connected node.
17491837
"""
17501838
if self.node is not None:
17511839
return self.node.redis_connection
17521840

1841+
def disconnect(self):
1842+
"""
1843+
Disconnect the pubsub connection.
1844+
"""
1845+
if self.connection:
1846+
self.connection.disconnect()
1847+
for pubsub in self.node_pubsub_mapping.values():
1848+
pubsub.connection.disconnect()
1849+
17531850

17541851
class ClusterPipeline(RedisCluster):
17551852
"""

redis/commands/core.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5090,6 +5090,15 @@ def publish(self, channel: ChannelT, message: EncodableT, **kwargs) -> ResponseT
50905090
"""
50915091
return self.execute_command("PUBLISH", channel, message, **kwargs)
50925092

5093+
def spublish(self, shard_channel: ChannelT, message: EncodableT) -> ResponseT:
5094+
"""
5095+
Posts a message to the given shard channel.
5096+
Returns the number of clients that received the message
5097+
5098+
For more information see https://redis.io/commands/spublish
5099+
"""
5100+
return self.execute_command("SPUBLISH", shard_channel, message)
5101+
50935102
def pubsub_channels(self, pattern: PatternT = "*", **kwargs) -> ResponseT:
50945103
"""
50955104
Return a list of channels that have at least one subscriber
@@ -5098,6 +5107,14 @@ def pubsub_channels(self, pattern: PatternT = "*", **kwargs) -> ResponseT:
50985107
"""
50995108
return self.execute_command("PUBSUB CHANNELS", pattern, **kwargs)
51005109

5110+
def pubsub_shardchannels(self, pattern: PatternT = "*", **kwargs) -> ResponseT:
5111+
"""
5112+
Return a list of shard_channels that have at least one subscriber
5113+
5114+
For more information see https://redis.io/commands/pubsub-shardchannels
5115+
"""
5116+
return self.execute_command("PUBSUB SHARDCHANNELS", pattern, **kwargs)
5117+
51015118
def pubsub_numpat(self, **kwargs) -> ResponseT:
51025119
"""
51035120
Returns the number of subscriptions to patterns
@@ -5115,6 +5132,15 @@ def pubsub_numsub(self, *args: ChannelT, **kwargs) -> ResponseT:
51155132
"""
51165133
return self.execute_command("PUBSUB NUMSUB", *args, **kwargs)
51175134

5135+
def pubsub_shardnumsub(self, *args: ChannelT, **kwargs) -> ResponseT:
5136+
"""
5137+
Return a list of (shard_channel, number of subscribers) tuples
5138+
for each channel given in ``*args``
5139+
5140+
For more information see https://redis.io/commands/pubsub-shardnumsub
5141+
"""
5142+
return self.execute_command("PUBSUB SHARDNUMSUB", *args, **kwargs)
5143+
51185144

51195145
AsyncPubSubCommands = PubSubCommands
51205146

redis/commands/parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,13 +153,13 @@ def _get_pubsub_keys(self, *args):
153153
# the second argument is a part of the command name, e.g.
154154
# ['PUBSUB', 'NUMSUB', 'foo'].
155155
pubsub_type = args[1].upper()
156-
if pubsub_type in ["CHANNELS", "NUMSUB"]:
156+
if pubsub_type in ["CHANNELS", "NUMSUB", "SHARDCHANNELS", "SHARDNUMSUB"]:
157157
keys = args[2:]
158158
elif command in ["SUBSCRIBE", "PSUBSCRIBE", "UNSUBSCRIBE", "PUNSUBSCRIBE"]:
159159
# format example:
160160
# SUBSCRIBE channel [channel ...]
161161
keys = list(args[1:])
162-
elif command == "PUBLISH":
162+
elif command in ["PUBLISH", "SPUBLISH"]:
163163
# format example:
164164
# PUBLISH channel message
165165
keys = [args[1]]

0 commit comments

Comments
 (0)