1
- from json import JSONDecodeError
2
1
import logging
2
+ from json import JSONDecodeError
3
3
from typing import Callable
4
4
from typing import Optional
5
5
from typing import Union
12
12
from idpyoidc .client .exception import OidcServiceError
13
13
from idpyoidc .client .exception import ParseError
14
14
from idpyoidc .client .service import REQUEST_INFO
15
- from idpyoidc .client .service import SUCCESSFUL
16
15
from idpyoidc .client .service import Service
16
+ from idpyoidc .client .service import SUCCESSFUL
17
17
from idpyoidc .client .util import do_add_ons
18
18
from idpyoidc .client .util import get_deserialization_method
19
19
from idpyoidc .configure import Configuration
26
26
27
27
logger = logging .getLogger (__name__ )
28
28
29
- Version = "2.0"
30
-
31
29
32
30
class ExpiredToken (Exception ):
33
31
pass
@@ -40,20 +38,20 @@ class Client(Entity):
40
38
client_type = "oauth2"
41
39
42
40
def __init__ (
43
- self ,
44
- keyjar : Optional [KeyJar ] = None ,
45
- config : Optional [Union [dict , Configuration ]] = None ,
46
- services : Optional [dict ] = None ,
47
- httpc : Optional [Callable ] = None ,
48
- httpc_params : Optional [dict ] = None ,
49
- context : Optional [OidcContext ] = None ,
50
- upstream_get : Optional [Callable ] = None ,
51
- key_conf : Optional [dict ] = None ,
52
- entity_id : Optional [str ] = "" ,
53
- verify_ssl : Optional [bool ] = True ,
54
- jwks_uri : Optional [str ] = "" ,
55
- client_type : Optional [str ] = "" ,
56
- ** kwargs
41
+ self ,
42
+ keyjar : Optional [KeyJar ] = None ,
43
+ config : Optional [Union [dict , Configuration ]] = None ,
44
+ services : Optional [dict ] = None ,
45
+ httpc : Optional [Callable ] = None ,
46
+ httpc_params : Optional [dict ] = None ,
47
+ context : Optional [OidcContext ] = None ,
48
+ upstream_get : Optional [Callable ] = None ,
49
+ key_conf : Optional [dict ] = None ,
50
+ entity_id : Optional [str ] = "" ,
51
+ verify_ssl : Optional [bool ] = True ,
52
+ jwks_uri : Optional [str ] = "" ,
53
+ client_type : Optional [str ] = "" ,
54
+ ** kwargs
57
55
):
58
56
"""
59
57
@@ -70,7 +68,11 @@ def __init__(
70
68
:return: Client instance
71
69
"""
72
70
73
- if not client_type :
71
+ if client_type :
72
+ self .client_type = client_type
73
+ elif config and 'client_type' in config :
74
+ client_type = self .client_type = config ["client_type" ]
75
+ else :
74
76
client_type = self .client_type
75
77
76
78
if verify_ssl is False :
@@ -80,6 +82,8 @@ def __init__(
80
82
else :
81
83
httpc_params = {"verify" : False }
82
84
85
+ jwks_uri = jwks_uri or config .get ('jwks_uri' , '' )
86
+
83
87
Entity .__init__ (
84
88
self ,
85
89
keyjar = keyjar ,
@@ -106,12 +110,12 @@ def __init__(
106
110
do_add_ons (_add_ons , self ._service )
107
111
108
112
def do_request (
109
- self ,
110
- request_type : str ,
111
- response_body_type : Optional [str ] = "" ,
112
- request_args : Optional [dict ] = None ,
113
- behaviour_args : Optional [dict ] = None ,
114
- ** kwargs
113
+ self ,
114
+ request_type : str ,
115
+ response_body_type : Optional [str ] = "" ,
116
+ request_args : Optional [dict ] = None ,
117
+ behaviour_args : Optional [dict ] = None ,
118
+ ** kwargs
115
119
):
116
120
_srv = self ._service [request_type ]
117
121
@@ -134,14 +138,14 @@ def set_client_id(self, client_id):
134
138
self .get_context ().set ("client_id" , client_id )
135
139
136
140
def get_response (
137
- self ,
138
- service : Service ,
139
- url : str ,
140
- method : Optional [str ] = "GET" ,
141
- body : Optional [dict ] = None ,
142
- response_body_type : Optional [str ] = "" ,
143
- headers : Optional [dict ] = None ,
144
- ** kwargs
141
+ self ,
142
+ service : Service ,
143
+ url : str ,
144
+ method : Optional [str ] = "GET" ,
145
+ body : Optional [dict ] = None ,
146
+ response_body_type : Optional [str ] = "" ,
147
+ headers : Optional [dict ] = None ,
148
+ ** kwargs
145
149
):
146
150
"""
147
151
@@ -177,14 +181,14 @@ def get_response(
177
181
return self .parse_request_response (service , resp , response_body_type , ** kwargs )
178
182
179
183
def service_request (
180
- self ,
181
- service : Service ,
182
- url : str ,
183
- method : Optional [str ] = "GET" ,
184
- body : Optional [dict ] = None ,
185
- response_body_type : Optional [str ] = "" ,
186
- headers : Optional [dict ] = None ,
187
- ** kwargs
184
+ self ,
185
+ service : Service ,
186
+ url : str ,
187
+ method : Optional [str ] = "GET" ,
188
+ body : Optional [dict ] = None ,
189
+ response_body_type : Optional [str ] = "" ,
190
+ headers : Optional [dict ] = None ,
191
+ ** kwargs
188
192
) -> Message :
189
193
"""
190
194
The method that sends the request and handles the response returned.
@@ -312,17 +316,20 @@ def dynamic_provider_info_discovery(client: Client, behaviour_args: Optional[dic
312
316
:param behaviour_args:
313
317
:param client: A :py:class:`idpyoidc.client.oidc.Client` instance
314
318
"""
319
+
320
+ if client .client_type == 'oidc' and client .get_service ("provider_info" ):
321
+ service = 'provider_info'
322
+ elif client .client_type == 'oauth2' and client .get_service ('server_metadata' ):
323
+ service = 'server_metadata'
324
+ else :
325
+ raise ConfigurationError ("Can not do dynamic provider info discovery" )
326
+
327
+ _context = client .get_context ()
315
328
try :
316
- client . get_service ( "provider_info" )
329
+ _context . set ( "issuer" , _context . config [ "srv_discovery_url" ] )
317
330
except KeyError :
318
- raise ConfigurationError ("Can not do dynamic provider info discovery" )
319
- else :
320
- _context = client .get_context ()
321
- try :
322
- _context .set ("issuer" , _context .config ["srv_discovery_url" ])
323
- except KeyError :
324
- pass
331
+ pass
325
332
326
- response = client .do_request ("provider_info" , behaviour_args = behaviour_args )
327
- if is_error_message (response ):
328
- raise OidcServiceError (response ["error" ])
333
+ response = client .do_request (service , behaviour_args = behaviour_args )
334
+ if is_error_message (response ):
335
+ raise OidcServiceError (response ["error" ])
0 commit comments