Skip to content

Commit 31814b6

Browse files
committed
Introduce Async::Redis::Endpoint.
- Handles authentication and database selection.
1 parent 32aaf45 commit 31814b6

File tree

6 files changed

+318
-15
lines changed

6 files changed

+318
-15
lines changed

lib/async/redis/client.rb

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
require_relative 'context/pipeline'
1111
require_relative 'context/transaction'
1212
require_relative 'context/subscribe'
13-
require_relative 'protocol/resp2'
13+
require_relative 'endpoint'
1414

1515
require 'io/endpoint/host_endpoint'
1616
require 'async/pool/controller'
@@ -23,14 +23,10 @@ module Redis
2323
# Legacy.
2424
ServerError = ::Protocol::Redis::ServerError
2525

26-
def self.local_endpoint(port: 6379)
27-
::IO::Endpoint.tcp('localhost', port)
28-
end
29-
3026
class Client
3127
include ::Protocol::Redis::Methods
3228

33-
def initialize(endpoint = Redis.local_endpoint, protocol: Protocol::RESP2, **options)
29+
def initialize(endpoint = Endpoint.local, protocol: endpoint.protocol, **options)
3430
@endpoint = endpoint
3531
@protocol = protocol
3632

lib/async/redis/endpoint.rb

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
# frozen_string_literal: true
2+
3+
# Released under the MIT License.
4+
# Copyright, 2024, by Samuel Williams.
5+
6+
require 'io/endpoint'
7+
require 'io/endpoint/host_endpoint'
8+
require 'io/endpoint/ssl_endpoint'
9+
10+
require_relative 'protocol/resp2'
11+
require_relative 'protocol/authenticated'
12+
require_relative 'protocol/selected'
13+
14+
module Async
15+
module Redis
16+
def self.local_endpoint(**options)
17+
Endpoint.local(**options)
18+
end
19+
20+
# Represents a way to connect to a remote Redis server.
21+
class Endpoint < ::IO::Endpoint::Generic
22+
LOCALHOST = URI.parse("redis://localhost").freeze
23+
24+
def self.local(**options)
25+
self.new(LOCALHOST, **options)
26+
end
27+
28+
SCHEMES = {
29+
'redis' => URI::Generic,
30+
'rediss' => URI::Generic,
31+
}
32+
33+
def self.parse(string, endpoint = nil, **options)
34+
url = URI.parse(string).normalize
35+
36+
return self.new(url, endpoint, **options)
37+
end
38+
39+
# Construct an endpoint with a specified scheme, hostname, optional path, and options.
40+
#
41+
# @parameter scheme [String] The scheme to use, e.g. "redis" or "rediss".
42+
# @parameter hostname [String] The hostname to connect to (or bind to).
43+
# @parameter *options [Hash] Additional options, passed to {#initialize}.
44+
def self.for(scheme, hostname, credentials: nil, port: nil, database: nil, **options)
45+
uri_klass = SCHEMES.fetch(scheme.downcase) do
46+
raise ArgumentError, "Unsupported scheme: #{scheme.inspect}"
47+
end
48+
49+
if database
50+
path = "/#{database}"
51+
end
52+
53+
self.new(
54+
uri_klass.new(scheme, credentials&.join(":"), hostname, port, nil, path, nil, nil, nil).normalize,
55+
**options
56+
)
57+
end
58+
59+
# Coerce the given object into an endpoint.
60+
# @parameter url [String | Endpoint] The URL or endpoint to convert.
61+
def self.[](object)
62+
if object.is_a?(self)
63+
return object
64+
else
65+
self.parse(object.to_s)
66+
end
67+
end
68+
69+
# Create a new endpoint.
70+
#
71+
# @parameter url [URI] The URL to connect to.
72+
# @parameter endpoint [Endpoint] The underlying endpoint to use.
73+
# @parameter scheme [String] The scheme to use, e.g. "redis" or "rediss".
74+
# @parameter hostname [String] The hostname to connect to (or bind to), overrides the URL hostname (used for SNI).
75+
# @parameter port [Integer] The port to bind to, overrides the URL port.
76+
def initialize(url, endpoint = nil, **options)
77+
super(**options)
78+
79+
raise ArgumentError, "URL must be absolute (include scheme, host): #{url}" unless url.absolute?
80+
81+
@url = url
82+
83+
if endpoint
84+
@endpoint = self.build_endpoint(endpoint)
85+
else
86+
@endpoint = nil
87+
end
88+
end
89+
90+
def to_url
91+
url = @url.dup
92+
93+
unless default_port?
94+
url.port = self.port
95+
end
96+
97+
return url
98+
end
99+
100+
def to_s
101+
"\#<#{self.class} #{self.to_url} #{@options}>"
102+
end
103+
104+
def inspect
105+
"\#<#{self.class} #{self.to_url} #{@options.inspect}>"
106+
end
107+
108+
attr :url
109+
110+
def address
111+
endpoint.address
112+
end
113+
114+
def secure?
115+
['rediss'].include?(self.scheme)
116+
end
117+
118+
def protocol
119+
protocol = @options.fetch(:protocol, Protocol::RESP2)
120+
121+
if database = self.database
122+
protocol = Protocol::Selected.new(database, protocol)
123+
end
124+
125+
if credentials = self.credentials
126+
protocol = Protocol::Authenticated.new(credentials, protocol)
127+
end
128+
129+
return protocol
130+
end
131+
132+
def default_port
133+
6379
134+
end
135+
136+
def default_port?
137+
port == default_port
138+
end
139+
140+
def port
141+
@options[:port] || @url.port || default_port
142+
end
143+
144+
# The hostname is the server we are connecting to:
145+
def hostname
146+
@options[:hostname] || @url.hostname
147+
end
148+
149+
def scheme
150+
@options[:scheme] || @url.scheme
151+
end
152+
153+
def database
154+
@options[:database] || @url.path[1..-1].to_i
155+
end
156+
157+
def credentials
158+
@options[:credentials] || @url.userinfo&.split(":")
159+
end
160+
161+
def localhost?
162+
@url.hostname =~ /^(.*?\.)?localhost\.?$/
163+
end
164+
165+
# We don't try to validate peer certificates when talking to localhost because they would always be self-signed.
166+
def ssl_verify_mode
167+
if self.localhost?
168+
OpenSSL::SSL::VERIFY_NONE
169+
else
170+
OpenSSL::SSL::VERIFY_PEER
171+
end
172+
end
173+
174+
def ssl_context
175+
@options[:ssl_context] || OpenSSL::SSL::SSLContext.new.tap do |context|
176+
context.set_params(
177+
verify_mode: self.ssl_verify_mode
178+
)
179+
end
180+
end
181+
182+
def build_endpoint(endpoint = nil)
183+
endpoint ||= tcp_endpoint
184+
185+
if secure?
186+
# Wrap it in SSL:
187+
return ::IO::Endpoint::SSLEndpoint.new(endpoint,
188+
ssl_context: self.ssl_context,
189+
hostname: @url.hostname,
190+
timeout: self.timeout,
191+
)
192+
end
193+
194+
return endpoint
195+
end
196+
197+
def endpoint
198+
@endpoint ||= build_endpoint
199+
end
200+
201+
def endpoint=(endpoint)
202+
@endpoint = build_endpoint(endpoint)
203+
end
204+
205+
def bind(*arguments, &block)
206+
endpoint.bind(*arguments, &block)
207+
end
208+
209+
def connect(&block)
210+
endpoint.connect(&block)
211+
end
212+
213+
def each
214+
return to_enum unless block_given?
215+
216+
self.tcp_endpoint.each do |endpoint|
217+
yield self.class.new(@url, endpoint, **@options)
218+
end
219+
end
220+
221+
def key
222+
[@url, @options]
223+
end
224+
225+
def eql? other
226+
self.key.eql? other.key
227+
end
228+
229+
def hash
230+
self.key.hash
231+
end
232+
233+
protected
234+
235+
def tcp_options
236+
options = @options.dup
237+
238+
options.delete(:scheme)
239+
options.delete(:port)
240+
options.delete(:hostname)
241+
options.delete(:ssl_context)
242+
options.delete(:protocol)
243+
244+
return options
245+
end
246+
247+
def tcp_endpoint
248+
::IO::Endpoint.tcp(self.hostname, port, **tcp_options)
249+
end
250+
end
251+
end
252+
end

