lib/torch/native/function.rb in torch-rb-0.3.6 vs lib/torch/native/function.rb in torch-rb-0.3.7
- old
+ new
@@ -4,13 +4,14 @@
attr_reader :function, :tensor_options
def initialize(function)
@function = function
- tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
- @tensor_options = @function["func"].include?(tensor_options_str)
- @function["func"].sub!(tensor_options_str, ")")
+ # note: don't modify function in-place
+ @tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
+ @tensor_options = @function["func"].include?(@tensor_options_str)
+ @out = out_size > 0 && base_name[-1] != "_"
end
def func
@func ||= @function["func"]
end
@@ -29,11 +30,11 @@
def args
@args ||= begin
args = []
pos = true
- args_str = func.split("(", 2).last.split(") ->").first
+ args_str = func.sub(@tensor_options_str, ")").split("(", 2).last.split(") ->").first
args_str.split(", ").each do |a|
if a == "*"
pos = false
next
end
@@ -70,16 +71,92 @@
end
next if t == "Generator?"
next if t == "MemoryFormat"
next if t == "MemoryFormat?"
- args << {name: k, type: t, default: d, pos: pos, has_default: has_default}
+ args << {name: k.to_sym, type: t, default: d, pos: pos, has_default: has_default}
end
args
end
end
+ def arg_checkers
+ @arg_checkers ||= begin
+ checkers = {}
+ arg_types.each do |k, t|
+ checker =
+ case t
+ when "Tensor"
+ ->(v) { v.is_a?(Tensor) }
+ when "Tensor?"
+ ->(v) { v.nil? || v.is_a?(Tensor) }
+ when "Tensor[]", "Tensor?[]"
+ ->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) } }
+ when "int"
+ if k == :reduction
+ ->(v) { v.is_a?(String) }
+ else
+ ->(v) { v.is_a?(Integer) }
+ end
+ when "int?"
+ ->(v) { v.is_a?(Integer) || v.nil? }
+ when "float?"
+ ->(v) { v.is_a?(Numeric) || v.nil? }
+ when "bool?"
+ ->(v) { v == true || v == false || v.nil? }
+ when "float"
+ ->(v) { v.is_a?(Numeric) }
+ when /int\[.*\]/
+ ->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) } }
+ when "Scalar"
+ ->(v) { v.is_a?(Numeric) }
+ when "Scalar?"
+ ->(v) { v.is_a?(Numeric) || v.nil? }
+ when "ScalarType"
+ ->(v) { false } # not supported yet
+ when "ScalarType?"
+ ->(v) { v.nil? }
+ when "bool"
+ ->(v) { v == true || v == false }
+ when "str"
+ ->(v) { v.is_a?(String) }
+ else
+ raise Error, "Unknown argument type: #{t}. Please report a bug with #{@name}."
+ end
+ checkers[k] = checker
+ end
+ checkers
+ end
+ end
+
+ def int_array_lengths
+ @int_array_lengths ||= begin
+ ret = {}
+ arg_types.each do |k, t|
+ if t.match?(/\Aint\[.+\]\z/)
+ size = t[4..-2]
+ raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
+ ret[k] = size.to_i
+ end
+ end
+ ret
+ end
+ end
+
+ def arg_names
+ @arg_names ||= args.map { |a| a[:name] }
+ end
+
+ def arg_types
+ @arg_types ||= args.map { |a| [a[:name], a[:type].split("(").first] }.to_h
+ end
+
+ def arg_defaults
+ # TODO find out why can't use select here
+ @arg_defaults ||= args.map { |a| [a[:name], a[:default]] }.to_h
+ end
+
def out_size
@out_size ||= func.split("->").last.count("!")
end
def ret_size
@@ -88,11 +165,15 @@
def ret_array?
@ret_array ||= func.split("->").last.include?('[]')
end
+ def ret_void?
+ func.split("->").last.strip == "()"
+ end
+
def out?
- out_size > 0 && base_name[-1] != "_"
+ @out
end
def ruby_name
@ruby_name ||= begin
name = base_name