ext/cumo/cuda/cublas.c in cumo-0.1.0 vs ext/cumo/cuda/cublas.c in cumo-0.1.1

- old
+ new

@@ -2,64 +2,82 @@ #include <assert.h> #include <ruby.h> #include "cumo/narray.h" #include "cumo/template.h" +#include "cumo/cuda/runtime.h" -//static void *blas_handle = 0; -//static char *blas_prefix = 0; +VALUE cumo_cuda_eCublasError; +VALUE cumo_cuda_mCublas; +#define eCublasError cumo_cuda_eCublasError +#define mCublas cumo_cuda_mCublas +static char* +get_cublas_error_msg(cublasStatus_t error) { + switch (error) { +#define RETURN_MSG(msg) \ + case msg: \ + return #msg + + RETURN_MSG(CUBLAS_STATUS_SUCCESS); + RETURN_MSG(CUBLAS_STATUS_NOT_INITIALIZED); + RETURN_MSG(CUBLAS_STATUS_ALLOC_FAILED); + RETURN_MSG(CUBLAS_STATUS_INVALID_VALUE); + RETURN_MSG(CUBLAS_STATUS_ARCH_MISMATCH); + RETURN_MSG(CUBLAS_STATUS_MAPPING_ERROR); + RETURN_MSG(CUBLAS_STATUS_EXECUTION_FAILED); + RETURN_MSG(CUBLAS_STATUS_INTERNAL_ERROR); + RETURN_MSG(CUBLAS_STATUS_NOT_SUPPORTED); + RETURN_MSG(CUBLAS_STATUS_LICENSE_ERROR); + +#undef RETURN_MSG + } + abort(); // never reach +} + +void +cumo_cuda_cublas_check_status(cublasStatus_t status) +{ + if (status != 0) { + rb_raise(cumo_cuda_eCublasError, "%s (error=%d)", get_cublas_error_msg(status), status); + } +} + +// Lazily initialize cublas handle, and cache it +cublasHandle_t +cumo_cuda_cublas_handle() +{ + static cublasHandle_t *handles = 0; // handle is never destroyed + if (handles == 0) { + int i; + int device_count = cumo_cuda_runtime_get_device_count(); + handles = malloc(sizeof(cublasHandle_t) * device_count); + for (i = 0; i < device_count; ++i) { + handles[i] = 0; + } + } + int device = cumo_cuda_runtime_get_device(); + if (handles[device] == 0) { + cublasCreate(&handles[device]); + } + return handles[device]; +} + VALUE -cumo_cublas_option_value(VALUE value, VALUE default_value) +cumo_cuda_cublas_option_value(VALUE value, VALUE default_value) { switch(TYPE(value)) { case T_NIL: case T_UNDEF: return default_value; } return value; } -//enum CBLAS_ORDER -//cumo_cublas_option_order(VALUE order) -//{ -// int opt; -// char *ptr; -// -// switch(TYPE(order)) { -// case T_NIL: -// case T_UNDEF: -// case T_FALSE: -// return CblasRowMajor; -// case T_TRUE: -// return CblasColMajor; -// case T_FIXNUM: -// opt = FIX2INT(order); -// if (opt >= CblasRowMajor && opt <= CblasColMajor) { -// return opt; -// } -// break; -// case T_SYMBOL: -// order = rb_sym2str(order); -// case T_STRING: -// ptr = RSTRING_PTR(order); -// if (RSTRING_LEN(order) > 0) { -// switch(ptr[0]){ -// case 'R': case 'r': -// return CblasRowMajor; -// case 'C': case 'c': -// return CblasColMajor; -// } -// } -// break; -// } -// rb_raise(rb_eArgError,"invalid value for CBLAS_ORDER"); -// return 0; -//} - +#if 0 cublasOperation_t -cumo_cublas_option_trans(VALUE trans) +cumo_cuda_cublas_option_trans(VALUE trans) { int opt; char *ptr; switch(TYPE(trans)) { @@ -92,187 +110,19 @@ break; } rb_raise(rb_eArgError, "invalid value for cublasOperation_t"); return 0; } +#endif -cublasFillMode_t -cumo_cublas_option_uplo(VALUE uplo) +void +Init_cumo_cuda_cublas(void) { - int opt; - char *ptr; + VALUE mCumo = rb_define_module("Cumo"); + VALUE mCUDA = rb_define_module_under(mCumo, "CUDA"); - switch(TYPE(uplo)) { - case T_NIL: - case T_UNDEF: - case T_FALSE: - return CUBLAS_FILL_MODE_UPPER; - case T_TRUE: - return CUBLAS_FILL_MODE_LOWER; - case T_FIXNUM: - opt = FIX2INT(uplo); - switch(opt){ - case CUBLAS_FILL_MODE_UPPER: - case CUBLAS_FILL_MODE_LOWER: - return opt; - } - break; - case T_SYMBOL: - uplo = rb_sym2str(uplo); - case T_STRING: - ptr = RSTRING_PTR(uplo); - if (RSTRING_LEN(uplo) > 0) { - switch(ptr[0]){ - case 'U': case 'u': - return CUBLAS_FILL_MODE_UPPER; - case 'L': case 'l': - return CUBLAS_FILL_MODE_LOWER; - } - } - break; - } - rb_raise(rb_eArgError, "invalid value for cublasFillMode_t"); - return 0; + /* + Document-module: Cumo::Cublas + */ + mCublas = rb_define_module_under(mCUDA, "Cublas"); + eCublasError = rb_define_class_under(mCUDA, "CublasError", rb_eStandardError); } - -cublasDiagType_t -cumo_cublas_option_diag(VALUE diag) -{ - int opt; - char *ptr; - - switch(TYPE(diag)) { - case T_NIL: - case T_UNDEF: - case T_FALSE: - return CUBLAS_DIAG_NON_UNIT; - case T_TRUE: - return CUBLAS_DIAG_UNIT; - case T_FIXNUM: - opt = FIX2INT(diag); - switch(opt){ - case CUBLAS_DIAG_NON_UNIT: - case CUBLAS_DIAG_UNIT: - return opt; - } - break; - case T_SYMBOL: - diag = rb_sym2str(diag); - case T_STRING: - ptr = RSTRING_PTR(diag); - if (RSTRING_LEN(diag) > 0) { - switch(ptr[0]){ - case 'N': case 'n': - return CUBLAS_DIAG_NON_UNIT; - case 'U': case 'u': - return CUBLAS_DIAG_UNIT; - } - } - break; - } - rb_raise(rb_eArgError, "invalid value for cublasDiagType_t"); - return 0; -} - -cublasSideMode_t -cumo_cublas_option_side(VALUE side) -{ - int opt; - char *ptr; - - switch(TYPE(side)) { - case T_NIL: - case T_UNDEF: - case T_FALSE: - return CUBLAS_SIDE_LEFT; - case T_TRUE: - return CUBLAS_SIDE_RIGHT; - case T_FIXNUM: - opt = FIX2INT(side); - switch(opt){ - case CUBLAS_SIDE_LEFT: - case CUBLAS_SIDE_RIGHT: - return opt; - } - break; - case T_SYMBOL: - side = rb_sym2str(side); - case T_STRING: - ptr = RSTRING_PTR(side); - if (RSTRING_LEN(side) > 0) { - switch(ptr[0]){ - case 'L': case 'l': - return CUBLAS_SIDE_LEFT; - case 'R': case 'r': - return CUBLAS_SIDE_RIGHT; - } - } - break; - } - rb_raise(rb_eArgError, "invalid value for cublasSideMode_t"); - return 0; -} - -//void -//cumo_cublas_check_func(void **func, const char *name) -//{ -// char *s, *error; -// -// if (*func==0) { -// if (blas_handle==0) { -// rb_raise(rb_eRuntimeError,"BLAS library is not loaded"); -// } -// if (blas_prefix==0) { -// rb_raise(rb_eRuntimeError,"CBLAS prefix is not set"); -// } -// s = alloca(strlen(blas_prefix)+strlen(name)+1); -// strcpy(s,blas_prefix); -// strcat(s,name); -// dlerror(); -// *func = dlsym(blas_handle, s); -// error = dlerror(); -// if (error != NULL) { -// rb_raise(rb_eRuntimeError, "%s", error); -// } -// } -//} - -//static VALUE -//blas_s_prefix_set(VALUE mod, VALUE prefix) -//{ -// long len; -// -// if (TYPE(prefix) != T_STRING) { -// rb_raise(rb_eTypeError,"argument must be string"); -// } -// if (blas_prefix) { -// free(blas_prefix); -// } -// len = RSTRING_LEN(prefix); -// blas_prefix = malloc(len+1); -// strcpy(blas_prefix, StringValueCStr(prefix)); -// return prefix; -//} - -//void -//Init_blas(void) -//{ -// VALUE mN; -// -// mN = rb_define_module("Numo"); -// /* -// Document-module: Numo::Linalg -// */ -// mLinalg = rb_define_module_under(mN, "Linalg"); -// mBlas = rb_define_module_under(mLinalg, "Blas"); -// -// rb_define_module_function(mBlas, "dlopen", blas_s_dlopen, -1); -// rb_define_module_function(mBlas, "prefix=", blas_s_prefix_set, 1); -// -// blas_prefix = malloc(strlen("cublas_")+1); // default prefix -// strcpy(blas_prefix,"cublas_"); -// -// Init_cumo_linalg_blas_s(); -// Init_cumo_linalg_blas_d(); -// Init_cumo_linalg_blas_c(); -// Init_cumo_linalg_blas_z(); -//}