Skip to content

AWS X-Ray Remote Sampler Part 2 - Add Rules Caching and Rules Matching Logic #47

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class _AwsXRaySamplingClient:
def __init__(self, endpoint=None, log_level=None):
def __init__(self, endpoint: str = None, log_level: str = None):
# Override default log level
if log_level is not None:
_logger.setLevel(log_level)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Sequence

from opentelemetry.context import Context
from opentelemetry.sdk.trace.sampling import ALWAYS_ON, Sampler, SamplingResult, TraceIdRatioBased
from opentelemetry.trace import Link, SpanKind
from opentelemetry.trace.span import TraceState
from opentelemetry.util.types import Attributes


class _FallbackSampler(Sampler):
def __init__(self):
# TODO: Add Reservoir sampler
# pylint: disable=unused-private-member
self.__fixed_rate_sampler = TraceIdRatioBased(0.05)

# pylint: disable=no-self-use
def should_sample(
self,
parent_context: Optional[Context],
trace_id: int,
name: str,
kind: SpanKind = None,
attributes: Attributes = None,
links: Sequence[Link] = None,
trace_state: TraceState = None,
) -> SamplingResult:
# TODO: add reservoir + fixed rate sampling
return ALWAYS_ON.should_sample(
parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state
)

# pylint: disable=no-self-use
def get_description(self) -> str:
description = (
"FallbackSampler{fallback sampling with sampling config of 1 req/sec and 5% of additional requests}"
)
return description
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import re

from opentelemetry.semconv.resource import CloudPlatformValues
from opentelemetry.util.types import Attributes

cloud_platform_mapping = {
CloudPlatformValues.AWS_LAMBDA.value: "AWS::Lambda::Function",
CloudPlatformValues.AWS_ELASTIC_BEANSTALK.value: "AWS::ElasticBeanstalk::Environment",
CloudPlatformValues.AWS_EC2.value: "AWS::EC2::Instance",
CloudPlatformValues.AWS_ECS.value: "AWS::ECS::Container",
CloudPlatformValues.AWS_EKS.value: "AWS::EKS::Container",
}


class _Matcher:
@staticmethod
def wild_card_match(text: str = None, pattern: str = None) -> bool:
if pattern == "*":
return True
if text is None or pattern is None:
return False
if len(pattern) == 0:
return len(text) == 0
for char in pattern:
if char in ("*", "?"):
return re.fullmatch(_Matcher.to_regex_pattern(pattern), text) is not None
return pattern == text

@staticmethod
def to_regex_pattern(rule_pattern: str) -> str:
token_start = -1
regex_pattern = ""
for index, char in enumerate(rule_pattern):
char = rule_pattern[index]
if char in ("*", "?"):
if token_start != -1:
regex_pattern += re.escape(rule_pattern[token_start:index])
token_start = -1
if char == "*":
regex_pattern += ".*"
else:
regex_pattern += "."
else:
if token_start == -1:
token_start = index
if token_start != -1:
regex_pattern += re.escape(rule_pattern[token_start:])
return regex_pattern

@staticmethod
def attribute_match(attributes: Attributes = None, rule_attributes: dict = None) -> bool:
if rule_attributes is None or len(rule_attributes) == 0:
return True
if attributes is None or len(attributes) == 0 or len(rule_attributes) > len(attributes):
return False

matched_count = 0
for key, val in attributes.items():
text_to_match = val
pattern = rule_attributes.get(key, None)
if pattern is None:
continue
if _Matcher.wild_card_match(text_to_match, pattern):
matched_count += 1
return matched_count == len(rule_attributes)
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import datetime
from logging import getLogger
from threading import Lock
from typing import Optional, Sequence

from amazon.opentelemetry.distro.sampler._fallback_sampler import _FallbackSampler
from amazon.opentelemetry.distro.sampler._sampling_rule import _SamplingRule
from amazon.opentelemetry.distro.sampler._sampling_rule_applier import _SamplingRuleApplier
from opentelemetry.context import Context
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.sampling import SamplingResult
from opentelemetry.trace import Link, SpanKind
from opentelemetry.trace.span import TraceState
from opentelemetry.util.types import Attributes

_logger = getLogger(__name__)

