001/*
002 * $Id$
003 */
004
005package edu.jas.kern;
006
007
008import java.io.IOException;
009import java.util.ArrayList;
010import java.util.Arrays;
011
012import mpi.Comm;
013import mpi.Intracomm;
014import mpi.MPI;
015import mpi.MPIException;
016import mpi.Status;
017
018import org.apache.logging.log4j.Logger;
019import org.apache.logging.log4j.LogManager;
020
021
022/**
023 * MPI engine, provides global MPI service. <b>Note:</b> could eventually be
024 * done directly with MPI, but provides logging. <b>Usage:</b> To obtain a
025 * reference to the MPI service communicator use
026 * <code>MPIEngine.getComminicator()</code>. Once an engine has been created it
027 * must be shutdown to exit JAS with <code>MPIEngine.terminate()</code>.
028 * @author Heinz Kredel
029 */
030
031public final class MPIEngine {
032
033
034    private static final Logger logger = LogManager.getLogger(MPIEngine.class);
035
036
037    private static final boolean debug = logger.isDebugEnabled();
038
039
040    /**
041     * Command line arguments. Required for MPI runtime system.
042     */
043    protected static String[] cmdline;
044
045
046    /**
047     * Hostnames of MPI partners.
048     */
049    public static ArrayList<String> hostNames = new ArrayList<String>();
050
051
052    /**
053     * Flag for MPI usage. <b>Note:</b> Only introduced because Google app
054     * engine does not support MPI.
055     */
056    public static boolean NO_MPI = false;
057
058
059    /**
060     * Number of processors.
061     */
062    public static final int N_CPUS = Runtime.getRuntime().availableProcessors();
063
064
065    /*
066     * Core number of threads.
067     * N_CPUS x 1.5, x 2, x 2.5, min 3, ?.
068     */
069    public static final int N_THREADS = (N_CPUS < 3 ? 3 : N_CPUS + N_CPUS / 2);
070
071
072    /**
073     * MPI communicator engine.
074     */
075    static Intracomm mpiComm;
076
077
078    /**
079     * MPI engine base tag number.
080     */
081    public static final int TAG = 11;
082
083
084    /**
085     * Hostname suffix.
086     */
087    public static final String hostSuf = "-ib";
088
089
090    // /*
091    //  * Send locks per tag.
092    //  */
093    // private static SortedMap<Integer,Object> sendLocks = new TreeMap<Integer,Object>();
094
095
096    // /*
097    //  * receive locks per tag.
098    //  */
099    // private static SortedMap<Integer,Object> recvLocks = new TreeMap<Integer,Object>();
100
101
102    /**
103     * No public constructor.
104     */
105    private MPIEngine() {
106    }
107
108
109    /**
110     * Set the commandline.
111     * @param args the command line to use for the MPI runtime system.
112     */
113    public static synchronized void setCommandLine(String[] args) {
114        cmdline = args;
115    }
116
117
118    /**
119     * Test if a pool is running.
120     * @return true if a thread pool has been started or is running, else false.
121     */
122    public static synchronized boolean isRunning() {
123        if (mpiComm == null) {
124            return false;
125        }
126        return true;
127    }
128
129
130    /**
131     * Get the MPI communicator.
132     * @return a Communicator constructed for cmdline.
133     */
134    public static synchronized Comm getCommunicator() throws IOException, MPIException {
135        if (cmdline == null) {
136            throw new IllegalArgumentException("command line not set");
137        }
138        return getCommunicator(cmdline);
139    }
140
141
142    /**
143     * Get the MPI communicator.
144     * @param args the command line to use for the MPI runtime system.
145     * @return a Communicator.
146     */
147    public static synchronized Comm getCommunicator(String[] args) throws IOException, MPIException {
148        if (NO_MPI) {
149            return null;
150        }
151        if (mpiComm == null) {
152            //String[] args = new String[] { }; //"-np " + N_THREADS };
153            if (args == null) {
154                throw new IllegalArgumentException("command line is null");
155            }
156            cmdline = args;
157            args = MPI.Init(args);
158            //int tl = MPI.Init_thread(args,MPI.THREAD_MULTIPLE);
159            logger.info("MPI initialized on " + MPI.Get_processor_name());
160            //logger.info("thread level MPI.THREAD_MULTIPLE: " + MPI.THREAD_MULTIPLE 
161            //            + ", provided: " + tl);
162            if (debug) {
163                logger.debug("remaining args: " + Arrays.toString(args));
164            }
165            mpiComm = MPI.COMM_WORLD;
166            int size = mpiComm.Size();
167            int rank = mpiComm.Rank();
168            logger.info("MPI size = " + size + ", rank = " + rank);
169            // maintain list of hostnames of partners
170            hostNames.ensureCapacity(size);
171            for (int i = 0; i < size; i++) {
172                hostNames.add("");
173            }
174            String myhost = MPI.Get_processor_name();
175            if ( myhost.matches("\\An\\d*") ) { // bwGRiD node names n010207
176                myhost += hostSuf;
177            }
178            if ( myhost.matches("kredel.*") ) { 
179                myhost = "localhost";
180            }
181            hostNames.set(rank, myhost);
182            if (rank == 0) {
183                String[] va = new String[1];
184                va[0] = hostNames.get(0);
185                mpiComm.Bcast(va, 0, va.length, MPI.OBJECT, 0);
186                for (int i = 1; i < size; i++) {
187                    Status stat = mpiComm.Recv(va, 0, va.length, MPI.OBJECT, i, TAG);
188                    if (stat == null) {
189                        throw new IOException("no Status received");
190                        //throw new MPIException("no Status received");
191                    }
192                    int cnt = stat.Get_count(MPI.OBJECT);
193                    if (cnt == 0) {
194                        throw new IOException("no Object received");
195                        //throw new MPIException("no object received");
196                    }
197                    String v = va[0];
198                    hostNames.set(i, v);
199                }
200                logger.info("MPI partner host names = " + hostNames);
201            } else {
202                String[] va = new String[1];
203                mpiComm.Bcast(va, 0, va.length, MPI.OBJECT, 0);
204                hostNames.set(0, va[0]);
205                va[0] = hostNames.get(rank);
206                mpiComm.Send(va, 0, va.length, MPI.OBJECT, 0, TAG);
207            }
208        }
209        return mpiComm;
210    }
211
212
213    /**
214     * Stop execution.
215     */
216    public static synchronized void terminate() {
217        if (mpiComm == null) {
218            return;
219        }
220        try {
221            logger.info("terminating MPI on rank = " + mpiComm.Rank());
222            mpiComm = null;
223            MPI.Finalize();
224        } catch (MPIException e) {
225            e.printStackTrace();
226        }
227    }
228
229
230    /**
231     * Set no MPI usage.
232     */
233    public static synchronized void setNoMPI() {
234        NO_MPI = true;
235        terminate();
236    }
237
238
239    /**
240     * Set MPI usage.
241     */
242    public static synchronized void setMPI() {
243        NO_MPI = false;
244    }
245
246
247    // /*
248    //  * Get send lock per tag.
249    //  * @param tag message tag.
250    //  * @return a lock for sends.
251    //  */
252    // public static synchronized Object getSendLock(int tag) {
253    //     tag = 11; // one global lock
254    //     Object lock = sendLocks.get(tag);
255    //     if ( lock == null ) {
256    //         lock = new Object();
257    //         sendLocks.put(tag,lock);
258    //     }
259    //     return lock;
260    // }
261
262
263    // /*
264    //  * Get receive lock per tag.
265    //  * @param tag message tag.
266    //  * @return a lock for receives.
267    //  */
268    // public static synchronized Object getRecvLock(int tag) {
269    //     Object lock = recvLocks.get(tag);
270    //     if ( lock == null ) {
271    //         lock = new Object();
272    //         recvLocks.put(tag,lock);
273    //     }
274    //     return lock;
275    // }
276
277
278    // /*
279    //  * Wait for termination of a mpi Request.
280    //  * @param req a Request.
281    //  * @return a Status after termination of req.Wait().
282    //  */
283    // public static Status waitRequest(final Request req) throws MPIException {
284    //     if ( req == null || req.Is_null() ) {
285    //         throw new IllegalArgumentException("null request");
286    //     }
287    //     int delay = 50;
288    //     int delcnt = 0;
289    //     Status stat = null;
290    //     while (true) {
291    //         synchronized (MPIEngine.class) { // global static lock
292    //             stat = req.Test(); 
293    //             logger.info("Request: " + req + ", Status: " + stat);
294    //             if (stat != null) {
295    //                 logger.info("Status: index = " + stat.index + ", source = " + stat.source
296    //                                   + ", tag = " + stat.tag);
297    //                 if (!stat.Test_cancelled()) {
298    //                     logger.info("enter req.Wait(): " + Thread.currentThread().toString());
299    //                     return req.Wait(); // should terminate immediately
300    //                 }
301    //             }
302    //         }
303    //         try {
304    //             Thread.currentThread().sleep(delay); // varied a bit
305    //         } catch (InterruptedException e) {
306    //             logger.info("sleep interrupted");
307    //             e.printStackTrace();
308    //         }
309    //         delcnt++; 
310    //         if ( delcnt % 7 != 0 ) {
311    //             delay++;
312    //             logger.info("delay(" + delay + "): " + Thread.currentThread().toString());
313    //         } 
314    //     }
315    // }
316
317}