|
37 | 37 | from sagemaker.jumpstart.utils import get_jumpstart_content_bucket, get_sagemaker_version
|
38 | 38 | from sagemaker.session import Session
|
39 | 39 |
|
| 40 | +MAX_SEARCH_WORKERS = int(100 * 1e6 / 25 * 1e3) # max 100MB total memory, 25kB per thread) |
| 41 | + |
40 | 42 |
|
41 | 43 | def _compare_model_version_tuples( # pylint: disable=too-many-return-statements
|
42 | 44 | model_version_1: Optional[Tuple[str, str]] = None,
|
@@ -392,6 +394,9 @@ def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str,
|
392 | 394 | )
|
393 | 395 | copied_filter_2 = copy.deepcopy(filter)
|
394 | 396 |
|
| 397 | + # spec is downloaded to thread's memory. since each thread |
| 398 | + # accesses a unique s3 spec, there is no need to use the JS caching utils. |
| 399 | + # spec only stays in memory for lifecycle of thread. |
395 | 400 | model_specs = JumpStartModelSpecs(
|
396 | 401 | json.loads(
|
397 | 402 | sagemaker_session.read_s3_file(
|
@@ -426,23 +431,18 @@ def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str,
|
426 | 431 | f"{(model_manifest.model_id, model_manifest.version)}."
|
427 | 432 | )
|
428 | 433 |
|
429 |
| - max_memory_bytes = int(100 * 1e6) |
430 |
| - average_memory_bytes_per_thread = int(25 * 1e3) |
431 |
| - max_workers = int(max_memory_bytes / average_memory_bytes_per_thread) |
432 |
| - |
433 |
| - executor = ThreadPoolExecutor(max_workers=max_workers) |
434 |
| - |
435 |
| - futures = [] |
436 |
| - for header in models_manifest_list: |
437 |
| - futures.append(executor.submit(evaluate_model, header)) |
438 |
| - |
439 |
| - for future in as_completed(futures): |
440 |
| - error = future.exception() |
441 |
| - if error: |
442 |
| - raise error |
443 |
| - result = future.result() |
444 |
| - if result: |
445 |
| - yield result |
| 434 | + with ThreadPoolExecutor(max_workers=MAX_SEARCH_WORKERS) as executor: |
| 435 | + futures = [] |
| 436 | + for header in models_manifest_list: |
| 437 | + futures.append(executor.submit(evaluate_model, header)) |
| 438 | + |
| 439 | + for future in as_completed(futures): |
| 440 | + error = future.exception() |
| 441 | + if error: |
| 442 | + raise error |
| 443 | + result = future.result() |
| 444 | + if result: |
| 445 | + yield result |
446 | 446 |
|
447 | 447 |
|
448 | 448 | def get_model_url(
|
|
0 commit comments