ext/torch/ext.cpp in torch-rb-0.5.2 vs ext/torch/ext.cpp in torch-rb-0.5.3
- old
+ new
@@ -42,25 +42,32 @@
obj = a[i];
if (obj.is_instance_of(rb_cInteger)) {
indices.push_back(from_ruby<int64_t>(obj));
} else if (obj.is_instance_of(rb_cRange)) {
- torch::optional<int64_t> start_index = from_ruby<int64_t>(obj.call("begin"));
- torch::optional<int64_t> stop_index = -1;
+ torch::optional<int64_t> start_index = torch::nullopt;
+ torch::optional<int64_t> stop_index = torch::nullopt;
+ Object begin = obj.call("begin");
+ if (!begin.is_nil()) {
+ start_index = from_ruby<int64_t>(begin);
+ }
+
Object end = obj.call("end");
if (!end.is_nil()) {
stop_index = from_ruby<int64_t>(end);
}
Object exclude_end = obj.call("exclude_end?");
- if (!exclude_end) {
+ if (stop_index.has_value() && !exclude_end) {
if (stop_index.value() == -1) {
stop_index = torch::nullopt;
} else {
stop_index = stop_index.value() + 1;
}
+ } else if (!stop_index.has_value() && exclude_end) {
+ stop_index = -1;
}
indices.push_back(torch::indexing::Slice(start_index, stop_index));
} else if (obj.is_instance_of(rb_cTensor)) {
indices.push_back(from_ruby<Tensor>(obj));
@@ -616,7 +623,9 @@
});
Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
.add_handler<torch::Error>(handle_error)
.define_singleton_method("available?", &torch::cuda::is_available)
- .define_singleton_method("device_count", &torch::cuda::device_count);
+ .define_singleton_method("device_count", &torch::cuda::device_count)
+ .define_singleton_method("manual_seed", &torch::cuda::manual_seed)
+ .define_singleton_method("manual_seed_all", &torch::cuda::manual_seed_all);
}