Skip to content

Commit 39158b8

Browse files
author
Verdi March
committed
Change bash boostrap to python3 bootstrap
1 parent 12c222b commit 39158b8

File tree

1 file changed

+27
-22
lines changed

1 file changed

+27
-22
lines changed

src/sagemaker/processing.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,7 +1235,7 @@ class FeatureStoreOutput(ApiObject):
12351235
class FrameworkProcessor(ScriptProcessor):
12361236
"""Handles Amazon SageMaker processing tasks for jobs using a machine learning framework."""
12371237

1238-
framework_entrypoint_command = ["/bin/bash"]
1238+
framework_entrypoint_command = ["python3"]
12391239

12401240
# Added new (kw)args for estimator. The rest are from ScriptProcessor with same defaults.
12411241
def __init__(
@@ -1436,12 +1436,12 @@ def get_run_args(
14361436
"""
14371437
# When job_name is None, the job_name to upload code (+payload) will
14381438
# differ from job_name used by run().
1439-
s3_runproc_sh, inputs, job_name = self._pack_and_upload_code(
1439+
s3_runproc_py, inputs, job_name = self._pack_and_upload_code(
14401440
code, source_dir, dependencies, git_config, job_name, inputs
14411441
)
14421442

14431443
return RunArgs(
1444-
s3_runproc_sh,
1444+
s3_runproc_py,
14451445
inputs=inputs,
14461446
outputs=outputs,
14471447
arguments=arguments,
@@ -1551,13 +1551,13 @@ def run( # type: ignore[override]
15511551
kms_key (str): The ARN of the KMS key that is used to encrypt the
15521552
user code file (default: None).
15531553
"""
1554-
s3_runproc_sh, inputs, job_name = self._pack_and_upload_code(
1554+
s3_runproc_py, inputs, job_name = self._pack_and_upload_code(
15551555
code, source_dir, dependencies, git_config, job_name, inputs
15561556
)
15571557

15581558
# Submit a processing job.
15591559
super().run(
1560-
code=s3_runproc_sh,
1560+
code=s3_runproc_py,
15611561
inputs=inputs,
15621562
outputs=outputs,
15631563
arguments=arguments,
@@ -1597,20 +1597,20 @@ def _pack_and_upload_code(self, code, source_dir, dependencies, git_config, job_
15971597
"automatically."
15981598
)
15991599

1600-
# Upload the bootstrapping code as s3://.../jobname/source/runproc.sh.
1600+
# Upload the bootstrapping code as s3://.../jobname/source/runproc.py.
16011601
entrypoint_s3_uri = estimator.uploaded_code.s3_prefix.replace(
16021602
"sourcedir.tar.gz",
1603-
"runproc.sh",
1603+
"runproc.py",
16041604
)
16051605
script = estimator.uploaded_code.script_name
1606-
s3_runproc_sh = S3Uploader.upload_string_as_file_body(
1606+
s3_runproc_py = S3Uploader.upload_string_as_file_body(
16071607
self._generate_framework_script(script),
16081608
desired_s3_uri=entrypoint_s3_uri,
16091609
sagemaker_session=self.sagemaker_session,
16101610
)
1611-
logger.info("runproc.sh uploaded to %s", s3_runproc_sh)
1611+
logger.info("runproc.py uploaded to %s", s3_runproc_py)
16121612

1613-
return s3_runproc_sh, inputs, job_name
1613+
return s3_runproc_py, inputs, job_name
16141614

16151615
def _generate_framework_script(self, user_script: str) -> str:
16161616
"""Generate the framework entrypoint file (as text) for a processing job.
@@ -1626,22 +1626,27 @@ def _generate_framework_script(self, user_script: str) -> str:
16261626
"""
16271627
return dedent(
16281628
"""\
1629-
#!/bin/bash
1629+
import os
1630+
import subprocess
1631+
import sys
1632+
import tarfile
16301633
1631-
cd /opt/ml/processing/input/code/
1632-
tar -xzf sourcedir.tar.gz
16331634
1634-
# Exit on any error. SageMaker uses error code to mark failed job.
1635-
set -e
1635+
if __name__ == "__main__":
1636+
os.chdir("/opt/ml/processing/input/code")
16361637
1637-
if [[ -f 'requirements.txt' ]]; then
1638-
# Some py3 containers has typing, which may breaks pip install
1639-
pip uninstall --yes typing
1638+
with tarfile.open("sourcedir.tar.gz", "r:gz") as tar:
1639+
tar.extractall()
16401640
1641-
pip install -r requirements.txt
1642-
fi
1641+
if os.path.isfile("requirements.txt"):
1642+
# Some py3 containers has typing, which may breaks pip install
1643+
subprocess.run(["pip", "uninstall", "--yes", "typing"])
16431644
1644-
{entry_point_command} {entry_point} "$@"
1645+
subprocess.run(["pip", "install", "-r", "requirements.txt"])
1646+
1647+
cmd = ["{entry_point_command}", "{entry_point}"] + sys.argv[1:]
1648+
print(' '.join(cmd))
1649+
subprocess.run(cmd)
16451650
"""
16461651
).format(
16471652
entry_point_command=" ".join(self.command),
@@ -1683,7 +1688,7 @@ def _patch_inputs_with_payload(self, inputs, s3_payload) -> List[ProcessingInput
16831688
# Follow the exact same mechanism that ScriptProcessor does, which
16841689
# is to inject the S3 code artifact as a processing input. Note that
16851690
# framework processor take-over /opt/ml/processing/input/code for
1686-
# sourcedir.tar.gz, and let ScriptProcessor to place runproc.sh under
1691+
# sourcedir.tar.gz, and let ScriptProcessor to place runproc.py under
16871692
# /opt/ml/processing/input/{self._CODE_CONTAINER_INPUT_NAME}.
16881693
#
16891694
# See:

0 commit comments

Comments
 (0)