Sha256: eef0a14d26b693ce5d9896bb8ce1eea3c4f045a22761416c830fe1880332c0e7

Contents?: true

Size: 1.96 KB

Versions: 4

Compression:

Stored size: 1.96 KB

Contents

require 'rbbt/vector/model/huggingface'
class MaskedLMModel < HuggingfaceModel

  def initialize(checkpoint, dir = nil, model_options = {})
    
    model_options = Misc.add_defaults model_options, :max_length => 128
    super("MaskedLM", checkpoint, dir, model_options)

    train_model do |texts,labels|
      model, tokenizer = self.init
      max_length = @model_options[:max_length]
      mask_id = tokenizer.mask_token_id

      dataset = []
      texts.zip(labels).each do |text,label_values|
        fixed_text = text.gsub("[MASK]", "[PENDINGMASK]")
        label_tokens = label_values.collect{|label| tokenizer.convert_tokens_to_ids(label) }
        label_tokens.each do |ids|
          ids = [ids] unless Array === ids
          fixed_text.sub!("[PENDINGMASK]", "[MASK]" * ids.length)
        end

        tokenized_text = tokenizer.call(fixed_text, truncation: true, padding: "max_length")
        input_ids = tokenized_text["input_ids"].to_a
        attention_mask = tokenized_text["attention_mask"].to_a

        all_label_tokens = label_tokens.flatten
        label_ids = input_ids.collect do |id|
          if id == mask_id
            all_label_tokens.shift
          else
            -100
          end
        end
        dataset << {input_ids: input_ids, labels: label_ids, attention_mask: attention_mask}
      end

      dataset_file = File.join(@directory, 'dataset.json')
      Open.write(dataset_file, dataset.collect{|e| e.to_json} * "\n")

      training_args_obj = RbbtPython.call_method("rbbt_dm.huggingface", :training_args, @model_path, @model_options[:training_args])
      data_collator = RbbtPython.class_new_obj("transformers", "DefaultDataCollator", {}) 
      RbbtPython.call_method("rbbt_dm.huggingface", :train_model, model, tokenizer, training_args_obj, dataset_file, @model_options[:class_weights], data_collator: data_collator)

      model.save_pretrained(@model_path) if @model_path
      tokenizer.save_pretrained(@model_path) if @model_path
    end

  end
end

Version data entries

4 entries across 4 versions & 1 rubygems

Version Path
rbbt-dm-1.3.2 lib/rbbt/vector/model/huggingface/masked_lm.rb
rbbt-dm-1.3.0 lib/rbbt/vector/model/huggingface/masked_lm.rb
rbbt-dm-1.2.10 lib/rbbt/vector/model/huggingface/masked_lm.rb
rbbt-dm-1.2.9 lib/rbbt/vector/model/huggingface/masked_lm.rb