Sha256: b0d03c256f79bbe0021e10ce48ef4534fee1cb4451d6094620106f3d5d4a30dd

Contents?: true

Size: 1.61 KB

Versions: 8

Compression:

Stored size: 1.61 KB

Contents

module TensorStream
  class Freezer
    include TensorStream::OpHelper

    ##
    # Utility class to convert variables to constants for production deployment
    #
    def convert(session, checkpoint_folder, output_file)
      model_file = File.join(checkpoint_folder, "model.yaml")
      TensorStream.graph.as_default do |current_graph|
        YamlLoader.new.load_from_string(File.read(model_file))
        saver = TensorStream::Train::Saver.new
        saver.restore(session, checkpoint_folder)

        # collect all assign ops and remove them from the graph
        remove_nodes = Set.new(current_graph.nodes.values.select { |op| op.is_a?(TensorStream::Operation) && op.operation == :assign }.map { |op| op.consumers.to_a }.flatten.uniq)

        output_buffer = TensorStream::Yaml.new.get_string(current_graph) { |graph, node_key|
          node = graph.get_tensor_by_name(node_key)
          case node.operation
          when :variable_v2
            value = node.container
            options = {
              value: value,
              data_type: node.data_type,
              shape: shape_eval(value),
            }
            const_op = TensorStream::Operation.new(current_graph, inputs: [], options: options)
            const_op.name = node.name
            const_op.operation = :const
            const_op.data_type = node.data_type
            const_op.shape = TensorShape.new(shape_eval(value))

            const_op
          when :assign
            nil
          else
            remove_nodes.include?(node.name) ? nil : node
          end
        }
        File.write(output_file, output_buffer)
      end
    end
  end
end

Version data entries

8 entries across 8 versions & 1 rubygems

Version Path
tensor_stream-1.0.8 lib/tensor_stream/utils/freezer.rb
tensor_stream-1.0.7 lib/tensor_stream/utils/freezer.rb
tensor_stream-1.0.6 lib/tensor_stream/utils/freezer.rb
tensor_stream-1.0.5 lib/tensor_stream/utils/freezer.rb
tensor_stream-1.0.4 lib/tensor_stream/utils/freezer.rb
tensor_stream-1.0.3 lib/tensor_stream/utils/freezer.rb
tensor_stream-1.0.2 lib/tensor_stream/utils/freezer.rb
tensor_stream-1.0.1 lib/tensor_stream/utils/freezer.rb