Sha256: 7bceff284a5afaae4ef17bdf73114db6c6be58136a0d1358a6faa72fd9e4fadd

Contents?: true

Size: 1.89 KB

Versions: 5

Compression:

Stored size: 1.89 KB

Contents

module TensorStream
  ## Collection of machine learning related ops
  module RandomOps
    def RandomOps.included(klass)
      klass.class_eval do
        register_op :glorot_uniform, no_eval: true do |_context, tensor, _inputs|
          seed = tensor.options[:seed]
          random = _get_randomizer(tensor, seed)

          shape = tensor.options[:shape] || tensor.shape.shape
          fan_in, fan_out = if shape.size.zero?
                              [1, 1]
                            elsif shape.size == 1
                              [1, shape[0]]
                            else
                              [shape[0], shape.last]
                            end

          limit = Math.sqrt(6.0 / (fan_in + fan_out))

          minval = -limit
          maxval = limit

          generator = -> { random.rand * (maxval - minval) + minval }
          generate_vector(shape, generator: generator)
        end

        register_op :random_uniform, no_eval: true do |_context, tensor, _inputs|
          maxval = tensor.options.fetch(:maxval, 1)
          minval = tensor.options.fetch(:minval, 0)
          seed = tensor.options[:seed]

          random = _get_randomizer(tensor, seed)
          generator = -> { random.rand * (maxval - minval) + minval }
          shape = tensor.options[:shape] || tensor.shape.shape
          generate_vector(shape, generator: generator)
        end

        register_op :random_standard_normal, no_eval: true do |_context, tensor, _inputs|
          seed = tensor.options[:seed]
          random = _get_randomizer(tensor, seed)
          r = RandomGaussian.new(tensor.options.fetch(:mean), tensor.options.fetch(:stddev), -> { random.rand })
          random = _get_randomizer(tensor, seed)
          generator = -> { r.rand }
          shape = tensor.options[:shape] || tensor.shape.shape
          generate_vector(shape, generator: generator)
        end
      end
    end
  end
end

Version data entries

5 entries across 5 versions & 1 rubygems

Version Path
tensor_stream-0.9.2 lib/tensor_stream/evaluator/ruby/random_ops.rb
tensor_stream-0.9.1 lib/tensor_stream/evaluator/ruby/random_ops.rb
tensor_stream-0.9.0 lib/tensor_stream/evaluator/ruby/random_ops.rb
tensor_stream-0.8.6 lib/tensor_stream/evaluator/ruby/random_ops.rb
tensor_stream-0.8.5 lib/tensor_stream/evaluator/ruby/random_ops.rb