Skip to content

Commit 0f50da5

Browse files
committed
feat: add support for includes, begins with, ends with
1 parent d573f27 commit 0f50da5

File tree

2 files changed

+172
-32
lines changed

2 files changed

+172
-32
lines changed

src/sagemaker/jumpstart/filters.py

Lines changed: 126 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515
from ast import literal_eval
1616
from enum import Enum
17-
from typing import Dict, List, Union, Any
17+
from typing import Dict, List, Optional, Union, Any
1818

1919
from sagemaker.jumpstart.types import JumpStartDataHolderType
2020

@@ -38,6 +38,10 @@ class FilterOperators(str, Enum):
3838
NOT_EQUALS = "not_equals"
3939
IN = "in"
4040
NOT_IN = "not_in"
41+
INCLUDES = "includes"
42+
NOT_INCLUDES = "not_includes"
43+
BEGINS_WITH = "begins_with"
44+
ENDS_WITH = "ends_with"
4145

4246

4347
class SpecialSupportedFilterKeys(str, Enum):
@@ -52,6 +56,10 @@ class SpecialSupportedFilterKeys(str, Enum):
5256
FilterOperators.NOT_EQUALS: ["!==", "!=", "not equals", "is not"],
5357
FilterOperators.IN: ["in"],
5458
FilterOperators.NOT_IN: ["not in"],
59+
FilterOperators.INCLUDES: ["includes", "contains"],
60+
FilterOperators.NOT_INCLUDES: ["not includes", "not contains"],
61+
FilterOperators.BEGINS_WITH: ["begins with", "starts with"],
62+
FilterOperators.ENDS_WITH: ["ends with"],
5563
}
5664

5765

@@ -62,7 +70,19 @@ class SpecialSupportedFilterKeys(str, Enum):
6270
)
6371

6472
ACCEPTABLE_OPERATORS_IN_PARSE_ORDER = (
65-
list(map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_EQUALS]))
73+
list(
74+
map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.BEGINS_WITH])
75+
)
76+
+ list(
77+
map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.ENDS_WITH])
78+
)
79+
+ list(
80+
map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_INCLUDES])
81+
)
82+
+ list(map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.INCLUDES]))
83+
+ list(
84+
map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_EQUALS])
85+
)
6686
+ list(map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_IN]))
6787
+ list(map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.EQUALS]))
6888
+ list(map(_PAD_ALPHABETIC_OPERATOR, FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.IN]))
@@ -428,9 +448,90 @@ def parse_filter_string(filter_string: str) -> ModelFilter:
428448
raise ValueError(f"Cannot parse filter string: {filter_string}")
429449

430450

451+
def _negate_boolean(boolean: BooleanValues) -> BooleanValues:
452+
if boolean == BooleanValues.TRUE:
453+
return BooleanValues.FALSE
454+
if boolean == BooleanValues.FALSE:
455+
return BooleanValues.TRUE
456+
return boolean
457+
458+
459+
def _evaluate_filter_expression_equals(
460+
model_filter: ModelFilter,
461+
cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]],
462+
) -> BooleanValues:
463+
if cached_model_value is None:
464+
return BooleanValues.FALSE
465+
model_filter_value = model_filter.value
466+
if isinstance(cached_model_value, bool):
467+
cached_model_value = str(cached_model_value).lower()
468+
model_filter_value = model_filter.value.lower()
469+
if str(model_filter_value) == str(cached_model_value):
470+
return BooleanValues.TRUE
471+
return BooleanValues.FALSE
472+
473+
474+
def _evaluate_filter_expression_in(
475+
model_filter: ModelFilter,
476+
cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]],
477+
) -> BooleanValues:
478+
if cached_model_value is None:
479+
return BooleanValues.FALSE
480+
py_obj = model_filter.value
481+
try:
482+
py_obj = literal_eval(py_obj)
483+
try:
484+
iter(py_obj)
485+
except TypeError:
486+
return BooleanValues.FALSE
487+
except Exception:
488+
pass
489+
if isinstance(cached_model_value, list):
490+
return BooleanValues.FALSE
491+
if cached_model_value in py_obj:
492+
return BooleanValues.TRUE
493+
return BooleanValues.FALSE
494+
495+
496+
def _evaluate_filter_expression_includes(
497+
model_filter: ModelFilter,
498+
cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]],
499+
) -> BooleanValues:
500+
if cached_model_value is None:
501+
return BooleanValues.FALSE
502+
filter_value = str(model_filter.value)
503+
if filter_value in cached_model_value:
504+
return BooleanValues.TRUE
505+
return BooleanValues.FALSE
506+
507+
508+
def _evaluate_filter_expression_begins_with(
509+
model_filter: ModelFilter,
510+
cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]],
511+
) -> BooleanValues:
512+
if cached_model_value is None:
513+
return BooleanValues.FALSE
514+
filter_value = str(model_filter.value)
515+
if cached_model_value.startswith(filter_value):
516+
return BooleanValues.TRUE
517+
return BooleanValues.FALSE
518+
519+
520+
def _evaluate_filter_expression_ends_with(
521+
model_filter: ModelFilter,
522+
cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]],
523+
) -> BooleanValues:
524+
if cached_model_value is None:
525+
return BooleanValues.FALSE
526+
filter_value = str(model_filter.value)
527+
if cached_model_value.endswith(filter_value):
528+
return BooleanValues.TRUE
529+
return BooleanValues.FALSE
530+
531+
431532
def evaluate_filter_expression( # pylint: disable=too-many-return-statements
432533
model_filter: ModelFilter,
433-
cached_model_value: Union[str, bool, int, float, Dict[str, Any], List[Any]],
534+
cached_model_value: Optional[Union[str, bool, int, float, Dict[str, Any], List[Any]]],
434535
) -> BooleanValues:
435536
"""Evaluates model filter with cached model spec value, returns boolean.
436537
@@ -440,36 +541,29 @@ def evaluate_filter_expression( # pylint: disable=too-many-return-statements
440541
evaluate the filter.
441542
"""
442543
if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.EQUALS]:
443-
model_filter_value = model_filter.value
444-
if isinstance(cached_model_value, bool):
445-
cached_model_value = str(cached_model_value).lower()
446-
model_filter_value = model_filter.value.lower()
447-
if str(model_filter_value) == str(cached_model_value):
448-
return BooleanValues.TRUE
449-
return BooleanValues.FALSE
544+
return _evaluate_filter_expression_equals(model_filter, cached_model_value)
545+
450546
if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_EQUALS]:
451-
if isinstance(cached_model_value, bool):
452-
cached_model_value = str(cached_model_value).lower()
453-
model_filter.value = model_filter.value.lower()
454-
if str(model_filter.value) == str(cached_model_value):
455-
return BooleanValues.FALSE
456-
return BooleanValues.TRUE
547+
return _negate_boolean(_evaluate_filter_expression_equals(model_filter, cached_model_value))
548+
457549
if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.IN]:
458-
py_obj = literal_eval(model_filter.value)
459-
try:
460-
iter(py_obj)
461-
except TypeError:
462-
return BooleanValues.FALSE
463-
if cached_model_value in py_obj:
464-
return BooleanValues.TRUE
465-
return BooleanValues.FALSE
550+
return _evaluate_filter_expression_in(model_filter, cached_model_value)
551+
466552
if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_IN]:
467-
py_obj = literal_eval(model_filter.value)
468-
try:
469-
iter(py_obj)
470-
except TypeError:
471-
return BooleanValues.TRUE
472-
if cached_model_value in py_obj:
473-
return BooleanValues.FALSE
474-
return BooleanValues.TRUE
553+
return _negate_boolean(_evaluate_filter_expression_in(model_filter, cached_model_value))
554+
555+
if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.INCLUDES]:
556+
return _evaluate_filter_expression_includes(model_filter, cached_model_value)
557+
558+
if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.NOT_INCLUDES]:
559+
return _negate_boolean(
560+
_evaluate_filter_expression_includes(model_filter, cached_model_value)
561+
)
562+
563+
if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.BEGINS_WITH]:
564+
return _evaluate_filter_expression_begins_with(model_filter, cached_model_value)
565+
566+
if model_filter.operator in FILTER_OPERATOR_STRING_MAPPINGS[FilterOperators.ENDS_WITH]:
567+
return _evaluate_filter_expression_ends_with(model_filter, cached_model_value)
568+
475569
raise RuntimeError(f"Bad operator: {model_filter.operator}")

