Skip to content

Commit 7374fcb

Browse files
unit tests
1 parent 2d26009 commit 7374fcb

7 files changed

+565
-72
lines changed

src/aws_encryption_sdk/streaming_client.py

Lines changed: 73 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -954,8 +954,72 @@ def _prep_message(self):
954954
self._prep_non_framed()
955955
self._message_prepped = True
956956

957-
# TODO-MPL: Refactor this function, remove linter disablers
958-
def _read_header(self): # noqa pylint: disable=too-many-branches
957+
def _create_decrypt_materials_request(self, header):
958+
"""
959+
Create a DecryptionMaterialsRequest based on whether
960+
the StreamDecryptor was provided encryption_context on decrypt
961+
(i.e. expects to use required encryption context CMM from the MPL).
962+
"""
963+
# If encryption_context is provided on decrypt,
964+
# pass it to the DecryptionMaterialsRequest as reproduced_encryption_context
965+
if hasattr(self.config, "encryption_context"):
966+
return DecryptionMaterialsRequest(
967+
encrypted_data_keys=header.encrypted_data_keys,
968+
algorithm=header.algorithm,
969+
encryption_context=header.encryption_context,
970+
commitment_policy=self.config.commitment_policy,
971+
reproduced_encryption_context=self.config.encryption_context
972+
)
973+
return DecryptionMaterialsRequest(
974+
encrypted_data_keys=header.encrypted_data_keys,
975+
algorithm=header.algorithm,
976+
encryption_context=header.encryption_context,
977+
commitment_policy=self.config.commitment_policy,
978+
)
979+
980+
def _validate_parsed_header(
981+
self,
982+
header,
983+
header_auth,
984+
raw_header,
985+
):
986+
"""
987+
Pass arguments from this StreamDecryptor to validate_header based on whether
988+
the StreamDecryptor has the _required_encryption_context attribute
989+
(i.e. is using the required encryption context CMM from the MPL).
990+
"""
991+
# If _required_encryption_context is present,
992+
# serialize it and pass it to validate_header.
993+
if hasattr(self, "_required_encryption_context") \
994+
and self._required_encryption_context is not None:
995+
# The authenticated only encryption context is all encryption context key-value pairs where the
996+
# key exists in Required Encryption Context Keys. It is then serialized according to the
997+
# message header Key Value Pairs.
998+
required_ec_serialized = \
999+
aws_encryption_sdk.internal.formatting.encryption_context.serialize_encryption_context(
1000+
self._required_encryption_context
1001+
)
1002+
1003+
validate_header(
1004+
header=header,
1005+
header_auth=header_auth,
1006+
# When verifying the header, the AAD input to the authenticated encryption algorithm
1007+
# specified by the algorithm suite is the message header body and the serialized
1008+
# authenticated only encryption context.
1009+
raw_header=raw_header + required_ec_serialized,
1010+
data_key=self._derived_data_key
1011+
)
1012+
else:
1013+
validate_header(
1014+
header=header,
1015+
header_auth=header_auth,
1016+
raw_header=raw_header,
1017+
data_key=self._derived_data_key
1018+
)
1019+
1020+
return header, header_auth
1021+
1022+
def _read_header(self):
9591023
"""Reads the message header from the input stream.
9601024
9611025
:returns: tuple containing deserialized header and header_auth objects
@@ -981,24 +1045,7 @@ def _read_header(self): # noqa pylint: disable=too-many-branches
9811045
)
9821046
)
9831047

984-
# If encryption_context is provided on decrypt,
985-
# pass it to the DecryptionMaterialsRequest
986-
if hasattr(self.config, "encryption_context"):
987-
decrypt_materials_request = DecryptionMaterialsRequest(
988-
encrypted_data_keys=header.encrypted_data_keys,
989-
algorithm=header.algorithm,
990-
encryption_context=header.encryption_context,
991-
commitment_policy=self.config.commitment_policy,
992-
reproduced_encryption_context=self.config.encryption_context
993-
)
994-
else:
995-
decrypt_materials_request = DecryptionMaterialsRequest(
996-
encrypted_data_keys=header.encrypted_data_keys,
997-
algorithm=header.algorithm,
998-
encryption_context=header.encryption_context,
999-
commitment_policy=self.config.commitment_policy,
1000-
)
1001-
1048+
decrypt_materials_request = self._create_decrypt_materials_request(header)
10021049
decryption_materials = self.config.materials_manager.decrypt_materials(request=decrypt_materials_request)
10031050

10041051
# If the materials_manager passed required_encryption_context_keys,
@@ -1049,36 +1096,12 @@ def _read_header(self): # noqa pylint: disable=too-many-branches
10491096
"Key commitment validation failed. Key identity does not match the identity asserted in the "
10501097
"message. Halting processing of this message."
10511098
)
1052-
1053-
# If _required_encryption_context is present,
1054-
# serialize it and pass it to validate_header.
1055-
if self._required_encryption_context is not None:
1056-
# The authenticated only encryption context is all encryption context key-value pairs where the
1057-
# key exists in Required Encryption Context Keys. It is then serialized according to the
1058-
# message header Key Value Pairs.
1059-
required_ec_serialized = \
1060-
aws_encryption_sdk.internal.formatting.encryption_context.serialize_encryption_context(
1061-
self._required_encryption_context
1062-
)
1063-
1064-
validate_header(
1065-
header=header,
1066-
header_auth=header_auth,
1067-
# When verifying the header, the AAD input to the authenticated encryption algorithm
1068-
# specified by the algorithm suite is the message header body and the serialized
1069-
# authenticated only encryption context.
1070-
raw_header=raw_header + required_ec_serialized,
1071-
data_key=self._derived_data_key
1072-
)
1073-
else:
1074-
validate_header(
1075-
header=header,
1076-
header_auth=header_auth,
1077-
raw_header=raw_header,
1078-
data_key=self._derived_data_key
1079-
)
1080-
1081-
return header, header_auth
1099+
1100+
return self._validate_parsed_header(
1101+
header=header,
1102+
header_auth=header_auth,
1103+
raw_header=raw_header,
1104+
)
10821105

10831106
def _prep_non_framed(self):
10841107
"""Prepare the opening data for a non-framed message."""

test/mpl/unit/test_material_managers_mpl_cmm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
mock_mpl_cmm = MagicMock(__class__=MPL_ICryptographicMaterialsManager)
3939
mock_mpl_encryption_materials = MagicMock(__class__=MPL_EncryptionMaterials)
4040
mock_mpl_decrypt_materials = MagicMock(__class__=MPL_DecryptionMaterials)
41+
mock_reproduced_encryption_context = MagicMock(__class_=dict)
4142

4243

4344
mock_edk = MagicMock(__class__=Native_EncryptedDataKey)
@@ -259,6 +260,7 @@ def test_GIVEN_valid_request_WHEN_create_mpl_decrypt_materials_input_from_reques
259260
for mock_edks in [no_mock_edks, one_mock_edk, two_mock_edks]:
260261

261262
mock_decryption_materials_request.encrypted_data_keys = mock_edks
263+
mock_decryption_materials_request.reproduced_encryption_context = mock_reproduced_encryption_context
262264

263265
# When: _create_mpl_decrypt_materials_input_from_request
264266
output = CryptoMaterialsManagerFromMPL._create_mpl_decrypt_materials_input_from_request(
@@ -271,6 +273,7 @@ def test_GIVEN_valid_request_WHEN_create_mpl_decrypt_materials_input_from_reques
271273
assert output.algorithm_suite_id == mock_algorithm_id
272274
assert output.commitment_policy == mock_commitment_policy
273275
assert output.encryption_context == mock_decryption_materials_request.encryption_context
276+
assert output.reproduced_encryption_context == mock_reproduced_encryption_context
274277

275278
assert len(output.encrypted_data_keys) == len(mock_edks)
276279
for i in range(len(output.encrypted_data_keys)):

test/mpl/unit/test_material_managers_mpl_materials.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,19 @@ def test_GIVEN_valid_signing_key_WHEN_EncryptionMaterials_get_signing_key_THEN_r
160160
assert output == mock_signing_key
161161

162162

163+
def test_GIVEN_valid_required_encryption_context_keys_WHEN_EncryptionMaterials_get_required_encryption_context_keys_THEN_returns_required_encryption_context_keys():
164+
# Given: valid required encryption context keys
165+
mock_required_encryption_context_keys = MagicMock(__class__=bytes)
166+
mock_mpl_encryption_materials.required_encryption_context_keys = mock_required_encryption_context_keys
167+
168+
# When: get required encryption context keys
169+
mpl_encryption_materials = EncryptionMaterialsFromMPL(mpl_materials=mock_mpl_encryption_materials)
170+
output = mpl_encryption_materials.required_encryption_context_keys
171+
172+
# Then: returns required encryption context keys
173+
assert output == mock_required_encryption_context_keys
174+
175+
163176
def test_GIVEN_valid_data_key_WHEN_DecryptionMaterials_get_data_key_THEN_returns_data_key():
164177
# Given: valid MPL data key
165178
mock_data_key = MagicMock(__class__=bytes)
@@ -187,3 +200,29 @@ def test_GIVEN_valid_verification_key_WHEN_DecryptionMaterials_get_verification_
187200

188201
# Then: returns verification key
189202
assert output == mock_verification_key
203+
204+
205+
def test_GIVEN_valid_encryption_context_WHEN_DecryptionMaterials_get_encryption_context_THEN_returns_encryption_context():
206+
# Given: valid encryption context
207+
mock_encryption_context = MagicMock(__class__=Dict[str, str])
208+
mock_mpl_decrypt_materials.encryption_context = mock_encryption_context
209+
210+
# When: get encryption context
211+
mpl_decryption_materials = DecryptionMaterialsFromMPL(mpl_materials=mock_mpl_decrypt_materials)
212+
output = mpl_decryption_materials.encryption_context
213+
214+
# Then: returns valid encryption context
215+
assert output == mock_encryption_context
216+
217+
218+
def test_GIVEN_valid_required_encryption_context_keys_WHEN_DecryptionMaterials_get_required_encryption_context_keys_THEN_returns_required_encryption_context_keys():
219+
# Given: valid required encryption context keys
220+
mock_required_encryption_context_keys = MagicMock(__class__=bytes)
221+
mock_mpl_decrypt_materials.required_encryption_context_keys = mock_required_encryption_context_keys
222+
223+
# When: get required encryption context keys
224+
mpl_decryption_materials = DecryptionMaterialsFromMPL(mpl_materials=mock_mpl_decrypt_materials)
225+
output = mpl_decryption_materials.required_encryption_context_keys
226+
227+
# Then: returns required encryption context keys
228+
assert output == mock_required_encryption_context_keys

test/unit/test_serialize.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def apply_fixtures(self):
7979
"aws_encryption_sdk.internal.formatting.serialize.aws_encryption_sdk.internal.utils.validate_frame_length"
8080
)
8181
self.mock_valid_frame_length = self.mock_valid_frame_length_patcher.start()
82+
self.mock_required_ec_bytes = MagicMock()
8283
# Set up mock signer
8384
self.mock_signer = MagicMock()
8485
self.mock_signer.update.return_value = None
@@ -167,6 +168,31 @@ def test_serialize_header_auth_v1_no_signer(self):
167168
data_encryption_key=VALUES["data_key_obj"],
168169
)
169170

171+
@patch("aws_encryption_sdk.internal.formatting.serialize.header_auth_iv")
172+
def test_GIVEN_required_ec_bytes_WHEN_serialize_header_auth_v1_THEN_aad_has_required_ec_bytes(self, mock_header_auth_iv):
173+
"""Validate that the _create_header_auth function
174+
behaves as expected for SerializationVersion.V1
175+
when required_ec_bytes are provided.
176+
"""
177+
self.mock_encrypt.return_value = VALUES["header_auth_base"]
178+
test = aws_encryption_sdk.internal.formatting.serialize.serialize_header_auth(
179+
version=SerializationVersion.V1,
180+
algorithm=self.mock_algorithm,
181+
header=VALUES["serialized_header"],
182+
data_encryption_key=sentinel.encryption_key,
183+
signer=self.mock_signer,
184+
required_ec_bytes=self.mock_required_ec_bytes,
185+
)
186+
self.mock_encrypt.assert_called_once_with(
187+
algorithm=self.mock_algorithm,
188+
key=sentinel.encryption_key,
189+
plaintext=b"",
190+
associated_data=VALUES["serialized_header"] + self.mock_required_ec_bytes,
191+
iv=mock_header_auth_iv.return_value,
192+
)
193+
self.mock_signer.update.assert_called_once_with(VALUES["serialized_header_auth"])
194+
assert test == VALUES["serialized_header_auth"]
195+
170196
@patch("aws_encryption_sdk.internal.formatting.serialize.header_auth_iv")
171197
def test_serialize_header_auth_v2(self, mock_header_auth_iv):
172198
"""Validate that the _create_header_auth function
@@ -203,6 +229,30 @@ def test_serialize_header_auth_v2_no_signer(self):
203229
data_encryption_key=VALUES["data_key_obj"],
204230
)
205231

