Skip to content

Commit 46042ce

Browse files
committed
introduce AbstractConnection so that UnixDomainSocketConnection can call super().__init__
1 parent fd7a79d commit 46042ce

File tree

1 file changed

+120
-158
lines changed

1 file changed

+120
-158
lines changed

redis/connection.py

Lines changed: 120 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import sys
77
import threading
88
import weakref
9+
from abc import abstractmethod
910
from io import SEEK_END
1011
from itertools import chain
1112
from queue import Empty, Full, LifoQueue
@@ -585,20 +586,13 @@ def pack(self, *args):
585586
return output
586587

587588

588-
class Connection:
589-
"Manages TCP communication to and from a Redis server"
589+
class AbstractConnection:
590+
"Manages communication to and from a Redis server"
590591

591592
def __init__(
592593
self,
593-
host="localhost",
594-
port=6379,
595594
db=0,
596595
password=None,
597-
socket_timeout=None,
598-
socket_connect_timeout=None,
599-
socket_keepalive=False,
600-
socket_keepalive_options=None,
601-
socket_type=0,
602596
retry_on_timeout=False,
603597
retry_on_error=SENTINEL,
604598
encoding="utf-8",
@@ -629,18 +623,11 @@ def __init__(
629623
"2. 'credential_provider'"
630624
)
631625
self.pid = os.getpid()
632-
self.host = host
633-
self.port = int(port)
634626
self.db = db
635627
self.client_name = client_name
636628
self.credential_provider = credential_provider
637629
self.password = password
638630
self.username = username
639-
self.socket_timeout = socket_timeout
640-
self.socket_connect_timeout = socket_connect_timeout or socket_timeout
641-
self.socket_keepalive = socket_keepalive
642-
self.socket_keepalive_options = socket_keepalive_options or {}
643-
self.socket_type = socket_type
644631
self.retry_on_timeout = retry_on_timeout
645632
if retry_on_error is SENTINEL:
646633
retry_on_error = []
@@ -673,11 +660,9 @@ def __repr__(self):
673660
repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
674661
return f"{self.__class__.__name__}<{repr_args}>"
675662

663+
@abstractmethod
676664
def repr_pieces(self):
677-
pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
678-
if self.client_name:
679-
pieces.append(("client_name", self.client_name))
680-
return pieces
665+
pass
681666

682667
def __del__(self):
683668
try:
@@ -740,75 +725,17 @@ def connect(self):
740725
if callback:
741726
callback(self)
742727

728+
@abstractmethod
743729
def _connect(self):
744-
"Create a TCP socket connection"
745-
# we want to mimic what socket.create_connection does to support
746-
# ipv4/ipv6, but we want to set options prior to calling
747-
# socket.connect()
748-
err = None
749-
for res in socket.getaddrinfo(
750-
self.host, self.port, self.socket_type, socket.SOCK_STREAM
751-
):
752-
family, socktype, proto, canonname, socket_address = res
753-
sock = None
754-
try:
755-
sock = socket.socket(family, socktype, proto)
756-
# TCP_NODELAY
757-
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
758-
759-
# TCP_KEEPALIVE
760-
if self.socket_keepalive:
761-
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
762-
for k, v in self.socket_keepalive_options.items():
763-
sock.setsockopt(socket.IPPROTO_TCP, k, v)
764-
765-
# set the socket_connect_timeout before we connect
766-
sock.settimeout(self.socket_connect_timeout)
767-
768-
# connect
769-
sock.connect(socket_address)
770-
771-
# set the socket_timeout now that we're connected
772-
sock.settimeout(self.socket_timeout)
773-
return sock
774-
775-
except OSError as _:
776-
err = _
777-
if sock is not None:
778-
sock.close()
779-
780-
if err is not None:
781-
raise err
782-
raise OSError("socket.getaddrinfo returned an empty list")
730+
pass
783731

732+
@abstractmethod
784733
def _host_error(self):
785-
try:
786-
host_error = f"{self.host}:{self.port}"
787-
except AttributeError:
788-
host_error = "connection"
789-
790-
return host_error
734+
pass
791735

736+
@abstractmethod
792737
def _error_message(self, exception):
793-
# args for socket.error can either be (errno, "message")
794-
# or just "message"
795-
796-
host_error = self._host_error()
797-
798-
if len(exception.args) == 1:
799-
try:
800-
return f"Error connecting to {host_error}. \
801-
{exception.args[0]}."
802-
except AttributeError:
803-
return f"Connection Error: {exception.args[0]}"
804-
else:
805-
try:
806-
return (
807-
f"Error {exception.args[0]} connecting to "
808-
f"{host_error}. {exception.args[1]}."
809-
)
810-
except AttributeError:
811-
return f"Connection Error: {exception.args[0]}"
738+
pass
812739

813740
def on_connect(self):
814741
"Initialize the connection, authenticate and select a database"
@@ -992,6 +919,101 @@ def pack_commands(self, commands):
992919
return output
993920

994921

922+
class Connection(AbstractConnection):
923+
"Manages TCP communication to and from a Redis server"
924+
925+
def __init__(
926+
self,
927+
host="localhost",
928+
port=6379,
929+
socket_timeout=None,
930+
socket_connect_timeout=None,
931+
socket_keepalive=False,
932+
socket_keepalive_options=None,
933+
socket_type=0,
934+
**kwargs,
935+
):
936+
self.host = host
937+
self.port = int(port)
938+
self.socket_timeout = socket_timeout
939+
self.socket_connect_timeout = socket_connect_timeout or socket_timeout
940+
self.socket_keepalive = socket_keepalive
941+
self.socket_keepalive_options = socket_keepalive_options or {}
942+
self.socket_type = socket_type
943+
super().__init__(**kwargs)
944+
945+
def repr_pieces(self):
946+
pieces = [("host", self.host), ("port", self.port), ("db", self.db)]
947+
if self.client_name:
948+
pieces.append(("client_name", self.client_name))
949+
return pieces
950+
951+
def _connect(self):
952+
"Create a TCP socket connection"
953+
# we want to mimic what socket.create_connection does to support
954+
# ipv4/ipv6, but we want to set options prior to calling
955+
# socket.connect()
956+
err = None
957+
for res in socket.getaddrinfo(
958+
self.host, self.port, self.socket_type, socket.SOCK_STREAM
959+
):
960+
family, socktype, proto, canonname, socket_address = res
961+
sock = None
962+
try:
963+
sock = socket.socket(family, socktype, proto)
964+
# TCP_NODELAY
965+
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
966+
967+
# TCP_KEEPALIVE
968+
if self.socket_keepalive:
969+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
970+
for k, v in self.socket_keepalive_options.items():
971+
sock.setsockopt(socket.IPPROTO_TCP, k, v)
972+
973+
# set the socket_connect_timeout before we connect
974+
sock.settimeout(self.socket_connect_timeout)
975+
976+
# connect
977+
sock.connect(socket_address)
978+
979+
# set the socket_timeout now that we're connected
980+
sock.settimeout(self.socket_timeout)
981+
return sock
982+
983+
except OSError as _:
984+
err = _
985+
if sock is not None:
986+
sock.close()
987+
988+
if err is not None:
989+
raise err
990+
raise OSError("socket.getaddrinfo returned an empty list")
991+
992+
def _host_error(self):
993+
return f"{self.host}:{self.port}"
994+
995+
def _error_message(self, exception):
996+
# args for socket.error can either be (errno, "message")
997+
# or just "message"
998+
999+
host_error = self._host_error()
1000+
1001+
if len(exception.args) == 1:
1002+
try:
1003+
return f"Error connecting to {host_error}. \
1004+
{exception.args[0]}."
1005+
except AttributeError:
1006+
return f"Connection Error: {exception.args[0]}"
1007+
else:
1008+
try:
1009+
return (
1010+
f"Error {exception.args[0]} connecting to "
1011+
f"{host_error}. {exception.args[1]}."
1012+
)
1013+
except AttributeError:
1014+
return f"Connection Error: {exception.args[0]}"
1015+
1016+
9951017
class SSLConnection(Connection):
9961018
"""Manages SSL connections to and from the Redis server(s).
9971019
This class extends the Connection class, adding SSL functionality, and making
@@ -1037,8 +1059,6 @@ def __init__(
10371059
if not ssl_available:
10381060
raise RedisError("Python wasn't built with SSL support")
10391061

1040-
super().__init__(**kwargs)
1041-
10421062
self.keyfile = ssl_keyfile
10431063
self.certfile = ssl_certfile
10441064
if ssl_cert_reqs is None:
@@ -1064,6 +1084,7 @@ def __init__(
10641084
self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled
10651085
self.ssl_ocsp_context = ssl_ocsp_context
10661086
self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert
1087+
super().__init__(**kwargs)
10671088

10681089
def _connect(self):
10691090
"Wrap the socket with SSL support"
@@ -1133,77 +1154,12 @@ def _connect(self):
11331154
return sslsock
11341155

11351156

1136-
class UnixDomainSocketConnection(Connection):
1137-
def __init__(
1138-
self,
1139-
path="",
1140-
db=0,
1141-
username=None,
1142-
password=None,
1143-
socket_timeout=None,
1144-
encoding="utf-8",
1145-
encoding_errors="strict",
1146-
decode_responses=False,
1147-
retry_on_timeout=False,
1148-
retry_on_error=SENTINEL,
1149-
parser_class=DefaultParser,
1150-
socket_read_size=65536,
1151-
health_check_interval=0,
1152-
client_name=None,
1153-
retry=None,
1154-
redis_connect_func=None,
1155-
credential_provider: Optional[CredentialProvider] = None,
1156-
command_packer=None,
1157-
):
1158-
"""
1159-
Initialize a new UnixDomainSocketConnection.
1160-
To specify a retry policy for specific errors, first set
1161-
`retry_on_error` to a list of the error/s to retry on, then set
1162-
`retry` to a valid `Retry` object.
1163-
To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
1164-
"""
1165-
if (username or password) and credential_provider is not None:
1166-
raise DataError(
1167-
"'username' and 'password' cannot be passed along with 'credential_"
1168-
"provider'. Please provide only one of the following arguments: \n"
1169-
"1. 'password' and (optional) 'username'\n"
1170-
"2. 'credential_provider'"
1171-
)
1172-
self.pid = os.getpid()
1157+
class UnixDomainSocketConnection(AbstractConnection):
1158+
"Manages UDS communication to and from a Redis server"
1159+
1160+
def __init__(self, path="", **kwargs):
11731161
self.path = path
1174-
self.db = db
1175-
self.client_name = client_name
1176-
self.credential_provider = credential_provider
1177-
self.password = password
1178-
self.username = username
1179-
self.socket_timeout = socket_timeout
1180-
self.retry_on_timeout = retry_on_timeout
1181-
if retry_on_error is SENTINEL:
1182-
retry_on_error = []
1183-
if retry_on_timeout:
1184-
# Add TimeoutError to the errors list to retry on
1185-
retry_on_error.append(TimeoutError)
1186-
self.retry_on_error = retry_on_error
1187-
if self.retry_on_error:
1188-
if retry is None:
1189-
self.retry = Retry(NoBackoff(), 1)
1190-
else:
1191-
# deep-copy the Retry object as it is mutable
1192-
self.retry = copy.deepcopy(retry)
1193-
# Update the retry's supported errors with the specified errors
1194-
self.retry.update_supported_errors(retry_on_error)
1195-
else:
1196-
self.retry = Retry(NoBackoff(), 0)
1197-
self.health_check_interval = health_check_interval
1198-
self.next_health_check = 0
1199-
self.redis_connect_func = redis_connect_func
1200-
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
1201-
self._sock = None
1202-
self._socket_read_size = socket_read_size
1203-
self.set_parser(parser_class)
1204-
self._connect_callbacks = []
1205-
self._buffer_cutoff = 6000
1206-
self._command_packer = self._construct_command_packer(command_packer)
1162+
super().__init__(**kwargs)
12071163

12081164
def repr_pieces(self):
12091165
pieces = [("path", self.path), ("db", self.db)]
@@ -1218,15 +1174,21 @@ def _connect(self):
12181174
sock.connect(self.path)
12191175
return sock
12201176

1177+
def _host_error(self):
1178+
return self.path
1179+
12211180
def _error_message(self, exception):
12221181
# args for socket.error can either be (errno, "message")
12231182
# or just "message"
1183+
host_error = self._host_error()
12241184
if len(exception.args) == 1:
1225-
return f"Error connecting to unix socket: {self.path}. {exception.args[0]}."
1185+
return (
1186+
f"Error connecting to unix socket: {host_error}. {exception.args[0]}."
1187+
)
12261188
else:
12271189
return (
12281190
f"Error {exception.args[0]} connecting to unix socket: "
1229-
f"{self.path}. {exception.args[1]}."
1191+
f"{host_error}. {exception.args[1]}."
12301192
)
12311193

12321194

0 commit comments

Comments
 (0)