lib/async/redis/protocol/authenticated.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class AuthenticationError < StandardError
1818
#
1919
# @parameter credentials [Array] The credentials to use for authentication.
2020
# @parameter protocol [Object] The delegated protocol for connecting.
21-
def initialize(credentials, protocol: Async::Redis::Protocol::RESP2)
21+
def initialize(credentials, protocol = Async::Redis::Protocol::RESP2)
2222
@credentials = credentials
2323
@protocol = protocol
2424
end

lib/async/redis/protocol/selected.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class SelectionError < StandardError
1818
#
1919
# @parameter index [Integer] The database index to select.
2020
# @parameter protocol [Object] The delegated protocol for connecting.
21-
def initialize(index, protocol: Async::Redis::Protocol::RESP2)
21+
def initialize(index, protocol = Async::Redis::Protocol::RESP2)
2222
@index = index
2323
@protocol = protocol
2424
end

test/async/redis/disconnect.rb

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,30 @@
1010

1111
describe Async::Redis::Client do
1212
include Sus::Fixtures::Async::ReactorContext
13-
14-
let(:endpoint) {::IO::Endpoint.tcp('localhost', 5555)}
15-
13+
14+
# Intended to not be connected:
15+
let(:endpoint) {Async::Redis::Endpoint.local(port: 5555)}
16+
17+
before do
18+
@server_endpoint = ::IO::Endpoint.tcp("localhost").bound
19+
end
20+
21+
after do
22+
@server_endpoint&.close
23+
end
24+
1625
it "should raise error on unexpected disconnect" do
17-
server_task = reactor.async do
18-
endpoint.accept do |connection|
26+
server_task = Async do
27+
@server_endpoint.accept do |connection|
1928
connection.read(8)
2029
connection.close
2130
end
2231
end
23-
24-
client = Async::Redis::Client.new(endpoint)
32+
33+
client = Async::Redis::Client.new(
34+
@server_endpoint.local_address_endpoint,
35+
protocol: Async::Redis::Protocol::RESP2,
36+
)
2537

