Sha256: e7509556e4dec94834e275786f141d094a6fcae4792644c699f18f3a7f0453bf

Contents?: true

Size: 1.69 KB

Versions: 9

Compression:

Stored size: 1.69 KB

Contents

module TensorStream
  module Train
    # convenience methods used for training
    module Utils
      def create_global_step(graph = nil)
        target_graph = graph || TensorStream.get_default_graph
        raise TensorStream::ValueError, '"global_step" already exists.' unless get_global_step(target_graph).nil?

        TensorStream.variable_scope.get_variable(TensorStream::GraphKeys::GLOBAL_STEP, shape: [],
                                                                                       dtype: :int64,
                                                                                       initializer: TensorStream.zeros_initializer,
                                                                                       trainable: false,
                                                                                       collections: [TensorStream::GraphKeys::GLOBAL_VARIABLES,
                                                                                                     TensorStream::GraphKeys::GLOBAL_STEP,])
      end

      def get_global_step(graph = nil)
        target_graph = graph || TensorStream.get_default_graph
        global_step_tensors = target_graph.get_collection(TensorStream::GraphKeys::GLOBAL_STEP)
        global_step_tensor = if global_step_tensors.nil? || global_step_tensors.empty?
          begin
            target_graph.get_tensor_by_name("global_step:0")
          rescue TensorStream::KeyError
            nil
          end
        elsif global_step_tensors.size == 1
          global_step_tensors[0]
        else
          TensorStream.logger.error("Multiple tensors in global_step collection.")
          nil
        end
        global_step_tensor
      end
    end
  end
end

Version data entries

9 entries across 9 versions & 1 rubygems

Version Path
tensor_stream-1.0.9 lib/tensor_stream/train/utils.rb
tensor_stream-1.0.8 lib/tensor_stream/train/utils.rb
tensor_stream-1.0.7 lib/tensor_stream/train/utils.rb
tensor_stream-1.0.6 lib/tensor_stream/train/utils.rb
tensor_stream-1.0.5 lib/tensor_stream/train/utils.rb
tensor_stream-1.0.4 lib/tensor_stream/train/utils.rb
tensor_stream-1.0.3 lib/tensor_stream/train/utils.rb
tensor_stream-1.0.2 lib/tensor_stream/train/utils.rb
tensor_stream-1.0.1 lib/tensor_stream/train/utils.rb