Skip to content

AWS X-Ray Remote Sampler Part 1 - Initial Rules Poller Implementation #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import json
from logging import getLogger

import requests

from amazon.opentelemetry.distro.sampler._sampling_rule import _SamplingRule

_logger = getLogger(__name__)


class _AwsXRaySamplingClient:
def __init__(self, endpoint=None, log_level=None):
# Override default log level
if log_level is not None:
_logger.setLevel(log_level)

if endpoint is None:
_logger.error("endpoint must be specified")
self.__get_sampling_rules_endpoint = endpoint + "/GetSamplingRules"

def get_sampling_rules(self) -> [_SamplingRule]:
sampling_rules = []
headers = {"content-type": "application/json"}

try:
xray_response = requests.post(url=self.__get_sampling_rules_endpoint, headers=headers, timeout=20)
if xray_response is None:
_logger.error("GetSamplingRules response is None")
return []
sampling_rules_response = xray_response.json()
if "SamplingRuleRecords" not in sampling_rules_response:
_logger.error(
"SamplingRuleRecords is missing in getSamplingRules response: %s", sampling_rules_response
)
return []

sampling_rules_records = sampling_rules_response["SamplingRuleRecords"]
for record in sampling_rules_records:
if "SamplingRule" not in record:
_logger.error("SamplingRule is missing in SamplingRuleRecord")
else:
sampling_rules.append(_SamplingRule(**record["SamplingRule"]))

except requests.exceptions.RequestException as req_err:
_logger.error("Request error occurred: %s", req_err)
except json.JSONDecodeError as json_err:
_logger.error("Error in decoding JSON response: %s", json_err)

return sampling_rules
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0


# Disable snake_case naming style so this class can match the sampling rules response from X-Ray
# pylint: disable=invalid-name
class _SamplingRule:
def __init__(
self,
Attributes: dict = None,
FixedRate=None,
HTTPMethod=None,
Host=None,
Priority=None,
ReservoirSize=None,
ResourceARN=None,
RuleARN=None,
RuleName=None,
ServiceName=None,
ServiceType=None,
URLPath=None,
Version=None,
):
self.Attributes = Attributes if Attributes is not None else {}
self.FixedRate = FixedRate if FixedRate is not None else 0.0
self.HTTPMethod = HTTPMethod if HTTPMethod is not None else ""
self.Host = Host if Host is not None else ""
# Default to value with lower priority than default rule
self.Priority = Priority if Priority is not None else 10001
self.ReservoirSize = ReservoirSize if ReservoirSize is not None else 0
self.ResourceARN = ResourceARN if ResourceARN is not None else ""
self.RuleARN = RuleARN if RuleARN is not None else ""
self.RuleName = RuleName if RuleName is not None else ""
self.ServiceName = ServiceName if ServiceName is not None else ""
self.ServiceType = ServiceType if ServiceType is not None else ""
self.URLPath = URLPath if URLPath is not None else ""
self.Version = Version if Version is not None else 0
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import json
from logging import getLogger
from threading import Timer
from typing import Optional, Sequence

from typing_extensions import override

from amazon.opentelemetry.distro.sampler._aws_xray_sampling_client import _AwsXRaySamplingClient
from opentelemetry.context import Context
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.sampling import ALWAYS_OFF, Sampler, SamplingResult
from opentelemetry.trace import Link, SpanKind
from opentelemetry.trace.span import TraceState
from opentelemetry.util.types import Attributes

_logger = getLogger(__name__)

DEFAULT_RULES_POLLING_INTERVAL_SECONDS = 300
DEFAULT_TARGET_POLLING_INTERVAL_SECONDS = 10
DEFAULT_SAMPLING_PROXY_ENDPOINT = "http://127.0.0.1:2000"


class AwsXRayRemoteSampler(Sampler):
"""
Remote Sampler for OpenTelemetry that gets sampling configurations from AWS X-Ray

Args:
resource: OpenTelemetry Resource (Required)
endpoint: proxy endpoint for AWS X-Ray Sampling (Optional)
polling_interval: Polling interval for getSamplingRules call (Optional)
log_level: custom log level configuration for remote sampler (Optional)
"""

__resource: Resource
__polling_interval: int
__xray_client: _AwsXRaySamplingClient

def __init__(
self,
resource: Resource,
endpoint=DEFAULT_SAMPLING_PROXY_ENDPOINT,
polling_interval=DEFAULT_RULES_POLLING_INTERVAL_SECONDS,
log_level=None,
):
# Override default log level
if log_level is not None:
_logger.setLevel(log_level)

self.__xray_client = _AwsXRaySamplingClient(endpoint, log_level=log_level)
self.__polling_interval = polling_interval

# pylint: disable=unused-private-member
if resource is not None:
self.__resource = resource
else:
_logger.warning("OTel Resource provided is `None`. Defaulting to empty resource")
self.__resource = Resource.get_empty()

# Schedule the next rule poll now
# Python Timers only run once, so they need to be recreated for every poll
self._timer = Timer(0, self.__start_sampling_rule_poller)
self._timer.daemon = True # Ensures that when the main thread exits, the Timer threads are killed
self._timer.start()

