@@ -41,13 +41,15 @@ def __init__(
41
41
42
42
43
43
class Provider :
44
- def __init__ (self , id : str ): # pylint: disable=redefined-builtin
44
+ def __init__ (
45
+ self , id : str , config : ProviderConfigForClient
46
+ ): # pylint: disable=redefined-builtin
45
47
self .id = id
46
- self .config = ProviderConfigForClientType ( "temp" )
48
+ self .config = config
47
49
48
50
async def get_config_for_client_type ( # pylint: disable=no-self-use
49
51
self , client_type : Optional [str ], user_context : Dict [str , Any ]
50
- ) -> ProviderConfigForClientType :
52
+ ) -> ProviderConfigForClient :
51
53
_ = client_type
52
54
__ = user_context
53
55
raise NotImplementedError ()
@@ -110,60 +112,6 @@ def to_json(self) -> Dict[str, Any]:
110
112
return {k : v for k , v in res .items () if v is not None }
111
113
112
114
113
- class ProviderConfigForClientType :
114
- def __init__ (
115
- self ,
116
- client_id : str ,
117
- client_secret : Optional [str ] = None ,
118
- scope : Optional [List [str ]] = None ,
119
- force_pkce : bool = False ,
120
- additional_config : Optional [Dict [str , Any ]] = None ,
121
- name : Optional [str ] = None ,
122
- authorization_endpoint : Optional [str ] = None ,
123
- authorization_endpoint_query_params : Optional [
124
- Dict [str , Union [str , None ]]
125
- ] = None ,
126
- token_endpoint : Optional [str ] = None ,
127
- token_endpoint_body_params : Optional [Dict [str , Union [str , None ]]] = None ,
128
- user_info_endpoint : Optional [str ] = None ,
129
- user_info_endpoint_query_params : Optional [Dict [str , Union [str , None ]]] = None ,
130
- user_info_endpoint_headers : Optional [Dict [str , Union [str , None ]]] = None ,
131
- jwks_uri : Optional [str ] = None ,
132
- oidc_discovery_endpoint : Optional [str ] = None ,
133
- user_info_map : Optional [UserInfoMap ] = None ,
134
- require_email : bool = True ,
135
- generate_fake_email : Optional [
136
- Callable [[str , str , Dict [str , Any ]], Awaitable [str ]]
137
- ] = None ,
138
- validate_id_token_payload : Optional [
139
- Callable [
140
- [Dict [str , Any ], ProviderConfigForClientType , Dict [str , Any ]],
141
- Awaitable [None ],
142
- ]
143
- ] = None ,
144
- ):
145
- self .client_id = client_id
146
- self .client_secret = client_secret
147
- self .scope = scope
148
- self .force_pkce = force_pkce
149
- self .additional_config = additional_config
150
-
151
- self .name = name
152
- self .authorization_endpoint = authorization_endpoint
153
- self .authorization_endpoint_query_params = authorization_endpoint_query_params
154
- self .token_endpoint = token_endpoint
155
- self .token_endpoint_body_params = token_endpoint_body_params
156
- self .user_info_endpoint = user_info_endpoint
157
- self .user_info_endpoint_query_params = user_info_endpoint_query_params
158
- self .user_info_endpoint_headers = user_info_endpoint_headers
159
- self .jwks_uri = jwks_uri
160
- self .oidc_discovery_endpoint = oidc_discovery_endpoint
161
- self .user_info_map = user_info_map
162
- self .require_email = require_email
163
- self .validate_id_token_payload = validate_id_token_payload
164
- self .generate_fake_email = generate_fake_email
165
-
166
-
167
115
class UserFields :
168
116
def __init__ (
169
117
self ,
@@ -201,12 +149,11 @@ def to_json(self) -> Dict[str, Any]:
201
149
}
202
150
203
151
204
- class ProviderConfig :
152
+ class CommonProviderConfig :
205
153
def __init__ (
206
154
self ,
207
155
third_party_id : str ,
208
156
name : Optional [str ] = None ,
209
- clients : Optional [List [ProviderClientConfig ]] = None ,
210
157
authorization_endpoint : Optional [str ] = None ,
211
158
authorization_endpoint_query_params : Optional [
212
159
Dict [str , Union [str , None ]]
@@ -222,7 +169,7 @@ def __init__(
222
169
require_email : bool = True ,
223
170
validate_id_token_payload : Optional [
224
171
Callable [
225
- [Dict [str , Any ], ProviderConfigForClientType , Dict [str , Any ]],
172
+ [Dict [str , Any ], ProviderConfigForClient , Dict [str , Any ]],
226
173
Awaitable [None ],
227
174
]
228
175
] = None ,
@@ -232,7 +179,6 @@ def __init__(
232
179
):
233
180
self .third_party_id = third_party_id
234
181
self .name = name
235
- self .clients = clients
236
182
self .authorization_endpoint = authorization_endpoint
237
183
self .authorization_endpoint_query_params = authorization_endpoint_query_params
238
184
self .token_endpoint = token_endpoint
@@ -251,9 +197,6 @@ def to_json(self) -> Dict[str, Any]:
251
197
res = {
252
198
"thirdPartyId" : self .third_party_id ,
253
199
"name" : self .name ,
254
- "clients" : [c .to_json () for c in self .clients ]
255
- if self .clients is not None
256
- else [],
257
200
"authorizationEndpoint" : self .authorization_endpoint ,
258
201
"authorizationEndpointQueryParams" : self .authorization_endpoint_query_params ,
259
202
"tokenEndpoint" : self .token_endpoint ,
@@ -272,6 +215,132 @@ def to_json(self) -> Dict[str, Any]:
272
215
return {k : v for k , v in res .items () if v is not None }
273
216
274
217
218
+ class ProviderConfigForClient (ProviderClientConfig , CommonProviderConfig ):
219
+ def __init__ (
220
+ self ,
221
+ # ProviderClientConfig:
222
+ client_id : str ,
223
+ client_secret : Optional [str ] = None ,
224
+ client_type : Optional [str ] = None ,
225
+ scope : Optional [List [str ]] = None ,
226
+ force_pkce : bool = False ,
227
+ additional_config : Optional [Dict [str , Any ]] = None ,
228
+ # CommonProviderConfig:
229
+ name : Optional [str ] = None ,
230
+ authorization_endpoint : Optional [str ] = None ,
231
+ authorization_endpoint_query_params : Optional [
232
+ Dict [str , Union [str , None ]]
233
+ ] = None ,
234
+ token_endpoint : Optional [str ] = None ,
235
+ token_endpoint_body_params : Optional [Dict [str , Union [str , None ]]] = None ,
236
+ user_info_endpoint : Optional [str ] = None ,
237
+ user_info_endpoint_query_params : Optional [Dict [str , Union [str , None ]]] = None ,
238
+ user_info_endpoint_headers : Optional [Dict [str , Union [str , None ]]] = None ,
239
+ jwks_uri : Optional [str ] = None ,
240
+ oidc_discovery_endpoint : Optional [str ] = None ,
241
+ user_info_map : Optional [UserInfoMap ] = None ,
242
+ require_email : bool = True ,
243
+ validate_id_token_payload : Optional [
244
+ Callable [
245
+ [Dict [str , Any ], ProviderConfigForClient , Dict [str , Any ]],
246
+ Awaitable [None ],
247
+ ]
248
+ ] = None ,
249
+ generate_fake_email : Optional [
250
+ Callable [[str , str , Dict [str , Any ]], Awaitable [str ]]
251
+ ] = None ,
252
+ ):
253
+ ProviderClientConfig .__init__ (
254
+ self ,
255
+ client_id ,
256
+ client_secret ,
257
+ client_type ,
258
+ scope ,
259
+ force_pkce ,
260
+ additional_config ,
261
+ )
262
+ CommonProviderConfig .__init__ (
263
+ self ,
264
+ "temp" ,
265
+ name ,
266
+ authorization_endpoint ,
267
+ authorization_endpoint_query_params ,
268
+ token_endpoint ,
269
+ token_endpoint_body_params ,
270
+ user_info_endpoint ,
271
+ user_info_endpoint_query_params ,
272
+ user_info_endpoint_headers ,
273
+ jwks_uri ,
274
+ oidc_discovery_endpoint ,
275
+ user_info_map ,
276
+ require_email ,
277
+ validate_id_token_payload ,
278
+ generate_fake_email ,
279
+ )
280
+
281
+ def to_json (self ) -> Dict [str , Any ]:
282
+ d1 = ProviderClientConfig .to_json (self )
283
+ d2 = CommonProviderConfig .to_json (self )
284
+ return {** d1 , ** d2 }
285
+
286
+
287
+ class ProviderConfig (CommonProviderConfig ):
288
+ def __init__ (
289
+ self ,
290
+ third_party_id : str ,
291
+ name : Optional [str ] = None ,
292
+ clients : Optional [List [ProviderClientConfig ]] = None ,
293
+ authorization_endpoint : Optional [str ] = None ,
294
+ authorization_endpoint_query_params : Optional [
295
+ Dict [str , Union [str , None ]]
296
+ ] = None ,
297
+ token_endpoint : Optional [str ] = None ,
298
+ token_endpoint_body_params : Optional [Dict [str , Union [str , None ]]] = None ,
299
+ user_info_endpoint : Optional [str ] = None ,
300
+ user_info_endpoint_query_params : Optional [Dict [str , Union [str , None ]]] = None ,
301
+ user_info_endpoint_headers : Optional [Dict [str , Union [str , None ]]] = None ,
302
+ jwks_uri : Optional [str ] = None ,
303
+ oidc_discovery_endpoint : Optional [str ] = None ,
304
+ user_info_map : Optional [UserInfoMap ] = None ,
305
+ require_email : bool = True ,
306
+ validate_id_token_payload : Optional [
307
+ Callable [
308
+ [Dict [str , Any ], ProviderConfigForClient , Dict [str , Any ]],
309
+ Awaitable [None ],
310
+ ]
311
+ ] = None ,
312
+ generate_fake_email : Optional [
313
+ Callable [[str , str , Dict [str , Any ]], Awaitable [str ]]
314
+ ] = None ,
315
+ ):
316
+ super ().__init__ (
317
+ third_party_id ,
318
+ name ,
319
+ authorization_endpoint ,
320
+ authorization_endpoint_query_params ,
321
+ token_endpoint ,
322
+ token_endpoint_body_params ,
323
+ user_info_endpoint ,
324
+ user_info_endpoint_query_params ,
325
+ user_info_endpoint_headers ,
326
+ jwks_uri ,
327
+ oidc_discovery_endpoint ,
328
+ user_info_map ,
329
+ require_email ,
330
+ validate_id_token_payload ,
331
+ generate_fake_email ,
332
+ )
333
+ self .clients = clients
334
+
335
+ def to_json (self ) -> Dict [str , Any ]:
336
+ d = CommonProviderConfig .to_json (self )
337
+
338
+ if self .clients is not None :
339
+ d ["clients" ] = [c .to_json () for c in self .clients ]
340
+
341
+ return d
342
+
343
+
275
344
class ProviderInput :
276
345
def __init__ (
277
346
self ,
0 commit comments