Skip to content

Commit 14ff170

Browse files
Merge branch 'master' into feature/latest_exeution_logs
2 parents 090225f + 80102e5 commit 14ff170

33 files changed

+1206
-149
lines changed

doc/amazon_sagemaker_featurestore.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ location of your offline store.
202202
    role_arn = role,
203203
    s3_uri = offline_feature_store_bucket,
204204
    enable_online_store = True,
205+
    ttl_duration = None,
205206
    online_store_kms_key_id = None,
206207
    offline_store_kms_key_id = None,
207208
    disable_glue_table_creation = False,

doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,102 @@ smdistributed.modelparallel.torch.DistributedOptimizer
494494
``state_dict`` contains elements corresponding to only the current
495495
partition, or to the entire model.
496496
497+
smdistributed.modelparallel.torch.nn.FlashAttentionLayer
498+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
499+
500+
.. function:: smdistributed.modelparallel.torch.nn.FlashAttentionLayer(attention_dropout_prob=0.1, attention_head_size=None, scale_attention_scores=True, scale_attn_by_layer_idx=False, layer_idx=None, scale=None, triton_flash_attention=False, use_alibi=False)
501+
502+
This class supports
503+
`FlashAttention <https://github.com/HazyResearch/flash-attention>`_
504+
for PyTorch 2.0.
505+
It takes the ``qkv`` matrix as an argument through its ``forward`` class method,
506+
computes attention scores and probabilities,
507+
and then operates the matrix multiplication with value layers.
508+
509+
Through this class, the smp library supports
510+
custom attention masks such as Attention with
511+
Linear Biases (ALiBi), and you can activate them by setting
512+
``triton_flash_attention`` and ``use_alibi`` to ``True``.
513+
514+
Note that the Triton flash attention does not support dropout
515+
on the attention probabilities. It uses standard lower triangular
516+
causal mask when causal mode is enabled. It also runs only
517+
on P4d and P4de instances, with fp16 or bf16.
518+
519+
This class computes the scale factor to apply when computing attention.
520+
By default, ``scale`` is set to ``None``, and it's automatically calculated.
521+
When ``scale_attention_scores`` is ``True`` (which is default), you must pass a value
522+
to ``attention_head_size``. When ``scale_attn_by_layer_idx`` is ``True``,
523+
you must pass a value to ``layer_idx``. If both factors are used, they are
524+
multiplied as follows: ``(1/(sqrt(attention_head_size) * (layer_idx+1)))``.
525+
This scale calculation can be bypassed if you specify a custom scaling
526+
factor to ``scale``. In other words, if you specify a value to ``scale``, the set of parameters
527+
(``scale_attention_scores``, ``attention_head_size``, ``scale_attn_by_layer_idx``, ``layer_idx``)
528+
is overridden and ignored.
529+
530+
**Parameters**
531+
532+
* ``attention_dropout_prob`` (float): (default: 0.1) specifies dropout probability
533+
to apply to attention.
534+
* ``attention_head_size`` (int): Required when ``scale_attention_scores`` is True.
535+
When ``scale_attention_scores`` is passed, this contributes
536+
``1/sqrt(attention_head_size)`` to the scale factor.
537+
* ``scale_attention_scores`` (boolean): (default: True) determines whether
538+
to multiply 1/sqrt(attention_head_size) to the scale factor.
539+
* ``layer_idx`` (int): Required when ``scale_attn_by_layer_idx`` is ``True``.
540+
The layer id to use for scaling attention by layer id.
541+
It contributes 1/(layer_idx + 1) to the scaling factor.
542+
* ``scale_attn_by_layer_idx`` (boolean): (default: False) determines whether
543+
to multiply 1/(layer_idx + 1) to the scale factor.
544+
* ``scale`` (float) (default: None): If passed, this scale factor will be
545+
applied bypassing the all of the previous arguments.
546+
* ``triton_flash_attention`` (bool): (default: False) If passed, Triton
547+
implementation of flash attention will be used. This is necessary to supports
548+
Attention with Linear Biases (ALiBi) (see next arg). Note that this version
549+
of the kernel doesn’t support dropout.
550+
* ``use_alibi`` (bool): (default: False) If passed, it enables Attention with
551+
Linear Biases (ALiBi) using the mask provided.
552+
553+
.. method:: forward(self, qkv, attn_mask=None, causal=False)
554+
555+
Returns a single ``torch.Tensor`` ``(batch_size x num_heads x seq_len x head_size)``,
556+
which represents the output of attention computation.
557+
558+
**Parameters**
559+
560+
* ``qkv``: ``torch.Tensor`` in the form of ``(batch_size x seqlen x 3 x num_heads x head_size)``.
561+
* ``attn_mask``: ``torch.Tensor`` in the form of ``(batch_size x 1 x 1 x seqlen)``.
562+
By default it is ``None``, and usage of this mask needs ``triton_flash_attention``
563+
and ``use_alibi`` to be set. See how to generate the mask in the following code snippet.
564+
* ``causal``: When passed, it uses the standard lower triangular mask. The default is ``False``.
565+
566+
When using ALiBi, it needs an attention mask prepared like the following.
567+
568+
.. code:: python
569+
570+
def generate_alibi_attn_mask(attention_mask, batch_size, seq_length,
571+
num_attention_heads, alibi_bias_max=8):
572+
573+
device, dtype = attention_mask.device, attention_mask.dtype
574+
alibi_attention_mask = torch.zeros(
575+
1, num_attention_heads, 1, seq_length, dtype=dtype, device=device
576+
)
577+
578+
alibi_bias = torch.arange(1 - seq_length, 1, dtype=dtype, device=device).view(
579+
1, 1, 1, seq_length
580+
)
581+
m = torch.arange(1, num_attention_heads + 1, dtype=dtype, device=device)
582+
m.mul_(alibi_bias_max / num_attention_heads)
583+
alibi_bias = alibi_bias * (1.0 / (2 ** m.view(1, num_attention_heads, 1, 1)))
584+
585+
alibi_attention_mask.add_(alibi_bias)
586+
alibi_attention_mask = alibi_attention_mask[..., :seq_length, :seq_length]
587+
if attention_mask is not None and attention_mask.bool().any():
588+
alibi_attention_mask.masked_fill(
589+
attention_mask.bool().view(batch_size, 1, 1, seq_length), float("-inf")
590+
)
591+
592+
return alibi_attention_mask
497593
498594
smdistributed.modelparallel.torch Context Managers and Util Functions
499595
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
doc8==0.10.1
2-
Pygments==2.11.2
2+
Pygments==2.15.0

