Sha256: e6b8e648ec9fc3dd05ef55c760cd156b21b32ac146252a562b93a9cfa788a593
Contents?: true
Size: 1.73 KB
Versions: 3
Compression:
Stored size: 1.73 KB
Contents
module HuggingFace class InferenceApi < BaseApi HOST = "https://api-inference.huggingface.co" # Retry connecting to the model for 1 minute MAX_RETRY = 60 # Deafult models that can be overriden by 'model' param QUESTION_ANSWERING_MODEL = 'distilbert-base-cased-distilled-squad' SUMMARIZATION_MODEL = "sshleifer/distilbart-xsum-12-6" GENERATION_MODEL = "distilgpt2" EMBEDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" def call(input:, model:) request(connection: connection(model), input: input) end def question_answering(question:, context:, model: QUESTION_ANSWERING_MODEL) input = { question: question, context: context } request connection: connection(model), input: input end def summarization(input:, model: SUMMARIZATION_MODEL) request connection: connection(model), input: { inputs: input } end def text_generation(input:, model: GENERATION_MODEL) request connection: connection(model), input: { inputs: input } end def embedding(input:) request connection: connection(EMBEDING_MODEL), input: { inputs: input } end private def connection(model) if model == EMBEDING_MODEL build_connection "#{HOST}/pipeline/feature-extraction/#{model}" else build_connection "#{HOST}/models/#{model}" end end def request(connection:, input:) retries = 0 begin return super(connection: connection, input: input) rescue ServiceUnavailable => exception if retries < MAX_RETRY logger.debug('Service unavailable, retrying...') retries += 1 sleep 1 retry else raise exception end end end end end
Version data entries
3 entries across 3 versions & 1 rubygems
Version | Path |
---|---|
hugging-face-0.3.2 | lib/hugging_face/inference_api.rb |
hugging-face-0.3.1 | lib/hugging_face/inference_api.rb |
hugging-face-0.3.0 | lib/hugging_face/inference_api.rb |