Skip to content

Commit 54e2736

Browse files
committed
feat: Changes suggested in feedback
1 parent e932be8 commit 54e2736

File tree

7 files changed

+49
-51
lines changed

7 files changed

+49
-51
lines changed

supertokens_python/recipe/session/session_class.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
1414
import json
15-
from typing import Any, Dict, Union, List, TypeVar
15+
from typing import Any, Dict, List, TypeVar, Union
1616

1717
from supertokens_python.recipe.session.exceptions import (
18-
raise_unauthorised_exception,
1918
raise_invalid_claims_exception,
19+
raise_unauthorised_exception,
2020
)
2121

22-
from .interfaces import SessionContainer, SessionClaimValidator, SessionClaim
22+
from .interfaces import SessionClaim, SessionClaimValidator, SessionContainer
2323
from .utils import update_claims_in_payload_if_needed, validate_claims_in_payload
2424

2525
_T = TypeVar("_T")
@@ -154,6 +154,7 @@ async def fetch_and_set_claim(
154154
) -> None:
155155
if user_context is None:
156156
user_context = {}
157+
157158
update = await claim.build(self.get_user_id(), user_context)
158159
return await self.merge_into_access_token_payload(update, user_context)
159160

@@ -172,6 +173,9 @@ async def set_claim_value(
172173
async def get_claim_value(
173174
self, claim: SessionClaim[Any], user_context: Union[Dict[str, Any], None] = None
174175
) -> Union[Any, None]:
176+
if user_context is None:
177+
user_context = {}
178+
175179
return claim.get_value_from_payload(
176180
self.get_access_token_payload(user_context), user_context
177181
)
@@ -181,12 +185,18 @@ async def remove_claim(
181185
) -> None:
182186
if user_context is None:
183187
user_context = {}
188+
184189
update = claim.remove_from_payload_by_merge_({}, user_context)
185190
return await self.merge_into_access_token_payload(update, user_context)
186191

187192
async def merge_into_access_token_payload(
188-
self, access_token_payload_update: Dict[str, Any], user_context: Dict[str, Any]
193+
self,
194+
access_token_payload_update: Dict[str, Any],
195+
user_context: Union[Dict[str, Any], None] = None,
189196
) -> None:
197+
if user_context is None:
198+
user_context = {}
199+
190200
update_payload = {
191201
**self.get_access_token_payload(user_context),
192202
**access_token_payload_update,

supertokens_python/recipe/session/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ async def validate_claims_in_payload(
552552
log_debug_message(
553553
"validate_claims_in_payload %s validate res %s",
554554
validator.id,
555-
json.dumps(claim_validation_res),
555+
json.dumps(claim_validation_res.__dict__),
556556
)
557557
if (
558558
not claim_validation_res.is_valid

tests/sessions/claims/test_fetch_and_set_claim.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ async def test_should_not_change_if_claim_fetch_value_returns_none():
2626
wraps=session.merge_into_access_token_payload,
2727
) as mock:
2828
await session.fetch_and_set_claim(NoneClaim)
29-
mock.assert_called_once_with({}, None)
29+
mock.assert_called_once_with({}, {})
3030

3131

3232
async def test_should_update_if_claim_fetch_value_returns_value(timestamp: int):
@@ -44,4 +44,4 @@ async def test_should_update_if_claim_fetch_value_returns_value(timestamp: int):
4444
wraps=session.merge_into_access_token_payload,
4545
) as mock:
4646
await session.fetch_and_set_claim(TrueClaim)
47-
mock.assert_called_once_with({"st-true": {"t": timestamp, "v": True}}, None)
47+
mock.assert_called_once_with({"st-true": {"t": timestamp, "v": True}}, {})

tests/sessions/claims/test_primitive_claim.py

Lines changed: 29 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,11 @@ async def test_validators_should_not_validate_empty_payload():
9999
claim = PrimitiveClaim("key", sync_fetch_value)
100100
res = await claim.validators.has_value(val).validate({}, {})
101101

102-
assert res == {
103-
"isValid": False,
104-
"reason": {
105-
"expectedValue": val,
106-
"actualValue": None,
107-
"message": "wrong value",
108-
},
102+
assert res.is_valid is False
103+
assert res.reason == {
104+
"expectedValue": val,
105+
"actualValue": None,
106+
"message": "wrong value",
109107
}
110108

111109

@@ -114,13 +112,11 @@ async def test_should_not_validate_mismatching_payload():
114112
payload = await claim.build("user_id")
115113
res = await claim.validators.has_value(val2).validate(payload, {})
116114

117-
assert res == {
118-
"isValid": False,
119-
"reason": {
120-
"expectedValue": val2,
121-
"actualValue": val,
122-
"message": "wrong value",
123-
},
115+
assert res.is_valid is False
116+
assert res.reason == {
117+
"expectedValue": val2,
118+
"actualValue": val,
119+
"message": "wrong value",
124120
}
125121

126122

@@ -129,7 +125,7 @@ async def test_validator_should_validate_matching_payload():
129125
payload = await claim.build("user_id")
130126
res = await claim.validators.has_value(val).validate(payload, {})
131127

132-
assert res == {"isValid": True}
128+
assert res.is_valid is True
133129

134130

135131
async def test_should_validate_old_values_as_well(patch_get_timestamp_ms: MagicMock):
@@ -140,7 +136,7 @@ async def test_should_validate_old_values_as_well(patch_get_timestamp_ms: MagicM
140136
patch_get_timestamp_ms.return_value += 100 # type: ignore
141137

142138
res = await claim.validators.has_value(val).validate(payload, {})
143-
assert res == {"isValid": True}
139+
assert res.is_valid is True
144140

145141

146142
async def test_should_refetch_if_value_not_set():
@@ -165,36 +161,32 @@ async def test_validator_should_not_refetch_if_value_is_set():
165161
async def test_should_not_validate_empty_payload():
166162
claim = PrimitiveClaim("key", sync_fetch_value)
167163
res = await claim.validators.has_fresh_value(val, 600).validate({}, {})
168-
assert res == {
169-
"isValid": False,
170-
"reason": {
171-
"expectedValue": val,
172-
"actualValue": None,
173-
"message": "value does not exist", # TODO: Validate that this is actually correct.
174-
# because this makes sense yet the node PR isn't aligned with this.
175-
},
164+
assert res.is_valid is False
165+
assert res.reason == {
166+
"expectedValue": val,
167+
"actualValue": None,
168+
"message": "value does not exist", # TODO: Validate that this is actually correct.
169+
# because this makes sense yet the node PR isn't aligned with this.
176170
}
177171

178172

179173
async def test_has_fresh_value_should_not_validate_mismatching_payload():
180174
claim = PrimitiveClaim("key", sync_fetch_value)
181175
payload = await claim.build("user_id")
182176
res = await claim.validators.has_fresh_value(val2, 600).validate(payload, {})
183-
assert res == {
184-
"isValid": False,
185-
"reason": {
186-
"expectedValue": val2,
187-
"actualValue": val,
188-
"message": "wrong value",
189-
},
177+
assert res.is_valid is False
178+
assert res.reason == {
179+
"expectedValue": val2,
180+
"actualValue": val,
181+
"message": "wrong value",
190182
}
191183

192184

193185
async def test_should_validate_matching_payload():
194186
claim = PrimitiveClaim("key", sync_fetch_value)
195187
payload = await claim.build("user_id")
196188
res = await claim.validators.has_fresh_value(val, 600).validate(payload, {})
197-
assert res == {"isValid": True}
189+
assert res.is_valid is True
198190

199191

200192
async def test_should_not_validate_old_values_as_well(
@@ -208,13 +200,11 @@ async def test_should_not_validate_old_values_as_well(
208200
patch_get_timestamp_ms.return_value += 100 * SECONDS # type: ignore
209201

210202
res = await claim.validators.has_fresh_value(val, 10).validate(payload, {})
211-
assert res == {
212-
"isValid": False,
213-
"reason": {
214-
"ageInSeconds": 100,
215-
"maxAgeInSeconds": 10,
216-
"message": "expired",
217-
},
203+
assert res.is_valid is False
204+
assert res.reason == {
205+
"ageInSeconds": 100,
206+
"maxAgeInSeconds": 10,
207+
"message": "expired",
218208
}
219209

220210

tests/sessions/claims/test_set_claim_value.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ async def test_should_merge_the_right_value(timestamp: int):
3737
wraps=session.merge_into_access_token_payload,
3838
) as mock:
3939
await session.set_claim_value(TrueClaim, False)
40-
mock.assert_called_once_with({"st-true": {"t": timestamp, "v": False}}, None)
40+
mock.assert_called_once_with({"st-true": {"t": timestamp, "v": False}}, {})
4141

4242

4343
async def test_should_overwrite_claim_value(timestamp: int):
@@ -54,7 +54,7 @@ async def test_should_overwrite_claim_value(timestamp: int):
5454

5555
# Payload should be updated now:
5656
payload = s.get_access_token_payload()
57-
assert payload == {"st-true": {"t": timestamp, "v": "NEW_TRUE"}}
57+
assert payload == {"st-true": {"t": timestamp, "v": False}}
5858

5959

6060
async def test_should_overwrite_claim_value_using_session_handle(timestamp: int):

tests/sessions/claims/test_verify_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ async def test_should_reject_if_assert_claims_returns_an_error(
327327
[
328328
ClaimValidationError(
329329
"test_id",
330-
{"msg": "test_reason"},
330+
{"message": "test_reason"},
331331
)
332332
],
333333
)

tests/sessions/claims/test_with_jwt.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ async def create_api(request: Request): # type: ignore
4242
async def test_should_create_the_right_access_token_payload_with_claims_and_JWT_enabled(
4343
fastapi_client: TestClient,
4444
):
45-
# TODO: FIXME
4645
init(**get_st_init_args(TrueClaim, jwt=JWTConfig(enable=True))) # type:ignore
4746
start_st()
4847

@@ -52,7 +51,6 @@ async def test_should_create_the_right_access_token_payload_with_claims_and_JWT_
5251
session_info = await get_session_information(session_handle)
5352
assert session_info is not None
5453
access_token_payload = session_info.access_token_payload
55-
# TODO: .sub and .iss should be undefined as per node PR
5654
assert access_token_payload["jwt"] is not None
5755
assert access_token_payload["_jwtPName"] == "jwt"
5856

0 commit comments

Comments
 (0)