ext/mpi/mpi.c in ruby-mpi-0.1.0 vs ext/mpi/mpi.c in ruby-mpi-0.2.0

- old
+ new

@@ -51,26 +51,47 @@ static VALUE cComm, cRequest, cOp, cErrhandler, cStatus; static VALUE eBUFFER, eCOUNT, eTYPE, eTAG, eCOMM, eRANK, eREQUEST, eROOT, eGROUP, eOP, eTOPOLOGY, eDIMS, eARG, eUNKNOWN, eTRUNCATE, eOTHER, eINTERN, eIN_STATUS, ePENDING, eACCESS, eAMODE, eASSERT, eBAD_FILE, eBASE, eCONVERSION, eDISP, eDUP_DATAREP, eFILE_EXISTS, eFILE_IN_USE, eFILE, eINFO_KEY, eINFO_NOKEY, eINFO_VALUE, eINFO, eIO, eKEYVAL, eLOCKTYPE, eNAME, eNO_MEM, eNOT_SAME, eNO_SPACE, eNO_SUCH_FILE, ePORT, eQUOTA, eREAD_ONLY, eRMA_CONFLICT, eRMA_SYNC, eSERVICE, eSIZE, eSPAWN, eUNSUPPORTED_DATAREP, eUNSUPPORTED_OPERATION, eWIN, eLASTCODE, eSYSRESOURCE; struct _Comm { - MPI_Comm comm; + MPI_Comm Comm; }; struct _Request { - MPI_Request request; + MPI_Request Request; }; struct _Op { - MPI_Op op; + MPI_Op Op; }; struct _Errhandler { - MPI_Errhandler errhandler; + MPI_Errhandler Errhandler; }; static bool _initialized = false; static bool _finalized = false; +#define DEF_FREE(name) \ +static void \ +name ## _free(void *ptr)\ +{\ + struct _ ## name *obj;\ + obj = (struct _ ## name*) ptr;\ + if (!_finalized)\ + MPI_ ## name ## _free(&(obj->name)); \ + free(obj);\ +} +DEF_FREE(Comm) +DEF_FREE(Request) +DEF_FREE(Op) +DEF_FREE(Errhandler) +static void +Status_free(void *ptr) +{ + free((MPI_Status*) ptr); +} + + #define CAE_ERR(type) case MPI_ERR_ ## type: rb_raise(e ## type,""); break static void check_error(int error) { switch (error) { @@ -127,37 +148,42 @@ CAE_ERR(SPAWN); CAE_ERR(UNSUPPORTED_DATAREP); CAE_ERR(UNSUPPORTED_OPERATION); CAE_ERR(WIN); CAE_ERR(LASTCODE); +#ifdef MPI_ERR_SYSRESOURCE CAE_ERR(SYSRESOURCE); +#endif default: rb_raise(rb_eRuntimeError, "unknown error"); } } -#define DEF_CONST(st, v, const, name, klass) \ +#define DEF_CONST(v, const, name) \ {\ - v = ALLOC(struct st);\ + v = ALLOC(struct _ ## v);\ v->v = const;\ - rb_define_const(klass, #name, Data_Wrap_Struct(klass, 0, -1, v)); \ + rb_define_const(c ## v, #name, Data_Wrap_Struct(c ## v, NULL, v ## _free, v)); \ } +static void +_finalize() +{ + if(_initialized && !_finalized) { + _finalized = true; + check_error(MPI_Finalize()); + } +} static VALUE rb_m_init(int argc, VALUE *argv, VALUE self) { VALUE argary; int cargc; char ** cargv; VALUE progname; int i; - if (_initialized) - return self; - else - _initialized = true; - rb_scan_args(argc, argv, "01", &argary); if (NIL_P(argary)) { argary = rb_const_get(rb_cObject, rb_intern("ARGV")); cargc = RARRAY_LEN(argary); @@ -176,49 +202,48 @@ else cargv[i+1] = (char*)""; } cargc++; - MPI_Init(&cargc, &cargv); + check_error(MPI_Init(&cargc, &cargv)); + if (_initialized) + return self; + else + _initialized = true; + atexit(_finalize); + + // define MPI::Comm::WORLD - struct _Comm *comm; - DEF_CONST(_Comm, comm, MPI_COMM_WORLD, WORLD, cComm); + struct _Comm *Comm; + DEF_CONST(Comm, MPI_COMM_WORLD, WORLD); MPI_Errhandler_set(MPI_COMM_WORLD, MPI_ERRORS_RETURN); // define MPI::Op::??? - struct _Op *op; - DEF_CONST(_Op, op, MPI_MAX, MAX, cOp); - DEF_CONST(_Op, op, MPI_MIN, MIN, cOp); - DEF_CONST(_Op, op, MPI_SUM, SUM, cOp); - DEF_CONST(_Op, op, MPI_PROD, PROD, cOp); - DEF_CONST(_Op, op, MPI_LAND, LAND, cOp); - DEF_CONST(_Op, op, MPI_BAND, BAND, cOp); - DEF_CONST(_Op, op, MPI_LOR, LOR, cOp); - DEF_CONST(_Op, op, MPI_BOR, BOR, cOp); - DEF_CONST(_Op, op, MPI_LXOR, LXOR, cOp); - DEF_CONST(_Op, op, MPI_BXOR, BXOR, cOp); - DEF_CONST(_Op, op, MPI_MAXLOC, MAXLOC, cOp); - DEF_CONST(_Op, op, MPI_MINLOC, MINLOC, cOp); - DEF_CONST(_Op, op, MPI_REPLACE, REPLACE, cOp); + struct _Op *Op; + DEF_CONST(Op, MPI_MAX, MAX); + DEF_CONST(Op, MPI_MIN, MIN); + DEF_CONST(Op, MPI_SUM, SUM); + DEF_CONST(Op, MPI_PROD, PROD); + DEF_CONST(Op, MPI_LAND, LAND); + DEF_CONST(Op, MPI_BAND, BAND); + DEF_CONST(Op, MPI_LOR, LOR); + DEF_CONST(Op, MPI_BOR, BOR); + DEF_CONST(Op, MPI_LXOR, LXOR); + DEF_CONST(Op, MPI_BXOR, BXOR); + DEF_CONST(Op, MPI_MAXLOC, MAXLOC); + DEF_CONST(Op, MPI_MINLOC, MINLOC); + DEF_CONST(Op, MPI_REPLACE, REPLACE); // define MPI::Errhandler::ERRORS_ARE_FATAL, ERRORS_RETURN - struct _Errhandler *errhandler; - DEF_CONST(_Errhandler, errhandler, MPI_ERRORS_ARE_FATAL, ERRORS_ARE_FATAL, cErrhandler); - DEF_CONST(_Errhandler, errhandler, MPI_ERRORS_RETURN, ERRORS_RETURN, cErrhandler); + struct _Errhandler *Errhandler; + DEF_CONST(Errhandler, MPI_ERRORS_ARE_FATAL, ERRORS_ARE_FATAL); + DEF_CONST(Errhandler, MPI_ERRORS_RETURN, ERRORS_RETURN); return self; } -static void -_finalize() -{ - if(_initialized && !_finalized) { - _finalized = true; - check_error(MPI_Finalize()); - } -} static VALUE rb_m_finalize(VALUE self) { _finalize(); return self; @@ -228,11 +253,11 @@ // MPI::Comm static VALUE rb_comm_alloc(VALUE klass) { struct _Comm *ptr = ALLOC(struct _Comm); - return Data_Wrap_Struct(klass, 0, -1, ptr); + return Data_Wrap_Struct(klass, NULL, Comm_free, ptr); } static VALUE rb_comm_initialize(VALUE self) { rb_raise(rb_eRuntimeError, "not developed yet"); @@ -242,20 +267,20 @@ rb_comm_size(VALUE self) { struct _Comm *comm; int size; Data_Get_Struct(self, struct _Comm, comm); - check_error(MPI_Comm_size(comm->comm, &size)); + check_error(MPI_Comm_size(comm->Comm, &size)); return INT2NUM(size); } static VALUE rb_comm_rank(VALUE self) { struct _Comm *comm; int rank; Data_Get_Struct(self, struct _Comm, comm); - check_error(MPI_Comm_rank(comm->comm, &rank)); + check_error(MPI_Comm_rank(comm->Comm, &rank)); return INT2NUM(rank); } static VALUE rb_comm_send(VALUE self, VALUE rb_obj, VALUE rb_dest, VALUE rb_tag) { @@ -266,11 +291,11 @@ OBJ2C(rb_obj, len, buffer, type); dest = NUM2INT(rb_dest); tag = NUM2INT(rb_tag); Data_Get_Struct(self, struct _Comm, comm); - check_error(MPI_Send(buffer, len, type, dest, tag, comm->comm)); + check_error(MPI_Send(buffer, len, type, dest, tag, comm->Comm)); return Qnil; } static VALUE rb_comm_isend(VALUE self, VALUE rb_obj, VALUE rb_dest, VALUE rb_tag) @@ -284,12 +309,12 @@ OBJ2C(rb_obj, len, buffer, type); dest = NUM2INT(rb_dest); tag = NUM2INT(rb_tag); Data_Get_Struct(self, struct _Comm, comm); - rb_request = Data_Make_Struct(cRequest, struct _Request, 0, -1, request); - check_error(MPI_Isend(buffer, len, type, dest, tag, comm->comm, &(request->request))); + rb_request = Data_Make_Struct(cRequest, struct _Request, NULL, Request_free, request); + check_error(MPI_Isend(buffer, len, type, dest, tag, comm->Comm, &(request->Request))); return rb_request; } static VALUE rb_comm_recv(VALUE self, VALUE rb_obj, VALUE rb_source, VALUE rb_tag) @@ -304,13 +329,13 @@ source = NUM2INT(rb_source); tag = NUM2INT(rb_tag); Data_Get_Struct(self, struct _Comm, comm); status = ALLOC(MPI_Status); - check_error(MPI_Recv(buffer, len, type, source, tag, comm->comm, status)); + check_error(MPI_Recv(buffer, len, type, source, tag, comm->Comm, status)); - return Data_Wrap_Struct(cStatus, 0, -1, status); + return Data_Wrap_Struct(cStatus, NULL, Status_free, status); } static VALUE rb_comm_irecv(VALUE self, VALUE rb_obj, VALUE rb_source, VALUE rb_tag) { void* buffer; @@ -322,12 +347,12 @@ OBJ2C(rb_obj, len, buffer, type); source = NUM2INT(rb_source); tag = NUM2INT(rb_tag); Data_Get_Struct(self, struct _Comm, comm); - rb_request = Data_Make_Struct(cRequest, struct _Request, 0, -1, request); - check_error(MPI_Irecv(buffer, len, type, source, tag, comm->comm, &(request->request))); + rb_request = Data_Make_Struct(cRequest, struct _Request, NULL, Request_free, request); + check_error(MPI_Irecv(buffer, len, type, source, tag, comm->Comm, &(request->Request))); return rb_request; } static VALUE rb_comm_gather(VALUE self, VALUE rb_sendbuf, VALUE rb_recvbuf, VALUE rb_root) @@ -338,19 +363,19 @@ int root, rank, size; struct _Comm *comm; OBJ2C(rb_sendbuf, sendcount, sendbuf, sendtype); root = NUM2INT(rb_root); Data_Get_Struct(self, struct _Comm, comm); - check_error(MPI_Comm_rank(comm->comm, &rank)); - check_error(MPI_Comm_size(comm->comm, &size)); + check_error(MPI_Comm_rank(comm->Comm, &rank)); + check_error(MPI_Comm_size(comm->Comm, &size)); if (rank == root) { OBJ2C(rb_recvbuf, recvcount, recvbuf, recvtype); if (recvcount < sendcount*size) rb_raise(rb_eArgError, "recvbuf is too small"); recvcount = sendcount; } - check_error(MPI_Gather(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm->comm)); + check_error(MPI_Gather(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm->Comm)); return Qnil; } static VALUE rb_comm_allgather(VALUE self, VALUE rb_sendbuf, VALUE rb_recvbuf) { @@ -359,17 +384,17 @@ MPI_Datatype sendtype, recvtype; int rank, size; struct _Comm *comm; OBJ2C(rb_sendbuf, sendcount, sendbuf, sendtype); Data_Get_Struct(self, struct _Comm, comm); - check_error(MPI_Comm_rank(comm->comm, &rank)); - check_error(MPI_Comm_size(comm->comm, &size)); + check_error(MPI_Comm_rank(comm->Comm, &rank)); + check_error(MPI_Comm_size(comm->Comm, &size)); OBJ2C(rb_recvbuf, recvcount, recvbuf, recvtype); if (recvcount < sendcount*size) rb_raise(rb_eArgError, "recvbuf is too small"); recvcount = sendcount; - check_error(MPI_Allgather(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm->comm)); + check_error(MPI_Allgather(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm->Comm)); return Qnil; } static VALUE rb_comm_bcast(VALUE self, VALUE rb_buffer, VALUE rb_root) { @@ -379,11 +404,11 @@ int root; struct _Comm *comm; OBJ2C(rb_buffer, count, buffer, type); root = NUM2INT(rb_root); Data_Get_Struct(self, struct _Comm, comm); - check_error(MPI_Bcast(buffer, count, type, root, comm->comm)); + check_error(MPI_Bcast(buffer, count, type, root, comm->Comm)); return Qnil; } static VALUE rb_comm_scatter(VALUE self, VALUE rb_sendbuf, VALUE rb_recvbuf, VALUE rb_root) { @@ -393,39 +418,61 @@ int root, rank, size; struct _Comm *comm; OBJ2C(rb_recvbuf, recvcount, recvbuf, recvtype); root = NUM2INT(rb_root); Data_Get_Struct(self, struct _Comm, comm); - check_error(MPI_Comm_rank(comm->comm, &rank)); - check_error(MPI_Comm_size(comm->comm, &size)); + check_error(MPI_Comm_rank(comm->Comm, &rank)); + check_error(MPI_Comm_size(comm->Comm, &size)); if (rank == root) { OBJ2C(rb_sendbuf, sendcount, sendbuf, sendtype); if (sendcount > recvcount*size) rb_raise(rb_eArgError, "recvbuf is too small"); sendcount = recvcount; } - check_error(MPI_Scatter(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm->comm)); + check_error(MPI_Scatter(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm->Comm)); return Qnil; } static VALUE +rb_comm_sendrecv(VALUE self, VALUE rb_sendbuf, VALUE rb_dest, VALUE rb_sendtag, VALUE rb_recvbuf, VALUE rb_source, VALUE rb_recvtag) +{ + void *sendbuf, *recvbuf; + int sendcount, recvcount; + MPI_Datatype sendtype, recvtype; + int dest, source; + int sendtag, recvtag; + int size; + struct _Comm *comm; + MPI_Status *status; + OBJ2C(rb_sendbuf, sendcount, sendbuf, sendtype); + Data_Get_Struct(self, struct _Comm, comm); + check_error(MPI_Comm_size(comm->Comm, &size)); + OBJ2C(rb_recvbuf, recvcount, recvbuf, recvtype); + dest = NUM2INT(rb_dest); + source = NUM2INT(rb_source); + sendtag = NUM2INT(rb_sendtag); + recvtag = NUM2INT(rb_recvtag); + status = ALLOC(MPI_Status); + check_error(MPI_Sendrecv(sendbuf, sendcount, sendtype, dest, sendtag, recvbuf, recvcount, recvtype, source, recvtag, comm->Comm, status)); + return Data_Wrap_Struct(cStatus, NULL, Status_free, status); +} +static VALUE rb_comm_alltoall(VALUE self, VALUE rb_sendbuf, VALUE rb_recvbuf) { void *sendbuf, *recvbuf; int sendcount, recvcount; MPI_Datatype sendtype, recvtype; - int rank, size; + int size; struct _Comm *comm; OBJ2C(rb_sendbuf, sendcount, sendbuf, sendtype); Data_Get_Struct(self, struct _Comm, comm); - check_error(MPI_Comm_rank(comm->comm, &rank)); - check_error(MPI_Comm_size(comm->comm, &size)); + check_error(MPI_Comm_size(comm->Comm, &size)); OBJ2C(rb_recvbuf, recvcount, recvbuf, recvtype); if (recvcount < sendcount) rb_raise(rb_eArgError, "recvbuf is too small"); recvcount = recvcount/size; sendcount = sendcount/size; - check_error(MPI_Alltoall(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm->comm)); + check_error(MPI_Alltoall(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm->Comm)); return Qnil; } static VALUE rb_comm_reduce(VALUE self, VALUE rb_sendbuf, VALUE rb_recvbuf, VALUE rb_op, VALUE rb_root) { @@ -436,21 +483,21 @@ struct _Comm *comm; struct _Op *op; OBJ2C(rb_sendbuf, sendcount, sendbuf, sendtype); root = NUM2INT(rb_root); Data_Get_Struct(self, struct _Comm, comm); - check_error(MPI_Comm_rank(comm->comm, &rank)); - check_error(MPI_Comm_size(comm->comm, &size)); + check_error(MPI_Comm_rank(comm->Comm, &rank)); + check_error(MPI_Comm_size(comm->Comm, &size)); if (rank == root) { OBJ2C(rb_recvbuf, recvcount, recvbuf, recvtype); if (recvcount != sendcount) rb_raise(rb_eArgError, "sendbuf and recvbuf has the same length"); if (recvtype != sendtype) rb_raise(rb_eArgError, "sendbuf and recvbuf has the same type"); } Data_Get_Struct(rb_op, struct _Op, op); - check_error(MPI_Reduce(sendbuf, recvbuf, sendcount, sendtype, op->op, root, comm->comm)); + check_error(MPI_Reduce(sendbuf, recvbuf, sendcount, sendtype, op->Op, root, comm->Comm)); return Qnil; } static VALUE rb_comm_allreduce(VALUE self, VALUE rb_sendbuf, VALUE rb_recvbuf, VALUE rb_op) { @@ -460,65 +507,73 @@ int rank, size; struct _Comm *comm; struct _Op *op; OBJ2C(rb_sendbuf, sendcount, sendbuf, sendtype); Data_Get_Struct(self, struct _Comm, comm); - check_error(MPI_Comm_rank(comm->comm, &rank)); - check_error(MPI_Comm_size(comm->comm, &size)); + check_error(MPI_Comm_rank(comm->Comm, &rank)); + check_error(MPI_Comm_size(comm->Comm, &size)); OBJ2C(rb_recvbuf, recvcount, recvbuf, recvtype); if (recvcount != sendcount) rb_raise(rb_eArgError, "sendbuf and recvbuf has the same length"); if (recvtype != sendtype) rb_raise(rb_eArgError, "sendbuf and recvbuf has the same type"); Data_Get_Struct(rb_op, struct _Op, op); - check_error(MPI_Allreduce(sendbuf, recvbuf, recvcount, recvtype, op->op, comm->comm)); + check_error(MPI_Allreduce(sendbuf, recvbuf, recvcount, recvtype, op->Op, comm->Comm)); return Qnil; } static VALUE rb_comm_get_Errhandler(VALUE self) { struct _Comm *comm; struct _Errhandler *errhandler; VALUE rb_errhandler; Data_Get_Struct(self, struct _Comm, comm); - rb_errhandler = Data_Make_Struct(cErrhandler, struct _Errhandler, 0, -1, errhandler); - MPI_Comm_get_errhandler(comm->comm, &(errhandler->errhandler)); + rb_errhandler = Data_Make_Struct(cErrhandler, struct _Errhandler, NULL, Errhandler_free, errhandler); + MPI_Comm_get_errhandler(comm->Comm, &(errhandler->Errhandler)); return rb_errhandler; } static VALUE rb_comm_set_Errhandler(VALUE self, VALUE rb_errhandler) { struct _Comm *comm; struct _Errhandler *errhandler; Data_Get_Struct(self, struct _Comm, comm); Data_Get_Struct(rb_errhandler, struct _Errhandler, errhandler); - MPI_Comm_set_errhandler(comm->comm, errhandler->errhandler); + MPI_Comm_set_errhandler(comm->Comm, errhandler->Errhandler); return self; } +static VALUE +rb_comm_barrier(VALUE self) +{ + struct _Comm *comm; + Data_Get_Struct(self, struct _Comm, comm); + check_error(MPI_Barrier(comm->Comm)); + return self; +} // MPI::Request static VALUE rb_request_wait(VALUE self) { MPI_Status *status; struct _Request *request; Data_Get_Struct(self, struct _Request, request); status = ALLOC(MPI_Status); - check_error(MPI_Wait(&(request->request), status)); - return Data_Wrap_Struct(cStatus, 0, -1, status); + check_error(MPI_Wait(&(request->Request), status)); + return Data_Wrap_Struct(cStatus, NULL, Status_free, status); } // MPI::Errhandler static VALUE rb_errhandler_eql(VALUE self, VALUE other) { struct _Errhandler *eh0, *eh1; Data_Get_Struct(self, struct _Errhandler, eh0); Data_Get_Struct(other, struct _Errhandler, eh1); - return eh0->errhandler == eh1->errhandler ? Qtrue : Qfalse; + return eh0->Errhandler == eh1->Errhandler ? Qtrue : Qfalse; } // MPI::Status static VALUE rb_status_source(VALUE self) @@ -546,19 +601,18 @@ void Init_mpi() { rb_require("narray"); - atexit(_finalize); - // MPI mMPI = rb_define_module("MPI"); rb_define_module_function(mMPI, "Init", rb_m_init, -1); rb_define_module_function(mMPI, "Finalize", rb_m_finalize, -1); rb_define_const(mMPI, "VERSION", INT2NUM(MPI_VERSION)); rb_define_const(mMPI, "SUBVERSION", INT2NUM(MPI_SUBVERSION)); rb_define_const(mMPI, "SUCCESS", INT2NUM(MPI_SUCCESS)); + rb_define_const(mMPI, "PROC_NULL", INT2NUM(MPI_PROC_NULL)); // MPI::Comm cComm = rb_define_class_under(mMPI, "Comm", rb_cObject); // rb_define_alloc_func(cComm, rb_comm_alloc); rb_define_private_method(cComm, "initialize", rb_comm_initialize, 0); @@ -570,14 +624,16 @@ rb_define_method(cComm, "Irecv", rb_comm_irecv, 3); rb_define_method(cComm, "Gather", rb_comm_gather, 3); rb_define_method(cComm, "Allgather", rb_comm_allgather, 2); rb_define_method(cComm, "Bcast", rb_comm_bcast, 2); rb_define_method(cComm, "Scatter", rb_comm_scatter, 3); + rb_define_method(cComm, "Sendrecv", rb_comm_sendrecv, 6); rb_define_method(cComm, "Alltoall", rb_comm_alltoall, 2); rb_define_method(cComm, "Reduce", rb_comm_reduce, 4); rb_define_method(cComm, "Allreduce", rb_comm_allreduce, 3); rb_define_method(cComm, "Errhandler", rb_comm_get_Errhandler, 0); rb_define_method(cComm, "Errhandler=", rb_comm_set_Errhandler, 1); + rb_define_method(cComm, "Barrier", rb_comm_barrier, 0); // MPI::Request cRequest = rb_define_class_under(mMPI, "Request", rb_cObject); rb_define_method(cRequest, "Wait", rb_request_wait, 0);