Skip to content

Commit 5538a92

Browse files
Using generics types + enabling ruff
1 parent 63f3d22 commit 5538a92

File tree

5 files changed

+30
-36
lines changed

5 files changed

+30
-36
lines changed

tests/e2e/parser/handlers/handler_with_basic_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,6 @@ class BasicModel(BaseModel):
1515
version: str
1616

1717

18-
@event_parser
18+
@event_parser(model=BasicModel)
1919
def lambda_handler(event: BasicModel, context: LambdaContext):
2020
return {"product": event.product}

tests/e2e/parser/handlers/handler_with_dataclass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,6 @@ class BasicDataclass:
1515
version: str
1616

1717

18-
@event_parser
18+
@event_parser(model=BasicDataclass)
1919
def lambda_handler(event: BasicDataclass, context: LambdaContext):
2020
return {"product": event.product}

tests/e2e/parser/handlers/handler_with_union_tag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,6 @@ class PartialFailureCallback(BaseModel):
3030
OrderCallback = Annotated[Union[SuccessCallback, ErrorCallback, PartialFailureCallback], Field(discriminator="status")]
3131

3232

33-
@event_parser
33+
@event_parser(model=OrderCallback)
3434
def lambda_handler(event: OrderCallback, context: LambdaContext):
3535
return {"error_msg": event.error_msg}

tests/e2e/utils/data_fetcher/common.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
from __future__ import annotations
2-
31
import functools
42
import time
53
from concurrent.futures import Future, ThreadPoolExecutor
64
from datetime import datetime
5+
from typing import List, Optional, Tuple
76

87
import boto3
98
import requests
@@ -14,13 +13,13 @@
1413
from requests.exceptions import RequestException
1514
from retry import retry
1615

17-
GetLambdaResponse = tuple[InvocationResponseTypeDef, datetime]
16+
GetLambdaResponse = Tuple[InvocationResponseTypeDef, datetime]
1817

1918

2019
class GetLambdaResponseOptions(BaseModel):
2120
lambda_arn: str
22-
payload: str | None = None
23-
client: LambdaClient | None = None
21+
payload: Optional[str] = None
22+
client: Optional[LambdaClient] = None
2423
raise_on_error: bool = True
2524

2625
model_config = ConfigDict(
@@ -30,8 +29,8 @@ class GetLambdaResponseOptions(BaseModel):
3029

3130
def get_lambda_response(
3231
lambda_arn: str,
33-
payload: str | None = None,
34-
client: LambdaClient | None = None,
32+
payload: Optional[str] = None,
33+
client: Optional[LambdaClient] = None,
3534
raise_on_error: bool = True,
3635
) -> GetLambdaResponse:
3736
"""Invoke function synchronously
@@ -83,8 +82,8 @@ def get_http_response(request: Request) -> Response:
8382

8483

8584
def get_lambda_response_in_parallel(
86-
get_lambda_response_options: list[GetLambdaResponseOptions],
87-
) -> list[GetLambdaResponse]:
85+
get_lambda_response_options: List[GetLambdaResponseOptions],
86+
) -> List[GetLambdaResponse]:
8887
"""Invoke functions in parallel
8988
9089
Parameters
@@ -99,7 +98,7 @@ def get_lambda_response_in_parallel(
9998
"""
10099
result_list = []
101100
with ThreadPoolExecutor() as executor:
102-
running_tasks: list[Future] = []
101+
running_tasks: List[Future] = []
103102
for options in get_lambda_response_options:
104103
# Sleep 0.5, 1, 1.5, ... seconds between each invocation. This way
105104
# we can guarantee that lambdas are executed in parallel, but they are

tests/e2e/utils/data_fetcher/logs.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,34 @@
1-
from __future__ import annotations
2-
31
import json
4-
from typing import TYPE_CHECKING
2+
from datetime import datetime
3+
from typing import List, Optional, Union
54

65
import boto3
6+
from mypy_boto3_logs.client import CloudWatchLogsClient
77
from pydantic import BaseModel
88
from retry import retry
99

10-
if TYPE_CHECKING:
11-
from datetime import datetime
12-
13-
from mypy_boto3_logs.client import CloudWatchLogsClient
14-
1510

1611
class Log(BaseModel, extra="allow"):
1712
level: str
1813
location: str
19-
message: dict | str
14+
message: Union[dict, str]
2015
timestamp: str
2116
service: str
22-
cold_start: bool | None = None
23-
function_name: str | None = None
24-
function_memory_size: str | None = None
25-
function_arn: str | None = None
26-
function_request_id: str | None = None
27-
xray_trace_id: str | None = None
17+
cold_start: Optional[bool] = None
18+
function_name: Optional[str] = None
19+
function_memory_size: Optional[str] = None
20+
function_arn: Optional[str] = None
21+
function_request_id: Optional[str] = None
22+
xray_trace_id: Optional[str] = None
2823

2924

3025
class LogFetcher:
3126
def __init__(
3227
self,
3328
function_name: str,
3429
start_time: datetime,
35-
log_client: CloudWatchLogsClient | None = None,
36-
filter_expression: str | None = None,
30+
log_client: Optional[CloudWatchLogsClient] = None,
31+
filter_expression: Optional[str] = None,
3732
minimum_log_entries: int = 1,
3833
):
3934
"""Fetch and expose Powertools for AWS Lambda (Python) Logger logs from CloudWatch Logs
@@ -57,9 +52,9 @@ def __init__(
5752
self.filter_expression = filter_expression or "message" # Logger message key
5853
self.log_group = f"/aws/lambda/{self.function_name}"
5954
self.minimum_log_entries = minimum_log_entries
60-
self.logs: list[Log] = self._get_logs()
55+
self.logs: List[Log] = self._get_logs()
6156

62-
def get_log(self, key: str, value: any | None = None) -> list[Log]:
57+
def get_log(self, key: str, value: Optional[any] = None) -> List[Log]:
6358
"""Get logs based on key or key and value
6459
6560
Parameters
@@ -83,7 +78,7 @@ def get_log(self, key: str, value: any | None = None) -> list[Log]:
8378
logs.append(log)
8479
return logs
8580

86-
def get_cold_start_log(self) -> list[Log]:
81+
def get_cold_start_log(self) -> List[Log]:
8782
"""Get logs where cold start was true
8883
8984
Returns
@@ -103,7 +98,7 @@ def have_keys(self, *keys) -> bool:
10398
"""
10499
return all(hasattr(log, key) for log in self.logs for key in keys)
105100

106-
def _get_logs(self) -> list[Log]:
101+
def _get_logs(self) -> List[Log]:
107102
ret = self.log_client.filter_log_events(
108103
logGroupName=self.log_group,
109104
startTime=self.start_time,
@@ -137,8 +132,8 @@ def get_logs(
137132
function_name: str,
138133
start_time: datetime,
139134
minimum_log_entries: int = 1,
140-
filter_expression: str | None = None,
141-
log_client: CloudWatchLogsClient | None = None,
135+
filter_expression: Optional[str] = None,
136+
log_client: Optional[CloudWatchLogsClient] = None,
142137
) -> LogFetcher:
143138
"""_summary_
144139

0 commit comments

Comments
 (0)