ext/cumo/narray/gen/tmpl/accum_index_kernel.cu in cumo-0.1.2 vs ext/cumo/narray/gen/tmpl/accum_index_kernel.cu in cumo-0.2.0
- old
+ new
@@ -15,11 +15,11 @@
struct cumo_<%=type_name%>_min_index_int<%=i%>_impl {
struct MinAndArgMin {
dtype min;
idx_t argmin;
};
- __device__ MinAndArgMin Identity() { return {DATA_MAX, 0}; }
+ __device__ MinAndArgMin Identity(idx_t index) { return {DATA_MAX, index}; }
__device__ MinAndArgMin MapIn(dtype in, idx_t index) { return {in, index}; }
__device__ void Reduce(MinAndArgMin next, MinAndArgMin& accum) {
if (accum.min > next.min) {
accum = next;
}
@@ -30,10 +30,10 @@
struct cumo_<%=type_name%>_max_index_int<%=i%>_impl {
struct MaxAndArgMax {
dtype max;
idx_t argmax;
};
- __device__ MaxAndArgMax Identity() { return {DATA_MIN, 0}; }
+ __device__ MaxAndArgMax Identity(idx_t index) { return {DATA_MIN, index}; }
__device__ MaxAndArgMax MapIn(dtype in, idx_t index) { return {in, index}; }
__device__ void Reduce(MaxAndArgMax next, MaxAndArgMax& accum) {
if (accum.max < next.max) {
accum = next;
}