14
14
from __future__ import absolute_import
15
15
from ast import literal_eval
16
16
from enum import Enum
17
- from typing import List , Union , Any
17
+ from typing import Dict , List , Union , Any
18
18
19
19
from sagemaker .jumpstart .types import JumpStartDataHolderType
20
20
21
21
22
22
class BooleanValues (str , Enum ):
23
- """Enum class for boolean values."""
23
+ """Enum class for boolean values.
24
+
25
+ This is a status value that an ``Operand`` can resolve to.
26
+ """
24
27
25
28
TRUE = "true"
26
29
FALSE = "false"
@@ -37,6 +40,14 @@ class FilterOperators(str, Enum):
37
40
NOT_IN = "not_in"
38
41
39
42
43
+ class SpecialSupportedFilterKeys (str , Enum ):
44
+ """Enum class for special supported filter keys."""
45
+
46
+ TASK = "task"
47
+ FRAMEWORK = "framework"
48
+ SUPPORTED_MODEL = "supported_model"
49
+
50
+
40
51
FILTER_OPERATOR_STRING_MAPPINGS = {
41
52
FilterOperators .EQUALS : ["===" , "==" , "equals" , "is" ],
42
53
FilterOperators .NOT_EQUALS : ["!==" , "!=" , "not equals" , "is not" ],
@@ -45,14 +56,37 @@ class FilterOperators(str, Enum):
45
56
}
46
57
47
58
59
+ _PAD_ALPHABETIC_OPERATOR = (
60
+ lambda operator : f" { operator } "
61
+ if any (character .isalpha () for character in operator )
62
+ else operator
63
+ )
64
+
65
+ ACCEPTABLE_OPERATORS_IN_PARSE_ORDER = (
66
+ list (map (_PAD_ALPHABETIC_OPERATOR , FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .NOT_EQUALS ]))
67
+ + list (map (_PAD_ALPHABETIC_OPERATOR , FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .NOT_IN ]))
68
+ + list (map (_PAD_ALPHABETIC_OPERATOR , FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .EQUALS ]))
69
+ + list (map (_PAD_ALPHABETIC_OPERATOR , FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .IN ]))
70
+ )
71
+
72
+
73
+ SPECIAL_SUPPORTED_FILTER_KEYS = set (
74
+ [
75
+ SpecialSupportedFilterKeys .TASK ,
76
+ SpecialSupportedFilterKeys .FRAMEWORK ,
77
+ SpecialSupportedFilterKeys .SUPPORTED_MODEL ,
78
+ ]
79
+ )
80
+
81
+
48
82
class Operand :
49
83
"""Operand class for filtering JumpStart content."""
50
84
51
85
def __init__ (
52
86
self , unresolved_value : Any , resolved_value : BooleanValues = BooleanValues .UNEVALUATED
53
87
):
54
- self .unresolved_value = unresolved_value
55
- self .resolved_value = resolved_value
88
+ self .unresolved_value : Any = unresolved_value
89
+ self ._resolved_value : BooleanValues = resolved_value
56
90
57
91
def __iter__ (self ) -> Any :
58
92
"""Returns an iterator."""
@@ -62,9 +96,32 @@ def eval(self) -> None:
62
96
"""Evaluates operand."""
63
97
return
64
98
99
+ @property
100
+ def resolved_value (self ):
101
+ """Getter method for resolved_value."""
102
+ return self ._resolved_value
103
+
104
+ @resolved_value .setter
105
+ def resolved_value (self , new_resolved_value : Any ):
106
+ """Setter method for resolved_value. Resolved_value must be of type ``BooleanValues``."""
107
+ if isinstance (new_resolved_value , BooleanValues ):
108
+ self ._resolved_value = new_resolved_value
109
+ return
110
+ raise RuntimeError (
111
+ "Resolved value must be of type BooleanValues, "
112
+ f"but got type { type (new_resolved_value )} ."
113
+ )
114
+
65
115
@staticmethod
66
116
def validate_operand (operand : Any ) -> Any :
67
- """Validate operand and return ``Operand`` object."""
117
+ """Validate operand and return ``Operand`` object.
118
+
119
+ Args:
120
+ operand (Any): The operand to validate.
121
+
122
+ Raises:
123
+ RuntimeError: If the operand is not of ``Operand`` or ``str`` type.
124
+ """
68
125
if isinstance (operand , str ):
69
126
if operand .lower () == BooleanValues .TRUE .lower ():
70
127
operand = Operand (operand , resolved_value = BooleanValues .TRUE )
@@ -75,12 +132,18 @@ def validate_operand(operand: Any) -> Any:
75
132
else :
76
133
operand = Operand (parse_filter_string (operand ))
77
134
elif not issubclass (type (operand ), Operand ):
78
- raise RuntimeError ()
135
+ raise RuntimeError (f"Operand ' { operand } ' is not supported." )
79
136
return operand
80
137
81
138
82
139
class Operator (Operand ):
83
- """Operator class for filtering JumpStart content."""
140
+ """Operator class for filtering JumpStart content.
141
+
142
+ An operator in this case corresponds to an operand that is also an operation.
143
+ For example, given the expression ``(True or True) and True``,
144
+ ``(True or True)`` is an operand to an ``And`` expression, but is also itself an
145
+ operator. ``(True or True) and True`` would also be considered an operator.
146
+ """
84
147
85
148
def __init__ (
86
149
self ,
@@ -117,24 +180,34 @@ def __init__(
117
180
118
181
Args:
119
182
operand (Operand): Operand for And-ing.
183
+
184
+ Raises:
185
+ RuntimeError: If the operands cannot be validated.
120
186
"""
121
187
self .operands : List [Operand ] = list (operands ) # type: ignore
122
188
for i in range (len (self .operands )):
123
189
self .operands [i ] = Operand .validate_operand (self .operands [i ])
124
190
super ().__init__ ()
125
191
126
192
def eval (self ) -> None :
127
- """Evaluates operator."""
193
+ """Evaluates operator.
194
+
195
+ Raises:
196
+ RuntimeError: If the operands remain unevaluated after calling ``eval``,
197
+ or if the resolved value isn't a ``BooleanValues`` type.
198
+ """
128
199
incomplete_expression = False
129
200
for operand in self .operands :
130
201
if not issubclass (type (operand ), Operand ):
131
- raise RuntimeError ()
202
+ raise RuntimeError (
203
+ f"Operand must be subclass of ``Operand``, but got { type (operand )} "
204
+ )
132
205
if operand .resolved_value == BooleanValues .UNEVALUATED :
133
206
operand .eval ()
134
- if operand .resolved_value == BooleanValues .UNEVALUATED :
135
- raise RuntimeError ()
136
- if not isinstance ( operand . resolved_value , BooleanValues ):
137
- raise RuntimeError ( )
207
+ if operand .resolved_value == BooleanValues .UNEVALUATED :
208
+ raise RuntimeError (
209
+ "Operand remains unevaluated after calling ``eval`` function."
210
+ )
138
211
if operand .resolved_value == BooleanValues .FALSE :
139
212
self .resolved_value = BooleanValues .FALSE
140
213
return
@@ -162,7 +235,7 @@ def __init__(
162
235
"""Instantiates Constant operator object.
163
236
164
237
Args:
165
- constant (BooleanValues] ): Value of constant.
238
+ constant (BooleanValues): Value of constant.
166
239
"""
167
240
super ().__init__ (constant )
168
241
@@ -195,14 +268,21 @@ def __iter__(self) -> Any:
195
268
yield self
196
269
yield from self .operand
197
270
198
- def eval (self ) -> Any :
199
- """Evaluates operator."""
271
+ def eval (self ) -> None :
272
+ """Evaluates operator.
273
+
274
+ Raises:
275
+ RuntimeError: If the operands remain unevaluated after calling ``eval``,
276
+ or if the resolved value isn't a ``BooleanValues`` type.
277
+ """
200
278
if not issubclass (type (self .operand ), Operand ):
201
- raise RuntimeError ()
279
+ raise RuntimeError (
280
+ f"Operand must be subclass of ``Operand``, but got { type (self .operand )} "
281
+ )
202
282
if self .operand .resolved_value == BooleanValues .UNEVALUATED :
203
283
self .operand .eval ()
204
- if self .operand .resolved_value == BooleanValues .UNEVALUATED :
205
- raise RuntimeError ()
284
+ if self .operand .resolved_value == BooleanValues .UNEVALUATED :
285
+ raise RuntimeError ("Operand remains unevaluated after calling ``eval`` function." )
206
286
if not isinstance (self .operand .resolved_value , BooleanValues ):
207
287
raise RuntimeError (self .operand .resolved_value )
208
288
self .resolved_value = self .operand .resolved_value
@@ -219,24 +299,34 @@ def __init__(
219
299
220
300
Args:
221
301
operands (Operand): Operand for Or-ing.
302
+
303
+ Raises:
304
+ RuntimeError: If the operands cannot be validated.
222
305
"""
223
306
self .operands : List [Operand ] = list (operands ) # type: ignore
224
307
for i in range (len (self .operands )):
225
308
self .operands [i ] = Operand .validate_operand (self .operands [i ])
226
309
super ().__init__ ()
227
310
228
311
def eval (self ) -> None :
229
- """Evaluates operator."""
312
+ """Evaluates operator.
313
+
314
+ Raises:
315
+ RuntimeError: If the operands remain unevaluated after calling ``eval``,
316
+ or if the resolved value isn't a ``BooleanValues`` type.
317
+ """
230
318
incomplete_expression = False
231
319
for operand in self .operands :
232
320
if not issubclass (type (operand ), Operand ):
233
- raise RuntimeError ()
321
+ raise RuntimeError (
322
+ f"Operand must be subclass of ``Operand``, but got { type (operand )} "
323
+ )
234
324
if operand .resolved_value == BooleanValues .UNEVALUATED :
235
325
operand .eval ()
236
- if operand .resolved_value == BooleanValues .UNEVALUATED :
237
- raise RuntimeError ()
238
- if not isinstance ( operand . resolved_value , BooleanValues ):
239
- raise RuntimeError ( )
326
+ if operand .resolved_value == BooleanValues .UNEVALUATED :
327
+ raise RuntimeError (
328
+ "Operand remains unevaluated after calling ``eval`` function."
329
+ )
240
330
if operand .resolved_value == BooleanValues .TRUE :
241
331
self .resolved_value = BooleanValues .TRUE
242
332
return
@@ -270,16 +360,21 @@ def __init__(
270
360
super ().__init__ ()
271
361
272
362
def eval (self ) -> None :
273
- """Evaluates operator."""
363
+ """Evaluates operator.
364
+
365
+ Raises:
366
+ RuntimeError: If the operands remain unevaluated after calling ``eval``,
367
+ or if the resolved value isn't a ``BooleanValues`` type.
368
+ """
274
369
275
370
if not issubclass (type (self .operand ), Operand ):
276
- raise RuntimeError ()
371
+ raise RuntimeError (
372
+ f"Operand must be subclass of ``Operand``, but got { type (self .operand )} "
373
+ )
277
374
if self .operand .resolved_value == BooleanValues .UNEVALUATED :
278
375
self .operand .eval ()
279
- if self .operand .resolved_value == BooleanValues .UNEVALUATED :
280
- raise RuntimeError ()
281
- if not isinstance (self .operand .resolved_value , BooleanValues ):
282
- raise RuntimeError ()
376
+ if self .operand .resolved_value == BooleanValues .UNEVALUATED :
377
+ raise RuntimeError ("Operand remains unevaluated after calling ``eval`` function." )
283
378
if self .operand .resolved_value == BooleanValues .TRUE :
284
379
self .resolved_value = BooleanValues .FALSE
285
380
return
@@ -324,37 +419,20 @@ def parse_filter_string(filter_string: str) -> ModelFilter:
324
419
filter_string (str): The filter string to be serialized to an object.
325
420
"""
326
421
327
- pad_alphabetic_operator = (
328
- lambda operator : " " + operator + " "
329
- if any (character .isalpha () for character in operator )
330
- else operator
331
- )
332
-
333
- acceptable_operators_in_parse_order = (
334
- list (
335
- map (
336
- pad_alphabetic_operator , FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .NOT_EQUALS ]
337
- )
338
- )
339
- + list (
340
- map (pad_alphabetic_operator , FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .NOT_IN ])
341
- )
342
- + list (
343
- map (pad_alphabetic_operator , FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .EQUALS ])
344
- )
345
- + list (map (pad_alphabetic_operator , FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .IN ]))
346
- )
347
- for operator in acceptable_operators_in_parse_order :
422
+ for operator in ACCEPTABLE_OPERATORS_IN_PARSE_ORDER :
348
423
split_filter_string = filter_string .split (operator )
349
424
if len (split_filter_string ) == 2 :
350
425
return ModelFilter (
351
- split_filter_string [0 ].strip (), split_filter_string [1 ].strip (), operator .strip ()
426
+ key = split_filter_string [0 ].strip (),
427
+ value = split_filter_string [1 ].strip (),
428
+ operator = operator .strip (),
352
429
)
353
- raise RuntimeError (f"Cannot parse filter string: { filter_string } " )
430
+ raise ValueError (f"Cannot parse filter string: { filter_string } " )
354
431
355
432
356
433
def evaluate_filter_expression ( # pylint: disable=too-many-return-statements
357
- model_filter : ModelFilter , cached_model_value : Any
434
+ model_filter : ModelFilter ,
435
+ cached_model_value : Union [str , bool , int , float , Dict [str , Any ], List [Any ]],
358
436
) -> BooleanValues :
359
437
"""Evaluates model filter with cached model spec value, returns boolean.
360
438
@@ -379,11 +457,21 @@ def evaluate_filter_expression( # pylint: disable=too-many-return-statements
379
457
return BooleanValues .FALSE
380
458
return BooleanValues .TRUE
381
459
if model_filter .operator in FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .IN ]:
382
- if cached_model_value in literal_eval (model_filter .value ):
460
+ py_obj = literal_eval (model_filter .value )
461
+ try :
462
+ iter (py_obj )
463
+ except TypeError :
464
+ return BooleanValues .FALSE
465
+ if cached_model_value in py_obj :
383
466
return BooleanValues .TRUE
384
467
return BooleanValues .FALSE
385
468
if model_filter .operator in FILTER_OPERATOR_STRING_MAPPINGS [FilterOperators .NOT_IN ]:
386
- if cached_model_value in literal_eval (model_filter .value ):
469
+ py_obj = literal_eval (model_filter .value )
470
+ try :
471
+ iter (py_obj )
472
+ except TypeError :
473
+ return BooleanValues .TRUE
474
+ if cached_model_value in py_obj :
387
475
return BooleanValues .FALSE
388
476
return BooleanValues .TRUE
389
477
raise RuntimeError (f"Bad operator: { model_filter .operator } " )
0 commit comments