lib/socketry/ssl/socket.rb in socketry-0.4.0 vs lib/socketry/ssl/socket.rb in socketry-0.5.0
- old
+ new
@@ -13,10 +13,11 @@
# @param resolver [Object] A resolver object to use for resolving DNS names
# @param socket_class [Object] Underlying socket class which implements I/O ops
# @param ssl_socket_class [Object] Class which provides the underlying SSL implementation
# @param ssl_context [OpenSSL::SSL::SSLContext] SSL configuration object
# @param ssL_params [Hash] Parameter hash to set on the given SSL context
+ #
# @return [Socketry::SSL::Socket]
def initialize(
ssl_socket_class: OpenSSL::SSL::SSLSocket,
ssl_context: OpenSSL::SSL::SSLContext.new,
ssl_params: nil,
@@ -42,10 +43,11 @@
# @param local_addr [String] DNS name or IP address to bind to locally
# @param local_port [Fixnum] Local TCP port to bind to
# @param timeout [Numeric] Number of seconds to wait before aborting connect
# @param enable_sni [true, false] (default: true) Enables Server Name Indication (SNI)
# @param verify_hostname [true, false] (default: true) Ensure server's hostname matches cert
+ #
# @raise [Socketry::AddressError] an invalid address was given
# @raise [Socketry::TimeoutError] connect operation timed out
# @raise [Socketry::SSL::Error] an error occurred negotiating an SSL connection
# @return [self]
def connect(
@@ -57,22 +59,24 @@
enable_sni: true,
verify_hostname: true
)
super(remote_addr, remote_port, local_addr: local_addr, local_port: local_port, timeout: timeout)
- @ssl_socket = OpenSSL::SSL::SSLSocket.new(@socket, @ssl_context)
+ @ssl_socket = @ssl_socket_class.new(@socket, @ssl_context)
@ssl_socket.hostname = remote_addr if enable_sni
+ @ssl_socket.sync_close = true
begin
@ssl_socket.connect_nonblock
rescue IO::WaitReadable
retry if @socket.wait_readable(timeout)
raise Socketry::TimeoutError, "connection to #{remote_addr}:#{remote_port} timed out"
rescue IO::WaitWritable
retry if @socket.wait_writable(timeout)
raise Socketry::TimeoutError, "connection to #{remote_addr}:#{remote_port} timed out"
rescue OpenSSL::SSL::SSLError => ex
+ raise Socketry::SSL::CertificateVerifyError, ex.message if ex.message.include?("certificate verify failed")
raise Socketry::SSL::Error, ex.message, ex.backtrace
end
begin
@ssl_socket.post_connection_check(remote_addr) if verify_hostname
@@ -87,31 +91,54 @@
@ssl_socket.close rescue nil
@ssl_socket = nil
raise ex
end
+ # Accept an SSL connection from a Socketry or Ruby socket
+ #
+ # @param tcp_socket [TCPSocket, Socketry::TCP::Socket] raw TCP socket to begin SSL handshake with
+ # @param timeout [Numeric, NilClass] (default nil, unlimited) seconds to wait before aborting the accept
+ #
+ # @return [self]
+ def accept(tcp_socket, timeout: nil)
+ tcp_socket = IO.try_convert(tcp_socket) || raise(TypeError, "couldn't convert #{tcp_socket.class} to IO")
+ ssl_socket = @ssl_socket_class.new(tcp_socket, @ssl_context)
+
+ begin
+ ssl_socket.accept_nonblock
+ rescue IO::WaitReadable
+ retry if IO.select([tcp_socket], nil, nil, timeout)
+ raise Socketry::TimeoutError, "failed to complete handshake after #{timeout} seconds"
+ rescue IO::WaitWritable
+ retry if IO.select(nil, [tcp_socket], nil, timeout)
+ raise Socketry::TimeoutError, "failed to complete handshake after #{timeout} seconds"
+ end
+
+ from_socket(ssl_socket)
+ end
+
# Wrap a Ruby OpenSSL::SSL::SSLSocket (or other low-level SSL socket)
#
- # @param socket [::Socket] (or specified socket_class) low-level socket to wrap
# @param ssl_socket [OpenSSL::SSL::SSLSocket] SSL socket class associated with this socket
+ #
# @return [self]
- def from_socket(socket, ssl_socket)
- raise TypeError, "expected #{@socket_class}, got #{socket.class}" unless socket.is_a?(@socket_class)
+ def from_socket(ssl_socket)
raise TypeError, "expected #{@ssl_socket_class}, got #{ssl_socket.class}" unless ssl_socket.is_a?(@ssl_socket_class)
- raise StateError, "already connected" if @socket && @socket != socket
+ raise StateError, "already connected" if @socket
- @socket = socket
+ @socket = ssl_socket.to_io
@ssl_socket = ssl_socket
@ssl_socket.sync_close = true
self
end
# Perform a non-blocking read operation
#
# @param size [Fixnum] number of bytes to attempt to read
# @param outbuf [String, NilClass] an optional buffer into which data should be read
+ #
# @raise [Socketry::Error] an I/O operation failed
# @return [String, :wait_readable] data read, or :wait_readable if operation would block
def read_nonblock(size, outbuf: nil)
case outbuf
when String
@@ -123,9 +150,10 @@
end
# Perform a non-blocking write operation
#
# @param data [String] number of bytes to attempt to read
+ #
# @raise [Socketry::Error] an I/O operation failed
# @return [Fixnum, :wait_writable] number of bytes written, or :wait_writable if op would block
def write_nonblock(data)
perform { @ssl_socket.write_nonblock(data, exception: false) }
end