lib/tensorflow/tensor.rb in tensorflow-0.1.2 vs lib/tensorflow/tensor.rb in tensorflow-0.2.0
- old
+ new
@@ -103,23 +103,13 @@
when :complex64
data_pointer.read_array_of_float(element_count * 2).each_slice(2).map { |v| Complex(*v) }
when :complex128
data_pointer.read_array_of_double(element_count * 2).each_slice(2).map { |v| Complex(*v) }
when :string
- # string tensor format
- # https://github.com/tensorflow/tensorflow/blob/5453aee48858fd375172d7ae22fad1557e8557d6/tensorflow/c/tf_tensor.h#L57
- start_offset_size = element_count * 8
- offsets = data_pointer.read_array_of_uint64(element_count)
- byte_size = FFI.TF_TensorByteSize(tensor_pointer)
+ tf_string_size = 24
element_count.times.map do |i|
- str_len = (offsets[i + 1] || (byte_size - start_offset_size)) - offsets[i]
- str = (data_pointer + start_offset_size + offsets[i]).read_bytes(str_len)
- dst = ::FFI::MemoryPointer.new(:char, str.bytesize + 100)
- dst_len = ::FFI::MemoryPointer.new(:size_t)
- FFI.TF_StringDecode(str, str.bytesize, dst, dst_len, @status)
- check_status @status
- dst.read_pointer.read_bytes(dst_len.read_int32)
+ FFI.TF_StringGetDataPointer(data_pointer + i * tf_string_size)
end
when :bool
data_pointer.read_array_of_int8(element_count).map { |v| v == 1 }
when :resource, :variant
return data_pointer
@@ -224,23 +214,16 @@
d = d.first
end
shape
end
- # string tensor format
- # https://github.com/tensorflow/tensorflow/blob/5453aee48858fd375172d7ae22fad1557e8557d6/tensorflow/c/tf_tensor.h#L57
def string_ptr(data)
- start_offset_size = data.size * 8
- offsets = [0]
- data.each do |str|
- offsets << offsets.last + str.bytesize + 1
- end
- data_ptr = ::FFI::MemoryPointer.new(:char, start_offset_size + offsets.pop)
- data_ptr.write_array_of_uint64(offsets)
- data.zip(offsets) do |str, offset|
- dst_len = FFI.TF_StringEncodedSize(str.bytesize)
- FFI.TF_StringEncode(str, str.bytesize, data_ptr + start_offset_size + offset, dst_len, @status)
- check_status @status
+ tf_string_size = 24
+ data_ptr = ::FFI::MemoryPointer.new(:char, data.size * tf_string_size)
+ data.each_with_index do |str, i|
+ offset = data_ptr + i * tf_string_size
+ FFI.TF_StringInit(offset)
+ FFI.TF_StringCopy(offset, str, str.bytesize)
end
data_ptr
end
def check_status(status)