Skip to content

py_version doc fixes #5048

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 23 additions & 24 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,37 +10,37 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Utility methods used by framework classes"""
"""Utility methods used by framework classes."""
from __future__ import absolute_import

import json
import logging
import os
import re
import time
import shutil
import tempfile
import time
from collections import namedtuple
from typing import List, Optional, Union, Dict
from typing import Dict, List, Optional, Union

from packaging import version

import sagemaker.image_uris
import sagemaker.utils
from sagemaker.deprecations import deprecation_warn_base, renamed_kwargs, renamed_warning
from sagemaker.instance_group import InstanceGroup
from sagemaker.s3_utils import s3_path_join
from sagemaker.session_settings import SessionSettings
import sagemaker.utils
from sagemaker.workflow import is_pipeline_variable

from sagemaker.deprecations import renamed_warning, renamed_kwargs
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.deprecations import deprecation_warn_base

logger = logging.getLogger(__name__)

_TAR_SOURCE_FILENAME = "source.tar.gz"

UploadedCode = namedtuple("UploadedCode", ["s3_prefix", "script_name"])
"""sagemaker.fw_utils.UploadedCode: An object containing the S3 prefix and script name.

This is for the source code used for the entry point with an ``Estimator``. It can be
instantiated with positional or keyword arguments.
"""
Expand Down Expand Up @@ -211,7 +211,7 @@ def validate_source_code_input_against_pipeline_variables(
git_config: Optional[Dict[str, str]] = None,
enable_network_isolation: Union[bool, PipelineVariable] = False,
):
"""Validate source code input against pipeline variables
"""Validate source code input against pipeline variables.

Args:
entry_point (str or PipelineVariable): The path to the local Python source file that
Expand Down Expand Up @@ -481,7 +481,7 @@ def tar_and_upload_dir(


def _list_files_to_compress(script, directory):
"""Placeholder docstring"""
"""Placeholder docstring."""
if directory is None:
return [script]

Expand Down Expand Up @@ -585,7 +585,6 @@ def model_code_key_prefix(code_location_key_prefix, model_name, image):
The location returned is a potential concatenation of 2 parts
1. code_location_key_prefix if it exists
2. model_name or a name derived from the image

Args:
code_location_key_prefix (str): the s3 key prefix from code_location
model_name (str): the name of the model
Expand Down Expand Up @@ -620,8 +619,6 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution
"enabled": True
}
}


"""
if training_instance_type == "local" or distribution is None:
return
Expand All @@ -646,7 +643,7 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution
def profiler_config_deprecation_warning(
profiler_config, image_uri, framework_name, framework_version
):
"""Put out a deprecation message for if framework profiling is specified TF >= 2.12 and PT >= 2.0"""
"""Deprecation message if framework profiling is specified TF >= 2.12 and PT >= 2.0."""
if profiler_config is None or profiler_config.framework_profile_params is None:
return

