ext/cumo/narray/gen/tmpl/batch_norm_backward.c in cumo-0.4.1 vs ext/cumo/narray/gen/tmpl/batch_norm_backward.c in cumo-0.4.2
- old
+ new
@@ -11,19 +11,19 @@
# CUDNN_DATA_HALF
raise 'not supported'
end
%>
-// gx, ggamma, gbeta = x.batch_normalizatoin_backward(gamma, gy, mean:, inv_std:, eps:, axis:)
+// gx, ggamma, gbeta = x.batch_norm_backward(gamma, gy, mean:, inv_std:, eps:, axis:)
static VALUE
<%=c_func(-1)%>(int argc, VALUE argv[], VALUE self)
{
cudnnDataType_t cudnn_dtype = <%= cudnn_dtype %>;
cudnnStatus_t status = 0;
cudnnHandle_t handle = 0;
- dtype coef_alpha = 1;
- dtype coef_beta = 0;
+ dtype coef_one = 1;
+ dtype coef_zero = 0;
VALUE x=self, gamma, gy, mean, inv_std, eps, axis, gx, ggamma, gbeta;
VALUE kw_hash = Qnil;
ID kw_table[] = {
rb_intern("mean"),
@@ -34,13 +34,13 @@
rb_intern("ggamma"),
rb_intern("gbeta")
};
VALUE opts[] = {Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef};
- cumo_narray_t *nx, *ngamma; // , *ngy;
- size_t *x_shape, *gamma_shape; // , *gy_shape, reduced_shape[CUMO_NA_MAX_DIMENSION];
- size_t x_ndim, gamma_ndim; // , gy_ndim, reduced_ndim;
+ cumo_narray_t *nx, *ngamma;
+ size_t *x_shape, *gamma_shape;
+ size_t x_ndim, gamma_ndim;
VALUE x_cont, gamma_cont, gy_cont;
cudnnTensorDescriptor_t x_desc = 0;
cudnnTensorDescriptor_t bn_desc = 0;
char *x_cont_ptr, *gamma_cont_ptr, *gy_cont_ptr, *gx_ptr, *ggamma_ptr, *gbeta_ptr;
@@ -77,32 +77,34 @@
axis_ndim = cumo_cuda_cudnn_get_int_axis(int_axis, axis);
}
CumoGetNArray(x, nx);
CumoGetNArray(gamma, ngamma);
- // CumoGetNArray(gy, ngy);
x_ndim = nx->ndim;
x_shape = nx->shape;
gamma_ndim = ngamma->ndim;
gamma_shape = ngamma->shape;
- // gy_ndim = ngy->ndim;
- // gy_shape = ngy->shape;
- // TODO: Size check of gammma, beta, running_mean, running_var, mean, inv_std
- // are equivalent with either of reduced_shape(keepdims: false) or reduced_shape(keepdims: true)
- // reduced_ndim = cumo_cuda_cudnn_ReduceShape(reduced_shape, x_ndim, x_shape, axis_ndim, int_axis, 1);
- // CUMO_CUDA_CUDNN_CHECK_DIM_EQ(reduced_ndim, gamma_ndim);
- // for (size_t idim = 0; idim < reduced_ndim; ++idim) {
- // CUMO_CUDA_CUDNN_CHECK_DIM_EQ(reduced_shape[idim], gamma_shape[idim]);
- // }
- // CUMO_CUDA_CUDNN_CHECK_DIM_EQ(x_ndim, gy_ndim);
- // for (size_t idim = 0; idim < x_ndim; ++idim) {
- // CUMO_CUDA_CUDNN_CHECK_DIM_EQ(x_shape[idim], gy_shape[idim]);
- // }
+ {
+ cumo_narray_t *ngy, *nmean, *ninv_std;
+ cumo_cuda_cudnn_shape_t reduced_shape = cumo_cuda_cudnn_ReduceShape(x_ndim, x_shape, axis_ndim, int_axis, 1);
+ size_t reduced_total_size = cumo_cuda_cudnn_GetTotalSize(&reduced_shape);
- // TODO: Add ndim and shape (same with reduced) for mean and inv_std if given
+ CumoGetNArray(gy, ngy);
+ CUMO_CUDA_CUDNN_CHECK_SIZE_EQ(nx->size, ngy->size);
+ CUMO_CUDA_CUDNN_CHECK_SIZE_EQ(ngamma->size, reduced_total_size);
+ if (mean != Qnil) {
+ CumoGetNArray(mean, nmean);
+ CUMO_CUDA_CUDNN_CHECK_SIZE_EQ(nmean->size, reduced_total_size);
+ }
+ if (inv_std != Qnil) {
+ CumoGetNArray(inv_std, ninv_std);
+ CUMO_CUDA_CUDNN_CHECK_SIZE_EQ(ninv_std->size, reduced_total_size);
+ }
+ }
+
CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(x, cT);
CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(gamma, cT);
CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(gy, cT);
if (mean != Qnil) CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(mean, cT);
if (inv_std != Qnil) CUMO_CUDA_CUDNN_CHECK_NARRAY_TYPE(inv_std, cT);
@@ -140,13 +142,13 @@
handle = cumo_cuda_cudnn_handle();
status = cudnnBatchNormalizationBackward(
handle,
mode,
- (void*)&coef_alpha,
- (void*)&coef_beta,
- (void*)&coef_alpha,
- (void*)&coef_beta,
+ (void*)&coef_one,
+ (void*)&coef_zero,
+ (void*)&coef_one,
+ (void*)&coef_zero,
x_desc,
x_cont_ptr,
x_desc,
gy_cont_ptr,
x_desc,