Sha256: 3cd7ad5fd03aea25892631d1008fab949844eb6758c50968d96003a8c1cff023
Contents?: true
Size: 1.66 KB
Versions: 2
Compression:
Stored size: 1.66 KB
Contents
require 'json' module TensorStream module Train class Saver def save(session, outputfile, global_step: nil, latest_filename: nil, meta_graph_suffix: 'meta', write_meta_graph: true, write_state: true, strip_default_attrs: false) vars = TensorStream::Graph.get_default_graph.get_collection(GraphKeys::GLOBAL_VARIABLES) variables = {} graph = {} gs = eval_global_step(session, global_step) output_dump = { variables: variables, graph: graph, global_step: gs } vars.each do |variable| variables[variable.name] = variable.value end basename = File.basename(outputfile) path = File.dirname(outputfile) new_filename = File.join(path, [basename, gs].compact.join('-')) File.write(new_filename, output_dump.to_json) path end def restore(session, inputfile) input_dump = JSON.parse(File.read(inputfile)) vars = TensorStream::Graph.get_default_graph.get_collection(GraphKeys::GLOBAL_VARIABLES) vars.each do |variable| variable.value = input_dump["variables"][variable.name] end end private def eval_global_step(session, global_step) return nil if global_step.nil? if (global_step.is_a?(Tensor)) session.last_session_context(global_step.name) elsif (global_step.is_a?(String) || global_step.is_a?(Symbol)) session.last_session_context(global_step) else global_step.to_i end end end end end
Version data entries
2 entries across 2 versions & 1 rubygems
Version | Path |
---|---|
tensor_stream-0.1.1 | lib/tensor_stream/train/saver.rb |
tensor_stream-0.1.0 | lib/tensor_stream/train/saver.rb |