lib/decisiontree/id3_tree.rb in decisiontree-0.4.0 vs lib/decisiontree/id3_tree.rb in decisiontree-0.5.0
- old
+ new
@@ -1,12 +1,10 @@
# The MIT License
#
### Copyright (c) 2007 Ilya Grigorik <ilya AT igvita DOT com>
### Modifed at 2007 by José Ignacio Fernández <joseignacio.fernandez AT gmail DOT com>
-require 'graphr'
-
class Object
def save_to_file(filename)
File.open(filename, 'w+' ) { |f| f << Marshal.dump(self) }
end
@@ -42,10 +40,11 @@
@used, @tree, @type = {}, {}, type
@data, @attributes, @default = data, attributes, default
end
def train(data=@data, attributes=@attributes, default=@default)
+ attributes = attributes.map {|e| e.to_s}
initialize(attributes, data, default, @type)
# Remove samples with same attributes leaving most common classification
data2 = data.inject({}) {|hash, d| hash[d.slice(0..-2)] ||= Hash.new(0); hash[d.slice(0..-2)][d.last] += 1; hash }.map{|key,val| key + [val.sort_by{ |k, v| v }.last.first]}
@@ -67,13 +66,18 @@
return default if data.empty?
# return classification if all examples have the same classification
return data.first.last if data.classification.uniq.size == 1
- # Choose best attribute (1. enumerate all attributes / 2. Pick best attribute)
+ # Choose best attribute:
+ # 1. enumerate all attributes
+ # 2. Pick best attribute
+ # 3. If attributes all score the same, then pick a random one to avoid infinite recursion.
performance = attributes.collect { |attribute| fitness_for(attribute).call(data, attributes, attribute) }
max = performance.max { |a,b| a[0] <=> b[0] }
+ min = performance.min { |a,b| a[0] <=> b[0] }
+ max = performance.shuffle.first if max[0] == min[0]
best = Node.new(attributes[performance.index(max)], max[1], max[0])
best.threshold = nil if @type == :discrete
@used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold]
tree, l = {best => {}}, ['>=', '<']
@@ -125,12 +129,13 @@
def predict(test)
descend(@tree, test)
end
- def graph(filename)
+ def graph(filename, file_type = "png")
+ require 'graphr'
dgp = DotGraphPrinter.new(build_tree)
- dgp.write_to_file("#{filename}.png", "png")
+ dgp.write_to_file("#{filename}.#{file_type}", file_type)
end
def ruleset
rs = Ruleset.new(@attributes, @data, @default, @type)
rs.rules = build_rules