Sha256: c6d8a1c5945d79f919de2f9d82d194d97da70c5f40367c252fc1cadb470c0318

Contents?: true

Size: 1.7 KB

Versions: 7

Compression:

Stored size: 1.7 KB

Contents

require 'json'

module TensorStream
  module Train
    # High level class used for loading and saving variables
    class Saver
      def save(session, outputfile, global_step: nil,
               latest_filename: nil,
               meta_graph_suffix: 'meta',
               write_meta_graph: true,
               write_state: true,
               strip_default_attrs: false)
        vars = TensorStream::Graph.get_default_graph.get_collection(GraphKeys::GLOBAL_VARIABLES)

        variables = {}
        graph = {}
        gs =  eval_global_step(session, global_step)
        output_dump = {
          variables: variables,
          graph: graph,
          global_step: gs
        }

        vars.each do |variable|
          variables[variable.name] = variable.read_value
        end

        basename = File.basename(outputfile)
        path = File.dirname(outputfile)

        new_filename = File.join(path, [basename, gs].compact.join('-'))
        File.write(new_filename, output_dump.to_json)

        path
      end

      def restore(_session, inputfile)
        input_dump = JSON.parse(File.read(inputfile))

        vars = TensorStream::Graph.get_default_graph.get_collection(GraphKeys::GLOBAL_VARIABLES)
        vars.each do |variable|
          variable.value = input_dump['variables'][variable.name]
        end
      end

      private

      def eval_global_step(session, global_step)
        return nil if global_step.nil?

        if global_step.is_a?(Tensor)
          session.last_session_context(global_step.name)
        elsif global_step.is_a?(String) || global_step.is_a?(Symbol)
          session.last_session_context(global_step)
        else
          global_step.to_i
        end
      end
    end
  end
end

Version data entries

7 entries across 7 versions & 1 rubygems

Version Path
tensor_stream-0.6.1 lib/tensor_stream/train/saver.rb
tensor_stream-0.6.0 lib/tensor_stream/train/saver.rb
tensor_stream-0.5.1 lib/tensor_stream/train/saver.rb
tensor_stream-0.5.0 lib/tensor_stream/train/saver.rb
tensor_stream-0.4.1 lib/tensor_stream/train/saver.rb
tensor_stream-0.4.0 lib/tensor_stream/train/saver.rb
tensor_stream-0.3.0 lib/tensor_stream/train/saver.rb