#The MIT License ###Copyright (c) 2007 Ilya Grigorik begin; require 'graph/graphviz_dot' rescue LoadError STDERR.puts "graph/graphviz_dot not installed, graphing functionality not included." end class Array def classification; collect { |v| v.last }; end def count_p; select { |v| v.last == 1 }.size; end def count_n; select { |v| v.last == 0 }.size; end end module DecisionTree class ID3Tree Node = Struct.new(:attribute, :threshold, :gain) def initialize(attributes, data, default, type) @used, @tree, @type = {}, {}, type @data, @attributes, @default = data, attributes, default end def train(data=@data, attributes=@attributes, default=@default) # Choose a fitness algorithm case @type when :discrete; fitness = proc{|a,b,c| id3_discrete(a,b,c)} when :continuous; fitness = proc{|a,b,c| id3_continuous(a,b,c)} end return default if data.empty? # return classification if all examples have the same classification return data.first.last if data.classification.uniq.size == 1 # Choose best attribute (1. enumerate all attributes / 2. Pick best attribute) performance = attributes.collect { |attribute| fitness.call(data, attributes, attribute) } max = performance.max { |a,b| a[0] <=> b[0] } best = Node.new(attributes[performance.index(max)], max[1], max[0]) @used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold] tree, l = {best => {}}, ['gt', 'lt'] case @type when :continuous data.partition { |d| d[attributes.index(best.attribute)] > best.threshold }.each_with_index { |examples, i| tree[best][String.new(l[i])] = train(examples, attributes, (data.classification.mode rescue 0), &fitness) } when :discrete values = data.collect { |d| d[attributes.index(best.attribute)] }.uniq.sort partitions = values.collect { |val| data.select { |d| d[attributes.index(best.attribute)] == val } } partitions.each_with_index { |examples, i| tree[best][values[i]] = train(examples, attributes-[values[i]], (data.classification.mode rescue 0), &fitness) } end @tree = tree end # ID3 for binary classification of continuous variables (e.g. healthy / sick based on temperature thresholds) def id3_continuous(data, attributes, attribute) values, thresholds = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort, [] values.each_index { |i| thresholds.push((values[i]+(values[i+1].nil? ? values[i] : values[i+1])).to_f / 2) } thresholds -= @used[attribute] if @used.has_key? attribute gain = thresholds.collect { |threshold| sp = data.partition { |d| d[attributes.index(attribute)] > threshold } pos = (sp[0].size).to_f / data.size neg = (sp[1].size).to_f / data.size [entropy_num(data.count_p, data.count_n) - pos*entropy_num(sp[0].count_p, sp[0].count_n) - neg*entropy_num(sp[1].count_p, sp[1].count_n), threshold] }.max { |a,b| a[0] <=> b[0] } end # ID3 for discrete label cases def id3_discrete(data, attributes, attribute) values = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort partitions = values.collect { |val| data.select { |d| d[attributes.index(attribute)] == val } } remainder = partitions.collect {|p| (p.size.to_f / data.size) * entropy_num(p.count_p, p.count_n)}.inject(0) {|i,s| s+=i } [entropy_num(data.count_p, data.count_n) - remainder, attributes.index(attribute)] end # calculate information based on number of positive and negative classifications def entropy_num(p,n); entropy(p.to_f/(p+n),n.to_f/(p+n)); end # calculate Information based on probabilities def entropy(p, n) p = 0 if p.nan? n = 0 if n.nan? if(n < 0.00000001 and p < 0.00000001); return 0 elsif (p < 0.00000001); return - n.to_f/(p+n)*Math.log(n.to_f/(p+n))/Math.log(2.0) elsif (n < 0.00000001); return - p.to_f/(p+n)*Math.log(p.to_f/(p+n))/Math.log(2.0) end return (- p.to_f/(p+n)) * Math.log(p.to_f/(p+n))/Math.log(2.0) + (- n.to_f/(p+n)) * Math.log(n.to_f/(p+n))/Math.log(2.0) end def predict(test); @type == :discrete ? descend_discrete(@tree, test) : descend_continuous(@tree,test); end def graph(filename) dgp = DotGraphPrinter.new(build_tree) dgp.write_to_file("#{filename}.png", "png") end private def descend_continuous(tree, test) attr = tree.to_a.first return attr[1]['gt'] if attr[1]['gt'].is_a?(Integer) and test[@attributes.index(attr.first.attribute)] >= attr.first.threshold return attr[1]['lt'] if attr[1]['lt'].is_a?(Integer) and test[@attributes.index(attr.first.attribute)] < attr.first.threshold return descend_continuous(attr[1]['gt'],test) if test[@attributes.index(attr.first.attribute)] >= attr.first.threshold return descend_continuous(attr[1]['lt'],test) if test[@attributes.index(attr.first.attribute)] < attr.first.threshold end def descend_discrete(tree,test) attr = tree.to_a.first return attr[1][test[@attributes.index(attr[0].attribute)]] if attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Integer) return descend_discrete(attr[1][test[@attributes.index(attr[0].attribute)]],test) end def build_tree(tree = @tree, root = nil) return [[root, "#{tree == 1 ? 'true' : 'false'} \n (#{String.new(tree.to_s).object_id})"]] if tree.is_a?(Integer) attr = tree.to_a.first mid = root.nil? ? [] : [[root, attr[0].attribute]] links = mid + attr[1].keys.collect { |key| [attr[0].attribute, key] } attr[1].keys.each { |key| links += build_tree(attr[1][key], key) } return links end end end