|
25 | 25 | import kafka.errors as Errors
|
26 | 26 | from kafka.future import Future
|
27 | 27 | from kafka.metrics.stats import Avg, Count, Max, Rate
|
| 28 | +from kafka.oauth.abstract import AbstractTokenProvider |
28 | 29 | from kafka.protocol.admin import SaslHandShakeRequest
|
29 | 30 | from kafka.protocol.commit import OffsetFetchRequest
|
30 | 31 | from kafka.protocol.metadata import MetadataRequest
|
@@ -184,6 +185,8 @@ class BrokerConnection(object):
|
184 | 185 | sasl mechanism handshake. Default: 'kafka'
|
185 | 186 | sasl_kerberos_domain_name (str): kerberos domain name to use in GSSAPI
|
186 | 187 | sasl mechanism handshake. Default: one of bootstrap servers
|
| 188 | + sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider |
| 189 | + instance. (See kafka.oauth.abstract). Default: None |
187 | 190 | """
|
188 | 191 |
|
189 | 192 | DEFAULT_CONFIG = {
|
@@ -216,10 +219,11 @@ class BrokerConnection(object):
|
216 | 219 | 'sasl_plain_username': None,
|
217 | 220 | 'sasl_plain_password': None,
|
218 | 221 | 'sasl_kerberos_service_name': 'kafka',
|
219 |
| - 'sasl_kerberos_domain_name': None |
| 222 | + 'sasl_kerberos_domain_name': None, |
| 223 | + 'sasl_oauth_token_provider': None |
220 | 224 | }
|
221 | 225 | SECURITY_PROTOCOLS = ('PLAINTEXT', 'SSL', 'SASL_PLAINTEXT', 'SASL_SSL')
|
222 |
| - SASL_MECHANISMS = ('PLAIN', 'GSSAPI') |
| 226 | + SASL_MECHANISMS = ('PLAIN', 'GSSAPI', 'OAUTHBEARER') |
223 | 227 |
|
224 | 228 | def __init__(self, host, port, afi, **configs):
|
225 | 229 | self.host = host
|
@@ -263,7 +267,10 @@ def __init__(self, host, port, afi, **configs):
|
263 | 267 | if self.config['sasl_mechanism'] == 'GSSAPI':
|
264 | 268 | assert gssapi is not None, 'GSSAPI lib not available'
|
265 | 269 | assert self.config['sasl_kerberos_service_name'] is not None, 'sasl_kerberos_service_name required for GSSAPI sasl'
|
266 |
| - |
| 270 | + if self.config['sasl_mechanism'] == 'OAUTHBEARER': |
| 271 | + token_provider = self.config['sasl_oauth_token_provider'] |
| 272 | + assert token_provider is not None, 'sasl_oauth_token_provider required for OAUTHBEARER sasl' |
| 273 | + assert callable(getattr(token_provider, "token", None)), 'sasl_oauth_token_provider must implement method #token()' |
267 | 274 | # This is not a general lock / this class is not generally thread-safe yet
|
268 | 275 | # However, to avoid pushing responsibility for maintaining
|
269 | 276 | # per-connection locks to the upstream client, we will use this lock to
|
@@ -537,6 +544,8 @@ def _handle_sasl_handshake_response(self, future, response):
|
537 | 544 | return self._try_authenticate_plain(future)
|
538 | 545 | elif self.config['sasl_mechanism'] == 'GSSAPI':
|
539 | 546 | return self._try_authenticate_gssapi(future)
|
| 547 | + elif self.config['sasl_mechanism'] == 'OAUTHBEARER': |
| 548 | + return self._try_authenticate_oauth(future) |
540 | 549 | else:
|
541 | 550 | return future.failure(
|
542 | 551 | Errors.UnsupportedSaslMechanismError(
|
@@ -660,6 +669,51 @@ def _try_authenticate_gssapi(self, future):
|
660 | 669 | log.info('%s: Authenticated as %s via GSSAPI', self, gssapi_name)
|
661 | 670 | return future.success(True)
|
662 | 671 |
|
| 672 | + def _try_authenticate_oauth(self, future): |
| 673 | + data = b'' |
| 674 | + |
| 675 | + msg = bytes(self._build_oauth_client_request().encode("utf-8")) |
| 676 | + size = Int32.encode(len(msg)) |
| 677 | + try: |
| 678 | + # Send SASL OAuthBearer request with OAuth token |
| 679 | + self._send_bytes_blocking(size + msg) |
| 680 | + |
| 681 | + # The server will send a zero sized message (that is Int32(0)) on success. |
| 682 | + # The connection is closed on failure |
| 683 | + data = self._recv_bytes_blocking(4) |
| 684 | + |
| 685 | + except ConnectionError as e: |
| 686 | + log.exception("%s: Error receiving reply from server", self) |
| 687 | + error = Errors.KafkaConnectionError("%s: %s" % (self, e)) |
| 688 | + self.close(error=error) |
| 689 | + return future.failure(error) |
| 690 | + |
| 691 | + if data != b'\x00\x00\x00\x00': |
| 692 | + error = Errors.AuthenticationFailedError('Unrecognized response during authentication') |
| 693 | + return future.failure(error) |
| 694 | + |
| 695 | + log.info('%s: Authenticated via OAuth', self) |
| 696 | + return future.success(True) |
| 697 | + |
| 698 | + def _build_oauth_client_request(self): |
| 699 | + token_provider = self.config['sasl_oauth_token_provider'] |
| 700 | + return "n,,\x01auth=Bearer {}{}\x01\x01".format(token_provider.token(), self._token_extensions()) |
| 701 | + |
| 702 | + def _token_extensions(self): |
| 703 | + """ |
| 704 | + Return a string representation of the OPTIONAL key-value pairs that can be sent with an OAUTHBEARER |
| 705 | + initial request. |
| 706 | + """ |
| 707 | + token_provider = self.config['sasl_oauth_token_provider'] |
| 708 | + |
| 709 | + # Only run if the #extensions() method is implemented by the clients Token Provider class |
| 710 | + # Builds up a string separated by \x01 via a dict of key value pairs |
| 711 | + if callable(getattr(token_provider, "extensions", None)) and len(token_provider.extensions()) > 0: |
| 712 | + msg = "\x01".join(["{}={}".format(k, v) for k, v in token_provider.extensions().items()]) |
| 713 | + return "\x01" + msg |
| 714 | + else: |
| 715 | + return "" |
| 716 | + |
663 | 717 | def blacked_out(self):
|
664 | 718 | """
|
665 | 719 | Return true if we are disconnected from the given node and can't
|
|
0 commit comments