lib/decisiontree/id3_tree.rb in decisiontree-0.1.0 vs lib/decisiontree/id3_tree.rb in decisiontree-0.2.0
- old
+ new
@@ -1,19 +1,33 @@
#The MIT License
###Copyright (c) 2007 Ilya Grigorik <ilya AT fortehost DOT com>
+###Modifed at 2007 by José Ignacio Fernández <joseignacio.fernandez AT gmail DOT com>
begin;
require 'graph/graphviz_dot'
rescue LoadError
STDERR.puts "graph/graphviz_dot not installed, graphing functionality not included."
end
class Array
def classification; collect { |v| v.last }; end
- def count_p; select { |v| v.last == 1 }.size; end
- def count_n; select { |v| v.last == 0 }.size; end
+
+ # calculate Information entropy
+ def entropy
+ return 0 if empty?
+
+ info = {}
+ total = 0
+ each {|i| info[i] = !info[i] ? 1 : (info[i] + 1); total += 1}
+
+ result = 0
+ info.each do |symbol, count|
+ result += -count.to_f/total*Math.log(count.to_f/total)/Math.log(2.0) if (count > 0)
+ end
+ result
+ end
end
module DecisionTree
class ID3Tree
Node = Struct.new(:attribute, :threshold, :gain)
@@ -28,19 +42,20 @@
when :discrete; fitness = proc{|a,b,c| id3_discrete(a,b,c)}
when :continuous; fitness = proc{|a,b,c| id3_continuous(a,b,c)}
end
return default if data.empty?
+
# return classification if all examples have the same classification
return data.first.last if data.classification.uniq.size == 1
# Choose best attribute (1. enumerate all attributes / 2. Pick best attribute)
performance = attributes.collect { |attribute| fitness.call(data, attributes, attribute) }
max = performance.max { |a,b| a[0] <=> b[0] }
best = Node.new(attributes[performance.index(max)], max[1], max[0])
@used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
- tree, l = {best => {}}, ['gt', 'lt']
+ tree, l = {best => {}}, ['>=', '<']
case @type
when :continuous
data.partition { |d| d[attributes.index(best.attribute)] > best.threshold }.each_with_index { |examples, i|
tree[best][String.new(l[i])] = train(examples, attributes, (data.classification.mode rescue 0), &fitness)
@@ -65,68 +80,69 @@
gain = thresholds.collect { |threshold|
sp = data.partition { |d| d[attributes.index(attribute)] > threshold }
pos = (sp[0].size).to_f / data.size
neg = (sp[1].size).to_f / data.size
- [entropy_num(data.count_p, data.count_n) - pos*entropy_num(sp[0].count_p, sp[0].count_n) - neg*entropy_num(sp[1].count_p, sp[1].count_n), threshold]
+ [data.classification.entropy - pos*sp[0].classification.entropy - neg*sp[1].classification.entropy, threshold]
}.max { |a,b| a[0] <=> b[0] }
end
# ID3 for discrete label cases
def id3_discrete(data, attributes, attribute)
values = data.collect { |d| d[attributes.index(attribute)] }.uniq.sort
partitions = values.collect { |val| data.select { |d| d[attributes.index(attribute)] == val } }
- remainder = partitions.collect {|p| (p.size.to_f / data.size) * entropy_num(p.count_p, p.count_n)}.inject(0) {|i,s| s+=i }
+ remainder = partitions.collect {|p| (p.size.to_f / data.size) * p.classification.entropy}.inject(0) {|i,s| s+=i }
- [entropy_num(data.count_p, data.count_n) - remainder, attributes.index(attribute)]
+ [data.classification.entropy - remainder, attributes.index(attribute)]
end
- # calculate information based on number of positive and negative classifications
- def entropy_num(p,n); entropy(p.to_f/(p+n),n.to_f/(p+n)); end
-
- # calculate Information based on probabilities
- def entropy(p, n)
- p = 0 if p.nan?
- n = 0 if n.nan?
-
- if(n < 0.00000001 and p < 0.00000001); return 0
- elsif (p < 0.00000001); return - n.to_f/(p+n)*Math.log(n.to_f/(p+n))/Math.log(2.0)
- elsif (n < 0.00000001); return - p.to_f/(p+n)*Math.log(p.to_f/(p+n))/Math.log(2.0)
- end
-
- return (- p.to_f/(p+n)) * Math.log(p.to_f/(p+n))/Math.log(2.0) + (- n.to_f/(p+n)) * Math.log(n.to_f/(p+n))/Math.log(2.0)
+ def predict(test)
+ @type == :discrete ? descend_discrete(@tree, test) : descend_continuous(@tree, test)
end
- def predict(test); @type == :discrete ? descend_discrete(@tree, test) : descend_continuous(@tree,test); end
-
def graph(filename)
dgp = DotGraphPrinter.new(build_tree)
dgp.write_to_file("#{filename}.png", "png")
end
private
def descend_continuous(tree, test)
attr = tree.to_a.first
- return attr[1]['gt'] if attr[1]['gt'].is_a?(Integer) and test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
- return attr[1]['lt'] if attr[1]['lt'].is_a?(Integer) and test[@attributes.index(attr.first.attribute)] < attr.first.threshold
- return descend_continuous(attr[1]['gt'],test) if test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
- return descend_continuous(attr[1]['lt'],test) if test[@attributes.index(attr.first.attribute)] < attr.first.threshold
+ return @default if !attr
+ return attr[1]['>='] if !attr[1]['>='].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
+ return attr[1]['<'] if !attr[1]['<'].is_a?(Hash) and test[@attributes.index(attr.first.attribute)] < attr.first.threshold
+ return descend_continuous(attr[1]['>='],test) if test[@attributes.index(attr.first.attribute)] >= attr.first.threshold
+ return descend_continuous(attr[1]['<'],test) if test[@attributes.index(attr.first.attribute)] < attr.first.threshold
end
- def descend_discrete(tree,test)
- attr = tree.to_a.first
- return attr[1][test[@attributes.index(attr[0].attribute)]] if attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Integer)
+ def descend_discrete(tree, test)
+ attr = tree.to_a.first
+ return @default if !attr
+ return attr[1][test[@attributes.index(attr[0].attribute)]] if !attr[1][test[@attributes.index(attr[0].attribute)]].is_a?(Hash)
return descend_discrete(attr[1][test[@attributes.index(attr[0].attribute)]],test)
end
- def build_tree(tree = @tree, root = nil)
- return [[root, "#{tree == 1 ? 'true' : 'false'} \n (#{String.new(tree.to_s).object_id})"]] if tree.is_a?(Integer)
-
+ def build_tree(tree = @tree)
+ return [] unless tree.is_a?(Hash)
+ return [["Always", @default]] if tree.empty?
+
attr = tree.to_a.first
- mid = root.nil? ? [] : [[root, attr[0].attribute]]
- links = mid + attr[1].keys.collect { |key| [attr[0].attribute, key] }
- attr[1].keys.each { |key| links += build_tree(attr[1][key], key) }
-
+
+ links = attr[1].keys.collect do |key|
+ parent_text = "#{attr[0].attribute}\n(#{attr[0].object_id})"
+ if attr[1][key].is_a?(Hash) then
+ child = attr[1][key].to_a.first[0]
+ child_text = "#{child.attribute}\n(#{child.object_id})"
+ else
+ child = attr[1][key]
+ child_text = "#{child}\n(#{child.to_s.object_id})"
+ end
+ label_text = "#{key} #{@type == :continuous ? attr[0].threshold : ""}"
+
+ [parent_text, child_text, label_text]
+ end
+ attr[1].keys.each { |key| links += build_tree(attr[1][key]) }
+
return links
end
end
end