lib/decisiontree/id3_tree.rb in decisiontree-0.1.0 vs lib/decisiontree/id3_tree.rb in decisiontree-0.2.0

- old
+ new

@@ -1,19 +1,33 @@ #The MIT License ###Copyright (c) 2007 Ilya Grigorik <ilya AT fortehost DOT com> +###Modifed at 2007 by José Ignacio Fernández <joseignacio.fernandez AT gmail DOT com> begin; require 'graph/graphviz_dot' rescue LoadError STDERR.puts "graph/graphviz_dot not installed, graphing functionality not included." end class Array def classification; collect { |v| v.last }; end - def count_p; select { |v| v.last == 1 }.size; end - def count_n; select { |v| v.last == 0 }.size; end + + # calculate Information entropy + def entropy + return 0 if empty? + + info = {} + total = 0 + each {|i| info[i] = !info[i] ? 1 : (info[i] + 1); total += 1} + + result = 0 + info.each do |symbol, count| + result += -count.to_f/total*Math.log(count.to_f/total)/Math.log(2.0) if (count > 0) + end + result + end end module DecisionTree class ID3Tree Node = Struct.new(:attribute, :threshold, :gain) @@ -28,19 +42,20 @@ when :discrete; fitness = proc{|a,b,c| id3_discrete(a,b,c)} when :continuous; fitness = proc{|a,b,c| id3_continuous(a,b,c)} end return default if data.empty? + # return classification if all examples have the same classification return data.first.last if data.classification.uniq.size == 1 # Choose best attribute (1. enumerate all attributes / 2. Pick best attribute) performance = attributes.collect { |attribute| fitness.call(data, attributes, attribute) } max = performance.max { |a,b| a[0] <=> b[0] } best = Node.new(attributes[performance.index(max)], max[1], max[0]) @used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold] - tree, l = {best => {}}, ['gt', 'lt'] + tree, l = {best => {}}, ['>=', '<'] case @type when :continuous data.partition { |d| d[attributes.index(best.attribute)] > best.threshold }.each_with_index { |examples, i| tree[best][String.new(l[i])] = train(examples, attributes, (data.classification.mode rescue 0), &fitness) @@ -65,68 +80,69 @@ gain = thresholds.collect { |threshold| sp = data.partition { |d| d[attributes.index(attribute)] > threshold } pos = (sp[0].size).to_f / data.size neg = (sp[1].size).to_f / data.size - [entropy_num(data.count_p, data.count_n) - pos*entropy_num(sp[0].count_p, sp[0].count_n) - neg*entropy_num(sp[1].count_p, sp[1].count_n), threshold] + [data.classification.entropy - pos*sp[0].classification.entropy - neg*sp[1].classification.entropy, threshold] }.max { |a,b| a[0] <=> b[0] } end # ID3 for discrete label cases def id3_discrete(data, attributes, attribute) values = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort partitions = values.collect { |val| data.select { |d| d[attributes.index(attribute)] == val } } - remainder = partitions.collect {|p| (p.size.to_f / data.size) * entropy_num(p.count_p, p.count_n)}.inject(0) {|i,s| s+=i } + remainder = partitions.collect {|p| (p.size.to_f / data.size) * p.classification.entropy}.inject(0) {|i,s| s+=i } - [entropy_num(data.count_p, data.count_n) - remainder, attributes.index(attribute)] + [data.classification.entropy - remainder, attributes.index(attribute)] end - # calculate information based on number of positive and negative classifications - def entropy_num(p,n); entropy(p.to_f/(p+n),n.to_f/(p+n)); end - - # calculate Information based on probabilities - def entropy(p, n) - p = 0 if p.nan? - n = 0 if n.nan? - - if(n < 0.00000001 and p < 0.00000001); return 0 - elsif (p < 0.00000001); return - n.to_f/(p+n)*Math.log(n.to_f/(p+n))/Math.log(2.0) - elsif (n < 0.00000001); return - p.to_f/(p+n)*Math.log(p.to_f/(p+n))/Math.log(2.0) - end - - return (- p.to_f/(p+n)) * Math.log(p.to_f/(p+n))/Math.log(2.0) + (- n.to_f/(p+n)) * Math.log(n.to_f/(p+n))/Math.log(2.0) + def predict(test) + @type == :discrete ? descend_discrete(@tree, test) : descend_continuous(@tree, test) end - def predict(test); @type == :discrete ? descend_discrete(@tree, test) : descend_continuous(@tree,test); end - def graph(filename) dgp = DotGraphPrinter.new(build_tree) dgp.write_to_file("#{filename}.png", "png") end private def descend_continuous(tree, test) attr = tree.to_a.first - return attr[1]['gt'] if attr[1]['gt'].is_a?(Integer) and test[@attributes.index(attr.first.attribute)] >= attr.first.threshold - return attr[1]['lt'] if attr[1]['lt'].is_a?(Integer) and test[@attributes.index(attr.first.attribute)] < attr.first.threshold - return descend_continuous(attr[1]['gt'],test) if test[@attributes.index(attr.first.attribute)] >= attr.first.threshold - return descend_continuous(attr[1]['lt'],test) if test[@attributes.index(attr.first.attribute)] < attr.first.threshold + return @default if !attr + return attr[1]['>='] if !attr[1]['>='].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] >= attr.first.threshold + return attr[1]['<'] if !attr[1]['<'].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] < attr.first.threshold + return descend_continuous(attr[1]['>='],test) if test[@attributes.index(attr.first.attribute)] >= attr.first.threshold + return descend_continuous(attr[1]['<'],test) if test[@attributes.index(attr.first.attribute)] < attr.first.threshold end - def descend_discrete(tree,test) - attr = tree.to_a.first - return attr[1][test[@attributes.index(attr[0].attribute)]] if attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Integer) + def descend_discrete(tree, test) + attr = tree.to_a.first + return @default if !attr + return attr[1][test[@attributes.index(attr[0].attribute)]] if !attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Hash) return descend_discrete(attr[1][test[@attributes.index(attr[0].attribute)]],test) end - def build_tree(tree = @tree, root = nil) - return [[root, "#{tree == 1 ? 'true' : 'false'} \n (#{String.new(tree.to_s).object_id})"]] if tree.is_a?(Integer) - + def build_tree(tree = @tree) + return [] unless tree.is_a?(Hash) + return [["Always", @default]] if tree.empty? + attr = tree.to_a.first - mid = root.nil? ? [] : [[root, attr[0].attribute]] - links = mid + attr[1].keys.collect { |key| [attr[0].attribute, key] } - attr[1].keys.each { |key| links += build_tree(attr[1][key], key) } - + + links = attr[1].keys.collect do |key| + parent_text = "#{attr[0].attribute}\n(#{attr[0].object_id})" + if attr[1][key].is_a?(Hash) then + child = attr[1][key].to_a.first[0] + child_text = "#{child.attribute}\n(#{child.object_id})" + else + child = attr[1][key] + child_text = "#{child}\n(#{child.to_s.object_id})" + end + label_text = "#{key} #{@type == :continuous ? attr[0].threshold : ""}" + + [parent_text, child_text, label_text] + end + attr[1].keys.each { |key| links += build_tree(attr[1][key]) } + return links end end end