Skip to content

Commit 457bcbb

Browse files
authored
add msk to __init__ and check for extension in conn.py
1 parent ebcfcb1 commit 457bcbb

File tree

3 files changed

+191
-1
lines changed

3 files changed

+191
-1
lines changed

kafka/conn.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,9 +261,14 @@ def __init__(self, host, port, afi, **configs):
261261
assert self.config['security_protocol'] in self.SECURITY_PROTOCOLS, (
262262
'security_protocol must be in ' + ', '.join(self.SECURITY_PROTOCOLS))
263263

264+
264265
if self.config['security_protocol'] in ('SSL', 'SASL_SSL'):
265266
assert ssl_available, "Python wasn't built with SSL support"
266267

268+
if self.config['sasl_mechanism'] == 'AWS_MSK_IAM':
269+
assert BotoSession is not None, 'AWS_MSK_IAM requires the "botocore" package'
270+
assert self.config['security_protocol'] == 'SASL_SSL', 'AWS_MSK_IAM requires SASL_SSL'
271+
267272
if self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL'):
268273
assert self.config['sasl_mechanism'] in sasl.MECHANISMS, (
269274
'sasl_mechanism must be one of {}'.format(', '.join(sasl.MECHANISMS.keys()))

kafka/sasl/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
from kafka.sasl import gssapi, oauthbearer, plain, scram
3+
from kafka.sasl import gssapi, oauthbearer, plain, scram, msk
44

55
log = logging.getLogger(__name__)
66

@@ -10,6 +10,7 @@
1010
'PLAIN': plain,
1111
'SCRAM-SHA-256': scram,
1212
'SCRAM-SHA-512': scram,
13+
'AWS_MSK_IAM', msk,
1314
}
1415

1516

