Sha256: 0f732fb8a90d5fb0c96fc9b62ab800fd25f4e5d042585baf6e7e71da654e5423
Contents?: true
Size: 842 Bytes
Versions: 36
Compression:
Stored size: 842 Bytes
Contents
module Torch module NN class Upsample < Module def initialize(size: nil, scale_factor: nil, mode: "nearest", align_corners: nil) super() @size = size if scale_factor.is_a?(Array) @scale_factor = scale_factor.map(&:to_f) else @scale_factor = scale_factor ? scale_factor.to_f : nil end @mode = mode @align_corners = align_corners end def forward(input) F.interpolate(input, size: @size, scale_factor: @scale_factor, mode: @mode, align_corners: @align_corners) end def extra_inspect if !@scale_factor.nil? info = "scale_factor: #{@scale_factor.inspect}" else info = "size: #{@size.inspect}" end info += ", mode: #{@mode.inspect}" info end end end end
Version data entries
36 entries across 36 versions & 1 rubygems