ext/torch/ext.cpp in torch-rb-0.5.2 vs ext/torch/ext.cpp in torch-rb-0.5.3

- old
+ new

@@ -42,25 +42,32 @@ obj = a[i]; if (obj.is_instance_of(rb_cInteger)) { indices.push_back(from_ruby<int64_t>(obj)); } else if (obj.is_instance_of(rb_cRange)) { - torch::optional<int64_t> start_index = from_ruby<int64_t>(obj.call("begin")); - torch::optional<int64_t> stop_index = -1; + torch::optional<int64_t> start_index = torch::nullopt; + torch::optional<int64_t> stop_index = torch::nullopt; + Object begin = obj.call("begin"); + if (!begin.is_nil()) { + start_index = from_ruby<int64_t>(begin); + } + Object end = obj.call("end"); if (!end.is_nil()) { stop_index = from_ruby<int64_t>(end); } Object exclude_end = obj.call("exclude_end?"); - if (!exclude_end) { + if (stop_index.has_value() && !exclude_end) { if (stop_index.value() == -1) { stop_index = torch::nullopt; } else { stop_index = stop_index.value() + 1; } + } else if (!stop_index.has_value() && exclude_end) { + stop_index = -1; } indices.push_back(torch::indexing::Slice(start_index, stop_index)); } else if (obj.is_instance_of(rb_cTensor)) { indices.push_back(from_ruby<Tensor>(obj)); @@ -616,7 +623,9 @@ }); Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA") .add_handler<torch::Error>(handle_error) .define_singleton_method("available?", &torch::cuda::is_available) - .define_singleton_method("device_count", &torch::cuda::device_count); + .define_singleton_method("device_count", &torch::cuda::device_count) + .define_singleton_method("manual_seed", &torch::cuda::manual_seed) + .define_singleton_method("manual_seed_all", &torch::cuda::manual_seed_all); }