Sha256: 8caf19195339ba0587feca3a0ec4fd6d064930b2bd7d13fc0c847f4792e06656

Contents?: true

Size: 1.41 KB

Versions: 4

Compression:

Stored size: 1.41 KB

Contents

require File.join(File.expand_path(File.dirname(__FILE__)),'../../../..', 'test_helper.rb')
require 'rbbt/vector/model/huggingface/masked_lm'

class TestMaskedLM < Test::Unit::TestCase
  def test_train_new_word
    TmpFile.with_file do |dir|

      checkpoint = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
      mlm = MaskedLMModel.new checkpoint, dir, tokenizer_args: {max_length: 16, model_max_length: 16}

      mod, tokenizer = mlm.init
      if tokenizer.vocab["[GENE]"].nil?
        tokenizer.add_tokens("[GENE]")
        mod.resize_token_embeddings(tokenizer.__len__)
      end

      100.times do
        mlm.add "This [GENE] is [MASK] on tumor cells.", %w(expressed)
        mlm.add "This [MASK] is expressed.", %w([GENE])
      end

      assert_equal "protein", mlm.eval(["This [MASK] is expressed."])

      mlm.train

      assert_equal "[GENE]", mlm.eval(["This [MASK] is expressed."])
      assert_equal "expressed", mlm.eval(["This [GENE] is [MASK] in tumor cells."])

      mlm = MaskedLMModel.new checkpoint, dir, :max_length => 16
      
      assert_equal "[GENE]", mlm.eval(["This [MASK] is expressed."])
      assert_equal "expressed", mlm.eval(["This [GENE] is [MASK] in tumor cells."])

      mlm = VectorModel.new dir
      
      assert_equal "[GENE]", mlm.eval(["This [MASK] is expressed."])
      assert_equal "expressed", mlm.eval(["This [GENE] is [MASK] in tumor cells."])

    end
  end
end

Version data entries

4 entries across 4 versions & 1 rubygems

Version Path
rbbt-dm-1.3.2 test/rbbt/vector/model/huggingface/test_masked_lm.rb
rbbt-dm-1.3.0 test/rbbt/vector/model/huggingface/test_masked_lm.rb
rbbt-dm-1.2.10 test/rbbt/vector/model/huggingface/test_masked_lm.rb
rbbt-dm-1.2.9 test/rbbt/vector/model/huggingface/test_masked_lm.rb