Skip to content

Commit efeb037

Browse files
authored
Merge pull request #258 from keisku/support-ipv6
fix: Support IPv6
1 parent dd38fed commit efeb037

File tree

3 files changed

+71
-3
lines changed

3 files changed

+71
-3
lines changed

bmemcached/protocol.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from urllib.parse import SplitResult # type: ignore[import-not-found]
1010

1111
import zlib
12+
from ipaddress import ip_address
1213
from io import BytesIO
1314
import six
1415
from six import binary_type, text_type
@@ -144,9 +145,7 @@ def _open_connection(self):
144145

145146
try:
146147
if self.host:
147-
self.connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
148-
self.connection.settimeout(self.socket_timeout)
149-
self.connection.connect((self.host, self.port))
148+
self.connection = socket.create_connection((self.host, self.port), self.socket_timeout)
150149

151150
if self.tls_context:
152151
self.connection = self.tls_context.wrap_socket(
@@ -174,11 +173,38 @@ def split_host_port(cls, server):
174173
175174
Port defaults to 11211.
176175
176+
When using IPv6 with a specified port, the address must be enclosed in brackets.
177+
If the port is not specified, brackets are optional.
178+
177179
>>> split_host_port('127.0.0.1:11211')
178180
('127.0.0.1', 11211)
179181
>>> split_host_port('127.0.0.1')
180182
('127.0.0.1', 11211)
183+
>>> split_host_port('::1')
184+
('::1', 11211)
185+
>>> split_host_port('[::1]')
186+
('::1', 11211)
187+
>>> split_host_port('[::1]:11211')
188+
('::1', 11211)
181189
"""
190+
default_port = 11211
191+
192+
def is_ip_address(address):
193+
try:
194+
ip_address(address)
195+
return True
196+
except ValueError:
197+
return False
198+
199+
if is_ip_address(server):
200+
return server, default_port
201+
202+
if server.startswith('['):
203+
host, _, port = server[1:].partition(']')
204+
if not is_ip_address(host):
205+
raise ValueError('{} is not a valid IPv6 address'.format(server))
206+
return host, default_port if not port else int(port.lstrip(':'))
207+
182208
u = SplitResult("", server, "", "", "")
183209
return u.hostname, 11211 if u.port is None else u.port
184210

test/conftest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,16 @@ def memcached_socket():
4141
yield p
4242
p.kill()
4343
p.wait()
44+
45+
46+
@pytest.yield_fixture(scope="session", autouse=True)
47+
def memcached_ipv6():
48+
p = subprocess.Popen(
49+
["memcached", "-l::1"],
50+
stdout=subprocess.PIPE,
51+
stderr=subprocess.PIPE,
52+
)
53+
time.sleep(0.1)
54+
yield p
55+
p.kill()
56+
p.wait()

test/test_server_parsing.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,38 @@ def testNoPortGiven(self):
2626
self.assertEqual(server.host, os.environ['MEMCACHED_HOST'])
2727
self.assertEqual(server.port, 11211)
2828

29+
def testIPv6(self):
30+
server = bmemcached.protocol.Protocol('[::1]')
31+
self.assertEqual(server.host, '::1')
32+
self.assertEqual(server.port, 11211)
33+
server = bmemcached.protocol.Protocol('::1')
34+
self.assertEqual(server.host, '::1')
35+
self.assertEqual(server.port, 11211)
36+
server = bmemcached.protocol.Protocol('[2001:db8::2]')
37+
self.assertEqual(server.host, '2001:db8::2')
38+
self.assertEqual(server.port, 11211)
39+
server = bmemcached.protocol.Protocol('2001:db8::2')
40+
self.assertEqual(server.host, '2001:db8::2')
41+
self.assertEqual(server.port, 11211)
42+
# Since `2001:db8::2:8080` is a valid IPv6 address,
43+
# it is ambiguous whether to split it into `2001:db8::2` and `8080`
44+
# or treat it as `2001:db8::2:8080`.
45+
# Therefore, it will be treated as `2001:db8::2:8080`.
46+
server = bmemcached.protocol.Protocol('2001:db8::2:8080')
47+
self.assertEqual(server.host, '2001:db8::2:8080')
48+
self.assertEqual(server.port, 11211)
49+
server = bmemcached.protocol.Protocol('[::1]:5000')
50+
self.assertEqual(server.host, '::1')
51+
self.assertEqual(server.port, 5000)
52+
server = bmemcached.protocol.Protocol('[2001:db8::2]:5000')
53+
self.assertEqual(server.host, '2001:db8::2')
54+
self.assertEqual(server.port, 5000)
55+
2956
def testInvalidPort(self):
3057
with self.assertRaises(ValueError):
3158
bmemcached.protocol.Protocol('{}:blah'.format(os.environ['MEMCACHED_HOST']))
59+
with self.assertRaises(ValueError):
60+
bmemcached.protocol.Protocol('[::1]:blah')
3261

3362
def testNonStandardPort(self):
3463
server = bmemcached.protocol.Protocol('{}:5000'.format(os.environ['MEMCACHED_HOST']))

0 commit comments

Comments
 (0)