lib/torch.rb in torch-rb-0.3.5 vs lib/torch.rb in torch-rb-0.3.6
- old
+ new
@@ -237,28 +237,25 @@
end
end
cls
end
- FloatTensor = _make_tensor_class(:float32)
- DoubleTensor = _make_tensor_class(:float64)
- HalfTensor = _make_tensor_class(:float16)
- ByteTensor = _make_tensor_class(:uint8)
- CharTensor = _make_tensor_class(:int8)
- ShortTensor = _make_tensor_class(:int16)
- IntTensor = _make_tensor_class(:int32)
- LongTensor = _make_tensor_class(:int64)
- BoolTensor = _make_tensor_class(:bool)
+ DTYPE_TO_CLASS = {
+ float32: "FloatTensor",
+ float64: "DoubleTensor",
+ float16: "HalfTensor",
+ uint8: "ByteTensor",
+ int8: "CharTensor",
+ int16: "ShortTensor",
+ int32: "IntTensor",
+ int64: "LongTensor",
+ bool: "BoolTensor"
+ }
- CUDA::FloatTensor = _make_tensor_class(:float32, true)
- CUDA::DoubleTensor = _make_tensor_class(:float64, true)
- CUDA::HalfTensor = _make_tensor_class(:float16, true)
- CUDA::ByteTensor = _make_tensor_class(:uint8, true)
- CUDA::CharTensor = _make_tensor_class(:int8, true)
- CUDA::ShortTensor = _make_tensor_class(:int16, true)
- CUDA::IntTensor = _make_tensor_class(:int32, true)
- CUDA::LongTensor = _make_tensor_class(:int64, true)
- CUDA::BoolTensor = _make_tensor_class(:bool, true)
+ DTYPE_TO_CLASS.each do |dtype, class_name|
+ const_set(class_name, _make_tensor_class(dtype))
+ CUDA.const_set(class_name, _make_tensor_class(dtype, true))
+ end
class << self
# Torch.float, Torch.long, etc
DTYPE_TO_ENUM.each_key do |dtype|
define_method(dtype) do