Sha256: 31ba2a0649733b94f858e4e010c5fe3829381cd9e4ec16a58ca0bdef50667804
Contents?: true
Size: 1.69 KB
Versions: 5
Compression:
Stored size: 1.69 KB
Contents
require 'json' module TensorStream module Train # High level class used for loading and saving variables class Saver def save(session, outputfile, global_step: nil, latest_filename: nil, meta_graph_suffix: 'meta', write_meta_graph: true, write_state: true, strip_default_attrs: false) vars = TensorStream::Graph.get_default_graph.get_collection(GraphKeys::GLOBAL_VARIABLES) variables = {} graph = {} gs = eval_global_step(session, global_step) output_dump = { variables: variables, graph: graph, global_step: gs } vars.each do |variable| variables[variable.name] = variable.value end basename = File.basename(outputfile) path = File.dirname(outputfile) new_filename = File.join(path, [basename, gs].compact.join('-')) File.write(new_filename, output_dump.to_json) path end def restore(_session, inputfile) input_dump = JSON.parse(File.read(inputfile)) vars = TensorStream::Graph.get_default_graph.get_collection(GraphKeys::GLOBAL_VARIABLES) vars.each do |variable| variable.value = input_dump['variables'][variable.name] end end private def eval_global_step(session, global_step) return nil if global_step.nil? if global_step.is_a?(Tensor) session.last_session_context(global_step.name) elsif global_step.is_a?(String) || global_step.is_a?(Symbol) session.last_session_context(global_step) else global_step.to_i end end end end end
Version data entries
5 entries across 5 versions & 1 rubygems