lib/tensorflow/utils.rb in tensorflow-0.1.1 vs lib/tensorflow/utils.rb in tensorflow-0.1.2

- old
+ new

@@ -1,7 +1,20 @@ module TensorFlow module Utils + NUMO_TYPE_MAP = { + int8: Numo::Int8, + int16: Numo::Int16, + int32: Numo::Int32, + int64: Numo::Int64, + uint8: Numo::UInt8, + uint16: Numo::UInt16, + uint32: Numo::UInt32, + uint64: Numo::UInt64, + float: Numo::SFloat, + double: Numo::DFloat + } + class << self def check_status(status) if FFI.TF_GetCode(status) != 0 raise Error, FFI.TF_Message(status) end @@ -24,62 +37,127 @@ is_list = ::FFI::MemoryPointer.new(:int) type = FFI.TFE_OpGetAttrType(op, attr_name, is_list, status) check_status status - case FFI::AttrType[type] - when :string - FFI.TFE_OpSetAttrString(op, attr_name, attr_value, attr_value.bytesize) - # when :int - # when :float - # when :bool - when :type - FFI.TFE_OpSetAttrType(op, attr_name, attr_value) - when :shape - # TODO set value properly - FFI.TFE_OpSetAttrShape(op, attr_name, nil, 0, status) - check_status status - # when :tensor - # when :placeholder - # when :func + if is_list.read_int == 1 + num_values = attr_value.size + + case FFI::AttrType[type] + when :int + values = ::FFI::MemoryPointer.new(:int64, num_values) + values.write_array_of_int64(attr_value) + FFI.TFE_OpSetAttrIntList(op, attr_name, values, num_values) + when :float + values = ::FFI::MemoryPointer.new(:float, num_values) + values.write_array_of_float(attr_value) + FFI.TFE_OpSetAttrFloatList(op, attr_name, values, num_values) + when :shape + dims_ptrs = + attr_value.map do |shape| + ptr = ::FFI::MemoryPointer.new(:int64, shape.size) + ptr.write_array_of_int64(shape) + end + dims = ::FFI::MemoryPointer.new(:pointer, num_values) + dims.write_array_of_pointer(dims_ptrs) + + num_dims = ::FFI::MemoryPointer.new(:int, num_values) + num_dims.write_array_of_int(attr_value.map(&:size)) + + FFI.TFE_OpSetAttrShapeList(op, attr_name, dims, num_dims, num_values, status) + when :type + values = ::FFI::MemoryPointer.new(:int, num_values) + types = + attr_value.map do |v| + if v.is_a?(Symbol) + FFI::DataType[v] + else + v + end + end + values.write_array_of_int(types) + FFI.TFE_OpSetAttrTypeList(op, attr_name, values, num_values) + else + raise "Unknown list type: #{FFI::AttrType[type]}" + end else - raise "Unknown type: #{FFI::AttrType[type]}" + case FFI::AttrType[type] + when :string + FFI.TFE_OpSetAttrString(op, attr_name, attr_value, attr_value.bytesize) + when :int + FFI.TFE_OpSetAttrInt(op, attr_name, attr_value) + when :float + FFI.TFE_OpSetAttrFloat(op, attr_name, attr_value) + when :bool + FFI.TFE_OpSetAttrBool(op, attr_name, attr_value ? 1 : 0) + when :type + attr_value = FFI::DataType[attr_value] if attr_value.is_a?(Symbol) + FFI.TFE_OpSetAttrType(op, attr_name, attr_value) + when :shape + ptr = ::FFI::MemoryPointer.new(:int64, attr_value.size) + ptr.write_array_of_int64(attr_value) + FFI.TFE_OpSetAttrShape(op, attr_name, ptr, attr_value.size, status) + check_status status + # when :tensor + # when :placeholder + # when :func + else + raise "Unknown type: #{FFI::AttrType[type]}" + end end end - inputs.each do |input| - input = TensorFlow.convert_to_tensor(input) unless input.respond_to?(:to_ptr) - FFI.TFE_OpAddInput(op, input, status) + inputs.each_with_index do |input, i| + # TODO handle this better + if op_name == "TensorSliceDataset" && i == 0 + input_ptr = ::FFI::MemoryPointer.new(:pointer, input.size) + input_ptr.write_array_of_pointer(input) + FFI.TFE_OpAddInputList(op, input_ptr, input.size, status) + else + raise "Missing argument" if input.nil? + + input = TensorFlow.convert_to_tensor(input) unless input.respond_to?(:to_ptr) + FFI.TFE_OpAddInput(op, input, status) + end check_status status end - retvals = ::FFI::MemoryPointer.new(:pointer) + # TODO decide how many retvals to allocate + retvals = ::FFI::MemoryPointer.new(:pointer, 2) num_retvals = ::FFI::MemoryPointer.new(:int) num_retvals.write_int(retvals.size) FFI.TFE_Execute(op, retvals, num_retvals, status) check_status status - if num_retvals.read_int > 0 - handle = retvals.read_pointer - type = FFI.TFE_TensorHandleDataType(handle) + n = num_retvals.read_int + if n > 0 + retvals = + retvals.read_array_of_pointer(n).map do |handle| + Tensor.new(pointer: handle) + end - case FFI::DataType[type] - when :resource - handle - else - Tensor.new(pointer: handle) - end + # TODO handle case where n = 1 and still want an array for retvals + n == 1 ? retvals.first : retvals end ensure FFI.TF_DeleteStatus(status) if status FFI.TFE_DeleteOp(op) if op end def infer_type(value) - if value.all? { |v| v.is_a?(String) } + if value.is_a?(Numo::NArray) + type = NUMO_TYPE_MAP.find { |k, v| value.is_a?(v) } + if type + type.first + else + raise Error, "Unable to infer data type" + end + elsif value.empty? + raise Error, "Unable to infer data type" + elsif value.all? { |v| v.is_a?(String) } :string - elsif value.all? { |v| v == true || v == false } + elsif value.all? { |v| v.is_a?(TrueClass) || v.is_a?(FalseClass) } :bool elsif value.all? { |v| v.is_a?(Integer) } if value.all? { |v| v >= -2147483648 && v <= 2147483647 } :int32 else @@ -92,50 +170,17 @@ else raise Error, "Unable to infer data type" end end - def load_dataset(path, url) - # TODO handle this better - raise "No HOME" unless ENV["HOME"] - datasets_dir = "#{ENV["HOME"]}/.keras/datasets" - FileUtils.mkdir_p(datasets_dir) - - path = "#{datasets_dir}/#{path}" - Utils.download_file(url, path) unless File.exist?(path) - Npy.load_npz(path) - end - - def download_file(url, dest) - uri = URI(url) - - temp_dir ||= File.dirname(Tempfile.new("tensorflow")) - temp_path = "#{temp_dir}/#{Time.now.to_f}" # TODO better name - - # Net::HTTP automatically adds Accept-Encoding for compression - # of response bodies and automatically decompresses gzip - # and deflateresponses unless a Range header was sent. - # https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html - Net::HTTP.start(uri.host, uri.port, use_ssl: true) do |http| - request = Net::HTTP::Get.new(uri) - - print("Downloading dataset") - i = 0 - File.open(temp_path, "wb") do |f| - http.request(request) do |response| - response.read_body do |chunk| - f.write(chunk) - - # print progress - putc "." if i % 50 == 0 - i += 1 - end - end - puts # newline + def to_tensor_array(values) + values.map do |v| + if v.is_a?(Tensor) + v + else + TensorFlow.convert_to_tensor(v) end end - - FileUtils.mv(temp_path, dest) end end end end