Skip to content

Commit 51d76a0

Browse files
committed
add method return typing, sampling_rule defaults, update tests
1 parent 1b9e3b5 commit 51d76a0

File tree

4 files changed

+90
-22
lines changed

4 files changed

+90
-22
lines changed

aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_aws_xray_sampling_client.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(self, endpoint=None, log_level=None):
2020
_logger.error("endpoint must be specified")
2121
self.__get_sampling_rules_endpoint = endpoint + "/GetSamplingRules"
2222

23-
def get_sampling_rules(self):
23+
def get_sampling_rules(self) -> [_SamplingRule]:
2424
sampling_rules = []
2525
headers = {"content-type": "application/json"}
2626

@@ -38,7 +38,10 @@ def get_sampling_rules(self):
3838

3939
sampling_rules_records = sampling_rules_response["SamplingRuleRecords"]
4040
for record in sampling_rules_records:
41-
sampling_rules.append(_SamplingRule(**record["SamplingRule"]))
41+
if "SamplingRule" not in record:
42+
_logger.error("SamplingRule is missing in SamplingRuleRecord")
43+
else:
44+
sampling_rules.append(_SamplingRule(**record["SamplingRule"]))
4245

4346
except requests.exceptions.RequestException as req_err:
4447
_logger.error("Request error occurred: %s", req_err)

aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_sampling_rule.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
class _SamplingRule:
88
def __init__(
99
self,
10-
Attributes=None,
10+
Attributes: dict = None,
1111
FixedRate=None,
1212
HTTPMethod=None,
1313
Host=None,
@@ -21,16 +21,16 @@ def __init__(
2121
URLPath=None,
2222
Version=None,
2323
):
24-
self.Attributes = Attributes
25-
self.FixedRate = FixedRate
26-
self.HTTPMethod = HTTPMethod
27-
self.Host = Host
28-
self.Priority = Priority
29-
self.ReservoirSize = ReservoirSize
30-
self.ResourceARN = ResourceARN
31-
self.RuleARN = RuleARN
32-
self.RuleName = RuleName
33-
self.ServiceName = ServiceName
34-
self.ServiceType = ServiceType
35-
self.URLPath = URLPath
36-
self.Version = Version
24+
self.Attributes = Attributes if Attributes is not None else {}
25+
self.FixedRate = FixedRate if FixedRate is not None else ""
26+
self.HTTPMethod = HTTPMethod if HTTPMethod is not None else ""
27+
self.Host = Host if Host is not None else ""
28+
self.Priority = Priority if Priority is not None else ""
29+
self.ReservoirSize = ReservoirSize if ReservoirSize is not None else ""
30+
self.ResourceARN = ResourceARN if ResourceARN is not None else ""
31+
self.RuleARN = RuleARN if RuleARN is not None else ""
32+
self.RuleName = RuleName if RuleName is not None else ""
33+
self.ServiceName = ServiceName if ServiceName is not None else ""
34+
self.ServiceType = ServiceType if ServiceType is not None else ""
35+
self.URLPath = URLPath if URLPath is not None else ""
36+
self.Version = Version if Version is not None else ""

aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/aws_xray_remote_sampler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,9 @@ def __init__(
5959
self.__resource = Resource.get_empty()
6060

6161
# Schedule the next rule poll now
62+
# Python Timers only run once, so they need to be recreated for every poll
6263
self._timer = Timer(0, self.__start_sampling_rule_poller)
63-
self._timer.daemon = True
64+
self._timer.daemon = True # Ensures that when the main thread exits, the Timer threads are killed
6465
self._timer.start()
6566

6667
# pylint: disable=no-self-use

aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_sampling_client.py

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,82 @@ def test_get_no_sampling_rules(self, mock_post=None):
2424
self.assertTrue(len(sampling_rules) == 0)
2525

2626
@patch("requests.post")
27-
def test_get_invalid_response(self, mock_post=None):
27+
def test_get_invalid_responses(self, mock_post=None):
2828
mock_post.return_value.configure_mock(**{"json.return_value": {}})
2929
client = _AwsXRaySamplingClient("http://127.0.0.1:2000")
3030
with self.assertLogs(_logger, level="ERROR"):
3131
sampling_rules = client.get_sampling_rules()
32-
self.assertTrue(len(sampling_rules) == 0)
32+
self.assertTrue(len(sampling_rules) == 0)
33+
34+
@patch("requests.post")
35+
def test_get_sampling_rule_missing_in_records(self, mock_post=None):
36+
mock_post.return_value.configure_mock(**{"json.return_value": {"SamplingRuleRecords": [{}]}})
37+
client = _AwsXRaySamplingClient("http://127.0.0.1:2000")
38+
with self.assertLogs(_logger, level="ERROR"):
39+
sampling_rules = client.get_sampling_rules()
40+
self.assertTrue(len(sampling_rules) == 0)
3341

3442
@patch("requests.post")
35-
def test_get_two_sampling_rules(self, mock_post=None):
43+
def test_default_values_used_when_missing_properties_in_sampling_rule(self, mock_post=None):
44+
mock_post.return_value.configure_mock(**{"json.return_value": {"SamplingRuleRecords": [{"SamplingRule": {}}]}})
45+
client = _AwsXRaySamplingClient("http://127.0.0.1:2000")
46+
sampling_rules = client.get_sampling_rules()
47+
self.assertTrue(len(sampling_rules) == 1)
48+
49+
sampling_rule = sampling_rules[0]
50+
self.assertEqual(sampling_rule.Attributes, {})
51+
self.assertEqual(sampling_rule.FixedRate, "")
52+
self.assertEqual(sampling_rule.HTTPMethod, "")
53+
self.assertEqual(sampling_rule.Host, "")
54+
self.assertEqual(sampling_rule.Priority, "")
55+
self.assertEqual(sampling_rule.ReservoirSize, "")
56+
self.assertEqual(sampling_rule.ResourceARN, "")
57+
self.assertEqual(sampling_rule.RuleARN, "")
58+
self.assertEqual(sampling_rule.RuleName, "")
59+
self.assertEqual(sampling_rule.ServiceName, "")
60+
self.assertEqual(sampling_rule.ServiceType, "")
61+
self.assertEqual(sampling_rule.URLPath, "")
62+
self.assertEqual(sampling_rule.Version, "")
63+
64+
@patch("requests.post")
65+
def test_get_three_sampling_rules(self, mock_post=None):
66+
sampling_records = []
3667
with open(f"{DATA_DIR}/get-sampling-rules-response-sample.json", encoding="UTF-8") as file:
37-
mock_post.return_value.configure_mock(**{"json.return_value": json.load(file)})
68+
sample_response = json.load(file)
69+
sampling_records = sample_response["SamplingRuleRecords"]
70+
mock_post.return_value.configure_mock(**{"json.return_value": sample_response})
3871
file.close()
3972
client = _AwsXRaySamplingClient("http://127.0.0.1:2000")
4073
sampling_rules = client.get_sampling_rules()
41-
self.assertTrue(len(sampling_rules) == 3)
74+
self.assertEqual(len(sampling_rules), 3)
75+
self.assertEqual(len(sampling_rules), len(sampling_records))
76+
self.validate_match_sampling_rules_properties_with_records(sampling_rules, sampling_records)
77+
78+
def validate_match_sampling_rules_properties_with_records(self, sampling_rules, sampling_records):
79+
for _, (sampling_rule, sampling_record) in enumerate(zip(sampling_rules, sampling_records)):
80+
self.assertIsNotNone(sampling_rule.Attributes)
81+
self.assertEqual(sampling_rule.Attributes, sampling_record["SamplingRule"]["Attributes"])
82+
self.assertIsNotNone(sampling_rule.FixedRate)
83+
self.assertEqual(sampling_rule.FixedRate, sampling_record["SamplingRule"]["FixedRate"])
84+
self.assertIsNotNone(sampling_rule.HTTPMethod)
85+
self.assertEqual(sampling_rule.HTTPMethod, sampling_record["SamplingRule"]["HTTPMethod"])
86+
self.assertIsNotNone(sampling_rule.Host)
87+
self.assertEqual(sampling_rule.Host, sampling_record["SamplingRule"]["Host"])
88+
self.assertIsNotNone(sampling_rule.Priority)
89+
self.assertEqual(sampling_rule.Priority, sampling_record["SamplingRule"]["Priority"])
90+
self.assertIsNotNone(sampling_rule.ReservoirSize)
91+
self.assertEqual(sampling_rule.ReservoirSize, sampling_record["SamplingRule"]["ReservoirSize"])
92+
self.assertIsNotNone(sampling_rule.ResourceARN)
93+
self.assertEqual(sampling_rule.ResourceARN, sampling_record["SamplingRule"]["ResourceARN"])
94+
self.assertIsNotNone(sampling_rule.RuleARN)
95+
self.assertEqual(sampling_rule.RuleARN, sampling_record["SamplingRule"]["RuleARN"])
96+
self.assertIsNotNone(sampling_rule.RuleName)
97+
self.assertEqual(sampling_rule.RuleName, sampling_record["SamplingRule"]["RuleName"])
98+
self.assertIsNotNone(sampling_rule.ServiceName)
99+
self.assertEqual(sampling_rule.ServiceName, sampling_record["SamplingRule"]["ServiceName"])
100+
self.assertIsNotNone(sampling_rule.ServiceType)
101+
self.assertEqual(sampling_rule.ServiceType, sampling_record["SamplingRule"]["ServiceType"])
102+
self.assertIsNotNone(sampling_rule.URLPath)
103+
self.assertEqual(sampling_rule.URLPath, sampling_record["SamplingRule"]["URLPath"])
104+
self.assertIsNotNone(sampling_rule.Version)
105+
self.assertEqual(sampling_rule.Version, sampling_record["SamplingRule"]["Version"])

0 commit comments

Comments
 (0)