@@ -499,3 +499,207 @@ async def _refresh_access_token(self) -> bool:
499
499
except Exception :
500
500
logger .exception ("Token refresh failed" )
501
501
return False
502
+
503
+
504
+ class ClientCredentialsProvider (httpx .Auth ):
505
+ """HTTPX auth using the OAuth2 client credentials grant."""
506
+
507
+ def __init__ (
508
+ self ,
509
+ server_url : str ,
510
+ client_metadata : OAuthClientMetadata ,
511
+ storage : TokenStorage ,
512
+ timeout : float = 300.0 ,
513
+ ):
514
+ self .server_url = server_url
515
+ self .client_metadata = client_metadata
516
+ self .storage = storage
517
+ self .timeout = timeout
518
+
519
+ self ._current_tokens : OAuthToken | None = None
520
+ self ._metadata : OAuthMetadata | None = None
521
+ self ._client_info : OAuthClientInformationFull | None = None
522
+ self ._token_expiry_time : float | None = None
523
+
524
+ self ._token_lock = anyio .Lock ()
525
+
526
+ def _get_authorization_base_url (self , server_url : str ) -> str :
527
+ from urllib .parse import urlparse , urlunparse
528
+
529
+ parsed = urlparse (server_url )
530
+ return urlunparse ((parsed .scheme , parsed .netloc , "" , "" , "" , "" ))
531
+
532
+ async def _discover_oauth_metadata (self , server_url : str ) -> OAuthMetadata | None :
533
+ auth_base_url = self ._get_authorization_base_url (server_url )
534
+ url = urljoin (auth_base_url , "/.well-known/oauth-authorization-server" )
535
+ headers = {"MCP-Protocol-Version" : LATEST_PROTOCOL_VERSION }
536
+
537
+ async with httpx .AsyncClient () as client :
538
+ try :
539
+ response = await client .get (url , headers = headers )
540
+ if response .status_code == 404 :
541
+ return None
542
+ response .raise_for_status ()
543
+ return OAuthMetadata .model_validate (response .json ())
544
+ except Exception :
545
+ try :
546
+ response = await client .get (url )
547
+ if response .status_code == 404 :
548
+ return None
549
+ response .raise_for_status ()
550
+ return OAuthMetadata .model_validate (response .json ())
551
+ except Exception :
552
+ logger .exception ("Failed to discover OAuth metadata" )
553
+ return None
554
+
555
+ async def _register_oauth_client (
556
+ self ,
557
+ server_url : str ,
558
+ client_metadata : OAuthClientMetadata ,
559
+ metadata : OAuthMetadata | None = None ,
560
+ ) -> OAuthClientInformationFull :
561
+ if not metadata :
562
+ metadata = await self ._discover_oauth_metadata (server_url )
563
+
564
+ if metadata and metadata .registration_endpoint :
565
+ registration_url = str (metadata .registration_endpoint )
566
+ else :
567
+ auth_base_url = self ._get_authorization_base_url (server_url )
568
+ registration_url = urljoin (auth_base_url , "/register" )
569
+
570
+ if (
571
+ client_metadata .scope is None
572
+ and metadata
573
+ and metadata .scopes_supported is not None
574
+ ):
575
+ client_metadata .scope = " " .join (metadata .scopes_supported )
576
+
577
+ registration_data = client_metadata .model_dump (
578
+ by_alias = True , mode = "json" , exclude_none = True
579
+ )
580
+
581
+ async with httpx .AsyncClient () as client :
582
+ response = await client .post (
583
+ registration_url ,
584
+ json = registration_data ,
585
+ headers = {"Content-Type" : "application/json" },
586
+ )
587
+
588
+ if response .status_code not in (200 , 201 ):
589
+ raise httpx .HTTPStatusError (
590
+ f"Registration failed: { response .status_code } " ,
591
+ request = response .request ,
592
+ response = response ,
593
+ )
594
+
595
+ return OAuthClientInformationFull .model_validate (response .json ())
596
+
597
+ def _has_valid_token (self ) -> bool :
598
+ if not self ._current_tokens or not self ._current_tokens .access_token :
599
+ return False
600
+
601
+ if self ._token_expiry_time and time .time () > self ._token_expiry_time :
602
+ return False
603
+ return True
604
+
605
+ async def _validate_token_scopes (self , token_response : OAuthToken ) -> None :
606
+ if not token_response .scope :
607
+ return
608
+
609
+ requested_scopes : set [str ] = set ()
610
+ if self .client_metadata .scope :
611
+ requested_scopes = set (self .client_metadata .scope .split ())
612
+ returned_scopes = set (token_response .scope .split ())
613
+ unauthorized_scopes = returned_scopes - requested_scopes
614
+ if unauthorized_scopes :
615
+ raise Exception (
616
+ f"Server granted unauthorized scopes: { unauthorized_scopes } ."
617
+ )
618
+ else :
619
+ granted = set (token_response .scope .split ())
620
+ logger .debug (
621
+ "No explicit scopes requested, accepting server-granted scopes: %s" ,
622
+ granted ,
623
+ )
624
+
625
+ async def initialize (self ) -> None :
626
+ self ._current_tokens = await self .storage .get_tokens ()
627
+ self ._client_info = await self .storage .get_client_info ()
628
+
629
+ async def _get_or_register_client (self ) -> OAuthClientInformationFull :
630
+ if not self ._client_info :
631
+ self ._client_info = await self ._register_oauth_client (
632
+ self .server_url , self .client_metadata , self ._metadata
633
+ )
634
+ await self .storage .set_client_info (self ._client_info )
635
+ return self ._client_info
636
+
637
+ async def _request_token (self ) -> None :
638
+ if not self ._metadata :
639
+ self ._metadata = await self ._discover_oauth_metadata (self .server_url )
640
+
641
+ client_info = await self ._get_or_register_client ()
642
+
643
+ if self ._metadata and self ._metadata .token_endpoint :
644
+ token_url = str (self ._metadata .token_endpoint )
645
+ else :
646
+ auth_base_url = self ._get_authorization_base_url (self .server_url )
647
+ token_url = urljoin (auth_base_url , "/token" )
648
+
649
+ token_data = {
650
+ "grant_type" : "client_credentials" ,
651
+ "client_id" : client_info .client_id ,
652
+ }
653
+
654
+ if client_info .client_secret :
655
+ token_data ["client_secret" ] = client_info .client_secret
656
+
657
+ if self .client_metadata .scope :
658
+ token_data ["scope" ] = self .client_metadata .scope
659
+
660
+ async with httpx .AsyncClient () as client :
661
+ response = await client .post (
662
+ token_url ,
663
+ data = token_data ,
664
+ headers = {"Content-Type" : "application/x-www-form-urlencoded" },
665
+ timeout = 30.0 ,
666
+ )
667
+
668
+ if response .status_code != 200 :
669
+ raise Exception (
670
+ f"Token request failed: { response .status_code } { response .text } "
671
+ )
672
+
673
+ token_response = OAuthToken .model_validate (response .json ())
674
+ await self ._validate_token_scopes (token_response )
675
+
676
+ if token_response .expires_in :
677
+ self ._token_expiry_time = time .time () + token_response .expires_in
678
+ else :
679
+ self ._token_expiry_time = None
680
+
681
+ await self .storage .set_tokens (token_response )
682
+ self ._current_tokens = token_response
683
+
684
+ async def ensure_token (self ) -> None :
685
+ async with self ._token_lock :
686
+ if self ._has_valid_token ():
687
+ return
688
+ await self ._request_token ()
689
+
690
+ async def async_auth_flow (
691
+ self , request : httpx .Request
692
+ ) -> AsyncGenerator [httpx .Request , httpx .Response ]:
693
+ if not self ._has_valid_token ():
694
+ await self .initialize ()
695
+ await self .ensure_token ()
696
+
697
+ if self ._current_tokens and self ._current_tokens .access_token :
698
+ request .headers ["Authorization" ] = (
699
+ f"Bearer { self ._current_tokens .access_token } "
700
+ )
701
+
702
+ response = yield request
703
+
704
+ if response .status_code == 401 :
705
+ self ._current_tokens = None
0 commit comments