Expand Down Expand Up @@ -692,6 +689,7 @@ def validate_smdistributed(
framework_name (str): A string representing the name of framework selected.
framework_version (str): A string representing the framework version selected.
py_version (str): A string representing the python version selected.
Ex: `py38, py39, py310, py311`
distribution (dict): A dictionary with information to enable distributed training.
(Defaults to None if distributed training is not enabled.) For example:

Expand Down Expand Up @@ -763,7 +761,8 @@ def _validate_smdataparallel_args(
instance_type (str): A string representing the type of training instance selected. Ex: `ml.p3.16xlarge`
framework_name (str): A string representing the name of framework selected. Ex: `tensorflow`
framework_version (str): A string representing the framework version selected. Ex: `2.3.1`
py_version (str): A string representing the python version selected. Ex: `py3`
py_version (str): A string representing the python version selected.
Ex: `py38, py39, py310, py311`
distribution (dict): A dictionary with information to enable distributed training.
(Defaults to None if distributed training is not enabled.) Ex:

Expand Down Expand Up @@ -847,6 +846,7 @@ def validate_distribution(
framework_name (str): A string representing the name of framework selected.
framework_version (str): A string representing the framework version selected.
py_version (str): A string representing the python version selected.
Ex: `py38, py39, py310, py311`
image_uri (str): A string representing a Docker image URI.
kwargs(dict): Additional kwargs passed to this function

Expand Down Expand Up @@ -953,7 +953,7 @@ def validate_distribution(


def validate_distribution_for_instance_type(instance_type, distribution):
"""Check if the provided distribution strategy is supported for the instance_type
"""Check if the provided distribution strategy is supported for the instance_type.

Args:
instance_type (str): A string representing the type of training instance selected.
Expand Down Expand Up @@ -1010,6 +1010,7 @@ def validate_torch_distributed_distribution(
}
framework_version (str): A string representing the framework version selected.
py_version (str): A string representing the python version selected.
Ex: `py38, py39, py310, py311`
image_uri (str): A string representing a Docker image URI.
entry_point (str or PipelineVariable): The absolute or relative path to the local Python
source file that should be executed as the entry point to
Expand Down Expand Up @@ -1072,7 +1073,7 @@ def validate_torch_distributed_distribution(


def _is_gpu_instance(instance_type):
"""Returns bool indicating whether instance_type supports GPU
"""Returns bool indicating whether instance_type supports GPU.

Args:
instance_type (str): Name of the instance_type to check against.
Expand All @@ -1091,7 +1092,7 @@ def _is_gpu_instance(instance_type):


def _is_trainium_instance(instance_type):
"""Returns bool indicating whether instance_type is a Trainium instance
"""Returns bool indicating whether instance_type is a Trainium instance.

Args:
instance_type (str): Name of the instance_type to check against.
Expand All @@ -1107,7 +1108,7 @@ def _is_trainium_instance(instance_type):


def python_deprecation_warning(framework, latest_supported_version):
"""Placeholder docstring"""
"""Placeholder docstring."""
return PYTHON_2_DEPRECATION_WARNING.format(
framework=framework, latest_supported_version=latest_supported_version
)
Expand All @@ -1121,7 +1122,6 @@ def _region_supports_debugger(region_name):

Returns:
bool: Whether or not the region supports Amazon SageMaker Debugger.

"""
return region_name.lower() not in DEBUGGER_UNSUPPORTED_REGIONS

Expand All @@ -1134,7 +1134,6 @@ def _region_supports_profiler(region_name):

Returns:
bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature.

"""
return region_name.lower() not in PROFILER_UNSUPPORTED_REGIONS

Expand Down Expand Up @@ -1162,7 +1161,8 @@ def validate_version_or_image_args(framework_version, py_version, image_uri):

Args:
framework_version (str): The version of the framework.
py_version (str): The version of Python.
py_version (str): A string representing the python version selected.
Ex: `py38, py39, py310, py311`
image_uri (str): The URI of the image.

Raises:
Expand Down Expand Up @@ -1194,9 +1194,8 @@ def create_image_uri(
instance_type (str): SageMaker instance type. Used to determine device
type (cpu/gpu/family-specific optimized).
framework_version (str): The version of the framework.
py_version (str): Optional. Python version. If specified, should be one
of 'py2' or 'py3'. If not specified, image uri will not include a
python component.
py_version (str): Optional. Python version Ex: `py38, py39, py310, py311`.
If not specified, image uri will not include a python component.
account (str): AWS account that contains the image. (default:
'520713654638')
accelerator_type (str): SageMaker Elastic Inference accelerator type.
Expand Down
14 changes: 5 additions & 9 deletions src/sagemaker/huggingface/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,13 @@

import logging
import re
from typing import Optional, Union, Dict
from typing import Dict, Optional, Union

from sagemaker.estimator import Framework, EstimatorBase
from sagemaker.fw_utils import (
framework_name_from_image,
validate_distribution,
)
from sagemaker.estimator import EstimatorBase, Framework
from sagemaker.fw_utils import framework_name_from_image, validate_distribution
from sagemaker.huggingface.model import HuggingFaceModel
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT

from sagemaker.huggingface.training_compiler.config import TrainingCompilerConfig
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
from sagemaker.workflow.entities import PipelineVariable

logger = logging.getLogger("sagemaker")
Expand Down Expand Up @@ -66,7 +62,7 @@ def __init__(
Args:
py_version (str): Python version you want to use for executing your model training
code. Defaults to ``None``. Required unless ``image_uri`` is provided. If
using PyTorch, the current supported version is ``py36``. If using TensorFlow,
using PyTorch, the current supported version is ``py39``. If using TensorFlow,
the current supported version is ``py37``.
entry_point (str or PipelineVariable): Path (absolute or relative) to the Python source
file which should be executed as the entry point to training.
Expand Down
38 changes: 19 additions & 19 deletions src/sagemaker/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,51 +18,51 @@
"""
from __future__ import absolute_import

import logging
import os
import pathlib
import logging
import re
from copy import copy
from textwrap import dedent
from typing import Dict, List, Optional, Union
from copy import copy
import re

import attr

from six.moves.urllib.parse import urlparse
from six.moves.urllib.request import url2pathname

from sagemaker import s3
from sagemaker.apiutils._base_types import ApiObject
from sagemaker.config import (
PROCESSING_JOB_ENABLE_NETWORK_ISOLATION_PATH,
PROCESSING_JOB_ENVIRONMENT_PATH,
PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH,
PROCESSING_JOB_KMS_KEY_ID_PATH,
PROCESSING_JOB_ROLE_ARN_PATH,
PROCESSING_JOB_SECURITY_GROUP_IDS_PATH,
PROCESSING_JOB_SUBNETS_PATH,
PROCESSING_JOB_ENABLE_NETWORK_ISOLATION_PATH,
PROCESSING_JOB_VOLUME_KMS_KEY_ID_PATH,
PROCESSING_JOB_ROLE_ARN_PATH,
PROCESSING_JOB_INTER_CONTAINER_ENCRYPTION_PATH,
PROCESSING_JOB_ENVIRONMENT_PATH,
)
from sagemaker.dataset_definition.inputs import DatasetDefinition, S3Input
from sagemaker.job import _Job
from sagemaker.local import LocalSession
from sagemaker.network import NetworkConfig
from sagemaker.s3 import S3Uploader
from sagemaker.session import Session
from sagemaker.utils import (
Tags,
base_name_from_image,
check_and_get_run_experiment_config,
format_tags,
get_config_value,
name_from_base,
check_and_get_run_experiment_config,
resolve_value_from_config,
resolve_class_attribute_from_config,
Tags,
format_tags,
resolve_value_from_config,
)
from sagemaker.session import Session
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.workflow.execution_variables import ExecutionVariables
from sagemaker.workflow.functions import Join
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
from sagemaker.workflow.execution_variables import ExecutionVariables
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.dataset_definition.inputs import S3Input, DatasetDefinition
from sagemaker.apiutils._base_types import ApiObject
from sagemaker.s3 import S3Uploader

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1465,7 +1465,7 @@ def __init__(
instance_type (str or PipelineVariable): The type of EC2 instance to use for
processing, for example, 'ml.c4.xlarge'.
py_version (str): Python version you want to use for executing your
model training code. One of 'py2' or 'py3'. Defaults to 'py3'. Value
model training code. Ex `py38, py39, py310, py311`. Value
is ignored when ``image_uri`` is provided.
image_uri (str or PipelineVariable): The URI of the Docker image to use for the
processing jobs (default: None).
Expand Down
Loading