lib/avro/ipc.rb in avro-1.3.0 vs lib/avro/ipc.rb in avro-1.3.3

- old
+ new

@@ -11,11 +11,10 @@ # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -require 'stringio' module Avro::IPC class AvroRemoteError < Avro::AvroError; end @@ -160,29 +159,32 @@ datum_writer.write(request_datum, encoder) end def read_handshake_response(decoder) handshake_response = HANDSHAKE_REQUESTOR_READER.read(decoder) - case match = handshake_response['match'] + we_have_matching_schema = false + + case handshake_response['match'] when 'BOTH' self.send_protocol = false - true + we_have_matching_schema = true when 'CLIENT' raise AvroError.new('Handshake failure. match == CLIENT') if send_protocol - self.remote_protocol = handshake_response['serverProtocol'] + self.remote_protocol = Avro::Protocol.parse(handshake_response['serverProtocol']) self.remote_hash = handshake_response['serverHash'] self.send_protocol = false - false + we_have_matching_schema = true when 'NONE' raise AvroError.new('Handshake failure. match == NONE') if send_protocol - self.remote_protocol = handshake_response['serverProtocol'] + self.remote_protocol = Avro::Protocol.parse(handshake_response['serverProtocol']) self.remote_hash = handshake_response['serverHash'] self.send_protocol = true - false else raise AvroError.new("Unexpected match: #{match}") end + + return we_have_matching_schema end def read_call_response(message_name, decoder) # The format of a call response is: # * response metadata, a map with values of type bytes @@ -234,23 +236,21 @@ @local_hash = self.local_protocol.md5 @protocol_cache = {} protocol_cache[local_hash] = local_protocol end - def respond(transport) - # Called by a server to deserialize a request, compute and serialize - # a response or error. Compare to 'handle()' in Thrift. - - call_request = transport.read_framed_message + # Called by a server to deserialize a request, compute and serialize + # a response or error. Compare to 'handle()' in Thrift. + def respond(call_request) buffer_decoder = Avro::IO::BinaryDecoder.new(StringIO.new(call_request)) buffer_writer = StringIO.new('', 'w+') buffer_encoder = Avro::IO::BinaryEncoder.new(buffer_writer) error = nil response_metadata = {} begin - remote_protocol = process_handshake(transport, buffer_decoder, buffer_encoder) + remote_protocol = process_handshake(buffer_decoder, buffer_encoder) # handshake failure unless remote_protocol return buffer_writer.string end @@ -298,20 +298,21 @@ self.write_error(SYSTEM_ERROR_SCHEMA, error, buffer_encoder) end buffer_writer.string end - def process_handshake(transport, decoder, encoder) + def process_handshake(decoder, encoder) handshake_request = HANDSHAKE_RESPONDER_READER.read(decoder) handshake_response = {} # determine the remote protocol client_hash = handshake_request['clientHash'] client_protocol = handshake_request['clientProtocol'] remote_protocol = protocol_cache[client_hash] + if !remote_protocol && client_protocol - remote_protocol = protocol.parse(client_protocol) + remote_protocol = Avro::Protocol.parse(client_protocol) protocol_cache[client_hash] = remote_protocol end # evaluate remote's guess of the local protocol server_hash = handshake_request['serverHash'] @@ -420,24 +421,115 @@ total_bytes_sent += bytes_sent end end def write_buffer_length(n) - bytes_sent = sock.write([n].pack('I')) + bytes_sent = sock.write([n].pack('N')) if bytes_sent == 0 raise ConnectionClosedException.new("socket sent 0 bytes") end end def read_buffer_length read = sock.read(BUFFER_HEADER_LENGTH) if read == '' || read == nil raise ConnectionClosedException.new("Socket read 0 bytes.") end - read.unpack('I')[0] + read.unpack('N')[0] end def close sock.close + end + end + + class ConnectionClosedError < StandardError; end + + class FramedWriter + attr_reader :writer + def initialize(writer) + @writer = writer + end + + def write_framed_message(message) + message_size = message.size + total_bytes_sent = 0 + while message_size - total_bytes_sent > 0 + if message_size - total_bytes_sent > BUFFER_SIZE + buffer_size = BUFFER_SIZE + else + buffer_size = message_size - total_bytes_sent + end + write_buffer(message[total_bytes_sent, buffer_size]) + total_bytes_sent += buffer_size + end + write_buffer_size(0) + end + + def to_s; writer.string; end + + private + def write_buffer(chunk) + buffer_size = chunk.size + write_buffer_size(buffer_size) + writer << chunk + end + + def write_buffer_size(n) + writer.write([n].pack('N')) + end + end + + class FramedReader + attr_reader :reader + + def initialize(reader) + @reader = reader + end + + def read_framed_message + message = [] + loop do + buffer = "" + buffer_size = read_buffer_size + + return message.join if buffer_size == 0 + + while buffer.size < buffer_size + chunk = reader.read(buffer_size - buffer.size) + chunk_error?(chunk) + buffer << chunk + end + message << buffer + end + end + + private + def read_buffer_size + header = reader.read(BUFFER_HEADER_LENGTH) + chunk_error?(header) + header.unpack('N')[0] + end + + def chunk_error?(chunk) + raise ConnectionClosedError.new("Reader read 0 bytes") if chunk == '' + end + end + + # Only works for clients. Sigh. + class HTTPTransceiver + attr_reader :remote_name, :host, :port + def initialize(host, port) + @host, @port = host, port + @remote_name = "#{host}:#{port}" + end + + def transceive(message) + writer = FramedWriter.new(StringIO.new) + writer.write_framed_message(message) + resp = Net::HTTP.start(host, port) do |http| + http.post('/', writer.to_s, {'Content-Type' => 'avro/binary'}) + end + FramedReader.new(StringIO.new(resp.body)).read_framed_message end end end