lib/torch.rb in torch-rb-0.3.5 vs lib/torch.rb in torch-rb-0.3.6

- old
+ new

@@ -237,28 +237,25 @@ end end cls end - FloatTensor = _make_tensor_class(:float32) - DoubleTensor = _make_tensor_class(:float64) - HalfTensor = _make_tensor_class(:float16) - ByteTensor = _make_tensor_class(:uint8) - CharTensor = _make_tensor_class(:int8) - ShortTensor = _make_tensor_class(:int16) - IntTensor = _make_tensor_class(:int32) - LongTensor = _make_tensor_class(:int64) - BoolTensor = _make_tensor_class(:bool) + DTYPE_TO_CLASS = { + float32: "FloatTensor", + float64: "DoubleTensor", + float16: "HalfTensor", + uint8: "ByteTensor", + int8: "CharTensor", + int16: "ShortTensor", + int32: "IntTensor", + int64: "LongTensor", + bool: "BoolTensor" + } - CUDA::FloatTensor = _make_tensor_class(:float32, true) - CUDA::DoubleTensor = _make_tensor_class(:float64, true) - CUDA::HalfTensor = _make_tensor_class(:float16, true) - CUDA::ByteTensor = _make_tensor_class(:uint8, true) - CUDA::CharTensor = _make_tensor_class(:int8, true) - CUDA::ShortTensor = _make_tensor_class(:int16, true) - CUDA::IntTensor = _make_tensor_class(:int32, true) - CUDA::LongTensor = _make_tensor_class(:int64, true) - CUDA::BoolTensor = _make_tensor_class(:bool, true) + DTYPE_TO_CLASS.each do |dtype, class_name| + const_set(class_name, _make_tensor_class(dtype)) + CUDA.const_set(class_name, _make_tensor_class(dtype, true)) + end class << self # Torch.float, Torch.long, etc DTYPE_TO_ENUM.each_key do |dtype| define_method(dtype) do