Sha256: 83ffc85268c4fefa8d17135b1232cbbee698f739d35337acba575cf220acc675

Contents?: true

Size: 1.33 KB

Versions: 8

Compression:

Stored size: 1.33 KB

Contents

module CIFAR10
  def self.load_train(index)
    if File.exist?("CIFAR-10-train#{index}.marshal")
      marshal = File.binread("CIFAR-10-train#{index}.marshal")
      return Marshal.load(marshal)
    end
    bin = File.binread("#{dir}/data_batch_#{index}.bin")
    datasets = bin.unpack("C*")
    x_train = []
    y_train = []
    loop do
      label = datasets.shift
      break unless label
      x_train << datasets.slice!(0, 3072)
      y_train << label
    end
    train = [x_train, y_train]
    File.binwrite("CIFAR-10-train#{index}.marshal", Marshal.dump(train))
    train
  end

  def self.load_test
    if File.exist?("CIFAR-10-test.marshal")
      marshal = File.binread("CIFAR-10-test.marshal")
      return Marshal.load(marshal)
    end
    bin = File.binread("#{dir}/test_batch.bin")
    datasets = bin.unpack("C*")
    x_test = []
    y_test = []
    loop do
      label = datasets.shift
      break unless label
      x_test << datasets.slice!(0, 3072)
      y_test << label
    end
    test = [x_test, y_test]
    File.binwrite("CIFAR-10-test.marshal", Marshal.dump(test))
    test
  end

  def self.categorical(y_data)
    y_data = y_data.map do |label|
      classes = Array.new(10, 0)
      classes[label] = 1
      classes
    end
  end

  def self.dir
    "cifar-10-batches-bin"
  end
end

Version data entries

8 entries across 8 versions & 1 rubygems

Version Path
nn-2.4.0 lib/nn/cifar10.rb
nn-2.3.0 lib/nn/cifar10.rb
nn-2.2.0 lib/nn/cifar10.rb
nn-2.1.0 lib/nn/cifar10.rb
nn-2.0.0 lib/nn/cifar10.rb
nn-2.0.1 lib/nn/cifar10.rb
nn-2.0 lib/nn/cifar10.rb
nn-1.8 lib/nn/cifar10.rb