@@ -1235,7 +1235,7 @@ class FeatureStoreOutput(ApiObject):
1235
1235
class FrameworkProcessor (ScriptProcessor ):
1236
1236
"""Handles Amazon SageMaker processing tasks for jobs using a machine learning framework."""
1237
1237
1238
- framework_entrypoint_command = ["/bin/bash " ]
1238
+ framework_entrypoint_command = ["python3 " ]
1239
1239
1240
1240
# Added new (kw)args for estimator. The rest are from ScriptProcessor with same defaults.
1241
1241
def __init__ (
@@ -1436,12 +1436,12 @@ def get_run_args(
1436
1436
"""
1437
1437
# When job_name is None, the job_name to upload code (+payload) will
1438
1438
# differ from job_name used by run().
1439
- s3_runproc_sh , inputs , job_name = self ._pack_and_upload_code (
1439
+ s3_runproc_py , inputs , job_name = self ._pack_and_upload_code (
1440
1440
code , source_dir , dependencies , git_config , job_name , inputs
1441
1441
)
1442
1442
1443
1443
return RunArgs (
1444
- s3_runproc_sh ,
1444
+ s3_runproc_py ,
1445
1445
inputs = inputs ,
1446
1446
outputs = outputs ,
1447
1447
arguments = arguments ,
@@ -1551,13 +1551,13 @@ def run( # type: ignore[override]
1551
1551
kms_key (str): The ARN of the KMS key that is used to encrypt the
1552
1552
user code file (default: None).
1553
1553
"""
1554
- s3_runproc_sh , inputs , job_name = self ._pack_and_upload_code (
1554
+ s3_runproc_py , inputs , job_name = self ._pack_and_upload_code (
1555
1555
code , source_dir , dependencies , git_config , job_name , inputs
1556
1556
)
1557
1557
1558
1558
# Submit a processing job.
1559
1559
super ().run (
1560
- code = s3_runproc_sh ,
1560
+ code = s3_runproc_py ,
1561
1561
inputs = inputs ,
1562
1562
outputs = outputs ,
1563
1563
arguments = arguments ,
@@ -1597,20 +1597,20 @@ def _pack_and_upload_code(self, code, source_dir, dependencies, git_config, job_
1597
1597
"automatically."
1598
1598
)
1599
1599
1600
- # Upload the bootstrapping code as s3://.../jobname/source/runproc.sh .
1600
+ # Upload the bootstrapping code as s3://.../jobname/source/runproc.py .
1601
1601
entrypoint_s3_uri = estimator .uploaded_code .s3_prefix .replace (
1602
1602
"sourcedir.tar.gz" ,
1603
- "runproc.sh " ,
1603
+ "runproc.py " ,
1604
1604
)
1605
1605
script = estimator .uploaded_code .script_name
1606
- s3_runproc_sh = S3Uploader .upload_string_as_file_body (
1606
+ s3_runproc_py = S3Uploader .upload_string_as_file_body (
1607
1607
self ._generate_framework_script (script ),
1608
1608
desired_s3_uri = entrypoint_s3_uri ,
1609
1609
sagemaker_session = self .sagemaker_session ,
1610
1610
)
1611
- logger .info ("runproc.sh uploaded to %s" , s3_runproc_sh )
1611
+ logger .info ("runproc.py uploaded to %s" , s3_runproc_py )
1612
1612
1613
- return s3_runproc_sh , inputs , job_name
1613
+ return s3_runproc_py , inputs , job_name
1614
1614
1615
1615
def _generate_framework_script (self , user_script : str ) -> str :
1616
1616
"""Generate the framework entrypoint file (as text) for a processing job.
@@ -1626,22 +1626,27 @@ def _generate_framework_script(self, user_script: str) -> str:
1626
1626
"""
1627
1627
return dedent (
1628
1628
"""\
1629
- #!/bin/bash
1629
+ import os
1630
+ import subprocess
1631
+ import sys
1632
+ import tarfile
1630
1633
1631
- cd /opt/ml/processing/input/code/
1632
- tar -xzf sourcedir.tar.gz
1633
1634
1634
- # Exit on any error. SageMaker uses error code to mark failed job.
1635
- set -e
1635
+ if __name__ == "__main__":
1636
+ os.chdir("/opt/ml/processing/input/code")
1636
1637
1637
- if [[ -f 'requirements.txt' ]]; then
1638
- # Some py3 containers has typing, which may breaks pip install
1639
- pip uninstall --yes typing
1638
+ with tarfile.open("sourcedir.tar.gz", "r:gz") as tar:
1639
+ tar.extractall()
1640
1640
1641
- pip install -r requirements.txt
1642
- fi
1641
+ if os.path.isfile("requirements.txt"):
1642
+ # Some py3 containers has typing, which may breaks pip install
1643
+ subprocess.run(["pip", "uninstall", "--yes", "typing"])
1643
1644
1644
- {entry_point_command} {entry_point} "$@"
1645
+ subprocess.run(["pip", "install", "-r", "requirements.txt"])
1646
+
1647
+ cmd = ["{entry_point_command}", "{entry_point}"] + sys.argv[1:]
1648
+ print(' '.join(cmd))
1649
+ subprocess.run(cmd)
1645
1650
"""
1646
1651
).format (
1647
1652
entry_point_command = " " .join (self .command ),
@@ -1683,7 +1688,7 @@ def _patch_inputs_with_payload(self, inputs, s3_payload) -> List[ProcessingInput
1683
1688
# Follow the exact same mechanism that ScriptProcessor does, which
1684
1689
# is to inject the S3 code artifact as a processing input. Note that
1685
1690
# framework processor take-over /opt/ml/processing/input/code for
1686
- # sourcedir.tar.gz, and let ScriptProcessor to place runproc.sh under
1691
+ # sourcedir.tar.gz, and let ScriptProcessor to place runproc.py under
1687
1692
# /opt/ml/processing/input/{self._CODE_CONTAINER_INPUT_NAME}.
1688
1693
#
1689
1694
# See:
0 commit comments