Skip to content

Commit 3186a64

Browse files
committed
unit tests for initial remote sampler and lint
1 parent 2a68194 commit 3186a64

File tree

6 files changed

+185
-24
lines changed

6 files changed

+185
-24
lines changed

aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/aws_xray_remote_sampler.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
DEFAULT_TARGET_POLLING_INTERVAL = 10
2222
DEFAULT_SAMPLING_PROXY_ENDPOINT = "http://127.0.0.1:2000"
2323

24+
2425
class AwsXRayRemoteSampler(Sampler):
2526
"""
2627
Remote Sampler for OpenTelemetry that gets sampling configurations from AWS X-Ray
@@ -32,11 +33,17 @@ class AwsXRayRemoteSampler(Sampler):
3233
log_level: custom log level configuration for remote sampler (Optional)
3334
"""
3435

35-
__resource : Resource
36-
__polling_interval : int
37-
__xray_client : AwsXRaySamplingClient
36+
__resource: Resource
37+
__polling_interval: int
38+
__xray_client: AwsXRaySamplingClient
3839

39-
def __init__(self, resource=None, endpoint=DEFAULT_SAMPLING_PROXY_ENDPOINT, polling_interval=DEFAULT_RULES_POLLING_INTERVAL, log_level = None):
40+
def __init__(
41+
self,
42+
resource=None,
43+
endpoint=DEFAULT_SAMPLING_PROXY_ENDPOINT,
44+
polling_interval=DEFAULT_RULES_POLLING_INTERVAL,
45+
log_level=None,
46+
):
4047
# Override default log level
4148
if log_level is not None:
4249
_logger.setLevel(log_level)
@@ -72,8 +79,7 @@ def __get_and_update_sampling_rules(self):
7279

7380
def __start_sampling_rule_poller(self):
7481
self.__get_and_update_sampling_rules()
75-
# Schedule the next sampling rule poll
82+
# Schedule the next sampling rule poll
7683
self._timer = Timer(self.__polling_interval, self.__start_sampling_rule_poller)
7784
self._timer.daemon = True
7885
self._timer.start()
79-

aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/aws_xray_sampling_client.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
_logger = getLogger(__name__)
1111

12+
1213
class AwsXRaySamplingClient:
1314
def __init__(self, endpoint=None, log_level=None):
1415
# Override default log level
@@ -21,15 +22,17 @@ def __init__(self, endpoint=None, log_level=None):
2122

2223
def get_sampling_rules(self):
2324
sampling_rules = []
24-
headers = {'content-type': 'application/json'}
25+
headers = {"content-type": "application/json"}
2526

2627
try:
2728
r = requests.post(url=self.__getSamplingRulesEndpoint, headers=headers)
2829
if r is None:
2930
raise Exception("GetSamplingRules response is None")
3031
sampling_rules_response = r.json()
3132
if "SamplingRuleRecords" not in sampling_rules_response:
32-
raise Exception(f"SamplingRuleRecords is missing in getSamplingRules response:{sampling_rules_response}")
33+
raise Exception(
34+
f"SamplingRuleRecords is missing in getSamplingRules response:{sampling_rules_response}"
35+
)
3336

3437
sampling_rules_records = sampling_rules_response["SamplingRuleRecords"]
3538
for record in sampling_rules_records:
@@ -42,4 +45,4 @@ def get_sampling_rules(self):
4245
except Exception as ex:
4346
_logger.exception(f"Exception occurred: {ex}")
4447

45-
return sampling_rules
48+
return sampling_rules

aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/sampling_rule.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-License-Identifier: Apache-2.0
33
class SamplingRule:
44
def __init__(
5-
self,
5+
self,
66
Attributes={},
77
FixedRate=None,
88
HTTPMethod=None,
@@ -15,18 +15,18 @@ def __init__(
1515
ServiceName=None,
1616
ServiceType=None,
1717
URLPath=None,
18-
Version=None
18+
Version=None,
1919
):
20-
self.Attributes=Attributes
21-
self.FixedRate=FixedRate
22-
self.HTTPMethod=HTTPMethod
23-
self.Host=Host
24-
self.Priority=Priority
25-
self.ReservoirSize=ReservoirSize
26-
self.ResourceARN=ResourceARN
27-
self.RuleARN=RuleARN
28-
self.RuleName=RuleName
29-
self.ServiceName=ServiceName
30-
self.ServiceType=ServiceType
31-
self.URLPath=URLPath
32-
self.Version=Version
20+
self.Attributes = Attributes
21+
self.FixedRate = FixedRate
22+
self.HTTPMethod = HTTPMethod
23+
self.Host = Host
24+
self.Priority = Priority
25+
self.ReservoirSize = ReservoirSize
26+
self.ResourceARN = ResourceARN
27+
self.RuleARN = RuleARN
28+
self.RuleName = RuleName
29+
self.ServiceName = ServiceName
30+
self.ServiceType = ServiceType
31+
self.URLPath = URLPath
32+
self.Version = Version
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
{
2+
"NextToken": null,
3+
"SamplingRuleRecords": [
4+
{
5+
"CreatedAt": 1.67799933E9,
6+
"ModifiedAt": 1.67799933E9,
7+
"SamplingRule": {
8+
"Attributes": {
9+
"foo": "bar",
10+
"doo": "baz"
11+
},
12+
"FixedRate": 0.05,
13+
"HTTPMethod": "*",
14+
"Host": "*",
15+
"Priority": 1000,
16+
"ReservoirSize": 10,
17+
"ResourceARN": "*",
18+
"RuleARN": "arn:aws:xray:us-west-2:123456789000:sampling-rule/Rule1",
19+
"RuleName": "Rule1",
20+
"ServiceName": "*",
21+
"ServiceType": "AWS::Foo::Bar",
22+
"URLPath": "*",
23+
"Version": 1
24+
}
25+
},
26+
{
27+
"CreatedAt": 0.0,
28+
"ModifiedAt": 1.611564245E9,
29+
"SamplingRule": {
30+
"Attributes": {},
31+
"FixedRate": 0.05,
32+
"HTTPMethod": "*",
33+
"Host": "*",
34+
"Priority": 10000,
35+
"ReservoirSize": 1,
36+
"ResourceARN": "*",
37+
"RuleARN": "arn:aws:xray:us-west-2:123456789000:sampling-rule/Default",
38+
"RuleName": "Default",
39+
"ServiceName": "*",
40+
"ServiceType": "*",
41+
"URLPath": "*",
42+
"Version": 1
43+
}
44+
},
45+
{
46+
"CreatedAt": 1.676038494E9,
47+
"ModifiedAt": 1.676038494E9,
48+
"SamplingRule": {
49+
"Attributes": {},
50+
"FixedRate": 0.2,
51+
"HTTPMethod": "GET",
52+
"Host": "*",
53+
"Priority": 1,
54+
"ReservoirSize": 10,
55+
"ResourceARN": "*",
56+
"RuleARN": "arn:aws:xray:us-west-2:123456789000:sampling-rule/Rule2",
57+
"RuleName": "Rule2",
58+
"ServiceName": "FooBar",
59+
"ServiceType": "*",
60+
"URLPath": "/foo/bar",
61+
"Version": 1
62+
}
63+
}
64+
]
65+
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
from logging import DEBUG
4+
from unittest import TestCase
5+
6+
from amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler import AwsXRayRemoteSampler
7+
from opentelemetry.sdk.resources import Resource
8+
9+
10+
class AwsXRayRemoteSamplerTest(TestCase):
11+
def test_create_remote_sampler_with_empty_resource(self):
12+
rs = AwsXRayRemoteSampler(resource=Resource.get_empty())
13+
self.assertIsNotNone(rs._timer)
14+
self.assertEqual(rs._AwsXRayRemoteSampler__polling_interval, 300)
15+
self.assertIsNotNone(rs._AwsXRayRemoteSampler__xray_client)
16+
self.assertIsNotNone(rs._AwsXRayRemoteSampler__resource)
17+
18+
def test_create_remote_sampler_with_populated_resource(self):
19+
rs = AwsXRayRemoteSampler(
20+
resource=Resource.create({"service.name": "test-service-name", "cloud.platform": "test-cloud-platform"})
21+
)
22+
self.assertIsNotNone(rs._timer)
23+
self.assertEqual(rs._AwsXRayRemoteSampler__polling_interval, 300)
24+
self.assertIsNotNone(rs._AwsXRayRemoteSampler__xray_client)
25+
self.assertIsNotNone(rs._AwsXRayRemoteSampler__resource)
26+
self.assertEqual(rs._AwsXRayRemoteSampler__resource.attributes["service.name"], "test-service-name")
27+
self.assertEqual(rs._AwsXRayRemoteSampler__resource.attributes["cloud.platform"], "test-cloud-platform")
28+
29+
def test_create_remote_sampler_with_all_fields_populated(self):
30+
rs = AwsXRayRemoteSampler(
31+
resource=Resource.create({"service.name": "test-service-name", "cloud.platform": "test-cloud-platform"}),
32+
endpoint="http://abc.com",
33+
polling_interval=120,
34+
log_level=DEBUG,
35+
)
36+
self.assertIsNotNone(rs._timer)
37+
self.assertEqual(rs._AwsXRayRemoteSampler__polling_interval, 120)
38+
self.assertIsNotNone(rs._AwsXRayRemoteSampler__xray_client)
39+
self.assertIsNotNone(rs._AwsXRayRemoteSampler__resource)
40+
self.assertEqual(
41+
rs._AwsXRayRemoteSampler__xray_client._AwsXRaySamplingClient__getSamplingRulesEndpoint,
42+
"http://abc.com/GetSamplingRules",
43+
) # "http://127.0.0.1:2000"
44+
self.assertEqual(rs._AwsXRayRemoteSampler__resource.attributes["service.name"], "test-service-name")
45+
self.assertEqual(rs._AwsXRayRemoteSampler__resource.attributes["cloud.platform"], "test-cloud-platform")
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
import json
4+
import os
5+
from logging import getLogger
6+
from unittest import TestCase
7+
from unittest.mock import patch
8+
9+
from amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler import AwsXRayRemoteSampler
10+
from amazon.opentelemetry.distro.sampler.aws_xray_sampling_client import AwsXRaySamplingClient
11+
12+
SAMPLING_CLIENT_LOGGER_NAME = "amazon.opentelemetry.distro.sampler.aws_xray_sampling_client"
13+
_logger = getLogger(SAMPLING_CLIENT_LOGGER_NAME)
14+
15+
TEST_DIR = os.path.dirname(os.path.realpath(__file__))
16+
DATA_DIR = os.path.join(TEST_DIR, "data")
17+
18+
19+
class AwsXRaySamplingClientTest(TestCase):
20+
@patch("requests.post")
21+
def test_get_no_sampling_rules(self, mock_post=None, prom_rw=None):
22+
mock_post.return_value.configure_mock(**{"json.return_value": {"SamplingRuleRecords": []}})
23+
client = AwsXRaySamplingClient("http://127.0.0.1:2000")
24+
sampling_rules = client.get_sampling_rules()
25+
self.assertTrue(len(sampling_rules) == 0)
26+
27+
@patch("requests.post")
28+
def test_get_invalid_response(self, mock_post=None, prom_rw=None):
29+
mock_post.return_value.configure_mock(**{"json.return_value": {}})
30+
client = AwsXRaySamplingClient("http://127.0.0.1:2000")
31+
with self.assertLogs(_logger, level="ERROR") as cm:
32+
sampling_rules = client.get_sampling_rules()
33+
self.assertTrue(len(sampling_rules) == 0)
34+
35+
@patch("requests.post")
36+
def test_get_two_sampling_rules(self, mock_post=None, prom_rw=None):
37+
with open(f"{DATA_DIR}/get-sampling-rules-response-sample.json") as f:
38+
mock_post.return_value.configure_mock(**{"json.return_value": json.load(f)})
39+
f.close()
40+
client = AwsXRaySamplingClient("http://127.0.0.1:2000")
41+
sampling_rules = client.get_sampling_rules()
42+
self.assertTrue(len(sampling_rules) == 3)

0 commit comments

Comments
 (0)