|
14 | 14 | from __future__ import absolute_import
|
15 | 15 | import copy
|
16 | 16 |
|
| 17 | +from concurrent.futures import ThreadPoolExecutor, as_completed |
| 18 | + |
17 | 19 | from functools import cmp_to_key
|
18 |
| -import os |
| 20 | +import json |
19 | 21 | from typing import Any, Generator, List, Optional, Tuple, Union, Set, Dict
|
20 | 22 | from packaging.version import Version
|
21 | 23 | from sagemaker.jumpstart import accessors
|
22 | 24 | from sagemaker.jumpstart.constants import (
|
23 |
| - ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING, |
| 25 | + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, |
24 | 26 | JUMPSTART_DEFAULT_REGION_NAME,
|
25 | 27 | )
|
26 | 28 | from sagemaker.jumpstart.enums import JumpStartScriptScope
|
|
31 | 33 | SpecialSupportedFilterKeys,
|
32 | 34 | )
|
33 | 35 | from sagemaker.jumpstart.filters import Constant, ModelFilter, Operator, evaluate_filter_expression
|
34 |
| -from sagemaker.jumpstart.utils import get_sagemaker_version |
| 36 | +from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs |
| 37 | +from sagemaker.jumpstart.utils import get_jumpstart_content_bucket, get_sagemaker_version |
35 | 38 |
|
36 | 39 |
|
37 | 40 | def _compare_model_version_tuples( # pylint: disable=too-many-return-statements
|
@@ -285,160 +288,130 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin
|
285 | 288 | results. (Default: False).
|
286 | 289 | """
|
287 | 290 |
|
288 |
| - class _ModelSearchContext: |
289 |
| - """Context manager for conducting model searches.""" |
290 |
| - |
291 |
| - def __init__(self): |
292 |
| - """Initialize context manager.""" |
293 |
| - |
294 |
| - self.old_disable_js_logging_env_var_value = os.environ.get( |
295 |
| - ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING |
296 |
| - ) |
297 |
| - |
298 |
| - def __enter__(self, *args, **kwargs): |
299 |
| - """Enter context. |
300 |
| -
|
301 |
| - Disable JumpStart logs to avoid excessive logging. |
302 |
| - """ |
| 291 | + models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(region=region) |
303 | 292 |
|
304 |
| - os.environ[ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING] = "true" |
| 293 | + if isinstance(filter, str): |
| 294 | + filter = Identity(filter) |
305 | 295 |
|
306 |
| - def __exit__(self, *args, **kwargs): |
307 |
| - """Exit context. |
| 296 | + manifest_keys = set(models_manifest_list[0].__slots__) |
308 | 297 |
|
309 |
| - Restore JumpStart logging settings, and reset cache so |
310 |
| - new logs would appear for models previously searched. |
311 |
| - """ |
| 298 | + all_keys: Set[str] = set() |
312 | 299 |
|
313 |
| - if self.old_disable_js_logging_env_var_value: |
314 |
| - os.environ[ |
315 |
| - ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING |
316 |
| - ] = self.old_disable_js_logging_env_var_value |
317 |
| - else: |
318 |
| - os.environ.pop(ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING, None) |
319 |
| - accessors.JumpStartModelsAccessor.reset_cache() |
| 300 | + model_filters: Set[ModelFilter] = set() |
320 | 301 |
|
321 |
| - with _ModelSearchContext(): |
322 |
| - |
323 |
| - if isinstance(filter, str): |
324 |
| - filter = Identity(filter) |
325 |
| - |
326 |
| - models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(region=region) |
327 |
| - manifest_keys = set(models_manifest_list[0].__slots__) |
| 302 | + for operator in _model_filter_in_operator_generator(filter): |
| 303 | + model_filter = operator.unresolved_value |
| 304 | + key = model_filter.key |
| 305 | + all_keys.add(key) |
| 306 | + model_filters.add(model_filter) |
328 | 307 |
|
329 |
| - all_keys: Set[str] = set() |
| 308 | + for key in all_keys: |
| 309 | + if "." in key: |
| 310 | + raise NotImplementedError(f"No support for multiple level metadata indexing ('{key}').") |
330 | 311 |
|
331 |
| - model_filters: Set[ModelFilter] = set() |
| 312 | + metadata_filter_keys = all_keys - SPECIAL_SUPPORTED_FILTER_KEYS |
332 | 313 |
|
333 |
| - for operator in _model_filter_in_operator_generator(filter): |
334 |
| - model_filter = operator.unresolved_value |
335 |
| - key = model_filter.key |
336 |
| - all_keys.add(key) |
337 |
| - model_filters.add(model_filter) |
| 314 | + required_manifest_keys = manifest_keys.intersection(metadata_filter_keys) |
| 315 | + possible_spec_keys = metadata_filter_keys - manifest_keys |
338 | 316 |
|
339 |
| - for key in all_keys: |
340 |
| - if "." in key: |
341 |
| - raise NotImplementedError( |
342 |
| - f"No support for multiple level metadata indexing ('{key}')." |
343 |
| - ) |
| 317 | + is_task_filter = SpecialSupportedFilterKeys.TASK in all_keys |
| 318 | + is_framework_filter = SpecialSupportedFilterKeys.FRAMEWORK in all_keys |
344 | 319 |
|
345 |
| - metadata_filter_keys = all_keys - SPECIAL_SUPPORTED_FILTER_KEYS |
| 320 | + def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str, str]]: |
346 | 321 |
|
347 |
| - required_manifest_keys = manifest_keys.intersection(metadata_filter_keys) |
348 |
| - possible_spec_keys = metadata_filter_keys - manifest_keys |
| 322 | + copied_filter = copy.deepcopy(filter) |
349 | 323 |
|
350 |
| - unrecognized_keys: Set[str] = set() |
| 324 | + manifest_specs_cached_values: Dict[str, Union[bool, int, float, str, dict, list]] = {} |
351 | 325 |
|
352 |
| - is_task_filter = SpecialSupportedFilterKeys.TASK in all_keys |
353 |
| - is_framework_filter = SpecialSupportedFilterKeys.FRAMEWORK in all_keys |
| 326 | + model_filters_to_resolved_values: Dict[ModelFilter, BooleanValues] = {} |
354 | 327 |
|
355 |
| - for model_manifest in models_manifest_list: |
| 328 | + for val in required_manifest_keys: |
| 329 | + manifest_specs_cached_values[val] = getattr(model_manifest, val) |
356 | 330 |
|
357 |
| - copied_filter = copy.deepcopy(filter) |
| 331 | + if is_task_filter: |
| 332 | + manifest_specs_cached_values[ |
| 333 | + SpecialSupportedFilterKeys.TASK |
| 334 | + ] = extract_framework_task_model(model_manifest.model_id)[1] |
358 | 335 |
|
359 |
| - manifest_specs_cached_values: Dict[str, Union[bool, int, float, str, dict, list]] = {} |
| 336 | + if is_framework_filter: |
| 337 | + manifest_specs_cached_values[ |
| 338 | + SpecialSupportedFilterKeys.FRAMEWORK |
| 339 | + ] = extract_framework_task_model(model_manifest.model_id)[0] |
360 | 340 |
|
361 |
| - model_filters_to_resolved_values: Dict[ModelFilter, BooleanValues] = {} |
| 341 | + if Version(model_manifest.min_version) > Version(get_sagemaker_version()): |
| 342 | + return None |
362 | 343 |
|
363 |
| - for val in required_manifest_keys: |
364 |
| - manifest_specs_cached_values[val] = getattr(model_manifest, val) |
| 344 | + _populate_model_filters_to_resolved_values( |
| 345 | + manifest_specs_cached_values, |
| 346 | + model_filters_to_resolved_values, |
| 347 | + model_filters, |
| 348 | + ) |
365 | 349 |
|
366 |
| - if is_task_filter: |
367 |
| - manifest_specs_cached_values[ |
368 |
| - SpecialSupportedFilterKeys.TASK |
369 |
| - ] = extract_framework_task_model(model_manifest.model_id)[1] |
| 350 | + _put_resolved_booleans_into_filter(copied_filter, model_filters_to_resolved_values) |
370 | 351 |
|
371 |
| - if is_framework_filter: |
372 |
| - manifest_specs_cached_values[ |
373 |
| - SpecialSupportedFilterKeys.FRAMEWORK |
374 |
| - ] = extract_framework_task_model(model_manifest.model_id)[0] |
| 352 | + copied_filter.eval() |
375 | 353 |
|
376 |
| - if Version(model_manifest.min_version) > Version(get_sagemaker_version()): |
377 |
| - continue |
| 354 | + if copied_filter.resolved_value in [BooleanValues.TRUE, BooleanValues.FALSE]: |
| 355 | + if copied_filter.resolved_value == BooleanValues.TRUE: |
| 356 | + return (model_manifest.model_id, model_manifest.version) |
| 357 | + return None |
378 | 358 |
|
379 |
| - _populate_model_filters_to_resolved_values( |
380 |
| - manifest_specs_cached_values, |
381 |
| - model_filters_to_resolved_values, |
382 |
| - model_filters, |
| 359 | + if copied_filter.resolved_value == BooleanValues.UNEVALUATED: |
| 360 | + raise RuntimeError( |
| 361 | + "Filter expression in unevaluated state after using " |
| 362 | + "values from model manifest. Model ID and version that " |
| 363 | + f"is failing: {(model_manifest.model_id, model_manifest.version)}." |
383 | 364 | )
|
| 365 | + copied_filter_2 = copy.deepcopy(filter) |
384 | 366 |
|
385 |
| - _put_resolved_booleans_into_filter(copied_filter, model_filters_to_resolved_values) |
386 |
| - |
387 |
| - copied_filter.eval() |
388 |
| - |
389 |
| - if copied_filter.resolved_value in [BooleanValues.TRUE, BooleanValues.FALSE]: |
390 |
| - if copied_filter.resolved_value == BooleanValues.TRUE: |
391 |
| - yield (model_manifest.model_id, model_manifest.version) |
392 |
| - continue |
393 |
| - |
394 |
| - if copied_filter.resolved_value == BooleanValues.UNEVALUATED: |
395 |
| - raise RuntimeError( |
396 |
| - "Filter expression in unevaluated state after using " |
397 |
| - "values from model manifest. Model ID and version that " |
398 |
| - f"is failing: {(model_manifest.model_id, model_manifest.version)}." |
| 367 | + model_specs = JumpStartModelSpecs( |
| 368 | + json.loads( |
| 369 | + DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file( |
| 370 | + get_jumpstart_content_bucket(), model_manifest.spec_key |
399 | 371 | )
|
400 |
| - copied_filter_2 = copy.deepcopy(filter) |
401 |
| - |
402 |
| - model_specs = accessors.JumpStartModelsAccessor.get_model_specs( |
403 |
| - region=region, |
404 |
| - model_id=model_manifest.model_id, |
405 |
| - version=model_manifest.version, |
406 | 372 | )
|
| 373 | + ) |
407 | 374 |
|
408 |
| - model_specs_keys = set(model_specs.__slots__) |
| 375 | + for val in possible_spec_keys: |
| 376 | + if hasattr(model_specs, val): |
| 377 | + manifest_specs_cached_values[val] = getattr(model_specs, val) |
409 | 378 |
|
410 |
| - unrecognized_keys -= model_specs_keys |
411 |
| - unrecognized_keys_for_single_spec = possible_spec_keys - model_specs_keys |
412 |
| - unrecognized_keys.update(unrecognized_keys_for_single_spec) |
| 379 | + _populate_model_filters_to_resolved_values( |
| 380 | + manifest_specs_cached_values, |
| 381 | + model_filters_to_resolved_values, |
| 382 | + model_filters, |
| 383 | + ) |
| 384 | + _put_resolved_booleans_into_filter(copied_filter_2, model_filters_to_resolved_values) |
413 | 385 |
|
414 |
| - for val in possible_spec_keys: |
415 |
| - if hasattr(model_specs, val): |
416 |
| - manifest_specs_cached_values[val] = getattr(model_specs, val) |
| 386 | + copied_filter_2.eval() |
417 | 387 |
|
418 |
| - _populate_model_filters_to_resolved_values( |
419 |
| - manifest_specs_cached_values, |
420 |
| - model_filters_to_resolved_values, |
421 |
| - model_filters, |
422 |
| - ) |
423 |
| - _put_resolved_booleans_into_filter(copied_filter_2, model_filters_to_resolved_values) |
| 388 | + if copied_filter_2.resolved_value != BooleanValues.UNEVALUATED: |
| 389 | + if copied_filter_2.resolved_value == BooleanValues.TRUE or ( |
| 390 | + BooleanValues.UNKNOWN and list_incomplete_models |
| 391 | + ): |
| 392 | + return (model_manifest.model_id, model_manifest.version) |
| 393 | + return None |
424 | 394 |
|
425 |
| - copied_filter_2.eval() |
| 395 | + raise RuntimeError( |
| 396 | + "Filter expression in unevaluated state after using values from model specs. " |
| 397 | + "Model ID and version that is failing: " |
| 398 | + f"{(model_manifest.model_id, model_manifest.version)}." |
| 399 | + ) |
426 | 400 |
|
427 |
| - if copied_filter_2.resolved_value != BooleanValues.UNEVALUATED: |
428 |
| - if copied_filter_2.resolved_value == BooleanValues.TRUE or ( |
429 |
| - BooleanValues.UNKNOWN and list_incomplete_models |
430 |
| - ): |
431 |
| - yield (model_manifest.model_id, model_manifest.version) |
432 |
| - continue |
| 401 | + max_memory = int(100 * 1e6) |
| 402 | + average_memory_per_thread = int(25 * 1e3) |
| 403 | + max_workers = int(max_memory / average_memory_per_thread) |
433 | 404 |
|
434 |
| - raise RuntimeError( |
435 |
| - "Filter expression in unevaluated state after using values from model specs. " |
436 |
| - "Model ID and version that is failing: " |
437 |
| - f"{(model_manifest.model_id, model_manifest.version)}." |
438 |
| - ) |
| 405 | + executor = ThreadPoolExecutor(max_workers=max_workers) |
| 406 | + |
| 407 | + futures = [] |
| 408 | + for header in models_manifest_list: |
| 409 | + futures.append(executor.submit(evaluate_model, header)) |
439 | 410 |
|
440 |
| - if len(unrecognized_keys) > 0: |
441 |
| - raise RuntimeError(f"Unrecognized keys: {str(unrecognized_keys)}") |
| 411 | + for future in as_completed(futures): |
| 412 | + result = future.result() |
| 413 | + if result: |
| 414 | + yield result |
442 | 415 |
|
443 | 416 |
|
444 | 417 | def get_model_url(
|
|
0 commit comments