Sha256: 91b93490f97070b746888043c7f38308ad91f6d41a4c83219f490c2bdb81fdd5

Contents?: true

Size: 1.58 KB

Versions: 3

Compression:

Stored size: 1.58 KB

Contents

require "csv"
require_relative "downloader"

module DNN
  class DNN_Iris_LoadError < DNN_Error; end

  module Iris
    URL_CSV = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"

    # Iris-setosa
    SETOSA = 0
    # Iris-versicolor
    VERSICOLOR = 1
    # Iris-virginica
    VIRGINICA = 2

    def self.downloads
      return if File.exist?(url_to_file_name(URL_CSV))
      Downloader.download(URL_CSV)
    end

    def self.load(shuffle = false, shuffle_seed = rand(1 << 31))
      downloads
      csv_array = CSV.read(url_to_file_name(URL_CSV)).select { |a| a.length > 0 }
      x = Numo::SFloat.zeros(csv_array.length, 4)
      y = Numo::SFloat.zeros(csv_array.length)
      csv_array.each.with_index do |(sepal_length, sepal_width, petal_length, petal_width, classes), i|
        x[i, 0] = sepal_length.to_f
        x[i, 1] = sepal_width.to_f
        x[i, 2] = petal_length.to_f
        x[i, 3] = petal_width.to_f
        y[i] = case classes
        when "Iris-setosa"
          SETOSA
        when "Iris-versicolor"
          VERSICOLOR
        when "Iris-virginica"
          VIRGINICA
        else
          raise DNN_Iris_LoadError.new("Unknown class name '#{classes}' for iris")
        end
      end
      if shuffle
        orig_seed = Random::DEFAULT.seed
        srand(shuffle_seed)
        indexs = (0...csv_array.length).to_a.shuffle
        x[indexs, true] = x
        y[indexs] = y
        srand(orig_seed)
      end
      [x, y]
    end

    private_class_method

    def self.url_to_file_name(url)
      __dir__ + "/downloads/" + url.match(%r`.+/(.+)$`)[1]
    end
  end
end

Version data entries

3 entries across 3 versions & 1 rubygems

Version Path
ruby-dnn-0.10.4 lib/dnn/iris.rb
ruby-dnn-0.10.3 lib/dnn/iris.rb
ruby-dnn-0.10.2 lib/dnn/iris.rb