232+
@patch("aws_encryption_sdk.internal.formatting.serialize.header_auth_iv")
233+
def test_GIVEN_required_ec_bytes_WHEN_serialize_header_auth_v2_THEN_aad_has_required_ec_bytes(self, mock_header_auth_iv):
234+
"""Validate that the _create_header_auth function
235+
behaves as expected for SerializationVersion.V2.
236+
"""
237+
self.mock_encrypt.return_value = VALUES["header_auth_base"]
238+
test = aws_encryption_sdk.internal.formatting.serialize.serialize_header_auth(
239+
version=SerializationVersion.V2,
240+
algorithm=self.mock_algorithm,
241+
header=VALUES["serialized_header_v2_committing"],
242+
data_encryption_key=sentinel.encryption_key,
243+
signer=self.mock_signer,
244+
required_ec_bytes=self.mock_required_ec_bytes,
245+
)
246+
self.mock_encrypt.assert_called_once_with(
247+
algorithm=self.mock_algorithm,
248+
key=sentinel.encryption_key,
249+
plaintext=b"",
250+
associated_data=VALUES["serialized_header_v2_committing"] + self.mock_required_ec_bytes,
251+
iv=mock_header_auth_iv.return_value,
252+
)
253+
self.mock_signer.update.assert_called_once_with(VALUES["serialized_header_auth_v2"])
254+
assert test == VALUES["serialized_header_auth_v2"]
255+
206256
def test_serialize_non_framed_open(self):
207257
"""Validate that the serialize_non_framed_open
208258
function behaves as expected.

test/unit/test_streaming_client_configs.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import pytest
1717
import six
18-
from mock import patch
18+
from mock import MagicMock, patch
1919

2020
from aws_encryption_sdk import CommitmentPolicy
2121
from aws_encryption_sdk.internal.defaults import ALGORITHM, FRAME_LENGTH, LINE_LENGTH
@@ -33,7 +33,10 @@
3333
# Ideally, this logic would be based on mocking imports and testing logic,
3434
# but doing that introduces errors that cause other tests to fail.
3535
try:
36-
from aws_cryptographic_materialproviders.mpl.references import IKeyring
36+
from aws_cryptographic_materialproviders.mpl.references import (
37+
ICryptographicMaterialsManager,
38+
IKeyring,
39+
)
3740
HAS_MPL = True
3841

3942
from aws_encryption_sdk.materials_managers.mpl.cmm import CryptoMaterialsManagerFromMPL
@@ -236,24 +239,21 @@ def test_client_configs_with_mpl(
236239
assert test.materials_manager is not None
237240

238241
# If materials manager was provided, it should be directly used
239-
if hasattr(kwargs, "materials_manager"):
242+
if "materials_manager" in kwargs:
240243
assert kwargs["materials_manager"] == test.materials_manager
241244

242-
# If MPL keyring was provided, it should be wrapped in MPL materials manager
243-
if hasattr(kwargs, "keyring"):
244-
assert test.keyring is not None
245-
assert test.keyring == kwargs["keyring"]
246-
assert isinstance(test.keyring, IKeyring)
247-
assert isinstance(test.materials_manager, CryptoMaterialsManagerFromMPL)
248-
249245
# If native key_provider was provided, it should be wrapped in native materials manager
250-
if hasattr(kwargs, "key_provider"):
246+
elif "key_provider" in kwargs:
251247
assert test.key_provider is not None
252248
assert test.key_provider == kwargs["key_provider"]
253249
assert isinstance(test.materials_manager, DefaultCryptoMaterialsManager)
254250

251+
else:
252+
raise ValueError(f"Test did not find materials_manager or key_provider. {kwargs}")
253+
255254

256-
# This needs its own test; pytest parametrize cannot use a conditionally-loaded type
255+
# This is an addition to test_client_configs_with_mpl;
256+
# This needs its own test; pytest's parametrize cannot use a conditionally-loaded type (IKeyring)
257257
@pytest.mark.skipif(not HAS_MPL, reason="Test should only be executed with MPL in installation")
258258
def test_keyring_client_config_with_mpl(
259259
):
@@ -265,16 +265,30 @@ def test_keyring_client_config_with_mpl(
265265

266266
test = _ClientConfig(**kwargs)
267267

268-
# In all cases, config should have a materials manager
269268
assert test.materials_manager is not None
270269

271-
# If materials manager was provided, it should be directly used
272-
if hasattr(kwargs, "materials_manager"):
273-
assert kwargs["materials_manager"] == test.materials_manager
270+
assert test.keyring is not None
271+
assert test.keyring == kwargs["keyring"]
272+
assert isinstance(test.keyring, IKeyring)
273+
assert isinstance(test.materials_manager, CryptoMaterialsManagerFromMPL)
274+
275+
276+
# This is an addition to test_client_configs_with_mpl;
277+
# This needs its own test; pytest's parametrize cannot use a conditionally-loaded type (MPL CMM)
278+
@pytest.mark.skipif(not HAS_MPL, reason="Test should only be executed with MPL in installation")
279+
def test_mpl_cmm_client_config_with_mpl(
280+
):
281+
mock_mpl_cmm = MagicMock(__class__=ICryptographicMaterialsManager)
282+
kwargs = {
283+
"source": b"",
284+
"materials_manager": mock_mpl_cmm,
285+
"commitment_policy": CommitmentPolicy.REQUIRE_ENCRYPT_REQUIRE_DECRYPT
286+
}
287+
288+
test = _ClientConfig(**kwargs)
274289

275-
# If MPL keyring was provided, it should be wrapped in MPL materials manager
276-
if hasattr(kwargs, "keyring"):
277-
assert test.keyring is not None
278-
assert test.keyring == kwargs["keyring"]
279-
assert isinstance(test.keyring, IKeyring)
280-
assert isinstance(test.materials_manager, CryptoMaterialsManagerFromMPL)
290+
assert test.materials_manager is not None
291+
# Assert that the MPL CMM is wrapped in the native interface
292+
assert isinstance(test.materials_manager, CryptoMaterialsManagerFromMPL)
293+
# Assert the MPL CMM is used by the native interface
294+
assert test.materials_manager.mpl_cmm == mock_mpl_cmm

0 commit comments

Comments
 (0)