Skip to content

Commit 751fee3

Browse files
committed
chore: address PR comments
1 parent 50dd33c commit 751fee3

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

src/sagemaker/jumpstart/notebook_utils.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket, get_sagemaker_version
3838
from sagemaker.session import Session
3939

40+
MAX_SEARCH_WORKERS = int(100 * 1e6 / 25 * 1e3) # max 100MB total memory, 25kB per thread)
41+
4042

4143
def _compare_model_version_tuples( # pylint: disable=too-many-return-statements
4244
model_version_1: Optional[Tuple[str, str]] = None,
@@ -392,6 +394,9 @@ def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str,
392394
)
393395
copied_filter_2 = copy.deepcopy(filter)
394396

397+
# spec is downloaded to thread's memory. since each thread
398+
# accesses a unique s3 spec, there is no need to use the JS caching utils.
399+
# spec only stays in memory for lifecycle of thread.
395400
model_specs = JumpStartModelSpecs(
396401
json.loads(
397402
sagemaker_session.read_s3_file(
@@ -426,23 +431,18 @@ def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str,
426431
f"{(model_manifest.model_id, model_manifest.version)}."
427432
)
428433

429-
max_memory_bytes = int(100 * 1e6)
430-
average_memory_bytes_per_thread = int(25 * 1e3)
431-
max_workers = int(max_memory_bytes / average_memory_bytes_per_thread)
432-
433-
executor = ThreadPoolExecutor(max_workers=max_workers)
434-
435-
futures = []
436-
for header in models_manifest_list:
437-
futures.append(executor.submit(evaluate_model, header))
438-
439-
for future in as_completed(futures):
440-
error = future.exception()
441-
if error:
442-
raise error
443-
result = future.result()
444-
if result:
445-
yield result
434+
with ThreadPoolExecutor(max_workers=MAX_SEARCH_WORKERS) as executor:
435+
futures = []
436+
for header in models_manifest_list:
437+
futures.append(executor.submit(evaluate_model, header))
438+
439+
for future in as_completed(futures):
440+
error = future.exception()
441+
if error:
442+
raise error
443+
result = future.result()
444+
if result:
445+
yield result
446446

447447

448448
def get_model_url(

0 commit comments

Comments
 (0)