Skip to content

Commit abe6137

Browse files
committed
Refactored CredentialProvider class
1 parent c37e0f1 commit abe6137

File tree

6 files changed

+107
-142
lines changed

6 files changed

+107
-142
lines changed

docs/examples/connection_examples.ipynb

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@
100100
{
101101
"cell_type": "markdown",
102102
"source": [
103-
"## Connecting to a redis instance with static credential provider"
103+
"## Connecting to a redis instance with username and password credential provider"
104104
],
105105
"metadata": {}
106106
},
@@ -111,7 +111,7 @@
111111
"source": [
112112
"import redis\n",
113113
"\n",
114-
"creds_provider = redis.StaticCredentialProvider(\"username\", \"password\")\n",
114+
"creds_provider = redis.UsernamePasswordCredentialProvider(\"username\", \"password\")\n",
115115
"user_connection = redis.Redis(host=\"localhost\", port=6379, credential_provider=creds_provider)\n",
116116
"user_connection.ping()"
117117
],
@@ -131,11 +131,19 @@
131131
"execution_count": null,
132132
"outputs": [],
133133
"source": [
134+
"from typing import Tuple\n",
134135
"import redis\n",
135136
"\n",
136137
"creds_map = {\"user_1\": \"pass_1\",\n",
137138
" \"user_2\": \"pass_2\"}\n",
138139
"\n",
140+
"class UserMapCredentialProvider(redis.CredentialProvider):\n",
141+
" def __init__(self, username: str):\n",
142+
" self.username = username\n",
143+
"\n",
144+
" def get_credentials(self) -> Tuple[str, str]:\n",
145+
" return self.username, creds_map.get(self.username)\n",
146+
"\n",
139147
"# Create a default connection to set the ACL user\n",
140148
"default_connection = redis.Redis(host=\"localhost\", port=6379)\n",
141149
"default_connection.acl_setuser(\n",
@@ -146,11 +154,8 @@
146154
" commands=[\"+ping\", \"+command\", \"+info\", \"+select\", \"+flushdb\"],\n",
147155
")\n",
148156
"\n",
149-
"def creds_provider(self):\n",
150-
" return self.username, creds_map.get(self.username)\n",
151-
"\n",
152-
"# Create a CredentialProvider instance for user_1\n",
153-
"creds_provider = redis.CredentialProvider(username=\"user_1\", supplier=creds_provider)\n",
157+
"# Create a UserMapCredentialProvider instance for user_1\n",
158+
"creds_provider = UserMapCredentialProvider(\"user_1\")\n",
154159
"# Initiate user connection with the credential provider\n",
155160
"user_connection = redis.Redis(host=\"localhost\", port=6379,\n",
156161
" credential_provider=creds_provider)\n",
@@ -172,24 +177,27 @@
172177
"execution_count": null,
173178
"outputs": [],
174179
"source": [
180+
"from typing import Union\n",
175181
"import redis\n",
176182
"\n",
177-
"def call_external_supplier():\n",
178-
" # Call to an external credential supplier\n",
179-
" raise NotImplementedError\n",
183+
"class InitCredsSetCredentialProvider(redis.CredentialProvider):\n",
184+
" def __init__(self, username, password):\n",
185+
" self.username = username\n",
186+
" self.password = password\n",
187+
" self.call_supplier = False\n",
188+
"\n",
189+
" def call_external_supplier(self) -> Union[Tuple[str], Tuple[str, str]]:\n",
190+
" # Call to an external credential supplier\n",
191+
" raise NotImplementedError\n",
180192
"\n",
181-
"def creds_supplier(self):\n",
182-
" call_supplier = self.supplier_kwargs.get(\"call_supplier\", True)\n",
183-
" if call_supplier:\n",
184-
" return call_external_supplier()\n",
185-
" # Use the init set only for the first time\n",
186-
" self.kwargs.update({\"call_supplier\": True})\n",
187-
" return self.username, self.password\n",
193+
" def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:\n",
194+
" if self.call_supplier:\n",
195+
" return self.call_external_supplier()\n",
196+
" # Use the init set only for the first time\n",
197+
" self.call_supplier = True\n",
198+
" return self.username, self.password\n",
188199
"\n",
189-
"cred_provider = redis.CredentialProvider(username=\"init_user\",\n",
190-
" password=\"init_pass\",\n",
191-
" call_supplier=False,\n",
192-
" supplier=creds_supplier)"
200+
"cred_provider = InitCredsSetCredentialProvider(username=\"init_user\", password=\"init_pass\")"
193201
],
194202
"metadata": {}
195203
}

