lib/amqp/client/connection.rb in amqp-client-1.0.2 vs lib/amqp/client/connection.rb in amqp-client-1.1.0

- old
+ new

@@ -11,51 +11,35 @@ class Client # Represents a single established AMQP connection class Connection # Establish a connection to an AMQP broker # @param uri [String] URL on the format amqp://username:password@hostname/vhost, use amqps:// for encrypted connection - # @param read_loop_thread [Boolean] Set to false if you manually want to run the {#read_loop} + # @param read_loop_thread [Boolean] If true run {#read_loop} in a background thread, + # otherwise the user have to run it explicitly, without {#read_loop} the connection won't function # @option options [Boolean] connection_name (PROGRAM_NAME) Set a name for the connection to be able to identify # the client from the broker # @option options [Boolean] verify_peer (true) Verify broker's TLS certificate, set to false for self-signed certs + # @option options [Integer] connect_timeout (30) TCP connection timeout # @option options [Integer] heartbeat (0) Heartbeat timeout, defaults to 0 and relies on TCP keepalive instead # @option options [Integer] frame_max (131_072) Maximum frame size, # the smallest of the client's and the broker's values will be used # @option options [Integer] channel_max (2048) Maxium number of channels the client will be allowed to have open. # Maxium allowed is 65_536. The smallest of the client's and the broker's value will be used. # @return [Connection] - def self.connect(uri, read_loop_thread: true, **options) + def initialize(uri = "", read_loop_thread: true, **options) uri = URI.parse(uri) tls = uri.scheme == "amqps" port = port_from_env || uri.port || (tls ? 5671 : 5672) host = uri.host || "localhost" user = uri.user || "guest" password = uri.password || "guest" - vhost = URI.decode_www_form_component(uri.path[1..-1] || "/") + vhost = URI.decode_www_form_component(uri.path[1..] || "/") options = URI.decode_www_form(uri.query || "").map! { |k, v| [k.to_sym, v] }.to_h.merge(options) - socket = Socket.tcp host, port, connect_timeout: 20, resolv_timeout: 5 - enable_tcp_keepalive(socket) - if tls - cert_store = OpenSSL::X509::Store.new - cert_store.set_default_paths - context = OpenSSL::SSL::SSLContext.new - context.cert_store = cert_store - context.verify_mode = OpenSSL::SSL::VERIFY_PEER unless [false, "false", "none"].include? options[:verify_peer] - socket = OpenSSL::SSL::SSLSocket.new(socket, context) - socket.sync_close = true # closing the TLS socket also closes the TCP socket - socket.hostname = host # SNI host - socket.connect - socket.post_connection_check(host) || raise(Error, "TLS certificate hostname doesn't match requested") - end + socket = open_socket(host, port, tls, options) channel_max, frame_max, heartbeat = establish(socket, user, password, vhost, options) - Connection.new(socket, channel_max, frame_max, heartbeat, read_loop_thread: read_loop_thread) - end - # Requires an already established TCP/TLS socket - # @api private - def initialize(socket, channel_max, frame_max, heartbeat, read_loop_thread: true) @socket = socket @channel_max = channel_max.zero? ? 65_536 : channel_max @frame_max = frame_max @heartbeat = heartbeat @channels = {} @@ -64,10 +48,17 @@ @write_lock = Mutex.new @blocked = nil Thread.new { read_loop } if read_loop_thread end + # Alias for {#initialize} + # @see #initialize + # @deprecated + def self.connect(uri, read_loop_thread: true, **options) + new(uri, read_loop_thread: read_loop_thread, **options) + end + # The max frame size negotiated between the client and the broker # @return [Integer] attr_reader :frame_max # Custom inspect @@ -181,11 +172,13 @@ nil # ignore read errors ensure @closed ||= [400, "unknown"] @replies.close begin - @socket.close + @write_lock.synchronize do + @socket.close + end rescue IOError, OpenSSL::OpenSSLError, SystemCallError nil end end @@ -214,11 +207,11 @@ return false when 51 # connection#close-ok @replies.push [:close_ok] return false when 60 # connection#blocked - reason_len = buf.unpack1("@4 C") + reason_len = buf.getbyte(4) reason = buf.byteslice(5, reason_len).force_encoding("utf-8") @blocked = reason @write_lock.lock when 61 # connection#unblocked @blocked = nil @@ -254,11 +247,11 @@ else raise Error::UnsupportedMethodFrame, class_id, method_id end when 50 # queue case method_id when 11 # declare-ok - queue_name_len = buf.unpack1("@4 C") + queue_name_len = buf.getbyte(4) queue_name = buf.byteslice(5, queue_name_len).force_encoding("utf-8") message_count, consumer_count = buf.byteslice(5 + queue_name_len, 8).unpack("L> L>") @channels[channel_id].reply [:queue_declare_ok, queue_name, message_count, consumer_count] when 21 # bind-ok @channels[channel_id].reply [:queue_bind_ok] @@ -274,58 +267,58 @@ when 60 # basic case method_id when 11 # qos-ok @channels[channel_id].reply [:basic_qos_ok] when 21 # consume-ok - tag_len = buf.unpack1("@4 C") + tag_len = buf.getbyte(4) tag = buf.byteslice(5, tag_len).force_encoding("utf-8") @channels[channel_id].reply [:basic_consume_ok, tag] when 30 # cancel - tag_len = buf.unpack1("@4 C") + tag_len = buf.getbyte(4) tag = buf.byteslice(5, tag_len).force_encoding("utf-8") - no_wait = buf[5 + tag_len].ord == 1 + no_wait = buf.getbyte(5 + tag_len) == 1 @channels[channel_id].close_consumer(tag) write_bytes FrameBytes.basic_cancel_ok(@id, tag) unless no_wait when 31 # cancel-ok - tag_len = buf.unpack1("@4 C") + tag_len = buf.getbyte(4) tag = buf.byteslice(5, tag_len).force_encoding("utf-8") @channels[channel_id].reply [:basic_cancel_ok, tag] when 50 # return reply_code, reply_text_len = buf.unpack("@4 S> C") pos = 7 reply_text = buf.byteslice(pos, reply_text_len).force_encoding("utf-8") pos += reply_text_len - exchange_len = buf[pos].ord + exchange_len = buf.getbyte(pos) pos += 1 exchange = buf.byteslice(pos, exchange_len).force_encoding("utf-8") pos += exchange_len - routing_key_len = buf[pos].ord + routing_key_len = buf.getbyte(pos) pos += 1 routing_key = buf.byteslice(pos, routing_key_len).force_encoding("utf-8") @channels[channel_id].message_returned(reply_code, reply_text, exchange, routing_key) when 60 # deliver - ctag_len = buf[4].ord + ctag_len = buf.getbyte(4) consumer_tag = buf.byteslice(5, ctag_len).force_encoding("utf-8") pos = 5 + ctag_len delivery_tag, redelivered, exchange_len = buf.byteslice(pos, 10).unpack("Q> C C") pos += 8 + 1 + 1 exchange = buf.byteslice(pos, exchange_len).force_encoding("utf-8") pos += exchange_len - rk_len = buf[pos].ord + rk_len = buf.getbyte(pos) pos += 1 routing_key = buf.byteslice(pos, rk_len).force_encoding("utf-8") @channels[channel_id].message_delivered(consumer_tag, delivery_tag, redelivered == 1, exchange, routing_key) when 71 # get-ok delivery_tag, redelivered, exchange_len = buf.unpack("@4 Q> C C") pos = 14 exchange = buf.byteslice(pos, exchange_len).force_encoding("utf-8") pos += exchange_len - routing_key_len = buf[pos].ord + routing_key_len = buf.getbyte(pos) pos += 1 routing_key = buf.byteslice(pos, routing_key_len).force_encoding("utf-8") - pos += routing_key_len - _message_count = buf.byteslice(pos, 4).unpack1("L>") + # pos += routing_key_len + # message_count = buf.byteslice(pos, 4).unpack1("L>") @channels[channel_id].message_delivered(nil, delivery_tag, redelivered == 1, exchange, routing_key) when 72 # get-empty @channels[channel_id].basic_get_empty when 80 # ack delivery_tag, multiple = buf.unpack("@4 Q> C") @@ -375,13 +368,36 @@ end frame_type == expected_frame_type || raise(Error::UnexpectedFrame.new(expected_frame_type, frame_type)) args end + # Connect to the host/port, optionally establish a TLS connection + # @return [Socket] + # @return [OpenSSL::SSL::SSLSocket] + def open_socket(host, port, tls, options) + connect_timeout = options.fetch(:connect_timeout, 30).to_i + socket = Socket.tcp host, port, connect_timeout: connect_timeout + enable_tcp_keepalive(socket) + if tls + cert_store = OpenSSL::X509::Store.new + cert_store.set_default_paths + context = OpenSSL::SSL::SSLContext.new + context.cert_store = cert_store + verify_peer = [false, "false", "none"].include? options[:verify_peer] + context.verify_mode = OpenSSL::SSL::VERIFY_PEER unless verify_peer + socket = OpenSSL::SSL::SSLSocket.new(socket, context) + socket.sync_close = true # closing the TLS socket also closes the TCP socket + socket.hostname = host # SNI host + socket.connect + socket.post_connection_check(host) || raise(Error, "TLS certificate hostname doesn't match requested") + end + socket + end + # Negotiate a connection # @return [Array<Integer, Integer, Integer>] channel_max, frame_max, heartbeat - def self.establish(socket, user, password, vhost, options) + def establish(socket, user, password, vhost, options) channel_max, frame_max, heartbeat = nil socket.write "AMQP\x00\x00\x09\x01" buf = String.new(capacity: 4096) loop do begin @@ -389,11 +405,11 @@ rescue IOError, OpenSSL::OpenSSLError, SystemCallError => e raise Error, "Could not establish AMQP connection: #{e.message}" end type, channel_id, frame_size = buf.unpack("C S> L>") - frame_end = buf[frame_size + 7].ord + frame_end = buf.getbyte(frame_size + 7) raise UnexpectedFrameEndError, frame_end if frame_end != 206 case type when 1 # method frame class_id, method_id = buf.unpack("@7 S> S>") @@ -435,25 +451,28 @@ nil end raise e end - def self.enable_tcp_keepalive(socket) + # Enable TCP keepalive, which is prefered to heartbeats + # @return [void] + def enable_tcp_keepalive(socket) socket.setsockopt(Socket::SOL_SOCKET, Socket::SO_KEEPALIVE, true) socket.setsockopt(Socket::SOL_TCP, Socket::TCP_KEEPIDLE, 60) socket.setsockopt(Socket::SOL_TCP, Socket::TCP_KEEPINTVL, 10) socket.setsockopt(Socket::SOL_TCP, Socket::TCP_KEEPCNT, 3) rescue StandardError => e warn "AMQP-Client could not enable TCP keepalive on socket. #{e.inspect}" end - def self.port_from_env + # Fetch the AMQP port number from ENV + # @return [Integer] A port number + # @return [nil] When the environment variable AMQP_PORT isn't set + def port_from_env return unless (port = ENV["AMQP_PORT"]) port.to_i end - - private_class_method :establish, :enable_tcp_keepalive, :port_from_env CLIENT_PROPERTIES = { capabilities: { authentication_failure_close: true, publisher_confirms: true,