Sha256: eab64dd2643536a25f1cc9605ba80ab65eec873be378c4a9cd39540e88825ba2

Contents?: true

Size: 1.63 KB

Versions: 3

Compression:

Stored size: 1.63 KB

Contents

module Ankusa
  INFTY = 1.0 / 0.0

  class NaiveBayesClassifier
    include Classifier

    def classify(text, classes=nil)
      # return the most probable class

      result = log_likelihoods(text, classes)
      if result.values.uniq.size. === 1
        # unless all classes are equally likely, then return nil
        return nil
      else
        result.sort_by { |c| -c[1] }.first.first
      end
    end

    # Classes is an array of classes to look at
    def classifications(text, classnames=nil)
      result = log_likelihoods text, classnames
      result.keys.each { |k|
        result[k] = (result[k] == -INFTY) ? 0 : Math.exp(result[k])
      }

      # normalize to get probs
      sum = result.values.inject{ |x,y| x+y }
      result.keys.each { |k|
        result[k] = result[k] / sum
        } unless sum.zero?
      result
    end

    # Classes is an array of classes to look at
    def log_likelihoods(text, classnames=nil)
      classnames ||= @classnames
      result = Hash.new 0

      TextHash.new(text).each { |word, count|
        probs = get_word_probs(word, classnames)
        classnames.each { |k|
          # log likelihood should be negative infinity if we've never seen the klass
          result[k] += probs[k] > 0 ? (Math.log(probs[k]) * count) : -INFTY
        }
      }

      # add the prior
      doc_counts = doc_count_totals.select { |k,v| classnames.include? k }.map { |k,v| v }

      doc_count_total = (doc_counts.inject(0){ |x,y| x+y } + classnames.length).to_f

      classnames.each { |k|
        result[k] += Math.log((@storage.get_doc_count(k) + 1).to_f / doc_count_total)
      }

      result
    end

  end

end

Version data entries

3 entries across 3 versions & 1 rubygems

Version Path
ankusa-0.1.0 lib/ankusa/naive_bayes.rb
ankusa-0.0.16 lib/ankusa/naive_bayes.rb
ankusa-0.0.15 lib/ankusa/naive_bayes.rb