redis/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
SSLConnection,
1010
UnixDomainSocketConnection,
1111
)
12-
from redis.credentials import CredentialProvider, StaticCredentialProvider
12+
from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
1313
from redis.exceptions import (
1414
AuthenticationError,
1515
AuthenticationWrongNumberOfArgsError,
@@ -78,7 +78,7 @@ def int_or_str(value):
7878
"SentinelManagedConnection",
7979
"SentinelManagedSSLConnection",
8080
"SSLConnection",
81-
"StaticCredentialProvider",
81+
"UsernamePasswordCredentialProvider",
8282
"StrictRedis",
8383
"TimeoutError",
8484
"UnixDomainSocketConnection",

redis/connection.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from urllib.parse import parse_qs, unquote, urlparse
1212

1313
from redis.backoff import NoBackoff
14-
from redis.credentials import CredentialProvider, StaticCredentialProvider
14+
from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
1515
from redis.exceptions import (
1616
AuthenticationError,
1717
AuthenticationWrongNumberOfArgsError,
@@ -690,8 +690,7 @@ def on_connect(self):
690690
if self.credential_provider or (self.username or self.password):
691691
cred_provider = (
692692
self.credential_provider
693-
if self.credential_provider
694-
else StaticCredentialProvider(self.username, self.password)
693+
or UsernamePasswordCredentialProvider(self.username, self.password)
695694
)
696695
auth_args = cred_provider.get_credentials()
697696
# avoid checking health here -- PING will fail if we try
@@ -705,11 +704,7 @@ def on_connect(self):
705704
# server seems to be < 6.0.0 which expects a single password
706705
# arg. retry auth with just the password.
707706
# https://github.com/andymccurdy/redis-py/issues/1274
708-
self.send_command(
709-
"AUTH",
710-
auth_args[-1],
711-
check_health=False,
712-
)
707+
self.send_command("AUTH", auth_args[-1], check_health=False)
713708
auth_response = self.read_response()
714709

715710
if str_if_bytes(auth_response) != "OK":

redis/credentials.py

Lines changed: 15 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,26 @@
1-
from typing import Callable, Optional, Union
1+
from typing import Optional, Tuple, Union
22

33

44
class CredentialProvider:
5-
def __init__(
6-
self,
7-
username: Union[str, None] = "",
8-
password: Union[str, None] = "",
9-
supplier: Optional[Callable] = None,
10-
*supplier_args,
11-
**supplier_kwargs,
12-
):
13-
"""
14-
Initialize a new Credentials Provider.
15-
:param supplier: a supplier function that returns the username and password.
16-
def supplier(self, arg1, arg2, ...) -> (username, password)
17-
See examples/connection_examples.ipynb
18-
:param supplier_args: arguments to pass to the supplier function
19-
:param supplier_kwargs: keyword arguments to pass to the supplier function
20-
"""
21-
self._username_ = username if username is not None else ""
22-
self._password_ = password if password is not None else ""
23-
self.supplier = supplier
24-
self.supplier_args = supplier_args
25-
self.supplier_kwargs = supplier_kwargs
26-
27-
def get_credentials(self):
28-
if self.supplier is not None:
29-
self.username, self.password = self.supplier(
30-
self, *self.supplier_args, **self.supplier_kwargs
31-
)
32-
33-
return (
34-
(self._username_, self._password_)
35-
if self._username_
36-
else (self._password_,)
37-
)
38-
39-
@property
40-
def password(self):
41-
if self.supplier is not None and not self._password_:
42-
self.username, self.password = self.supplier(
43-
self, *self.supplier_args, **self.supplier_kwargs
44-
)
45-
return self._password_
46-
47-
@password.setter
48-
def password(self, value):
49-
self._password_ = value if value is not None else ""
50-
51-
@property
52-
def username(self):
53-
if self.supplier is not None and not self._username_:
54-
self.username, self.password = self.supplier(
55-
self, *self.supplier_args, **self.supplier_kwargs
56-
)
57-
return self._username_
5+
"""
6+
Credentials Provider.
7+
"""
588

59-
@username.setter
60-
def username(self, value):
61-
self._username_ = value if value is not None else ""
9+
def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:
10+
raise NotImplementedError("get_credentials must be implemented")
6211

6312

64-
class StaticCredentialProvider(CredentialProvider):
13+
class UsernamePasswordCredentialProvider(CredentialProvider):
6514
"""
6615
Simple implementation of CredentialProvider that just wraps static
6716
username and password.
6817
"""
6918

70-
def __init__(
71-
self, username: Union[str, None] = "", password: Union[str, None] = ""
72-
):
73-
super().__init__(username=username, password=password)
19+
def __init__(self, username: Optional[str] = None, password: Optional[str] = None):
20+
self.username = username
21+
self.password = password
22+
23+
def get_credentials(self):
24+
if self.username:
25+
return self.username, self.password
26+
return (self.password,)

tests/test_commands.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import redis
1111
from redis import exceptions
1212
from redis.client import parse_info
13-
from redis.credentials import StaticCredentialProvider
1413

1514
from .conftest import (
1615
_get_client,
@@ -96,9 +95,7 @@ def teardown():
9695
# error when switching to the db 9 because we're not authenticated yet
9796
# setting the password on the connection itself triggers the
9897
# authentication in the connection's `on_connect` method
99-
r.connection.credential_provider = StaticCredentialProvider(
100-
password=temp_pass
101-
)
98+
r.connection.password = temp_pass
10299
except AttributeError:
103100
# connection field is not set in Redis Cluster, but that's ok
104101
# because the problem discussed above does not apply to Redis Cluster

0 commit comments

Comments
 (0)