@@ -88,20 +88,69 @@ def __init__(self):
88
88
"appmetadata-{}-{}" .format (environment or "" , client_id or "" ),
89
89
}
90
90
91
- def find (self , credential_type , target = None , query = None ):
92
- target = target or []
91
+ def _get_access_token (
92
+ self ,
93
+ home_account_id , environment , client_id , realm , target , # Together they form a compound key
94
+ default = None ,
95
+ ): # O(1)
96
+ return self ._get (
97
+ self .CredentialType .ACCESS_TOKEN ,
98
+ self .key_makers [TokenCache .CredentialType .ACCESS_TOKEN ](
99
+ home_account_id = home_account_id ,
100
+ environment = environment ,
101
+ client_id = client_id ,
102
+ realm = realm ,
103
+ target = " " .join (target ),
104
+ ),
105
+ default = default )
106
+
107
+ def _get_app_metadata (self , environment , client_id , default = None ): # O(1)
108
+ return self ._get (
109
+ self .CredentialType .APP_METADATA ,
110
+ self .key_makers [TokenCache .CredentialType .APP_METADATA ](
111
+ environment = environment ,
112
+ client_id = client_id ,
113
+ ),
114
+ default = default )
115
+
116
+ def _get (self , credential_type , key , default = None ): # O(1)
117
+ with self ._lock :
118
+ return self ._cache .get (credential_type , {}).get (key , default )
119
+
120
+ def _find (self , credential_type , target = None , query = None ): # O(n) generator
121
+ """Returns a generator of matching entries.
122
+
123
+ It is O(1) for AT hits, and O(n) for other types.
124
+ Note that it holds a lock during the entire search.
125
+ """
126
+ target = sorted (target or []) # Match the order sorted by add()
93
127
assert isinstance (target , list ), "Invalid parameter type"
128
+
129
+ preferred_result = None
130
+ if (credential_type == self .CredentialType .ACCESS_TOKEN
131
+ and "home_account_id" in query and "environment" in query
132
+ and "client_id" in query and "realm" in query and target
133
+ ): # Special case for O(1) AT lookup
134
+ preferred_result = self ._get_access_token (
135
+ query ["home_account_id" ], query ["environment" ],
136
+ query ["client_id" ], query ["realm" ], target )
137
+ if preferred_result :
138
+ yield preferred_result
139
+
94
140
target_set = set (target )
95
141
with self ._lock :
96
142
# Since the target inside token cache key is (per schema) unsorted,
97
143
# there is no point to attempt an O(1) key-value search here.
98
144
# So we always do an O(n) in-memory search.
99
- return [entry
100
- for entry in self ._cache .get (credential_type , {}).values ()
101
- if is_subdict_of (query or {}, entry )
102
- and (target_set <= set (entry .get ("target" , "" ).split ())
103
- if target else True )
104
- ]
145
+ for entry in self ._cache .get (credential_type , {}).values ():
146
+ if is_subdict_of (query or {}, entry ) and (
147
+ target_set <= set (entry .get ("target" , "" ).split ())
148
+ if target else True ):
149
+ if entry != preferred_result : # Avoid yielding the same entry twice
150
+ yield entry
151
+
152
+ def find (self , credential_type , target = None , query = None ): # Obsolete. Use _find() instead.
153
+ return list (self ._find (credential_type , target = target , query = query ))
105
154
106
155
def add (self , event , now = None ):
107
156
"""Handle a token obtaining event, and add tokens into cache."""
@@ -160,7 +209,7 @@ def __add(self, event, now=None):
160
209
decode_id_token (id_token , client_id = event ["client_id" ]) if id_token else {})
161
210
client_info , home_account_id = self .__parse_account (response , id_token_claims )
162
211
163
- target = ' ' .join (event .get ("scope" ) or []) # Per schema, we don't sort it
212
+ target = ' ' .join (sorted ( event .get ("scope" ) or [])) # Schema should have required sorting
164
213
165
214
with self ._lock :
166
215
now = int (time .time () if now is None else now )
0 commit comments