Skip to content

Improve KafkaConnection with more tests #196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Aug 22, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 60 additions & 18 deletions kafka/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,28 +67,38 @@ def __repr__(self):
###################

def _raise_connection_error(self):
self._dirty = True
# Cleanup socket if we have one
if self._sock:
self.close()

# And then raise
raise ConnectionError("Kafka @ {0}:{1} went away".format(self.host, self.port))

def _read_bytes(self, num_bytes):
bytes_left = num_bytes
responses = []

log.debug("About to read %d bytes from Kafka", num_bytes)
if self._dirty:

# Make sure we have a connection
if not self._sock:
self.reinit()

while bytes_left:

try:
data = self._sock.recv(min(bytes_left, 4096))

# Receiving empty string from recv signals
# that the socket is in error. we will never get
# more data from this socket
if data == '':
raise socket.error('Not enough data to read message -- did server kill socket?')

except socket.error:
log.exception('Unable to receive data from Kafka')
self._raise_connection_error()

if data == '':
log.error("Not enough data to read this response")
self._raise_connection_error()

bytes_left -= len(data)
log.debug("Read %d/%d bytes from Kafka", num_bytes - bytes_left, num_bytes)
responses.append(data)
Expand All @@ -102,26 +112,34 @@ def _read_bytes(self, num_bytes):
# TODO multiplex socket communication to allow for multi-threaded clients

def send(self, request_id, payload):
"Send a request to Kafka"
"""
Send a request to Kafka
param: request_id -- can be any int (used only for debug logging...)
param: payload -- an encoded kafka packet (see KafkaProtocol)
"""

log.debug("About to send %d bytes to Kafka, request %d" % (len(payload), request_id))

# Make sure we have a connection
if not self._sock:
self.reinit()

try:
if self._dirty:
self.reinit()
sent = self._sock.sendall(payload)
if sent is not None:
self._raise_connection_error()
self._sock.sendall(payload)
except socket.error:
log.exception('Unable to send payload to Kafka')
self._raise_connection_error()

def recv(self, request_id):
"""
Get a response from Kafka
Get a response packet from Kafka
param: request_id -- can be any int (only used for debug logging...)
returns encoded kafka packet response from server as type str
"""
log.debug("Reading response %d from Kafka" % request_id)

# Read the size off of the header
resp = self._read_bytes(4)

(size,) = struct.unpack('>i', resp)

# Read the remainder of the response
Expand All @@ -132,22 +150,46 @@ def copy(self):
"""
Create an inactive copy of the connection object
A reinit() has to be done on the copy before it can be used again
return a new KafkaConnection object
"""
c = copy.deepcopy(self)
c._sock = None
return c

def close(self):
"""
Close this connection
Shutdown and close the connection socket
"""
log.debug("Closing socket connection for %s:%d" % (self.host, self.port))
if self._sock:
# Call shutdown to be a good TCP client
# But expect an error if the socket has already been
# closed by the server
try:
self._sock.shutdown(socket.SHUT_RDWR)
except socket.error:
pass

# Closing the socket should always succeed
self._sock.close()
self._sock = None
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could also set self._dirty = None here

else:
log.debug("No socket found to close!")

def reinit(self):
"""
Re-initialize the socket connection
close current socket (if open)
and start a fresh connection
raise ConnectionError on error
"""
self.close()
self._sock = socket.create_connection((self.host, self.port), self.timeout)
self._dirty = False
log.debug("Reinitializing socket connection for %s:%d" % (self.host, self.port))

if self._sock:
self.close()

try:
self._sock = socket.create_connection((self.host, self.port), self.timeout)
except socket.error:
log.exception('Unable to connect to kafka broker at %s:%d' % (self.host, self.port))
self._raise_connection_error()
139 changes: 117 additions & 22 deletions test/test_conn.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,52 @@
import os
import random
import socket
import struct

import mock
import unittest2
import kafka.conn

from kafka.common import *
from kafka.conn import *

class ConnTest(unittest2.TestCase):
def setUp(self):
self.config = {
'host': 'localhost',
'port': 9090,
'request_id': 0,
'payload': 'test data',
'payload2': 'another packet'
}

# Mocking socket.create_connection will cause _sock to always be a
# MagicMock()
patcher = mock.patch('socket.create_connection', spec=True)
self.MockCreateConn = patcher.start()
self.addCleanup(patcher.stop)

# Also mock socket.sendall() to appear successful
socket.create_connection().sendall.return_value = None

