lib/torch/inspector.rb in torch-rb-0.1.1 vs lib/torch/inspector.rb in torch-rb-0.1.2

- old
+ new

@@ -1,7 +1,8 @@ module Torch module Inspector + # TODO make more performance, especially when summarizing def inspect data = if numel == 0 "[]" elsif dim == 0 @@ -18,11 +19,11 @@ end if floating_point? sci = max / min.to_f > 1000 || max > 1e8 || min < 1e-4 - all_int = values.all? { |v| v == v.to_i } + all_int = values.all? { |v| v.finite? && v == v.to_i } decimal = all_int ? 1 : 4 total += sci ? 10 : decimal + 1 + max.to_i.to_s.size if sci @@ -33,11 +34,13 @@ else total += max.to_s.size fmt = "%#{total}d" end - inspect_level(to_a, fmt, dim - 1) + summarize = numel > 1000 + + inspect_level(to_a, fmt, dim - 1, 0, summarize) end attributes = [] if requires_grad attributes << "requires_grad: true" @@ -49,14 +52,33 @@ "tensor(#{data}#{attributes.map { |a| ", #{a}" }.join("")})" end private - def inspect_level(arr, fmt, total, level = 0) + # TODO DRY code + def inspect_level(arr, fmt, total, level, summarize) if level == total - "[#{arr.map { |v| fmt % v }.join(", ")}]" + cols = + if summarize && arr.size > 7 + arr[0..2].map { |v| fmt % v } + + ["..."] + + arr[-3..-1].map { |v| fmt % v } + else + arr.map { |v| fmt % v } + end + + "[#{cols.join(", ")}]" else - "[#{arr.map { |row| inspect_level(row, fmt, total, level + 1) }.join(",#{"\n" * (total - level)}#{" " * (level + 8)}")}]" + rows = + if summarize && arr.size > 7 + arr[0..2].map { |row| inspect_level(row, fmt, total, level + 1, summarize) } + + ["..."] + + arr[-3..-1].map { |row| inspect_level(row, fmt, total, level + 1, summarize) } + else + arr.map { |row| inspect_level(row, fmt, total, level + 1, summarize) } + end + + "[#{rows.join(",#{"\n" * (total - level)}#{" " * (level + 8)}")}]" end end end end