Skip to content

Commit 601c1aa

Browse files
authored
(tests): Added testing for auth via DefaultAzureCredential (#3544)
* (tests): Added testing for auth via DefaultAzureCredential * Added testing for async * Remove unused import
1 parent b8ba391 commit 601c1aa

File tree

5 files changed

+59
-14
lines changed

5 files changed

+59
-14
lines changed

dev_requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@ ujson>=4.2.0
1313
uvloop
1414
vulture>=2.3.0
1515
numpy>=1.24.0
16-
redis-entraid==0.3.0b1
16+
redis-entraid==0.4.0b2

tests/entraid_utils.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,16 @@
1919
ServicePrincipalIdentityProviderConfig,
2020
_create_provider_from_managed_identity,
2121
_create_provider_from_service_principal,
22+
DefaultAzureCredentialIdentityProviderConfig,
23+
_create_provider_from_default_azure_credential,
2224
)
2325
from tests.conftest import mock_identity_provider
2426

2527

2628
class AuthType(Enum):
2729
MANAGED_IDENTITY = "managed_identity"
2830
SERVICE_PRINCIPAL = "service_principal"
31+
DEFAULT_AZURE_CREDENTIAL = "default_azure_credential"
2932

3033

3134
def identity_provider(request) -> IdentityProviderInterface:
@@ -37,18 +40,25 @@ def identity_provider(request) -> IdentityProviderInterface:
3740
if request.param.get("mock_idp", None) is not None:
3841
return mock_identity_provider()
3942

40-
auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL)
43+
auth_type = kwargs.get("auth_type", AuthType.SERVICE_PRINCIPAL)
4144
config = get_identity_provider_config(request=request)
4245

43-
if auth_type == "MANAGED_IDENTITY":
46+
if auth_type == AuthType.MANAGED_IDENTITY:
4447
return _create_provider_from_managed_identity(config)
4548

49+
if auth_type == AuthType.DEFAULT_AZURE_CREDENTIAL:
50+
return _create_provider_from_default_azure_credential(config)
51+
4652
return _create_provider_from_service_principal(config)
4753

4854

