Skip to content

Commit 89e18a9

Browse files
author
Roja Reddy Sareddy
committed
fix:Added handler for pipeline variable while creating process job
1 parent 16b6f0c commit 89e18a9

File tree

2 files changed

+296
-2
lines changed

2 files changed

+296
-2
lines changed

src/sagemaker/processing.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,10 @@
6060
)
6161
from sagemaker.workflow import is_pipeline_variable
6262
from sagemaker.workflow.entities import PipelineVariable
63-
from sagemaker.workflow.execution_variables import ExecutionVariables
63+
from sagemaker.workflow.execution_variables import ExecutionVariable, ExecutionVariables
6464
from sagemaker.workflow.functions import Join
6565
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
66+
from sagemaker.workflow.parameters import Parameter
6667

6768
logger = logging.getLogger(__name__)
6869

@@ -314,6 +315,15 @@ def _normalize_args(
314315
"code argument has to be a valid S3 URI or local file path "
315316
+ "rather than a pipeline variable"
316317
)
318+
if arguments is not None:
319+
normalized_arguments = []
320+
for arg in arguments:
321+
if isinstance(arg, PipelineVariable):
322+
normalized_value = self._normalize_pipeline_variable(arg)
323+
normalized_arguments.append(normalized_value)
324+
else:
325+
normalized_arguments.append(str(arg))
326+
arguments = normalized_arguments
317327

318328
self._current_job_name = self._generate_current_job_name(job_name=job_name)
319329

@@ -499,6 +509,37 @@ def _normalize_outputs(self, outputs=None):
499509
normalized_outputs.append(output)
500510
return normalized_outputs
501511

512+
def _normalize_pipeline_variable(self, value):
513+
"""Helper function to normalize PipelineVariable objects"""
514+
try:
515+
if isinstance(value, Parameter):
516+
return str(value.default_value) if value.default_value is not None else None
517+
518+
elif isinstance(value, ExecutionVariable):
519+
return f"{value.name}"
520+
521+
elif isinstance(value, Join):
522+
normalized_values = [
523+
normalize_pipeline_variable(v) if isinstance(v, PipelineVariable) else str(v)
524+
for v in value.values
525+
]
526+
return value.on.join(normalized_values)
527+
528+
elif isinstance(value, PipelineVariable):
529+
if hasattr(value, 'default_value'):
530+
return str(value.default_value)
531+
elif hasattr(value, 'expr'):
532+
return str(value.expr)
533+
534+
return str(value)
535+
536+
except AttributeError as e:
537+
raise ValueError(f"Missing required attribute while normalizing {type(value).__name__}: {e}")
538+
except TypeError as e:
539+
raise ValueError(f"Type error while normalizing {type(value).__name__}: {e}")
540+
except Exception as e:
541+
raise ValueError(f"Error normalizing {type(value).__name__}: {e}")
542+
502543

503544
class ScriptProcessor(Processor):
504545
"""Handles Amazon SageMaker processing tasks for jobs using a machine learning framework."""

tests/unit/test_processing.py

Lines changed: 254 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,9 @@
4646
from sagemaker.fw_utils import UploadedCode
4747
from sagemaker.workflow.pipeline_context import PipelineSession, _PipelineConfig
4848
from sagemaker.workflow.functions import Join
49-
from sagemaker.workflow.execution_variables import ExecutionVariables
49+
from sagemaker.workflow.execution_variables import ExecutionVariable, ExecutionVariables
5050
from tests.unit import SAGEMAKER_CONFIG_PROCESSING_JOB
51+
from sagemaker.workflow.parameters import ParameterString, Parameter
5152

