Sha256: 54626aa1c595ef6195cfc13c345cbebc87ba4b94bf1c30f42ec220e797e5ca9a

Contents?: true

Size: 1.61 KB

Versions: 1

Compression:

Stored size: 1.61 KB

Contents

module TensorStream
  class Pbtext
    def initialize
    end

    def serialize(session, filename, tensor)
    end

    def get_string(graph)
      @lines = []
      graph.nodes.each do |k, node|
        @lines << "node {"
        @lines << "  name: #{node.name.to_json}"
        if node.is_a?(TensorStream::Operation)
          @lines << "  op: #{node.operation.to_json}"
          node.items.each do |input|
            next unless input
            @lines << "  input: #{input.name.to_json}"
          end
          # type
          pb_attr('T', sym_to_protobuf_type(node.data_type))
        elsif node.is_a?(TensorStream::Tensor) && node.is_const
          @lines << "  op: \"Const\""
          # type
          pb_attr('T', sym_to_protobuf_type(node.data_type))
          pb_attr('value', tensor_value(node))
        end
        @lines << "}"
      end
      @lines.join("\n")
    end

    private

    def tensor_value(tensor)
      arr = []
      arr << "tensor {"
      arr << "  dtype: #{sym_to_protobuf_type(tensor.data_type)}"
      arr << "  float_val: #{tensor.value}"
      arr << "}"
      arr
    end

    def sym_to_protobuf_type(type)
      case type
      when :int32
        "DT_INT32"
      when :float, :float32
        "DT_FLOAT"
      else
        "DT_UNKNOWN"
      end
    end

    def pb_attr(key, value)
      @lines << "  attr {"
      @lines << "    key: \"#{key}\""
      @lines << "    value {"
      if value.is_a?(Array)
        value.each do |v|
          @lines << "      #{v}"
        end
      else
        @lines << "      #{value}"
      end
      @lines << "    }"
      @lines << "  }"
    end
  end

end

Version data entries

1 entries across 1 versions & 1 rubygems

Version Path
tensor_stream-0.1.5 lib/tensor_stream/graph_serializers/pbtext.rb