lib/tensorflow/tensor.rb in tensorflow-0.1.1 vs lib/tensorflow/tensor.rb in tensorflow-0.1.2

- old
+ new

@@ -4,51 +4,64 @@ @status = FFI.TF_NewStatus if pointer @pointer = pointer else - data = Array(value) + data = value + data = Array(data) unless data.is_a?(Array) || data.is_a?(Numo::NArray) shape ||= calculate_shape(value) if shape.size > 0 dims_ptr = ::FFI::MemoryPointer.new(:int64, shape.size) dims_ptr.write_array_of_int64(shape) else dims_ptr = nil end - data = data.flatten - - dtype ||= Utils.infer_type(data) - type = FFI::DataType[dtype] - case dtype - when :float, :double, :int32, :uint8, :int16, :int8, :int64, :uint16, :uint32, :uint64 - data_ptr = ::FFI::MemoryPointer.new(dtype, data.size) - data_ptr.send("write_array_of_#{dtype}", data) - when :bfloat16 - # https://en.wikipedia.org/wiki/Bfloat16_floating-point_format - data_ptr = ::FFI::MemoryPointer.new(:int8, data.size * 2) - data_ptr.write_bytes(data.map { |v| [v].pack("g")[0..1] }.join) - when :complex64 - data_ptr = ::FFI::MemoryPointer.new(:float, data.size * 2) - data_ptr.write_array_of_float(data.flat_map { |v| [v.real, v.imaginary] }) - when :complex128 - data_ptr = ::FFI::MemoryPointer.new(:double, data.size * 2) - data_ptr.write_array_of_double(data.flat_map { |v| [v.real, v.imaginary] }) - when :string - data_ptr = string_ptr(data) - when :bool - data_ptr = ::FFI::MemoryPointer.new(:int8, data.size) - data_ptr.write_array_of_int8(data.map { |v| v ? 1 : 0 }) + if data.is_a?(Numo::NArray) + dtype ||= Utils.infer_type(data) + # TODO use Numo read pointer? + data_ptr = ::FFI::MemoryPointer.new(:uchar, data.byte_size) + data_ptr.write_bytes(data.to_string) else - raise "Unknown type: #{dtype}" + data = data.flatten + dtype ||= Utils.infer_type(data) + case dtype + when :float, :double, :int32, :uint8, :int16, :int8, :int64, :uint16, :uint32, :uint64 + data_ptr = ::FFI::MemoryPointer.new(dtype, data.size) + data_ptr.send("write_array_of_#{dtype}", data) + when :bfloat16 + # https://en.wikipedia.org/wiki/Bfloat16_floating-point_format + data_ptr = ::FFI::MemoryPointer.new(:int8, data.size * 2) + data_ptr.write_bytes(data.map { |v| [v].pack("g")[0..1] }.join) + when :complex64 + data_ptr = ::FFI::MemoryPointer.new(:float, data.size * 2) + data_ptr.write_array_of_float(data.flat_map { |v| [v.real, v.imaginary] }) + when :complex128 + data_ptr = ::FFI::MemoryPointer.new(:double, data.size * 2) + data_ptr.write_array_of_double(data.flat_map { |v| [v.real, v.imaginary] }) + when :string + data_ptr = string_ptr(data) + when :bool + data_ptr = ::FFI::MemoryPointer.new(:int8, data.size) + data_ptr.write_array_of_int8(data.map { |v| v ? 1 : 0 }) + else + raise "Unknown type: #{dtype}" + end end + type = FFI::DataType[dtype] + callback = ::FFI::Function.new(:void, [:pointer, :size_t, :pointer]) do |data, len, arg| # FFI handles deallocation end + # keep data pointer alive for duration of object + @data_ptr = data_ptr + @dims_ptr = dims_ptr + @callback = callback + tensor = FFI.TF_NewTensor(type, dims_ptr, shape.size, data_ptr, data_ptr.size, callback, nil) @pointer = FFI.TFE_NewTensorHandle(tensor, @status) check_status @status end @@ -73,10 +86,14 @@ def %(other) Math.floormod(self, other) end + def -@ + Math.negative(self) + end + def value value = case dtype when :float, :double, :int32, :uint8, :int16, :int8, :int64, :uint16, :uint32, :uint64 data_pointer.send("read_array_of_#{dtype}", element_count) @@ -90,14 +107,23 @@ when :string # string tensor format # https://github.com/tensorflow/tensorflow/blob/5453aee48858fd375172d7ae22fad1557e8557d6/tensorflow/c/tf_tensor.h#L57 start_offset_size = element_count * 8 offsets = data_pointer.read_array_of_uint64(element_count) - element_count.times.map { |i| (data_pointer + start_offset_size + offsets[i]).read_string } + byte_size = FFI.TF_TensorByteSize(tensor_pointer) + element_count.times.map do |i| + str_len = (offsets[i + 1] || (byte_size - start_offset_size)) - offsets[i] + str = (data_pointer + start_offset_size + offsets[i]).read_bytes(str_len) + dst = ::FFI::MemoryPointer.new(:char, str.bytesize + 100) + dst_len = ::FFI::MemoryPointer.new(:size_t) + FFI.TF_StringDecode(str, str.bytesize, dst, dst_len, @status) + check_status @status + dst.read_pointer.read_bytes(dst_len.read_int32) + end when :bool data_pointer.read_array_of_int8(element_count).map { |v| v == 1 } - when :resource + when :resource, :variant return data_pointer else raise "Unknown type: #{dtype}" end @@ -133,12 +159,18 @@ def to_ptr @pointer end + def numo + klass = Utils::NUMO_TYPE_MAP[dtype] + raise "Unknown type: #{dtype}" unless klass + klass.cast(value) + end + def inspect - inspection = %w(value shape dtype).map { |v| "#{v}: #{send(v).inspect}"} + inspection = %w(numo shape dtype).map { |v| "#{v}: #{send(v).inspect}"} "#<#{self.class} #{inspection.join(", ")}>" end def self.finalize(pointer, status, tensor) # must use proc instead of stabby lambda @@ -162,13 +194,17 @@ check_status @status ret end def data_pointer + FFI.TF_TensorData(tensor_pointer) + end + + def tensor_pointer tensor = FFI.TFE_TensorHandleResolve(@pointer, @status) check_status @status - FFI.TF_TensorData(tensor) + tensor end def reshape(arr, dims) return arr.first if dims.empty? arr = arr.flatten @@ -177,10 +213,12 @@ end arr.to_a end def calculate_shape(value) + return value.shape if value.respond_to?(:shape) + shape = [] d = value while d.is_a?(Array) shape << d.size d = d.first @@ -197,10 +235,12 @@ offsets << offsets.last + str.bytesize + 1 end data_ptr = ::FFI::MemoryPointer.new(:char, start_offset_size + offsets.pop) data_ptr.write_array_of_uint64(offsets) data.zip(offsets) do |str, offset| - (data_ptr + start_offset_size + offset).write_string(str) + dst_len = FFI.TF_StringEncodedSize(str.bytesize) + FFI.TF_StringEncode(str, str.bytesize, data_ptr + start_offset_size + offset, dst_len, @status) + check_status @status end data_ptr end def check_status(status)