Sha256: 7feb35d1e1c8ee33a31db2e25d7cc263d3bf040618d654d8f3491df03a59c0b2

Contents?: true

Size: 620 Bytes

Versions: 4

Compression:

Stored size: 620 Bytes

Contents

module TensorStream
  module Debugging
    extend TensorStream::OpHelper

    def add_check_numerics_ops
      graph = TensorStream.get_default_graph
      nodes_to_process  = graph.nodes.values.select { |node| node.is_a?(Operation) }

      nodes_to_process.each do |node|
        node.inputs = node.inputs.compact.collect do |input|
          if TensorStream::Ops::FLOATING_POINT_TYPES.include?(input.data_type)
            TensorStream.check_numerics(input, "#{node.name}/#{input.name}", name: "check/#{node.name}/#{input.name}" )
          else
            input
          end
        end
      end
    end
  end
end

Version data entries

4 entries across 4 versions & 1 rubygems

Version Path
tensor_stream-0.6.1 lib/tensor_stream/debugging/debugging.rb
tensor_stream-0.6.0 lib/tensor_stream/debugging/debugging.rb
tensor_stream-0.5.1 lib/tensor_stream/debugging/debugging.rb
tensor_stream-0.5.0 lib/tensor_stream/debugging/debugging.rb