lib/npy.rb in npy-0.1.2 vs lib/npy.rb in npy-0.2.0

- old
+ new

@@ -25,11 +25,14 @@ "<u4" => Numo::UInt32, "<u8" => Numo::UInt64, "<f4" => Numo::SFloat, "<f8" => Numo::DFloat, "<c8" => Numo::SComplex, - "<c16" => Numo::DComplex + "<c16" => Numo::DComplex, + # must come last + # as save uses first match + "|b1" => Numo::UInt8 } class << self def load(path) case path @@ -83,22 +86,21 @@ end header = io.read(header_len) descr, fortran_order, shape = parse_header(header) raise Error, "Fortran order not supported" if fortran_order - # numo can't handle empty shapes - empty_shape = shape.empty? - shape = [1] if empty_shape - klass = TYPE_MAP[descr] raise Error, "Type not supported: #{descr}" unless klass # use from_string instead of from_binary for max compatibility # from_binary introduced in 0.9.0.4 - result = klass.from_string(io.read, shape) - result = result[0] if empty_shape - result + # numo from_string can't handle rank0 + if shape.empty? + klass.cast(klass.from_string(io.read, [1])[0]) + else + klass.from_string(io.read, shape) + end end # TODO make private def load_npz_io(io) File.new(io) @@ -143,20 +145,23 @@ save_io(f, arr) end end def save_io(f, arr) - empty_shape = arr.is_a?(Numeric) - arr = Numo::NArray.cast([arr]) if empty_shape - arr = Numo::NArray.cast(arr) if arr.is_a?(Array) + unless arr.is_a?(Numo::NArray) + begin + arr = Numo::NArray.cast(arr) + rescue TypeError + # do nothing + end + end # desc descr = TYPE_MAP.find { |_, v| arr.is_a?(v) } raise Error, "Unsupported type: #{arr.class.name}" unless descr # shape shape = arr.shape - shape = [] if empty_shape # header header = "{'descr': '#{descr[0]}', 'fortran_order': False, 'shape': (#{shape.join(", ")}#{shape.size == 1 ? "," : nil}), }".b padding_len = 64 - (11 + header.length) % 64 padding = "\x20".b * padding_len