Skip to content

Commit b81f2f9

Browse files
XinRanZhAWSADOT Patch workflowthpierce
authored
Implement Contract Test for Psychopg2 (#72)
*Description of changes:* Implement Contract Test for Psychopg2 By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --------- Co-authored-by: ADOT Patch workflow <[email protected]> Co-authored-by: Thomas Pierce <[email protected]>
1 parent 7d255d1 commit b81f2f9

File tree

10 files changed

+362
-60
lines changed

10 files changed

+362
-60
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Meant to be run from aws-otel-python-instrumentation/contract-tests.
2+
# Assumes existence of dist/aws_opentelemetry_distro-<pkg_version>-py3-none-any.whl.
3+
# Assumes filename of aws_opentelemetry_distro-<pkg_version>-py3-none-any.whl is passed in as "DISTRO" arg.
4+
FROM python:3.10
5+
WORKDIR /psycopg2
6+
COPY ./dist/$DISTRO /psycopg2
7+
COPY ./contract-tests/images/applications/psycopg2 /psycopg2
8+
9+
ENV PIP_ROOT_USER_ACTION=ignore
10+
ARG DISTRO
11+
RUN pip install --upgrade pip && pip install -r requirements.txt && pip install ${DISTRO} --force-reinstall
12+
RUN opentelemetry-bootstrap -a install
13+
14+
# Without `-u`, logs will be buffered and `wait_for_logs` will never return.
15+
CMD ["opentelemetry-instrument", "python", "-u", "./psycopg2_server.py"]
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
import atexit
4+
import os
5+
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
6+
from threading import Thread
7+
from typing import Tuple
8+
9+
import psycopg2
10+
from typing_extensions import override
11+
12+
_PORT: int = 8080
13+
_SUCCESS: str = "success"
14+
_ERROR: str = "error"
15+
_FAULT: str = "fault"
16+
17+
_DB_HOST = os.getenv("DB_HOST")
18+
_DB_USER = os.getenv("DB_USER")
19+
_DB_PASS = os.getenv("DB_PASS")
20+
_DB_NAME = os.getenv("DB_NAME")
21+
22+
23+
def prepare_database() -> None:
24+
conn = psycopg2.connect(dbname=_DB_NAME, user=_DB_USER, password=_DB_PASS, host=_DB_HOST)
25+
cur = conn.cursor()
26+
cur.execute("DROP TABLE IF EXISTS test_table")
27+
cur.execute(
28+
"""
29+
CREATE TABLE test_table (
30+
id SERIAL PRIMARY KEY,
31+
name TEXT NOT NULL
32+
)
33+
"""
34+
)
35+
36+
cur.execute("INSERT INTO test_table (name) VALUES (%s)", ("Alice",))
37+
cur.execute("INSERT INTO test_table (name) VALUES (%s)", ("Bob",))
38+
39+
conn.commit()
40+
41+
cur.close()
42+
conn.close()
43+
44+
45+
class RequestHandler(BaseHTTPRequestHandler):
46+
@override
47+
# pylint: disable=invalid-name
48+
def do_GET(self):
49+
status_code: int = 200
50+
conn = psycopg2.connect(dbname=_DB_NAME, user=_DB_USER, password=_DB_PASS, host=_DB_HOST)
51+
if self.in_path(_SUCCESS):
52+
cur = conn.cursor()
53+
cur.execute("SELECT id, name FROM test_table")
54+
rows = cur.fetchall()
55+
cur.close()
56+
if len(rows) == 2:
57+
status_code = 200
58+
else:
59+
status_code = 400
60+
elif self.in_path(_FAULT):
61+
cur = conn.cursor()
62+
try:
63+
cur.execute("SELECT DISTINCT id, name FROM invalid_table")
64+
except psycopg2.ProgrammingError as exception:
65+
print("Expected Exception with Invalid SQL occurred:", exception)
66+
status_code = 500
67+
except Exception as exception: # pylint: disable=broad-except
68+
print("Exception Occurred:", exception)
69+
else:
70+
status_code = 200
71+
finally:
72+
cur.close()
73+
else:
74+
status_code = 404
75+
conn.close()
76+
self.send_response_only(status_code)
77+
self.end_headers()
78+
79+
def in_path(self, sub_path: str):
80+
return sub_path in self.path
81+
82+
83+
def main() -> None:
84+
prepare_database()
85+
server_address: Tuple[str, int] = ("0.0.0.0", _PORT)
86+
request_handler_class: type = RequestHandler
87+
requests_server: ThreadingHTTPServer = ThreadingHTTPServer(server_address, request_handler_class)
88+
atexit.register(requests_server.shutdown)
89+
server_thread: Thread = Thread(target=requests_server.serve_forever)
90+
server_thread.start()
91+
print("Ready")
92+
server_thread.join()
93+
94+
95+
if __name__ == "__main__":
96+
main()
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
[project]
2+
name = "psycopg2-server"
3+
description = "Simple server that relies on psycopg2 library"
4+
version = "1.0.0"
5+
license = "Apache-2.0"
6+
requires-python = ">=3.8"
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
opentelemetry-distro==0.43b0
2+
opentelemetry-exporter-otlp-proto-grpc==1.22.0
3+
typing-extensions==4.9.0
4+
psycopg2==2.9.9

contract-tests/images/applications/requests/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Meant to be run from aws-otel-python-instrumentation/contract-tests.
22
# Assumes existence of dist/aws_opentelemetry_distro-<pkg_version>-py3-none-any.whl.
33
# Assumes filename of aws_opentelemetry_distro-<pkg_version>-py3-none-any.whl is passed in as "DISTRO" arg.
4-
FROM public.ecr.aws/docker/library/python:3.11-slim
4+
FROM python:3.10
55
WORKDIR /requests
66
COPY ./dist/$DISTRO /requests
77
COPY ./contract-tests/images/applications/requests /requests

contract-tests/images/mock-collector/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM public.ecr.aws/docker/library/python:3.11-slim
1+
FROM python:3.10
22
WORKDIR /mock-collector
33
COPY . /mock-collector
44

contract-tests/tests/test/amazon/base/contract_test_base.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,15 @@
77
from docker import DockerClient
88
from docker.models.networks import Network, NetworkCollection
99
from docker.types import EndpointConfig
10-
from mock_collector_client import MockCollectorClient
10+
from mock_collector_client import MockCollectorClient, ResourceScopeMetric, ResourceScopeSpan
11+
from requests import Response, request
1112
from testcontainers.core.container import DockerContainer
1213
from testcontainers.core.waiting_utils import wait_for_logs
1314
from typing_extensions import override
1415

16+
from amazon.utils.app_signals_constants import ERROR_METRIC, FAULT_METRIC, LATENCY_METRIC
17+
from opentelemetry.proto.common.v1.common_pb2 import AnyValue
18+
1519
NETWORK_NAME: str = "aws-appsignals-network"
1620

1721
_logger: Logger = getLogger(__name__)
@@ -115,6 +119,45 @@ def tear_down(self) -> None:
115119

116120
self.mock_collector_client.clear_signals()
117121

122+
def do_test_requests(
123+
self, path: str, method: str, status_code: int, expected_error: int, expected_fault: int, **kwargs
124+
) -> None:
125+
address: str = self.application.get_container_host_ip()
126+
port: str = self.application.get_exposed_port(self.get_application_port())
127+
url: str = f"http://{address}:{port}/{path}"
128+
response: Response = request(method, url, timeout=20)
129+
130+
self.assertEqual(status_code, response.status_code)
131+
132+
resource_scope_spans: List[ResourceScopeSpan] = self.mock_collector_client.get_traces()
133+
self._assert_aws_span_attributes(resource_scope_spans, path, **kwargs)
134+
self._assert_semantic_conventions_span_attributes(resource_scope_spans, method, path, status_code, **kwargs)
135+
136+
metrics: List[ResourceScopeMetric] = self.mock_collector_client.get_metrics(
137+
{LATENCY_METRIC, ERROR_METRIC, FAULT_METRIC}
138+
)
139+
self._assert_metric_attributes(metrics, LATENCY_METRIC, 5000, **kwargs)
140+
self._assert_metric_attributes(metrics, ERROR_METRIC, expected_error, **kwargs)
141+
self._assert_metric_attributes(metrics, FAULT_METRIC, expected_fault, **kwargs)
142+
143+
def _assert_str_attribute(self, attributes_dict: Dict[str, AnyValue], key: str, expected_value: str):
144+
self.assertIn(key, attributes_dict)
145+
actual_value: AnyValue = attributes_dict[key]
146+
self.assertIsNotNone(actual_value)
147+
self.assertEqual(expected_value, actual_value.string_value)
148+
149+
def _assert_int_attribute(self, attributes_dict: Dict[str, AnyValue], key: str, expected_value: int) -> None:
150+
self.assertIn(key, attributes_dict)
151+
actual_value: AnyValue = attributes_dict[key]
152+
self.assertIsNotNone(actual_value)
153+
self.assertEqual(expected_value, actual_value.int_value)
154+
155+
def check_sum(self, metric_name: str, actual_sum: float, expected_sum: float) -> None:
156+
if metric_name is LATENCY_METRIC:
157+
self.assertTrue(0 < actual_sum < expected_sum)
158+
else:
159+
self.assertEqual(actual_sum, expected_sum)
160+
118161
# pylint: disable=no-self-use
119162
# Methods that should be overridden in subclasses
120163
@classmethod
@@ -145,3 +188,16 @@ def get_application_otel_service_name(self) -> str:
145188

146189
def get_application_otel_resource_attributes(self) -> str:
147190
return "service.name=" + self.get_application_otel_service_name()
191+
192+
def _assert_aws_span_attributes(self, resource_scope_spans: List[ResourceScopeSpan], path: str, **kwargs):
193+
self.fail("Tests must implement this function")
194+
195+
def _assert_semantic_conventions_span_attributes(
196+
self, resource_scope_spans: List[ResourceScopeSpan], method: str, path: str, status_code: int, **kwargs
197+
):
198+
self.fail("Tests must implement this function")
199+
200+
def _assert_metric_attributes(
201+
self, resource_scope_metrics: List[ResourceScopeMetric], metric_name: str, expected_sum: int, **kwargs
202+
):
203+
self.fail("Tests must implement this function")
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
from typing import Dict, List
4+
5+
from mock_collector_client import ResourceScopeMetric, ResourceScopeSpan
6+
from testcontainers.postgres import PostgresContainer
7+
from typing_extensions import override
8+
9+
from amazon.base.contract_test_base import NETWORK_NAME, ContractTestBase
10+
from amazon.utils.app_signals_constants import (
11+
AWS_LOCAL_OPERATION,
12+
AWS_LOCAL_SERVICE,
13+
AWS_REMOTE_OPERATION,
14+
AWS_REMOTE_SERVICE,
15+
AWS_SPAN_KIND,
16+
)
17+
from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue
18+
from opentelemetry.proto.metrics.v1.metrics_pb2 import ExponentialHistogramDataPoint, Metric
19+
from opentelemetry.proto.trace.v1.trace_pb2 import Span
20+
from opentelemetry.trace import StatusCode
21+
22+
23+
class Psycopg2Test(ContractTestBase):
24+
@override
25+
@classmethod
26+
def set_up_dependency_container(cls) -> None:
27+
cls.container = (
28+
PostgresContainer(user="dbuser", password="example", dbname="postgres")
29+
.with_kwargs(network=NETWORK_NAME)
30+
.with_name("mydb")
31+
)
32+
cls.container.start()
33+
34+
@override
35+
@classmethod
36+
def tear_down_dependency_container(cls) -> None:
37+
cls.container.stop()
38+
39+
@override
40+
def get_application_extra_environment_variables(self) -> Dict[str, str]:
41+
return {
42+
"DB_HOST": "mydb",
43+
"DB_USER": "dbuser",
44+
"DB_PASS": "example",
45+
"DB_NAME": "postgres",
46+
}
47+
48+
@override
49+
def get_application_image_name(self) -> str:
50+
return "aws-appsignals-tests-psycopg2-app"
51+
52+
def test_success(self) -> None:
53+
self.mock_collector_client.clear_signals()
54+
self.do_test_requests("success", "GET", 200, 0, 0, sql_command="SELECT")
55+
56+
def test_fault(self) -> None:
57+
self.mock_collector_client.clear_signals()
58+
self.do_test_requests("fault", "GET", 500, 0, 1, sql_command="SELECT DISTINCT")
59+
60+
@override
61+
def _assert_aws_span_attributes(self, resource_scope_spans: List[ResourceScopeSpan], path: str, **kwargs) -> None:
62+
target_spans: List[Span] = []
63+
for resource_scope_span in resource_scope_spans:
64+
# pylint: disable=no-member
65+
if resource_scope_span.span.kind == Span.SPAN_KIND_CLIENT:
66+
target_spans.append(resource_scope_span.span)
67+
68+
self.assertEqual(len(target_spans), 1)
69+
self._assert_aws_attributes(target_spans[0].attributes, **kwargs)
70+
71+
@override
72+
def _assert_aws_attributes(self, attributes_list: List[KeyValue], **kwargs) -> None:
73+
attributes_dict: Dict[str, AnyValue] = self._get_attributes_dict(attributes_list)
74+
self._assert_str_attribute(attributes_dict, AWS_LOCAL_SERVICE, self.get_application_otel_service_name())
75+
# InternalOperation as OTEL does not instrument the basic server we are using, so the client span is a local
76+
# root.
77+
self._assert_str_attribute(attributes_dict, AWS_LOCAL_OPERATION, "InternalOperation")
78+
self._assert_str_attribute(attributes_dict, AWS_REMOTE_SERVICE, "postgresql")
79+
command: str = kwargs.get("sql_command")
80+
self._assert_str_attribute(attributes_dict, AWS_REMOTE_OPERATION, f"{command}")
81+
# See comment above AWS_LOCAL_OPERATION
82+
self._assert_str_attribute(attributes_dict, AWS_SPAN_KIND, "LOCAL_ROOT")
83+
84+
def _get_attributes_dict(self, attributes_list: List[KeyValue]) -> Dict[str, AnyValue]:
85+
attributes_dict: Dict[str, AnyValue] = {}
86+
for attribute in attributes_list:
87+
key: str = attribute.key
88+
value: AnyValue = attribute.value
89+
if key in attributes_dict:
90+
old_value: AnyValue = attributes_dict[key]
91+
self.fail(f"Attribute {key} unexpectedly duplicated. Value 1: {old_value} Value 2: {value}")
92+
attributes_dict[key] = value
93+
return attributes_dict
94+
95+
@override
96+
def _assert_semantic_conventions_span_attributes(
97+
self, resource_scope_spans: List[ResourceScopeSpan], method: str, path: str, status_code: int, **kwargs
98+
) -> None:
99+
target_spans: List[Span] = []
100+
for resource_scope_span in resource_scope_spans:
101+
# pylint: disable=no-member
102+
if resource_scope_span.span.kind == Span.SPAN_KIND_CLIENT:
103+
target_spans.append(resource_scope_span.span)
104+
105+
self.assertEqual(target_spans[0].name, kwargs.get("sql_command").split()[0])
106+
if status_code == 200:
107+
self.assertEqual(target_spans[0].status.code, StatusCode.UNSET.value)
108+
else:
109+
self.assertEqual(target_spans[0].status.code, StatusCode.ERROR.value)
110+
111+
self._assert_semantic_conventions_attributes(target_spans[0].attributes, kwargs.get("sql_command"))
112+
113+
def _assert_semantic_conventions_attributes(self, attributes_list: List[KeyValue], command: str) -> None:
114+
attributes_dict: Dict[str, AnyValue] = self._get_attributes_dict(attributes_list)
115+
self.assertTrue(attributes_dict.get("db.statement").string_value.startswith(command))
116+
self._assert_str_attribute(attributes_dict, "db.system", "postgresql")
117+
self._assert_str_attribute(attributes_dict, "db.name", "postgres")
118+
self.assertTrue("db.operation" not in attributes_dict)
119+
120+
@override
121+
def _assert_metric_attributes(
122+
self, resource_scope_metrics: List[ResourceScopeMetric], metric_name: str, expected_sum: int, **kwargs
123+
) -> None:
124+
target_metrics: List[Metric] = []
125+
for resource_scope_metric in resource_scope_metrics:
126+
if resource_scope_metric.metric.name.lower() == metric_name.lower():
127+
target_metrics.append(resource_scope_metric.metric)
128+
129+
self.assertEqual(len(target_metrics), 1)
130+
target_metric: Metric = target_metrics[0]
131+
dp_list: List[ExponentialHistogramDataPoint] = target_metric.exponential_histogram.data_points
132+
133+
self.assertEqual(len(dp_list), 2)
134+
dependency_dp: ExponentialHistogramDataPoint = dp_list[0]
135+
service_dp: ExponentialHistogramDataPoint = dp_list[1]
136+
if len(dp_list[1].attributes) > len(dp_list[0].attributes):
137+
dependency_dp = dp_list[1]
138+
service_dp = dp_list[0]
139+
attribute_dict: Dict[str, AnyValue] = self._get_attributes_dict(dependency_dp.attributes)
140+
self._assert_str_attribute(attribute_dict, AWS_LOCAL_SERVICE, self.get_application_otel_service_name())
141+
# See comment on AWS_LOCAL_OPERATION in _assert_aws_attributes
142+
self._assert_str_attribute(attribute_dict, AWS_LOCAL_OPERATION, "InternalOperation")
143+
self._assert_str_attribute(attribute_dict, AWS_REMOTE_SERVICE, "postgresql")
144+
self._assert_str_attribute(attribute_dict, AWS_REMOTE_OPERATION, kwargs.get("sql_command"))
145+
self._assert_str_attribute(attribute_dict, AWS_SPAN_KIND, "CLIENT")
146+
self.check_sum(metric_name, dependency_dp.sum, expected_sum)
147+
148+
attribute_dict: Dict[str, AnyValue] = self._get_attributes_dict(service_dp.attributes)
149+
# See comment on AWS_LOCAL_OPERATION in _assert_aws_attributes
150+
self._assert_str_attribute(attribute_dict, AWS_LOCAL_OPERATION, "InternalOperation")
151+
self._assert_str_attribute(attribute_dict, AWS_SPAN_KIND, "LOCAL_ROOT")
152+
self.check_sum(metric_name, service_dp.sum, expected_sum)

0 commit comments

Comments
 (0)