5253
BUCKET_NAME = "mybucket"
5354
REGION = "us-west-2"
@@ -1717,3 +1718,255 @@ def _get_describe_response_inputs_and_ouputs():
17171718
"ProcessingInputs": _get_expected_args_all_parameters(None)["inputs"],
17181719
"ProcessingOutputConfig": _get_expected_args_all_parameters(None)["output_config"],
17191720
}
1721+
1722+
# Parameters
1723+
def _get_data_inputs_with_parameters():
1724+
return [
1725+
ProcessingInput(
1726+
source=ParameterString(
1727+
name="input_data",
1728+
default_value="s3://dummy-bucket/input"
1729+
),
1730+
destination="/opt/ml/processing/input",
1731+
input_name="input-1"
1732+
)
1733+
]
1734+
1735+
1736+
def _get_data_outputs_with_parameters():
1737+
return [
1738+
ProcessingOutput(
1739+
source="/opt/ml/processing/output",
1740+
destination=ParameterString(
1741+
name="output_data",
1742+
default_value="s3://dummy-bucket/output"
1743+
),
1744+
output_name="output-1"
1745+
)
1746+
]
1747+
1748+
1749+
def _get_expected_args_with_parameters(job_name):
1750+
return {
1751+
"inputs": [{
1752+
"InputName": "input-1",
1753+
"S3Input": {
1754+
"S3Uri": "s3://dummy-bucket/input",
1755+
"LocalPath": "/opt/ml/processing/input",
1756+
"S3DataType": "S3Prefix",
1757+
"S3InputMode": "File",
1758+
"S3DataDistributionType": "FullyReplicated",
1759+
"S3CompressionType": "None"
1760+
}
1761+
}],
1762+
"output_config": {
1763+
"Outputs": [{
1764+
"OutputName": "output-1",
1765+
"S3Output": {
1766+
"S3Uri": "s3://dummy-bucket/output",
1767+
"LocalPath": "/opt/ml/processing/output",
1768+
"S3UploadMode": "EndOfJob"
1769+
}
1770+
}]
1771+
},
1772+
"job_name": job_name,
1773+
"resources": {
1774+
"ClusterConfig": {
1775+
"InstanceType": "ml.m4.xlarge",
1776+
"InstanceCount": 1,
1777+
"VolumeSizeInGB": 100,
1778+
"VolumeKmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/volume-kms-key"
1779+
}
1780+
},
1781+
"stopping_condition": {"MaxRuntimeInSeconds": 3600},
1782+
"app_specification": {
1783+
"ImageUri": "custom-image-uri",
1784+
"ContainerArguments": [
1785+
"--input-data",
1786+
"s3://dummy-bucket/input-param",
1787+
"--output-path",
1788+
"s3://dummy-bucket/output-param"
1789+
],
1790+
"ContainerEntrypoint": ["python3"]
1791+
},
1792+
"environment": {"my_env_variable": "my_env_variable_value"},
1793+
"network_config": {
1794+
"EnableNetworkIsolation": True,
1795+
"EnableInterContainerTrafficEncryption": True,
1796+
"VpcConfig": {
1797+
"Subnets": ["my_subnet_id"],
1798+
"SecurityGroupIds": ["my_security_group_id"]
1799+
}
1800+
},
1801+
"role_arn": "dummy/role",
1802+
"tags": [{"Key": "my-tag", "Value": "my-tag-value"}],
1803+
"experiment_config": {"ExperimentName": "AnExperiment"}
1804+
}
1805+
1806+
1807+
@patch("os.path.exists", return_value=True)
1808+
@patch("os.path.isfile", return_value=True)
1809+
@patch("sagemaker.utils.repack_model")
1810+
@patch("sagemaker.utils.create_tar_file")
1811+
@patch("sagemaker.session.Session.upload_data")
1812+
def test_script_processor_with_parameter_string(
1813+
upload_data_mock,
1814+
create_tar_file_mock,
1815+
repack_model_mock,
1816+
exists_mock,
1817+
isfile_mock,
1818+
sagemaker_session,
1819+
):
1820+
"""Test ScriptProcessor with ParameterString arguments"""
1821+
upload_data_mock.return_value = "s3://mocked_s3_uri_from_upload_data"
1822+
1823+
# Setup processor
1824+
processor = ScriptProcessor(
1825+
role="arn:aws:iam::012345678901:role/SageMakerRole", # Updated role ARN
1826+
image_uri="custom-image-uri",
1827+
command=["python3"],
1828+
instance_type="ml.m4.xlarge",
1829+
instance_count=1,
1830+
volume_size_in_gb=100,
1831+
volume_kms_key="arn:aws:kms:us-west-2:012345678901:key/volume-kms-key",
1832+
output_kms_key="arn:aws:kms:us-west-2:012345678901:key/output-kms-key",
1833+
max_runtime_in_seconds=3600,
1834+
base_job_name="test_processor",
1835+
env={"my_env_variable": "my_env_variable_value"},
1836+
tags=[{"Key": "my-tag", "Value": "my-tag-value"}],
1837+
network_config=NetworkConfig(
1838+
subnets=["my_subnet_id"],
1839+
security_group_ids=["my_security_group_id"],
1840+
enable_network_isolation=True,
1841+
encrypt_inter_container_traffic=True,
1842+
),
1843+
sagemaker_session=sagemaker_session,
1844+
)
1845+
1846+
input_param = ParameterString(
1847+
name="input_param",
1848+
default_value="s3://dummy-bucket/input-param"
1849+
)
1850+
output_param = ParameterString(
1851+
name="output_param",
1852+
default_value="s3://dummy-bucket/output-param"
1853+
)
1854+
exec_var = ExecutionVariable(
1855+
name="ExecutionTest"
1856+
)
1857+
join_var = Join(
1858+
on="/",
1859+
values=["s3://bucket", "prefix", "file.txt"]
1860+
)
1861+
dummy_str_var = "test-variable"
1862+
1863+
# Define expected arguments
1864+
expected_args = {
1865+
"inputs": [
1866+
{
1867+
"InputName": "input-1",
1868+
"AppManaged": False,
1869+
"S3Input": {
1870+
"S3Uri": ParameterString(
1871+
name="input_data",
1872+
default_value="s3://dummy-bucket/input"
1873+
),
1874+
"LocalPath": "/opt/ml/processing/input",
1875+
"S3DataType": "S3Prefix",
1876+
"S3InputMode": "File",
1877+
"S3DataDistributionType": "FullyReplicated",
1878+
"S3CompressionType": "None"
1879+
}
1880+
},
1881+
{
1882+
"InputName": "code",
1883+
"AppManaged": False,
1884+
"S3Input": {
1885+
"S3Uri": "s3://mocked_s3_uri_from_upload_data",
1886+
"LocalPath": "/opt/ml/processing/input/code",
1887+
"S3DataType": "S3Prefix",
1888+
"S3InputMode": "File",
1889+
"S3DataDistributionType": "FullyReplicated",
1890+
"S3CompressionType": "None"
1891+
}
1892+
}
1893+
],
1894+
"output_config": {
1895+
"Outputs": [
1896+
{
1897+
"OutputName": "output-1",
1898+
"AppManaged": False,
1899+
"S3Output": {
1900+
"S3Uri": ParameterString(
1901+
name="output_data",
1902+
default_value="s3://dummy-bucket/output"
1903+
),
1904+
"LocalPath": "/opt/ml/processing/output",
1905+
"S3UploadMode": "EndOfJob"
1906+
}
1907+
}
1908+
],
1909+
"KmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/output-kms-key"
1910+
},
1911+
"job_name": "test_job",
1912+
"resources": {
1913+
"ClusterConfig": {
1914+
"InstanceType": "ml.m4.xlarge",
1915+
"InstanceCount": 1,
1916+
"VolumeSizeInGB": 100,
1917+
"VolumeKmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/volume-kms-key"
1918+
}
1919+
},
1920+
"stopping_condition": {"MaxRuntimeInSeconds": 3600},
1921+
"app_specification": {
1922+
"ImageUri": "custom-image-uri",
1923+
"ContainerArguments": [
1924+
"--input-data",
1925+
"s3://dummy-bucket/input-param",
1926+
"--output-path",
1927+
"s3://dummy-bucket/output-param",
1928+
"--exec-arg", "ExecutionTest",
1929+
"--join-arg", "s3://bucket/prefix/file.txt",
1930+
"--string-param", "test-variable"
1931+
],
1932+
"ContainerEntrypoint": ["python3", "/opt/ml/processing/input/code/processing_code.py"]
1933+
},
1934+
"environment": {"my_env_variable": "my_env_variable_value"},
1935+
"network_config": {
1936+
"EnableNetworkIsolation": True,
1937+
"EnableInterContainerTrafficEncryption": True,
1938+
"VpcConfig": {
1939+
"SecurityGroupIds": ["my_security_group_id"],
1940+
"Subnets": ["my_subnet_id"]
1941+
}
1942+
},
1943+
"role_arn": "arn:aws:iam::012345678901:role/SageMakerRole",
1944+
"tags": [{"Key": "my-tag", "Value": "my-tag-value"}],
1945+
"experiment_config": {"ExperimentName": "AnExperiment"}
1946+
}
1947+
1948+
# Run processor
1949+
processor.run(
1950+
code="/local/path/to/processing_code.py",
1951+
inputs=_get_data_inputs_with_parameters(),
1952+
outputs=_get_data_outputs_with_parameters(),
1953+
arguments=[
1954+
"--input-data",
1955+
input_param,
1956+
"--output-path",
1957+
output_param,
1958+
"--exec-arg", exec_var,
1959+
"--join-arg", join_var,
1960+
"--string-param", dummy_str_var
1961+
],
1962+
wait=True,
1963+
logs=False,
1964+
job_name="test_job",
1965+
experiment_config={"ExperimentName": "AnExperiment"},
1966+
)
1967+
1968+
# Assert
1969+
sagemaker_session.process.assert_called_with(**expected_args)
1970+
assert "test_job" in processor._current_job_name
1971+
1972+

0 commit comments

Comments
 (0)