# Author:: Sergio Fierens (Implementation, Quinlan is # the creator of the algorithm) # License:: MPL 1.1 # Project:: ai4r # Url:: http://ai4r.rubyforge.org/ # # You can redistribute it and/or modify it under the terms of # the Mozilla Public License version 1.1 as published by the # Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt require File.dirname(__FILE__) + '/classifier_helper' module Ai4r module Classifiers # = Introduction # This is an implementation of the ID3 algorithm (Quinlan) # Given a set of preclassified examples, it builds a top-down # induction of decision tree, biased by the information gain and # entropy measure. # # * http://en.wikipedia.org/wiki/Decision_tree # * http://en.wikipedia.org/wiki/ID3_algorithm # # = How to use it # # DATA_LABELS = [ 'city', 'age_range', 'gender', 'marketing_target' ] # # DATA_SET = [ ['New York', '<30', 'M', 'Y'], # ['Chicago', '<30', 'M', 'Y'], # ['Chicago', '<30', 'F', 'Y'], # ['New York', '<30', 'M', 'Y'], # ['New York', '<30', 'M', 'Y'], # ['Chicago', '[30-50)', 'M', 'Y'], # ['New York', '[30-50)', 'F', 'N'], # ['Chicago', '[30-50)', 'F', 'Y'], # ['New York', '[30-50)', 'F', 'N'], # ['Chicago', '[50-80]', 'M', 'N'], # ['New York', '[50-80]', 'F', 'N'], # ['New York', '[50-80]', 'M', 'N'], # ['Chicago', '[50-80]', 'M', 'N'], # ['New York', '[50-80]', 'F', 'N'], # ['Chicago', '>80', 'F', 'Y'] # ] # # id3 = DecisionTree::ID3.new(DATA_SET, DATA_LABELS) # # id3.to_s # # => if age_range=='<30' then marketing_target='Y' # elsif age_range=='[30-50)' and city=='Chicago' then marketing_target='Y' # elsif age_range=='[30-50)' and city=='New York' then marketing_target='N' # elsif age_range=='[50-80]' then marketing_target='N' # elsif age_range=='>80' then marketing_target='Y' # else raise 'There was not enough information during training to do a proper induction for this data element' end # # id3.eval(['New York', '<30', 'M']) # # => 'Y' # # = A better way to load the data # # In the real life you will use lot more data training examples, with more # attributes. Consider moving your data to an external CSV (comma separate # values) file. # # data_set = [] # CSV::Reader.parse(File.open("#{File.dirname(__FILE__)}/data_set.csv", 'r')) do |row| # data_set << row # end # data_labels = data_set.shift # # id3 = DecisionTree::ID3.new(data_set, data_labels) # # = A nice tip for data evaluation # # id3 = DecisionTree::ID3.new(DATA_SET, DATA_LABELS) # age_range = '<30' # marketing_target = nil # eval id3.to_s # puts marketing_target # # => 'Y' # = More about ID3 and decision trees # # * http://en.wikipedia.org/wiki/Decision_tree # * http://en.wikipedia.org/wiki/ID3_algorithm # # = About the project # Author:: Sergio Fierens # License:: MPL 1.1 class ID3 attr_reader :data_labels include ClassifierHelper # Create a new decision tree. If your data is classified with N attributed # and M examples, then your data examples must have the following format: # # [ [ATT1_VAL1, ATT2_VAL1, ATT3_VAL1, ... , ATTN_VAL1, CATEGORY_VAL1], # [ATT1_VAL2, ATT2_VAL2, ATT3_VAL2, ... , ATTN_VAL2, CATEGORY_VAL2], # ... # [ATTM1_VALM, ATT2_VALM, ATT3_VALM, ... , ATTN_VALM, CATEGORY_VALM], # ] # # e.g. # [ ['New York', '<30', 'M', 'Y'], # ['Chicago', '<30', 'M', 'Y'], # ['Chicago', '<30', 'F', 'Y'], # ['New York', '<30', 'M', 'Y'], # ['New York', '<30', 'M', 'Y'], # ['Chicago', '[30-50)', 'M', 'Y'], # ['New York', '[30-50)', 'F', 'N'], # ['Chicago', '[30-50)', 'F', 'Y'], # ['New York', '[30-50)', 'F', 'N'], # ['Chicago', '[50-80]', 'M', 'N'], # ['New York', '[50-80]', 'F', 'N'], # ['New York', '[50-80]', 'M', 'N'], # ['Chicago', '[50-80]', 'M', 'N'], # ['New York', '[50-80]', 'F', 'N'], # ['Chicago', '>80', 'F', 'Y'] # ] # # Data labels must have the following format: # [ 'city', 'age_range', 'gender', 'marketing_target' ] # # If you do not provide labels for you data, the following labels will # be created by default: # [ 'ATTRIBUTE_1', 'ATTRIBUTE_2', 'ATTRIBUTE_3', 'CATEGORY' ] # def build(data_examples, data_labels=nil) check_data_examples(data_examples) @data_labels = (data_labels) ? data_labels : default_data_labels(data_examples) preprocess_data(data_examples) return self end # You can evaluate new data, predicting its category. # e.g. # id3.eval(['New York', '<30', 'F']) # => 'Y' def eval(data) @tree.value(data) if @tree end # This method returns the generated rules in ruby code. # e.g. # # id3.to_s # # => if age_range=='<30' then marketing_target='Y' # elsif age_range=='[30-50)' and city=='Chicago' then marketing_target='Y' # elsif age_range=='[30-50)' and city=='New York' then marketing_target='N' # elsif age_range=='[50-80]' then marketing_target='N' # elsif age_range=='>80' then marketing_target='Y' # else raise 'There was not enough information during training to do a proper induction for this data element' end # # It is a nice way to inspect induction results, and also to execute them: # age_range = '<30' # marketing_target = nil # eval id3.to_s # puts marketing_target # # => 'Y' def to_s rules = @tree.get_rules rules = rules.collect do |rule| "#{rule[0..-2].join(' and ')} then #{rule.last}" end return "if #{rules.join("\nelsif ")}\nelse raise 'There was not enough information during training to do a proper induction for this data element' end" end private def preprocess_data(data_examples) @tree = build_node(data_examples) end private def build_node(data_examples, flag_att = []) return ErrorNode.new if data_examples.length == 0 domain = domain(data_examples) return CategoryNode.new(@data_labels.last, domain.last[0]) if domain.last.length == 1 min_entropy_index = min_entropy_index(data_examples, domain, flag_att) flag_att << min_entropy_index split_data_examples = split_data_examples(data_examples, domain, min_entropy_index) return CategoryNode.new(@data_labels.last, most_freq(data_examples, domain)) if split_data_examples.length == 1 nodes = split_data_examples.collect do |partial_data_examples| build_node(partial_data_examples, flag_att) end return EvaluationNode.new(@data_labels, min_entropy_index, domain[min_entropy_index], nodes) end private def self.sum(values) values.inject( 0 ) { |sum,x| sum+x } end private def self.log2(z) return 0.0 if z == 0 Math.log(z)/LOG2 end private def most_freq(examples, domain) freqs = [] domain.last.length.times { freqs << 0} examples.each do |example| cat_index = domain.last.index(example.last) freq = freqs[cat_index] + 1 freqs[cat_index] = freq end max_freq = freqs.max max_freq_index = freqs.index(max_freq) domain.last[max_freq_index] end private def split_data_examples(data_examples, domain, att_index) data_examples_array = [] att_value_examples = {} data_examples.each do |example| example_set = att_value_examples[example[att_index]] example_set = [] if !example_set example_set << example att_value_examples.store(example[att_index], example_set) end att_value_examples.each_pair do |att_value, example_set| att_value_index = domain[att_index].index(att_value) data_examples_array[att_value_index] = example_set end return data_examples_array end private def min_entropy_index(data_examples, domain, flag_att=[]) min_entropy = nil min_index = 0 domain[0..-2].each_index do |index| freq_grid = freq_grid(index, data_examples, domain) entropy = entropy(freq_grid, data_examples.length) if (!min_entropy || entropy < min_entropy) && !flag_att.include?(index) min_entropy = entropy min_index = index end end return min_index end private def domain(data_examples) #return build_domains(data_examples) domain = [] @data_labels.length.times { domain << [] } data_examples.each do |data| data.each_index do |i| domain[i] << data[i] if i