Sha256: 17dd74b273897dbe0288826831bd9141a3be24391492c1ae0e13357bb64f043c

Contents?: true

Size: 1.92 KB

Versions: 15

Compression:

Stored size: 1.92 KB

Contents

module TensorStream
  module Train
    # convenience methods used for training
    module Utils
      def create_global_step(graph = nil)
        target_graph = graph || TensorStream.get_default_graph
        raise TensorStream::ValueError, '"global_step" already exists.' unless get_global_step(target_graph).nil?

        TensorStream.variable_scope.get_variable(TensorStream::GraphKeys::GLOBAL_STEP, shape: [],
                                                                                       dtype: :int64,
                                                                                       initializer: TensorStream.zeros_initializer,
                                                                                       trainable: false,
                                                                                       collections: [TensorStream::GraphKeys::GLOBAL_VARIABLES,
                                                                                                     TensorStream::GraphKeys::GLOBAL_STEP])
      end

      def get_global_step(graph = nil)
        target_graph = graph || TensorStream.get_default_graph
        global_step_tensors = target_graph.get_collection(TensorStream::GraphKeys::GLOBAL_STEP)
        global_step_tensor = if global_step_tensors.nil? || global_step_tensors.empty?
                               begin
                                 target_graph.get_tensor_by_name('global_step:0')
                               rescue TensorStream::KeyError
                                 nil
                               end
                             elsif global_step_tensors.size == 1
                               global_step_tensors[0]
                             else
                               TensorStream.logger.error("Multiple tensors in global_step collection.")
                               nil
                             end
        global_step_tensor
      end
    end
  end
end

Version data entries

15 entries across 15 versions & 1 rubygems

Version Path
tensor_stream-1.0.0 lib/tensor_stream/train/utils.rb
tensor_stream-1.0.0.pre.rc1 lib/tensor_stream/train/utils.rb
tensor_stream-0.9.10 lib/tensor_stream/train/utils.rb
tensor_stream-0.9.9 lib/tensor_stream/train/utils.rb
tensor_stream-0.9.8 lib/tensor_stream/train/utils.rb
tensor_stream-0.9.7 lib/tensor_stream/train/utils.rb
tensor_stream-0.9.6 lib/tensor_stream/train/utils.rb
tensor_stream-0.9.5 lib/tensor_stream/train/utils.rb
tensor_stream-0.9.2 lib/tensor_stream/train/utils.rb
tensor_stream-0.9.1 lib/tensor_stream/train/utils.rb
tensor_stream-0.9.0 lib/tensor_stream/train/utils.rb
tensor_stream-0.8.6 lib/tensor_stream/train/utils.rb
tensor_stream-0.8.5 lib/tensor_stream/train/utils.rb
tensor_stream-0.8.1 lib/tensor_stream/train/utils.rb
tensor_stream-0.8.0 lib/tensor_stream/train/utils.rb