Sha256: abc37c5fad88a6aa3aa0a9c85c62489d2830ee8e5300be8bb473c2881c6b04fa

Contents?: true

Size: 1.94 KB

Versions: 1

Compression:

Stored size: 1.94 KB

Contents

module FastText
  class Classifier < Model
    DEFAULT_OPTIONS = {
      lr: 0.1,
      lr_update_rate: 100,
      dim: 100,
      ws: 5,
      epoch: 5,
      min_count: 1,
      min_count_label: 0,
      neg: 5,
      word_ngrams: 1,
      loss: "softmax",
      model: "supervised",
      bucket: 2000000,
      minn: 0,
      maxn: 0,
      thread: 3,
      t: 0.0001,
      label_prefix: "__label__",
      verbose: 2,
      pretrained_vectors: "",
      save_output: false,
      # seed: 0
    }

    def fit(x, y = nil)
      input = input_path(x, y)
      @m ||= Ext::Model.new
      m.train(DEFAULT_OPTIONS.merge(@options).merge(input: input, model: "supervised"))
    end

    # TODO support array of text
    def predict(text, k: 1, threshold: 0.0)
      m.predict(prep_text(text), k, threshold).map do |v|
        [remove_prefix(v[1]), v[0]]
      end.to_h
    end

    def test(x, y = nil, k: 1)
      input = input_path(x, y)
      res = m.test(input, k)
      {
        examples: res[0],
        precision: res[1],
        recall: res[2]
      }
    end

    # TODO support options
    def quantize
      m.quantize({})
    end

    def labels(include_freq: false)
      labels, freqs = m.labels
      labels.map! { |v| remove_prefix(v) }
      if include_freq
        labels.zip(freqs).to_h
      else
        labels
      end
    end

    private

    def input_path(x, y)
      if x.is_a?(String)
        raise ArgumentError, "Cannot pass y with file" if y
        x
      else
        tempfile = Tempfile.new("fasttext")
        x.zip(y) do |xi, yi|
          parts = Array(yi).map { |label| "__label__" + label }
          parts << xi.gsub("\n", " ") # replace newlines in document
          tempfile.write(parts.join(" "))
          tempfile.write("\n")
        end
        tempfile.close
        tempfile.path
      end
    end

    def remove_prefix(label)
      label.sub(label_prefix, "")
    end

    def label_prefix
      m.label_prefix
    end
  end
end

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
fasttext-0.1.0 lib/fasttext/classifier.rb