Sha256: 3cd7ad5fd03aea25892631d1008fab949844eb6758c50968d96003a8c1cff023

Contents?: true

Size: 1.66 KB

Versions: 2

Compression:

Stored size: 1.66 KB

Contents

require 'json'

module TensorStream
  module Train
    class Saver
      def save(session, outputfile,
                global_step: nil,
                latest_filename: nil,
                meta_graph_suffix: 'meta',
                write_meta_graph: true,
                write_state: true,
                strip_default_attrs: false)
        vars = TensorStream::Graph.get_default_graph.get_collection(GraphKeys::GLOBAL_VARIABLES)
        
        variables = {}
        graph = {}
        gs =  eval_global_step(session, global_step)
        output_dump = {
          variables: variables,
          graph: graph,
          global_step: gs
        }

        vars.each do |variable|
          variables[variable.name] = variable.value
        end

        basename = File.basename(outputfile)
        path = File.dirname(outputfile)

        new_filename = File.join(path, [basename, gs].compact.join('-'))
        File.write(new_filename, output_dump.to_json)

        path
      end

      def restore(session, inputfile)
        input_dump = JSON.parse(File.read(inputfile))

        vars = TensorStream::Graph.get_default_graph.get_collection(GraphKeys::GLOBAL_VARIABLES)
        vars.each do |variable|
          variable.value = input_dump["variables"][variable.name]
        end
      end

      private

      def eval_global_step(session, global_step)
        return nil if global_step.nil?

        if (global_step.is_a?(Tensor))
          session.last_session_context(global_step.name)
        elsif (global_step.is_a?(String) || global_step.is_a?(Symbol))
          session.last_session_context(global_step)
        else
          global_step.to_i
        end
      end
    end
  end
end

Version data entries

2 entries across 2 versions & 1 rubygems

Version Path
tensor_stream-0.1.1 lib/tensor_stream/train/saver.rb
tensor_stream-0.1.0 lib/tensor_stream/train/saver.rb