lib/tensorflow.rb in tensorflow-0.1.0 vs lib/tensorflow.rb in tensorflow-0.1.1

- old
+ new

@@ -1,15 +1,32 @@ # dependencies require "ffi" +require "npy" +# stdlib +require "fileutils" +require "forwardable" +require "net/http" +require "tempfile" + # modules require "tensorflow/utils" require "tensorflow/context" +require "tensorflow/math" +require "tensorflow/ops" +require "tensorflow/raw_ops" require "tensorflow/tensor" require "tensorflow/variable" require "tensorflow/version" +# keras +require "tensorflow/keras/datasets/mnist" +require "tensorflow/keras/layers/dense" +require "tensorflow/keras/layers/dropout" +require "tensorflow/keras/layers/flatten" +require "tensorflow/keras/models/sequential" + module TensorFlow class Error < StandardError; end class << self attr_accessor :ffi_lib @@ -18,12 +35,16 @@ # friendlier error message autoload :FFI, "tensorflow/ffi" class << self + include Ops include Utils + extend Forwardable + def_delegators Math, :abs, :acos, :acosh, :add, :add_n, :argmax, :argmin, :asin, :asinh, :atan, :atan2, :atanh, :cos, :cosh, :cumsum, :divide, :equal, :exp, :floor, :greater, :greater_equal, :less, :less_equal, :logical_and, :logical_not, :logical_or, :maximum, :minimum, :multiply, :negative, :not_equal, :pow, :reduce_all, :reduce_any, :reduce_logsumexp, :reduce_max, :reduce_mean, :reduce_min, :reduce_prod, :reduce_sum, :round, :scalar_mul, :sigmoid, :sign, :sin, :sinh, :sqrt, :square, :subtract, :tan, :tanh, :truediv + def library_version FFI.TF_Version end def constant(value, dtype: nil, shape: nil) @@ -31,156 +52,9 @@ end def convert_to_tensor(value, dtype: nil) value = Tensor.new(value, dtype: dtype) unless value.is_a?(Tensor) value - end - - def add(x, y) - execute("Add", [x, y]) - end - - def subtract(x, y) - execute("Sub", [x, y]) - end - - def multiply(x, y) - execute("Mul", [x, y]) - end - - def divide(x, y) - execute("Div", [x, y]) - end - - def abs(x) - execute("Abs", [x]) - end - - def sqrt(x) - execute("Sqrt", [x]) - end - - def matmul(x, y) - execute("MatMul", [x, y]) - end - - def floormod(x, y) - execute("Mod", [x, y]) - end - - def range(start, limit, delta) - execute("Range", [start, limit, delta]) - end - - def transpose(x, perm: [1, 0]) - execute("Transpose", [x, perm]) - end - - def equal(x, y) - execute("Equal", [x, y]) - end - - def zeros_like(x) - execute("ZerosLike", [x]) - end - - def fill(dims, value) - execute("Fill", [dims, value]) - end - - def zeros(dims) - fill(dims, 0) - end - - def ones(dims) - fill(dims, 1) - end - - def assign_add_variable_op(resource, value) - execute("AssignAddVariableOp", [resource, value]) - end - - def assign_sub_variable_op(resource, value) - execute("AssignSubVariableOp", [resource, value]) - end - - def assign_variable_op(resource, value) - execute("AssignVariableOp", [resource, value]) - end - - def read_variable_op(resource, dtype) - execute("ReadVariableOp", [resource], dtype: dtype) - end - - def var_handle_op(dtype, shape, container: "", shared_name: "") - execute("VarHandleOp", [], container: container, shared_name: shared_name, dtype: dtype, shape: shape) - end - - def var_is_initialized_op(resource) - execute("VarIsInitializedOp", [resource]) - end - - private - - def default_context - @default_context ||= Context.new - end - - def execute(op_name, inputs = [], **attrs) - context = default_context - status = FFI.TF_NewStatus - op = FFI.TFE_NewOp(context, op_name, status) - check_status status - # TODO clean up status and op - - attrs.each do |attr_name, attr_value| - attr_name = attr_name.to_s - - is_list = ::FFI::MemoryPointer.new(:int) - type = FFI.TFE_OpGetAttrType(op, attr_name, is_list, status) - check_status status - - case FFI::AttrType[type] - when :type - FFI.TFE_OpSetAttrType(op, attr_name, attr_value) - when :shape - # TODO set value properly - FFI.TFE_OpSetAttrShape(op, attr_name, attr_value, 0, status) - check_status status - when :string - FFI.TFE_OpSetAttrString(op, attr_name, attr_value, attr_value.bytesize) - else - raise "Unknown type: #{FFI::AttrType[type]}" - end - end - - inputs.each do |input| - input = convert_to_tensor(input) unless input.respond_to?(:to_ptr) - FFI.TFE_OpAddInput(op, input, status) - check_status status - end - - retvals = ::FFI::MemoryPointer.new(:pointer) - num_retvals = ::FFI::MemoryPointer.new(:int) - num_retvals.write_int(retvals.size) - FFI.TFE_Execute(op, retvals, num_retvals, status) - check_status status - - if num_retvals.read_int > 0 - handle = retvals.read_pointer - type = FFI.TFE_TensorHandleDataType(handle) - - case FFI::DataType[type] - when :resource - handle - else - Tensor.new(pointer: handle) - end - end - end - - def check_status(status) - Utils.check_status(status) end end end # shortcut