Sha256: 801e9d878242bf4d7a2a2042b880969377bfa8dc6578b82ed607804fdb4114dd

Contents?: true

Size: 1.93 KB

Versions: 8

Compression:

Stored size: 1.93 KB

Contents

module TensorStream
  ##
  # Class for deserialization from a YAML file
  class YamlLoader
    def initialize(graph = nil)
      @graph = graph || TensorStream.get_default_graph
    end

    ##
    # Loads a model Yaml file and builds the model from it
    #
    # Args:
    # filename: String - Location of Yaml file
    #
    # Returns: Graph where model is restored to
    def load_from_file(filename)
      load_from_string(File.read(filename))
    end

    ##
    # Loads a model Yaml file and builds the model from it
    #
    # Args:
    # buffer: String - String in Yaml format of the model
    #
    # Returns: Graph where model is restored to
    def load_from_string(buffer)
      serialized_ops = YAML.safe_load(buffer, [Symbol], [], true)
      serialized_ops.each do |op_def|
        inputs = op_def[:inputs].map { |i| @graph.get_tensor_by_name(i) }
        options = {}

        new_var = nil
        if op_def.dig(:attrs, :container)
          new_var = Variable.new(op_def.dig(:attrs, :data_type))
          var_shape = op_def.dig(:attrs, :container, :shape)
          var_options = op_def.dig(:attrs, :container, :options)
          var_options[:name] = op_def[:name]

          new_var.prepare(var_shape.size, var_shape, TensorStream.get_variable_scope, var_options)
          options[:container] = new_var

          @graph.add_variable(new_var, var_options)
        end

        new_op = Operation.new(@graph, inputs: inputs, options: op_def[:attrs].merge(options))
        new_op.operation = op_def[:op].to_sym
        new_op.name = op_def[:name]
        new_op.shape = TensorShape.new(TensorStream::InferShape.infer_shape(new_op))
        new_op.rank = new_op.shape.rank
        new_op.data_type = new_op.set_data_type(op_def.dig(:attrs, :data_type))
        new_op.is_const = new_op.infer_const
        new_op.given_name = new_op.name
        new_var.op = new_op if new_var

        @graph.add_node(new_op)
      end
      @graph
    end
  end
end

Version data entries

8 entries across 8 versions & 1 rubygems

Version Path
tensor_stream-1.0.8 lib/tensor_stream/graph_deserializers/yaml_loader.rb
tensor_stream-1.0.7 lib/tensor_stream/graph_deserializers/yaml_loader.rb
tensor_stream-1.0.6 lib/tensor_stream/graph_deserializers/yaml_loader.rb
tensor_stream-1.0.5 lib/tensor_stream/graph_deserializers/yaml_loader.rb
tensor_stream-1.0.4 lib/tensor_stream/graph_deserializers/yaml_loader.rb
tensor_stream-1.0.3 lib/tensor_stream/graph_deserializers/yaml_loader.rb
tensor_stream-1.0.2 lib/tensor_stream/graph_deserializers/yaml_loader.rb
tensor_stream-1.0.1 lib/tensor_stream/graph_deserializers/yaml_loader.rb