kafka/sasl/msk.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import datetime
2+
import hashlib
3+
import hmac
4+
import json
5+
import string
6+
7+
from kafka.vendor.six.moves import urllib
8+
9+
10+
class AwsMskIamClient:
11+
UNRESERVED_CHARS = string.ascii_letters + string.digits + '-._~'
12+
13+
def __init__(self, host, access_key, secret_key, region, token=None):
14+
"""
15+
Arguments:
16+
host (str): The hostname of the broker.
17+
access_key (str): An AWS_ACCESS_KEY_ID.
18+
secret_key (str): An AWS_SECRET_ACCESS_KEY.
19+
region (str): An AWS_REGION.
20+
token (Optional[str]): An AWS_SESSION_TOKEN if using temporary
21+
credentials.
22+
"""
23+
self.algorithm = 'AWS4-HMAC-SHA256'
24+
self.expires = '900'
25+
self.hashfunc = hashlib.sha256
26+
self.headers = [
27+
('host', host)
28+
]
29+
self.version = '2020_10_22'
30+
31+
self.service = 'kafka-cluster'
32+
self.action = '{}:Connect'.format(self.service)
33+
34+
now = datetime.datetime.utcnow()
35+
self.datestamp = now.strftime('%Y%m%d')
36+
self.timestamp = now.strftime('%Y%m%dT%H%M%SZ')
37+
38+
self.host = host
39+
self.access_key = access_key
40+
self.secret_key = secret_key
41+
self.region = region
42+
self.token = token
43+
44+
@property
45+
def _credential(self):
46+
return '{0.access_key}/{0._scope}'.format(self)
47+
48+
@property
49+
def _scope(self):
50+
return '{0.datestamp}/{0.region}/{0.service}/aws4_request'.format(self)
51+
52+
@property
53+
def _signed_headers(self):
54+
"""
55+
Returns (str):
56+
An alphabetically sorted, semicolon-delimited list of lowercase
57+
request header names.
58+
"""
59+
return ';'.join(sorted(k.lower() for k, _ in self.headers))
60+
61+
@property
62+
def _canonical_headers(self):
63+
"""
64+
Returns (str):
65+
A newline-delited list of header names and values.
66+
Header names are lowercased.
67+
"""
68+
return '\n'.join(map(':'.join, self.headers)) + '\n'
69+
70+
@property
71+
def _canonical_request(self):
72+
"""
73+
Returns (str):
74+
An AWS Signature Version 4 canonical request in the format:
75+
<Method>\n
76+
<Path>\n
77+
<CanonicalQueryString>\n
78+
<CanonicalHeaders>\n
79+
<SignedHeaders>\n
80+
<HashedPayload>
81+
"""
82+
# The hashed_payload is always an empty string for MSK.
83+
hashed_payload = self.hashfunc(b'').hexdigest()
84+
return '\n'.join((
85+
'GET',
86+
'/',
87+
self._canonical_querystring,
88+
self._canonical_headers,
89+
self._signed_headers,
90+
hashed_payload,
91+
))
92+
93+
@property
94+
def _canonical_querystring(self):
95+
"""
96+
Returns (str):
97+
A '&'-separated list of URI-encoded key/value pairs.
98+
"""
99+
params = []
100+
params.append(('Action', self.action))
101+
params.append(('X-Amz-Algorithm', self.algorithm))
102+
params.append(('X-Amz-Credential', self._credential))
103+
params.append(('X-Amz-Date', self.timestamp))
104+
params.append(('X-Amz-Expires', self.expires))
105+
if self.token:
106+
params.append(('X-Amz-Security-Token', self.token))
107+
params.append(('X-Amz-SignedHeaders', self._signed_headers))
108+
109+
return '&'.join(self._uriencode(k) + '=' + self._uriencode(v) for k, v in params)
110+
111+
@property
112+
def _signing_key(self):
113+
"""
114+
Returns (bytes):
115+
An AWS Signature V4 signing key generated from the secret_key, date,
116+
region, service, and request type.
117+
"""
118+
key = self._hmac(('AWS4' + self.secret_key).encode('utf-8'), self.datestamp)
119+
key = self._hmac(key, self.region)
120+
key = self._hmac(key, self.service)
121+
key = self._hmac(key, 'aws4_request')
122+
return key
123+
124+
@property
125+
def _signing_str(self):
126+
"""
127+
Returns (str):
128+
A string used to sign the AWS Signature V4 payload in the format:
129+
<Algorithm>\n
130+
<Timestamp>\n
131+
<Scope>\n
132+
<CanonicalRequestHash>
133+
"""
134+
canonical_request_hash = self.hashfunc(self._canonical_request.encode('utf-8')).hexdigest()
135+
return '\n'.join((self.algorithm, self.timestamp, self._scope, canonical_request_hash))
136+
137+
def _uriencode(self, msg):
138+
"""
139+
Arguments:
140+
msg (str): A string to URI-encode.
141+
142+
Returns (str):
143+
The URI-encoded version of the provided msg, following the encoding
144+
rules specified: https://github.com/aws/aws-msk-iam-auth#uriencode
145+
"""
146+
return urllib.parse.quote(msg, safe=self.UNRESERVED_CHARS)
147+
148+
def _hmac(self, key, msg):
149+
"""
150+
Arguments:
151+
key (bytes): A key to use for the HMAC digest.
152+
msg (str): A value to include in the HMAC digest.
153+
Returns (bytes):
154+
An HMAC digest of the given key and msg.
155+
"""
156+
return hmac.new(key, msg.encode('utf-8'), digestmod=self.hashfunc).digest()
157+
158+
def first_message(self):
159+
"""
160+
Returns (bytes):
161+
An encoded JSON authentication payload that can be sent to the
162+
broker.
163+
"""
164+
signature = hmac.new(
165+
self._signing_key,
166+
self._signing_str.encode('utf-8'),
167+
digestmod=self.hashfunc,
168+
).hexdigest()
169+
msg = {
170+
'version': self.version,
171+
'host': self.host,
172+
'user-agent': 'kafka-python',
173+
'action': self.action,
174+
'x-amz-algorithm': self.algorithm,
175+
'x-amz-credential': self._credential,
176+
'x-amz-date': self.timestamp,
177+
'x-amz-signedheaders': self._signed_headers,
178+
'x-amz-expires': self.expires,
179+
'x-amz-signature': signature,
180+
}
181+
if self.token:
182+
msg['x-amz-security-token'] = self.token
183+
184+
return json.dumps(msg, separators=(',', ':')).encode('utf-8')

0 commit comments

Comments
 (0)