Sha256: b351e37fdb0bd9ca9ca31a9c6bf1b352d4e4a27791d198bdddbaea3434f6da45

Contents?: true

Size: 710 Bytes

Versions: 9

Compression:

Stored size: 710 Bytes

Contents

module TensorStream
  module Debugging
    extend TensorStream::OpHelper

    def add_check_numerics_ops
      graph = TensorStream.get_default_graph
      nodes_to_process = graph.nodes.values.select { |node| node.is_a?(Operation) }

      nodes_to_process.each do |node|
        node.inputs = node.inputs.collect do |input|
          next if input.nil?
          next input if input.is_a?(Variable)

          if input.is_a?(Tensor) && TensorStream::Ops::FLOATING_POINT_TYPES.include?(input.data_type)
            TensorStream.check_numerics(input, "#{node.name}/#{input.name}", name: "check/#{node.name}/#{input.name}")
          else
            input
          end
        end
      end
    end
  end
end

Version data entries

9 entries across 9 versions & 1 rubygems

Version Path
tensor_stream-1.0.9 lib/tensor_stream/debugging/debugging.rb
tensor_stream-1.0.8 lib/tensor_stream/debugging/debugging.rb
tensor_stream-1.0.7 lib/tensor_stream/debugging/debugging.rb
tensor_stream-1.0.6 lib/tensor_stream/debugging/debugging.rb
tensor_stream-1.0.5 lib/tensor_stream/debugging/debugging.rb
tensor_stream-1.0.4 lib/tensor_stream/debugging/debugging.rb
tensor_stream-1.0.3 lib/tensor_stream/debugging/debugging.rb
tensor_stream-1.0.2 lib/tensor_stream/debugging/debugging.rb
tensor_stream-1.0.1 lib/tensor_stream/debugging/debugging.rb