lib/torch.rb in torch-rb-0.3.2 vs lib/torch.rb in torch-rb-0.3.3
- old
+ new
@@ -458,9 +458,20 @@
def zeros_like(input, **options)
zeros(input.size, **like_options(input, options))
end
+ def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
+ if center
+ signal_dim = input.dim
+ extended_shape = [1] * (3 - signal_dim) + input.size
+ pad = n_fft.div(2).to_i
+ input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
+ input = input.view(input.shape[-signal_dim..-1])
+ end
+ _stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
+ end
+
private
def to_ivalue(obj)
case obj
when String