Skip to content

Commit 7743c4a

Browse files
committed
Update tests
1 parent 240205f commit 7743c4a

File tree

4 files changed

+62
-57
lines changed

4 files changed

+62
-57
lines changed

tests/conftest.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,42 @@
11
import httpx
22
import pytest
3-
from azure.identity.aio import DefaultAzureCredential
4-
from kiota_abstractions.authentication import AccessTokenProvider
3+
from kiota_abstractions.authentication import AnonymousAuthenticationProvider
54
from kiota_authentication_azure.azure_identity_access_token_provider import (
65
AzureIdentityAccessTokenProvider,
76
)
87

98
from msgraph.core import APIVersion, NationalClouds
10-
from msgraph.core.middleware import GraphRequest, GraphRequestContext
9+
from msgraph.core.graph_client_factory import GraphClientFactory
10+
from msgraph.core.middleware import GraphRequestContext
1111

1212
BASE_URL = NationalClouds.Global + '/' + APIVersion.v1
1313

1414

15-
class MockAccessTokenProvider(AccessTokenProvider):
15+
class MockAuthenticationProvider(AnonymousAuthenticationProvider):
1616

17-
async def get_authorization_token(self, request: GraphRequest) -> str:
17+
async def get_authorization_token(self, request: httpx.Request) -> str:
1818
"""Returns a string representing a dummy token
1919
Args:
2020
request (GraphRequest): Graph request object
2121
"""
22-
return "Sample token"
23-
24-
def get_allowed_hosts_validator(self) -> None:
25-
pass
22+
request.headers['Authorization'] = 'Sample token'
23+
return
2624

2725

2826
@pytest.fixture
29-
def mock_token_provider():
30-
return MockAccessTokenProvider()
27+
def mock_auth_provider():
28+
return MockAuthenticationProvider()
3129

3230

3331
@pytest.fixture
3432
def mock_transport():
35-
return httpx.AsyncClient()._transport
33+
client = GraphClientFactory.create_with_default_middleware()
34+
return client._transport
3635

3736

3837
@pytest.fixture
3938
def mock_request():
4039
req = httpx.Request('GET', "https://example.org")
41-
req.context = GraphRequestContext({}, req.headers)
4240
return req
4341

4442

tests/unit/test_graph_client_factory.py

Lines changed: 36 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,62 +4,59 @@
44
# ------------------------------------
55
import httpx
66
import pytest
7-
from kiota_http.middleware import AsyncKiotaTransport
7+
from kiota_http.middleware import (
8+
AsyncKiotaTransport,
9+
MiddlewarePipeline,
10+
ParametersNameDecodingHandler,
11+
)
812

913
from msgraph.core import APIVersion, GraphClientFactory, NationalClouds
10-
from msgraph.core._constants import DEFAULT_CONNECTION_TIMEOUT, DEFAULT_REQUEST_TIMEOUT
11-
from msgraph.core.middleware import GraphAuthorizationHandler
12-
from msgraph.core.middleware.middleware import GraphMiddlewarePipeline
13-
from msgraph.core.middleware.redirect import GraphRedirectHandler
14-
from msgraph.core.middleware.retry import GraphRetryHandler
1514
from msgraph.core.middleware.telemetry import GraphTelemetryHandler
1615

1716

18-
def test_create_with_default_middleware_no_auth_provider():
19-
"""Test creation of GraphClient without a token provider does not
20-
add the Authorization middleware"""
21-
client = GraphClientFactory.create_with_default_middleware(client=httpx.AsyncClient())
17+
def test_create_with_default_middleware():
18+
"""Test creation of GraphClient using default middleware"""
19+
client = GraphClientFactory.create_with_default_middleware()
2220

2321
assert isinstance(client, httpx.AsyncClient)
2422
assert isinstance(client._transport, AsyncKiotaTransport)
25-
pipeline = client._transport.middleware
26-
assert isinstance(pipeline, GraphMiddlewarePipeline)
27-
assert not isinstance(pipeline._first_middleware, GraphAuthorizationHandler)
23+
pipeline = client._transport.pipeline
24+
assert isinstance(pipeline, MiddlewarePipeline)
25+
assert isinstance(pipeline._first_middleware, ParametersNameDecodingHandler)
26+
assert isinstance(pipeline._current_middleware, GraphTelemetryHandler)
2827

2928

