Skip to content

Commit a93fbc7

Browse files
committed
feat: parallelize notebook search utils
1 parent 0b7bb64 commit a93fbc7

File tree

2 files changed

+159
-199
lines changed

2 files changed

+159
-199
lines changed

src/sagemaker/jumpstart/notebook_utils.py

Lines changed: 96 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
from __future__ import absolute_import
1515
import copy
1616

17+
from concurrent.futures import ThreadPoolExecutor, as_completed
18+
1719
from functools import cmp_to_key
18-
import os
20+
import json
1921
from typing import Any, Generator, List, Optional, Tuple, Union, Set, Dict
2022
from packaging.version import Version
2123
from sagemaker.jumpstart import accessors
2224
from sagemaker.jumpstart.constants import (
23-
ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING,
25+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
2426
JUMPSTART_DEFAULT_REGION_NAME,
2527
)
2628
from sagemaker.jumpstart.enums import JumpStartScriptScope
@@ -31,7 +33,8 @@
3133
SpecialSupportedFilterKeys,
3234
)
3335
from sagemaker.jumpstart.filters import Constant, ModelFilter, Operator, evaluate_filter_expression
34-
from sagemaker.jumpstart.utils import get_sagemaker_version
36+
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
37+
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket, get_sagemaker_version
3538

3639

