lib/torch/inspector.rb in torch-rb-0.1.1 vs lib/torch/inspector.rb in torch-rb-0.1.2
- old
+ new
@@ -1,7 +1,8 @@
module Torch
module Inspector
+ # TODO make more performance, especially when summarizing
def inspect
data =
if numel == 0
"[]"
elsif dim == 0
@@ -18,11 +19,11 @@
end
if floating_point?
sci = max / min.to_f > 1000 || max > 1e8 || min < 1e-4
- all_int = values.all? { |v| v == v.to_i }
+ all_int = values.all? { |v| v.finite? && v == v.to_i }
decimal = all_int ? 1 : 4
total += sci ? 10 : decimal + 1 + max.to_i.to_s.size
if sci
@@ -33,11 +34,13 @@
else
total += max.to_s.size
fmt = "%#{total}d"
end
- inspect_level(to_a, fmt, dim - 1)
+ summarize = numel > 1000
+
+ inspect_level(to_a, fmt, dim - 1, 0, summarize)
end
attributes = []
if requires_grad
attributes << "requires_grad: true"
@@ -49,14 +52,33 @@
"tensor(#{data}#{attributes.map { |a| ", #{a}" }.join("")})"
end
private
- def inspect_level(arr, fmt, total, level = 0)
+ # TODO DRY code
+ def inspect_level(arr, fmt, total, level, summarize)
if level == total
- "[#{arr.map { |v| fmt % v }.join(", ")}]"
+ cols =
+ if summarize && arr.size > 7
+ arr[0..2].map { |v| fmt % v } +
+ ["..."] +
+ arr[-3..-1].map { |v| fmt % v }
+ else
+ arr.map { |v| fmt % v }
+ end
+
+ "[#{cols.join(", ")}]"
else
- "[#{arr.map { |row| inspect_level(row, fmt, total, level + 1) }.join(",#{"\n" * (total - level)}#{" " * (level + 8)}")}]"
+ rows =
+ if summarize && arr.size > 7
+ arr[0..2].map { |row| inspect_level(row, fmt, total, level + 1, summarize) } +
+ ["..."] +
+ arr[-3..-1].map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
+ else
+ arr.map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
+ end
+
+ "[#{rows.join(",#{"\n" * (total - level)}#{" " * (level + 8)}")}]"
end
end
end
end