Skip to content

Commit ef24d8d

Browse files
committed
Moved CredentialsProvider to a separate file, added type hints
1 parent 7b958e0 commit ef24d8d

File tree

5 files changed

+136
-131
lines changed

5 files changed

+136
-131
lines changed

redis/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
BlockingConnectionPool,
77
Connection,
88
ConnectionPool,
9-
CredentialsProvider,
109
SSLConnection,
1110
UnixDomainSocketConnection,
1211
)
12+
from redis.credentials import CredentialsProvider
1313
from redis.exceptions import (
1414
AuthenticationError,
1515
AuthenticationWrongNumberOfArgsError,

redis/connection.py

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from packaging.version import Version
1414

1515
from redis.backoff import NoBackoff
16+
from redis.credentials import CredentialsProvider
1617
from redis.exceptions import (
1718
AuthenticationError,
1819
AuthenticationWrongNumberOfArgsError,
@@ -500,42 +501,6 @@ def read_response(self, disable_decoding=False):
500501
DefaultParser = PythonParser
501502

502503

503-
class CredentialsProvider:
504-
def __init__(self, username="", password="", supplier=None, *args, **kwargs):
505-
"""
506-
Initialize a new Credentials Provider.
507-
:param supplier: a supplier function that returns the username and password.
508-
def supplier(arg1, arg2, ...) -> (username, password)
509-
For examples see examples/connection_examples.ipynb
510-
:param args: arguments to pass to the supplier function
511-
:param kwargs: keyword arguments to pass to the supplier function
512-
"""
513-
self.username = username
514-
self.password = password
515-
self.supplier = supplier
516-
self.args = args
517-
self.kwargs = kwargs
518-
519-
def get_credentials(self):
520-
if self.supplier:
521-
self.username, self.password = self.supplier(*self.args, **self.kwargs)
522-
if self.username:
523-
auth_args = (self.username, self.password or "")
524-
else:
525-
auth_args = (self.password,)
526-
return auth_args
527-
528-
def get_password(self, call_supplier=True):
529-
if call_supplier and self.supplier:
530-
self.username, self.password = self.supplier(*self.args, **self.kwargs)
531-
return self.password
532-
533-
def get_username(self, call_supplier=True):
534-
if call_supplier and self.supplier:
535-
self.username, self.password = self.supplier(*self.args, **self.kwargs)
536-
return self.username
537-
538-
539504
class Connection:
540505
"Manages TCP communication to and from a Redis server"
541506

redis/credentials.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
class CredentialsProvider:
2+
def __init__(
3+
self,
4+
username: str = "",
5+
password: str = "",
6+
supplier: callable = None,
7+
*args,
8+
**kwargs,
9+
):
10+
"""
11+
Initialize a new Credentials Provider.
12+
:param supplier: a supplier function that returns the username and password.
13+
def supplier(arg1, arg2, ...) -> (username, password)
14+
For examples see examples/connection_examples.ipynb
15+
:param args: arguments to pass to the supplier function
16+
:param kwargs: keyword arguments to pass to the supplier function
17+
"""
18+
self.username = username
19+
self.password = password
20+
self.supplier = supplier
21+
self.args = args
22+
self.kwargs = kwargs
23+
24+
def get_credentials(self):
25+
if self.supplier:
26+
self.username, self.password = self.supplier(*self.args, **self.kwargs)
27+
if self.username:
28+
auth_args = (self.username, self.password or "")
29+
else:
30+
auth_args = (self.password,)
31+
return auth_args
32+
33+
def get_password(self, call_supplier: bool = True):
34+
if call_supplier and self.supplier:
35+
self.username, self.password = self.supplier(*self.args, **self.kwargs)
36+
return self.password
37+
38+
def get_username(self, call_supplier: bool = True):
39+
if call_supplier and self.supplier:
40+
self.username, self.password = self.supplier(*self.args, **self.kwargs)
41+
return self.username

tests/test_connection.py

Lines changed: 3 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,17 @@
1-
import random
21
import socket
3-
import string
42
import types
53
from unittest import mock
64
from unittest.mock import patch
75

86
import pytest
97

10-
import redis
118
from redis.backoff import NoBackoff
12-
from redis.connection import Connection, CredentialsProvider
13-
from redis.exceptions import (
14-
ConnectionError,
15-
InvalidResponse,
16-
ResponseError,
17-
TimeoutError,
18-
)
9+
from redis.connection import Connection
10+
from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError
1911
from redis.retry import Retry
2012
from redis.utils import HIREDIS_AVAILABLE
2113

22-
from .conftest import _get_client, skip_if_redis_enterprise, skip_if_server_version_lt
14+
from .conftest import skip_if_server_version_lt
2315

2416

2517
@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
@@ -130,86 +122,3 @@ def test_connect_timeout_error_without_retry(self):
130122
assert conn._connect.call_count == 1
131123
assert str(e.value) == "Timeout connecting to server"
132124
self.clear(conn)
133-
134-
135-
class TestCredentialsProvider:
136-
@skip_if_redis_enterprise()
137-
def test_credentials_provider_without_supplier(self, r, request):
138-
# first, test for default user (`username` is supposed to be optional)
139-
default_username = "default"
140-
temp_pass = "temp_pass"
141-
creds_provider = CredentialsProvider(default_username, temp_pass)
142-
r.config_set("requirepass", temp_pass)
143-
creds = creds_provider.get_credentials()
144-
assert r.auth(creds[1], creds[0]) is True
145-
assert r.auth(creds_provider.get_password()) is True
146-
147-
# test for other users
148-
username = "redis-py-auth"
149-
password = "strong_password"
150-
151-
def teardown():
152-
try:
153-
r.auth(temp_pass)
154-
except ResponseError:
155-
r.auth("default", "")
156-
r.config_set("requirepass", "")
157-
r.acl_deluser(username)
158-
159-
request.addfinalizer(teardown)
160-
161-
assert r.acl_setuser(
162-
username,
163-
enabled=True,
164-
passwords=["+" + password],
165-
keys="~*",
166-
commands=["+ping", "+command", "+info", "+select", "+flushdb", "+cluster"],
167-
)
168-
169-
creds_provider2 = CredentialsProvider(username, password)
170-
r2 = _get_client(
171-
redis.Redis, request, flushdb=False, credentials_provider=creds_provider2
172-
)
173-
174-
assert r2.ping() is True
175-
176-
@skip_if_redis_enterprise()
177-
def test_credentials_provider_with_supplier(self, r, request):
178-
import functools
179-
180-
@functools.lru_cache(maxsize=10)
181-
def auth_supplier(user, endpoint):
182-
def get_random_string(length):
183-
letters = string.ascii_lowercase
184-
result_str = "".join(random.choice(letters) for i in range(length))
185-
return result_str
186-
187-
auth_token = get_random_string(5) + user + "_" + endpoint
188-
return user, auth_token
189-
190-
username = "redis-py-auth"
191-
creds_provider = CredentialsProvider(
192-
supplier=auth_supplier,
193-
user=username,
194-
endpoint="localhost",
195-
)
196-
password = creds_provider.get_password()
197-
198-
assert r.acl_setuser(
199-
username,
200-
enabled=True,
201-
passwords=["+" + password],
202-
keys="~*",
203-
commands=["+ping", "+command", "+info", "+select", "+flushdb", "+cluster"],
204-
)
205-
206-
def teardown():
207-
r.acl_deluser(username)
208-
209-
request.addfinalizer(teardown)
210-
211-
r2 = _get_client(
212-
redis.Redis, request, flushdb=False, credentials_provider=creds_provider
213-
)
214-
215-
assert r2.ping() is True

tests/test_credentials.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import random
2+
import string
3+
4+
import redis
5+
from redis import ResponseError
6+
from redis.credentials import CredentialsProvider
7+
from tests.conftest import _get_client, skip_if_redis_enterprise
8+
9+
10+
class TestCredentialsProvider:
11+
@skip_if_redis_enterprise()
12+
def test_credentials_provider_without_supplier(self, r, request):
13+
# first, test for default user (`username` is supposed to be optional)
14+
default_username = "default"
15+
temp_pass = "temp_pass"
16+
creds_provider = CredentialsProvider(default_username, temp_pass)
17+
r.config_set("requirepass", temp_pass)
18+
creds = creds_provider.get_credentials()
19+
assert r.auth(creds[1], creds[0]) is True
20+
assert r.auth(creds_provider.get_password()) is True
21+
22+
# test for other users
23+
username = "redis-py-auth"
24+
password = "strong_password"
25+
26+
def teardown():
27+
try:
28+
r.auth(temp_pass)
29+
except ResponseError:
30+
r.auth("default", "")
31+
r.config_set("requirepass", "")
32+
r.acl_deluser(username)
33+
34+
request.addfinalizer(teardown)
35+
36+
assert r.acl_setuser(
37+
username,
38+
enabled=True,
39+
passwords=["+" + password],
40+
keys="~*",
41+
commands=["+ping", "+command", "+info", "+select", "+flushdb", "+cluster"],
42+
)
43+
44+
creds_provider2 = CredentialsProvider(username, password)
45+
r2 = _get_client(
46+
redis.Redis, request, flushdb=False, credentials_provider=creds_provider2
47+
)
48+
49+
assert r2.ping() is True
50+
51+
@skip_if_redis_enterprise()
52+
def test_credentials_provider_with_supplier(self, r, request):
53+
import functools
54+
55+
@functools.lru_cache(maxsize=10)
56+
def auth_supplier(user, endpoint):
57+
def get_random_string(length):
58+
letters = string.ascii_lowercase
59+
result_str = "".join(random.choice(letters) for i in range(length))
60+
return result_str
61+
62+
auth_token = get_random_string(5) + user + "_" + endpoint
63+
return user, auth_token
64+
65+
username = "redis-py-auth"
66+
creds_provider = CredentialsProvider(
67+
supplier=auth_supplier,
68+
user=username,
69+
endpoint="localhost",
70+
)
71+
password = creds_provider.get_password()
72+
73+
assert r.acl_setuser(
74+
username,
75+
enabled=True,
76+
passwords=["+" + password],
77+
keys="~*",
78+
commands=["+ping", "+command", "+info", "+select", "+flushdb", "+cluster"],
79+
)
80+
81+
def teardown():
82+
r.acl_deluser(username)
83+
84+
request.addfinalizer(teardown)
85+
86+
r2 = _get_client(
87+
redis.Redis, request, flushdb=False, credentials_provider=creds_provider
88+
)
89+
90+
assert r2.ping() is True

0 commit comments

Comments
 (0)