Sha256: 91b93490f97070b746888043c7f38308ad91f6d41a4c83219f490c2bdb81fdd5
Contents?: true
Size: 1.58 KB
Versions: 3
Compression:
Stored size: 1.58 KB
Contents
require "csv" require_relative "downloader" module DNN class DNN_Iris_LoadError < DNN_Error; end module Iris URL_CSV = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data" # Iris-setosa SETOSA = 0 # Iris-versicolor VERSICOLOR = 1 # Iris-virginica VIRGINICA = 2 def self.downloads return if File.exist?(url_to_file_name(URL_CSV)) Downloader.download(URL_CSV) end def self.load(shuffle = false, shuffle_seed = rand(1 << 31)) downloads csv_array = CSV.read(url_to_file_name(URL_CSV)).select { |a| a.length > 0 } x = Numo::SFloat.zeros(csv_array.length, 4) y = Numo::SFloat.zeros(csv_array.length) csv_array.each.with_index do |(sepal_length, sepal_width, petal_length, petal_width, classes), i| x[i, 0] = sepal_length.to_f x[i, 1] = sepal_width.to_f x[i, 2] = petal_length.to_f x[i, 3] = petal_width.to_f y[i] = case classes when "Iris-setosa" SETOSA when "Iris-versicolor" VERSICOLOR when "Iris-virginica" VIRGINICA else raise DNN_Iris_LoadError.new("Unknown class name '#{classes}' for iris") end end if shuffle orig_seed = Random::DEFAULT.seed srand(shuffle_seed) indexs = (0...csv_array.length).to_a.shuffle x[indexs, true] = x y[indexs] = y srand(orig_seed) end [x, y] end private_class_method def self.url_to_file_name(url) __dir__ + "/downloads/" + url.match(%r`.+/(.+)$`)[1] end end end
Version data entries
3 entries across 3 versions & 1 rubygems
Version | Path |
---|---|
ruby-dnn-0.10.4 | lib/dnn/iris.rb |
ruby-dnn-0.10.3 | lib/dnn/iris.rb |
ruby-dnn-0.10.2 | lib/dnn/iris.rb |