diff --git a/.ameba.yml b/.ameba.yml new file mode 100644 index 0000000..f7e64c7 --- /dev/null +++ b/.ameba.yml @@ -0,0 +1,2 @@ +Excluded: + - test/ diff --git a/.gitignore b/.gitignore index cf69edc..d186f7a 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ *.tar.gz *.swp /builds/ +shard.override.yml diff --git a/CHANGELOG.md b/CHANGELOG.md index 809c0c4..89673f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [v2.0.0] - 2024-02-19 + +- Rewrite of the proxy where Channels are pooled rather than connections. When a client opens a channel it will get a channel on a shared upstream connection, the proxy will remap the channel numbers between the two. Many client connections can therefor share a single upstream connection. Upside is that way fewer connections are needed to the upstream server, downside is that if there's a misbehaving client, for which the server closes the connection, all channels for other clients on that shared connection will also be closed. + ## [v1.0.0] - 2024-02-19 - Nothing changed from v0.8.14 diff --git a/shard.lock b/shard.lock index d1b5452..434be48 100644 --- a/shard.lock +++ b/shard.lock @@ -2,17 +2,13 @@ version: 2.0 shards: ameba: git: https://github.com/crystal-ameba/ameba.git - version: 1.5.0 + version: 1.6.1 amq-protocol: git: https://github.com/cloudamqp/amq-protocol.cr.git - version: 1.1.4 + version: 1.1.14 amqp-client: git: https://github.com/cloudamqp/amqp-client.cr.git - version: 1.0.11 - - logger: - git: https://github.com/84codes/logger.cr.git - version: 1.0.2 + version: 1.2.1 diff --git a/shard.yml b/shard.yml index 09febe2..9a4d6dd 100644 --- a/shard.yml +++ b/shard.yml @@ -1,5 +1,5 @@ name: amqproxy -version: 1.0.0 +version: 2.0.0 authors: - CloudAMQP @@ -11,8 +11,6 @@ targets: dependencies: amq-protocol: github: cloudamqp/amq-protocol.cr - logger: - github: 84codes/logger.cr development_dependencies: amqp-client: diff --git a/spec/amqproxy_spec.cr b/spec/amqproxy_spec.cr index db58892..3679cf0 100644 --- a/spec/amqproxy_spec.cr +++ b/spec/amqproxy_spec.cr @@ -2,17 +2,17 @@ require "./spec_helper" describe AMQProxy::Server do it "keeps connections open" do - s = AMQProxy::Server.new("127.0.0.1", 5672, false, Logger::DEBUG) + s = AMQProxy::Server.new("127.0.0.1", 5672, false) begin spawn { s.listen("127.0.0.1", 5673) } Fiber.yield 10.times do AMQP::Client.start("amqp://localhost:5673") do |conn| - conn.channel + ch = conn.channel + ch.basic_publish "foobar", "amq.fanout", "" s.client_connections.should eq 1 s.upstream_connections.should eq 1 end - sleep 0.1 end s.client_connections.should eq 0 s.upstream_connections.should eq 1 @@ -22,7 +22,7 @@ describe AMQProxy::Server do end it "publish and consume works" do - server = AMQProxy::Server.new("127.0.0.1", 5672, false, Logger::DEBUG) + server = AMQProxy::Server.new("127.0.0.1", 5672, false) begin spawn { server.listen("127.0.0.1", 5673) } Fiber.yield @@ -38,18 +38,19 @@ describe AMQProxy::Server do queue = channel.queue(queue_name) queue.publish_confirm(message_payload) end - sleep 0.1 end + sleep 0.1 AMQP::Client.start("amqp://localhost:5673") do |conn| channel = conn.channel - channel.basic_consume(queue_name, block: true, tag: "AMQProxy specs") do |msg| + channel.basic_consume(queue_name, no_ack: false, tag: "AMQProxy specs") do |msg| body = msg.body_io.to_s if body == message_payload channel.basic_ack(msg.delivery_tag) num_received_messages += 1 end end + sleep 0.1 end num_received_messages.should eq num_messages_to_publish @@ -58,8 +59,30 @@ describe AMQProxy::Server do end end + it "a client can open all channels" do + s = AMQProxy::Server.new("127.0.0.1", 5672, false) + begin + spawn { s.listen("127.0.0.1", 5673) } + Fiber.yield + max = 4000 + AMQP::Client.start("amqp://localhost:5673?channel_max=#{max}") do |conn| + conn.channel_max.should eq max + conn.channel_max.times do + conn.channel + end + s.client_connections.should eq 1 + s.upstream_connections.should eq 2 + end + sleep 0.1 + s.client_connections.should eq 0 + s.upstream_connections.should eq 2 + ensure + s.stop_accepting_clients + end + end + it "can reconnect if upstream closes" do - s = AMQProxy::Server.new("127.0.0.1", 5672, false, Logger::DEBUG) + s = AMQProxy::Server.new("127.0.0.1", 5672, false) begin spawn { s.listen("127.0.0.1", 5673) } Fiber.yield @@ -83,7 +106,7 @@ describe AMQProxy::Server do it "responds to upstream heartbeats" do system("#{MAYBE_SUDO}rabbitmqctl eval 'application:set_env(rabbit, heartbeat, 1).' > /dev/null").should be_true - s = AMQProxy::Server.new("127.0.0.1", 5672, false, Logger::DEBUG) + s = AMQProxy::Server.new("127.0.0.1", 5672, false) begin spawn { s.listen("127.0.0.1", 5673) } Fiber.yield @@ -102,7 +125,7 @@ describe AMQProxy::Server do it "supports waiting for client connections on graceful shutdown" do started = Time.utc.to_unix - s = AMQProxy::Server.new("127.0.0.1", 5672, false, Logger::DEBUG, 5) + s = AMQProxy::Server.new("127.0.0.1", 5672, false, 5) wait_for_channel = Channel(Int32).new # channel used to wait for certain calls, to test certain behaviour spawn do s.listen("127.0.0.1", 5673) @@ -133,11 +156,11 @@ describe AMQProxy::Server do end wait_for_channel.receive.should eq 2 # wait 2 s.client_connections.should eq 2 - s.upstream_connections.should eq 2 + s.upstream_connections.should eq 1 spawn s.stop_accepting_clients wait_for_channel.receive.should eq 3 # wait 3 s.client_connections.should eq 1 - s.upstream_connections.should eq 2 # since connection stays open + s.upstream_connections.should eq 1 # since connection stays open spawn do begin AMQP::Client.start("amqp://localhost:5673") do |conn| @@ -153,7 +176,7 @@ describe AMQProxy::Server do end wait_for_channel.receive.should eq 4 # wait 4 s.client_connections.should eq 1 # since the new connection should not have worked - s.upstream_connections.should eq 2 # since connections stay open + s.upstream_connections.should eq 1 # since connections stay open wait_for_channel.receive.should eq 5 # wait 5 s.client_connections.should eq 0 # since now the server should be closed s.upstream_connections.should eq 1 diff --git a/spec/spec_helper.cr b/spec/spec_helper.cr index d0978b6..924f505 100644 --- a/spec/spec_helper.cr +++ b/spec/spec_helper.cr @@ -4,33 +4,3 @@ require "../src/amqproxy/version" require "amqp-client" MAYBE_SUDO = (ENV.has_key?("NO_SUDO") || `id -u` == "0\n") ? "" : "sudo " - -# Spec timeout borrowed from Crystal project: -# https://github.com/crystal-lang/crystal/blob/1.10.1/spec/support/mt_abort_timeout.cr - -private SPEC_TIMEOUT = 15.seconds - -Spec.around_each do |example| - done = Channel(Exception?).new - - spawn(same_thread: true) do - begin - example.run - rescue e - done.send(e) - else - done.send(nil) - end - end - - timeout = SPEC_TIMEOUT - - select - when res = done.receive - raise res if res - when timeout(timeout) - _it = example.example - ex = Spec::AssertionFailed.new("spec timed out after #{timeout}", _it.file, _it.line) - _it.parent.report(:fail, _it.description, _it.file, _it.line, timeout, ex) - end -end diff --git a/src/amqproxy.cr b/src/amqproxy.cr index ee113ad..691cbc7 100644 --- a/src/amqproxy.cr +++ b/src/amqproxy.cr @@ -3,12 +3,12 @@ require "./amqproxy/server" require "option_parser" require "uri" require "ini" -require "logger" +require "log" class AMQProxy::CLI @listen_address = ENV["LISTEN_ADDRESS"]? || "localhost" @listen_port = ENV["LISTEN_PORT"]? || 5673 - @log_level : Logger::Severity = Logger::INFO + @log_level : Log::Severity = Log::Severity::Info @idle_connection_timeout : Int32 = ENV.fetch("IDLE_CONNECTION_TIMEOUT", "5").to_i @upstream = ENV["AMQP_URL"]? @@ -19,7 +19,7 @@ class AMQProxy::CLI section.each do |key, value| case key when "upstream" then @upstream = value - when "log_level" then @log_level = Logger::Severity.parse(value) + when "log_level" then @log_level = Log::Severity.parse(value) when "idle_connection_timeout" then @idle_connection_timeout = value.to_i else raise "Unsupported config #{name}/#{key}" end @@ -29,7 +29,7 @@ class AMQProxy::CLI case key when "port" then @listen_port = value when "bind", "address" then @listen_address = value - when "log_level" then @log_level = Logger::Severity.parse(value) + when "log_level" then @log_level = Log::Severity.parse(value) else raise "Unsupported config #{name}/#{key}" end end @@ -50,7 +50,7 @@ class AMQProxy::CLI parser.on("-t IDLE_CONNECTION_TIMEOUT", "--idle-connection-timeout=SECONDS", "Maxiumum time in seconds an unused pooled connection stays open (default 5s)") do |v| @idle_connection_timeout = v.to_i end - parser.on("-d", "--debug", "Verbose logging") { @log_level = Logger::DEBUG } + parser.on("-d", "--debug", "Verbose logging") { @log_level = Log::Severity::Debug } parser.on("-c FILE", "--config=FILE", "Load config file") { |v| parse_config(v) } parser.on("-h", "--help", "Show this help") { puts parser.to_s; exit 0 } parser.on("-v", "--version", "Display version") { puts AMQProxy::VERSION.to_s; exit 0 } @@ -71,15 +71,23 @@ class AMQProxy::CLI port = u.port || default_port tls = u.scheme == "amqps" - server = AMQProxy::Server.new(u.host || "", port, tls, @log_level, @idle_connection_timeout) + log_backend = if ENV.has_key?("JOURNAL_STREAM") + Log::IOBackend.new(formatter: JournalLogFormat, dispatcher: ::Log::DirectDispatcher) + else + Log::IOBackend.new(formatter: StdoutLogFormat, dispatcher: ::Log::DirectDispatcher) + end + Log.setup_from_env(default_level: @log_level, backend: log_backend) + + server = AMQProxy::Server.new(u.host || "", port, tls, @idle_connection_timeout) first_shutdown = true shutdown = ->(_s : Signal) do if first_shutdown first_shutdown = false server.stop_accepting_clients - else server.disconnect_clients + else + server.close_sockets end end Signal::INT.trap &shutdown @@ -92,6 +100,28 @@ class AMQProxy::CLI sleep 0.2 end end + + struct JournalLogFormat < Log::StaticFormatter + def run + source + context(before: '[', after: ']') + string ' ' + message + exception + end + end + + struct StdoutLogFormat < Log::StaticFormatter + def run + timestamp + severity + source(before: ' ') + context(before: '[', after: ']') + string ' ' + message + exception + end + end end AMQProxy::CLI.new.run diff --git a/src/amqproxy/channel_pool.cr b/src/amqproxy/channel_pool.cr new file mode 100644 index 0000000..65336b6 --- /dev/null +++ b/src/amqproxy/channel_pool.cr @@ -0,0 +1,85 @@ +require "openssl" +require "log" +require "./records" +require "./upstream" + +module AMQProxy + class ChannelPool + Log = ::Log.for(self) + @lock = Mutex.new + @upstreams = Deque(Upstream).new + + def initialize(@host : String, @port : Int32, @tls_ctx : OpenSSL::SSL::Context::Client?, @credentials : Credentials, @idle_connection_timeout : Int32) + spawn shrink_pool_loop, name: "shrink pool loop" + end + + def get(downstream_channel : DownstreamChannel) : UpstreamChannel + at_channel_max = 0 + @lock.synchronize do + loop do + if upstream = @upstreams.shift? + next if upstream.closed? + begin + upstream_channel = upstream.open_channel_for(downstream_channel) + @upstreams.unshift(upstream) + return upstream_channel + rescue Upstream::ChannelMaxReached + @upstreams.push(upstream) + at_channel_max += 1 + add_upstream if at_channel_max == @upstreams.size + end + else + add_upstream + end + end + end + end + + private def add_upstream + upstream = Upstream.new(@host, @port, @tls_ctx, @credentials) + Log.info { "Adding upstream connection" } + @upstreams.unshift upstream + rescue ex : IO::Error + raise Upstream::Error.new ex.message, cause: ex + end + + def connections + @upstreams.size + end + + def close + Log.info { "Closing all upstream connections" } + @lock.synchronize do + while u = @upstreams.shift? + begin + u.close "AMQProxy shutdown" + rescue ex + Log.error { "Problem closing upstream: #{ex.inspect}" } + end + end + end + end + + private def shrink_pool_loop + loop do + sleep @idle_connection_timeout.seconds + @lock.synchronize do + (@upstreams.size - 1).times do # leave at least one connection + u = @upstreams.pop + if u.active_channels.zero? + begin + u.close "Pooled connection closed due to inactivity" + rescue ex + Log.error { "Problem closing upstream: #{ex.inspect}" } + end + elsif u.closed? + Log.error { "Removing closed upstream connection from pool" } + else + @upstreams.unshift u + end + end + end + end + end + end +end diff --git a/src/amqproxy/client.cr b/src/amqproxy/client.cr index 7ce89da..98135b7 100644 --- a/src/amqproxy/client.cr +++ b/src/amqproxy/client.cr @@ -1,77 +1,157 @@ require "socket" require "amq-protocol" require "./version" +require "./upstream" +require "./records" module AMQProxy - struct Client - @lock = Mutex.new + class Client + Log = ::Log.for(self) + getter credentials : Credentials + @channel_map = Hash(UInt16, UpstreamChannel).new + @outgoing_frames = Channel(AMQ::Protocol::Frame).new(128) + @frame_max : UInt32 + @channel_max : UInt16 + @heartbeat : UInt16 def initialize(@socket : TCPSocket) + set_socket_options(@socket) + tune_ok, @credentials = negotiate(@socket) + @frame_max = tune_ok.frame_max + @channel_max = tune_ok.channel_max + @heartbeat = tune_ok.heartbeat + spawn write_loop end - def read_loop(upstream : Upstream) - socket = @socket + # frames from enduser + def read_loop(channel_pool, socket = @socket) # ameba:disable Metrics/CyclomaticComplexity + Log.context.set(remote_address: socket.remote_address.to_s) + Log.debug { "Connected" } loop do - AMQ::Protocol::Frame.from_io(socket, IO::ByteFormat::NetworkEndian) do |frame| - case frame - when AMQ::Protocol::Frame::Heartbeat - socket.write_bytes frame, IO::ByteFormat::NetworkEndian - socket.flush - when AMQ::Protocol::Frame::Connection::CloseOk - return - else - if response_frame = upstream.write frame - socket.write_bytes response_frame, IO::ByteFormat::NetworkEndian - socket.flush - return if response_frame.is_a? AMQ::Protocol::Frame::Connection::CloseOk - end + case frame = AMQ::Protocol::Frame.from_io(socket, IO::ByteFormat::NetworkEndian) + when AMQ::Protocol::Frame::Heartbeat then write frame + when AMQ::Protocol::Frame::Connection::CloseOk then return + when AMQ::Protocol::Frame::Connection::Close + close_all_upstream_channels + write AMQ::Protocol::Frame::Connection::CloseOk.new + return + when AMQ::Protocol::Frame::Channel::Open + raise "Channel already opened" if @channel_map.has_key? frame.channel + upstream_channel = channel_pool.get(DownstreamChannel.new(self, frame.channel)) + @channel_map[frame.channel] = upstream_channel + write AMQ::Protocol::Frame::Channel::OpenOk.new(frame.channel) + when AMQ::Protocol::Frame::Channel::Close + if upstream_channel = @channel_map.delete(frame.channel) + upstream_channel.unassign + end + write AMQ::Protocol::Frame::Channel::CloseOk.new(frame.channel) + when AMQ::Protocol::Frame::Channel::CloseOk + # noop + when frame.channel.zero? + Log.error { "Unexpected connection frame: #{frame}" } + close_connection(540_u16, "NOT_IMPLEMENTED", frame) + else + src_channel = frame.channel + begin + upstream_channel = @channel_map[frame.channel] + upstream_channel.write(frame) + rescue ex : Upstream::WriteError + close_channel(src_channel) + rescue KeyError + close_connection(504_u16, "CHANNEL_ERROR - Channel #{frame.channel} not open", frame) end end end - rescue ex : Upstream::WriteError - upstream_disconnected rescue ex : IO::EOFError - raise Error.new("Client disconnected", ex) unless @socket.closed? - rescue ex - raise ReadError.new "Client read error", ex + Log.debug { "Disconnected" } + rescue ex : IO::Error + Log.error(exception: ex) { "IO error" } unless socket.closed? + rescue ex : Upstream::AccessError + Log.error { "Access refused, reason: #{ex.message}" } + close_connection(403_u16, ex.message || "ACCESS_REFUSED") + rescue ex : Upstream::Error + Log.error(exception: ex) { "Upstream error" } + close_connection(503_u16, "UPSTREAM_ERROR - #{ex.message}") + else + Log.debug { "Disconnected" } + ensure + @outgoing_frames.close + close_all_upstream_channels + end + + private def write_loop(socket = @socket) + while frame = @outgoing_frames.receive? + socket.write_bytes frame, IO::ByteFormat::NetworkEndian + socket.flush unless expect_more_publish_frames?(frame) + + break if frame.is_a? AMQ::Protocol::Frame::Connection::CloseOk + end + rescue ex : IO::Error + raise ex unless socket.closed? ensure - @socket.close rescue nil + @outgoing_frames.close + socket.close rescue nil + close_all_upstream_channels end - # Send frame to client + # Send frame to client, channel id should already be remapped by the caller def write(frame : AMQ::Protocol::Frame) - @lock.synchronize do - socket = @socket - return if socket.closed? - frame.to_io(socket, IO::ByteFormat::NetworkEndian) - socket.flush - case frame - when AMQ::Protocol::Frame::Connection::CloseOk - socket.close - end + @outgoing_frames.send frame + end + + def close_connection(code, text, frame = nil) + case frame + when AMQ::Protocol::Frame::Method + write AMQ::Protocol::Frame::Connection::Close.new(code, text, frame.class_id, frame.method_id) + else + write AMQ::Protocol::Frame::Connection::Close.new(code, text, 0_u16, 0_u16) end - rescue ex : Socket::Error - raise WriteError.new "Error writing to client", ex end - def upstream_disconnected - write AMQ::Protocol::Frame::Connection::Close.new(0_u16, - "UPSTREAM_ERROR", - 0_u16, 0_u16) - rescue WriteError + def close_channel(id) + write AMQ::Protocol::Frame::Channel::Close.new(id, 500_u16, "UPSTREAM_DISCONNECTED", 0_u16, 0_u16) + end + + private def close_all_upstream_channels + @channel_map.each_value do |upstream_channel| + upstream_channel.unassign + rescue Upstream::WriteError + next # Nothing to do + end + @channel_map.clear + end + + private def expect_more_publish_frames?(frame) : Bool + case frame + when AMQ::Protocol::Frame::Basic::Publish then true + when AMQ::Protocol::Frame::Header then frame.body_size != 0 + when AMQ::Protocol::Frame::Body then frame.bytesize == @frame_max + else false + end end def close write AMQ::Protocol::Frame::Connection::Close.new(0_u16, "AMQProxy shutdown", 0_u16, 0_u16) + # @socket.read_timeout = 1.seconds end + # Close the outgoing frames channel which will let write_loop close the socket def close_socket - @socket.close rescue nil + @outgoing_frames.close end - def self.negotiate(socket) + private def set_socket_options(socket = @socket) + socket.sync = false + socket.keepalive = true + socket.tcp_nodelay = true + socket.tcp_keepalive_idle = 60 + socket.tcp_keepalive_count = 3 + socket.tcp_keepalive_interval = 10 + end + + private def negotiate(socket = @socket) proto = uninitialized UInt8[8] socket.read_fully(proto.to_slice) @@ -82,69 +162,63 @@ module AMQProxy raise IO::EOFError.new("Invalid protocol start") end - props = AMQ::Protocol::Table.new({ - product: "AMQProxy", - version: VERSION, - capabilities: { - consumer_priorities: true, - exchange_exchange_bindings: true, - "connection.blocked": true, - authentication_failure_close: true, - per_consumer_qos: true, - "basic.nack": true, - direct_reply_to: true, - publisher_confirms: true, - consumer_cancel_notify: true, - }, - }) - start = AMQ::Protocol::Frame::Connection::Start.new(server_properties: props) + start = AMQ::Protocol::Frame::Connection::Start.new(server_properties: ServerProperties) start.to_io(socket, IO::ByteFormat::NetworkEndian) socket.flush user = password = "" - AMQ::Protocol::Frame.from_io(socket, IO::ByteFormat::NetworkEndian) do |frame| - start_ok = frame.as(AMQ::Protocol::Frame::Connection::StartOk) - case start_ok.mechanism - when "PLAIN" - resp = start_ok.response - if i = resp.index('\u0000', 1) - user = resp[1...i] - password = resp[(i + 1)..-1] - else - raise "Invalid authentication information encoding" - end - when "AMQPLAIN" - io = IO::Memory.new(start_ok.response) - tbl = AMQ::Protocol::Table.from_io(io, IO::ByteFormat::NetworkEndian, - start_ok.response.size.to_u32) - user = tbl["LOGIN"].as(String) - password = tbl["PASSWORD"].as(String) - else raise "Unsupported authentication mechanism: #{start_ok.mechanism}" + start_ok = AMQ::Protocol::Frame.from_io(socket).as(AMQ::Protocol::Frame::Connection::StartOk) + case start_ok.mechanism + when "PLAIN" + resp = start_ok.response + if i = resp.index('\u0000', 1) + user = resp[1...i] + password = resp[(i + 1)..-1] + else + raise "Invalid authentication information encoding" end + when "AMQPLAIN" + io = IO::Memory.new(start_ok.response) + tbl = AMQ::Protocol::Table.from_io(io, IO::ByteFormat::NetworkEndian, start_ok.response.size.to_u32) + user = tbl["LOGIN"].as(String) + password = tbl["PASSWORD"].as(String) + else raise "Unsupported authentication mechanism: #{start_ok.mechanism}" end - tune = AMQ::Protocol::Frame::Connection::Tune.new(frame_max: 131072_u32, channel_max: 0_u16, heartbeat: 0_u16) + tune = AMQ::Protocol::Frame::Connection::Tune.new(frame_max: 131072_u32, channel_max: UInt16::MAX, heartbeat: 0_u16) tune.to_io(socket, IO::ByteFormat::NetworkEndian) socket.flush - AMQ::Protocol::Frame.from_io socket, IO::ByteFormat::NetworkEndian do |_tune_ok| - end + tune_ok = AMQ::Protocol::Frame.from_io(socket).as(AMQ::Protocol::Frame::Connection::TuneOk) - vhost = "" - AMQ::Protocol::Frame.from_io(socket, IO::ByteFormat::NetworkEndian) do |frame| - open = frame.as(AMQ::Protocol::Frame::Connection::Open) - vhost = open.vhost - end + open = AMQ::Protocol::Frame.from_io(socket).as(AMQ::Protocol::Frame::Connection::Open) + vhost = open.vhost open_ok = AMQ::Protocol::Frame::Connection::OpenOk.new open_ok.to_io(socket, IO::ByteFormat::NetworkEndian) socket.flush - {vhost, user, password} + {tune_ok, Credentials.new(user, password, vhost)} rescue ex raise NegotiationError.new "Client negotiation failed", ex end + ServerProperties = AMQ::Protocol::Table.new({ + product: "AMQProxy", + version: VERSION, + capabilities: { + consumer_priorities: true, + exchange_exchange_bindings: true, + "connection.blocked": false, + authentication_failure_close: true, + per_consumer_qos: true, + "basic.nack": true, + direct_reply_to: true, + publisher_confirms: true, + consumer_cancel_notify: true, + }, + }) + class Error < Exception; end class ReadError < Error; end diff --git a/src/amqproxy/pool.cr b/src/amqproxy/pool.cr deleted file mode 100644 index 1ec9476..0000000 --- a/src/amqproxy/pool.cr +++ /dev/null @@ -1,88 +0,0 @@ -require "openssl" - -module AMQProxy - class Pool - getter :size - @tls_ctx : OpenSSL::SSL::Context::Client? - - def initialize(@host : String, @port : Int32, tls : Bool, @log : Logger, @idle_connection_timeout : Int32) - @pools = Hash(Tuple(String, String, String), Deque(Upstream)).new do |h, k| - h[k] = Deque(Upstream).new - end - @lock = Mutex.new - @size = 0 - @tls_ctx = OpenSSL::SSL::Context::Client.new if tls - spawn shrink_pool_loop, name: "shrink pool loop" - end - - def borrow(user : String, password : String, vhost : String, client : Client, & : Upstream -> _) - u = @lock.synchronize do - c = @pools[{user, password, vhost}].pop? - if c.nil? || c.closed? - c = Upstream.new(@host, @port, @tls_ctx, @log).connect(user, password, vhost) - @size += 1 - end - c.current_client = client - c - end - - yield u - ensure - @lock.synchronize do - if u.nil? - @size -= 1 - @log.error "Upstream connection could not be established" - elsif u.closed? - @size -= 1 - @log.error "Upstream connection closed when returned" - else - u.client_disconnected - u.last_used = Time.monotonic - @pools[{user, password, vhost}].push u - end - end - end - - def close - @lock.synchronize do - @pools.each_value do |q| - while u = q.shift? - begin - u.close "AMQProxy shutdown" - rescue ex - @log.error "Problem closing upstream: #{ex.inspect}" - end - end - end - @size = 0 - end - end - - private def shrink_pool_loop - loop do - sleep 5.seconds - @lock.synchronize do - max_connection_age = Time.monotonic - @idle_connection_timeout.seconds - @pools.each_value do |q| - q.size.times do - u = q.shift - if u.last_used < max_connection_age - @size -= 1 - begin - u.close "Pooled connection closed due to inactivity" - rescue ex - @log.error "Problem closing upstream: #{ex.inspect}" - end - elsif u.closed? - @size -= 1 - @log.error "Removing closed upstream connection from pool" - else - q.push u - end - end - end - end - end - end - end -end diff --git a/src/amqproxy/records.cr b/src/amqproxy/records.cr new file mode 100644 index 0000000..1dec6c2 --- /dev/null +++ b/src/amqproxy/records.cr @@ -0,0 +1,37 @@ +require "./upstream" +require "./client" + +module AMQProxy + record UpstreamChannel, upstream : Upstream, channel : UInt16 do + def write(frame) + frame.channel = @channel + @upstream.write frame + end + + def unassign + @upstream.unassign_channel(@channel) + end + end + + record DownstreamChannel, client : Client, channel : UInt16 do + def write(frame) + frame.channel = @channel + @client.write(frame) + end + + def close + @client.close_channel(@channel) + end + end + + record Credentials, user : String, password : String, vhost : String +end + +# Be able to overwrite channel id +module AMQ + module Protocol + abstract struct Frame + setter channel + end + end +end diff --git a/src/amqproxy/server.cr b/src/amqproxy/server.cr index 004dbdc..b588257 100644 --- a/src/amqproxy/server.cr +++ b/src/amqproxy/server.cr @@ -1,32 +1,22 @@ require "socket" -require "logger" +require "log" require "amq-protocol" -require "./pool" +require "./channel_pool" require "./client" require "./upstream" module AMQProxy class Server - def initialize(upstream_host, upstream_port, upstream_tls, log_level = Logger::INFO, idle_connection_timeout = 5) - @log = Logger.new(STDOUT) - @log.level = log_level - journald = - {% if flag?(:unix) %} - if journal_stream = ENV.fetch("JOURNAL_STREAM", nil) - stdout_stat = STDOUT.info.@stat - journal_stream == "#{stdout_stat.st_dev}:#{stdout_stat.st_ino}" - end - {% else %} - false - {% end %} - @log.formatter = Logger::Formatter.new do |_severity, datetime, _progname, message, io| - io << datetime << ": " unless journald - io << message + Log = ::Log.for(self) + @clients_lock = Mutex.new + @clients = Array(Client).new + + def initialize(upstream_host, upstream_port, upstream_tls, idle_connection_timeout = 5) + tls_ctx = OpenSSL::SSL::Context::Client.new if upstream_tls + @channel_pools = Hash(Credentials, ChannelPool).new do |hash, credentials| + hash[credentials] = ChannelPool.new(upstream_host, upstream_port, tls_ctx, credentials, idle_connection_timeout) end - @clients_lock = Mutex.new - @clients = Array(Client).new - @pool = Pool.new(upstream_host, upstream_port, upstream_tls, @log, idle_connection_timeout) - @log.info "Proxy upstream: #{upstream_host}:#{upstream_port} #{upstream_tls ? "TLS" : ""}" + Log.info { "Proxy upstream: #{upstream_host}:#{upstream_port} #{upstream_tls ? "TLS" : ""}" } end def client_connections @@ -34,65 +24,46 @@ module AMQProxy end def upstream_connections - @pool.size + @channel_pools.each_value.sum &.connections end def listen(address, port) - @socket = socket = TCPServer.new(address, port) - @log.info "Proxy listening on #{socket.local_address}" - while client = socket.accept? - addr = client.remote_address - spawn handle_connection(client, addr), name: "handle connection #{addr}" + @server = server = TCPServer.new(address, port) + Log.info { "Proxy listening on #{server.local_address}" } + while socket = server.accept? + addr = socket.remote_address + spawn handle_connection(socket, addr), name: "Client#read_loop #{addr}" end - @log.info "Proxy stopping accepting connections" + Log.info { "Proxy stopping accepting connections" } end def stop_accepting_clients - @socket.try &.close + @server.try &.close end def disconnect_clients + Log.info { "Disconnecting clients" } @clients_lock.synchronize do @clients.each &.close # send Connection#Close frames end - sleep 1 # wait for clients to disconnect voluntarily + end + + def close_sockets + Log.info { "Closing client sockets" } @clients_lock.synchronize do @clients.each &.close_socket # close sockets forcefully end end private def handle_connection(socket, remote_address) - socket.sync = false - socket.keepalive = true - socket.tcp_nodelay = true - socket.tcp_keepalive_idle = 60 - socket.tcp_keepalive_count = 3 - socket.tcp_keepalive_interval = 10 - @log.debug { "Client connected: #{remote_address}" } - vhost, user, password = Client.negotiate(socket) c = Client.new(socket) active_client(c) do - @pool.borrow(user, password, vhost, c) do |u| - # print "\r#{@clients.size} clients\t\t #{@pool.size} upstreams" - c.read_loop(u) - end - rescue ex : Upstream::AccessError - @log.error { "Access refused for user '#{user}' to vhost '#{vhost}', reason: #{ex.message}" } - close = AMQ::Protocol::Frame::Connection::Close.new(403_u16, "ACCESS_REFUSED - #{ex.message}", 0_u16, 0_u16) - close.to_io socket, IO::ByteFormat::NetworkEndian - socket.flush - rescue ex : Upstream::Error - @log.error { "Upstream error for user '#{user}' to vhost '#{vhost}': #{ex.inspect} (cause: #{ex.cause.inspect})" } - close = AMQ::Protocol::Frame::Connection::Close.new(403_u16, "UPSTREAM_ERROR", 0_u16, 0_u16) - close.to_io socket, IO::ByteFormat::NetworkEndian - socket.flush + channel_pool = @channel_pools[c.credentials] + c.read_loop(channel_pool) end - rescue ex : Client::Error - @log.debug { "Client disconnected: #{remote_address}: #{ex.inspect}" } - ensure - @log.debug { "Client disconnected: #{remote_address}" } - socket.close rescue nil - # print "\r#{@clients.size} clients\t\t #{@pool.size} upstreams" + rescue ex # only raise from constructor, when negotating + Log.debug { "Client connection failure (#{remote_address}) #{ex.inspect}" } + socket.close end private def active_client(client, &) diff --git a/src/amqproxy/upstream.cr b/src/amqproxy/upstream.cr index a588faa..acb1824 100644 --- a/src/amqproxy/upstream.cr +++ b/src/amqproxy/upstream.cr @@ -1,17 +1,20 @@ require "socket" require "openssl" +require "log" require "./client" +require "./channel_pool" module AMQProxy class Upstream - property last_used = Time.monotonic - setter current_client : Client? + Log = ::Log.for(self) @socket : IO - @open_channels = Set(UInt16).new @unsafe_channels = Set(UInt16).new + @channels = Hash(UInt16, DownstreamChannel?).new + @channels_lock = Mutex.new + @channel_max : UInt16 @lock = Mutex.new - def initialize(@host : String, @port : Int32, @tls_ctx : OpenSSL::SSL::Context::Client?, @log : Logger) + def initialize(@host : String, @port : Int32, @tls_ctx : OpenSSL::SSL::Context::Client?, credentials) tcp_socket = TCPSocket.new(@host, @port) tcp_socket.sync = false tcp_socket.keepalive = true @@ -21,106 +24,139 @@ module AMQProxy tcp_socket.tcp_nodelay = true @socket = if tls_ctx = @tls_ctx - OpenSSL::SSL::Socket::Client.new(tcp_socket, tls_ctx, hostname: @host).tap do |c| - c.sync_close = true - end + tls_socket = OpenSSL::SSL::Socket::Client.new(tcp_socket, tls_ctx, hostname: @host) + tls_socket.sync_close = true + tls_socket else tcp_socket end + @channel_max = start(credentials) + spawn read_loop(@socket, tcp_socket.remote_address.to_s) end - def connect(user : String, password : String, vhost : String) - start(user, password, vhost) - spawn read_loop, name: "upstream read loop #{@host}:#{@port}" - self - rescue ex : IO::Error | OpenSSL::SSL::Error - raise Error.new "Cannot establish connection to upstream", ex + def open_channel_for(downstream_channel : DownstreamChannel) : UpstreamChannel + @channels_lock.synchronize do + 1_u16.upto(@channel_max) do |i| + if @channels.has_key?(i) + if @channels[i].nil? + @channels[i] = downstream_channel + return UpstreamChannel.new(self, i) # reuse + else + next # in use + end + else + @channels[i] = downstream_channel + send AMQ::Protocol::Frame::Channel::Open.new(i) + return UpstreamChannel.new(self, i) + end + end + raise ChannelMaxReached.new + end + end + + def unassign_channel(channel : UInt16) + @channels_lock.synchronize do + if @unsafe_channels.delete channel + send AMQ::Protocol::Frame::Channel::Close.new(channel, 0u16, "", 0u16, 0u16) + @channels.delete channel + else + @channels[channel] = nil # keep for reuse + end + end + end + + def channels + @channels.size + end + + def active_channels + @channels.count { |_, v| !v.nil? } end # Frames from upstream (to client) - def read_loop # ameba:disable Metrics/CyclomaticComplexity - socket = @socket + private def read_loop(socket, remote_address : String) # ameba:disable Metrics/CyclomaticComplexity + Log.context.set(remote_address: remote_address) loop do - AMQ::Protocol::Frame.from_io(socket, IO::ByteFormat::NetworkEndian) do |frame| - case frame - when AMQ::Protocol::Frame::Channel::OpenOk - @open_channels.add(frame.channel) - when AMQ::Protocol::Frame::Channel::Close, - AMQ::Protocol::Frame::Channel::CloseOk - @open_channels.delete(frame.channel) - @unsafe_channels.delete(frame.channel) - when AMQ::Protocol::Frame::Connection::CloseOk - return - when AMQ::Protocol::Frame::Heartbeat - write frame - next + case frame = AMQ::Protocol::Frame.from_io(socket, IO::ByteFormat::NetworkEndian) + when AMQ::Protocol::Frame::Heartbeat then send frame + when AMQ::Protocol::Frame::Connection::Close + close_all_client_channels + begin + send AMQ::Protocol::Frame::Connection::CloseOk.new + rescue WriteError end - if client = @current_client - begin - client.write(frame) - rescue ex - @log.error "#{frame.inspect} could not be sent to client: #{ex.inspect}" - client.close_socket # close the socket of the client so that the client's read_loop exits - client_disconnected - end - elsif !frame.is_a? AMQ::Protocol::Frame::Channel::CloseOk - @log.error "Receiving #{frame.inspect} but no client to delivery to" - if body = frame.as? AMQ::Protocol::Frame::Body - body.body.skip(body.body_size) + return + when AMQ::Protocol::Frame::Connection::CloseOk then return + when AMQ::Protocol::Frame::Channel::OpenOk # we assume it always succeeds + when AMQ::Protocol::Frame::Channel::Close # when upstream server requested a channel close + @channels_lock.synchronize do + @unsafe_channels.delete(frame.channel) + if downstream_channel = @channels.delete(frame.channel) + downstream_channel.write frame end end - if frame.is_a? AMQ::Protocol::Frame::Connection::Close - @log.error "Upstream closed connection: #{frame.reply_text}" - begin - write AMQ::Protocol::Frame::Connection::CloseOk.new - rescue ex : WriteError - @log.error "Error writing CloseOk to upstream: #{ex.inspect}" - end - return + when AMQ::Protocol::Frame::Channel::CloseOk # when channel pool requested channel close + else + if downstream_channel = @channels[frame.channel]? + downstream_channel.write(frame) + else + Log.debug { "Frame for unmapped channel from upstream: #{frame}" } + send AMQ::Protocol::Frame::Channel::Close.new(frame.channel, 500_u16, + "DOWNSTREAM_DISCONNECTED", 0_u16, 0_u16) end end end rescue ex : IO::Error | OpenSSL::SSL::Error - @log.error "Error reading from upstream: #{ex.inspect_with_backtrace}" unless @socket.closed? + Log.error(exception: ex) { "Error reading from upstream" } unless socket.closed? ensure - @socket.close unless @socket.closed? - @current_client.try &.close_socket + socket.close rescue nil + close_all_client_channels + end + + def closed? + @socket.closed? end - SAFE_BASIC_METHODS = {40, 10} # qos and publish + private def close_all_client_channels + Log.debug { "Closing all client channels for closed upstream" } + @channels_lock.synchronize do + cnt = 0 + @channels.each_value do |downstream_channel| + if dch = downstream_channel + dch.close + cnt += 1 + end + end + Log.debug { "Upstream connection closed, closing #{cnt} client channels" } unless cnt.zero? + @channels.clear + end + end - # Send frames to upstream (often from the client) - def write(frame : AMQ::Protocol::Frame) + # Forward frames from client to upstream + def write(frame : AMQ::Protocol::Frame) : Nil case frame + when AMQ::Protocol::Frame::Basic::Publish, + AMQ::Protocol::Frame::Basic::Qos when AMQ::Protocol::Frame::Basic::Get - unless frame.no_ack - @unsafe_channels.add(frame.channel) - end - when AMQ::Protocol::Frame::Basic - unless SAFE_BASIC_METHODS.includes? frame.method_id - @unsafe_channels.add(frame.channel) - end - when AMQ::Protocol::Frame::Confirm + @unsafe_channels.add(frame.channel) unless frame.no_ack + when AMQ::Protocol::Frame::Basic, + AMQ::Protocol::Frame::Confirm, + AMQ::Protocol::Frame::Tx @unsafe_channels.add(frame.channel) - when AMQ::Protocol::Frame::Tx - @unsafe_channels.add(frame.channel) - when AMQ::Protocol::Frame::Connection::Close - return AMQ::Protocol::Frame::Connection::CloseOk.new - when AMQ::Protocol::Frame::Channel::Open - if @open_channels.includes? frame.channel - return AMQ::Protocol::Frame::Channel::OpenOk.new(frame.channel) - end - when AMQ::Protocol::Frame::Channel::Close - unless @unsafe_channels.includes? frame.channel - return AMQ::Protocol::Frame::Channel::CloseOk.new(frame.channel) - end + when AMQ::Protocol::Frame::Connection + raise "Connection frames should not be sent through here: #{frame}" + when AMQ::Protocol::Frame::Channel + raise "Channel frames should not be sent through here: #{frame}" end + send frame + end + + private def send(frame : AMQ::Protocol::Frame) : Nil @lock.synchronize do @socket.write_bytes frame, IO::ByteFormat::NetworkEndian @socket.flush - nil rescue ex : IO::Error | OpenSSL::SSL::Error - @socket.close + @socket.close rescue nil raise WriteError.new "Error writing to upstream", ex end end @@ -140,69 +176,44 @@ module AMQProxy @socket.closed? end - def client_disconnected - @current_client = nil - return if closed? - @lock.synchronize do - @open_channels.each do |ch| - if @unsafe_channels.includes? ch - close = AMQ::Protocol::Frame::Channel::Close.new(ch, 200_u16, "Client disconnected", 0_u16, 0_u16) - close.to_io @socket, IO::ByteFormat::NetworkEndian - @socket.flush - end - end - end - end - - private def start(user, password, vhost) + private def start(credentials) : UInt16 @socket.write AMQ::Protocol::PROTOCOL_START_0_9_1.to_slice @socket.flush # assert correct frame type - AMQ::Protocol::Frame.from_io(@socket, IO::ByteFormat::NetworkEndian) { |f| f.as(AMQ::Protocol::Frame::Connection::Start) } - - props = AMQ::Protocol::Table.new({ - connection_name: "AMQProxy #{VERSION}", - product: "AMQProxy", - version: VERSION, - capabilities: { - consumer_priorities: true, - exchange_exchange_bindings: true, - "connection.blocked": true, - authentication_failure_close: true, - per_consumer_qos: true, - "basic.nack": true, - direct_reply_to: true, - publisher_confirms: true, - consumer_cancel_notify: true, - }, - }) - start_ok = AMQ::Protocol::Frame::Connection::StartOk.new(response: "\u0000#{user}\u0000#{password}", - client_properties: props, mechanism: "PLAIN", locale: "en_US") - start_ok.to_io @socket, IO::ByteFormat::NetworkEndian - @socket.flush + AMQ::Protocol::Frame.from_io(@socket).as(AMQ::Protocol::Frame::Connection::Start) - tune = AMQ::Protocol::Frame.from_io(@socket, IO::ByteFormat::NetworkEndian) { |f| f.as(AMQ::Protocol::Frame::Connection::Tune) } - tune_ok = AMQ::Protocol::Frame::Connection::TuneOk.new(tune.channel_max, tune.frame_max, tune.heartbeat) - tune_ok.to_io @socket, IO::ByteFormat::NetworkEndian + response = "\u0000#{credentials.user}\u0000#{credentials.password}" + start_ok = AMQ::Protocol::Frame::Connection::StartOk.new(response: response, client_properties: ClientProperties, mechanism: "PLAIN", locale: "en_US") + @socket.write_bytes start_ok, IO::ByteFormat::NetworkEndian @socket.flush - open = AMQ::Protocol::Frame::Connection::Open.new(vhost: vhost) - open.to_io @socket, IO::ByteFormat::NetworkEndian + case tune = AMQ::Protocol::Frame.from_io(@socket) + when AMQ::Protocol::Frame::Connection::Tune + channel_max = tune.channel_max.zero? ? UInt16::MAX : tune.channel_max + tune_ok = AMQ::Protocol::Frame::Connection::TuneOk.new(channel_max, tune.frame_max, tune.heartbeat) + @socket.write_bytes tune_ok, IO::ByteFormat::NetworkEndian + @socket.flush + when AMQ::Protocol::Frame::Connection::Close + send_close_ok + raise AccessError.new tune.reply_text + else + raise "Unexpected frame on connection to upstream: #{tune}" + end + + open = AMQ::Protocol::Frame::Connection::Open.new(vhost: credentials.vhost) + @socket.write_bytes open, IO::ByteFormat::NetworkEndian @socket.flush - AMQ::Protocol::Frame.from_io(@socket, IO::ByteFormat::NetworkEndian) do |f| - case f - when AMQ::Protocol::Frame::Connection::Close - close_ok = AMQ::Protocol::Frame::Connection::CloseOk.new - close_ok.to_io @socket, IO::ByteFormat::NetworkEndian - @socket.flush - @socket.close - raise AccessError.new f.reply_text - when AMQ::Protocol::Frame::Connection::OpenOk - true - end + case f = AMQ::Protocol::Frame.from_io(@socket, IO::ByteFormat::NetworkEndian) + when AMQ::Protocol::Frame::Connection::OpenOk + when AMQ::Protocol::Frame::Connection::Close + send_close_ok + raise AccessError.new f.reply_text + else + raise "Unexpected frame on connection to upstream: #{f}" end + channel_max rescue ex : AccessError raise ex rescue ex @@ -210,10 +221,35 @@ module AMQProxy raise Error.new ex.message, cause: ex end + private def send_close_ok + @socket.write_bytes AMQ::Protocol::Frame::Connection::CloseOk.new, IO::ByteFormat::NetworkEndian + @socket.flush + @socket.close + end + + ClientProperties = AMQ::Protocol::Table.new({ + connection_name: "AMQProxy #{VERSION}", + product: "AMQProxy", + version: VERSION, + capabilities: { + consumer_priorities: true, + exchange_exchange_bindings: true, + "connection.blocked": false, + authentication_failure_close: true, + per_consumer_qos: true, + "basic.nack": true, + direct_reply_to: true, + publisher_confirms: true, + consumer_cancel_notify: true, + }, + }) + class Error < Exception; end class AccessError < Error; end class WriteError < Error; end + + class ChannelMaxReached < Error; end end end diff --git a/src/amqproxy/version.cr b/src/amqproxy/version.cr index 9593395..e70d7e2 100644 --- a/src/amqproxy/version.cr +++ b/src/amqproxy/version.cr @@ -1,3 +1,3 @@ module AMQProxy - VERSION = {{ `shards version`.stringify }} + VERSION = {{ `shards version`.stringify.chomp }} end