lib/torch/native/function.rb in torch-rb-0.3.6 vs lib/torch/native/function.rb in torch-rb-0.3.7

- old
+ new

@@ -4,13 +4,14 @@ attr_reader :function, :tensor_options def initialize(function) @function = function - tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)" - @tensor_options = @function["func"].include?(tensor_options_str) - @function["func"].sub!(tensor_options_str, ")") + # note: don't modify function in-place + @tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)" + @tensor_options = @function["func"].include?(@tensor_options_str) + @out = out_size > 0 && base_name[-1] != "_" end def func @func ||= @function["func"] end @@ -29,11 +30,11 @@ def args @args ||= begin args = [] pos = true - args_str = func.split("(", 2).last.split(") ->").first + args_str = func.sub(@tensor_options_str, ")").split("(", 2).last.split(") ->").first args_str.split(", ").each do |a| if a == "*" pos = false next end @@ -70,16 +71,92 @@ end next if t == "Generator?" next if t == "MemoryFormat" next if t == "MemoryFormat?" - args << {name: k, type: t, default: d, pos: pos, has_default: has_default} + args << {name: k.to_sym, type: t, default: d, pos: pos, has_default: has_default} end args end end + def arg_checkers + @arg_checkers ||= begin + checkers = {} + arg_types.each do |k, t| + checker = + case t + when "Tensor" + ->(v) { v.is_a?(Tensor) } + when "Tensor?" + ->(v) { v.nil? || v.is_a?(Tensor) } + when "Tensor[]", "Tensor?[]" + ->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) } } + when "int" + if k == :reduction + ->(v) { v.is_a?(String) } + else + ->(v) { v.is_a?(Integer) } + end + when "int?" + ->(v) { v.is_a?(Integer) || v.nil? } + when "float?" + ->(v) { v.is_a?(Numeric) || v.nil? } + when "bool?" + ->(v) { v == true || v == false || v.nil? } + when "float" + ->(v) { v.is_a?(Numeric) } + when /int\[.*\]/ + ->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) } } + when "Scalar" + ->(v) { v.is_a?(Numeric) } + when "Scalar?" + ->(v) { v.is_a?(Numeric) || v.nil? } + when "ScalarType" + ->(v) { false } # not supported yet + when "ScalarType?" + ->(v) { v.nil? } + when "bool" + ->(v) { v == true || v == false } + when "str" + ->(v) { v.is_a?(String) } + else + raise Error, "Unknown argument type: #{t}. Please report a bug with #{@name}." + end + checkers[k] = checker + end + checkers + end + end + + def int_array_lengths + @int_array_lengths ||= begin + ret = {} + arg_types.each do |k, t| + if t.match?(/\Aint\[.+\]\z/) + size = t[4..-2] + raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/ + ret[k] = size.to_i + end + end + ret + end + end + + def arg_names + @arg_names ||= args.map { |a| a[:name] } + end + + def arg_types + @arg_types ||= args.map { |a| [a[:name], a[:type].split("(").first] }.to_h + end + + def arg_defaults + # TODO find out why can't use select here + @arg_defaults ||= args.map { |a| [a[:name], a[:default]] }.to_h + end + def out_size @out_size ||= func.split("->").last.count("!") end def ret_size @@ -88,11 +165,15 @@ def ret_array? @ret_array ||= func.split("->").last.include?('[]') end + def ret_void? + func.split("->").last.strip == "()" + end + def out? - out_size > 0 && base_name[-1] != "_" + @out end def ruby_name @ruby_name ||= begin name = base_name