# pylint: disable=no-self-use
@override
def should_sample(
self,
parent_context: Optional["Context"],
trace_id: int,
name: str,
kind: SpanKind = None,
attributes: Attributes = None,
links: Sequence["Link"] = None,
trace_state: "TraceState" = None,
) -> SamplingResult:
# TODO: add sampling functionality
return ALWAYS_OFF.should_sample(
parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state
)

# pylint: disable=no-self-use
@override
def get_description(self) -> str:
description = "AwsXRayRemoteSampler{remote sampling with AWS X-Ray}"
return description

def __get_and_update_sampling_rules(self):
sampling_rules = self.__xray_client.get_sampling_rules()

# TODO: Update sampling rules cache
_logger.info("Got Sampling Rules: %s", {json.dumps([ob.__dict__ for ob in sampling_rules])})

def __start_sampling_rule_poller(self):
self.__get_and_update_sampling_rules()
# Schedule the next sampling rule poll
self._timer = Timer(self.__polling_interval, self.__start_sampling_rule_poller)
self._timer.daemon = True
self._timer.start()
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
{
"NextToken": null,
"SamplingRuleRecords": [
{
"CreatedAt": 1.67799933E9,
"ModifiedAt": 1.67799933E9,
"SamplingRule": {
"Attributes": {
"foo": "bar",
"doo": "baz"
},
"FixedRate": 0.05,
"HTTPMethod": "*",
"Host": "*",
"Priority": 1000,
"ReservoirSize": 10,
"ResourceARN": "*",
"RuleARN": "arn:aws:xray:us-west-2:123456789000:sampling-rule/Rule1",
"RuleName": "Rule1",
"ServiceName": "*",
"ServiceType": "AWS::Foo::Bar",
"URLPath": "*",
"Version": 1
}
},
{
"CreatedAt": 0.0,
"ModifiedAt": 1.611564245E9,
"SamplingRule": {
"Attributes": {},
"FixedRate": 0.05,
"HTTPMethod": "*",
"Host": "*",
"Priority": 10000,
"ReservoirSize": 1,
"ResourceARN": "*",
"RuleARN": "arn:aws:xray:us-west-2:123456789000:sampling-rule/Default",
"RuleName": "Default",
"ServiceName": "*",
"ServiceType": "*",
"URLPath": "*",
"Version": 1
}
},
{
"CreatedAt": 1.676038494E9,
"ModifiedAt": 1.676038494E9,
"SamplingRule": {
"Attributes": {},
"FixedRate": 0.2,
"HTTPMethod": "GET",
"Host": "*",
"Priority": 1,
"ReservoirSize": 10,
"ResourceARN": "*",
"RuleARN": "arn:aws:xray:us-west-2:123456789000:sampling-rule/Rule2",
"RuleName": "Rule2",
"ServiceName": "FooBar",
"ServiceType": "*",
"URLPath": "/foo/bar",
"Version": 1
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from logging import DEBUG
from unittest import TestCase

from amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler import AwsXRayRemoteSampler
from opentelemetry.sdk.resources import Resource


class TestAwsXRayRemoteSampler(TestCase):
def test_create_remote_sampler_with_empty_resource(self):
rs = AwsXRayRemoteSampler(resource=Resource.get_empty())
self.assertIsNotNone(rs._timer)
self.assertEqual(rs._AwsXRayRemoteSampler__polling_interval, 300)
self.assertIsNotNone(rs._AwsXRayRemoteSampler__xray_client)
self.assertIsNotNone(rs._AwsXRayRemoteSampler__resource)

def test_create_remote_sampler_with_populated_resource(self):
rs = AwsXRayRemoteSampler(
resource=Resource.create({"service.name": "test-service-name", "cloud.platform": "test-cloud-platform"})
)
self.assertIsNotNone(rs._timer)
self.assertEqual(rs._AwsXRayRemoteSampler__polling_interval, 300)
self.assertIsNotNone(rs._AwsXRayRemoteSampler__xray_client)
self.assertIsNotNone(rs._AwsXRayRemoteSampler__resource)
self.assertEqual(rs._AwsXRayRemoteSampler__resource.attributes["service.name"], "test-service-name")
self.assertEqual(rs._AwsXRayRemoteSampler__resource.attributes["cloud.platform"], "test-cloud-platform")

def test_create_remote_sampler_with_all_fields_populated(self):
rs = AwsXRayRemoteSampler(
resource=Resource.create({"service.name": "test-service-name", "cloud.platform": "test-cloud-platform"}),
endpoint="http://abc.com",
polling_interval=120,
log_level=DEBUG,
)
self.assertIsNotNone(rs._timer)
self.assertEqual(rs._AwsXRayRemoteSampler__polling_interval, 120)
self.assertIsNotNone(rs._AwsXRayRemoteSampler__xray_client)
self.assertIsNotNone(rs._AwsXRayRemoteSampler__resource)
self.assertEqual(
rs._AwsXRayRemoteSampler__xray_client._AwsXRaySamplingClient__get_sampling_rules_endpoint,
"http://abc.com/GetSamplingRules",
)
self.assertEqual(rs._AwsXRayRemoteSampler__resource.attributes["service.name"], "test-service-name")
self.assertEqual(rs._AwsXRayRemoteSampler__resource.attributes["cloud.platform"], "test-cloud-platform")
Loading