lib/tensorflow/utils.rb in tensorflow-0.1.0 vs lib/tensorflow/utils.rb in tensorflow-0.1.1

- old
+ new

@@ -1,22 +1,141 @@ module TensorFlow module Utils - def self.check_status(status) - if FFI.TF_GetCode(status) != 0 - raise Error, FFI.TF_Message(status) + class << self + def check_status(status) + if FFI.TF_GetCode(status) != 0 + raise Error, FFI.TF_Message(status) + end end - end - def self.infer_type(value) - case value - when String - :string - when Float - :float - when true, false - :bool - else - :int32 + def default_context + @default_context ||= Context.new + end + + def execute(op_name, inputs = [], **attrs) + context = default_context + status = FFI.TF_NewStatus # TODO reuse status between ops? + op = FFI.TFE_NewOp(context, op_name, status) + check_status status + + attrs.each do |attr_name, attr_value| + next if attr_value.nil? + + attr_name = attr_name.to_s + + is_list = ::FFI::MemoryPointer.new(:int) + type = FFI.TFE_OpGetAttrType(op, attr_name, is_list, status) + check_status status + + case FFI::AttrType[type] + when :string + FFI.TFE_OpSetAttrString(op, attr_name, attr_value, attr_value.bytesize) + # when :int + # when :float + # when :bool + when :type + FFI.TFE_OpSetAttrType(op, attr_name, attr_value) + when :shape + # TODO set value properly + FFI.TFE_OpSetAttrShape(op, attr_name, nil, 0, status) + check_status status + # when :tensor + # when :placeholder + # when :func + else + raise "Unknown type: #{FFI::AttrType[type]}" + end + end + + inputs.each do |input| + input = TensorFlow.convert_to_tensor(input) unless input.respond_to?(:to_ptr) + FFI.TFE_OpAddInput(op, input, status) + check_status status + end + + retvals = ::FFI::MemoryPointer.new(:pointer) + num_retvals = ::FFI::MemoryPointer.new(:int) + num_retvals.write_int(retvals.size) + FFI.TFE_Execute(op, retvals, num_retvals, status) + check_status status + + if num_retvals.read_int > 0 + handle = retvals.read_pointer + type = FFI.TFE_TensorHandleDataType(handle) + + case FFI::DataType[type] + when :resource + handle + else + Tensor.new(pointer: handle) + end + end + ensure + FFI.TF_DeleteStatus(status) if status + FFI.TFE_DeleteOp(op) if op + end + + def infer_type(value) + if value.all? { |v| v.is_a?(String) } + :string + elsif value.all? { |v| v == true || v == false } + :bool + elsif value.all? { |v| v.is_a?(Integer) } + if value.all? { |v| v >= -2147483648 && v <= 2147483647 } + :int32 + else + :int64 + end + elsif value.all? { |v| v.is_a?(Complex) } + :complex128 + elsif value.all? { |v| v.is_a?(Numeric) } + :float + else + raise Error, "Unable to infer data type" + end + end + + def load_dataset(path, url) + # TODO handle this better + raise "No HOME" unless ENV["HOME"] + datasets_dir = "#{ENV["HOME"]}/.keras/datasets" + FileUtils.mkdir_p(datasets_dir) + + path = "#{datasets_dir}/#{path}" + Utils.download_file(url, path) unless File.exist?(path) + Npy.load_npz(path) + end + + def download_file(url, dest) + uri = URI(url) + + temp_dir ||= File.dirname(Tempfile.new("tensorflow")) + temp_path = "#{temp_dir}/#{Time.now.to_f}" # TODO better name + + # Net::HTTP automatically adds Accept-Encoding for compression + # of response bodies and automatically decompresses gzip + # and deflateresponses unless a Range header was sent. + # https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html + Net::HTTP.start(uri.host, uri.port, use_ssl: true) do |http| + request = Net::HTTP::Get.new(uri) + + print("Downloading dataset") + i = 0 + File.open(temp_path, "wb") do |f| + http.request(request) do |response| + response.read_body do |chunk| + f.write(chunk) + + # print progress + putc "." if i % 50 == 0 + i += 1 + end + end + puts # newline + end + end + + FileUtils.mv(temp_path, dest) end end end end