module TensorStream
class Graphml < Serializer
def initialize
end
def get_string(tensor, session = nil)
tensor = TensorStream.convert_to_tensor(tensor) unless tensor.is_a?(Tensor)
@session = session
@name = tensor.name
@last_session_context = session ? session.last_session_context : {}
groups = {}
arr_buf = []
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << ""
arr_buf << ""
arr_buf << "out"
arr_buf << "red"
arr_buf << ""
arr_buf << ""
arr_buf << " "
arr_buf << " out"
arr_buf << ""
arr_buf << ""
arr_buf << ""
to_graph_ml(tensor, arr_buf, {}, groups)
# dump groups
groups.each do |k, g|
arr_buf << create_group(k, k, g)
end
output_edge(tensor, "out", arr_buf)
arr_buf << ""
arr_buf << ""
arr_buf.flatten.join("\n")
end
private
def add_to_group(groups, name, arr_buf)
name_parts = name.split('/')
return false if name_parts.size < 2
prefix = name_parts.shift
ptr = find_or_create_group(prefix, groups)
Kernel.loop do
next_group = ptr[:group]
ptr = find_or_create_group(prefix, next_group)
break if name_parts.size < 2
prefix = name_parts.shift
end
ptr[:buf] << arr_buf
true
end
def find_or_create_group(prefix, groups)
if !groups[prefix]
groups[prefix] = { buf: [], group: {} }
end
return groups[prefix]
end
def create_group(id, title, group)
arr_buf = []
arr_buf << ""
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << '' + title + ''
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << group[:buf]
group[:group].each do |k, g|
arr_buf << create_group(k, k, g)
end
arr_buf << ''
arr_buf << ''
arr_buf
end
def _val(tensor)
# JSON.pretty_generate(@last_session_context[tensor.name])
@last_session_context[tensor.name] || @last_session_context[:_cache][tensor.name]
end
def to_graph_ml(tensor, arr_buf = [], added = {}, groups = {}, _id = 0)
puts tensor.name
return unless tensor.is_a?(Operation)
added[tensor.name] = true
node_buf = []
node_buf << ""
node_buf << "#{tensor.operation}"
node_buf << "#{tensor.to_math(true, 1)}"
node_buf << "blue"
if @last_session_context[tensor.name]
arr_buf << "#{_val(tensor)}"
end
node_buf << ""
node_buf << ""
if tensor.internal?
node_buf << " "
else
node_buf << " "
end
node_buf << " #{tensor.operation}"
node_buf << ""
node_buf << ""
node_buf << ""
if !add_to_group(groups, tensor.name, node_buf)
add_to_group(groups, "program/#{tensor.name}", node_buf)
end
tensor.inputs.each do |input|
next unless input
next if added[input.name]
next to_graph_ml(input, arr_buf, added, groups) if input.is_a?(Operation)
added[input.name] = true
input_buf = []
if input.is_a?(Variable)
input_buf << ""
input_buf << "#{input.name}"
input_buf << "green"
input_buf << "#{_val(tensor)}" if @last_session_context[input.name]
input_buf << ""
input_buf << ""
input_buf << " "
input_buf << " #{input.name}"
input_buf << ""
input_buf << ""
input_buf << ""
elsif input.is_a?(Placeholder)
input_buf << ""
input_buf << ""
input_buf << ""
input_buf << " "
input_buf << " #{input.name}"
input_buf << ""
input_buf << ""
input_buf << "#{_val(tensor)}" if @last_session_context[input.name]
\
input_buf << ""
elsif input.is_a?(Tensor)
input_buf << ""
input_buf << "#{input.name}"
input_buf << "black"
input_buf << ""
input_buf << ""
input_buf << if input.internal?
" "
else
" "
end
input_buf << " #{input.name}"
input_buf << ""
input_buf << ""
input_buf << ""
end
unless add_to_group(groups, input.name, input_buf)
if input.is_a?(Variable)
add_to_group(groups, "variable/#{input.name}", input_buf)
else
add_to_group(groups, "program/#{input.name}", input_buf)
end
end
end
tensor.inputs.each_with_index do |input, index|
next unless input
output_edge(input, tensor, arr_buf, index)
end
end
def _gml_string(str)
str.tr('/', '-')
end
def output_edge(input, tensor, arr_buf, index = 0)
target_name = tensor.is_a?(Tensor) ? tensor.name : tensor
arr_buf << ""
arr_buf << ""
arr_buf << ""
arr_buf << ""
arr_buf << if !@last_session_context.empty?
""
elsif input.shape.shape.nil?
""
else
""
end
arr_buf << ""
arr_buf << ""
arr_buf << if index.zero?
""
else
""
end
arr_buf << ""
arr_buf << ""
arr_buf << ""
end
end
end