Sha256: 7e28fa53b90913bd55d0d7fc75e56e912bd9b6ad3bbb2cee7ce229f5cc923d16

Contents?: true

Size: 1.32 KB

Versions: 1

Compression:

Stored size: 1.32 KB

Contents

module HuggingFace
  class InferenceApi < BaseApi
    HOST = "https://api-inference.huggingface.co"
    MAX_RETRY = 20

    QUESTION_ANSWERING_MODEL = 'distilbert-base-cased-distilled-squad'
    SUMMARIZATION_MODEL = "sshleifer/distilbart-xsum-12-6"
    GENERATION_MODEL = "distilgpt2"

    def call(input:, model:)
      request(connection: connection(model), input: input)
    end

    def question_answering(question:, context:, model: QUESTION_ANSWERING_MODEL)
      input = { question: question, context: context }

      request(connection: connection(model), input: input)
    end

    def summarization(input:, model: SUMMARIZATION_MODEL)
      request(connection: connection(model), input: { inputs: input })
    end

    def text_generation(input:, model: GENERATION_MODEL)
      request(connection: connection(model), input: { inputs: input })
    end

    private

    def connection(model)
      super "#{HOST}/models/#{model}"
    end

    def request(connection:, input:)
      retries = 0

      begin
        return super(connection: connection, input: input)
      rescue ServiceUnavailable => exception

        if retries < MAX_RETRY
          logger.debug('Service unavailable, retrying...')
          retries += 1
          sleep 1
          retry
        else
          raise exception
        end
      end
    end
  end
end

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
hugging-face-0.2.0 lib/hugging_face/inference_api.rb