lib/decisiontree/id3_tree.rb in decisiontree-0.4.0 vs lib/decisiontree/id3_tree.rb in decisiontree-0.5.0

- old
+ new

@@ -1,12 +1,10 @@ # The MIT License # ### Copyright (c) 2007 Ilya Grigorik <ilya AT igvita DOT com> ### Modifed at 2007 by José Ignacio Fernández <joseignacio.fernandez AT gmail DOT com> -require 'graphr' - class Object def save_to_file(filename) File.open(filename, 'w+' ) { |f| f << Marshal.dump(self) } end @@ -42,10 +40,11 @@ @used, @tree, @type = {}, {}, type @data, @attributes, @default = data, attributes, default end def train(data=@data, attributes=@attributes, default=@default) + attributes = attributes.map {|e| e.to_s} initialize(attributes, data, default, @type) # Remove samples with same attributes leaving most common classification data2 = data.inject({}) {|hash, d| hash[d.slice(0..-2)] ||= Hash.new(0); hash[d.slice(0..-2)][d.last] += 1; hash }.map{|key,val| key + [val.sort_by{ |k, v| v }.last.first]} @@ -67,13 +66,18 @@ return default if data.empty? # return classification if all examples have the same classification return data.first.last if data.classification.uniq.size == 1 - # Choose best attribute (1. enumerate all attributes / 2. Pick best attribute) + # Choose best attribute: + # 1. enumerate all attributes + # 2. Pick best attribute + # 3. If attributes all score the same, then pick a random one to avoid infinite recursion. performance = attributes.collect { |attribute| fitness_for(attribute).call(data, attributes, attribute) } max = performance.max { |a,b| a[0] <=> b[0] } + min = performance.min { |a,b| a[0] <=> b[0] } + max = performance.shuffle.first if max[0] == min[0] best = Node.new(attributes[performance.index(max)], max[1], max[0]) best.threshold = nil if @type == :discrete @used.has_key?(best.attribute) ? @used[best.attribute] += [best.threshold] : @used[best.attribute] = [best.threshold] tree, l = {best => {}}, ['>=', '<'] @@ -125,12 +129,13 @@ def predict(test) descend(@tree, test) end - def graph(filename) + def graph(filename, file_type = "png") + require 'graphr' dgp = DotGraphPrinter.new(build_tree) - dgp.write_to_file("#{filename}.png", "png") + dgp.write_to_file("#{filename}.#{file_type}", file_type) end def ruleset rs = Ruleset.new(@attributes, @data, @default, @type) rs.rules = build_rules