lib/tensorflow/tensor.rb in tensorflow-0.1.1 vs lib/tensorflow/tensor.rb in tensorflow-0.1.2
- old
+ new
@@ -4,51 +4,64 @@
@status = FFI.TF_NewStatus
if pointer
@pointer = pointer
else
- data = Array(value)
+ data = value
+ data = Array(data) unless data.is_a?(Array) || data.is_a?(Numo::NArray)
shape ||= calculate_shape(value)
if shape.size > 0
dims_ptr = ::FFI::MemoryPointer.new(:int64, shape.size)
dims_ptr.write_array_of_int64(shape)
else
dims_ptr = nil
end
- data = data.flatten
-
- dtype ||= Utils.infer_type(data)
- type = FFI::DataType[dtype]
- case dtype
- when :float, :double, :int32, :uint8, :int16, :int8, :int64, :uint16, :uint32, :uint64
- data_ptr = ::FFI::MemoryPointer.new(dtype, data.size)
- data_ptr.send("write_array_of_#{dtype}", data)
- when :bfloat16
- # https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
- data_ptr = ::FFI::MemoryPointer.new(:int8, data.size * 2)
- data_ptr.write_bytes(data.map { |v| [v].pack("g")[0..1] }.join)
- when :complex64
- data_ptr = ::FFI::MemoryPointer.new(:float, data.size * 2)
- data_ptr.write_array_of_float(data.flat_map { |v| [v.real, v.imaginary] })
- when :complex128
- data_ptr = ::FFI::MemoryPointer.new(:double, data.size * 2)
- data_ptr.write_array_of_double(data.flat_map { |v| [v.real, v.imaginary] })
- when :string
- data_ptr = string_ptr(data)
- when :bool
- data_ptr = ::FFI::MemoryPointer.new(:int8, data.size)
- data_ptr.write_array_of_int8(data.map { |v| v ? 1 : 0 })
+ if data.is_a?(Numo::NArray)
+ dtype ||= Utils.infer_type(data)
+ # TODO use Numo read pointer?
+ data_ptr = ::FFI::MemoryPointer.new(:uchar, data.byte_size)
+ data_ptr.write_bytes(data.to_string)
else
- raise "Unknown type: #{dtype}"
+ data = data.flatten
+ dtype ||= Utils.infer_type(data)
+ case dtype
+ when :float, :double, :int32, :uint8, :int16, :int8, :int64, :uint16, :uint32, :uint64
+ data_ptr = ::FFI::MemoryPointer.new(dtype, data.size)
+ data_ptr.send("write_array_of_#{dtype}", data)
+ when :bfloat16
+ # https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
+ data_ptr = ::FFI::MemoryPointer.new(:int8, data.size * 2)
+ data_ptr.write_bytes(data.map { |v| [v].pack("g")[0..1] }.join)
+ when :complex64
+ data_ptr = ::FFI::MemoryPointer.new(:float, data.size * 2)
+ data_ptr.write_array_of_float(data.flat_map { |v| [v.real, v.imaginary] })
+ when :complex128
+ data_ptr = ::FFI::MemoryPointer.new(:double, data.size * 2)
+ data_ptr.write_array_of_double(data.flat_map { |v| [v.real, v.imaginary] })
+ when :string
+ data_ptr = string_ptr(data)
+ when :bool
+ data_ptr = ::FFI::MemoryPointer.new(:int8, data.size)
+ data_ptr.write_array_of_int8(data.map { |v| v ? 1 : 0 })
+ else
+ raise "Unknown type: #{dtype}"
+ end
end
+ type = FFI::DataType[dtype]
+
callback = ::FFI::Function.new(:void, [:pointer, :size_t, :pointer]) do |data, len, arg|
# FFI handles deallocation
end
+ # keep data pointer alive for duration of object
+ @data_ptr = data_ptr
+ @dims_ptr = dims_ptr
+ @callback = callback
+
tensor = FFI.TF_NewTensor(type, dims_ptr, shape.size, data_ptr, data_ptr.size, callback, nil)
@pointer = FFI.TFE_NewTensorHandle(tensor, @status)
check_status @status
end
@@ -73,10 +86,14 @@
def %(other)
Math.floormod(self, other)
end
+ def -@
+ Math.negative(self)
+ end
+
def value
value =
case dtype
when :float, :double, :int32, :uint8, :int16, :int8, :int64, :uint16, :uint32, :uint64
data_pointer.send("read_array_of_#{dtype}", element_count)
@@ -90,14 +107,23 @@
when :string
# string tensor format
# https://github.com/tensorflow/tensorflow/blob/5453aee48858fd375172d7ae22fad1557e8557d6/tensorflow/c/tf_tensor.h#L57
start_offset_size = element_count * 8
offsets = data_pointer.read_array_of_uint64(element_count)
- element_count.times.map { |i| (data_pointer + start_offset_size + offsets[i]).read_string }
+ byte_size = FFI.TF_TensorByteSize(tensor_pointer)
+ element_count.times.map do |i|
+ str_len = (offsets[i + 1] || (byte_size - start_offset_size)) - offsets[i]
+ str = (data_pointer + start_offset_size + offsets[i]).read_bytes(str_len)
+ dst = ::FFI::MemoryPointer.new(:char, str.bytesize + 100)
+ dst_len = ::FFI::MemoryPointer.new(:size_t)
+ FFI.TF_StringDecode(str, str.bytesize, dst, dst_len, @status)
+ check_status @status
+ dst.read_pointer.read_bytes(dst_len.read_int32)
+ end
when :bool
data_pointer.read_array_of_int8(element_count).map { |v| v == 1 }
- when :resource
+ when :resource, :variant
return data_pointer
else
raise "Unknown type: #{dtype}"
end
@@ -133,12 +159,18 @@
def to_ptr
@pointer
end
+ def numo
+ klass = Utils::NUMO_TYPE_MAP[dtype]
+ raise "Unknown type: #{dtype}" unless klass
+ klass.cast(value)
+ end
+
def inspect
- inspection = %w(value shape dtype).map { |v| "#{v}: #{send(v).inspect}"}
+ inspection = %w(numo shape dtype).map { |v| "#{v}: #{send(v).inspect}"}
"#<#{self.class} #{inspection.join(", ")}>"
end
def self.finalize(pointer, status, tensor)
# must use proc instead of stabby lambda
@@ -162,13 +194,17 @@
check_status @status
ret
end
def data_pointer
+ FFI.TF_TensorData(tensor_pointer)
+ end
+
+ def tensor_pointer
tensor = FFI.TFE_TensorHandleResolve(@pointer, @status)
check_status @status
- FFI.TF_TensorData(tensor)
+ tensor
end
def reshape(arr, dims)
return arr.first if dims.empty?
arr = arr.flatten
@@ -177,10 +213,12 @@
end
arr.to_a
end
def calculate_shape(value)
+ return value.shape if value.respond_to?(:shape)
+
shape = []
d = value
while d.is_a?(Array)
shape << d.size
d = d.first
@@ -197,10 +235,12 @@
offsets << offsets.last + str.bytesize + 1
end
data_ptr = ::FFI::MemoryPointer.new(:char, start_offset_size + offsets.pop)
data_ptr.write_array_of_uint64(offsets)
data.zip(offsets) do |str, offset|
- (data_ptr + start_offset_size + offset).write_string(str)
+ dst_len = FFI.TF_StringEncodedSize(str.bytesize)
+ FFI.TF_StringEncode(str, str.bytesize, data_ptr + start_offset_size + offset, dst_len, @status)
+ check_status @status
end
data_ptr
end
def check_status(status)