Skip to content

Commit 3b1888b

Browse files
committed
chore: resolve git comments
1 parent 2635dd1 commit 3b1888b

File tree

5 files changed

+295
-105
lines changed

5 files changed

+295
-105
lines changed

src/sagemaker/jumpstart/accessors.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,16 +153,18 @@ def reset_cache(cache_kwargs: Dict[str, Any] = None, region: Optional[str] = Non
153153

154154
@staticmethod
155155
def get_manifest(
156-
cache_kwargs: Dict[str, Any] = None, region: Optional[str] = None
156+
cache_kwargs: Optional[Dict[str, Any]] = None, region: Optional[str] = None
157157
) -> List[JumpStartModelHeader]:
158158
"""Return entire JumpStart models manifest.
159159
160160
Raises:
161161
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.
162162
163163
Args:
164-
cache_kwargs (str): cache kwargs to use.
165-
region (str): The region to use for the cache.
164+
cache_kwargs (Dict[str, Any]): Optional. Cache kwargs to use.
165+
(Default: None).
166+
region (str): Optional. The region to use for the cache.
167+
(Default: None).
166168
"""
167169
cache_kwargs_dict: Dict[str, Any] = {} if cache_kwargs is None else cache_kwargs
168170
JumpStartModelsAccessor.set_cache_kwargs(cache_kwargs_dict, region)

src/sagemaker/jumpstart/filters.py

Lines changed: 145 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,16 @@
1414
from __future__ import absolute_import
1515
from ast import literal_eval
1616
from enum import Enum
17-
from typing import List, Union, Any
17+
from typing import Dict, List, Union, Any
1818

1919
from sagemaker.jumpstart.types import JumpStartDataHolderType
2020

2121

2222
class BooleanValues(str, Enum):
23-
"""Enum class for boolean values."""
23+
"""Enum class for boolean values.
24+
25+
This is a status value that an ``Operand`` can resolve to.
26+
"""
2427

2528
TRUE = "true"
2629
FALSE = "false"
@@ -37,6 +40,14 @@ class FilterOperators(str, Enum):
3740
NOT_IN = "not_in"
3841

3942

43+
class SpecialSupportedFilterKeys(str, Enum):
44+
"""Enum class for special supported filter keys."""
45+
46+
TASK = "task"
47+
FRAMEWORK = "framework"
48+
SUPPORTED_MODEL = "supported_model"
49+
50+
4051
FILTER_OPERATOR_STRING_MAPPINGS = {
4152
FilterOperators.EQUALS: ["===", "==", "equals", "is"],
4253
FilterOperators.NOT_EQUALS: ["!==", "!=", "not equals", "is not"],
@@ -45,14 +56,37 @@ class FilterOperators(str, Enum):
4556
}
4657

4758

59+
_PAD_ALPHABETIC_OPERATOR = (
60+
lambda operator: f" {operator} "
61+
if any(character.isalpha() for character in operator)
62+
else operator
63+
)
64+
65+
ACCEPTABLE_OPERATORS_IN_PARSE_ORDER = (
66+
list(map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_EQUALS]))
67+
+ list(map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_IN]))
68+
+ list(map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.EQUALS]))
69+
+ list(map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.IN]))
70+
)
71+
72+
73+
SPECIAL_SUPPORTED_FILTER_KEYS = set(
74+
[
75+
SpecialSupportedFilterKeys.TASK,
76+
SpecialSupportedFilterKeys.FRAMEWORK,
77+
SpecialSupportedFilterKeys.SUPPORTED_MODEL,
78+
]
79+
)
80+
81+
4882
class Operand:
4983
"""Operand class for filtering JumpStart content."""
5084

5185
def __init__(
5286
self, unresolved_value: Any, resolved_value: BooleanValues = BooleanValues.UNEVALUATED
5387
):
54-
self.unresolved_value = unresolved_value
55-
self.resolved_value = resolved_value
88+
self.unresolved_value: Any = unresolved_value
89+
self._resolved_value: BooleanValues = resolved_value
5690

5791
def __iter__(self) -> Any:
5892
"""Returns an iterator."""
@@ -62,9 +96,32 @@ def eval(self) -> None:
6296
"""Evaluates operand."""
6397
return
6498

99+
@property
100+
def resolved_value(self):
101+
"""Getter method for resolved_value."""
102+
return self._resolved_value
103+
104+
@resolved_value.setter
105+
def resolved_value(self, new_resolved_value: Any):
106+
"""Setter method for resolved_value. Resolved_value must be of type ``BooleanValues``."""
107+
if isinstance(new_resolved_value, BooleanValues):
108+
self._resolved_value = new_resolved_value
109+
return
110+
raise RuntimeError(
111+
"Resolved value must be of type BooleanValues, "
112+
f"but got type {type(new_resolved_value)}."
113+
)
114+
65115
@staticmethod
66116
def validate_operand(operand: Any) -> Any:
67-
"""Validate operand and return ``Operand`` object."""
117+
"""Validate operand and return ``Operand`` object.
118+
119+
Args:
120+
operand (Any): The operand to validate.
121+
122+
Raises:
123+
RuntimeError: If the operand is not of ``Operand`` or ``str`` type.
124+
"""
68125
if isinstance(operand, str):
69126
if operand.lower() == BooleanValues.TRUE.lower():
70127
operand = Operand(operand, resolved_value=BooleanValues.TRUE)
@@ -75,12 +132,18 @@ def validate_operand(operand: Any) -> Any:
75132
else:
76133
operand = Operand(parse_filter_string(operand))
77134
elif not issubclass(type(operand), Operand):
78-
raise RuntimeError()
135+
raise RuntimeError(f"Operand '{operand}' is not supported.")
79136
return operand
80137

