Sha256: d7bf5b1e512670fcc0de57b9bfb4be0058ac8d3bc41e828d0096b48fcec9dbe4
Contents?: true
Size: 1.57 KB
Versions: 1
Compression:
Stored size: 1.57 KB
Contents
require 'faraday' module HuggingFace class InterfaceApi HOST = "https://api-inference.huggingface.co" MAX_RETRY = 2 HTTP_SEVICE_UNAVAILABLE = 503 QUESTION_ANSWERING_MODEL = 'distilbert-base-cased-distilled-squad' SUMMARIZATION_MODEL = "sshleifer/distilbart-xsum-12-6" GENERATION_MODEL = "distilgpt2" def initialize(api_token:) @headers = { 'Authorization' => 'Bearer ' + api_token, 'Content-Type' => 'application/json' } end def call(input:, model:) request(connection: connection(model), input: input) end def question_answering(question:, context:, model: QUESTION_ANSWERING_MODEL) input = { question: question, context: context } request(connection: connection(model), input: input) end def summarization(input:, model: SUMMARIZATION_MODEL) request(connection: connection(model), input: { inputs: input }) end def text_generation(input:, model: GENERATION_MODEL) request(connection: connection(model), input: { inputs: input }) end private def request(connection:, input:) retries = 0 while retries < MAX_RETRY response = connection.post { |req| req.body = input.to_json } break if response.success? if response.status == HTTP_SEVICE_UNAVAILABLE retries += 1 sleep 1 redo end raise "Error: #{response.body}" end return JSON.parse(response.body) end def connection(model) Faraday.new(url: "#{HOST}/models/#{model}" , headers: @headers) end end end
Version data entries
1 entries across 1 versions & 1 rubygems
Version | Path |
---|---|
hugging-face-0.1.0 | lib/hugging_face/interface_api.rb |