2638
expect do
2739
client.call("GET", "test")

test/async/redis/endpoint.rb

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# frozen_string_literal: true
2+
3+
# Released under the MIT License.
4+
# Copyright, 2024, by Samuel Williams.
5+
6+
require 'async/redis/client'
7+
require 'async/redis/protocol/authenticated'
8+
require 'sus/fixtures/async'
9+
10+
describe Async::Redis::Protocol::Authenticated do
11+
include Sus::Fixtures::Async::ReactorContext
12+
13+
let(:endpoint) {Async::Redis.local_endpoint}
14+
let(:credentials) {["testuser", "testpassword"]}
15+
let(:protocol) {subject.new(credentials)}
16+
let(:client) {Async::Redis::Client.new(endpoint, protocol: protocol)}
17+
18+
before do
19+
# Setup ACL user with limited permissions for testing.
20+
admin_client = Async::Redis::Client.new(endpoint)
21+
admin_client.call("ACL", "SETUSER", "testuser", "on", ">" + credentials[1], "+ping", "+auth")
22+
ensure
23+
admin_client.close
24+
end
25+
26+
after do
27+
# Cleanup ACL user after tests.
28+
admin_client = Async::Redis::Client.new(endpoint)
29+
admin_client.call("ACL", "DELUSER", "testuser")
30+
admin_client.close
31+
end
32+
33+
it "can authenticate and send allowed commands" do
34+
response = client.call("PING")
35+
expect(response).to be == "PONG"
36+
end
37+
38+
it "rejects commands not allowed by ACL" do
39+
expect do
40+
client.call("SET", "key", "value")
41+
end.to raise_exception(Protocol::Redis::ServerError, message: be =~ /NOPERM/)
42+
end
43+
end

0 commit comments

Comments
 (0)