lib/rubydns/resolver.rb in rubydns-0.8.5 vs lib/rubydns/resolver.rb in rubydns-0.9.0

- old
+ new

@@ -16,255 +16,203 @@ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. -require_relative 'message' -require_relative 'binary_string' +require_relative 'handler' require 'securerandom' +require 'celluloid/io' module RubyDNS class InvalidProtocolError < StandardError end + class InvalidResponseError < StandardError + end + class ResolutionFailure < StandardError end class Resolver + include Celluloid::IO + # Servers are specified in the same manor as options[:listen], e.g. # [:tcp/:udp, address, port] # In the case of multiple servers, they will be checked in sequence. def initialize(servers, options = {}) @servers = servers @options = options + + @logger = options[:logger] || Celluloid.logger end # Provides the next sequence identification number which is used to keep track of DNS messages. def next_id! # Using sequential numbers for the query ID is generally a bad thing because over UDP they can be spoofed. 16-bits isn't hard to guess either, but over UDP we also use a random port, so this makes effectively 32-bits of entropy to guess per request. SecureRandom.random_number(2**16) end # Look up a named resource of the given resource_class. - def query(name, resource_class = Resolv::DNS::Resource::IN::A, &block) + def query(name, resource_class = Resolv::DNS::Resource::IN::A) message = Resolv::DNS::Message.new(next_id!) message.rd = 1 message.add_question name, resource_class - send_message(message, &block) + dispatch_request(message) end - - def send_message(message, &block) - Request.fetch(message, @servers, @options, &block) - end - - # Yields a list of `Resolv::IPv4` and `Resolv::IPv6` addresses for the given `name` and `resource_class`. - def addresses_for(name, resource_class = Resolv::DNS::Resource::IN::A, &block) - query(name, resource_class) do |response| - # Resolv::DNS::Name doesn't retain the trailing dot. - name = name.sub(/\.$/, '') + + # Yields a list of `Resolv::IPv4` and `Resolv::IPv6` addresses for the given `name` and `resource_class`. Raises a ResolutionFailure if no severs respond. + def addresses_for(name, resource_class = Resolv::DNS::Resource::IN::A, options = {}) + (options[:retries] || 5).times do + response = query(name, resource_class) - case response - when Message - yield response.answer.select{|record| record[0].to_s == name}.collect{|record| record[2].address} - else - yield [] + if response + # Resolv::DNS::Name doesn't retain the trailing dot. + name = name.sub(/\.$/, '') + + return response.answer.select{|record| record[0].to_s == name}.collect{|record| record[2].address} end + + # Wait 10ms before trying again: + sleep 0.01 end + + abort ResolutionFailure.new("No server replied.") end - - # Manages a single DNS question message across one or more servers. - class Request - include EventMachine::Deferrable + + def request_timeout + @options[:timeout] || 1 + end + + # Send the message to available servers. If no servers respond correctly, nil is returned. This result indicates a failure of the resolver to correctly contact any server and get a valid response. + def dispatch_request(message) + request = Request.new(message, @servers) - def self.fetch(*args) - request = self.new(*args) + request.each do |server| + @logger.debug "[#{message.id}] Sending request to server #{server.inspect}" if @logger - request.callback do |message| - yield message - end - - request.errback do |error| - # In the case of a timeout, error will be nil, so we make one up. + begin + response = nil - yield error + timeout(request_timeout) do + response = try_server(request, server) + end + + if valid_response(message, response) + return response + end + rescue Task::TimeoutError + @logger.debug "[#{message.id}] Request timed out!" if @logger + rescue InvalidResponseError + @logger.warn "[#{message.id}] Invalid response from network: #{$!}!" if @logger + rescue DecodeError + @logger.warn "[#{message.id}] Error while decoding data from network: #{$!}!" if @logger + rescue IOError + @logger.warn "[#{message.id}] Error while reading from network: #{$!}!" if @logger end - - request.run! end - def initialize(message, servers, options = {}, &block) + return nil + end + + private + + def try_server(request, server) + case server[0] + when :udp + try_udp_server(request, server[1], server[2]) + when :tcp + try_tcp_server(request, server[1], server[2]) + else + raise InvalidProtocolError.new(server) + end + end + + def valid_response(message, response) + if response.tc != 0 + @logger.warn "[#{message.id}] Received truncated response!" if @logger + elsif response.id != message.id + @logger.warn "[#{message.id}] Received response with incorrect message id: #{response.id}!" if @logger + else + @logger.debug "[#{message.id}] Received valid response with #{response.answer.count} answer(s)." if @logger + + return true + end + + return false + end + + def try_udp_server(request, host, port) + socket = UDPSocket.new + + socket.send(request.packet, 0, host, port) + + data, (_, remote_port) = socket.recvfrom(UDP_TRUNCATION_SIZE) + # Need to check host, otherwise security issue. + + # May indicate some kind of spoofing attack: + if port != remote_port + raise InvalidResponseError.new("Data was not received from correct remote port (#{port} != #{remote_port})") + end + + message = RubyDNS::decode_message(data) + ensure + socket.close if socket + end + + def try_tcp_server(request, host, port) + begin + socket = TCPSocket.new(host, port) + rescue Errno::EALREADY + raise IOError.new("Could not connect to remote host!") + end + + StreamTransport.write_chunk(socket, request.packet) + + input_data = StreamTransport.read_chunk(socket) + + message = RubyDNS::decode_message(input_data) + rescue Errno::ECONNREFUSED => error + raise IOError.new(error.message) + rescue Errno::EPIPE => error + raise IOError.new(error.message) + rescue Errno::ECONNRESET => error + raise IOError.new(error.message) + ensure + socket.close if socket + end + + # Manages a single DNS question message across one or more servers. + class Request + def initialize(message, servers) @message = message @packet = message.encode @servers = servers.dup # We select the protocol based on the size of the data: if @packet.bytesize > UDP_TRUNCATION_SIZE @servers.delete_if{|server| server[0] == :udp} end - - # Measured in seconds: - @timeout = options[:timeout] || 1 - - @logger = options[:logger] end attr :message attr :packet attr :logger - def run! - try_next_server! - end - - # Once either an exception or message is received, we update the status of this request. - def process_response!(response) - finish_request! - - if Exception === response - @logger.warn "[#{@message.id}] Failure while processing response #{response}!" if @logger - RubyDNS.log_exception(@logger, response) if @logger + def each(&block) + @servers.each do |server| + next if @packet.bytesize > UDP_TRUNCATION_SIZE - try_next_server! - elsif response.tc != 0 - @logger.warn "[#{@message.id}] Received truncated response!" if @logger - - try_next_server! - elsif response.id != @message.id - @logger.warn "[#{@message.id}] Received response with incorrect message id: #{response.id}" if @logger - - try_next_server! - else - @logger.debug "[#{@message.id}] Received valid response #{response.inspect}" if @logger - - succeed response + yield server end end - - private - - # Closes any connections and cancels any timeout. - def finish_request! - cancel_timeout - - # Cancel an existing request if it is in flight: - if @request - @request.close_connection - @request = nil - end - end - - def try_next_server! - if @servers.size > 0 - @server = @servers.shift - - @logger.debug "[#{@message.id}] Sending request to server #{@server.inspect}" if @logger - - # We make requests one at a time to the given server, naturally the servers are ordered in terms of priority. - case @server[0] - when :udp - @request = UDPRequestHandler.open(@server[1], @server[2], self) - when :tcp - @request = TCPRequestHandler.open(@server[1], @server[2], self) - else - raise InvalidProtocolError.new(@server) - end - - # Setting up the timeout... - timeout(@timeout) - else - fail ResolutionFailure.new("No available servers responded to the request.") - end - end - - def timeout seconds - cancel_timeout - - @deferred_timeout = EventMachine::Timer.new(seconds) do - @logger.debug "[#{@message.id}] Request timed out!" if @logger - - finish_request! - - try_next_server! - end - end - - module UDPRequestHandler - def self.open(host, port, request) - # Open a datagram socket... a random socket chosen by the OS by specifying 0 for the port: - EventMachine::open_datagram_socket('', 0, self, request, host, port) - end - - def initialize(request, host, port) - @request = request - @host = host - @port = port - end - - def post_init - # Sending question to remote DNS server... - send_datagram(@request.packet, @host, @port) - end - - def receive_data(data) - # local_port, local_ip = Socket.unpack_sockaddr_in(get_sockname) - # puts "Socket name: #{local_ip}:#{local_port}" - - # Receiving response from remote DNS server... - message = RubyDNS::decode_message(data) - - # The message id must match, and it can't be truncated: - @request.process_response!(message) - rescue Resolv::DNS::DecodeError => error - @request.process_response!(error) - end - end - - module TCPRequestHandler - def self.open(host, port, request) - EventMachine::connect(host, port, TCPRequestHandler, request) - end - - def initialize(request) - @request = request - @buffer = nil - @length = nil - end - - def post_init - data = @request.packet - - send_data([data.bytesize].pack('n')) - send_data data - end - - def receive_data(data) - # We buffer data until we've received the entire packet: - @buffer ||= BinaryStringIO.new - @buffer.write(data) - # If we've received enough data and we haven't figured out the length yet... - if @length == nil and @buffer.size > 2 - # Extract the length from the buffer: - @length = @buffer.string.byteslice(0, 2).unpack('n')[0] - end - - # If we know what the length is, and we've got that much data, we can decode the message: - if @length != nil and @buffer.size >= (@length + 2) - data = @buffer.string.byteslice(2, @length) - - message = RubyDNS::decode_message(data) - - @request.process_response!(message) - end - - # If we have received more data than expected, should this be an error? - rescue Resolv::DNS::DecodeError => error - @request.process_response!(error) - end + def update_id!(id) + @message.id = id + @packet = @message.encode end end end -end \ No newline at end of file +end