CACHE_TTL_SECONDS = 3600


class _RuleCache:
def __init__(self, resource: Resource, fallback_sampler: _FallbackSampler, date_time: datetime, lock: Lock):
self.__rule_appliers: [_SamplingRuleApplier] = []
self.__cache_lock = lock
self.__resource = resource
self._fallback_sampler = fallback_sampler
self._date_time = date_time
self._last_modified = self._date_time.datetime.now()

def should_sample(
self,
parent_context: Optional[Context],
trace_id: int,
name: str,
kind: SpanKind = None,
attributes: Attributes = None,
links: Sequence[Link] = None,
trace_state: TraceState = None,
) -> SamplingResult:
for rule_applier in self.__rule_appliers:
if rule_applier.matches(self.__resource, attributes):
return rule_applier.should_sample(
parent_context,
trace_id,
name,
kind=kind,
attributes=attributes,
links=links,
trace_state=trace_state,
)

return self._fallback_sampler.should_sample(
parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state
)

def update_sampling_rules(self, new_sampling_rules: [_SamplingRule]) -> None:
new_sampling_rules.sort()
temp_rule_appliers = []
for sampling_rule in new_sampling_rules:
if sampling_rule.RuleName == "":
_logger.debug("sampling rule without rule name is not supported")
continue
if sampling_rule.Version != 1:
_logger.debug("sampling rule without Version 1 is not supported: RuleName: %s", sampling_rule.RuleName)
continue
temp_rule_appliers.append(_SamplingRuleApplier(sampling_rule))

self.__cache_lock.acquire()

# map list of rule appliers by each applier's sampling_rule name
rule_applier_map = {rule.sampling_rule.RuleName: rule for rule in self.__rule_appliers}

# If a sampling rule has not changed, keep its respective applier in the cache.
for index, new_applier in enumerate(temp_rule_appliers):
rule_name_to_check = new_applier.sampling_rule.RuleName
if rule_name_to_check in rule_applier_map:
old_applier = rule_applier_map[rule_name_to_check]
if new_applier.sampling_rule == old_applier.sampling_rule:
temp_rule_appliers[index] = old_applier
self.__rule_appliers = temp_rule_appliers
self._last_modified = datetime.datetime.now()

self.__cache_lock.release()

