lib/rubydns/resolver.rb in rubydns-0.8.5 vs lib/rubydns/resolver.rb in rubydns-0.9.0
- old
+ new
@@ -16,255 +16,203 @@
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
-require_relative 'message'
-require_relative 'binary_string'
+require_relative 'handler'
require 'securerandom'
+require 'celluloid/io'
module RubyDNS
class InvalidProtocolError < StandardError
end
+ class InvalidResponseError < StandardError
+ end
+
class ResolutionFailure < StandardError
end
class Resolver
+ include Celluloid::IO
+
# Servers are specified in the same manor as options[:listen], e.g.
# [:tcp/:udp, address, port]
# In the case of multiple servers, they will be checked in sequence.
def initialize(servers, options = {})
@servers = servers
@options = options
+
+ @logger = options[:logger] || Celluloid.logger
end
# Provides the next sequence identification number which is used to keep track of DNS messages.
def next_id!
# Using sequential numbers for the query ID is generally a bad thing because over UDP they can be spoofed. 16-bits isn't hard to guess either, but over UDP we also use a random port, so this makes effectively 32-bits of entropy to guess per request.
SecureRandom.random_number(2**16)
end
# Look up a named resource of the given resource_class.
- def query(name, resource_class = Resolv::DNS::Resource::IN::A, &block)
+ def query(name, resource_class = Resolv::DNS::Resource::IN::A)
message = Resolv::DNS::Message.new(next_id!)
message.rd = 1
message.add_question name, resource_class
- send_message(message, &block)
+ dispatch_request(message)
end
-
- def send_message(message, &block)
- Request.fetch(message, @servers, @options, &block)
- end
-
- # Yields a list of `Resolv::IPv4` and `Resolv::IPv6` addresses for the given `name` and `resource_class`.
- def addresses_for(name, resource_class = Resolv::DNS::Resource::IN::A, &block)
- query(name, resource_class) do |response|
- # Resolv::DNS::Name doesn't retain the trailing dot.
- name = name.sub(/\.$/, '')
+
+ # Yields a list of `Resolv::IPv4` and `Resolv::IPv6` addresses for the given `name` and `resource_class`. Raises a ResolutionFailure if no severs respond.
+ def addresses_for(name, resource_class = Resolv::DNS::Resource::IN::A, options = {})
+ (options[:retries] || 5).times do
+ response = query(name, resource_class)
- case response
- when Message
- yield response.answer.select{|record| record[0].to_s == name}.collect{|record| record[2].address}
- else
- yield []
+ if response
+ # Resolv::DNS::Name doesn't retain the trailing dot.
+ name = name.sub(/\.$/, '')
+
+ return response.answer.select{|record| record[0].to_s == name}.collect{|record| record[2].address}
end
+
+ # Wait 10ms before trying again:
+ sleep 0.01
end
+
+ abort ResolutionFailure.new("No server replied.")
end
-
- # Manages a single DNS question message across one or more servers.
- class Request
- include EventMachine::Deferrable
+
+ def request_timeout
+ @options[:timeout] || 1
+ end
+
+ # Send the message to available servers. If no servers respond correctly, nil is returned. This result indicates a failure of the resolver to correctly contact any server and get a valid response.
+ def dispatch_request(message)
+ request = Request.new(message, @servers)
- def self.fetch(*args)
- request = self.new(*args)
+ request.each do |server|
+ @logger.debug "[#{message.id}] Sending request to server #{server.inspect}" if @logger
- request.callback do |message|
- yield message
- end
-
- request.errback do |error|
- # In the case of a timeout, error will be nil, so we make one up.
+ begin
+ response = nil
- yield error
+ timeout(request_timeout) do
+ response = try_server(request, server)
+ end
+
+ if valid_response(message, response)
+ return response
+ end
+ rescue Task::TimeoutError
+ @logger.debug "[#{message.id}] Request timed out!" if @logger
+ rescue InvalidResponseError
+ @logger.warn "[#{message.id}] Invalid response from network: #{$!}!" if @logger
+ rescue DecodeError
+ @logger.warn "[#{message.id}] Error while decoding data from network: #{$!}!" if @logger
+ rescue IOError
+ @logger.warn "[#{message.id}] Error while reading from network: #{$!}!" if @logger
end
-
- request.run!
end
- def initialize(message, servers, options = {}, &block)
+ return nil
+ end
+
+ private
+
+ def try_server(request, server)
+ case server[0]
+ when :udp
+ try_udp_server(request, server[1], server[2])
+ when :tcp
+ try_tcp_server(request, server[1], server[2])
+ else
+ raise InvalidProtocolError.new(server)
+ end
+ end
+
+ def valid_response(message, response)
+ if response.tc != 0
+ @logger.warn "[#{message.id}] Received truncated response!" if @logger
+ elsif response.id != message.id
+ @logger.warn "[#{message.id}] Received response with incorrect message id: #{response.id}!" if @logger
+ else
+ @logger.debug "[#{message.id}] Received valid response with #{response.answer.count} answer(s)." if @logger
+
+ return true
+ end
+
+ return false
+ end
+
+ def try_udp_server(request, host, port)
+ socket = UDPSocket.new
+
+ socket.send(request.packet, 0, host, port)
+
+ data, (_, remote_port) = socket.recvfrom(UDP_TRUNCATION_SIZE)
+ # Need to check host, otherwise security issue.
+
+ # May indicate some kind of spoofing attack:
+ if port != remote_port
+ raise InvalidResponseError.new("Data was not received from correct remote port (#{port} != #{remote_port})")
+ end
+
+ message = RubyDNS::decode_message(data)
+ ensure
+ socket.close if socket
+ end
+
+ def try_tcp_server(request, host, port)
+ begin
+ socket = TCPSocket.new(host, port)
+ rescue Errno::EALREADY
+ raise IOError.new("Could not connect to remote host!")
+ end
+
+ StreamTransport.write_chunk(socket, request.packet)
+
+ input_data = StreamTransport.read_chunk(socket)
+
+ message = RubyDNS::decode_message(input_data)
+ rescue Errno::ECONNREFUSED => error
+ raise IOError.new(error.message)
+ rescue Errno::EPIPE => error
+ raise IOError.new(error.message)
+ rescue Errno::ECONNRESET => error
+ raise IOError.new(error.message)
+ ensure
+ socket.close if socket
+ end
+
+ # Manages a single DNS question message across one or more servers.
+ class Request
+ def initialize(message, servers)
@message = message
@packet = message.encode
@servers = servers.dup
# We select the protocol based on the size of the data:
if @packet.bytesize > UDP_TRUNCATION_SIZE
@servers.delete_if{|server| server[0] == :udp}
end
-
- # Measured in seconds:
- @timeout = options[:timeout] || 1
-
- @logger = options[:logger]
end
attr :message
attr :packet
attr :logger
- def run!
- try_next_server!
- end
-
- # Once either an exception or message is received, we update the status of this request.
- def process_response!(response)
- finish_request!
-
- if Exception === response
- @logger.warn "[#{@message.id}] Failure while processing response #{response}!" if @logger
- RubyDNS.log_exception(@logger, response) if @logger
+ def each(&block)
+ @servers.each do |server|
+ next if @packet.bytesize > UDP_TRUNCATION_SIZE
- try_next_server!
- elsif response.tc != 0
- @logger.warn "[#{@message.id}] Received truncated response!" if @logger
-
- try_next_server!
- elsif response.id != @message.id
- @logger.warn "[#{@message.id}] Received response with incorrect message id: #{response.id}" if @logger
-
- try_next_server!
- else
- @logger.debug "[#{@message.id}] Received valid response #{response.inspect}" if @logger
-
- succeed response
+ yield server
end
end
-
- private
-
- # Closes any connections and cancels any timeout.
- def finish_request!
- cancel_timeout
-
- # Cancel an existing request if it is in flight:
- if @request
- @request.close_connection
- @request = nil
- end
- end
-
- def try_next_server!
- if @servers.size > 0
- @server = @servers.shift
-
- @logger.debug "[#{@message.id}] Sending request to server #{@server.inspect}" if @logger
-
- # We make requests one at a time to the given server, naturally the servers are ordered in terms of priority.
- case @server[0]
- when :udp
- @request = UDPRequestHandler.open(@server[1], @server[2], self)
- when :tcp
- @request = TCPRequestHandler.open(@server[1], @server[2], self)
- else
- raise InvalidProtocolError.new(@server)
- end
-
- # Setting up the timeout...
- timeout(@timeout)
- else
- fail ResolutionFailure.new("No available servers responded to the request.")
- end
- end
-
- def timeout seconds
- cancel_timeout
-
- @deferred_timeout = EventMachine::Timer.new(seconds) do
- @logger.debug "[#{@message.id}] Request timed out!" if @logger
-
- finish_request!
-
- try_next_server!
- end
- end
-
- module UDPRequestHandler
- def self.open(host, port, request)
- # Open a datagram socket... a random socket chosen by the OS by specifying 0 for the port:
- EventMachine::open_datagram_socket('', 0, self, request, host, port)
- end
-
- def initialize(request, host, port)
- @request = request
- @host = host
- @port = port
- end
-
- def post_init
- # Sending question to remote DNS server...
- send_datagram(@request.packet, @host, @port)
- end
-
- def receive_data(data)
- # local_port, local_ip = Socket.unpack_sockaddr_in(get_sockname)
- # puts "Socket name: #{local_ip}:#{local_port}"
-
- # Receiving response from remote DNS server...
- message = RubyDNS::decode_message(data)
-
- # The message id must match, and it can't be truncated:
- @request.process_response!(message)
- rescue Resolv::DNS::DecodeError => error
- @request.process_response!(error)
- end
- end
-
- module TCPRequestHandler
- def self.open(host, port, request)
- EventMachine::connect(host, port, TCPRequestHandler, request)
- end
-
- def initialize(request)
- @request = request
- @buffer = nil
- @length = nil
- end
-
- def post_init
- data = @request.packet
-
- send_data([data.bytesize].pack('n'))
- send_data data
- end
-
- def receive_data(data)
- # We buffer data until we've received the entire packet:
- @buffer ||= BinaryStringIO.new
- @buffer.write(data)
- # If we've received enough data and we haven't figured out the length yet...
- if @length == nil and @buffer.size > 2
- # Extract the length from the buffer:
- @length = @buffer.string.byteslice(0, 2).unpack('n')[0]
- end
-
- # If we know what the length is, and we've got that much data, we can decode the message:
- if @length != nil and @buffer.size >= (@length + 2)
- data = @buffer.string.byteslice(2, @length)
-
- message = RubyDNS::decode_message(data)
-
- @request.process_response!(message)
- end
-
- # If we have received more data than expected, should this be an error?
- rescue Resolv::DNS::DecodeError => error
- @request.process_response!(error)
- end
+ def update_id!(id)
+ @message.id = id
+ @packet = @message.encode
end
end
end
-end
\ No newline at end of file
+end