Skip to content

AWS X-Ray Remote Sampler Part 3 - rate limiter logic and get sampling targets #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import requests

from amazon.opentelemetry.distro.sampler._sampling_rule import _SamplingRule
from amazon.opentelemetry.distro.sampler._sampling_target import _SamplingTargetResponse

_logger = getLogger(__name__)

Expand All @@ -19,6 +20,7 @@ def __init__(self, endpoint: str = None, log_level: str = None):
if endpoint is None:
_logger.error("endpoint must be specified")
self.__get_sampling_rules_endpoint = endpoint + "/GetSamplingRules"
self.__get_sampling_targets_endpoint = endpoint + "/SamplingTargets"

def get_sampling_rules(self) -> [_SamplingRule]:
sampling_rules = []
Expand All @@ -30,12 +32,11 @@ def get_sampling_rules(self) -> [_SamplingRule]:
_logger.error("GetSamplingRules response is None")
return []
sampling_rules_response = xray_response.json()
if "SamplingRuleRecords" not in sampling_rules_response:
if sampling_rules_response is None or "SamplingRuleRecords" not in sampling_rules_response:
_logger.error(
"SamplingRuleRecords is missing in getSamplingRules response: %s", sampling_rules_response
)
return []

sampling_rules_records = sampling_rules_response["SamplingRuleRecords"]
for record in sampling_rules_records:
if "SamplingRule" not in record:
Expand All @@ -47,5 +48,43 @@ def get_sampling_rules(self) -> [_SamplingRule]:
_logger.error("Request error occurred: %s", req_err)
except json.JSONDecodeError as json_err:
_logger.error("Error in decoding JSON response: %s", json_err)
# pylint: disable=broad-exception-caught
except Exception as err:
_logger.error("Error occurred when attempting to fetch rules: %s", err)

return sampling_rules

def get_sampling_targets(self, statistics: [dict]) -> _SamplingTargetResponse:
sampling_targets_response = _SamplingTargetResponse(
LastRuleModification=None, SamplingTargetDocuments=None, UnprocessedStatistics=None
)
headers = {"content-type": "application/json"}
try:
xray_response = requests.post(
url=self.__get_sampling_targets_endpoint,
headers=headers,
timeout=20,
json={"SamplingStatisticsDocuments": statistics},
)
if xray_response is None:
_logger.debug("GetSamplingTargets response is None. Unable to update targets.")
return sampling_targets_response
xray_response_json = xray_response.json()
if (
xray_response_json is None
or "SamplingTargetDocuments" not in xray_response_json
or "LastRuleModification" not in xray_response_json
):
_logger.debug("getSamplingTargets response is invalid. Unable to update targets.")
return sampling_targets_response

sampling_targets_response = _SamplingTargetResponse(**xray_response_json)
except requests.exceptions.RequestException as req_err:
_logger.debug("Request error occurred: %s", req_err)
except json.JSONDecodeError as json_err:
_logger.debug("Error in decoding JSON response: %s", json_err)
# pylint: disable=broad-exception-caught
except Exception as err:
_logger.debug("Error occurred when attempting to fetch targets: %s", err)

return sampling_targets_response
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

import datetime


class _Clock:
def __init__(self):
self.__datetime = datetime.datetime

def now(self) -> datetime.datetime:
return self.__datetime.now()

# pylint: disable=no-self-use
def from_timestamp(self, timestamp: float) -> datetime.datetime:
return datetime.datetime.fromtimestamp(timestamp)

def time_delta(self, seconds: float) -> datetime.timedelta:
return datetime.timedelta(seconds=seconds)

def max(self) -> datetime.datetime:
return datetime.datetime.max
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Sequence

from amazon.opentelemetry.distro.sampler._clock import _Clock
from amazon.opentelemetry.distro.sampler._rate_limiting_sampler import _RateLimitingSampler
from opentelemetry.context import Context
from opentelemetry.sdk.trace.sampling import ALWAYS_ON, Sampler, SamplingResult, TraceIdRatioBased
from opentelemetry.sdk.trace.sampling import Decision, Sampler, SamplingResult, TraceIdRatioBased
from opentelemetry.trace import Link, SpanKind
from opentelemetry.trace.span import TraceState
from opentelemetry.util.types import Attributes


class _FallbackSampler(Sampler):
def __init__(self):
# TODO: Add Reservoir sampler
# pylint: disable=unused-private-member
def __init__(self, clock: _Clock):
self.__rate_limiting_sampler = _RateLimitingSampler(1, clock)
self.__fixed_rate_sampler = TraceIdRatioBased(0.05)

