14
14
from __future__ import absolute_import
15
15
from ast import literal_eval
16
16
from enum import Enum
17
- from typing import Dict , List , Union , Any
17
+ from typing import Dict , List , Optional , Union , Any
18
18
19
19
from sagemaker .jumpstart .types import JumpStartDataHolderType
20
20
@@ -38,6 +38,10 @@ class FilterOperators(str, Enum):
38
38
NOT_EQUALS = "not_equals"
39
39
IN = "in"
40
40
NOT_IN = "not_in"
41
+ INCLUDES = "includes"
42
+ NOT_INCLUDES = "not_includes"
43
+ BEGINS_WITH = "begins_with"
44
+ ENDS_WITH = "ends_with"
41
45
42
46
43
47
class SpecialSupportedFilterKeys (str , Enum ):
@@ -52,6 +56,10 @@ class SpecialSupportedFilterKeys(str, Enum):
52
56
FilterOperators .NOT_EQUALS : ["!==" , "!=" , "not equals" , "is not" ],
53
57
FilterOperators .IN : ["in" ],
54
58
FilterOperators .NOT_IN : ["not in" ],
59
+ FilterOperators .INCLUDES : ["includes" , "contains" ],
60
+ FilterOperators .NOT_INCLUDES : ["not includes" , "not contains" ],
61
+ FilterOperators .BEGINS_WITH : ["begins with" , "starts with" ],
62
+ FilterOperators .ENDS_WITH : ["ends with" ],
55
63
}
56
64
57
65
@@ -62,7 +70,19 @@ class SpecialSupportedFilterKeys(str, Enum):
62
70
)
63
71
64
72
ACCEPTABLE_OPERATORS_IN_PARSE_ORDER = (
65
- list (map (_PAD_ALPHABETIC_OPERATOR , FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .NOT_EQUALS ]))
73
+ list (
74
+ map (_PAD_ALPHABETIC_OPERATOR , FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .BEGINS_WITH ])
75
+ )
76
+ + list (
77
+ map (_PAD_ALPHABETIC_OPERATOR , FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .ENDS_WITH ])
78
+ )
79
+ + list (
80
+ map (_PAD_ALPHABETIC_OPERATOR , FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .NOT_INCLUDES ])
81
+ )
82
+ + list (map (_PAD_ALPHABETIC_OPERATOR , FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .INCLUDES ]))
83
+ + list (
84
+ map (_PAD_ALPHABETIC_OPERATOR , FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .NOT_EQUALS ])
85
+ )
66
86
+ list (map (_PAD_ALPHABETIC_OPERATOR , FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .NOT_IN ]))
67
87
+ list (map (_PAD_ALPHABETIC_OPERATOR , FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .EQUALS ]))
68
88
+ list (map (_PAD_ALPHABETIC_OPERATOR , FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .IN ]))
@@ -428,9 +448,90 @@ def parse_filter_string(filter_string: str) -> ModelFilter:
428
448
raise ValueError (f"Cannot parse filter string: { filter_string } " )
429
449
430
450
451
+ def _negate_boolean (boolean : BooleanValues ) -> BooleanValues :
452
+ if boolean == BooleanValues .TRUE :
453
+ return BooleanValues .FALSE
454
+ if boolean == BooleanValues .FALSE :
455
+ return BooleanValues .TRUE
456
+ return boolean
457
+
458
+
459
+ def _evaluate_filter_expression_equals (
460
+ model_filter : ModelFilter ,
461
+ cached_model_value : Optional [Union [str , bool , int , float , Dict [str , Any ], List [Any ]]],
462
+ ) -> BooleanValues :
463
+ if cached_model_value is None :
464
+ return BooleanValues .FALSE
465
+ model_filter_value = model_filter .value
466
+ if isinstance (cached_model_value , bool ):
467
+ cached_model_value = str (cached_model_value ).lower ()
468
+ model_filter_value = model_filter .value .lower ()
469
+ if str (model_filter_value ) == str (cached_model_value ):
470
+ return BooleanValues .TRUE
471
+ return BooleanValues .FALSE
472
+
473
+
474
+ def _evaluate_filter_expression_in (
475
+ model_filter : ModelFilter ,
476
+ cached_model_value : Optional [Union [str , bool , int , float , Dict [str , Any ], List [Any ]]],
477
+ ) -> BooleanValues :
478
+ if cached_model_value is None :
479
+ return BooleanValues .FALSE
480
+ py_obj = model_filter .value
481
+ try :
482
+ py_obj = literal_eval (py_obj )
483
+ try :
484
+ iter (py_obj )
485
+ except TypeError :
486
+ return BooleanValues .FALSE
487
+ except Exception :
488
+ pass
489
+ if isinstance (cached_model_value , list ):
490
+ return BooleanValues .FALSE
491
+ if cached_model_value in py_obj :
492
+ return BooleanValues .TRUE
493
+ return BooleanValues .FALSE
494
+
495
+
496
+ def _evaluate_filter_expression_includes (
497
+ model_filter : ModelFilter ,
498
+ cached_model_value : Optional [Union [str , bool , int , float , Dict [str , Any ], List [Any ]]],
499
+ ) -> BooleanValues :
500
+ if cached_model_value is None :
501
+ return BooleanValues .FALSE
502
+ filter_value = str (model_filter .value )
503
+ if filter_value in cached_model_value :
504
+ return BooleanValues .TRUE
505
+ return BooleanValues .FALSE
506
+
507
+
508
+ def _evaluate_filter_expression_begins_with (
509
+ model_filter : ModelFilter ,
510
+ cached_model_value : Optional [Union [str , bool , int , float , Dict [str , Any ], List [Any ]]],
511
+ ) -> BooleanValues :
512
+ if cached_model_value is None :
513
+ return BooleanValues .FALSE
514
+ filter_value = str (model_filter .value )
515
+ if cached_model_value .startswith (filter_value ):
516
+ return BooleanValues .TRUE
517
+ return BooleanValues .FALSE
518
+
519
+
520
+ def _evaluate_filter_expression_ends_with (
521
+ model_filter : ModelFilter ,
522
+ cached_model_value : Optional [Union [str , bool , int , float , Dict [str , Any ], List [Any ]]],
523
+ ) -> BooleanValues :
524
+ if cached_model_value is None :
525
+ return BooleanValues .FALSE
526
+ filter_value = str (model_filter .value )
527
+ if cached_model_value .endswith (filter_value ):
528
+ return BooleanValues .TRUE
529
+ return BooleanValues .FALSE
530
+
531
+
431
532
def evaluate_filter_expression ( # pylint: disable=too-many-return-statements
432
533
model_filter : ModelFilter ,
433
- cached_model_value : Union [str , bool , int , float , Dict [str , Any ], List [Any ]],
534
+ cached_model_value : Optional [ Union [str , bool , int , float , Dict [str , Any ], List [Any ] ]],
434
535
) -> BooleanValues :
435
536
"""Evaluates model filter with cached model spec value, returns boolean.
436
537
@@ -440,36 +541,29 @@ def evaluate_filter_expression( # pylint: disable=too-many-return-statements
440
541
evaluate the filter.
441
542
"""
442
543
if model_filter .operator in FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .EQUALS ]:
443
- model_filter_value = model_filter .value
444
- if isinstance (cached_model_value , bool ):
445
- cached_model_value = str (cached_model_value ).lower ()
446
- model_filter_value = model_filter .value .lower ()
447
- if str (model_filter_value ) == str (cached_model_value ):
448
- return BooleanValues .TRUE
449
- return BooleanValues .FALSE
544
+ return _evaluate_filter_expression_equals (model_filter , cached_model_value )
545
+
450
546
if model_filter .operator in FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .NOT_EQUALS ]:
451
- if isinstance (cached_model_value , bool ):
452
- cached_model_value = str (cached_model_value ).lower ()
453
- model_filter .value = model_filter .value .lower ()
454
- if str (model_filter .value ) == str (cached_model_value ):
455
- return BooleanValues .FALSE
456
- return BooleanValues .TRUE
547
+ return _negate_boolean (_evaluate_filter_expression_equals (model_filter , cached_model_value ))
548
+
457
549
if model_filter .operator in FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .IN ]:
458
- py_obj = literal_eval (model_filter .value )
459
- try :
460
- iter (py_obj )
461
- except TypeError :
462
- return BooleanValues .FALSE
463
- if cached_model_value in py_obj :
464
- return BooleanValues .TRUE
465
- return BooleanValues .FALSE
550
+ return _evaluate_filter_expression_in (model_filter , cached_model_value )
551
+
466
552
if model_filter .operator in FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .NOT_IN ]:
467
- py_obj = literal_eval (model_filter .value )
468
- try :
469
- iter (py_obj )
470
- except TypeError :
471
- return BooleanValues .TRUE
472
- if cached_model_value in py_obj :
473
- return BooleanValues .FALSE
474
- return BooleanValues .TRUE
553
+ return _negate_boolean (_evaluate_filter_expression_in (model_filter , cached_model_value ))
554
+
555
+ if model_filter .operator in FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .INCLUDES ]:
556
+ return _evaluate_filter_expression_includes (model_filter , cached_model_value )
557
+
558
+ if model_filter .operator in FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .NOT_INCLUDES ]:
559
+ return _negate_boolean (
560
+ _evaluate_filter_expression_includes (model_filter , cached_model_value )
561
+ )
562
+
563
+ if model_filter .operator in FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .BEGINS_WITH ]:
564
+ return _evaluate_filter_expression_begins_with (model_filter , cached_model_value )
565
+
566
+ if model_filter .operator in FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .ENDS_WITH ]:
567
+ return _evaluate_filter_expression_ends_with (model_filter , cached_model_value )
568
+
475
569
raise RuntimeError (f"Bad operator: { model_filter .operator } " )
0 commit comments