3740
def _compare_model_version_tuples( # pylint: disable=too-many-return-statements
@@ -285,160 +288,130 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin
285288
results. (Default: False).
286289
"""
287290

288-
class _ModelSearchContext:
289-
"""Context manager for conducting model searches."""
290-
291-
def __init__(self):
292-
"""Initialize context manager."""
293-
294-
self.old_disable_js_logging_env_var_value = os.environ.get(
295-
ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING
296-
)
297-
298-
def __enter__(self, *args, **kwargs):
299-
"""Enter context.
300-
301-
Disable JumpStart logs to avoid excessive logging.
302-
"""
291+
models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(region=region)
303292

304-
os.environ[ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING] = "true"
293+
if isinstance(filter, str):
294+
filter = Identity(filter)
305295

306-
def __exit__(self, *args, **kwargs):
307-
"""Exit context.
296+
manifest_keys = set(models_manifest_list[0].__slots__)
308297

309-
Restore JumpStart logging settings, and reset cache so
310-
new logs would appear for models previously searched.
311-
"""
298+
all_keys: Set[str] = set()
312299

313-
if self.old_disable_js_logging_env_var_value:
314-
os.environ[
315-
ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING
316-
] = self.old_disable_js_logging_env_var_value
317-
else:
318-
os.environ.pop(ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING, None)
319-
accessors.JumpStartModelsAccessor.reset_cache()
300+
model_filters: Set[ModelFilter] = set()
320301

321-
with _ModelSearchContext():
322-
323-
if isinstance(filter, str):
324-
filter = Identity(filter)
325-
326-
models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(region=region)
327-
manifest_keys = set(models_manifest_list[0].__slots__)
302+
for operator in _model_filter_in_operator_generator(filter):
303+
model_filter = operator.unresolved_value
304+
key = model_filter.key
305+
all_keys.add(key)
306+
model_filters.add(model_filter)
328307

329-
all_keys: Set[str] = set()
308+
for key in all_keys:
309+
if "." in key:
310+
raise NotImplementedError(f"No support for multiple level metadata indexing ('{key}').")
330311

331-
model_filters: Set[ModelFilter] = set()
312+
metadata_filter_keys = all_keys - SPECIAL_SUPPORTED_FILTER_KEYS
332313

333-
for operator in _model_filter_in_operator_generator(filter):
334-
model_filter = operator.unresolved_value
335-
key = model_filter.key
336-
all_keys.add(key)
337-
model_filters.add(model_filter)
314+
required_manifest_keys = manifest_keys.intersection(metadata_filter_keys)
315+
possible_spec_keys = metadata_filter_keys - manifest_keys
338316

339-
for key in all_keys:
340-
if "." in key:
341-
raise NotImplementedError(
342-
f"No support for multiple level metadata indexing ('{key}')."
343-
)
317+
is_task_filter = SpecialSupportedFilterKeys.TASK in all_keys
318+
is_framework_filter = SpecialSupportedFilterKeys.FRAMEWORK in all_keys
344319

345-
metadata_filter_keys = all_keys - SPECIAL_SUPPORTED_FILTER_KEYS
320+
def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str, str]]:
346321

347-
required_manifest_keys = manifest_keys.intersection(metadata_filter_keys)
348-
possible_spec_keys = metadata_filter_keys - manifest_keys
322+
copied_filter = copy.deepcopy(filter)
349323

350-
unrecognized_keys: Set[str] = set()
324+
manifest_specs_cached_values: Dict[str, Union[bool, int, float, str, dict, list]] = {}
351325

352-
is_task_filter = SpecialSupportedFilterKeys.TASK in all_keys
353-
is_framework_filter = SpecialSupportedFilterKeys.FRAMEWORK in all_keys
326+
model_filters_to_resolved_values: Dict[ModelFilter, BooleanValues] = {}
354327

355-
for model_manifest in models_manifest_list:
328+
for val in required_manifest_keys:
329+
manifest_specs_cached_values[val] = getattr(model_manifest, val)
356330

357-
copied_filter = copy.deepcopy(filter)
331+
if is_task_filter:
332+
manifest_specs_cached_values[
333+
SpecialSupportedFilterKeys.TASK
334+
] = extract_framework_task_model(model_manifest.model_id)[1]
358335

359-
manifest_specs_cached_values: Dict[str, Union[bool, int, float, str, dict, list]] = {}
336+
if is_framework_filter:
337+
manifest_specs_cached_values[
338+
SpecialSupportedFilterKeys.FRAMEWORK
339+
] = extract_framework_task_model(model_manifest.model_id)[0]
360340

361-
model_filters_to_resolved_values: Dict[ModelFilter, BooleanValues] = {}
341+
if Version(model_manifest.min_version) > Version(get_sagemaker_version()):
342+
return None
362343

363-
for val in required_manifest_keys:
364-
manifest_specs_cached_values[val] = getattr(model_manifest, val)
344+
_populate_model_filters_to_resolved_values(
345+
manifest_specs_cached_values,
346+
model_filters_to_resolved_values,
347+
model_filters,
348+
)
365349

366-
if is_task_filter:
367-
manifest_specs_cached_values[
368-
SpecialSupportedFilterKeys.TASK
369-
] = extract_framework_task_model(model_manifest.model_id)[1]
350+
_put_resolved_booleans_into_filter(copied_filter, model_filters_to_resolved_values)
370351

371-
if is_framework_filter:
372-
manifest_specs_cached_values[
373-
SpecialSupportedFilterKeys.FRAMEWORK
374-
] = extract_framework_task_model(model_manifest.model_id)[0]
352+
copied_filter.eval()
375353

376-
if Version(model_manifest.min_version) > Version(get_sagemaker_version()):
377-
continue
354+
if copied_filter.resolved_value in [BooleanValues.TRUE, BooleanValues.FALSE]:
355+
if copied_filter.resolved_value == BooleanValues.TRUE:
356+
return (model_manifest.model_id, model_manifest.version)
357+
return None
378358

379-
_populate_model_filters_to_resolved_values(
380-
manifest_specs_cached_values,
381-
model_filters_to_resolved_values,
382-
model_filters,
359+
if copied_filter.resolved_value == BooleanValues.UNEVALUATED:
360+
raise RuntimeError(
361+
"Filter expression in unevaluated state after using "
362+
"values from model manifest. Model ID and version that "
363+
f"is failing: {(model_manifest.model_id, model_manifest.version)}."
383364
)
365+
copied_filter_2 = copy.deepcopy(filter)
384366

385-
_put_resolved_booleans_into_filter(copied_filter, model_filters_to_resolved_values)
386-
387-
copied_filter.eval()
388-
389-
if copied_filter.resolved_value in [BooleanValues.TRUE, BooleanValues.FALSE]:
390-
if copied_filter.resolved_value == BooleanValues.TRUE:
391-
yield (model_manifest.model_id, model_manifest.version)
392-
continue
393-
394-
if copied_filter.resolved_value == BooleanValues.UNEVALUATED:
395-
raise RuntimeError(
396-
"Filter expression in unevaluated state after using "
397-
"values from model manifest. Model ID and version that "
398-
f"is failing: {(model_manifest.model_id, model_manifest.version)}."
367+
model_specs = JumpStartModelSpecs(
368+
json.loads(
369+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file(
370+
get_jumpstart_content_bucket(), model_manifest.spec_key
399371
)
400-
copied_filter_2 = copy.deepcopy(filter)
401-
402-
model_specs = accessors.JumpStartModelsAccessor.get_model_specs(
403-
region=region,
404-
model_id=model_manifest.model_id,
405-
version=model_manifest.version,
406372
)
373+
)
407374

408-
model_specs_keys = set(model_specs.__slots__)
375+
for val in possible_spec_keys:
376+
if hasattr(model_specs, val):
377+
manifest_specs_cached_values[val] = getattr(model_specs, val)
409378

410-
unrecognized_keys -= model_specs_keys
411-
unrecognized_keys_for_single_spec = possible_spec_keys - model_specs_keys
412-
unrecognized_keys.update(unrecognized_keys_for_single_spec)
379+
_populate_model_filters_to_resolved_values(
380+
manifest_specs_cached_values,
381+
model_filters_to_resolved_values,
382+
model_filters,
383+
)
384+
_put_resolved_booleans_into_filter(copied_filter_2, model_filters_to_resolved_values)
413385

414-
for val in possible_spec_keys:
415-
if hasattr(model_specs, val):
416-
manifest_specs_cached_values[val] = getattr(model_specs, val)
386+
copied_filter_2.eval()
417387

418-
_populate_model_filters_to_resolved_values(
419-
manifest_specs_cached_values,
420-
model_filters_to_resolved_values,
421-
model_filters,
422-
)
423-
_put_resolved_booleans_into_filter(copied_filter_2, model_filters_to_resolved_values)
388+
if copied_filter_2.resolved_value != BooleanValues.UNEVALUATED:
389+
if copied_filter_2.resolved_value == BooleanValues.TRUE or (
390+
BooleanValues.UNKNOWN and list_incomplete_models
391+
):
392+
return (model_manifest.model_id, model_manifest.version)
393+
return None
424394

425-
copied_filter_2.eval()
395+
raise RuntimeError(
396+
"Filter expression in unevaluated state after using values from model specs. "
397+
"Model ID and version that is failing: "
398+
f"{(model_manifest.model_id, model_manifest.version)}."
399+
)
426400

427-
if copied_filter_2.resolved_value != BooleanValues.UNEVALUATED:
428-
if copied_filter_2.resolved_value == BooleanValues.TRUE or (
429-
BooleanValues.UNKNOWN and list_incomplete_models
430-
):
431-
yield (model_manifest.model_id, model_manifest.version)
432-
continue
401+
max_memory = int(100 * 1e6)
402+
average_memory_per_thread = int(25 * 1e3)
403+
max_workers = int(max_memory / average_memory_per_thread)
433404

434-
raise RuntimeError(
435-
"Filter expression in unevaluated state after using values from model specs. "
436-
"Model ID and version that is failing: "
437-
f"{(model_manifest.model_id, model_manifest.version)}."
438-
)
405+
executor = ThreadPoolExecutor(max_workers=max_workers)
406+
407+
futures = []
408+
for header in models_manifest_list:
409+
futures.append(executor.submit(evaluate_model, header))
439410

440-
if len(unrecognized_keys) > 0:
441-
raise RuntimeError(f"Unrecognized keys: {str(unrecognized_keys)}")
411+
for future in as_completed(futures):
412+
result = future.result()
413+
if result:
414+
yield result
442415

443416

444417
def get_model_url(

0 commit comments

Comments
 (0)