lib/tensorflow.rb in tensorflow-0.1.0 vs lib/tensorflow.rb in tensorflow-0.1.1
- old
+ new
@@ -1,15 +1,32 @@
# dependencies
require "ffi"
+require "npy"
+# stdlib
+require "fileutils"
+require "forwardable"
+require "net/http"
+require "tempfile"
+
# modules
require "tensorflow/utils"
require "tensorflow/context"
+require "tensorflow/math"
+require "tensorflow/ops"
+require "tensorflow/raw_ops"
require "tensorflow/tensor"
require "tensorflow/variable"
require "tensorflow/version"
+# keras
+require "tensorflow/keras/datasets/mnist"
+require "tensorflow/keras/layers/dense"
+require "tensorflow/keras/layers/dropout"
+require "tensorflow/keras/layers/flatten"
+require "tensorflow/keras/models/sequential"
+
module TensorFlow
class Error < StandardError; end
class << self
attr_accessor :ffi_lib
@@ -18,12 +35,16 @@
# friendlier error message
autoload :FFI, "tensorflow/ffi"
class << self
+ include Ops
include Utils
+ extend Forwardable
+ def_delegators Math, :abs, :acos, :acosh, :add, :add_n, :argmax, :argmin, :asin, :asinh, :atan, :atan2, :atanh, :cos, :cosh, :cumsum, :divide, :equal, :exp, :floor, :greater, :greater_equal, :less, :less_equal, :logical_and, :logical_not, :logical_or, :maximum, :minimum, :multiply, :negative, :not_equal, :pow, :reduce_all, :reduce_any, :reduce_logsumexp, :reduce_max, :reduce_mean, :reduce_min, :reduce_prod, :reduce_sum, :round, :scalar_mul, :sigmoid, :sign, :sin, :sinh, :sqrt, :square, :subtract, :tan, :tanh, :truediv
+
def library_version
FFI.TF_Version
end
def constant(value, dtype: nil, shape: nil)
@@ -31,156 +52,9 @@
end
def convert_to_tensor(value, dtype: nil)
value = Tensor.new(value, dtype: dtype) unless value.is_a?(Tensor)
value
- end
-
- def add(x, y)
- execute("Add", [x, y])
- end
-
- def subtract(x, y)
- execute("Sub", [x, y])
- end
-
- def multiply(x, y)
- execute("Mul", [x, y])
- end
-
- def divide(x, y)
- execute("Div", [x, y])
- end
-
- def abs(x)
- execute("Abs", [x])
- end
-
- def sqrt(x)
- execute("Sqrt", [x])
- end
-
- def matmul(x, y)
- execute("MatMul", [x, y])
- end
-
- def floormod(x, y)
- execute("Mod", [x, y])
- end
-
- def range(start, limit, delta)
- execute("Range", [start, limit, delta])
- end
-
- def transpose(x, perm: [1, 0])
- execute("Transpose", [x, perm])
- end
-
- def equal(x, y)
- execute("Equal", [x, y])
- end
-
- def zeros_like(x)
- execute("ZerosLike", [x])
- end
-
- def fill(dims, value)
- execute("Fill", [dims, value])
- end
-
- def zeros(dims)
- fill(dims, 0)
- end
-
- def ones(dims)
- fill(dims, 1)
- end
-
- def assign_add_variable_op(resource, value)
- execute("AssignAddVariableOp", [resource, value])
- end
-
- def assign_sub_variable_op(resource, value)
- execute("AssignSubVariableOp", [resource, value])
- end
-
- def assign_variable_op(resource, value)
- execute("AssignVariableOp", [resource, value])
- end
-
- def read_variable_op(resource, dtype)
- execute("ReadVariableOp", [resource], dtype: dtype)
- end
-
- def var_handle_op(dtype, shape, container: "", shared_name: "")
- execute("VarHandleOp", [], container: container, shared_name: shared_name, dtype: dtype, shape: shape)
- end
-
- def var_is_initialized_op(resource)
- execute("VarIsInitializedOp", [resource])
- end
-
- private
-
- def default_context
- @default_context ||= Context.new
- end
-
- def execute(op_name, inputs = [], **attrs)
- context = default_context
- status = FFI.TF_NewStatus
- op = FFI.TFE_NewOp(context, op_name, status)
- check_status status
- # TODO clean up status and op
-
- attrs.each do |attr_name, attr_value|
- attr_name = attr_name.to_s
-
- is_list = ::FFI::MemoryPointer.new(:int)
- type = FFI.TFE_OpGetAttrType(op, attr_name, is_list, status)
- check_status status
-
- case FFI::AttrType[type]
- when :type
- FFI.TFE_OpSetAttrType(op, attr_name, attr_value)
- when :shape
- # TODO set value properly
- FFI.TFE_OpSetAttrShape(op, attr_name, attr_value, 0, status)
- check_status status
- when :string
- FFI.TFE_OpSetAttrString(op, attr_name, attr_value, attr_value.bytesize)
- else
- raise "Unknown type: #{FFI::AttrType[type]}"
- end
- end
-
- inputs.each do |input|
- input = convert_to_tensor(input) unless input.respond_to?(:to_ptr)
- FFI.TFE_OpAddInput(op, input, status)
- check_status status
- end
-
- retvals = ::FFI::MemoryPointer.new(:pointer)
- num_retvals = ::FFI::MemoryPointer.new(:int)
- num_retvals.write_int(retvals.size)
- FFI.TFE_Execute(op, retvals, num_retvals, status)
- check_status status
-
- if num_retvals.read_int > 0
- handle = retvals.read_pointer
- type = FFI.TFE_TensorHandleDataType(handle)
-
- case FFI::DataType[type]
- when :resource
- handle
- else
- Tensor.new(pointer: handle)
- end
- end
- end
-
- def check_status(status)
- Utils.check_status(status)
end
end
end
# shortcut