# frozen_string_literal: true

require 'rack'
require 'thrift'
require 'logger'

# Multiprotocol Thrift Rack application
class MultiprotocolThriftRackApp
  def initialize(
    processor,
    protocol_factory_map,
    logger: default_logger,
    buffered: false
  )
    @processor = processor
    @protocol_factory_map = protocol_factory_map.freeze
    @logger = logger
    @buffered = buffered
  end

  def call(env)
    request = Rack::Request.new(env)
    return failure_response('Not POST method') unless request.post?

    protocol_factory, content_type = find_protocol_factory(request)
    return failure_response('Unknown Content-Type') unless protocol_factory

    successful_response(request.body, protocol_factory, content_type)
  end

  private

  CONTENT_TYPE_ENV = 'CONTENT_TYPE'

  def default_logger
    Logger.new(STDERR, level: Logger::INFO)
  end

  def failure_response(error_message)
    Rack::Response.new(error_message, 400, {})
  end

  def fetch_content_type(request)
    request.get_header(CONTENT_TYPE_ENV)
  end

  def debug_protocol_factory(protocol_factory, content_type)
    if protocol_factory
      @logger.debug("Match #{content_type} for #{protocol_factory}")
    else
      @logger.error("Unexpected Content-Type #{content_type}")
    end
  end

  def find_protocol_factory(request)
    content_type = fetch_content_type(request)

    protocol_factory, =
      @protocol_factory_map.find do |(_protocol_factory, content_types)|
        content_types.include?(content_type)
      end

    debug_protocol_factory(protocol_factory, content_type)

    [protocol_factory, content_type]
  end

  def build_transport(raw_transport)
    if @buffered
      Thrift::BufferedTransport.new(raw_transport)
    else
      raw_transport
    end
  end

  def successful_response(request_body, protocol_factory, content_type)
    Rack::Response.new(
      [],
      200,
      Rack::CONTENT_TYPE => content_type,
    ) do |response|
      raw_transport = Thrift::IOStreamTransport.new(request_body, response)
      transport = build_transport(raw_transport)
      protocol = protocol_factory.get_protocol(transport)
      @processor.process(protocol, protocol)
    end
  end
end