ext/cumo/narray/gen/tmpl/accum_index_kernel.cu in cumo-0.1.2 vs ext/cumo/narray/gen/tmpl/accum_index_kernel.cu in cumo-0.2.0

- old
+ new

@@ -15,11 +15,11 @@ struct cumo_<%=type_name%>_min_index_int<%=i%>_impl { struct MinAndArgMin { dtype min; idx_t argmin; }; - __device__ MinAndArgMin Identity() { return {DATA_MAX, 0}; } + __device__ MinAndArgMin Identity(idx_t index) { return {DATA_MAX, index}; } __device__ MinAndArgMin MapIn(dtype in, idx_t index) { return {in, index}; } __device__ void Reduce(MinAndArgMin next, MinAndArgMin& accum) { if (accum.min > next.min) { accum = next; } @@ -30,10 +30,10 @@ struct cumo_<%=type_name%>_max_index_int<%=i%>_impl { struct MaxAndArgMax { dtype max; idx_t argmax; }; - __device__ MaxAndArgMax Identity() { return {DATA_MIN, 0}; } + __device__ MaxAndArgMax Identity(idx_t index) { return {DATA_MIN, index}; } __device__ MaxAndArgMax MapIn(dtype in, idx_t index) { return {in, index}; } __device__ void Reduce(MaxAndArgMax next, MaxAndArgMax& accum) { if (accum.max < next.max) { accum = next; }