Sha256: b351e37fdb0bd9ca9ca31a9c6bf1b352d4e4a27791d198bdddbaea3434f6da45
Contents?: true
Size: 710 Bytes
Versions: 9
Compression:
Stored size: 710 Bytes
Contents
module TensorStream module Debugging extend TensorStream::OpHelper def add_check_numerics_ops graph = TensorStream.get_default_graph nodes_to_process = graph.nodes.values.select { |node| node.is_a?(Operation) } nodes_to_process.each do |node| node.inputs = node.inputs.collect do |input| next if input.nil? next input if input.is_a?(Variable) if input.is_a?(Tensor) && TensorStream::Ops::FLOATING_POINT_TYPES.include?(input.data_type) TensorStream.check_numerics(input, "#{node.name}/#{input.name}", name: "check/#{node.name}/#{input.name}") else input end end end end end end
Version data entries
9 entries across 9 versions & 1 rubygems