src/sagemaker/estimator.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1737,21 +1737,41 @@ def register(
17371737

17381738
@property
17391739
def model_data(self):
1740-
"""str: The model location in S3. Only set if Estimator has been ``fit()``."""
1740+
"""Str or dict: The model location in S3. Only set if Estimator has been ``fit()``."""
17411741
if self.latest_training_job is not None and not isinstance(
17421742
self.sagemaker_session, PipelineSession
17431743
):
1744-
model_uri = self.sagemaker_session.sagemaker_client.describe_training_job(
1744+
job_details = self.sagemaker_session.sagemaker_client.describe_training_job(
17451745
TrainingJobName=self.latest_training_job.name
1746-
)["ModelArtifacts"]["S3ModelArtifacts"]
1747-
else:
1748-
logger.warning(
1749-
"No finished training job found associated with this estimator. Please make sure "
1750-
"this estimator is only used for building workflow config"
17511746
)
1752-
model_uri = os.path.join(
1753-
self.output_path, self._current_job_name, "output", "model.tar.gz"
1747+
model_uri = job_details["ModelArtifacts"]["S3ModelArtifacts"]
1748+
compression_type = job_details.get("OutputDataConfig", {}).get(
1749+
"CompressionType", "GZIP"
17541750
)
1751+
if compression_type == "GZIP":
1752+
return model_uri
1753+
# fail fast if we don't recognize training output compression type
1754+
if compression_type not in {"GZIP", "NONE"}:
1755+
raise ValueError(
1756+
f'Unrecognized training job output data compression type "{compression_type}"'
1757+
)
1758+
# model data is in uncompressed form NOTE SageMaker Hosting mandates presence of
1759+
# trailing forward slash in S3 model data URI, so append one if necessary.
1760+
if not model_uri.endswith("/"):
1761+
model_uri += "/"
1762+
return {
1763+
"S3DataSource": {
1764+
"S3Uri": model_uri,
1765+
"S3DataType": "S3Prefix",
1766+
"CompressionType": "None",
1767+
}
1768+
}
1769+
1770+
logger.warning(
1771+
"No finished training job found associated with this estimator. Please make sure "
1772+
"this estimator is only used for building workflow config"
1773+
)
1774+
model_uri = os.path.join(self.output_path, self._current_job_name, "output", "model.tar.gz")
17551775
return model_uri
17561776

17571777
@abstractmethod

src/sagemaker/feature_store/feature_group.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@
6161
FeatureParameter,
6262
TableFormatEnum,
6363
DeletionModeEnum,
64+
TtlDuration,
65+
OnlineStoreConfigUpdate,
6466
)
6567
from sagemaker.utils import resolve_value_from_config
6668

@@ -523,6 +525,7 @@ def create(
523525
role_arn: str = None,
524526
online_store_kms_key_id: str = None,
525527
enable_online_store: bool = False,
528+
ttl_duration: TtlDuration = None,
526529
offline_store_kms_key_id: str = None,
527530
disable_glue_table_creation: bool = False,
528531
data_catalog_config: DataCatalogConfig = None,
@@ -539,6 +542,7 @@ def create(
539542
event_time_feature_name (str): name of the event time feature.
540543
role_arn (str): ARN of the role used to call CreateFeatureGroup.
541544
online_store_kms_key_id (str): KMS key ARN for online store (default: None).
545+
ttl_duration (TtlDuration): Default time to live duration for records (default: None).
542546
enable_online_store (bool): whether to enable online store or not (default: False).
543547
offline_store_kms_key_id (str): KMS key ARN for offline store (default: None).
544548
If a KMS encryption key is not specified, SageMaker encrypts all data at
@@ -592,7 +596,10 @@ def create(
592596

593597
# online store configuration
594598
if enable_online_store:
595-
online_store_config = OnlineStoreConfig(enable_online_store=enable_online_store)
599+
online_store_config = OnlineStoreConfig(
600+
enable_online_store=enable_online_store,
601+
ttl_duration=ttl_duration,
602+
)
596603
if online_store_kms_key_id is not None:
597604
online_store_config.online_store_security_config = OnlineStoreSecurityConfig(
598605
kms_key_id=online_store_kms_key_id
@@ -633,21 +640,37 @@ def describe(self, next_token: str = None) -> Dict[str, Any]:
633640
feature_group_name=self.name, next_token=next_token
634641
)
635642

636-
def update(self, feature_additions: Sequence[FeatureDefinition]) -> Dict[str, Any]:
643+
def update(
644+
self,
645+
feature_additions: Sequence[FeatureDefinition] = None,
646+
online_store_config: OnlineStoreConfigUpdate = None,
647+
) -> Dict[str, Any]:
637648
"""Update a FeatureGroup and add new features from the given feature definitions.
638649
639650
Args:
640651
feature_additions (Sequence[Dict[str, str]): list of feature definitions to be updated.
652+
online_store_config (OnlineStoreConfigUpdate): online store config to be updated.
641653
642654
Returns:
643655
Response dict from service.
644656
"""
645657

658+
if feature_additions is None:
659+
feature_additions_parameter = None
660+
else:
661+
feature_additions_parameter = [
662+
feature_addition.to_dict() for feature_addition in feature_additions
663+
]
664+
665+
if online_store_config is None:
666+
online_store_config_parameter = None
667+
else:
668+
online_store_config_parameter = online_store_config.to_dict()
669+
646670
return self.sagemaker_session.update_feature_group(
647671
feature_group_name=self.name,
648-
feature_additions=[
649-
feature_addition.to_dict() for feature_addition in feature_additions
650-
],
672+
feature_additions=feature_additions_parameter,
673+
online_store_config=online_store_config_parameter,
651674
)
652675

653676
def update_feature_metadata(
@@ -756,7 +779,9 @@ def load_feature_definitions(
756779
return self.feature_definitions
757780

758781
def get_record(
759-
self, record_identifier_value_as_string: str, feature_names: Sequence[str] = None
782+
self,
783+
record_identifier_value_as_string: str,
784+
feature_names: Sequence[str] = None,
760785
) -> Sequence[Dict[str, str]]:
761786
"""Get a single record in a FeatureGroup
762787
@@ -772,14 +797,24 @@ def get_record(
772797
feature_names=feature_names,
773798
).get("Record")
774799

775-
def put_record(self, record: Sequence[FeatureValue]):
800+
def put_record(self, record: Sequence[FeatureValue], ttl_duration: TtlDuration = None):
776801
"""Put a single record in the FeatureGroup.
777802
778803
Args:
779804
record (Sequence[FeatureValue]): a list contains feature values.
805+
ttl_duration (TtlDuration): customer specified ttl duration.
780806
"""
807+
808+
if ttl_duration is not None:
809+
return self.sagemaker_session.put_record(
810+
feature_group_name=self.name,
811+
record=[value.to_dict() for value in record],
812+
ttl_duration=ttl_duration.to_dict(),
813+
)
814+
781815
return self.sagemaker_session.put_record(
782-
feature_group_name=self.name, record=[value.to_dict() for value in record]
816+
feature_group_name=self.name,
817+
record=[value.to_dict() for value in record],
783818
)
784819

785820
def delete_record(

src/sagemaker/feature_store/feature_store.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,18 +137,27 @@ def list_feature_groups(
137137
next_token=next_token,
138138
)
139139

140-
def batch_get_record(self, identifiers: Sequence[Identifier]) -> Dict[str, Any]:
140+
def batch_get_record(
141+
self,
142+
identifiers: Sequence[Identifier],
143+
expiration_time_response: str = None,
144+
) -> Dict[str, Any]:
141145
"""Get record in batch from FeatureStore
142146
143147
Args:
144148
identifiers (Sequence[Identifier]): A list of identifiers to uniquely identify records
145149
in FeatureStore.
150+
expiration_time_response (str): the field of expiration time response
151+
to toggle returning of expiresAt.
146152
147153
Returns:
148154
Response dict from service.
149155
"""
150156
batch_get_record_identifiers = [identifier.to_dict() for identifier in identifiers]
151-
return self.sagemaker_session.batch_get_record(identifiers=batch_get_record_identifiers)
157+
return self.sagemaker_session.batch_get_record(
158+
identifiers=batch_get_record_identifiers,
159+
expiration_time_response=expiration_time_response,
160+
)
152161

153162
def search(
154163
self,

src/sagemaker/feature_store/inputs.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,17 +84,43 @@ def to_dict(self) -> Dict[str, Any]:
8484
return Config.construct_dict(KmsKeyId=self.kms_key_id)
8585

8686

87+
@attr.s
88+
class TtlDuration(Config):
89+
"""TtlDuration for records in online FeatureStore.
90+
91+
Attributes:
92+
unit (str): time unit.
93+
value (int): time value.
94+
"""
95+
96+
unit: str = attr.ib()
97+
value: int = attr.ib()
98+
99+
def to_dict(self) -> Dict[str, Any]:
100+
"""Construct a dictionary based on the attributes.
101+
102+
Returns:
103+
dict represents the attributes.
104+
"""
105+
return Config.construct_dict(
106+
Unit=self.unit,
107+
Value=self.value,
108+
)
109+
110+
87111
@attr.s
88112
class OnlineStoreConfig(Config):
89113
"""OnlineStoreConfig for FeatureStore.
90114
91115
Attributes:
92116
enable_online_store (bool): whether to enable the online store.
93117
online_store_security_config (OnlineStoreSecurityConfig): configuration of security setting.
118+
ttl_duration (TtlDuration): Default time to live duration for records.
94119
"""
95120

96121
enable_online_store: bool = attr.ib(default=True)
97122
online_store_security_config: OnlineStoreSecurityConfig = attr.ib(default=None)
123+
ttl_duration: TtlDuration = attr.ib(default=None)
98124

99125
def to_dict(self) -> Dict[str, Any]:
100126
"""Construct a dictionary based on the attributes.
@@ -105,6 +131,28 @@ def to_dict(self) -> Dict[str, Any]:
105131
return Config.construct_dict(
106132
EnableOnlineStore=self.enable_online_store,
107133
SecurityConfig=self.online_store_security_config,
134+
TtlDuration=self.ttl_duration,
135+
)
136+
137+
138+
@attr.s
139+
class OnlineStoreConfigUpdate(Config):
140+
"""OnlineStoreConfigUpdate for FeatureStore.
141+
142+
Attributes:
143+
ttl_duration (TtlDuration): Default time to live duration for records.
144+
"""
145+
146+
ttl_duration: TtlDuration = attr.ib(default=None)
147+
148+
def to_dict(self) -> Dict[str, Any]:
149+
"""Construct a dictionary based on the attributes.
150+
151+
Returns:
152+
dict represents the attributes.
153+
"""
154+
return Config.construct_dict(
155+
TtlDuration=self.ttl_duration,
108156
)
109157

110158

@@ -379,3 +427,13 @@ class DeletionModeEnum(Enum):
379427

380428
SOFT_DELETE = "SoftDelete"
381429
HARD_DELETE = "HardDelete"
430+
431+
432+
class ExpirationTimeResponseEnum(Enum):
433+
"""Enum of toggling the response of ExpiresAt.
434+
435+
The ExpirationTimeResponse for toggling the response of ExpiresAt can be Disabled or Enabled.
436+
"""
437+
438+
DISABLED = "Disabled"
439+
ENABLED = "Enabled"

0 commit comments

Comments
 (0)