def expired(self) -> bool:
self.__cache_lock.acquire()
try:
return datetime.datetime.now() > self._last_modified + datetime.timedelta(seconds=CACHE_TTL_SECONDS)
finally:
self.__cache_lock.release()
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@ class _SamplingRule:
def __init__(
self,
Attributes: dict = None,
FixedRate=None,
HTTPMethod=None,
Host=None,
Priority=None,
ReservoirSize=None,
ResourceARN=None,
RuleARN=None,
RuleName=None,
ServiceName=None,
ServiceType=None,
URLPath=None,
Version=None,
FixedRate: float = None,
HTTPMethod: str = None,
Host: str = None,
Priority: int = None,
ReservoirSize: int = None,
ResourceARN: str = None,
RuleARN: str = None,
RuleName: str = None,
ServiceName: str = None,
ServiceType: str = None,
URLPath: str = None,
Version: int = None,
):
self.Attributes = Attributes if Attributes is not None else {}
self.FixedRate = FixedRate if FixedRate is not None else 0.0
Expand All @@ -35,3 +35,29 @@ def __init__(
self.ServiceType = ServiceType if ServiceType is not None else ""
self.URLPath = URLPath if URLPath is not None else ""
self.Version = Version if Version is not None else 0

def __lt__(self, other) -> bool:
if self.Priority == other.Priority:
# String order priority example:
# "A","Abc","a","ab","abc","abcdef"
return self.RuleName < other.RuleName
return self.Priority < other.Priority

def __eq__(self, other: object) -> bool:
if not isinstance(other, _SamplingRule):
return False
return (
self.FixedRate == other.FixedRate
and self.HTTPMethod == other.HTTPMethod
and self.Host == other.Host
and self.Priority == other.Priority
and self.ReservoirSize == other.ReservoirSize
and self.ResourceARN == other.ResourceARN
and self.RuleARN == other.RuleARN
and self.RuleName == other.RuleName
and self.ServiceName == other.ServiceName
and self.ServiceType == other.ServiceType
and self.URLPath == other.URLPath
and self.Version == other.Version
and self.Attributes == other.Attributes
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Sequence
from urllib.parse import urlparse

from amazon.opentelemetry.distro.sampler._matcher import _Matcher, cloud_platform_mapping
from amazon.opentelemetry.distro.sampler._sampling_rule import _SamplingRule
from opentelemetry.context import Context
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.sampling import ALWAYS_ON, SamplingResult
from opentelemetry.semconv.resource import CloudPlatformValues, ResourceAttributes
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import Link, SpanKind
from opentelemetry.trace.span import TraceState
from opentelemetry.util.types import Attributes


class _SamplingRuleApplier:
def __init__(self, sampling_rule: _SamplingRule):
self.sampling_rule = sampling_rule
# TODO add self.next_target_fetch_time from maybe time.process_time() or cache's datetime object
# TODO add statistics
# TODO change to rate limiter given rate, add fixed rate sampler
self.reservoir_sampler = ALWAYS_ON
# self.fixed_rate_sampler = None
# TODO add clientId

def should_sample(
self,
parent_context: Optional[Context],
trace_id: int,
name: str,
kind: SpanKind = None,
attributes: Attributes = None,
links: Sequence[Link] = None,
trace_state: TraceState = None,
) -> SamplingResult:
return self.reservoir_sampler.should_sample(
parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state
)

def matches(self, resource: Resource, attributes: Attributes) -> bool:
url_path = None
url_full = None
http_request_method = None
server_address = None
service_name = None

if attributes is not None:
url_path = attributes.get(SpanAttributes.URL_PATH, None)
url_full = attributes.get(SpanAttributes.URL_FULL, None)
http_request_method = attributes.get(SpanAttributes.HTTP_REQUEST_METHOD, None)
server_address = attributes.get(SpanAttributes.SERVER_ADDRESS, None)

# Resource shouldn't be none as it should default to empty resource
if resource is not None:
service_name = resource.attributes.get(ResourceAttributes.SERVICE_NAME, "")

# target may be in url
if url_path is None and url_full is not None:
scheme_end_index = url_full.find("://")
# For network calls, URL usually has `scheme://host[:port][path][?query][#fragment]` format
# Per spec, url.full is always populated with scheme://host/target.
# If scheme doesn't match, assume it's bad instrumentation and ignore.
if scheme_end_index > -1:
# urlparse("scheme://netloc/path;parameters?query#fragment")
url_path = urlparse(url_full).path
if url_path == "":
url_path = "/"
elif url_path is None and url_full is None:
# When missing, the URL Path is assumed to be /
url_path = "/"

return (
_Matcher.attribute_match(attributes, self.sampling_rule.Attributes)
and _Matcher.wild_card_match(url_path, self.sampling_rule.URLPath)
and _Matcher.wild_card_match(http_request_method, self.sampling_rule.HTTPMethod)
and _Matcher.wild_card_match(server_address, self.sampling_rule.Host)
and _Matcher.wild_card_match(service_name, self.sampling_rule.ServiceName)
and _Matcher.wild_card_match(self.__get_service_type(resource), self.sampling_rule.ServiceType)
and _Matcher.wild_card_match(self.__get_arn(resource, attributes), self.sampling_rule.ResourceARN)
)

# pylint: disable=no-self-use
def __get_service_type(self, resource: Resource) -> str:
if resource is None:
return ""

cloud_platform = resource.attributes.get(ResourceAttributes.CLOUD_PLATFORM, None)
if cloud_platform is None:
return ""

return cloud_platform_mapping.get(cloud_platform, "")

# pylint: disable=no-self-use
def __get_arn(self, resource: Resource, attributes: Attributes) -> str:
if resource is not None:
arn = resource.attributes.get(ResourceAttributes.AWS_ECS_CONTAINER_ARN, None)
if arn is not None:
return arn
if attributes is not None and self.__get_service_type(resource=resource) == cloud_platform_mapping.get(
CloudPlatformValues.AWS_LAMBDA.value
):
arn = attributes.get(SpanAttributes.CLOUD_RESOURCE_ID, None)
if arn is not None:
return arn
return ""
Loading