Main Page   Namespace List   Class Hierarchy   Compound List   File List   Namespace Members   Compound Members   File Members   Related Pages   Examples  

ContainerComm.h

Go to the documentation of this file.
00001 #ifndef CONTAINERCOMM_H
00002 #define CONTAINERCOMM_H
00003 
00004 #include "SundanceDefs.h"
00005 #include "TSFArray.h"
00006 #include "MPIComm.h"
00007 #include "MPITraits.h"
00008 
00009 namespace Sundance
00010 {
00011   /** \ingroup MPI
00012    * MPI communication of templated containers
00013    */
00014 
00015   template <class T> class ContainerComm
00016     {
00017     public:
00018       /** broadcast a single object */
00019       static void bcast(T& x, int src, const MPIComm& comm);
00020 
00021       /** send a single object */
00022       static void send(const T& x, int tag, int dest,
00023                        const MPIComm& comm);
00024 
00025       /** recv a single object */
00026       static void recv(T& x, int tag, int src,
00027                        const MPIComm& comm);
00028 
00029       /** bcast an array of objects */
00030       static void bcast(TSFArray<T>& x, int src, const MPIComm& comm);
00031 
00032       /** send an array of objects */
00033       static void send(const TSFArray<T>& x, int tag, int dest,
00034                        const MPIComm& comm);
00035 
00036       /** recv an array of objects */
00037       static void recv(TSFArray<T>& x, int tag, int src,
00038                        const MPIComm& comm);
00039 
00040       /** bcast an array of arrays  */
00041       static void bcast(TSFArray<TSFArray<T> >& x,
00042                         int src, const MPIComm& comm);
00043 
00044       /** AllGather: each process sends a single object to all other procs */
00045       static void allGather(const T& outgoing,
00046                             TSFArray<T>& incoming,
00047                             const MPIComm& comm);
00048 
00049       /** All-to-all scatter/gather, each proc sends an object to every proc */
00050       static void allToAll(const TSFArray<T>& outgoing,
00051                            TSFArray<TSFArray<T> >& incoming,
00052                            const MPIComm& comm);
00053 
00054       /** All-to-all scatter/gather, each proc sends an array to every proc */
00055       static void allToAll(const TSFArray<TSFArray<T> >& outgoing,
00056                            TSFArray<TSFArray<T> >& incoming,
00057                            const MPIComm& comm);
00058 
00059       /** sum local values from all processors with rank < myRank */
00060       static void accumulate(const T& localValue, TSFArray<T>& sums,
00061                              const MPIComm& comm);
00062 
00063     private:
00064       /** build a 1D array and an offset list from a 2D array */
00065       static void getBigTSFArray(const TSFArray<TSFArray<T> >& x,
00066                                  TSFArray<T>& bigTSFArray,
00067                                  TSFArray<int>& offsets);
00068 
00069       /** reassemble a 2D array from a 1D array and an offset table */
00070       static void getSmallTSFArrays(const TSFArray<T>& bigTSFArray,
00071                                     const TSFArray<int>& offsets,
00072                                     TSFArray<TSFArray<T> >& x);
00073 
00074 
00075     };
00076 
00077   /** \ingroup MPI
00078    * Specialiaztion of ContainerComm<T> to string
00079    */
00080   template <> class ContainerComm<string>
00081     {
00082     public:
00083       static void bcast(string& x, int src, const MPIComm& comm);
00084 
00085       static void send(const string& x, int tag, int dest,
00086                        const MPIComm& comm);
00087 
00088       static void recv(string& x, int tag, int src,
00089                        const MPIComm& comm);
00090 
00091       /** bcast an array of objects */
00092       static void bcast(TSFArray<string>& x, int src, const MPIComm& comm);
00093 
00094       /** send an array of objects */
00095       static void send(const TSFArray<string>& x, int tag, int dest,
00096                        const MPIComm& comm);
00097 
00098       /** recv an array of objects */
00099       static void recv(TSFArray<string>& x, int tag, int src,
00100                        const MPIComm& comm);
00101 
00102       /** bcast an array of arrays  */
00103       static void bcast(TSFArray<TSFArray<string> >& x,
00104                         int src, const MPIComm& comm);
00105 
00106       /** AllGather: each process sends a single object to all other procs */
00107       static void allGather(const string& outgoing,
00108                             TSFArray<string>& incoming,
00109                             const MPIComm& comm);
00110 
00111     private:
00112       /** get a single big array of characters from an array of strings */
00113       static void getBigTSFArray(const TSFArray<string>& x,
00114                                  TSFArray<char>& bigTSFArray,
00115                                  TSFArray<int>& offsets);
00116 
00117       /** recover an array of strings from a single big array and
00118        * and offset table */
00119       static void getStrings(const TSFArray<char>& bigTSFArray,
00120                              const TSFArray<int>& offsets,
00121                              TSFArray<string>& x);
00122     };
00123 
00124 
00125   /* --------- generic functions for primitives ------------------- */
00126 
00127   template <class T> inline void ContainerComm<T>::bcast(T& x, int src,
00128                                                          const MPIComm& comm)
00129     {
00130       comm.bcast((void*)&x, 1, MPITraits<T>::type(), src);
00131     }
00132 
00133   template <class T> inline void ContainerComm<T>::send(const T& x, int tag, int dest,
00134                                                         const MPIComm& comm)
00135     {
00136       comm.send((void*)&x, 1, MPITraits<T>::type(), tag, dest);
00137     }
00138 
00139   template <class T> inline void ContainerComm<T>::recv(T& x, int tag, int src,
00140                                                         const MPIComm& comm)
00141     {
00142       comm.recv((void*)&x, 1, MPITraits<T>::type(), tag, src);
00143     }
00144 
00145   /* ----------- generic functions for arrays of primitives ----------- */
00146 
00147   template <class T>
00148     inline void ContainerComm<T>::bcast(TSFArray<T>& x, int src, const MPIComm& comm)
00149     {
00150       try
00151         {
00152           int len = x.length();
00153           ContainerComm<int>::bcast(len, src, comm);
00154 
00155           if (comm.getRank() != src)
00156             {
00157               x.resize(len);
00158             }
00159           if (len==0) return;
00160 
00161           /* then broadcast the contents */
00162           comm.bcast((void*) &(x[0]), (int) len,
00163                      MPITraits<T>::type(), src);
00164         }
00165       catch(exception& e)
00166         {
00167           TSFError::trace(e, "in TSFArrayComm::bcast(TSFArray<T>)");
00168         }
00169     }
00170 
00171 
00172   template <class T>
00173     inline void ContainerComm<T>::send(const TSFArray<T>& x, int tag,
00174                                        int dest, const MPIComm& comm)
00175     {
00176       try
00177         {
00178           /* first send the length */
00179           int len = x.length();
00180           ContainerComm<int>::send(len, tag, dest, comm);
00181 
00182           /* now send the data */
00183           comm.send((void*) &(x[0]), len,
00184                     MPITraits<T>::type(), tag, dest);
00185         }
00186       catch(exception& e)
00187         {
00188           TSFError::trace(e, "in TSFArrayComm::send(int)");
00189         }
00190     }
00191 
00192   template <class T>
00193     inline void ContainerComm<T>::recv(TSFArray<T>& x, int tag, int src, const MPIComm& comm)
00194     {
00195       try
00196         {
00197           /* first recv the length */
00198           int len;
00199           ContainerComm<int>::recv(len, tag, src, comm);
00200 
00201           /* set the size of the array to the recvd length */
00202           x.resize(len);
00203 
00204           /* recv the data */
00205           comm.recv((void*)&(x[0]), len, MPITraits<T>::type(), tag, src);
00206         }
00207       catch(exception& e)
00208         {
00209           TSFError::trace(e, "in TSFArrayComm::recv(int)");
00210         }
00211     }
00212 
00213   /* ---------- generic function for arrays of arrays ----------- */
00214 
00215   template <class T>
00216     inline void ContainerComm<T>::bcast(TSFArray<TSFArray<T> >& x, int src, const MPIComm& comm)
00217     {
00218       try
00219         {
00220           TSFArray<T> bigTSFArray;
00221           TSFArray<int> offsets;
00222 
00223           if (src==comm.getRank())
00224             {
00225               getBigTSFArray(x, bigTSFArray, offsets);
00226             }
00227 
00228           bcast(bigTSFArray, src, comm);
00229           ContainerComm<int>::bcast(offsets, src, comm);
00230 
00231           if (src != comm.getRank())
00232             {
00233               getSmallTSFArrays(bigTSFArray, offsets, x);
00234             }
00235         }
00236       catch(exception& e)
00237         {
00238           TSFError::trace(e, "in TSFArrayComm::bcast(TSFArray<T>)");
00239         }
00240     }
00241 
00242   /* ---------- generic gather and scatter ------------------------ */
00243 
00244   template <class T> inline
00245     void ContainerComm<T>::allToAll(const TSFArray<T>& outgoing,
00246                                     TSFArray<TSFArray<T> >& incoming,
00247                                     const MPIComm& comm)
00248     {
00249       try
00250         {
00251           int numProcs = comm.getNProc();
00252 
00253           // catch degenerate case
00254           if (numProcs==1)
00255             {
00256               incoming.resize(1);
00257               incoming[0] = outgoing;
00258               return;
00259             }
00260 
00261           T* sendBuf = new T[numProcs * outgoing.length()];
00262           if (sendBuf==0)
00263             TSFError::raise("Comm::allToAll failed to allocate sendBuf");
00264           T* recvBuf = new T[numProcs * outgoing.length()];
00265           if (recvBuf==0)
00266             TSFError::raise("Comm::allToAll failed to allocate recvBuf");
00267 
00268           int i;
00269           for (i=0; i<numProcs; i++)
00270             {
00271               for (int j=0; j<outgoing.length(); j++)
00272                 {
00273                   sendBuf[i*outgoing.length() + j] = outgoing[j];
00274                 }
00275             }
00276 
00277           comm.allToAll(sendBuf, outgoing.length(), MPITraits<T>::type(),
00278                         recvBuf, outgoing.length(), MPITraits<T>::type());
00279 
00280           incoming.resize(numProcs);
00281 
00282           for (i=0; i<numProcs; i++)
00283             {
00284               incoming[i].resize(outgoing.length());
00285               for (int j=0; j<outgoing.length(); j++)
00286                 {
00287                   incoming[i][j] = recvBuf[i*outgoing.length() + j];
00288                 }
00289             }
00290 
00291           delete [] sendBuf;
00292           delete [] recvBuf;
00293         }
00294       catch(exception& e)
00295         {
00296           TSFError::trace(e, "in Comm::allToAll(const TSFArray<int>& outgoing, ...)");
00297         }
00298     }
00299 
00300   template <class T> inline
00301     void ContainerComm<T>::allToAll(const TSFArray<TSFArray<T> >& outgoing,
00302                                     TSFArray<TSFArray<T> >& incoming, const MPIComm& comm)
00303     {
00304       try
00305         {
00306           int numProcs = comm.getNProc();
00307 
00308           // catch degenerate case
00309           if (numProcs==1)
00310             {
00311               incoming = outgoing;
00312               return;
00313             }
00314 
00315           int* sendMesgLength = new int[numProcs];
00316           if (sendMesgLength==0)
00317             TSFError::raise("failed to allocate sendMesgLength");
00318           int* recvMesgLength = new int[numProcs];
00319           if (recvMesgLength==0)
00320             TSFError::raise("failed to allocate recvMesgLength");
00321 
00322           int p = 0;
00323           for (p=0; p<numProcs; p++)
00324             {
00325               sendMesgLength[p] = outgoing[p].length();
00326             }
00327 
00328           comm.allToAll(sendMesgLength, 1, MPIComm::INT,
00329                         recvMesgLength, 1, MPIComm::INT);
00330 
00331 
00332           int totalSendLength = 0;
00333           int totalRecvLength = 0;
00334           for (p=0; p<numProcs; p++)
00335             {
00336               totalSendLength += sendMesgLength[p];
00337               totalRecvLength += recvMesgLength[p];
00338             }
00339 
00340           T* sendBuf = new T[totalSendLength];
00341           if (sendBuf==0)
00342             TSFError::raise("failed to allocate sendBuf");
00343           T* recvBuf = new T[totalRecvLength];
00344           if (recvBuf==0)
00345             TSFError::raise("failed to allocate recvBuf");
00346 
00347           int* sendDisp = new int[numProcs];
00348           if (sendDisp==0)
00349             TSFError::raise("failed to allocate sendDisp");
00350           int* recvDisp = new int[numProcs];
00351           if (recvDisp==0)
00352             TSFError::raise("failed to allocate recvDisp");
00353 
00354           int count = 0;
00355           sendDisp[0] = 0;
00356           recvDisp[0] = 0;
00357 
00358           for (p=0; p<numProcs; p++)
00359             {
00360               for (int i=0; i<outgoing[p].length(); i++)
00361                 {
00362                   sendBuf[count] = outgoing[p][i];
00363                   count++;
00364                 }
00365               if (p>0)
00366                 {
00367                   sendDisp[p] = sendDisp[p-1] + sendMesgLength[p-1];
00368                   recvDisp[p] = recvDisp[p-1] + recvMesgLength[p-1];
00369                 }
00370             }
00371 
00372           comm.allToAllv(sendBuf, sendMesgLength,
00373                          sendDisp, MPITraits<T>::type(),
00374                          recvBuf, recvMesgLength,
00375                          recvDisp, MPITraits<T>::type());
00376 
00377           incoming.resize(numProcs);
00378           for (p=0; p<numProcs; p++)
00379             {
00380               incoming[p].resize(recvMesgLength[p]);
00381               for (int i=0; i<recvMesgLength[p]; i++)
00382                 {
00383                   incoming[p][i] = recvBuf[recvDisp[p] + i];
00384                 }
00385             }
00386 
00387           delete [] sendBuf;
00388           delete [] sendMesgLength;
00389           delete [] sendDisp;
00390           delete [] recvBuf;
00391           delete [] recvMesgLength;
00392           delete [] recvDisp;
00393         }
00394       catch(exception& e)
00395         {
00396           TSFError::trace(e, "in TSFArrayComm::allToAll(const TSFArray<TSFArray<T> >& outgoing...)");
00397         }
00398     }
00399 
00400   template <class T> inline
00401     void ContainerComm<T>::allGather(const T& outgoing, TSFArray<T>& incoming,
00402                                      const MPIComm& comm)
00403     {
00404       int nProc = comm.getNProc();
00405       incoming.resize(nProc);
00406 
00407       if (nProc==1)
00408         {
00409           incoming[0] = outgoing;
00410         }
00411       else
00412         {
00413           comm.allGather((void*) &outgoing, 1, MPITraits<T>::type(),
00414                          (void*) &(incoming[0]), 1, MPITraits<T>::type());
00415         }
00416     }
00417 
00418   template <class T> inline
00419     void ContainerComm<T>::accumulate(const T& localValue, TSFArray<T>& sums,
00420                                       const MPIComm& comm)
00421     {
00422       TSFArray<T> contributions;
00423       allGather(localValue, contributions, comm);
00424       sums.resize(comm.getNProc());
00425       sums[0] = 0;
00426 
00427       for (int i=0; i<comm.getNProc()-1; i++)
00428         {
00429           sums[i+1] = sums[i] + contributions[i];
00430         }
00431     }
00432 
00433 
00434 
00435 
00436   template <class T> inline
00437     void ContainerComm<T>::getBigTSFArray(const TSFArray<TSFArray<T> >& x, TSFArray<T>& bigTSFArray,
00438                                           TSFArray<int>& offsets)
00439     {
00440       offsets.resize(x.length()+1);
00441       int totalLength = 0;
00442 
00443       for (int i=0; i<x.length(); i++)
00444         {
00445           offsets[i] = totalLength;
00446           totalLength += x[i].length();
00447         }
00448       offsets[x.length()] = totalLength;
00449 
00450       bigTSFArray.resize(totalLength);
00451 
00452       for (int i=0; i<x.length(); i++)
00453         {
00454           for (int j=0; j<x[i].length(); j++)
00455             {
00456               bigTSFArray[offsets[i]+j] = x[i][j];
00457             }
00458         }
00459     }
00460 
00461   template <class T> inline
00462     void ContainerComm<T>::getSmallTSFArrays(const TSFArray<T>& bigTSFArray,
00463                                              const TSFArray<int>& offsets,
00464                                              TSFArray<TSFArray<T> >& x)
00465     {
00466       x.resize(offsets.length()-1);
00467       for (int i=0; i<x.length(); i++)
00468         {
00469           x[i].resize(offsets[i+1]-offsets[i]);
00470           for (int j=0; j<x[i].length(); j++)
00471             {
00472               x[i][j] = bigTSFArray[offsets[i] + j];
00473             }
00474         }
00475     }
00476 
00477 
00478   /* --------------- string specializations --------------------- */
00479 
00480   inline void ContainerComm<string>::bcast(string& x,
00481                                            int src, const MPIComm& comm)
00482     {
00483       int len = x.length();
00484       ContainerComm<int>::bcast(len, src, comm);
00485 
00486       x.resize(len);
00487       comm.bcast((void*)&(x[0]), len, MPITraits<char>::type(), src);
00488     }
00489 
00490   inline void ContainerComm<string>::send(const string& x, int tag, int dest,
00491                                           const MPIComm& comm)
00492     {
00493       int len = x.length();
00494       ContainerComm<int>::send(len, tag, dest, comm);
00495 
00496       void* start = (void*) x.c_str();
00497       comm.send(start, len, MPITraits<char>::type(), tag, dest);
00498     }
00499 
00500   inline void ContainerComm<string>::recv(string& x, int tag, int src,
00501                                           const MPIComm& comm)
00502     {
00503       int len;
00504       ContainerComm<int>::recv(len, tag, src, comm);
00505 
00506       x.resize(len);
00507 
00508       void* start = (void*) x.c_str();
00509       comm.recv(start, len, MPITraits<char>::type(), tag, src);
00510     }
00511 
00512   inline void ContainerComm<string>::bcast(TSFArray<string>& x, int src,
00513                                            const MPIComm& comm)
00514     {
00515       try
00516         {
00517           /* begin by packing all the data into a big char array. This will
00518            * take a little time, but will be cheaper than multiple MPI calls */
00519           TSFArray<char> bigTSFArray;
00520           TSFArray<int> offsets;
00521           if (comm.getRank()==src)
00522             {
00523               getBigTSFArray(x, bigTSFArray, offsets);
00524             }
00525 
00526           /* now broadcast the big array and the offsets */
00527           ContainerComm<char>::bcast(bigTSFArray, src, comm);
00528           ContainerComm<int>::bcast(offsets, src, comm);
00529 
00530           /* finally, reassemble the array of strings */
00531           if (comm.getRank() != src)
00532             {
00533               getStrings(bigTSFArray, offsets, x);
00534             }
00535 
00536         }
00537       catch(exception& e)
00538         {
00539           TSFError::trace(e, "in bcast(TSFArray<string>)");
00540         }
00541     }
00542 
00543   inline void ContainerComm<string>::bcast(TSFArray<TSFArray<string> >& x,
00544                                            int src, const MPIComm& comm)
00545     {
00546       try
00547         {
00548           int len = x.length();
00549           ContainerComm<int>::bcast(len, src, comm);
00550 
00551           x.resize(len);
00552           for (int i=0; i<len; i++)
00553             {
00554               ContainerComm<string>::bcast(x[i], src, comm);
00555             }
00556         }
00557       catch(exception& e)
00558         {
00559           TSFError::trace(e, "in ContainerComm<string>::bcast(TSFArray<TSFArray<string>>)");
00560         }
00561     }
00562 
00563 
00564   inline void ContainerComm<string>::send(const TSFArray<string>& x, int tag,
00565                                           int dest, const MPIComm& comm)
00566     {
00567       try
00568         {
00569           TSFArray<char> bigTSFArray;
00570           TSFArray<int> offsets;
00571 
00572           getBigTSFArray(x, bigTSFArray, offsets);
00573 
00574           ContainerComm<int>::send(offsets, tag, dest, comm);
00575           ContainerComm<char>::send(bigTSFArray, tag, dest, comm);
00576         }
00577       catch(exception& e)
00578         {
00579           TSFError::trace(e, "in send(TSFArray<string>)");
00580         }
00581     }
00582 
00583   inline void ContainerComm<string>::recv(TSFArray<string>& x,
00584                                           int tag, int src, const MPIComm& comm)
00585     {
00586       try
00587         {
00588           TSFArray<char> bigTSFArray;
00589           TSFArray<int> offsets;
00590 
00591           ContainerComm<int>::recv(offsets, tag, src, comm);
00592           ContainerComm<char>::recv(bigTSFArray, tag, src, comm);
00593 
00594           getStrings(bigTSFArray, offsets, x);
00595         }
00596       catch(exception& e)
00597         {
00598           TSFError::trace(e, "in recv(TSFArray<string>)");
00599         }
00600     }
00601 
00602   inline void ContainerComm<string>::allGather(const string& outgoing,
00603                                                TSFArray<string>& incoming,
00604                                                const MPIComm& comm)
00605     {
00606       int nProc = comm.getNProc();
00607 
00608       int sendCount = outgoing.length();
00609 
00610       incoming.resize(nProc);
00611 
00612       int* recvCounts = new int[nProc];
00613       int* recvDisplacements = new int[nProc];
00614 
00615       /* share lengths with all procs */
00616       comm.allGather((void*) &sendCount, 1, MPIComm::INT,
00617                      (void*) recvCounts, 1, MPIComm::INT);
00618 
00619 
00620       int recvSize = 0;
00621       recvDisplacements[0] = 0;
00622       for (int i=0; i<nProc; i++)
00623         {
00624           recvSize += recvCounts[i];
00625           if (i < nProc-1)
00626             {
00627               recvDisplacements[i+1] = recvDisplacements[i]+recvCounts[i];
00628             }
00629         }
00630 
00631       char* recvBuf = new char[recvSize];
00632 
00633       comm.allGatherv((void*) outgoing.c_str(), sendCount, MPIComm::CHAR,
00634                       recvBuf, recvCounts, recvDisplacements, MPIComm::CHAR);
00635 
00636       for (int j=0; j<nProc; j++)
00637         {
00638           char* start = recvBuf + recvDisplacements[j];
00639           char* tmp = new char[recvCounts[j]+1];
00640           memcpy(tmp, start, recvCounts[j]);
00641           tmp[recvCounts[j]] = '\0';
00642           incoming[j] = string(tmp);
00643           delete [] tmp;
00644         }
00645 
00646       delete [] recvCounts;
00647       delete [] recvDisplacements;
00648       delete [] recvBuf;
00649     }
00650 
00651 
00652   inline void ContainerComm<string>::getBigTSFArray(const TSFArray<string>& x,
00653                                                     TSFArray<char>& bigTSFArray,
00654                                                     TSFArray<int>& offsets)
00655     {
00656       offsets.resize(x.length()+1);
00657       int totalLength = 0;
00658 
00659       for (int i=0; i<x.length(); i++)
00660         {
00661           offsets[i] = totalLength;
00662           totalLength += x[i].length();
00663         }
00664       offsets[x.length()] = totalLength;
00665 
00666       bigTSFArray.resize(totalLength);
00667 
00668       for (int i=0; i<x.length(); i++)
00669         {
00670           for (unsigned int j=0; j<x[i].length(); j++)
00671             {
00672               bigTSFArray[offsets[i]+j] = x[i][j];
00673             }
00674         }
00675     }
00676 
00677   inline void ContainerComm<string>::getStrings(const TSFArray<char>& bigTSFArray,
00678                                                 const TSFArray<int>& offsets,
00679                                                 TSFArray<string>& x)
00680     {
00681       x.resize(offsets.length()-1);
00682       for (int i=0; i<x.length(); i++)
00683         {
00684           x[i].resize(offsets[i+1]-offsets[i]);
00685           for (unsigned int j=0; j<x[i].length(); j++)
00686             {
00687               x[i][j] = bigTSFArray[offsets[i] + j];
00688             }
00689         }
00690     }
00691 }
00692 
00693 
00694 #endif
00695 
00696 

Contact:
Kevin Long (krlong@ca.sandia.gov)


Documentation generated by