lib/npy.rb in npy-0.1.2 vs lib/npy.rb in npy-0.2.0
- old
+ new
@@ -25,11 +25,14 @@
"<u4" => Numo::UInt32,
"<u8" => Numo::UInt64,
"<f4" => Numo::SFloat,
"<f8" => Numo::DFloat,
"<c8" => Numo::SComplex,
- "<c16" => Numo::DComplex
+ "<c16" => Numo::DComplex,
+ # must come last
+ # as save uses first match
+ "|b1" => Numo::UInt8
}
class << self
def load(path)
case path
@@ -83,22 +86,21 @@
end
header = io.read(header_len)
descr, fortran_order, shape = parse_header(header)
raise Error, "Fortran order not supported" if fortran_order
- # numo can't handle empty shapes
- empty_shape = shape.empty?
- shape = [1] if empty_shape
-
klass = TYPE_MAP[descr]
raise Error, "Type not supported: #{descr}" unless klass
# use from_string instead of from_binary for max compatibility
# from_binary introduced in 0.9.0.4
- result = klass.from_string(io.read, shape)
- result = result[0] if empty_shape
- result
+ # numo from_string can't handle rank0
+ if shape.empty?
+ klass.cast(klass.from_string(io.read, [1])[0])
+ else
+ klass.from_string(io.read, shape)
+ end
end
# TODO make private
def load_npz_io(io)
File.new(io)
@@ -143,20 +145,23 @@
save_io(f, arr)
end
end
def save_io(f, arr)
- empty_shape = arr.is_a?(Numeric)
- arr = Numo::NArray.cast([arr]) if empty_shape
- arr = Numo::NArray.cast(arr) if arr.is_a?(Array)
+ unless arr.is_a?(Numo::NArray)
+ begin
+ arr = Numo::NArray.cast(arr)
+ rescue TypeError
+ # do nothing
+ end
+ end
# desc
descr = TYPE_MAP.find { |_, v| arr.is_a?(v) }
raise Error, "Unsupported type: #{arr.class.name}" unless descr
# shape
shape = arr.shape
- shape = [] if empty_shape
# header
header = "{'descr': '#{descr[0]}', 'fortran_order': False, 'shape': (#{shape.join(", ")}#{shape.size == 1 ? "," : nil}), }".b
padding_len = 64 - (11 + header.length) % 64
padding = "\x20".b * padding_len