ext/cumo/narray/gen/tmpl/batch_norm_backward.c in cumo-0.4.1 vs ext/cumo/narray/gen/tmpl/batch_norm_backward.c in cumo-0.4.2

- old
+ new

@@ -11,19 +11,19 @@ # CUDNN_DATA_HALF raise 'not supported' end %> -// gx, ggamma, gbeta = x.batch_normalizatoin_backward(gamma, gy, mean:, inv_std:, eps:, axis:) +// gx, ggamma, gbeta = x.batch_norm_backward(gamma, gy, mean:, inv_std:, eps:, axis:) static VALUE <%=c_func(-1)%>(int argc, VALUE argv[], VALUE self) { cudnnDataType_t cudnn_dtype = <%= cudnn_dtype %>; cudnnStatus_t status = 0; cudnnHandle_t handle = 0; - dtype coef_alpha = 1; - dtype coef_beta = 0; + dtype coef_one = 1; + dtype coef_zero = 0; VALUE x=self, gamma, gy, mean, inv_std, eps, axis, gx, ggamma, gbeta; VALUE kw_hash = Qnil; ID kw_table[] = { rb_intern("mean"), @@ -34,13 +34,13 @@ rb_intern("ggamma"), rb_intern("gbeta") }; VALUE opts[] = {Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef}; - cumo_narray_t *nx, *ngamma; // , *ngy; - size_t *x_shape, *gamma_shape; // , *gy_shape, reduced_shape[CUMO_NA_MAX_DIMENSION]; - size_t x_ndim, gamma_ndim; // , gy_ndim, reduced_ndim; + cumo_narray_t *nx, *ngamma; + size_t *x_shape, *gamma_shape; + size_t x_ndim, gamma_ndim; VALUE x_cont, gamma_cont, gy_cont; cudnnTensorDescriptor_t x_desc = 0; cudnnTensorDescriptor_t bn_desc = 0; char *x_cont_ptr, *gamma_cont_ptr, *gy_cont_ptr, *gx_ptr, *ggamma_ptr, *gbeta_ptr; @@ -77,32 +77,34 @@ axis_ndim = cumo_cuda_cudnn_get_int_axis(int_axis, axis); } CumoGetNArray(x, nx); CumoGetNArray(gamma, ngamma); - // CumoGetNArray(gy, ngy); x_ndim = nx->ndim; x_shape = nx->shape; gamma_ndim = ngamma->ndim; gamma_shape = ngamma->shape; - // gy_ndim = ngy->ndim; - // gy_shape = ngy->shape; - // TODO: Size check of gammma, beta, running_mean, running_var, mean, inv_std - // are equivalent with either of reduced_shape(keepdims: false) or reduced_shape(keepdims: true) - // reduced_ndim = cumo_cuda_cudnn_ReduceShape(reduced_shape, x_ndim, x_shape, axis_ndim, int_axis, 1); - // CUMO_CUDA_CUDNN_CHECK_DIM_EQ(reduced_ndim, gamma_ndim); - // for (size_t idim = 0; idim < reduced_ndim; ++idim) { - // CUMO_CUDA_CUDNN_CHECK_DIM_EQ(reduced_shape[idim], gamma_shape[idim]); - // } - // CUMO_CUDA_CUDNN_CHECK_DIM_EQ(x_ndim, gy_ndim); - // for (size_t idim = 0; idim < x_ndim; ++idim) { - // CUMO_CUDA_CUDNN_CHECK_DIM_EQ(x_shape[idim], gy_shape[idim]); - // } + { + cumo_narray_t *ngy, *nmean, *ninv_std; + cumo_cuda_cudnn_shape_t reduced_shape = cumo_cuda_cudnn_ReduceShape(x_ndim, x_shape, axis_ndim, int_axis, 1); + size_t reduced_total_size = cumo_cuda_cudnn_GetTotalSize(&reduced_shape); - // TODO: Add ndim and shape (same with reduced) for mean and inv_std if given + CumoGetNArray(gy, ngy); + CUMO_CUDA_CUDNN_CHECK_SIZE_EQ(nx->size, ngy->size); + CUMO_CUDA_CUDNN_CHECK_SIZE_EQ(ngamma->size, reduced_total_size); + if (mean != Qnil) { + CumoGetNArray(mean, nmean); + CUMO_CUDA_CUDNN_CHECK_SIZE_EQ(nmean->size, reduced_total_size); + } + if (inv_std != Qnil) { + CumoGetNArray(inv_std, ninv_std); + CUMO_CUDA_CUDNN_CHECK_SIZE_EQ(ninv_std->size, reduced_total_size); + } + } + CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(x, cT); CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(gamma, cT); CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(gy, cT); if (mean != Qnil) CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(mean, cT); if (inv_std != Qnil) CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(inv_std, cT); @@ -140,13 +142,13 @@ handle = cumo_cuda_cudnn_handle(); status = cudnnBatchNormalizationBackward( handle, mode, - (void*)&coef_alpha, - (void*)&coef_beta, - (void*)&coef_alpha, - (void*)&coef_beta, + (void*)&coef_one, + (void*)&coef_zero, + (void*)&coef_one, + (void*)&coef_zero, x_desc, x_cont_ptr, x_desc, gy_cont_ptr, x_desc,