lib/npy.rb in npy-0.1.1 vs lib/npy.rb in npy-0.1.2

- old
+ new

@@ -1,9 +1,13 @@ # dependencies require "numo/narray" require "zip" +# stdlib +require "stringio" +require "tempfile" + # modules require "npy/file" require "npy/version" module Npy @@ -26,30 +30,44 @@ "<c16" => Numo::DComplex } class << self def load(path) - with_file(path) do |f| - load_io(f) + case path + when IO, StringIO + load_io(path) + else + load_file(path) end end def load_npz(path) - with_file(path) do |f| - load_npz_io(f) + case path + when IO, StringIO + load_npz_io(path) + else + load_npz_file(path) end end def load_string(byte_str) load_io(StringIO.new(byte_str)) end - # rubyzip not playing nicely with StringIO - # def load_npz_string(byte_str) - # load_npz_io(StringIO.new(byte_str)) - # end + def load_npz_string(byte_str) + # not playing nicely with StringIO + file = Tempfile.new("npy") + begin + file.write(byte_str) + load_npz_io(file) + ensure + file.close + file.unlink + end + end + # TODO make private def load_io(io) magic = io.read(6) raise Error, "Invalid npy format" unless magic&.b == MAGIC_STR version = io.read(2) @@ -79,54 +97,70 @@ result = klass.from_string(io.read, shape) result = result[0] if empty_shape result end + # TODO make private def load_npz_io(io) File.new(io) end def save(path, arr) - ::File.open(path, "wb") do |f| - save_io(f, arr) + case path + when IO, StringIO + save_io(path, arr) + else + save_file(path, arr) end true end - def save_npz(path, **arrs) - # use File.open instead passing path to zip file - # so it overrides instead of appends - ::File.open(path, "wb") do |f| - Zip::File.open(f, Zip::File::CREATE) do |zipfile| - arrs.each do |k, v| - zipfile.get_output_stream("#{k}.npy") do |f2| - save_io(f2, v) - end - end - end + def save_npz(path, arrs) + case path + when IO, StringIO + save_npz_io(path, arrs) + else + save_npz_file(path, arrs) end true end private + def load_file(path) + with_file(path, "rb") do |f| + load_io(f) + end + end + + def load_npz_file(path) + with_file(path, "rb") do |f| + load_npz_io(f) + end + end + + def save_file(path, arr) + with_file(path, "wb") do |f| + save_io(f, arr) + end + end + def save_io(f, arr) empty_shape = arr.is_a?(Numeric) arr = Numo::NArray.cast([arr]) if empty_shape arr = Numo::NArray.cast(arr) if arr.is_a?(Array) # desc - descr = TYPE_MAP.find { |k, v| arr.is_a?(v) } + descr = TYPE_MAP.find { |_, v| arr.is_a?(v) } raise Error, "Unsupported type: #{arr.class.name}" unless descr # shape shape = arr.shape - shape << "" if shape.size == 1 shape = [] if empty_shape # header - header = "{'descr': '#{descr[0]}', 'fortran_order': False, 'shape': (#{shape.join(", ")}), }".b + header = "{'descr': '#{descr[0]}', 'fortran_order': False, 'shape': (#{shape.join(", ")}#{shape.size == 1 ? "," : nil}), }".b padding_len = 64 - (11 + header.length) % 64 padding = "\x20".b * padding_len header = "#{header}#{padding}\n" f.write(MAGIC_STR) @@ -134,12 +168,28 @@ f.write([header.bytesize].pack("S<")) f.write(header) f.write(arr.to_string) end - def with_file(path) - ::File.open(path, "rb") do |f| + def save_npz_file(path, arrs) + with_file(path, "wb") do |f| + save_npz_io(f, arrs) + end + end + + def save_npz_io(f, arrs) + Zip::File.open(f, Zip::File::CREATE) do |zipfile| + arrs.each do |k, v| + zipfile.get_output_stream("#{k}.npy") do |f2| + save_io(f2, v) + end + end + end + end + + def with_file(path, mode) + ::File.open(path, mode) do |f| yield f end end # header is Python dict, so use regex to parse @@ -156,10 +206,10 @@ fortran_order = m[1] == "True" # shape m = /'shape': *\(([^)]*)\)/.match(header) # no space in split for max compatibility - shape = m[1].split(",").map(&:to_i) + shape = m[1].strip.split(",").map(&:to_i) [descr, fortran_order, shape] end end end