# pylint: disable=no-self-use
Expand All @@ -26,8 +27,12 @@ def should_sample(
links: Sequence[Link] = None,
trace_state: TraceState = None,
) -> SamplingResult:
# TODO: add reservoir + fixed rate sampling
return ALWAYS_ON.should_sample(
sampling_result = self.__rate_limiting_sampler.should_sample(
parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state
)
if sampling_result.decision is not Decision.DROP:
return sampling_result
return self.__fixed_rate_sampler.should_sample(
parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from decimal import Decimal
from threading import Lock

from amazon.opentelemetry.distro.sampler._clock import _Clock


class _RateLimiter:
def __init__(self, max_balance_in_seconds: int, quota: int, clock: _Clock):
# max_balance_in_seconds is usually 1
# pylint: disable=invalid-name
self.MAX_BALANCE_MILLIS = Decimal(max_balance_in_seconds * 1000.0)
self._clock = clock

self._quota = Decimal(quota)
self.__wallet_floor_millis = Decimal(self._clock.now().timestamp() * 1000.0)
# current "wallet_balance" would be ceiling - floor

self.__lock = Lock()

def try_spend(self, cost: float) -> bool:
if self._quota == 0:
return False

quota_per_millis = self._quota / Decimal(1000.0)

# assume divide by zero not possible
cost_in_millis = Decimal(cost) / quota_per_millis

with self.__lock:
wallet_ceiling_millis = Decimal(self._clock.now().timestamp() * 1000.0)
current_balance_millis = wallet_ceiling_millis - self.__wallet_floor_millis
if current_balance_millis > self.MAX_BALANCE_MILLIS:
current_balance_millis = self.MAX_BALANCE_MILLIS

pending_remaining_balance_millis = current_balance_millis - cost_in_millis
if pending_remaining_balance_millis >= 0:
self.__wallet_floor_millis = wallet_ceiling_millis - pending_remaining_balance_millis
return True
# No changes to the wallet state
return False
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Sequence

from amazon.opentelemetry.distro.sampler._clock import _Clock
from amazon.opentelemetry.distro.sampler._rate_limiter import _RateLimiter
from opentelemetry.context import Context
from opentelemetry.sdk.trace.sampling import Decision, Sampler, SamplingResult
from opentelemetry.trace import Link, SpanKind
from opentelemetry.trace.span import TraceState
from opentelemetry.util.types import Attributes


class _RateLimitingSampler(Sampler):
def __init__(self, quota: int, clock: _Clock):
self.__quota = quota
self.__reservoir = _RateLimiter(1, quota, clock)

# pylint: disable=no-self-use
def should_sample(
self,
parent_context: Optional[Context],
trace_id: int,
name: str,
kind: SpanKind = None,
attributes: Attributes = None,
links: Sequence[Link] = None,
trace_state: TraceState = None,
) -> SamplingResult:
if self.__reservoir.try_spend(1):
return SamplingResult(decision=Decision.RECORD_AND_SAMPLE, attributes=attributes, trace_state=trace_state)
return SamplingResult(decision=Decision.DROP, attributes=attributes, trace_state=trace_state)

# pylint: disable=no-self-use
def get_description(self) -> str:
description = (
"RateLimitingSampler{rate limiting sampling with sampling config of "
+ self.__quota
+ " req/sec and 0% of additional requests}"
)
return description
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import datetime
from logging import getLogger
from threading import Lock
from typing import Optional, Sequence
from typing import Dict, Optional, Sequence

from amazon.opentelemetry.distro.sampler._clock import _Clock
from amazon.opentelemetry.distro.sampler._fallback_sampler import _FallbackSampler
from amazon.opentelemetry.distro.sampler._sampling_rule import _SamplingRule
from amazon.opentelemetry.distro.sampler._sampling_rule_applier import _SamplingRuleApplier
from amazon.opentelemetry.distro.sampler._sampling_target import _SamplingTarget, _SamplingTargetResponse
from opentelemetry.context import Context
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.sampling import SamplingResult
Expand All @@ -18,16 +19,20 @@
_logger = getLogger(__name__)

CACHE_TTL_SECONDS = 3600
DEFAULT_TARGET_POLLING_INTERVAL_SECONDS = 10


class _RuleCache:
def __init__(self, resource: Resource, fallback_sampler: _FallbackSampler, date_time: datetime, lock: Lock):
def __init__(
self, resource: Resource, fallback_sampler: _FallbackSampler, client_id: str, clock: _Clock, lock: Lock
):
self.__client_id = client_id
self.__rule_appliers: [_SamplingRuleApplier] = []
self.__cache_lock = lock
self.__resource = resource
self._fallback_sampler = fallback_sampler
self._date_time = date_time
self._last_modified = self._date_time.datetime.now()
self._clock = clock
self._last_modified = self._clock.now()

def should_sample(
self,
Expand All @@ -39,6 +44,7 @@ def should_sample(
links: Sequence[Link] = None,
trace_state: TraceState = None,
) -> SamplingResult:
rule_applier: _SamplingRuleApplier
for rule_applier in self.__rule_appliers:
if rule_applier.matches(self.__resource, attributes):
return rule_applier.should_sample(
Expand All @@ -51,6 +57,8 @@ def should_sample(
trace_state=trace_state,
)

_logger.debug("No sampling rules were matched")
# Should not ever reach fallback sampler as default rule is able to match
return self._fallback_sampler.should_sample(
parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state
)
Expand All @@ -65,28 +73,70 @@ def update_sampling_rules(self, new_sampling_rules: [_SamplingRule]) -> None:
if sampling_rule.Version != 1:
_logger.debug("sampling rule without Version 1 is not supported: RuleName: %s", sampling_rule.RuleName)
continue
temp_rule_appliers.append(_SamplingRuleApplier(sampling_rule))
temp_rule_appliers.append(_SamplingRuleApplier(sampling_rule, self.__client_id, self._clock))

self.__cache_lock.acquire()

# map list of rule appliers by each applier's sampling_rule name
rule_applier_map = {rule.sampling_rule.RuleName: rule for rule in self.__rule_appliers}
rule_applier_map: Dict[str, _SamplingRuleApplier] = {
applier.sampling_rule.RuleName: applier for applier in self.__rule_appliers
}

# If a sampling rule has not changed, keep its respective applier in the cache.
new_applier: _SamplingRuleApplier
for index, new_applier in enumerate(temp_rule_appliers):
rule_name_to_check = new_applier.sampling_rule.RuleName
if rule_name_to_check in rule_applier_map:
old_applier = rule_applier_map[rule_name_to_check]
if new_applier.sampling_rule == old_applier.sampling_rule:
temp_rule_appliers[index] = old_applier
self.__rule_appliers = temp_rule_appliers
self._last_modified = datetime.datetime.now()
self._last_modified = self._clock.now()

self.__cache_lock.release()

def update_sampling_targets(self, sampling_targets_response: _SamplingTargetResponse) -> (bool, int):
targets: [_SamplingTarget] = sampling_targets_response.SamplingTargetDocuments

with self.__cache_lock:
next_polling_interval = DEFAULT_TARGET_POLLING_INTERVAL_SECONDS
min_polling_interval = None

target_map: Dict[str, _SamplingTarget] = {target.RuleName: target for target in targets}

new_appliers = []
applier: _SamplingRuleApplier
for applier in self.__rule_appliers:
if applier.sampling_rule.RuleName in target_map:
target = target_map[applier.sampling_rule.RuleName]
new_appliers.append(applier.with_target(target))

if target.Interval is not None:
if min_polling_interval is None or min_polling_interval > target.Interval:
min_polling_interval = target.Interval
else:
new_appliers.append(applier)

self.__rule_appliers = new_appliers

if min_polling_interval is not None:
next_polling_interval = min_polling_interval

last_rule_modification = self._clock.from_timestamp(sampling_targets_response.LastRuleModification)
refresh_rules = last_rule_modification > self._last_modified

return (refresh_rules, next_polling_interval)

def get_all_statistics(self) -> [dict]:
all_statistics = []
applier: _SamplingRuleApplier
for applier in self.__rule_appliers:
all_statistics.append(applier.get_then_reset_statistics())
return all_statistics

def expired(self) -> bool:
self.__cache_lock.acquire()
try:
return datetime.datetime.now() > self._last_modified + datetime.timedelta(seconds=CACHE_TTL_SECONDS)
return self._clock.now() > self._last_modified + self._clock.time_delta(seconds=CACHE_TTL_SECONDS)
finally:
self.__cache_lock.release()
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
self.URLPath = URLPath if URLPath is not None else ""
self.Version = Version if Version is not None else 0

def __lt__(self, other) -> bool:
def __lt__(self, other: "_SamplingRule") -> bool:
if self.Priority == other.Priority:
# String order priority example:
# "A","Abc","a","ab","abc","abcdef"
Expand Down
Loading