81138

82139
class Operator(Operand):
83-
"""Operator class for filtering JumpStart content."""
140+
"""Operator class for filtering JumpStart content.
141+
142+
An operator in this case corresponds to an operand that is also an operation.
143+
For example, given the expression ``(True or True) and True``,
144+
``(True or True)`` is an operand to an ``And`` expression, but is also itself an
145+
operator. ``(True or True) and True`` would also be considered an operator.
146+
"""
84147

85148
def __init__(
86149
self,
@@ -117,24 +180,34 @@ def __init__(
117180
118181
Args:
119182
operand (Operand): Operand for And-ing.
183+
184+
Raises:
185+
RuntimeError: If the operands cannot be validated.
120186
"""
121187
self.operands: List[Operand] = list(operands) # type: ignore
122188
for i in range(len(self.operands)):
123189
self.operands[i] = Operand.validate_operand(self.operands[i])
124190
super().__init__()
125191

126192
def eval(self) -> None:
127-
"""Evaluates operator."""
193+
"""Evaluates operator.
194+
195+
Raises:
196+
RuntimeError: If the operands remain unevaluated after calling ``eval``,
197+
or if the resolved value isn't a ``BooleanValues`` type.
198+
"""
128199
incomplete_expression = False
129200
for operand in self.operands:
130201
if not issubclass(type(operand), Operand):
131-
raise RuntimeError()
202+
raise RuntimeError(
203+
f"Operand must be subclass of ``Operand``, but got {type(operand)}"
204+
)
132205
if operand.resolved_value == BooleanValues.UNEVALUATED:
133206
operand.eval()
134-
if operand.resolved_value == BooleanValues.UNEVALUATED:
135-
raise RuntimeError()
136-
if not isinstance(operand.resolved_value, BooleanValues):
137-
raise RuntimeError()
207+
if operand.resolved_value == BooleanValues.UNEVALUATED:
208+
raise RuntimeError(
209+
"Operand remains unevaluated after calling ``eval`` function."
210+
)
138211
if operand.resolved_value == BooleanValues.FALSE:
139212
self.resolved_value = BooleanValues.FALSE
140213
return
@@ -162,7 +235,7 @@ def __init__(
162235
"""Instantiates Constant operator object.
163236
164237
Args:
165-
constant (BooleanValues]): Value of constant.
238+
constant (BooleanValues): Value of constant.
166239
"""
167240
super().__init__(constant)
168241

@@ -195,14 +268,21 @@ def __iter__(self) -> Any:
195268
yield self
196269
yield from self.operand
197270

198-
def eval(self) -> Any:
199-
"""Evaluates operator."""
271+
def eval(self) -> None:
272+
"""Evaluates operator.
273+
274+
Raises:
275+
RuntimeError: If the operands remain unevaluated after calling ``eval``,
276+
or if the resolved value isn't a ``BooleanValues`` type.
277+
"""
200278
if not issubclass(type(self.operand), Operand):
201-
raise RuntimeError()
279+
raise RuntimeError(
280+
f"Operand must be subclass of ``Operand``, but got {type(self.operand)}"
281+
)
202282
if self.operand.resolved_value == BooleanValues.UNEVALUATED:
203283
self.operand.eval()
204-
if self.operand.resolved_value == BooleanValues.UNEVALUATED:
205-
raise RuntimeError()
284+
if self.operand.resolved_value == BooleanValues.UNEVALUATED:
285+
raise RuntimeError("Operand remains unevaluated after calling ``eval`` function.")
206286
if not isinstance(self.operand.resolved_value, BooleanValues):
207287
raise RuntimeError(self.operand.resolved_value)
208288
self.resolved_value = self.operand.resolved_value
@@ -219,24 +299,34 @@ def __init__(
219299
220300
Args:
221301
operands (Operand): Operand for Or-ing.
302+
303+
Raises:
304+
RuntimeError: If the operands cannot be validated.
222305
"""
223306
self.operands: List[Operand] = list(operands) # type: ignore
224307
for i in range(len(self.operands)):
225308
self.operands[i] = Operand.validate_operand(self.operands[i])
226309
super().__init__()
227310

228311
def eval(self) -> None:
229-
"""Evaluates operator."""
312+
"""Evaluates operator.
313+
314+
Raises:
315+
RuntimeError: If the operands remain unevaluated after calling ``eval``,
316+
or if the resolved value isn't a ``BooleanValues`` type.
317+
"""
230318
incomplete_expression = False
231319
for operand in self.operands:
232320
if not issubclass(type(operand), Operand):
233-
raise RuntimeError()
321+
raise RuntimeError(
322+
f"Operand must be subclass of ``Operand``, but got {type(operand)}"
323+
)
234324
if operand.resolved_value == BooleanValues.UNEVALUATED:
235325
operand.eval()
236-
if operand.resolved_value == BooleanValues.UNEVALUATED:
237-
raise RuntimeError()
238-
if not isinstance(operand.resolved_value, BooleanValues):
239-
raise RuntimeError()
326+
if operand.resolved_value == BooleanValues.UNEVALUATED:
327+
raise RuntimeError(
328+
"Operand remains unevaluated after calling ``eval`` function."
329+
)
240330
if operand.resolved_value == BooleanValues.TRUE:
241331
self.resolved_value = BooleanValues.TRUE
242332
return
@@ -270,16 +360,21 @@ def __init__(
270360
super().__init__()
271361

272362
def eval(self) -> None:
273-
"""Evaluates operator."""
363+
"""Evaluates operator.
364+
365+
Raises:
366+
RuntimeError: If the operands remain unevaluated after calling ``eval``,
367+
or if the resolved value isn't a ``BooleanValues`` type.
368+
"""
274369

275370
if not issubclass(type(self.operand), Operand):
276-
raise RuntimeError()
371+
raise RuntimeError(
372+
f"Operand must be subclass of ``Operand``, but got {type(self.operand)}"
373+
)
277374
if self.operand.resolved_value == BooleanValues.UNEVALUATED:
278375
self.operand.eval()
279-
if self.operand.resolved_value == BooleanValues.UNEVALUATED:
280-
raise RuntimeError()
281-
if not isinstance(self.operand.resolved_value, BooleanValues):
282-
raise RuntimeError()
376+
if self.operand.resolved_value == BooleanValues.UNEVALUATED:
377+
raise RuntimeError("Operand remains unevaluated after calling ``eval`` function.")
283378
if self.operand.resolved_value == BooleanValues.TRUE:
284379
self.resolved_value = BooleanValues.FALSE
285380
return
@@ -324,37 +419,20 @@ def parse_filter_string(filter_string: str) -> ModelFilter:
324419
filter_string (str): The filter string to be serialized to an object.
325420
"""
326421

327-
pad_alphabetic_operator = (
328-
lambda operator: " " + operator + " "
329-
if any(character.isalpha() for character in operator)
330-
else operator
331-
)
332-
333-
acceptable_operators_in_parse_order = (
334-
list(
335-
map(
336-
pad_alphabetic_operator, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_EQUALS]
337-
)
338-
)
339-
+ list(
340-
map(pad_alphabetic_operator, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_IN])
341-
)
342-
+ list(
343-
map(pad_alphabetic_operator, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.EQUALS])
344-
)
345-
+ list(map(pad_alphabetic_operator, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.IN]))
346-
)
347-
for operator in acceptable_operators_in_parse_order:
422+
for operator in ACCEPTABLE_OPERATORS_IN_PARSE_ORDER:
348423
split_filter_string = filter_string.split(operator)
349424
if len(split_filter_string) == 2:
350425
return ModelFilter(
351-
split_filter_string[0].strip(), split_filter_string[1].strip(), operator.strip()
426+
key=split_filter_string[0].strip(),
427+
value=split_filter_string[1].strip(),
428+
operator=operator.strip(),
352429
)
353-
raise RuntimeError(f"Cannot parse filter string: {filter_string}")
430+
raise ValueError(f"Cannot parse filter string: {filter_string}")
354431

355432

356433
def evaluate_filter_expression( # pylint: disable=too-many-return-statements
357-
model_filter: ModelFilter, cached_model_value: Any
434+
model_filter: ModelFilter,
435+
cached_model_value: Union[str, bool, int, float, Dict[str, Any], List[Any]],
358436
) -> BooleanValues:
359437
"""Evaluates model filter with cached model spec value, returns boolean.
360438
@@ -379,11 +457,21 @@ def evaluate_filter_expression( # pylint: disable=too-many-return-statements
379457
return BooleanValues.FALSE
380458
return BooleanValues.TRUE
381459
if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.IN]:
382-
if cached_model_value in literal_eval(model_filter.value):
460+
py_obj = literal_eval(model_filter.value)
461+
try:
462+
iter(py_obj)
463+
except TypeError:
464+
return BooleanValues.FALSE
465+
if cached_model_value in py_obj:
383466
return BooleanValues.TRUE
384467
return BooleanValues.FALSE
385468
if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_IN]:
386-
if cached_model_value in literal_eval(model_filter.value):
469+
py_obj = literal_eval(model_filter.value)
470+
try:
471+
iter(py_obj)
472+
except TypeError:
473+
return BooleanValues.TRUE
474+
if cached_model_value in py_obj:
387475
return BooleanValues.FALSE
388476
return BooleanValues.TRUE
389477
raise RuntimeError(f"Bad operator: {model_filter.operator}")

0 commit comments

Comments
 (0)