Sha256: 17dd74b273897dbe0288826831bd9141a3be24391492c1ae0e13357bb64f043c
Contents?: true
Size: 1.92 KB
Versions: 15
Compression:
Stored size: 1.92 KB
Contents
module TensorStream module Train # convenience methods used for training module Utils def create_global_step(graph = nil) target_graph = graph || TensorStream.get_default_graph raise TensorStream::ValueError, '"global_step" already exists.' unless get_global_step(target_graph).nil? TensorStream.variable_scope.get_variable(TensorStream::GraphKeys::GLOBAL_STEP, shape: [], dtype: :int64, initializer: TensorStream.zeros_initializer, trainable: false, collections: [TensorStream::GraphKeys::GLOBAL_VARIABLES, TensorStream::GraphKeys::GLOBAL_STEP]) end def get_global_step(graph = nil) target_graph = graph || TensorStream.get_default_graph global_step_tensors = target_graph.get_collection(TensorStream::GraphKeys::GLOBAL_STEP) global_step_tensor = if global_step_tensors.nil? || global_step_tensors.empty? begin target_graph.get_tensor_by_name('global_step:0') rescue TensorStream::KeyError nil end elsif global_step_tensors.size == 1 global_step_tensors[0] else TensorStream.logger.error("Multiple tensors in global_step collection.") nil end global_step_tensor end end end end
Version data entries
15 entries across 15 versions & 1 rubygems