lib/npy.rb in npy-0.1.1 vs lib/npy.rb in npy-0.1.2
- old
+ new
@@ -1,9 +1,13 @@
# dependencies
require "numo/narray"
require "zip"
+# stdlib
+require "stringio"
+require "tempfile"
+
# modules
require "npy/file"
require "npy/version"
module Npy
@@ -26,30 +30,44 @@
"<c16" => Numo::DComplex
}
class << self
def load(path)
- with_file(path) do |f|
- load_io(f)
+ case path
+ when IO, StringIO
+ load_io(path)
+ else
+ load_file(path)
end
end
def load_npz(path)
- with_file(path) do |f|
- load_npz_io(f)
+ case path
+ when IO, StringIO
+ load_npz_io(path)
+ else
+ load_npz_file(path)
end
end
def load_string(byte_str)
load_io(StringIO.new(byte_str))
end
- # rubyzip not playing nicely with StringIO
- # def load_npz_string(byte_str)
- # load_npz_io(StringIO.new(byte_str))
- # end
+ def load_npz_string(byte_str)
+ # not playing nicely with StringIO
+ file = Tempfile.new("npy")
+ begin
+ file.write(byte_str)
+ load_npz_io(file)
+ ensure
+ file.close
+ file.unlink
+ end
+ end
+ # TODO make private
def load_io(io)
magic = io.read(6)
raise Error, "Invalid npy format" unless magic&.b == MAGIC_STR
version = io.read(2)
@@ -79,54 +97,70 @@
result = klass.from_string(io.read, shape)
result = result[0] if empty_shape
result
end
+ # TODO make private
def load_npz_io(io)
File.new(io)
end
def save(path, arr)
- ::File.open(path, "wb") do |f|
- save_io(f, arr)
+ case path
+ when IO, StringIO
+ save_io(path, arr)
+ else
+ save_file(path, arr)
end
true
end
- def save_npz(path, **arrs)
- # use File.open instead passing path to zip file
- # so it overrides instead of appends
- ::File.open(path, "wb") do |f|
- Zip::File.open(f, Zip::File::CREATE) do |zipfile|
- arrs.each do |k, v|
- zipfile.get_output_stream("#{k}.npy") do |f2|
- save_io(f2, v)
- end
- end
- end
+ def save_npz(path, arrs)
+ case path
+ when IO, StringIO
+ save_npz_io(path, arrs)
+ else
+ save_npz_file(path, arrs)
end
true
end
private
+ def load_file(path)
+ with_file(path, "rb") do |f|
+ load_io(f)
+ end
+ end
+
+ def load_npz_file(path)
+ with_file(path, "rb") do |f|
+ load_npz_io(f)
+ end
+ end
+
+ def save_file(path, arr)
+ with_file(path, "wb") do |f|
+ save_io(f, arr)
+ end
+ end
+
def save_io(f, arr)
empty_shape = arr.is_a?(Numeric)
arr = Numo::NArray.cast([arr]) if empty_shape
arr = Numo::NArray.cast(arr) if arr.is_a?(Array)
# desc
- descr = TYPE_MAP.find { |k, v| arr.is_a?(v) }
+ descr = TYPE_MAP.find { |_, v| arr.is_a?(v) }
raise Error, "Unsupported type: #{arr.class.name}" unless descr
# shape
shape = arr.shape
- shape << "" if shape.size == 1
shape = [] if empty_shape
# header
- header = "{'descr': '#{descr[0]}', 'fortran_order': False, 'shape': (#{shape.join(", ")}), }".b
+ header = "{'descr': '#{descr[0]}', 'fortran_order': False, 'shape': (#{shape.join(", ")}#{shape.size == 1 ? "," : nil}), }".b
padding_len = 64 - (11 + header.length) % 64
padding = "\x20".b * padding_len
header = "#{header}#{padding}\n"
f.write(MAGIC_STR)
@@ -134,12 +168,28 @@
f.write([header.bytesize].pack("S<"))
f.write(header)
f.write(arr.to_string)
end
- def with_file(path)
- ::File.open(path, "rb") do |f|
+ def save_npz_file(path, arrs)
+ with_file(path, "wb") do |f|
+ save_npz_io(f, arrs)
+ end
+ end
+
+ def save_npz_io(f, arrs)
+ Zip::File.open(f, Zip::File::CREATE) do |zipfile|
+ arrs.each do |k, v|
+ zipfile.get_output_stream("#{k}.npy") do |f2|
+ save_io(f2, v)
+ end
+ end
+ end
+ end
+
+ def with_file(path, mode)
+ ::File.open(path, mode) do |f|
yield f
end
end
# header is Python dict, so use regex to parse
@@ -156,10 +206,10 @@
fortran_order = m[1] == "True"
# shape
m = /'shape': *\(([^)]*)\)/.match(header)
# no space in split for max compatibility
- shape = m[1].split(",").map(&:to_i)
+ shape = m[1].strip.split(",").map(&:to_i)
[descr, fortran_order, shape]
end
end
end