# frozen_string_literal: true

module OpenTracing
  module Instrumentation
    module Thrift
      # Trace thrift processor
      #
      # Usage:
      #   processor =
      #     OrderService::Processor.new(orders_handler)
      #   traced_processor =
      #     OpenTracing::Instrumentation::Thrift::TracedProcessor.new(processor)
      class TracedProcessor
        extend Forwardable

        # @private
        class ReadCachedProtocol
          include ::Thrift::ProtocolDecorator

          def read_message_begin
            @read_message_begin ||= @protocol.read_message_begin
          end

          def ==(other)
            @protocol == other.protocol
          end

          protected

          attr_reader :protocol
        end

        # @parama processor [Thrift::Processor] traced processor
        # @param config [TracedProcessorConfig]
        # @yieldparam [TracedProcessorConfig]
        def initialize(processor, config: TracedProcessorConfig.new)
          @processor = processor
          yield config if block_given?
          @config = config.dup
        end

        # @param iproto [Thrift::Protocol] input protocol
        # @param oproto [Thrift::Protocol] output protocol
        def process(iproto, oproto)
          trace_process(iproto) do |cached_iproto|
            processor.process(
              wrap_protocol(cached_iproto),
              wrap_protocol(oproto),
            )
          end
        end

        private

        attr_reader :processor
        attr_reader :config

        def_delegators :config,
                       :tracer,
                       :trace_protocol,
                       :error_writer,
                       :operation_name_builder,
                       :tags_builder,
                       :logger

        def trace_process(iproto)
          cached_iproto = ReadCachedProtocol.new(iproto)

          start_time = Time.now

          name, type, seq_id = cached_iproto.read_message_begin

          scope = safe_start_scope(iproto, name, type, seq_id, start_time)

          yield cached_iproto
        rescue StandardError => e
          error_writer.write_error(scope.span, e) if scope&.span
          raise e
        ensure
          safe_close_scope(scope)
        end

        def safe_start_scope(protocol, name, type, seq_id, start_time)
          operation_name = operation_name_builder.build(name, type, seq_id)
          tags = tags_builder.build_tags(protocol, name, type)
          tracer.start_active_span(
            operation_name,
            start_time: start_time,
            tags: tags,
          )
        rescue StandardError => e
          logger&.error(e)
        end

        def safe_close_scope(scope)
          scope&.close
        rescue StandardError => e
          logger&.error(e)
        end

        def wrap_protocol(protocol)
          return protocol unless trace_protocol

          TracedProtocol.new(protocol) do |config|
            config.tracer = tracer
          end
        end
      end
    end
  end
end