ext/mpi/mpi.c in ruby-mpi-0.2.0 vs ext/mpi/mpi.c in ruby-mpi-0.3.0
- old
+ new
@@ -4,42 +4,49 @@
#include "ruby.h"
#include "narray.h"
#include "mpi.h"
-#define OBJ2C(rb_obj, len, buffer, typ) \
+#define OBJ2C(rb_obj, len, buffer, typ, off) \
{\
if (TYPE(rb_obj) == T_STRING) {\
- len = RSTRING_LEN(rb_obj);\
- buffer = (void*)StringValuePtr(rb_obj);\
- typ = MPI_CHAR;\
- } else if (IsNArray(rb_obj)) { \
+ if (len==0) len = RSTRING_LEN(rb_obj);\
+ buffer = (void*)(StringValuePtr(rb_obj) + off);\
+ typ = MPI_BYTE;\
+ } else if (IsNArray(rb_obj)) {\
struct NARRAY *a;\
GetNArray(rb_obj, a);\
buffer = (void*)(a->ptr);\
- len = a->total;\
+ if (len==0) len = a->total;\
switch (a->type) {\
case NA_BYTE:\
typ = MPI_BYTE;\
+ buffer = (void*)((char*)buffer + off);\
break;\
case NA_SINT:\
typ = MPI_SHORT;\
+ buffer = (void*)((char*)buffer + off*4);\
break;\
case NA_LINT:\
typ = MPI_LONG;\
+ buffer = (void*)((char*)buffer + off*8);\
break;\
case NA_SFLOAT:\
typ = MPI_FLOAT;\
+ buffer = (void*)((char*)buffer + off*4);\
break;\
case NA_DFLOAT:\
typ = MPI_DOUBLE;\
+ buffer = (void*)((char*)buffer + off*8);\
break;\
case NA_SCOMPLEX:\
typ = MPI_2COMPLEX;\
+ buffer = (void*)((char*)buffer + off*8);\
break;\
case NA_DCOMPLEX:\
typ = MPI_2DOUBLE_COMPLEX;\
+ buffer = (void*)((char*)buffer + off*16);\
break;\
default:\
rb_raise(rb_eArgError, "narray type is invalid");\
}\
} else {\
@@ -52,52 +59,40 @@
static VALUE eBUFFER, eCOUNT, eTYPE, eTAG, eCOMM, eRANK, eREQUEST, eROOT, eGROUP, eOP, eTOPOLOGY, eDIMS, eARG, eUNKNOWN, eTRUNCATE, eOTHER, eINTERN, eIN_STATUS, ePENDING, eACCESS, eAMODE, eASSERT, eBAD_FILE, eBASE, eCONVERSION, eDISP, eDUP_DATAREP, eFILE_EXISTS, eFILE_IN_USE, eFILE, eINFO_KEY, eINFO_NOKEY, eINFO_VALUE, eINFO, eIO, eKEYVAL, eLOCKTYPE, eNAME, eNO_MEM, eNOT_SAME, eNO_SPACE, eNO_SUCH_FILE, ePORT, eQUOTA, eREAD_ONLY, eRMA_CONFLICT, eRMA_SYNC, eSERVICE, eSIZE, eSPAWN, eUNSUPPORTED_DATAREP, eUNSUPPORTED_OPERATION, eWIN, eLASTCODE, eSYSRESOURCE;
struct _Comm {
MPI_Comm Comm;
+ bool free;
};
struct _Request {
MPI_Request Request;
+ bool free;
};
struct _Op {
MPI_Op Op;
+ bool free;
};
struct _Errhandler {
MPI_Errhandler Errhandler;
+ bool free;
};
static bool _initialized = false;
static bool _finalized = false;
-#define DEF_FREE(name) \
-static void \
-name ## _free(void *ptr)\
-{\
- struct _ ## name *obj;\
- obj = (struct _ ## name*) ptr;\
- if (!_finalized)\
- MPI_ ## name ## _free(&(obj->name)); \
- free(obj);\
-}
-DEF_FREE(Comm)
-DEF_FREE(Request)
-DEF_FREE(Op)
-DEF_FREE(Errhandler)
+#define CAE_ERR(type) case MPI_ERR_ ## type: rb_raise(e ## type,"%s",str); break
static void
-Status_free(void *ptr)
-{
- free((MPI_Status*) ptr);
-}
-
-
-#define CAE_ERR(type) case MPI_ERR_ ## type: rb_raise(e ## type,""); break
-static void
check_error(int error)
{
- switch (error) {
- case MPI_SUCCESS: break;
+ if (error == MPI_SUCCESS) return;
+ int code, len;
+ char str[MPI_MAX_ERROR_STRING];
+ if (MPI_Error_class(error, &code)!=MPI_SUCCESS || MPI_Error_string(error, str, &len)!=MPI_SUCCESS)
+ rb_raise(rb_eRuntimeError, "unknown error occuerd in MPI call");
+
+ switch (code) {
CAE_ERR(BUFFER);
CAE_ERR(COUNT);
CAE_ERR(TYPE);
CAE_ERR(TAG);
CAE_ERR(COMM);
@@ -152,19 +147,52 @@
CAE_ERR(LASTCODE);
#ifdef MPI_ERR_SYSRESOURCE
CAE_ERR(SYSRESOURCE);
#endif
default:
- rb_raise(rb_eRuntimeError, "unknown error");
+ rb_raise(rb_eRuntimeError, "unknown error: %d", code);
}
}
+#define DEF_FREE(name, capit) \
+static void \
+name ## _free(void *ptr)\
+{\
+ struct _ ## name *obj;\
+ obj = (struct _ ## name*) ptr;\
+ if (!_finalized && obj->free && obj->name!=MPI_ ## capit ##_NULL)\
+ check_error(MPI_ ## name ## _free(&(obj->name))); \
+ free(obj);\
+}
+#define DEF_FREE2(name, capit) \
+static void \
+name ## _free2(void *ptr)\
+{\
+ struct _ ## name *obj;\
+ obj = (struct _ ## name*) ptr;\
+ free(obj);\
+}
+DEF_FREE(Comm, COMM)
+DEF_FREE(Request, REQUEST)
+DEF_FREE(Op, OP)
+DEF_FREE(Errhandler, ERRHANDLER)
+DEF_FREE2(Comm, COMM)
+DEF_FREE2(Op, OP)
+DEF_FREE2(Errhandler, ERRHANDLER)
+static void
+Status_free(void *ptr)
+{
+ free((MPI_Status*) ptr);
+}
+
+
#define DEF_CONST(v, const, name) \
{\
v = ALLOC(struct _ ## v);\
v->v = const;\
- rb_define_const(c ## v, #name, Data_Wrap_Struct(c ## v, NULL, v ## _free, v)); \
+ v->free = false;\
+ rb_define_const(c ## v, #name, Data_Wrap_Struct(c ## v, NULL, v ## _free2, v)); \
}
static void
_finalize()
{
@@ -214,11 +242,11 @@
// define MPI::Comm::WORLD
struct _Comm *Comm;
DEF_CONST(Comm, MPI_COMM_WORLD, WORLD);
- MPI_Errhandler_set(MPI_COMM_WORLD, MPI_ERRORS_RETURN);
+ check_error(MPI_Errhandler_set(MPI_COMM_WORLD, MPI_ERRORS_RETURN));
// define MPI::Op::???
struct _Op *Op;
DEF_CONST(Op, MPI_MAX, MAX);
DEF_CONST(Op, MPI_MIN, MIN);
@@ -247,11 +275,21 @@
{
_finalize();
return self;
}
+static VALUE
+rb_m_abort(VALUE self, VALUE rcomm, VALUE rerror)
+{
+ struct _Comm *comm;
+ int ierror;
+ Data_Get_Struct(rcomm, struct _Comm, comm);
+ ierror = MPI_Abort(comm->Comm, NUM2INT(rerror));
+ return INT2NUM(ierror);
+}
+
// MPI::Comm
static VALUE
rb_comm_alloc(VALUE klass)
{
struct _Comm *ptr = ALLOC(struct _Comm);
@@ -260,10 +298,11 @@
static VALUE
rb_comm_initialize(VALUE self)
{
rb_raise(rb_eRuntimeError, "not developed yet");
// MPI_Comm_create()
+ // comm->free = true;
}
static VALUE
rb_comm_size(VALUE self)
{
struct _Comm *comm;
@@ -283,15 +322,15 @@
}
static VALUE
rb_comm_send(VALUE self, VALUE rb_obj, VALUE rb_dest, VALUE rb_tag)
{
void* buffer;
- int len, dest, tag;
+ int len=0, dest, tag;
MPI_Datatype type;
struct _Comm *comm;
- OBJ2C(rb_obj, len, buffer, type);
+ OBJ2C(rb_obj, len, buffer, type, 0);
dest = NUM2INT(rb_dest);
tag = NUM2INT(rb_tag);
Data_Get_Struct(self, struct _Comm, comm);
check_error(MPI_Send(buffer, len, type, dest, tag, comm->Comm));
@@ -299,78 +338,103 @@
}
static VALUE
rb_comm_isend(VALUE self, VALUE rb_obj, VALUE rb_dest, VALUE rb_tag)
{
void* buffer;
- int len, dest, tag;
+ int len=0, dest, tag;
MPI_Datatype type;
struct _Comm *comm;
struct _Request *request;
VALUE rb_request;
- OBJ2C(rb_obj, len, buffer, type);
+ OBJ2C(rb_obj, len, buffer, type, 0);
dest = NUM2INT(rb_dest);
tag = NUM2INT(rb_tag);
Data_Get_Struct(self, struct _Comm, comm);
rb_request = Data_Make_Struct(cRequest, struct _Request, NULL, Request_free, request);
+ request->free = true;
check_error(MPI_Isend(buffer, len, type, dest, tag, comm->Comm, &(request->Request)));
return rb_request;
}
static VALUE
-rb_comm_recv(VALUE self, VALUE rb_obj, VALUE rb_source, VALUE rb_tag)
+rb_comm_recv(int argc, VALUE *argv, VALUE self)
{
+ VALUE rb_obj, rb_source, rb_tag;
+ VALUE rb_len, rb_offset; // option
void* buffer;
- int len, source, tag;
+ int source, tag, len = 0, offset = 0;
MPI_Datatype type;
MPI_Status *status;
struct _Comm *comm;
- OBJ2C(rb_obj, len, buffer, type);
+ rb_scan_args(argc, argv, "32", &rb_obj, &rb_source, &rb_tag, &rb_len, &rb_offset);
+
+ if (rb_len != Qnil) {
+ len = NUM2INT(rb_len);
+ }
+ if (rb_offset != Qnil) {
+ offset = NUM2INT(rb_offset);
+ }
+
+ OBJ2C(rb_obj, len, buffer, type, offset);
source = NUM2INT(rb_source);
tag = NUM2INT(rb_tag);
Data_Get_Struct(self, struct _Comm, comm);
status = ALLOC(MPI_Status);
check_error(MPI_Recv(buffer, len, type, source, tag, comm->Comm, status));
return Data_Wrap_Struct(cStatus, NULL, Status_free, status);
}
static VALUE
-rb_comm_irecv(VALUE self, VALUE rb_obj, VALUE rb_source, VALUE rb_tag)
+rb_comm_irecv(int argc, VALUE *argv, VALUE self)
{
+ VALUE rb_obj, rb_source, rb_tag;
+ VALUE rb_len, rb_offset; // option
void* buffer;
- int len, source, tag;
+ int source, tag, len = 0, offset = 0;
MPI_Datatype type;
struct _Comm *comm;
struct _Request *request;
VALUE rb_request;
- OBJ2C(rb_obj, len, buffer, type);
+ rb_scan_args(argc, argv, "32", &rb_obj, &rb_source, &rb_tag, &rb_len, &rb_offset);
+
+ if (rb_len != Qnil) {
+ len = NUM2INT(rb_len);
+ }
+ if (rb_offset != Qnil) {
+ offset = NUM2INT(rb_offset);
+ }
+
+ OBJ2C(rb_obj, len, buffer, type, offset);
source = NUM2INT(rb_source);
tag = NUM2INT(rb_tag);
+
Data_Get_Struct(self, struct _Comm, comm);
rb_request = Data_Make_Struct(cRequest, struct _Request, NULL, Request_free, request);
+ request->free = true;
check_error(MPI_Irecv(buffer, len, type, source, tag, comm->Comm, &(request->Request)));
return rb_request;
}
static VALUE
rb_comm_gather(VALUE self, VALUE rb_sendbuf, VALUE rb_recvbuf, VALUE rb_root)
{
void *sendbuf, *recvbuf = NULL;
- int sendcount, recvcount = 0;
- MPI_Datatype sendtype, recvtype = NULL;
+ int sendcount=0, recvcount = 0;
+ MPI_Datatype sendtype, recvtype = 0;
int root, rank, size;
struct _Comm *comm;
- OBJ2C(rb_sendbuf, sendcount, sendbuf, sendtype);
+ OBJ2C(rb_sendbuf, sendcount, sendbuf, sendtype, 0);
root = NUM2INT(rb_root);
Data_Get_Struct(self, struct _Comm, comm);
check_error(MPI_Comm_rank(comm->Comm, &rank));
check_error(MPI_Comm_size(comm->Comm, &size));
if (rank == root) {
- OBJ2C(rb_recvbuf, recvcount, recvbuf, recvtype);
+ OBJ2C(rb_recvbuf, recvcount, recvbuf, recvtype, 0);
if (recvcount < sendcount*size)
rb_raise(rb_eArgError, "recvbuf is too small");
recvcount = sendcount;
}
check_error(MPI_Gather(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm->Comm));
@@ -378,54 +442,54 @@
}
static VALUE
rb_comm_allgather(VALUE self, VALUE rb_sendbuf, VALUE rb_recvbuf)
{
void *sendbuf, *recvbuf;
- int sendcount, recvcount;
+ int sendcount=0, recvcount=0;
MPI_Datatype sendtype, recvtype;
int rank, size;
struct _Comm *comm;
- OBJ2C(rb_sendbuf, sendcount, sendbuf, sendtype);
+ OBJ2C(rb_sendbuf, sendcount, sendbuf, sendtype, 0);
Data_Get_Struct(self, struct _Comm, comm);
check_error(MPI_Comm_rank(comm->Comm, &rank));
check_error(MPI_Comm_size(comm->Comm, &size));
- OBJ2C(rb_recvbuf, recvcount, recvbuf, recvtype);
+ OBJ2C(rb_recvbuf, recvcount, recvbuf, recvtype, 0);
if (recvcount < sendcount*size)
rb_raise(rb_eArgError, "recvbuf is too small");
recvcount = sendcount;
check_error(MPI_Allgather(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm->Comm));
return Qnil;
}
static VALUE
rb_comm_bcast(VALUE self, VALUE rb_buffer, VALUE rb_root)
{
void *buffer;
- int count;
+ int count=0;
MPI_Datatype type;
int root;
struct _Comm *comm;
- OBJ2C(rb_buffer, count, buffer, type);
+ OBJ2C(rb_buffer, count, buffer, type, 0);
root = NUM2INT(rb_root);
Data_Get_Struct(self, struct _Comm, comm);
check_error(MPI_Bcast(buffer, count, type, root, comm->Comm));
return Qnil;
}
static VALUE
rb_comm_scatter(VALUE self, VALUE rb_sendbuf, VALUE rb_recvbuf, VALUE rb_root)
{
void *sendbuf = NULL, *recvbuf;
- int sendcount = 0, recvcount;
- MPI_Datatype sendtype = NULL, recvtype;
+ int sendcount = 0, recvcount=0;
+ MPI_Datatype sendtype = 0, recvtype;
int root, rank, size;
struct _Comm *comm;
- OBJ2C(rb_recvbuf, recvcount, recvbuf, recvtype);
+ OBJ2C(rb_recvbuf, recvcount, recvbuf, recvtype, 0);
root = NUM2INT(rb_root);
Data_Get_Struct(self, struct _Comm, comm);
check_error(MPI_Comm_rank(comm->Comm, &rank));
check_error(MPI_Comm_size(comm->Comm, &size));
if (rank == root) {
- OBJ2C(rb_sendbuf, sendcount, sendbuf, sendtype);
+ OBJ2C(rb_sendbuf, sendcount, sendbuf, sendtype, 0);
if (sendcount > recvcount*size)
rb_raise(rb_eArgError, "recvbuf is too small");
sendcount = recvcount;
}
check_error(MPI_Scatter(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm->Comm));
@@ -433,21 +497,21 @@
}
static VALUE
rb_comm_sendrecv(VALUE self, VALUE rb_sendbuf, VALUE rb_dest, VALUE rb_sendtag, VALUE rb_recvbuf, VALUE rb_source, VALUE rb_recvtag)
{
void *sendbuf, *recvbuf;
- int sendcount, recvcount;
+ int sendcount=0, recvcount=0;
MPI_Datatype sendtype, recvtype;
int dest, source;
int sendtag, recvtag;
int size;
struct _Comm *comm;
MPI_Status *status;
- OBJ2C(rb_sendbuf, sendcount, sendbuf, sendtype);
+ OBJ2C(rb_sendbuf, sendcount, sendbuf, sendtype, 0);
Data_Get_Struct(self, struct _Comm, comm);
check_error(MPI_Comm_size(comm->Comm, &size));
- OBJ2C(rb_recvbuf, recvcount, recvbuf, recvtype);
+ OBJ2C(rb_recvbuf, recvcount, recvbuf, recvtype, 0);
dest = NUM2INT(rb_dest);
source = NUM2INT(rb_source);
sendtag = NUM2INT(rb_sendtag);
recvtag = NUM2INT(rb_recvtag);
status = ALLOC(MPI_Status);
@@ -456,18 +520,18 @@
}
static VALUE
rb_comm_alltoall(VALUE self, VALUE rb_sendbuf, VALUE rb_recvbuf)
{
void *sendbuf, *recvbuf;
- int sendcount, recvcount;
+ int sendcount=0, recvcount=0;
MPI_Datatype sendtype, recvtype;
int size;
struct _Comm *comm;
- OBJ2C(rb_sendbuf, sendcount, sendbuf, sendtype);
+ OBJ2C(rb_sendbuf, sendcount, sendbuf, sendtype, 0);
Data_Get_Struct(self, struct _Comm, comm);
check_error(MPI_Comm_size(comm->Comm, &size));
- OBJ2C(rb_recvbuf, recvcount, recvbuf, recvtype);
+ OBJ2C(rb_recvbuf, recvcount, recvbuf, recvtype, 0);
if (recvcount < sendcount)
rb_raise(rb_eArgError, "recvbuf is too small");
recvcount = recvcount/size;
sendcount = sendcount/size;
check_error(MPI_Alltoall(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm->Comm));
@@ -475,22 +539,22 @@
}
static VALUE
rb_comm_reduce(VALUE self, VALUE rb_sendbuf, VALUE rb_recvbuf, VALUE rb_op, VALUE rb_root)
{
void *sendbuf, *recvbuf = NULL;
- int sendcount, recvcount = 0;
- MPI_Datatype sendtype, recvtype = NULL;
+ int sendcount=0, recvcount = 0;
+ MPI_Datatype sendtype, recvtype = 0;
int root, rank, size;
struct _Comm *comm;
struct _Op *op;
- OBJ2C(rb_sendbuf, sendcount, sendbuf, sendtype);
+ OBJ2C(rb_sendbuf, sendcount, sendbuf, sendtype, 0);
root = NUM2INT(rb_root);
Data_Get_Struct(self, struct _Comm, comm);
check_error(MPI_Comm_rank(comm->Comm, &rank));
check_error(MPI_Comm_size(comm->Comm, &size));
if (rank == root) {
- OBJ2C(rb_recvbuf, recvcount, recvbuf, recvtype);
+ OBJ2C(rb_recvbuf, recvcount, recvbuf, recvtype, 0);
if (recvcount != sendcount)
rb_raise(rb_eArgError, "sendbuf and recvbuf has the same length");
if (recvtype != sendtype)
rb_raise(rb_eArgError, "sendbuf and recvbuf has the same type");
}
@@ -500,20 +564,20 @@
}
static VALUE
rb_comm_allreduce(VALUE self, VALUE rb_sendbuf, VALUE rb_recvbuf, VALUE rb_op)
{
void *sendbuf, *recvbuf;
- int sendcount, recvcount;
+ int sendcount=0, recvcount=0;
MPI_Datatype sendtype, recvtype;
int rank, size;
struct _Comm *comm;
struct _Op *op;
- OBJ2C(rb_sendbuf, sendcount, sendbuf, sendtype);
+ OBJ2C(rb_sendbuf, sendcount, sendbuf, sendtype, 0);
Data_Get_Struct(self, struct _Comm, comm);
check_error(MPI_Comm_rank(comm->Comm, &rank));
check_error(MPI_Comm_size(comm->Comm, &size));
- OBJ2C(rb_recvbuf, recvcount, recvbuf, recvtype);
+ OBJ2C(rb_recvbuf, recvcount, recvbuf, recvtype, 0);
if (recvcount != sendcount)
rb_raise(rb_eArgError, "sendbuf and recvbuf has the same length");
if (recvtype != sendtype)
rb_raise(rb_eArgError, "sendbuf and recvbuf has the same type");
Data_Get_Struct(rb_op, struct _Op, op);
@@ -527,22 +591,23 @@
struct _Errhandler *errhandler;
VALUE rb_errhandler;
Data_Get_Struct(self, struct _Comm, comm);
rb_errhandler = Data_Make_Struct(cErrhandler, struct _Errhandler, NULL, Errhandler_free, errhandler);
- MPI_Comm_get_errhandler(comm->Comm, &(errhandler->Errhandler));
+ errhandler->free = false;
+ check_error(MPI_Comm_get_errhandler(comm->Comm, &(errhandler->Errhandler)));
return rb_errhandler;
}
static VALUE
rb_comm_set_Errhandler(VALUE self, VALUE rb_errhandler)
{
struct _Comm *comm;
struct _Errhandler *errhandler;
Data_Get_Struct(self, struct _Comm, comm);
Data_Get_Struct(rb_errhandler, struct _Errhandler, errhandler);
- MPI_Comm_set_errhandler(comm->Comm, errhandler->Errhandler);
+ check_error(MPI_Comm_set_errhandler(comm->Comm, errhandler->Errhandler));
return self;
}
static VALUE
rb_comm_barrier(VALUE self)
{
@@ -605,10 +670,11 @@
// MPI
mMPI = rb_define_module("MPI");
rb_define_module_function(mMPI, "Init", rb_m_init, -1);
rb_define_module_function(mMPI, "Finalize", rb_m_finalize, -1);
+ rb_define_module_function(mMPI, "Abort", rb_m_abort, 2);
rb_define_const(mMPI, "VERSION", INT2NUM(MPI_VERSION));
rb_define_const(mMPI, "SUBVERSION", INT2NUM(MPI_SUBVERSION));
rb_define_const(mMPI, "SUCCESS", INT2NUM(MPI_SUCCESS));
rb_define_const(mMPI, "PROC_NULL", INT2NUM(MPI_PROC_NULL));
@@ -618,11 +684,11 @@
rb_define_private_method(cComm, "initialize", rb_comm_initialize, 0);
rb_define_method(cComm, "rank", rb_comm_rank, 0);
rb_define_method(cComm, "size", rb_comm_size, 0);
rb_define_method(cComm, "Send", rb_comm_send, 3);
rb_define_method(cComm, "Isend", rb_comm_isend, 3);
- rb_define_method(cComm, "Recv", rb_comm_recv, 3);
- rb_define_method(cComm, "Irecv", rb_comm_irecv, 3);
+ rb_define_method(cComm, "Recv", rb_comm_recv, -1);
+ rb_define_method(cComm, "Irecv", rb_comm_irecv, -1);
rb_define_method(cComm, "Gather", rb_comm_gather, 3);
rb_define_method(cComm, "Allgather", rb_comm_allgather, 2);
rb_define_method(cComm, "Bcast", rb_comm_bcast, 2);
rb_define_method(cComm, "Scatter", rb_comm_scatter, 3);
rb_define_method(cComm, "Sendrecv", rb_comm_sendrecv, 6);