|
15 | 15 | import copy
|
16 | 16 |
|
17 | 17 | from functools import cmp_to_key
|
| 18 | +import os |
18 | 19 | from typing import Any, Generator, List, Optional, Tuple, Union, Set, Dict
|
19 | 20 | from packaging.version import Version
|
20 | 21 | from sagemaker.jumpstart import accessors
|
21 |
| -from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME |
| 22 | +from sagemaker.jumpstart.constants import ( |
| 23 | + ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING, |
| 24 | + JUMPSTART_DEFAULT_REGION_NAME, |
| 25 | +) |
22 | 26 | from sagemaker.jumpstart.enums import JumpStartScriptScope
|
23 | 27 | from sagemaker.jumpstart.filters import (
|
24 | 28 | SPECIAL_SUPPORTED_FILTER_KEYS,
|
@@ -281,126 +285,160 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin
|
281 | 285 | results. (Default: False).
|
282 | 286 | """
|
283 | 287 |
|
284 |
| - if isinstance(filter, str): |
285 |
| - filter = Identity(filter) |
| 288 | + class _ModelSearchContext: |
| 289 | + """Context manager for conducting model searches.""" |
286 | 290 |
|
287 |
| - models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(region=region) |
288 |
| - manifest_keys = set(models_manifest_list[0].__slots__) |
| 291 | + def __init__(self): |
| 292 | + """Initialize context manager.""" |
289 | 293 |
|
290 |
| - all_keys: Set[str] = set() |
| 294 | + self.old_disable_js_logging_env_var_value = os.environ.get( |
| 295 | + ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING |
| 296 | + ) |
291 | 297 |
|
292 |
| - model_filters: Set[ModelFilter] = set() |
| 298 | + def __enter__(self, *args, **kwargs): |
| 299 | + """Enter context. |
293 | 300 |
|
294 |
| - for operator in _model_filter_in_operator_generator(filter): |
295 |
| - model_filter = operator.unresolved_value |
296 |
| - key = model_filter.key |
297 |
| - all_keys.add(key) |
298 |
| - model_filters.add(model_filter) |
| 301 | + JumpStart logs get disabled to avoid excessive logging. |
| 302 | + """ |
299 | 303 |
|
300 |
| - for key in all_keys: |
301 |
| - if "." in key: |
302 |
| - raise NotImplementedError(f"No support for multiple level metadata indexing ('{key}').") |
| 304 | + os.environ[ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING] = "true" |
303 | 305 |
|
304 |
| - metadata_filter_keys = all_keys - SPECIAL_SUPPORTED_FILTER_KEYS |
| 306 | + def __exit__(self, *args, **kwargs): |
| 307 | + """Exit context. |
305 | 308 |
|
306 |
| - required_manifest_keys = manifest_keys.intersection(metadata_filter_keys) |
307 |
| - possible_spec_keys = metadata_filter_keys - manifest_keys |
| 309 | + Restore JumpStart logging settings, and reset cache so |
| 310 | + new logs would appear for models previously searched. |
| 311 | + """ |
308 | 312 |
|
309 |
| - unrecognized_keys: Set[str] = set() |
| 313 | + if self.old_disable_js_logging_env_var_value: |
| 314 | + os.environ[ |
| 315 | + ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING |
| 316 | + ] = self.old_disable_js_logging_env_var_value |
| 317 | + else: |
| 318 | + os.environ.pop(ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING, None) |
| 319 | + accessors.JumpStartModelsAccessor.reset_cache() |
310 | 320 |
|
311 |
| - is_task_filter = SpecialSupportedFilterKeys.TASK in all_keys |
312 |
| - is_framework_filter = SpecialSupportedFilterKeys.FRAMEWORK in all_keys |
313 |
| - is_supported_model_filter = SpecialSupportedFilterKeys.SUPPORTED_MODEL in all_keys |
| 321 | + with _ModelSearchContext(): |
314 | 322 |
|
315 |
| - for model_manifest in models_manifest_list: |
| 323 | + if isinstance(filter, str): |
| 324 | + filter = Identity(filter) |
316 | 325 |
|
317 |
| - copied_filter = copy.deepcopy(filter) |
| 326 | + models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(region=region) |
| 327 | + manifest_keys = set(models_manifest_list[0].__slots__) |
318 | 328 |
|
319 |
| - manifest_specs_cached_values: Dict[str, Union[bool, int, float, str, dict, list]] = {} |
| 329 | + all_keys: Set[str] = set() |
320 | 330 |
|
321 |
| - model_filters_to_resolved_values: Dict[ModelFilter, BooleanValues] = {} |
| 331 | + model_filters: Set[ModelFilter] = set() |
322 | 332 |
|
323 |
| - for val in required_manifest_keys: |
324 |
| - manifest_specs_cached_values[val] = getattr(model_manifest, val) |
| 333 | + for operator in _model_filter_in_operator_generator(filter): |
| 334 | + model_filter = operator.unresolved_value |
| 335 | + key = model_filter.key |
| 336 | + all_keys.add(key) |
| 337 | + model_filters.add(model_filter) |
325 | 338 |
|
326 |
| - if is_task_filter: |
327 |
| - manifest_specs_cached_values[ |
328 |
| - SpecialSupportedFilterKeys.TASK |
329 |
| - ] = extract_framework_task_model(model_manifest.model_id)[1] |
| 339 | + for key in all_keys: |
| 340 | + if "." in key: |
| 341 | + raise NotImplementedError( |
| 342 | + f"No support for multiple level metadata indexing ('{key}')." |
| 343 | + ) |
330 | 344 |
|
331 |
| - if is_framework_filter: |
332 |
| - manifest_specs_cached_values[ |
333 |
| - SpecialSupportedFilterKeys.FRAMEWORK |
334 |
| - ] = extract_framework_task_model(model_manifest.model_id)[0] |
| 345 | + metadata_filter_keys = all_keys - SPECIAL_SUPPORTED_FILTER_KEYS |
335 | 346 |
|
336 |
| - if is_supported_model_filter: |
337 |
| - manifest_specs_cached_values[SpecialSupportedFilterKeys.SUPPORTED_MODEL] = Version( |
338 |
| - model_manifest.min_version |
339 |
| - ) <= Version(get_sagemaker_version()) |
| 347 | + required_manifest_keys = manifest_keys.intersection(metadata_filter_keys) |
| 348 | + possible_spec_keys = metadata_filter_keys - manifest_keys |
340 | 349 |
|
341 |
| - _populate_model_filters_to_resolved_values( |
342 |
| - manifest_specs_cached_values, |
343 |
| - model_filters_to_resolved_values, |
344 |
| - model_filters, |
345 |
| - ) |
| 350 | + unrecognized_keys: Set[str] = set() |
346 | 351 |
|
347 |
| - _put_resolved_booleans_into_filter(copied_filter, model_filters_to_resolved_values) |
| 352 | + is_task_filter = SpecialSupportedFilterKeys.TASK in all_keys |
| 353 | + is_framework_filter = SpecialSupportedFilterKeys.FRAMEWORK in all_keys |
348 | 354 |
|
349 |
| - copied_filter.eval() |
| 355 | + for model_manifest in models_manifest_list: |
350 | 356 |
|
351 |
| - if copied_filter.resolved_value in [BooleanValues.TRUE, BooleanValues.FALSE]: |
352 |
| - if copied_filter.resolved_value == BooleanValues.TRUE: |
353 |
| - yield (model_manifest.model_id, model_manifest.version) |
354 |
| - continue |
| 357 | + copied_filter = copy.deepcopy(filter) |
355 | 358 |
|
356 |
| - if copied_filter.resolved_value == BooleanValues.UNEVALUATED: |
357 |
| - raise RuntimeError( |
358 |
| - "Filter expression in unevaluated state after using values from model manifest. " |
359 |
| - "Model ID and version that is failing: " |
360 |
| - f"{(model_manifest.model_id, model_manifest.version)}." |
| 359 | + manifest_specs_cached_values: Dict[str, Union[bool, int, float, str, dict, list]] = {} |
| 360 | + |
| 361 | + model_filters_to_resolved_values: Dict[ModelFilter, BooleanValues] = {} |
| 362 | + |
| 363 | + for val in required_manifest_keys: |
| 364 | + manifest_specs_cached_values[val] = getattr(model_manifest, val) |
| 365 | + |
| 366 | + if is_task_filter: |
| 367 | + manifest_specs_cached_values[ |
| 368 | + SpecialSupportedFilterKeys.TASK |
| 369 | + ] = extract_framework_task_model(model_manifest.model_id)[1] |
| 370 | + |
| 371 | + if is_framework_filter: |
| 372 | + manifest_specs_cached_values[ |
| 373 | + SpecialSupportedFilterKeys.FRAMEWORK |
| 374 | + ] = extract_framework_task_model(model_manifest.model_id)[0] |
| 375 | + |
| 376 | + if Version(model_manifest.min_version) > Version(get_sagemaker_version()): |
| 377 | + continue |
| 378 | + |
| 379 | + _populate_model_filters_to_resolved_values( |
| 380 | + manifest_specs_cached_values, |
| 381 | + model_filters_to_resolved_values, |
| 382 | + model_filters, |
361 | 383 | )
|
362 |
| - copied_filter_2 = copy.deepcopy(filter) |
363 | 384 |
|
364 |
| - model_specs = accessors.JumpStartModelsAccessor.get_model_specs( |
365 |
| - region=region, |
366 |
| - model_id=model_manifest.model_id, |
367 |
| - version=model_manifest.version, |
368 |
| - ) |
| 385 | + _put_resolved_booleans_into_filter(copied_filter, model_filters_to_resolved_values) |
369 | 386 |
|
370 |
| - model_specs_keys = set(model_specs.__slots__) |
| 387 | + copied_filter.eval() |
371 | 388 |
|
372 |
| - unrecognized_keys -= model_specs_keys |
373 |
| - unrecognized_keys_for_single_spec = possible_spec_keys - model_specs_keys |
374 |
| - unrecognized_keys.update(unrecognized_keys_for_single_spec) |
| 389 | + if copied_filter.resolved_value in [BooleanValues.TRUE, BooleanValues.FALSE]: |
| 390 | + if copied_filter.resolved_value == BooleanValues.TRUE: |
| 391 | + yield (model_manifest.model_id, model_manifest.version) |
| 392 | + continue |
375 | 393 |
|
376 |
| - for val in possible_spec_keys: |
377 |
| - if hasattr(model_specs, val): |
378 |
| - manifest_specs_cached_values[val] = getattr(model_specs, val) |
| 394 | + if copied_filter.resolved_value == BooleanValues.UNEVALUATED: |
| 395 | + raise RuntimeError( |
| 396 | + "Filter expression in unevaluated state after using " |
| 397 | + "values from model manifest. Model ID and version that " |
| 398 | + f"is failing: {(model_manifest.model_id, model_manifest.version)}." |
| 399 | + ) |
| 400 | + copied_filter_2 = copy.deepcopy(filter) |
379 | 401 |
|
380 |
| - _populate_model_filters_to_resolved_values( |
381 |
| - manifest_specs_cached_values, |
382 |
| - model_filters_to_resolved_values, |
383 |
| - model_filters, |
384 |
| - ) |
385 |
| - _put_resolved_booleans_into_filter(copied_filter_2, model_filters_to_resolved_values) |
| 402 | + model_specs = accessors.JumpStartModelsAccessor.get_model_specs( |
| 403 | + region=region, |
| 404 | + model_id=model_manifest.model_id, |
| 405 | + version=model_manifest.version, |
| 406 | + ) |
386 | 407 |
|
387 |
| - copied_filter_2.eval() |
| 408 | + model_specs_keys = set(model_specs.__slots__) |
388 | 409 |
|
389 |
| - if copied_filter_2.resolved_value != BooleanValues.UNEVALUATED: |
390 |
| - if copied_filter_2.resolved_value == BooleanValues.TRUE or ( |
391 |
| - BooleanValues.UNKNOWN and list_incomplete_models |
392 |
| - ): |
393 |
| - yield (model_manifest.model_id, model_manifest.version) |
394 |
| - continue |
| 410 | + unrecognized_keys -= model_specs_keys |
| 411 | + unrecognized_keys_for_single_spec = possible_spec_keys - model_specs_keys |
| 412 | + unrecognized_keys.update(unrecognized_keys_for_single_spec) |
395 | 413 |
|
396 |
| - raise RuntimeError( |
397 |
| - "Filter expression in unevaluated state after using values from model specs. " |
398 |
| - "Model ID and version that is failing: " |
399 |
| - f"{(model_manifest.model_id, model_manifest.version)}." |
400 |
| - ) |
| 414 | + for val in possible_spec_keys: |
| 415 | + if hasattr(model_specs, val): |
| 416 | + manifest_specs_cached_values[val] = getattr(model_specs, val) |
| 417 | + |
| 418 | + _populate_model_filters_to_resolved_values( |
| 419 | + manifest_specs_cached_values, |
| 420 | + model_filters_to_resolved_values, |
| 421 | + model_filters, |
| 422 | + ) |
| 423 | + _put_resolved_booleans_into_filter(copied_filter_2, model_filters_to_resolved_values) |
| 424 | + |
| 425 | + copied_filter_2.eval() |
| 426 | + |
| 427 | + if copied_filter_2.resolved_value != BooleanValues.UNEVALUATED: |
| 428 | + if copied_filter_2.resolved_value == BooleanValues.TRUE or ( |
| 429 | + BooleanValues.UNKNOWN and list_incomplete_models |
| 430 | + ): |
| 431 | + yield (model_manifest.model_id, model_manifest.version) |
| 432 | + continue |
| 433 | + |
| 434 | + raise RuntimeError( |
| 435 | + "Filter expression in unevaluated state after using values from model specs. " |
| 436 | + "Model ID and version that is failing: " |
| 437 | + f"{(model_manifest.model_id, model_manifest.version)}." |
| 438 | + ) |
401 | 439 |
|
402 |
| - if len(unrecognized_keys) > 0: |
403 |
| - raise RuntimeError(f"Unrecognized keys: {str(unrecognized_keys)}") |
| 440 | + if len(unrecognized_keys) > 0: |
| 441 | + raise RuntimeError(f"Unrecognized keys: {str(unrecognized_keys)}") |
404 | 442 |
|
405 | 443 |
|
406 | 444 | def get_model_url(
|
|
0 commit comments