Skip to content

Commit c91cad3

Browse files
authored
AWS X-Ray Remote Sampler Part 2 - Add Rules Caching and Rules Matching Logic (#47)
*Issue #, if available:* Second PR of 3 parts for adding the X-Ray remote sampling support for OTel Python SDK. [See Part 1](#33) *Description of changes:* - Sampling `RuleCache` - Caches a list of `Rule`s, ordered by rule priority then rule name. Each rule corresponds to the Sampling Rule from GetSamplingRules. Each call to GetSamplingRules will only update the `Rule`s that have changed properties, to preserve the state of unchanged rules. This means when Reservoir and Statistics are implemented in the `Rule`s, they will persist for unchanged rules. - The RuleCache will determine which Rule a set of {ResourceAttributes,SpanAttributes} matches with that has highest priority. - `Rule` - Corresponds to a `SamplingRule` and has logic to match with provided set of ResourceAttribute and SpanAttribute using the `Matcher` class. - Will determine the final sampling decision - `Matcher` class with methods to perform: - Convert X-Ray sampling rule options to regex patterns - Wild card and attribute matching - Initial class for `FallbackSampler` Testing: Unit tests By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
1 parent 6287561 commit c91cad3

File tree

13 files changed

+964
-28
lines changed

13 files changed

+964
-28
lines changed

aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_aws_xray_sampling_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
class _AwsXRaySamplingClient:
14-
def __init__(self, endpoint=None, log_level=None):
14+
def __init__(self, endpoint: str = None, log_level: str = None):
1515
# Override default log level
1616
if log_level is not None:
1717
_logger.setLevel(log_level)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
from typing import Optional, Sequence
4+
5+
from opentelemetry.context import Context
6+
from opentelemetry.sdk.trace.sampling import ALWAYS_ON, Sampler, SamplingResult, TraceIdRatioBased
7+
from opentelemetry.trace import Link, SpanKind
8+
from opentelemetry.trace.span import TraceState
9+
from opentelemetry.util.types import Attributes
10+
11+
12+
class _FallbackSampler(Sampler):
13+
def __init__(self):
14+
# TODO: Add Reservoir sampler
15+
# pylint: disable=unused-private-member
16+
self.__fixed_rate_sampler = TraceIdRatioBased(0.05)
17+
18+
# pylint: disable=no-self-use
19+
def should_sample(
20+
self,
21+
parent_context: Optional[Context],
22+
trace_id: int,
23+
name: str,
24+
kind: SpanKind = None,
25+
attributes: Attributes = None,
26+
links: Sequence[Link] = None,
27+
trace_state: TraceState = None,
28+
) -> SamplingResult:
29+
# TODO: add reservoir + fixed rate sampling
30+
return ALWAYS_ON.should_sample(
31+
parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state
32+
)
33+
34+
# pylint: disable=no-self-use
35+
def get_description(self) -> str:
36+
description = (
37+
"FallbackSampler{fallback sampling with sampling config of 1 req/sec and 5% of additional requests}"
38+
)
39+
return description
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
import re
4+
5+
from opentelemetry.semconv.resource import CloudPlatformValues
6+
from opentelemetry.util.types import Attributes
7+
8+
cloud_platform_mapping = {
9+
CloudPlatformValues.AWS_LAMBDA.value: "AWS::Lambda::Function",
10+
CloudPlatformValues.AWS_ELASTIC_BEANSTALK.value: "AWS::ElasticBeanstalk::Environment",
11+
CloudPlatformValues.AWS_EC2.value: "AWS::EC2::Instance",
12+
CloudPlatformValues.AWS_ECS.value: "AWS::ECS::Container",
13+
CloudPlatformValues.AWS_EKS.value: "AWS::EKS::Container",
14+
}
15+
16+
17+
class _Matcher:
18+
@staticmethod
19+
def wild_card_match(text: str = None, pattern: str = None) -> bool:
20+
if pattern == "*":
21+
return True
22+
if text is None or pattern is None:
23+
return False
24+
if len(pattern) == 0:
25+
return len(text) == 0
26+
for char in pattern:
27+
if char in ("*", "?"):
28+
return re.fullmatch(_Matcher.to_regex_pattern(pattern), text) is not None
29+
return pattern == text
30+
31+
@staticmethod
32+
def to_regex_pattern(rule_pattern: str) -> str:
33+
token_start = -1
34+
regex_pattern = ""
35+
for index, char in enumerate(rule_pattern):
36+
char = rule_pattern[index]
37+
if char in ("*", "?"):
38+
if token_start != -1:
39+
regex_pattern += re.escape(rule_pattern[token_start:index])
40+
token_start = -1
41+
if char == "*":
42+
regex_pattern += ".*"
43+
else:
44+
regex_pattern += "."
45+
else:
46+
if token_start == -1:
47+
token_start = index
48+
if token_start != -1:
49+
regex_pattern += re.escape(rule_pattern[token_start:])
50+
return regex_pattern
51+
52+
@staticmethod
53+
def attribute_match(attributes: Attributes = None, rule_attributes: dict = None) -> bool:
54+
if rule_attributes is None or len(rule_attributes) == 0:
55+
return True
56+
if attributes is None or len(attributes) == 0 or len(rule_attributes) > len(attributes):
57+
return False
58+
59+
matched_count = 0
60+
for key, val in attributes.items():
61+
text_to_match = val
62+
pattern = rule_attributes.get(key, None)
63+
if pattern is None:
64+
continue
65+
if _Matcher.wild_card_match(text_to_match, pattern):
66+
matched_count += 1
67+
return matched_count == len(rule_attributes)
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
import datetime
4+
from logging import getLogger
5+
from threading import Lock
6+
from typing import Optional, Sequence
7+
8+
from amazon.opentelemetry.distro.sampler._fallback_sampler import _FallbackSampler
9+
from amazon.opentelemetry.distro.sampler._sampling_rule import _SamplingRule
10+
from amazon.opentelemetry.distro.sampler._sampling_rule_applier import _SamplingRuleApplier
11+
from opentelemetry.context import Context
12+
from opentelemetry.sdk.resources import Resource
13+
from opentelemetry.sdk.trace.sampling import SamplingResult
14+
from opentelemetry.trace import Link, SpanKind
15+
from opentelemetry.trace.span import TraceState
16+
from opentelemetry.util.types import Attributes
17+
18+
_logger = getLogger(__name__)
19+
20+
CACHE_TTL_SECONDS = 3600
21+
22+
23+
class _RuleCache:
24+
def __init__(self, resource: Resource, fallback_sampler: _FallbackSampler, date_time: datetime, lock: Lock):
25+
self.__rule_appliers: [_SamplingRuleApplier] = []
26+
self.__cache_lock = lock
27+
self.__resource = resource
28+
self._fallback_sampler = fallback_sampler
29+
self._date_time = date_time
30+
self._last_modified = self._date_time.datetime.now()
31+
32+
def should_sample(
33+
self,
34+
parent_context: Optional[Context],
35+
trace_id: int,
36+
name: str,
37+
kind: SpanKind = None,
38+
attributes: Attributes = None,
39+
links: Sequence[Link] = None,
40+
trace_state: TraceState = None,
41+
) -> SamplingResult:
42+
for rule_applier in self.__rule_appliers:
43+
if rule_applier.matches(self.__resource, attributes):
44+
return rule_applier.should_sample(
45+
parent_context,
46+
trace_id,
47+
name,
48+
kind=kind,
49+
attributes=attributes,
50+
links=links,
51+
trace_state=trace_state,
52+
)
53+
54+
return self._fallback_sampler.should_sample(
55+
parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state
56+
)
57+
58+
def update_sampling_rules(self, new_sampling_rules: [_SamplingRule]) -> None:
59+
new_sampling_rules.sort()
60+
temp_rule_appliers = []
61+
for sampling_rule in new_sampling_rules:
62+
if sampling_rule.RuleName == "":
63+
_logger.debug("sampling rule without rule name is not supported")
64+
continue
65+
if sampling_rule.Version != 1:
66+
_logger.debug("sampling rule without Version 1 is not supported: RuleName: %s", sampling_rule.RuleName)
67+
continue
68+
temp_rule_appliers.append(_SamplingRuleApplier(sampling_rule))
69+
70+
self.__cache_lock.acquire()
71+
72+
# map list of rule appliers by each applier's sampling_rule name
73+
rule_applier_map = {rule.sampling_rule.RuleName: rule for rule in self.__rule_appliers}
74+
75+
# If a sampling rule has not changed, keep its respective applier in the cache.
76+
for index, new_applier in enumerate(temp_rule_appliers):
77+
rule_name_to_check = new_applier.sampling_rule.RuleName
78+
if rule_name_to_check in rule_applier_map:
79+
old_applier = rule_applier_map[rule_name_to_check]
80+
if new_applier.sampling_rule == old_applier.sampling_rule:
81+
temp_rule_appliers[index] = old_applier
82+
self.__rule_appliers = temp_rule_appliers
83+
self._last_modified = datetime.datetime.now()
84+
85+
self.__cache_lock.release()
86+
87+
def expired(self) -> bool:
88+
self.__cache_lock.acquire()
89+
try:
90+
return datetime.datetime.now() > self._last_modified + datetime.timedelta(seconds=CACHE_TTL_SECONDS)
91+
finally:
92+
self.__cache_lock.release()

aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_sampling_rule.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,18 @@ class _SamplingRule:
88
def __init__(
99
self,
1010
Attributes: dict = None,
11-
FixedRate=None,
12-
HTTPMethod=None,
13-
Host=None,
14-
Priority=None,
15-
ReservoirSize=None,
16-
ResourceARN=None,
17-
RuleARN=None,
18-
RuleName=None,
19-
ServiceName=None,
20-
ServiceType=None,
21-
URLPath=None,
22-
Version=None,
11+
FixedRate: float = None,
12+
HTTPMethod: str = None,
13+
Host: str = None,
14+
Priority: int = None,
15+
ReservoirSize: int = None,
16+
ResourceARN: str = None,
17+
RuleARN: str = None,
18+
RuleName: str = None,
19+
ServiceName: str = None,
20+
ServiceType: str = None,
21+
URLPath: str = None,
22+
Version: int = None,
2323
):
2424
self.Attributes = Attributes if Attributes is not None else {}
2525
self.FixedRate = FixedRate if FixedRate is not None else 0.0
@@ -35,3 +35,29 @@ def __init__(
3535
self.ServiceType = ServiceType if ServiceType is not None else ""
3636
self.URLPath = URLPath if URLPath is not None else ""
3737
self.Version = Version if Version is not None else 0
38+
39+
def __lt__(self, other) -> bool:
40+
if self.Priority == other.Priority:
41+
# String order priority example:
42+
# "A","Abc","a","ab","abc","abcdef"
43+
return self.RuleName < other.RuleName
44+
return self.Priority < other.Priority
45+
46+
def __eq__(self, other: object) -> bool:
47+
if not isinstance(other, _SamplingRule):
48+
return False
49+
return (
50+
self.FixedRate == other.FixedRate
51+
and self.HTTPMethod == other.HTTPMethod
52+
and self.Host == other.Host
53+
and self.Priority == other.Priority
54+
and self.ReservoirSize == other.ReservoirSize
55+
and self.ResourceARN == other.ResourceARN
56+
and self.RuleARN == other.RuleARN
57+
and self.RuleName == other.RuleName
58+
and self.ServiceName == other.ServiceName
59+
and self.ServiceType == other.ServiceType
60+
and self.URLPath == other.URLPath
61+
and self.Version == other.Version
62+
and self.Attributes == other.Attributes
63+
)
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
from typing import Optional, Sequence
4+
from urllib.parse import urlparse
5+
6+
from amazon.opentelemetry.distro.sampler._matcher import _Matcher, cloud_platform_mapping
7+
from amazon.opentelemetry.distro.sampler._sampling_rule import _SamplingRule
8+
from opentelemetry.context import Context
9+
from opentelemetry.sdk.resources import Resource
10+
from opentelemetry.sdk.trace.sampling import ALWAYS_ON, SamplingResult
11+
from opentelemetry.semconv.resource import CloudPlatformValues, ResourceAttributes
12+
from opentelemetry.semconv.trace import SpanAttributes
13+
from opentelemetry.trace import Link, SpanKind
14+
from opentelemetry.trace.span import TraceState
15+
from opentelemetry.util.types import Attributes
16+
17+
18+
class _SamplingRuleApplier:
19+
def __init__(self, sampling_rule: _SamplingRule):
20+
self.sampling_rule = sampling_rule
21+
# TODO add self.next_target_fetch_time from maybe time.process_time() or cache's datetime object
22+
# TODO add statistics
23+
# TODO change to rate limiter given rate, add fixed rate sampler
24+
self.reservoir_sampler = ALWAYS_ON
25+
# self.fixed_rate_sampler = None
26+
# TODO add clientId
27+
28+
def should_sample(
29+
self,
30+
parent_context: Optional[Context],
31+
trace_id: int,
32+
name: str,
33+
kind: SpanKind = None,
34+
attributes: Attributes = None,
35+
links: Sequence[Link] = None,
36+
trace_state: TraceState = None,
37+
) -> SamplingResult:
38+
return self.reservoir_sampler.should_sample(
39+
parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state
40+
)
41+
42+
def matches(self, resource: Resource, attributes: Attributes) -> bool:
43+
url_path = None
44+
url_full = None
45+
http_request_method = None
46+
server_address = None
47+
service_name = None
48+
49+
if attributes is not None:
50+
url_path = attributes.get(SpanAttributes.URL_PATH, None)
51+
url_full = attributes.get(SpanAttributes.URL_FULL, None)
52+
http_request_method = attributes.get(SpanAttributes.HTTP_REQUEST_METHOD, None)
53+
server_address = attributes.get(SpanAttributes.SERVER_ADDRESS, None)
54+
55+
# Resource shouldn't be none as it should default to empty resource
56+
if resource is not None:
57+
service_name = resource.attributes.get(ResourceAttributes.SERVICE_NAME, "")
58+
59+
# target may be in url
60+
if url_path is None and url_full is not None:
61+
scheme_end_index = url_full.find("://")
62+
# For network calls, URL usually has `scheme://host[:port][path][?query][#fragment]` format
63+
# Per spec, url.full is always populated with scheme://host/target.
64+
# If scheme doesn't match, assume it's bad instrumentation and ignore.
65+
if scheme_end_index > -1:
66+
# urlparse("scheme://netloc/path;parameters?query#fragment")
67+
url_path = urlparse(url_full).path
68+
if url_path == "":
69+
url_path = "/"
70+
elif url_path is None and url_full is None:
71+
# When missing, the URL Path is assumed to be /
72+
url_path = "/"
73+
74+
return (
75+
_Matcher.attribute_match(attributes, self.sampling_rule.Attributes)
76+
and _Matcher.wild_card_match(url_path, self.sampling_rule.URLPath)
77+
and _Matcher.wild_card_match(http_request_method, self.sampling_rule.HTTPMethod)
78+
and _Matcher.wild_card_match(server_address, self.sampling_rule.Host)
79+
and _Matcher.wild_card_match(service_name, self.sampling_rule.ServiceName)
80+
and _Matcher.wild_card_match(self.__get_service_type(resource), self.sampling_rule.ServiceType)
81+
and _Matcher.wild_card_match(self.__get_arn(resource, attributes), self.sampling_rule.ResourceARN)
82+
)
83+
84+
# pylint: disable=no-self-use
85+
def __get_service_type(self, resource: Resource) -> str:
86+
if resource is None:
87+
return ""
88+
89+
cloud_platform = resource.attributes.get(ResourceAttributes.CLOUD_PLATFORM, None)
90+
if cloud_platform is None:
91+
return ""
92+
93+
return cloud_platform_mapping.get(cloud_platform, "")
94+
95+
# pylint: disable=no-self-use
96+
def __get_arn(self, resource: Resource, attributes: Attributes) -> str:
97+
if resource is not None:
98+
arn = resource.attributes.get(ResourceAttributes.AWS_ECS_CONTAINER_ARN, None)
99+
if arn is not None:
100+
return arn
101+
if attributes is not None and self.__get_service_type(resource=resource) == cloud_platform_mapping.get(
102+
CloudPlatformValues.AWS_LAMBDA.value
103+
):
104+
arn = attributes.get(SpanAttributes.CLOUD_RESOURCE_ID, None)
105+
if arn is not None:
106+
return arn
107+
return ""

0 commit comments

Comments
 (0)