module TensorStream
class Graphml < Serializer
def initialize
end
def get_string(tensor, session = nil)
tensor = TensorStream.convert_to_tensor(tensor) unless tensor.is_a?(Tensor)
@session = session
@name = tensor.name
@last_session_context = session ? session.last_session_context : {}
groups = {}
arr_buf = []
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << ""
arr_buf << ""
arr_buf << "out"
arr_buf << "red"
arr_buf << ""
arr_buf << ""
arr_buf << " "
arr_buf << " out"
arr_buf << ""
arr_buf << ""
arr_buf << ""
to_graph_ml(tensor, arr_buf, {}, groups)
#dump groups
groups.each do |k, g|
arr_buf << create_group(k, k, g)
end
output_edge(tensor, "out", arr_buf)
arr_buf << ""
arr_buf << ""
arr_buf.flatten.join("\n")
end
private
def add_to_group(groups, name, arr_buf)
name_parts = name.split('/')
return false if name_parts.size < 2
prefix = name_parts.shift
ptr = find_or_create_group(prefix, groups)
Kernel.loop do
next_group = ptr[:group]
ptr = find_or_create_group(prefix, next_group)
break if name_parts.size < 2
prefix = name_parts.shift
end
ptr[:buf] << arr_buf
true
end
def find_or_create_group(prefix, groups)
if !groups[prefix]
groups[prefix] = { buf: [], group: {} }
end
return groups[prefix]
end
def create_group(id, title, group)
arr_buf = []
arr_buf << ""
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << ''+ title + ''
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << ''
arr_buf << group[:buf]
group[:group].each do |k, g|
arr_buf << create_group(k, k, g)
end
arr_buf << ''
arr_buf << ''
arr_buf
end
def _val(tensor)
# JSON.pretty_generate(@last_session_context[tensor.name])
@last_session_context[tensor.name] || @last_session_context[:_cache][tensor.name]
end
def to_graph_ml(tensor, arr_buf = [], added = {}, groups = {}, _id = 0)
puts tensor.name
return unless tensor.is_a?(Operation)
added[tensor.name] = true
node_buf = []
node_buf << ""
node_buf << "#{tensor.operation}"
node_buf << "#{tensor.to_math(true, 1)}"
node_buf << "blue"
if @last_session_context[tensor.name]
arr_buf << "#{_val(tensor)}"
end
node_buf << ""
node_buf << ""
if tensor.internal?
node_buf << " "
else
node_buf << " "
end
node_buf << " #{tensor.operation}"
node_buf << ""
node_buf << ""
node_buf << ""
if !add_to_group(groups, tensor.name, node_buf)
add_to_group(groups, "program/#{tensor.name}", node_buf)
end
tensor.items.each do |item|
next unless item
next if added[item.name]
next to_graph_ml(item, arr_buf, added, groups) if item.is_a?(Operation)
added[item.name] = true
item_buf = []
if item.is_a?(Variable)
item_buf << ""
item_buf << "#{item.name}"
item_buf << "green"
if @last_session_context[item.name]
item_buf << "#{_val(tensor)}"
end
item_buf << ""
item_buf << ""
item_buf << " "
item_buf << " #{item.name}"
item_buf << ""
item_buf << ""
item_buf << ""
elsif item.is_a?(Placeholder)
item_buf << ""
item_buf << ""
item_buf << ""
item_buf << " "
item_buf << " #{item.name}"
item_buf << ""
item_buf << ""
if @last_session_context[item.name]
item_buf << "#{_val(tensor)}"
end
item_buf << ""
elsif item.is_a?(Tensor)
item_buf << ""
item_buf << "#{item.name}"
item_buf << "black"
item_buf << ""
item_buf << ""
if item.internal?
item_buf << " "
else
item_buf << " "
end
item_buf << " #{item.name}"
item_buf << ""
item_buf << ""
item_buf << ""
end
if !add_to_group(groups, item.name, item_buf)
if item.is_a?(Variable)
add_to_group(groups, "variable/#{item.name}", item_buf)
else
add_to_group(groups, "program/#{item.name}", item_buf)
end
end
end
tensor.items.each_with_index do |item, index|
next unless item
output_edge(item, tensor, arr_buf, index)
end
end
def _gml_string(str)
str.gsub('/','-')
end
def output_edge(item, tensor, arr_buf, index = 0)
target_name = tensor.is_a?(Tensor) ? tensor.name : tensor
arr_buf << ""
arr_buf << ""
arr_buf << ""
arr_buf << ""
if !@last_session_context.empty?
arr_buf << ""
else
if item.shape.shape.nil?
arr_buf << ""
else
arr_buf << ""
end
end
arr_buf << ""
arr_buf << ""
if index == 0
arr_buf << ""
else
arr_buf << ""
end
arr_buf << ""
arr_buf << ""
arr_buf << ""
end
end
end