lib/torch/inspector.rb in torch-rb-0.2.6 vs lib/torch/inspector.rb in torch-rb-0.2.7
- old
+ new
@@ -1,89 +1,264 @@
+# mirrors _tensor_str.py
module Torch
module Inspector
- # TODO make more performant, especially when summarizing
- # how? only read data that will be displayed
- def inspect
- data =
- if numel == 0
- "[]"
- elsif dim == 0
- item
- else
- summarize = numel > 1000
+ PRINT_OPTS = {
+ precision: 4,
+ threshold: 1000,
+ edgeitems: 3,
+ linewidth: 80,
+ sci_mode: nil
+ }
- if dtype == :bool
- fmt = "%s"
- else
- values = _flat_data
- abs = values.select { |v| v != 0 }.map(&:abs)
- max = abs.max || 1
- min = abs.min || 1
+ class Formatter
+ def initialize(tensor)
+ @floating_dtype = tensor.floating_point?
+ @complex_dtype = tensor.complex?
+ @int_mode = true
+ @sci_mode = false
+ @max_width = 1
- total = 0
- if values.any? { |v| v < 0 }
- total += 1
- end
+ tensor_view = Torch.no_grad { tensor.reshape(-1) }
- if floating_point?
- sci = max > 1e8 || max < 1e-4
+ if !@floating_dtype
+ tensor_view.each do |value|
+ value_str = value.item.to_s
+ @max_width = [@max_width, value_str.length].max
+ end
+ else
+ nonzero_finite_vals = Torch.masked_select(tensor_view, Torch.isfinite(tensor_view) & tensor_view.ne(0))
- all_int = values.all? { |v| v.finite? && v == v.to_i }
- decimal = all_int ? 1 : 4
+ # no valid number, do nothing
+ return if nonzero_finite_vals.numel == 0
- total += sci ? 10 : decimal + 1 + max.to_i.to_s.size
+ # Convert to double for easy calculation. HalfTensor overflows with 1e8, and there's no div() on CPU.
+ nonzero_finite_abs = nonzero_finite_vals.abs.double
+ nonzero_finite_min = nonzero_finite_abs.min.double
+ nonzero_finite_max = nonzero_finite_abs.max.double
- if sci
- fmt = "%#{total}.4e"
- else
- fmt = "%#{total}.#{decimal}f"
+ nonzero_finite_vals.each do |value|
+ if value.item != value.item.ceil
+ @int_mode = false
+ break
+ end
+ end
+
+ if @int_mode
+ # in int_mode for floats, all numbers are integers, and we append a decimal to nonfinites
+ # to indicate that the tensor is of floating type. add 1 to the len to account for this.
+ if nonzero_finite_max / nonzero_finite_min > 1000.0 || nonzero_finite_max > 1.0e8
+ @sci_mode = true
+ nonzero_finite_vals.each do |value|
+ value_str = "%.#{PRINT_OPTS[:precision]}e" % value.item
+ @max_width = [@max_width, value_str.length].max
end
else
- total += max.to_s.size
- fmt = "%#{total}d"
+ nonzero_finite_vals.each do |value|
+ value_str = "%.0f" % value.item
+ @max_width = [@max_width, value_str.length + 1].max
+ end
end
+ else
+ # Check if scientific representation should be used.
+ if nonzero_finite_max / nonzero_finite_min > 1000.0 || nonzero_finite_max > 1.0e8 || nonzero_finite_min < 1.0e-4
+ @sci_mode = true
+ nonzero_finite_vals.each do |value|
+ value_str = "%.#{PRINT_OPTS[:precision]}e" % value.item
+ @max_width = [@max_width, value_str.length].max
+ end
+ else
+ nonzero_finite_vals.each do |value|
+ value_str = "%.#{PRINT_OPTS[:precision]}f" % value.item
+ @max_width = [@max_width, value_str.length].max
+ end
+ end
end
-
- inspect_level(to_a, fmt, dim - 1, 0, summarize)
end
- attributes = []
- if requires_grad
- attributes << "requires_grad: true"
+ @sci_mode = PRINT_OPTS[:sci_mode] unless PRINT_OPTS[:sci_mode].nil?
end
- if ![:float32, :int64, :bool].include?(dtype)
- attributes << "dtype: #{dtype.inspect}"
+
+ def width
+ @max_width
end
- "tensor(#{data}#{attributes.map { |a| ", #{a}" }.join("")})"
+ def format(value)
+ value = value.item
+
+ if @floating_dtype
+ if @sci_mode
+ ret = "%#{@max_width}.#{PRINT_OPTS[:precision]}e" % value
+ elsif @int_mode
+ ret = String.new("%.0f" % value)
+ unless value.infinite? || value.nan?
+ ret += "."
+ end
+ else
+ ret = "%.#{PRINT_OPTS[:precision]}f" % value
+ end
+ elsif @complex_dtype
+ p = PRINT_OPTS[:precision]
+ raise NotImplementedYet
+ else
+ ret = value.to_s
+ end
+ # Ruby throws error when negative, Python doesn't
+ " " * [@max_width - ret.size, 0].max + ret
+ end
end
+ def inspect
+ Torch.no_grad do
+ str_intern(self)
+ end
+ rescue => e
+ # prevent stack error
+ puts e.backtrace.join("\n")
+ "Error inspecting tensor: #{e.inspect}"
+ end
+
private
- # TODO DRY code
- def inspect_level(arr, fmt, total, level, summarize)
- if level == total
- cols =
- if summarize && arr.size > 7
- arr[0..2].map { |v| fmt % v } +
- ["..."] +
- arr[-3..-1].map { |v| fmt % v }
- else
- arr.map { |v| fmt % v }
- end
+ # TODO update
+ def str_intern(slf)
+ prefix = "tensor("
+ indent = prefix.length
+ suffixes = []
- "[#{cols.join(", ")}]"
+ has_default_dtype = [:float32, :int64, :bool].include?(slf.dtype)
+
+ if slf.numel == 0 && !slf.sparse?
+ # Explicitly print the shape if it is not (0,), to match NumPy behavior
+ if slf.dim != 1
+ suffixes << "size: #{shape.inspect}"
+ end
+
+ # In an empty tensor, there are no elements to infer if the dtype
+ # should be int64, so it must be shown explicitly.
+ if slf.dtype != :int64
+ suffixes << "dtype: #{slf.dtype.inspect}"
+ end
+ tensor_str = "[]"
else
- rows =
- if summarize && arr.size > 7
- arr[0..2].map { |row| inspect_level(row, fmt, total, level + 1, summarize) } +
- ["..."] +
- arr[-3..-1].map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
- else
- arr.map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
- end
+ if !has_default_dtype
+ suffixes << "dtype: #{slf.dtype.inspect}"
+ end
- "[#{rows.join(",#{"\n" * (total - level)}#{" " * (level + 8)}")}]"
+ if slf.layout != :strided
+ tensor_str = tensor_str(slf.to_dense, indent)
+ else
+ tensor_str = tensor_str(slf, indent)
+ end
end
+
+ if slf.layout != :strided
+ suffixes << "layout: #{slf.layout.inspect}"
+ end
+
+ # TODO show grad_fn
+ if slf.requires_grad?
+ suffixes << "requires_grad: true"
+ end
+
+ add_suffixes(prefix + tensor_str, suffixes, indent, slf.sparse?)
+ end
+
+ def add_suffixes(tensor_str, suffixes, indent, force_newline)
+ tensor_strs = [tensor_str]
+ # rfind in Python returns -1 when not found
+ last_line_len = tensor_str.length - (tensor_str.rindex("\n") || -1) + 1
+ suffixes.each do |suffix|
+ suffix_len = suffix.length
+ if force_newline || last_line_len + suffix_len + 2 > PRINT_OPTS[:linewidth]
+ tensor_strs << ",\n" + " " * indent + suffix
+ last_line_len = indent + suffix_len
+ force_newline = false
+ else
+ tensor_strs.append(", " + suffix)
+ last_line_len += suffix_len + 2
+ end
+ end
+ tensor_strs.append(")")
+ tensor_strs.join("")
+ end
+
+ def tensor_str(slf, indent)
+ return "[]" if slf.numel == 0
+
+ summarize = slf.numel > PRINT_OPTS[:threshold]
+
+ if slf.dtype == :float16 || slf.dtype == :bfloat16
+ slf = slf.float
+ end
+ formatter = Formatter.new(summarize ? summarized_data(slf) : slf)
+ tensor_str_with_formatter(slf, indent, formatter, summarize)
+ end
+
+ def summarized_data(slf)
+ edgeitems = PRINT_OPTS[:edgeitems]
+
+ dim = slf.dim
+ if dim == 0
+ slf
+ elsif dim == 1
+ if size(0) > 2 * edgeitems
+ Torch.cat([slf[0...edgeitems], slf[-edgeitems..-1]])
+ else
+ slf
+ end
+ elsif slf.size(0) > 2 * edgeitems
+ start = edgeitems.times.map { |i| slf[i] }
+ finish = (slf.length - edgeitems).upto(slf.length - 1).map { |i| slf[i] }
+ Torch.stack((start + finish).map { |x| summarized_data(x) })
+ else
+ Torch.stack(slf.map { |x| summarized_data(x) })
+ end
+ end
+
+ def tensor_str_with_formatter(slf, indent, formatter, summarize)
+ edgeitems = PRINT_OPTS[:edgeitems]
+
+ dim = slf.dim
+
+ return scalar_str(slf, formatter) if dim == 0
+ return vector_str(slf, indent, formatter, summarize) if dim == 1
+
+ if summarize && slf.size(0) > 2 * edgeitems
+ slices = (
+ [edgeitems.times.map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }] +
+ ["..."] +
+ [((slf.length - edgeitems)...slf.length).map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }]
+ )
+ else
+ slices = slf.size(0).times.map { |i| tensor_str_with_formatter(slf[i], indent + 1, formatter, summarize) }
+ end
+
+ tensor_str = slices.join("," + "\n" * (dim - 1) + " " * (indent + 1))
+ "[" + tensor_str + "]"
+ end
+
+ def scalar_str(slf, formatter)
+ formatter.format(slf)
+ end
+
+ def vector_str(slf, indent, formatter, summarize)
+ # length includes spaces and comma between elements
+ element_length = formatter.width + 2
+ elements_per_line = [1, ((PRINT_OPTS[:linewidth] - indent) / element_length.to_f).floor.to_i].max
+ char_per_line = element_length * elements_per_line
+
+ if summarize && slf.size(0) > 2 * PRINT_OPTS[:edgeitems]
+ data = (
+ [slf[0...PRINT_OPTS[:edgeitems]].map { |val| formatter.format(val) }] +
+ [" ..."] +
+ [slf[-PRINT_OPTS[:edgeitems]..-1].map { |val| formatter.format(val) }]
+ )
+ else
+ data = slf.map { |val| formatter.format(val) }
+ end
+
+ data_lines = (0...data.length).step(elements_per_line).map { |i| data[i...(i + elements_per_line)] }
+ lines = data_lines.map { |line| line.join(", ") }
+ "[" + lines.join("," + "\n" + " " * (indent + 1)) + "]"
end
end
end