Sha256: e7509556e4dec94834e275786f141d094a6fcae4792644c699f18f3a7f0453bf
Contents?: true
Size: 1.69 KB
Versions: 9
Compression:
Stored size: 1.69 KB
Contents
module TensorStream module Train # convenience methods used for training module Utils def create_global_step(graph = nil) target_graph = graph || TensorStream.get_default_graph raise TensorStream::ValueError, '"global_step" already exists.' unless get_global_step(target_graph).nil? TensorStream.variable_scope.get_variable(TensorStream::GraphKeys::GLOBAL_STEP, shape: [], dtype: :int64, initializer: TensorStream.zeros_initializer, trainable: false, collections: [TensorStream::GraphKeys::GLOBAL_VARIABLES, TensorStream::GraphKeys::GLOBAL_STEP,]) end def get_global_step(graph = nil) target_graph = graph || TensorStream.get_default_graph global_step_tensors = target_graph.get_collection(TensorStream::GraphKeys::GLOBAL_STEP) global_step_tensor = if global_step_tensors.nil? || global_step_tensors.empty? begin target_graph.get_tensor_by_name("global_step:0") rescue TensorStream::KeyError nil end elsif global_step_tensors.size == 1 global_step_tensors[0] else TensorStream.logger.error("Multiple tensors in global_step collection.") nil end global_step_tensor end end end end
Version data entries
9 entries across 9 versions & 1 rubygems