ext/cumo/narray/gen/tmpl/gemm.c in cumo-0.1.0 vs ext/cumo/narray/gen/tmpl/gemm.c in cumo-0.1.1

- old
+ new

@@ -21,61 +21,236 @@ when 'dcomplex' 'cuDoubleComplex' end %> -#define args_t <%=name%>_args_t +#define ROW_SIZE(na) ((na)->shape[(na)->ndim-2]) +#define COL_SIZE(na) ((na)->shape[(na)->ndim-1]) +#define CHECK_NARRAY_TYPE(x,t) \ + if (rb_obj_class(x)!=(t)) { \ + rb_raise(rb_eTypeError,"invalid NArray type (class)"); \ + } + +// Error Class ?? +#define CHECK_DIM_GE(na,nd) \ + if ((na)->ndim<(nd)) { \ + rb_raise(cumo_na_eShapeError, \ + "n-dimension=%d, but >=%d is expected", \ + (na)->ndim, (nd)); \ + } + +#define CHECK_DIM_EQ(na1,nd) \ + if ((na1)->ndim != (nd)) { \ + rb_raise(cumo_na_eShapeError, \ + "dimention mismatch: %d != %d", \ + (na1)->ndim, (nd)); \ + } + +#define CHECK_SQUARE(name,na) \ + if ((na)->shape[(na)->ndim-1] != (na)->shape[(na)->ndim-2]) { \ + rb_raise(cumo_na_eShapeError,"%s is not square matrix",name); \ + } + +#define CHECK_SIZE_GE(na,sz) \ + if ((na)->size < (size_t)(sz)) { \ + rb_raise(cumo_na_eShapeError, \ + "NArray size must be >= %"SZF"u",(size_t)(sz));\ + } +#define CHECK_NON_EMPTY(na) \ + if ((na)->size==0) { \ + rb_raise(cumo_na_eShapeError,"empty NArray"); \ + } + +#define CHECK_SIZE_EQ(n,m) \ + if ((n)!=(m)) { \ + rb_raise(cumo_na_eShapeError, \ + "size mismatch: %"SZF"d != %"SZF"d", \ + (size_t)(n),(size_t)(m)); \ + } + +#define CHECK_SAME_SHAPE(na1,na2) \ + { int i; \ + CHECK_DIM_EQ(na1,na2->ndim); \ + for (i=0; i<na1->ndim; i++) { \ + CHECK_SIZE_EQ(na1->shape[i],na2->shape[i]); \ + } \ + } + +#define CHECK_INT_EQ(sm,m,sn,n) \ + if ((m) != (n)) { \ + rb_raise(cumo_na_eShapeError, \ + "%s must be == %s: %s=%d %s=%d", \ + sm,sn,sm,m,sn,n); \ + } + +// Error Class ?? +#define CHECK_LEADING_GE(sld,ld,sn,n) \ + if ((ld) < (n)) { \ + rb_raise(cumo_na_eShapeError, \ + "%s must be >= max(%s,1): %s=%d %s=%d", \ + sld,sn,sld,ld,sn,n); \ + } + +#define COPY_OR_CAST_TO(a,T) \ + { \ + if (rb_obj_class(a) == (T)) { \ + if (!CUMO_TEST_INPLACE(a)) { \ + a = cumo_na_copy(a); \ + } \ + } else { \ + a = rb_funcall(T,rb_intern("cast"),1,a); \ + } \ + } + typedef struct { - // enum CBLAS_ORDER order; // cuBLAS does not have order (row-major or column-major) option - cublasOperation_t transa, transb; - cublasSideMode_t side; - cublasFillMode_t uplo; - cublasDiagType_t diag; - dtype alpha, beta; - int m, n, k; -} args_t; + dtype alpha, beta; + int m, n, k; +} gemm_args_t; -static void -<%=c_iter%>(na_loop_t *const lp) +typedef struct { + int ld; + int stride; // in element count + cublasOperation_t trans; + VALUE a; +} gemm_layout_t; + +static bool +is_f_contiguous(VALUE a) { - dtype *a, *b; - int lda, ldb; - dtype *c; - int ldc; - args_t *g; - static cublasHandle_t handle = 0; + int i; + ssize_t s0; + cumo_narray_t *na; - a = (dtype*)NDL_PTR(lp,0); - b = (dtype*)NDL_PTR(lp,1); - c = (dtype*)NDL_PTR(lp,2); - g = (args_t*)(lp->opt_ptr); + switch(CUMO_RNARRAY_TYPE(a)) { + case CUMO_NARRAY_DATA_T: + case CUMO_NARRAY_FILEMAP_T: + return CUMO_TEST_COLUMN_MAJOR(a); + case CUMO_NARRAY_VIEW_T: + CumoGetNArray(a, na); - lda = NDL_STEP(lp,0) / sizeof(dtype); - ldb = NDL_STEP(lp,1) / sizeof(dtype); - ldc = NDL_STEP(lp,2) / sizeof(dtype); + // not contiguous if it has index + for (i = 0; i < CUMO_NA_NDIM(na); ++i) { + if (CUMO_NA_IS_INDEX_AT(na, i)) return false; + } - //printf("transa=%d transb=%d m=%d n=%d k=%d lda=%d ldb=%d ldc=%d\n",g->transa,g->transb,g->m,g->n,g->k,lda,ldb,ldc); + // check f-contiguous + s0 = cumo_na_element_stride(a); + for (i = 0; i < CUMO_NA_NDIM(na); ++i) { + if (CUMO_NA_SHAPE(na)[i] == 1) continue; + if (CUMO_NA_STRIDE_AT(na, i) != s0) return false; + s0 *= CUMO_NA_SHAPE(na)[i]; + } + return true; + default: + rb_raise(rb_eArgError, "NArray type : %d is not supported", CUMO_RNARRAY_TYPE(a)); + } +} +static bool +is_c_contiguous(VALUE a) +{ + return cumo_na_check_contiguous(a) == Qtrue; +} + +static gemm_layout_t +make_gemm_layout(VALUE a) +{ + cumo_narray_t *na; + gemm_layout_t layout; + + CumoGetNArray(a, na); + + if (cumo_na_debug_flag) { + printf("ndim==2 && f_contiguous:%d, c_contiguous:%d\n", + CUMO_NA_NDIM(na) == 2 && is_f_contiguous(a), is_c_contiguous(a)); + } + + if (CUMO_NA_NDIM(na) == 2 && is_f_contiguous(a)) { + layout.ld = ROW_SIZE(na); + layout.trans = CUBLAS_OP_T; + layout.a = a; + } else { + layout.ld = COL_SIZE(na); + layout.trans = CUBLAS_OP_N; // transposed + // force c-contiguous + layout.a = is_c_contiguous(a) ? a : rb_funcall(a, rb_intern("dup"), 0); + } + layout.stride = ROW_SIZE(na) * COL_SIZE(na); + return layout; +} + +extern int cumo_na_debug_flag; // narray.c + +static void +print_gemm_args(gemm_args_t* g, gemm_layout_t* a_layout, gemm_layout_t* b_layout, int stridec, int batch_count) +{ + printf("transb=%d transa=%d, n=%d, m=%d, k=%d, ldb=%d, lda=%d, ldc=n=%d, strideb=%d, stridea=%d stridec=%d batch_count=%d\n", + (int)b_layout->trans, + (int)a_layout->trans, + (int)g->n, + (int)g->m, + (int)g->k, + (int)b_layout->ld, + (int)a_layout->ld, + (int)g->n, + (int)b_layout->stride, + (int)a_layout->stride, + (int)stridec, + (int)batch_count); +} + +static void +<%=c_iter%>(VALUE a, VALUE b, VALUE c, gemm_args_t *g) +{ + gemm_layout_t a_layout, b_layout; + cublasHandle_t handle = 0; + cublasStatus_t status = 0; + cumo_narray_t* nc; + int stridec = 0; + int batch_count = 0; + // Note that cuBLAS uses the column major matrix representation. // We use technic which following site describes: // https://www.christophlassner.de/using-blas-from-c-with-row-major-data.html // // b^T = nxk matrix // a^T = kxm matrix // c^T = nxm matrix // c^T = b^T * a^T // - // cublasSgemm(handle,transb,transa,n,m,k,&alpha,b,n,a,k,&beta,c,n); + // cublasSgemm(handle,transb,transa,n,m,k,&alpha,b,ldb,a,lda,&beta,c,ldc=n); - // TODO(sonots): Create another handle for another cuda device or cpu thread - if (!handle) { - cublasCreate(&handle); - } - cublas<%=func_prefix%>gemm(handle, g->transb, g->transa, g->n, g->m, g->k, (<%=cutype%>*)(&g->alpha), (<%=cutype%>*)b, ldb, (<%=cutype%>*)a, lda, (<%=cutype%>*)(&g->beta), (<%=cutype%>*)c, ldc); - // TODO(sonots): Destroy correctly - //cublasDestroy(handle); + a_layout = make_gemm_layout(a); + b_layout = make_gemm_layout(b); + + CumoGetNArray(c, nc); + stridec = ROW_SIZE(nc) * COL_SIZE(nc); + batch_count = CUMO_NA_SIZE(nc) / stridec; + + if (cumo_na_debug_flag) print_gemm_args(g, &a_layout, &b_layout, stridec, batch_count); + handle = cumo_cuda_cublas_handle(); + status = cublas<%=func_prefix%>gemmStridedBatched( + handle, + b_layout.trans, + a_layout.trans, + g->n, + g->m, + g->k, + (<%=cutype%>*)(&g->alpha), + (<%=cutype%>*)(cumo_na_get_pointer_for_read(b_layout.a) + cumo_na_get_offset(b_layout.a)), + b_layout.ld, + b_layout.stride, + (<%=cutype%>*)(cumo_na_get_pointer_for_read(a_layout.a) + cumo_na_get_offset(a_layout.a)), + a_layout.ld, + a_layout.stride, + (<%=cutype%>*)(&g->beta), + (<%=cutype%>*)(cumo_na_get_pointer_for_write(c) + cumo_na_get_offset(c)), + g->n, + stridec, + batch_count); + cumo_cuda_cublas_check_status(status); } /* <% # ext/numo/linalg/blas/gen/decl.rb @@ -88,116 +263,85 @@ end def opt(v,tp=nil,*a) tp ||= "String or Symbol" case v - when /^order$/ - "@param #{v} [#{tp}] if 'R': Row-major, if 'C': Column-major. (default='R')" - when /^uplo$/ - "@param #{v} [#{tp}] if 'U': Upper triangle, if 'L': Lower triangle. (default='U')" - when /^side$/ - "@param #{v} [#{tp}] if 'L': op(A)\\*B (left-side op), if 'R': B\\*op(A) (right-side op). (default='L')" - when /^diag$/ - "@param #{v} [#{tp}] if 'U': assumed to be unit triangular, if 'N': not assumed to be unit triangular. (default='U')" - when /^trans(\w+)?$/ - b = a[0] || $1 - "@param #{v} [#{tp}] if 'N': Not transpose #{b}, if 'T': Transpose #{b}. (default='N')" when "alpha" "@param #{v} [Float] (default=1.0)" when "beta" "@param #{v} [Float] (default=0.0)" else "@param #{v} [#{tp}] #{a[0]}" end end %> <% - args_v = "a, b, [c, alpha:1, beta:0, transa:'N', transb:'N']" + args_v = "a, b, [c, alpha:1, beta:0]" params = [ mat("a"), mat("b"), mat("c","optional",:inplace), opt("alpha"), opt("beta"), - opt("transa"), - opt("transb"), ].select{|x| x}.join("\n ") %> @overload <%=name%>(<%=args_v%>) <%=params%> @return [<%=class_name%>] returns c = alpha\*op( A )\*op( B ) + beta\*C. <%=description%> */ static VALUE <%=c_func(-1)%>(int argc, VALUE argv[], VALUE self) { - VALUE a=self, b, c=Qnil, alpha, beta; - narray_t *na1, *na2; - int ma, ka, kb, nb, tmp; - size_t shape[2]; - ndfunc_arg_in_t ain[3] = {{cT,2},{cT,2},{OVERWRITE,2}}; - ndfunc_arg_out_t aout[1] = {{cT,2,shape}}; - ndfunc_t ndf = {<%=c_iter%>, NO_LOOP, 3, 0, ain, aout}; + VALUE a=self, b, c=Qnil, alpha, beta; + cumo_narray_t *na, *nb; - args_t g; + gemm_args_t g; VALUE kw_hash = Qnil; - ID kw_table[4] = {rb_intern("alpha"),rb_intern("beta"),rb_intern("transa"),rb_intern("transb")}; - VALUE opts[4] = {Qundef,Qundef,Qundef,Qundef}; + ID kw_table[2] = {rb_intern("alpha"), rb_intern("beta")}; + VALUE opts[2] = {Qundef, Qundef}; rb_scan_args(argc, argv, "11:", &b, &c, &kw_hash); - rb_get_kwargs(kw_hash, kw_table, 0, 4, opts); - alpha = option_value(opts[0],Qnil); - g.alpha = RTEST(alpha) ? m_num_to_data(alpha) : m_one; - beta = option_value(opts[1],Qnil); - g.beta = RTEST(beta) ? m_num_to_data(beta) : m_zero; - g.transa = option_trans(opts[2]); - g.transb = option_trans(opts[3]); + rb_get_kwargs(kw_hash, kw_table, 0, 2, opts); + alpha = cumo_cuda_cublas_option_value(opts[0],Qnil); + g.alpha = RTEST(alpha) ? m_num_to_data(alpha) : m_one; + beta = cumo_cuda_cublas_option_value(opts[1],Qnil); + g.beta = RTEST(beta) ? m_num_to_data(beta) : m_zero; - GetNArray(a,na1); - GetNArray(b,na2); - CHECK_DIM_GE(na1,2); - CHECK_DIM_GE(na2,2); - ma = ROW_SIZE(na1); // m - ka = COL_SIZE(na1); // k - kb = ROW_SIZE(na2); // k - nb = COL_SIZE(na2); // n + CumoGetNArray(a, na); + CumoGetNArray(b, nb); + CHECK_DIM_GE(na, 2); + CHECK_DIM_GE(nb, 2); - SWAP_IFTR(g.transa, ma, ka, tmp); - SWAP_IFTR(g.transb, kb, nb, tmp); - CHECK_INT_EQ("ka",ka,"kb",kb); - g.m = ma; - g.n = nb; - g.k = ka; + if (ROW_SIZE(nb) != COL_SIZE(na)) { + rb_raise(cumo_na_eShapeError,"ROW_SIZE(b)=%d must equal to COL_SIZE(a)=%d", + (int)ROW_SIZE(nb), (int)COL_SIZE(na)); + } - SWAP_IFROW(ma, nb, tmp); + g.m = ROW_SIZE(na); + g.k = COL_SIZE(na); + g.n = COL_SIZE(nb); if (c == Qnil) { // c is not given. - ndfunc_arg_in_t ain_init = {sym_init,0}; - ain[2] = ain_init; - ndf.nout = 1; - c = INT2FIX(0); - shape[0] = nb; - shape[1] = ma; + int ndim = CUMO_NA_NDIM(na); + size_t *shape = ALLOCA_N(size_t, ndim); + memcpy(shape, CUMO_NA_SHAPE(na), sizeof(size_t) * (ndim - 1)); // ... x m x k + shape[ndim - 1] = g.n; // ... x m x n + c = cumo_na_new(cT, ndim, shape); } else { - narray_t *na3; - int nc; - COPY_OR_CAST_TO(c,cT); - GetNArray(c,na3); - CHECK_DIM_GE(na3,2); - nc = ROW_SIZE(na3); - if (nc < nb) { - rb_raise(nary_eShapeError,"nc=%d must be >= nb=%d",nc,nb); + cumo_narray_t *nc; + COPY_OR_CAST_TO(c, cT); + CumoGetNArray(c, nc); + CHECK_DIM_GE(nc, 2); + if (ROW_SIZE(nc) != ROW_SIZE(na)) { + rb_raise(cumo_na_eShapeError,"ROW_SIZE(c)=%d must equal to ROW_SIZE(a)=%d", + (int)ROW_SIZE(nc), (int)ROW_SIZE(na)); } - //CHECK_LEADING_GE("ldc",g.ldc,"m",ma); - } - { - VALUE ans = na_ndloop3(&ndf, &g, 3, a, b, c); - - if (ndf.nout == 1) { // c is not given. - return ans; - } else { - return c; + if (COL_SIZE(nc) != COL_SIZE(nb)) { + rb_raise(cumo_na_eShapeError,"COL_SIZE(c)=%d must equal to COL_SIZE(b)=%d", + (int)COL_SIZE(nc), (int)COL_SIZE(nc)); } } -} -#undef args_t + <%=c_iter%>(a, b, c, &g); + return c; +}