4955
def get_identity_provider_config(
5056
request,
51-
) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]:
57+
) -> Union[
58+
ManagedIdentityProviderConfig,
59+
ServicePrincipalIdentityProviderConfig,
60+
DefaultAzureCredentialIdentityProviderConfig,
61+
]:
5262
if hasattr(request, "param"):
5363
kwargs = request.param.get("idp_kwargs", {})
5464
else:
@@ -59,6 +69,9 @@ def get_identity_provider_config(
5969
if auth_type == AuthType.MANAGED_IDENTITY:
6070
return _get_managed_identity_provider_config(request)
6171

72+
if auth_type == AuthType.DEFAULT_AZURE_CREDENTIAL:
73+
return _get_default_azure_credential_provider_config(request)
74+
6275
return _get_service_principal_provider_config(request)
6376

6477

@@ -114,6 +127,26 @@ def _get_service_principal_provider_config(
114127
)
115128

116129

130+
def _get_default_azure_credential_provider_config(
131+
request,
132+
) -> DefaultAzureCredentialIdentityProviderConfig:
133+
scopes = os.getenv("AZURE_REDIS_SCOPES", ())
134+
135+
if hasattr(request, "param"):
136+
kwargs = request.param.get("idp_kwargs", {})
137+
token_kwargs = request.param.get("token_kwargs", {})
138+
else:
139+
kwargs = {}
140+
token_kwargs = {}
141+
142+
if isinstance(scopes, str):
143+
scopes = scopes.split(",")
144+
145+
return DefaultAzureCredentialIdentityProviderConfig(
146+
scopes=scopes, app_kwargs=kwargs, token_kwargs=token_kwargs
147+
)
148+
149+
117150
def get_entra_id_credentials_provider(request, cred_provider_kwargs):
118151
idp = identity_provider(request)
119152
expiration_refresh_ratio = cred_provider_kwargs.get(

tests/test_asyncio/conftest.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import random
22
from contextlib import asynccontextmanager as _asynccontextmanager
3-
from enum import Enum
43
from typing import Union
54

65
import pytest
@@ -18,11 +17,6 @@
1817
from .compat import mock
1918

2019

21-
class AuthType(Enum):
22-
MANAGED_IDENTITY = "managed_identity"
23-
SERVICE_PRINCIPAL = "service_principal"
24-
25-
2620
async def _get_info(redis_url):
2721
client = redis.Redis.from_url(redis_url)
2822
info = await client.info()

tests/test_asyncio/test_credentials.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from redis.exceptions import ConnectionError
1919
from redis.utils import str_if_bytes
2020
from tests.conftest import get_endpoint, skip_if_redis_enterprise
21+
from tests.entraid_utils import AuthType
2122
from tests.test_asyncio.conftest import get_credential_provider
2223

2324
try:
@@ -616,8 +617,12 @@ class TestEntraIdCredentialsProvider:
616617
"cred_provider_class": EntraIdCredentialsProvider,
617618
"cred_provider_kwargs": {"block_for_initial": True},
618619
},
620+
{
621+
"cred_provider_class": EntraIdCredentialsProvider,
622+
"idp_kwargs": {"auth_type": AuthType.DEFAULT_AZURE_CREDENTIAL},
623+
},
619624
],
620-
ids=["blocked", "non-blocked"],
625+
ids=["blocked", "non-blocked", "DefaultAzureCredential"],
621626
indirect=True,
622627
)
623628
@pytest.mark.asyncio
@@ -692,8 +697,12 @@ class TestClusterEntraIdCredentialsProvider:
692697
"cred_provider_class": EntraIdCredentialsProvider,
693698
"cred_provider_kwargs": {"block_for_initial": True},
694699
},
700+
{
701+
"cred_provider_class": EntraIdCredentialsProvider,
702+
"idp_kwargs": {"auth_type": AuthType.DEFAULT_AZURE_CREDENTIAL},
703+
},
695704
],
696-
ids=["blocked", "non-blocked"],
705+
ids=["blocked", "non-blocked", "DefaultAzureCredential"],
697706
indirect=True,
698707
)
699708
@pytest.mark.asyncio

tests/test_credentials.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
get_endpoint,
2323
skip_if_redis_enterprise,
2424
)
25+
from tests.entraid_utils import AuthType
2526

2627
try:
2728
from redis_entraid.cred_provider import EntraIdCredentialsProvider
@@ -585,8 +586,12 @@ class TestEntraIdCredentialsProvider:
585586
"cred_provider_class": EntraIdCredentialsProvider,
586587
"single_connection_client": True,
587588
},
589+
{
590+
"cred_provider_class": EntraIdCredentialsProvider,
591+
"idp_kwargs": {"auth_type": AuthType.DEFAULT_AZURE_CREDENTIAL},
592+
},
588593
],
589-
ids=["pool", "single"],
594+
ids=["pool", "single", "DefaultAzureCredential"],
590595
indirect=True,
591596
)
592597
@pytest.mark.onlynoncluster
@@ -656,8 +661,12 @@ class TestClusterEntraIdCredentialsProvider:
656661
"cred_provider_class": EntraIdCredentialsProvider,
657662
"single_connection_client": True,
658663
},
664+
{
665+
"cred_provider_class": EntraIdCredentialsProvider,
666+
"idp_kwargs": {"auth_type": AuthType.DEFAULT_AZURE_CREDENTIAL},
667+
},
659668
],
660-
ids=["pool", "single"],
669+
ids=["pool", "single", "DefaultAzureCredential"],
661670
indirect=True,
662671
)
663672
@pytest.mark.onlycluster

0 commit comments

Comments
 (0)