Sha256: ad728be2c85c51a97e55bfd8e8337000f4f5fd0c597053c3c3ea03751c5a5aac
Contents?: true
Size: 1.2 KB
Versions: 2
Compression:
Stored size: 1.2 KB
Contents
module TensorStream class Freezer include OpHelper ## # Utility class to convert variables to constants for production deployment # def convert(model_file, checkpoint_file, output_file) TensorStream.graph.as_default do |current_graph| YamlLoader.new.load_from_string(File.read(model_file)) saver = TensorStream::Train::Saver.new saver.restore(nil, checkpoint_file) output_buffer = TensorStream::Yaml.new.get_string(current_graph) do |graph, node_key| node = graph.get_tensor_by_name(node_key) if node.operation == :variable_v2 value = node.container options = { value: value, data_type: node.data_type, shape: shape_eval(value) } const_op = TensorStream::Operation.new(current_graph, inputs: [], options: options) const_op.name = node.name const_op.operation = :const const_op.data_type = node.data_type const_op.shape = TensorShape.new(shape_eval(value)) const_op else node end end File.write(output_file, output_buffer) end end end end
Version data entries
2 entries across 2 versions & 1 rubygems
Version | Path |
---|---|
tensor_stream-0.9.10 | lib/tensor_stream/utils/freezer.rb |
tensor_stream-0.9.9 | lib/tensor_stream/utils/freezer.rb |