lib/npy.rb in npy-0.1.0 vs lib/npy.rb in npy-0.1.1

- old
+ new

@@ -9,10 +9,25 @@ module Npy class Error < StandardError; end MAGIC_STR = "\x93NUMPY".b + TYPE_MAP = { + "|i1" => Numo::Int8, + "<i2" => Numo::Int16, + "<i4" => Numo::Int32, + "<i8" => Numo::Int64, + "|u1" => Numo::UInt8, + "<u2" => Numo::UInt16, + "<u4" => Numo::UInt32, + "<u8" => Numo::UInt64, + "<f4" => Numo::SFloat, + "<f8" => Numo::DFloat, + "<c8" => Numo::SComplex, + "<c16" => Numo::DComplex + } + class << self def load(path) with_file(path) do |f| load_io(f) end @@ -35,76 +50,116 @@ def load_io(io) magic = io.read(6) raise Error, "Invalid npy format" unless magic&.b == MAGIC_STR - major_version = io.read(1) - minor_version = io.read(1) - raise Error, "Unsupported version" unless major_version == "\x01".b + version = io.read(2) - header_len = io.read(2).unpack1("S<") + header_len = + case version + when "\x01\x00".b + io.read(2).unpack1("S<") + when "\x02\x00".b, "\x03\x00".b + io.read(4).unpack1("I<") + else + raise Error, "Unsupported version" + end header = io.read(header_len) descr, fortran_order, shape = parse_header(header) raise Error, "Fortran order not supported" if fortran_order - klass = - case descr - when "|i1" - Numo::Int8 - when "<i2" - Numo::Int16 - when "<i4" - Numo::Int32 - when "<i8" - Numo::Int64 - when "|u1" - Numo::UInt8 - when "<u2" - Numo::UInt16 - when "<u4" - Numo::UInt32 - when "<u8" - Numo::UInt64 - when "<f4" - Numo::SFloat - when "<f8" - Numo::DFloat - when "<c8" - Numo::SComplex - when "<c16" - Numo::DComplex - else - raise Error, "Type not supported: #{descr}" - end + # numo can't handle empty shapes + empty_shape = shape.empty? + shape = [1] if empty_shape - klass.from_binary(io.read, shape) + klass = TYPE_MAP[descr] + raise Error, "Type not supported: #{descr}" unless klass + + # use from_string instead of from_binary for max compatibility + # from_binary introduced in 0.9.0.4 + result = klass.from_string(io.read, shape) + result = result[0] if empty_shape + result end def load_npz_io(io) File.new(io) end + def save(path, arr) + ::File.open(path, "wb") do |f| + save_io(f, arr) + end + true + end + + def save_npz(path, **arrs) + # use File.open instead passing path to zip file + # so it overrides instead of appends + ::File.open(path, "wb") do |f| + Zip::File.open(f, Zip::File::CREATE) do |zipfile| + arrs.each do |k, v| + zipfile.get_output_stream("#{k}.npy") do |f2| + save_io(f2, v) + end + end + end + end + true + end + private + def save_io(f, arr) + empty_shape = arr.is_a?(Numeric) + arr = Numo::NArray.cast([arr]) if empty_shape + arr = Numo::NArray.cast(arr) if arr.is_a?(Array) + + # desc + descr = TYPE_MAP.find { |k, v| arr.is_a?(v) } + raise Error, "Unsupported type: #{arr.class.name}" unless descr + + # shape + shape = arr.shape + shape << "" if shape.size == 1 + shape = [] if empty_shape + + # header + header = "{'descr': '#{descr[0]}', 'fortran_order': False, 'shape': (#{shape.join(", ")}), }".b + padding_len = 64 - (11 + header.length) % 64 + padding = "\x20".b * padding_len + header = "#{header}#{padding}\n" + + f.write(MAGIC_STR) + f.write("\x01\x00".b) + f.write([header.bytesize].pack("S<")) + f.write(header) + f.write(arr.to_string) + end + def with_file(path) ::File.open(path, "rb") do |f| yield f end end # header is Python dict, so use regex to parse def parse_header(header) + # sanity check + raise Error, "Bad header" if !header || header[-1] != "\n" + # descr - m = /'descr': '([^']+)'/.match(header) + m = /'descr': *'([^']+)'/.match(header) descr = m[1] # fortran_order - m = /'fortran_order': ([^,]+)/.match(header) + m = /'fortran_order': *([^,]+)/.match(header) fortran_order = m[1] == "True" # shape - m = /'shape': \(([^)]+)\)/.match(header) - shape = m[1].split(", ").map(&:to_i) + m = /'shape': *\(([^)]*)\)/.match(header) + # no space in split for max compatibility + shape = m[1].split(",").map(&:to_i) [descr, fortran_order, shape] end end end