# And mock socket.recv() to return two payloads, then '', then raise
# Note that this currently ignores the num_bytes parameter to sock.recv()
payload_size = len(self.config['payload'])
payload2_size = len(self.config['payload2'])
socket.create_connection().recv.side_effect = [
struct.pack('>i', payload_size),
struct.pack('>%ds' % payload_size, self.config['payload']),
struct.pack('>i', payload2_size),
struct.pack('>%ds' % payload2_size, self.config['payload2']),
''
]

# Create a connection object
self.conn = KafkaConnection(self.config['host'], self.config['port'])

# Reset any mock counts caused by __init__
socket.create_connection.reset_mock()

def test_collect_hosts__happy_path(self):
hosts = "localhost:1234,localhost"
results = kafka.conn.collect_hosts(hosts)
results = collect_hosts(hosts)

self.assertEqual(set(results), set([
('localhost', 1234),
Expand All @@ -20,7 +59,7 @@ def test_collect_hosts__string_list(self):
'localhost',
]

results = kafka.conn.collect_hosts(hosts)
results = collect_hosts(hosts)

self.assertEqual(set(results), set([
('localhost', 1234),
Expand All @@ -29,41 +68,97 @@ def test_collect_hosts__string_list(self):

def test_collect_hosts__with_spaces(self):
hosts = "localhost:1234, localhost"
results = kafka.conn.collect_hosts(hosts)
results = collect_hosts(hosts)

self.assertEqual(set(results), set([
('localhost', 1234),
('localhost', 9092),
]))

@unittest2.skip("Not Implemented")
def test_send(self):
pass
self.conn.send(self.config['request_id'], self.config['payload'])
self.conn._sock.sendall.assert_called_with(self.config['payload'])

def test_init_creates_socket_connection(self):
KafkaConnection(self.config['host'], self.config['port'])
socket.create_connection.assert_called_with((self.config['host'], self.config['port']), DEFAULT_SOCKET_TIMEOUT_SECONDS)

def test_init_failure_raises_connection_error(self):

def raise_error(*args):
raise socket.error

assert socket.create_connection is self.MockCreateConn
socket.create_connection.side_effect=raise_error
with self.assertRaises(ConnectionError):
KafkaConnection(self.config['host'], self.config['port'])

@unittest2.skip("Not Implemented")
def test_send__reconnects_on_dirty_conn(self):
pass

@unittest2.skip("Not Implemented")
# Dirty the connection
try:
self.conn._raise_connection_error()
except ConnectionError:
pass

# Now test that sending attempts to reconnect
self.assertEqual(socket.create_connection.call_count, 0)
self.conn.send(self.config['request_id'], self.config['payload'])
self.assertEqual(socket.create_connection.call_count, 1)

def test_send__failure_sets_dirty_connection(self):
pass

@unittest2.skip("Not Implemented")
def raise_error(*args):
raise socket.error

assert isinstance(self.conn._sock, mock.Mock)
self.conn._sock.sendall.side_effect=raise_error
try:
self.conn.send(self.config['request_id'], self.config['payload'])
except ConnectionError:
self.assertIsNone(self.conn._sock)

def test_recv(self):
pass

@unittest2.skip("Not Implemented")
self.assertEquals(self.conn.recv(self.config['request_id']), self.config['payload'])

def test_recv__reconnects_on_dirty_conn(self):
pass

@unittest2.skip("Not Implemented")
# Dirty the connection
try:
self.conn._raise_connection_error()
except ConnectionError:
pass

# Now test that recv'ing attempts to reconnect
self.assertEqual(socket.create_connection.call_count, 0)
self.conn.recv(self.config['request_id'])
self.assertEqual(socket.create_connection.call_count, 1)

def test_recv__failure_sets_dirty_connection(self):
pass

@unittest2.skip("Not Implemented")
def raise_error(*args):
raise socket.error

# test that recv'ing attempts to reconnect
assert isinstance(self.conn._sock, mock.Mock)
self.conn._sock.recv.side_effect=raise_error
try:
self.conn.recv(self.config['request_id'])
except ConnectionError:
self.assertIsNone(self.conn._sock)

def test_recv__doesnt_consume_extra_data_in_stream(self):
pass

@unittest2.skip("Not Implemented")
# Here just test that each call to recv will return a single payload
self.assertEquals(self.conn.recv(self.config['request_id']), self.config['payload'])
self.assertEquals(self.conn.recv(self.config['request_id']), self.config['payload2'])

def test_close__object_is_reusable(self):
pass

# test that sending to a closed connection
# will re-connect and send data to the socket
self.conn.close()
self.conn.send(self.config['request_id'], self.config['payload'])
self.assertEqual(socket.create_connection.call_count, 1)
self.conn._sock.sendall.assert_called_with(self.config['payload'])