30-
def test_create_with_default_middleware(mock_token_provider):
31-
"""Test creation of GraphClient using default middleware and passing a token
32-
provider adds Authorization middleware"""
33-
client = GraphClientFactory.create_with_default_middleware(
34-
client=httpx.AsyncClient(), token_provider=mock_token_provider
35-
)
36-
37-
assert isinstance(client, httpx.AsyncClient)
38-
assert isinstance(client._transport, AsyncKiotaTransport)
39-
pipeline = client._transport.middleware
40-
assert isinstance(pipeline, GraphMiddlewarePipeline)
41-
assert isinstance(pipeline._first_middleware, GraphAuthorizationHandler)
42-
43-
44-
def test_create_with_custom_middleware(mock_token_provider):
29+
def test_create_with_custom_middleware():
4530
"""Test creation of HTTP Clients with custom middleware"""
4631
middleware = [
4732
GraphTelemetryHandler(),
4833
]
49-
client = GraphClientFactory.create_with_custom_middleware(
50-
client=httpx.AsyncClient(), middleware=middleware
51-
)
34+
client = GraphClientFactory.create_with_custom_middleware(middleware=middleware)
5235

5336
assert isinstance(client, httpx.AsyncClient)
5437
assert isinstance(client._transport, AsyncKiotaTransport)
55-
pipeline = client._transport.middleware
38+
pipeline = client._transport.pipeline
5639
assert isinstance(pipeline._first_middleware, GraphTelemetryHandler)
5740

5841

59-
def test_get_common_middleware():
60-
middleware = GraphClientFactory._get_common_middleware()
61-
62-
assert len(middleware) == 3
63-
assert isinstance(middleware[0], GraphRedirectHandler)
64-
assert isinstance(middleware[1], GraphRetryHandler)
65-
assert isinstance(middleware[2], GraphTelemetryHandler)
42+
def test_graph_client_factory_with_custom_configuration():
43+
"""
44+
Test creating a graph client with custom url overrides the default
45+
"""
46+
graph_client = GraphClientFactory.create_with_default_middleware(
47+
api_version=APIVersion.beta, host=NationalClouds.China
48+
)
49+
assert isinstance(graph_client, httpx.AsyncClient)
50+
assert str(graph_client.base_url) == f'{NationalClouds.China}/{APIVersion.beta}/'
51+
52+
53+
def test_get_base_url():
54+
"""
55+
Test base url is formed by combining the national cloud endpoint with
56+
Api version
57+
"""
58+
url = GraphClientFactory._get_base_url(
59+
host=NationalClouds.Germany,
60+
api_version=APIVersion.beta,
61+
)
62+
assert url == f'{NationalClouds.Germany}/{APIVersion.beta}'

tests/unit/test_graph_request_adapter.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,11 @@
77
)
88

99
from msgraph.core.graph_request_adapter import GraphRequestAdapter
10-
from tests.conftest import mock_token_provider
1110

1211

13-
def test_create_graph_request_adapter(mock_token_provider):
14-
request_adapter = GraphRequestAdapter(mock_token_provider)
15-
assert request_adapter._authentication_provider is mock_token_provider
12+
def test_create_graph_request_adapter(mock_auth_provider):
13+
request_adapter = GraphRequestAdapter(mock_auth_provider)
14+
assert request_adapter._authentication_provider is mock_auth_provider
1615
assert isinstance(request_adapter._parse_node_factory, ParseNodeFactoryRegistry)
1716
assert isinstance(
1817
request_adapter._serialization_writer_factory, SerializationWriterFactoryRegistry

tests/unit/test_graph_telemetry_handler.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,23 @@
99
import httpx
1010
import pytest
1111

12-
from msgraph.core import SDK_VERSION, APIVersion, GraphClient, NationalClouds
12+
from msgraph.core import SDK_VERSION, APIVersion, NationalClouds
13+
from msgraph.core._enums import FeatureUsageFlag
1314
from msgraph.core.middleware import GraphRequestContext, GraphTelemetryHandler
1415

1516
BASE_URL = NationalClouds.Global + '/' + APIVersion.v1
1617

1718

19+
def test_set_request_context_and_feature_usage(mock_request, mock_transport):
20+
telemetry_handler = GraphTelemetryHandler()
21+
telemetry_handler.set_request_context_and_feature_usage(mock_request, mock_transport)
22+
23+
assert hasattr(mock_request, 'context')
24+
assert mock_request.context.feature_usage == hex(
25+
FeatureUsageFlag.RETRY_HANDLER_ENABLED | FeatureUsageFlag.REDIRECT_HANDLER_ENABLED
26+
)
27+
28+
1829
def test_is_graph_url(mock_graph_request):
1930
"""
2031
Test method that checks whether a request url is a graph endpoint

0 commit comments

Comments
 (0)