lib/torch.rb in torch-rb-0.5.0 vs lib/torch.rb in torch-rb-0.5.1
- old
+ new
@@ -259,10 +259,12 @@
elsif args.size == 1 && args.first.is_a?(ByteStorage) && dtype == :uint8
bytes = args.first.bytes
Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
elsif args.size == 1 && args.first.is_a?(Array)
Torch.tensor(args.first, dtype: dtype, device: device)
+ elsif args.size == 0
+ Torch.empty(0, dtype: dtype, device: device)
else
Torch.empty(*args, dtype: dtype, device: device)
end
end
TENSOR_TYPE_CLASSES << cls
@@ -432,23 +434,19 @@
def zeros_like(input, **options)
zeros(input.size, **like_options(input, options))
end
- def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
+ # center option
+ def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true, return_complex: nil)
if center
signal_dim = input.dim
extended_shape = [1] * (3 - signal_dim) + input.size
pad = n_fft.div(2).to_i
input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
input = input.view(input.shape[-signal_dim..-1])
end
- _stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
- end
-
- def clamp(tensor, min, max)
- tensor = _clamp_min(tensor, min)
- _clamp_max(tensor, max)
+ _stft(input, n_fft, hop_length, win_length, window, normalized, onesided, return_complex)
end
private
def to_ivalue(obj)