lib/tomoto/lda.rb in tomoto-0.1.1 vs lib/tomoto/lda.rb in tomoto-0.1.2
- old
+ new
@@ -3,11 +3,11 @@
def self.new(tw: :one, min_cf: 0, min_df: 0, rm_top: 0, k: 1, alpha: 0.1, eta: 0.01, seed: nil)
model = _new(to_tw(tw), k, alpha, eta, seed || -1)
model.instance_variable_set(:@min_cf, min_cf)
model.instance_variable_set(:@min_df, min_df)
model.instance_variable_set(:@rm_top, rm_top)
- model
+ init_params(model, binding)
end
def self.load(filename)
model = new
model._load(filename)
@@ -30,21 +30,57 @@
def save(filename, full: true)
_save(filename, full)
end
+ # returns string instead of printing
+ def summary(initial_hp: true, params: true, topic_word_top_n: 5)
+ summary = []
+
+ summary << "<Basic Info>"
+ basic_info(summary)
+ summary << "|"
+
+ summary << "<Training Info>"
+ training_info(summary)
+ summary << "|"
+
+ if initial_hp
+ summary << "<Initial Parameters>"
+ initial_params_info(summary)
+ summary << "|"
+ end
+
+ if params
+ summary << "<Parameters>"
+ params_info(summary)
+ summary << "|"
+ end
+
+ if topic_word_top_n > 0
+ summary << "<Topics>"
+ topics_info(summary, topic_word_top_n: topic_word_top_n)
+ summary << "|"
+ end
+
+ # skip ending |
+ summary.pop
+
+ summary.join("\n")
+ end
+
def topic_words(topic_id = nil, top_n: 10)
if topic_id
_topic_words(topic_id, top_n)
else
k.times.map { |i| _topic_words(i, top_n) }
end
end
- def train(iterations = 10, workers: 0)
+ def train(iterations = 10, workers: 0, parallel: :default)
prepare
- _train(iterations, workers)
+ _train(iterations, workers, to_ps(parallel))
end
def tw
TERM_WEIGHT[_tw]
end
@@ -62,14 +98,70 @@
raise "cannot add_doc() after train()" if defined?(@prepared)
doc = doc.split(/[[:space:]]+/) unless doc.is_a?(Array)
doc
end
+ def basic_info(summary)
+ sum = used_vocab_freq.sum.to_f
+ mapped = used_vocab_freq.map { |v| v / sum }
+ entropy = mapped.map { |v| v * Math.log(v) }.sum
+
+ summary << "| #{self.class.name.sub("Tomoto::", "")} (current version: #{VERSION})"
+ summary << "| #{num_docs} docs, #{num_words} words"
+ summary << "| Total Vocabs: #{vocabs.size}, Used Vocabs: #{used_vocabs.size}"
+ summary << "| Entropy of words: %.5f" % entropy
+ summary << "| Removed Vocabs: #{removed_top_words.any? ? removed_top_words.join(" ") : "<NA>"}"
+ end
+
+ def training_info(summary)
+ summary << "| Iterations: #{global_step}, Burn-in steps: #{burn_in}"
+ summary << "| Optimization Interval: #{optim_interval}"
+ summary << "| Log-likelihood per word: %.5f" % ll_per_word
+ end
+
+ def initial_params_info(summary)
+ if defined?(@init_params)
+ @init_params.each do |k, v|
+ summary << "| #{k}: #{v}"
+ end
+ else
+ summary << "| Not Available"
+ end
+ end
+
+ def params_info(summary)
+ summary << "| alpha (Dirichlet prior on the per-document topic distributions)"
+ summary << "| #{alpha}"
+ summary << "| eta (Dirichlet prior on the per-topic word distribution)"
+ summary << "| %.5f" % eta
+ end
+
+ def topics_info(summary, topic_word_top_n:)
+ counts = count_by_topics
+ topic_words(top_n: topic_word_top_n).each_with_index do |words, i|
+ summary << "| ##{i} (#{counts[i]}) : #{words.keys.join(" ")}"
+ end
+ end
+
+ def to_ps(ps)
+ PARALLEL_SCHEME.index(ps) || (raise ArgumentError, "Invalid parallel scheme: #{ps}")
+ end
+
class << self
private
def to_tw(tw)
TERM_WEIGHT.index(tw) || (raise ArgumentError, "Invalid tw: #{tw}")
+ end
+
+ def init_params(model, binding)
+ init_params = {}
+ method(:new).parameters.each do |v|
+ next if v[0] != :key
+ init_params[v[1]] = binding.local_variable_get(v[1]).inspect
+ end
+ model.instance_variable_set(:@init_params, init_params)
+ model
end
end
end
end