tests/unit/sagemaker/jumpstart/test_filters.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,10 @@ def test_not_equals(self):
143143

144144
def test_in(self):
145145

146+
assert BooleanValues.TRUE == evaluate_filter_expression(
147+
ModelFilter(key="hello", operator="in", value="daddy"), "dad"
148+
)
149+
146150
assert BooleanValues.TRUE == evaluate_filter_expression(
147151
ModelFilter(key="hello", operator="in", value='["mom", "dad"]'), "dad"
148152
)
@@ -169,6 +173,10 @@ def test_in(self):
169173

170174
def test_not_in(self):
171175

176+
assert BooleanValues.FALSE == evaluate_filter_expression(
177+
ModelFilter(key="hello", operator="not in", value="daddy"), "dad"
178+
)
179+
172180
assert BooleanValues.FALSE == evaluate_filter_expression(
173181
ModelFilter(key="hello", operator="not in", value='["mom", "dad"]'), "dad"
174182
)
@@ -193,6 +201,44 @@ def test_not_in(self):
193201
ModelFilter(key="hello", operator="not in", value='["mom", "fsdfdsfsd"]'), False
194202
)
195203

204+
def test_includes(self):
205+
206+
assert BooleanValues.TRUE == evaluate_filter_expression(
207+
ModelFilter(key="hello", operator="includes", value="dad"), "daddy"
208+
)
209+
210+
assert BooleanValues.TRUE == evaluate_filter_expression(
211+
ModelFilter(key="hello", operator="includes", value="dad"), ["dad"]
212+
)
213+
214+
def test_not_includes(self):
215+
216+
assert BooleanValues.FALSE == evaluate_filter_expression(
217+
ModelFilter(key="hello", operator="not includes", value="dad"), "daddy"
218+
)
219+
220+
assert BooleanValues.FALSE == evaluate_filter_expression(
221+
ModelFilter(key="hello", operator="not includes", value="dad"), ["dad"]
222+
)
223+
224+
def test_begins_with(self):
225+
assert BooleanValues.TRUE == evaluate_filter_expression(
226+
ModelFilter(key="hello", operator="begins with", value="dad"), "daddy"
227+
)
228+
229+
assert BooleanValues.FALSE == evaluate_filter_expression(
230+
ModelFilter(key="hello", operator="begins with", value="mm"), "mommy"
231+
)
232+
233+
def test_ends_with(self):
234+
assert BooleanValues.TRUE == evaluate_filter_expression(
235+
ModelFilter(key="hello", operator="ends with", value="car"), "racecar"
236+
)
237+
238+
assert BooleanValues.FALSE == evaluate_filter_expression(
239+
ModelFilter(key="hello", operator="begins with", value="ace"), "racecar"
240+
)
241+
196242

197243
def test_parse_filter_string():
198244

0 commit comments

Comments
 (0)