|
14 | 14 |
|
15 | 15 | import os
|
16 | 16 |
|
| 17 | +import boto3 |
17 | 18 | import pytest
|
| 19 | +from botocore.config import Config |
| 20 | +from sagemaker import Session |
18 | 21 | from sagemaker.fw_registry import default_framework_uri
|
19 | 22 |
|
20 | 23 | from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor, Processor
|
|
23 | 26 | from tests.integ.kms_utils import get_or_create_kms_key
|
24 | 27 |
|
25 | 28 | ROLE = "SageMakerRole"
|
| 29 | +DEFAULT_REGION = "us-west-2" |
| 30 | +CUSTOM_BUCKET_PATH = "sagemaker-custom-bucket" |
| 31 | + |
| 32 | + |
| 33 | +@pytest.fixture(scope="module") |
| 34 | +def sagemaker_session_with_custom_bucket( |
| 35 | + boto_config, sagemaker_client_config, sagemaker_runtime_config |
| 36 | +): |
| 37 | + boto_session = ( |
| 38 | + boto3.Session(**boto_config) if boto_config else boto3.Session(region_name=DEFAULT_REGION) |
| 39 | + ) |
| 40 | + sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=10))) |
| 41 | + sagemaker_client = ( |
| 42 | + boto_session.client("sagemaker", **sagemaker_client_config) |
| 43 | + if sagemaker_client_config |
| 44 | + else None |
| 45 | + ) |
| 46 | + runtime_client = ( |
| 47 | + boto_session.client("sagemaker-runtime", **sagemaker_runtime_config) |
| 48 | + if sagemaker_runtime_config |
| 49 | + else None |
| 50 | + ) |
| 51 | + |
| 52 | + return Session( |
| 53 | + boto_session=boto_session, |
| 54 | + sagemaker_client=sagemaker_client, |
| 55 | + sagemaker_runtime_client=runtime_client, |
| 56 | + default_bucket=CUSTOM_BUCKET_PATH, |
| 57 | + ) |
26 | 58 |
|
27 | 59 |
|
28 | 60 | @pytest.fixture(scope="module")
|
@@ -170,6 +202,89 @@ def test_sklearn_with_customizations(
|
170 | 202 | assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600}
|
171 | 203 |
|
172 | 204 |
|
| 205 | +def test_sklearn_with_custom_default_bucket( |
| 206 | + sagemaker_session_with_custom_bucket, |
| 207 | + image_uri, |
| 208 | + sklearn_full_version, |
| 209 | + cpu_instance_type, |
| 210 | + output_kms_key, |
| 211 | +): |
| 212 | + input_file_path = os.path.join(DATA_DIR, "dummy_input.txt") |
| 213 | + |
| 214 | + sklearn_processor = SKLearnProcessor( |
| 215 | + framework_version=sklearn_full_version, |
| 216 | + role=ROLE, |
| 217 | + command=["python3"], |
| 218 | + instance_type=cpu_instance_type, |
| 219 | + instance_count=1, |
| 220 | + volume_size_in_gb=100, |
| 221 | + volume_kms_key=None, |
| 222 | + output_kms_key=output_kms_key, |
| 223 | + max_runtime_in_seconds=3600, |
| 224 | + base_job_name="test-sklearn-with-customizations", |
| 225 | + env={"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"}, |
| 226 | + tags=[{"Key": "dummy-tag", "Value": "dummy-tag-value"}], |
| 227 | + sagemaker_session=sagemaker_session_with_custom_bucket, |
| 228 | + ) |
| 229 | + |
| 230 | + sklearn_processor.run( |
| 231 | + code=os.path.join(DATA_DIR, "dummy_script.py"), |
| 232 | + inputs=[ |
| 233 | + ProcessingInput( |
| 234 | + source=input_file_path, |
| 235 | + destination="/opt/ml/processing/input/container/path/", |
| 236 | + input_name="dummy_input", |
| 237 | + s3_data_type="S3Prefix", |
| 238 | + s3_input_mode="File", |
| 239 | + s3_data_distribution_type="FullyReplicated", |
| 240 | + s3_compression_type="None", |
| 241 | + ) |
| 242 | + ], |
| 243 | + outputs=[ |
| 244 | + ProcessingOutput( |
| 245 | + source="/opt/ml/processing/output/container/path/", |
| 246 | + output_name="dummy_output", |
| 247 | + s3_upload_mode="EndOfJob", |
| 248 | + ) |
| 249 | + ], |
| 250 | + arguments=["-v"], |
| 251 | + wait=True, |
| 252 | + logs=True, |
| 253 | + ) |
| 254 | + |
| 255 | + job_description = sklearn_processor.latest_job.describe() |
| 256 | + |
| 257 | + assert job_description["ProcessingInputs"][0]["InputName"] == "dummy_input" |
| 258 | + assert CUSTOM_BUCKET_PATH in job_description["ProcessingInputs"][0]["S3Input"]["S3Uri"] |
| 259 | + |
| 260 | + assert job_description["ProcessingInputs"][1]["InputName"] == "code" |
| 261 | + assert CUSTOM_BUCKET_PATH in job_description["ProcessingInputs"][1]["S3Input"]["S3Uri"] |
| 262 | + |
| 263 | + assert job_description["ProcessingJobName"].startswith("test-sklearn-with-customizations") |
| 264 | + |
| 265 | + assert job_description["ProcessingJobStatus"] == "Completed" |
| 266 | + |
| 267 | + assert job_description["ProcessingOutputConfig"]["KmsKeyId"] == output_kms_key |
| 268 | + assert job_description["ProcessingOutputConfig"]["Outputs"][0]["OutputName"] == "dummy_output" |
| 269 | + |
| 270 | + assert job_description["ProcessingResources"] == { |
| 271 | + "ClusterConfig": {"InstanceCount": 1, "InstanceType": "ml.m4.xlarge", "VolumeSizeInGB": 100} |
| 272 | + } |
| 273 | + |
| 274 | + assert job_description["AppSpecification"]["ContainerArguments"] == ["-v"] |
| 275 | + assert job_description["AppSpecification"]["ContainerEntrypoint"] == [ |
| 276 | + "python3", |
| 277 | + "/opt/ml/processing/input/code/dummy_script.py", |
| 278 | + ] |
| 279 | + assert job_description["AppSpecification"]["ImageUri"] == image_uri |
| 280 | + |
| 281 | + assert job_description["Environment"] == {"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"} |
| 282 | + |
| 283 | + assert ROLE in job_description["RoleArn"] |
| 284 | + |
| 285 | + assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600} |
| 286 | + |
| 287 | + |
173 | 288 | def test_sklearn_with_no_inputs_or_outputs(
|
174 | 289 | sagemaker_session, image_uri, sklearn_full_version, cpu_instance_type
|
175 | 290 | ):
|
@@ -405,3 +520,72 @@ def test_processor(sagemaker_session, image_uri, cpu_instance_type, output_kms_k
|
405 | 520 | assert ROLE in job_description["RoleArn"]
|
406 | 521 |
|
407 | 522 | assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600}
|
| 523 | + |
| 524 | + |
| 525 | +def test_processor_with_custom_bucket( |
| 526 | + sagemaker_session_with_custom_bucket, image_uri, cpu_instance_type, output_kms_key |
| 527 | +): |
| 528 | + script_path = os.path.join(DATA_DIR, "dummy_script.py") |
| 529 | + |
| 530 | + processor = Processor( |
| 531 | + role=ROLE, |
| 532 | + image_uri=image_uri, |
| 533 | + instance_count=1, |
| 534 | + instance_type=cpu_instance_type, |
| 535 | + entrypoint=["python3", "/opt/ml/processing/input/code/dummy_script.py"], |
| 536 | + volume_size_in_gb=100, |
| 537 | + volume_kms_key=None, |
| 538 | + output_kms_key=output_kms_key, |
| 539 | + max_runtime_in_seconds=3600, |
| 540 | + base_job_name="test-processor", |
| 541 | + env={"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"}, |
| 542 | + tags=[{"Key": "dummy-tag", "Value": "dummy-tag-value"}], |
| 543 | + sagemaker_session=sagemaker_session_with_custom_bucket, |
| 544 | + ) |
| 545 | + |
| 546 | + processor.run( |
| 547 | + inputs=[ |
| 548 | + ProcessingInput( |
| 549 | + source=script_path, destination="/opt/ml/processing/input/code/", input_name="code" |
| 550 | + ) |
| 551 | + ], |
| 552 | + outputs=[ |
| 553 | + ProcessingOutput( |
| 554 | + source="/opt/ml/processing/output/container/path/", |
| 555 | + output_name="dummy_output", |
| 556 | + s3_upload_mode="EndOfJob", |
| 557 | + ) |
| 558 | + ], |
| 559 | + arguments=["-v"], |
| 560 | + wait=True, |
| 561 | + logs=True, |
| 562 | + ) |
| 563 | + |
| 564 | + job_description = processor.latest_job.describe() |
| 565 | + |
| 566 | + assert job_description["ProcessingInputs"][0]["InputName"] == "code" |
| 567 | + assert CUSTOM_BUCKET_PATH in job_description["ProcessingInputs"][0]["S3Input"]["S3Uri"] |
| 568 | + |
| 569 | + assert job_description["ProcessingJobName"].startswith("test-processor") |
| 570 | + |
| 571 | + assert job_description["ProcessingJobStatus"] == "Completed" |
| 572 | + |
| 573 | + assert job_description["ProcessingOutputConfig"]["KmsKeyId"] == output_kms_key |
| 574 | + assert job_description["ProcessingOutputConfig"]["Outputs"][0]["OutputName"] == "dummy_output" |
| 575 | + |
| 576 | + assert job_description["ProcessingResources"] == { |
| 577 | + "ClusterConfig": {"InstanceCount": 1, "InstanceType": "ml.m4.xlarge", "VolumeSizeInGB": 100} |
| 578 | + } |
| 579 | + |
| 580 | + assert job_description["AppSpecification"]["ContainerArguments"] == ["-v"] |
| 581 | + assert job_description["AppSpecification"]["ContainerEntrypoint"] == [ |
| 582 | + "python3", |
| 583 | + "/opt/ml/processing/input/code/dummy_script.py", |
| 584 | + ] |
| 585 | + assert job_description["AppSpecification"]["ImageUri"] == image_uri |
| 586 | + |
| 587 | + assert job_description["Environment"] == {"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"} |
| 588 | + |
| 589 | + assert ROLE in job_description["RoleArn"] |
| 590 | + |
| 591 | + assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600} |
0 commit comments