001/*
002 * $Id$
003 */
004
005package edu.jas.util;
006
007
008import java.io.IOException;
009import java.util.AbstractMap;
010import java.util.ArrayList;
011import java.util.Collection;
012import java.util.Iterator;
013import java.util.List;
014import java.util.Set;
015import java.util.SortedMap;
016import java.util.TreeMap;
017
018import mpi.Comm;
019import mpi.MPI;
020import mpi.MPIException;
021import mpi.Status;
022
023import org.apache.logging.log4j.Logger;
024import org.apache.logging.log4j.LogManager;
025
026import edu.jas.kern.MPIEngine;
027
028
029/**
030 * Distributed version of a HashTable using MPI. Implemented with a SortedMap /
031 * TreeMap to keep the sequence order of elements. Implemented using MPI
032 * transport or TCP transport.
033 * @author Heinz Kredel
034 */
035
036public class DistHashTableMPI<K, V> extends AbstractMap<K, V> {
037
038
039    private static final Logger logger = LogManager.getLogger(DistHashTableMPI.class);
040
041
042    private static final boolean debug = logger.isDebugEnabled();
043
044
045    /*
046     * Backing data structure.
047     */
048    protected final SortedMap<K, V> theList;
049
050
051    /*
052     * Thread for receiving pairs.
053     */
054    protected DHTMPIListener<K, V> listener;
055
056
057    /*
058     * MPI communicator.
059     */
060    protected final Comm engine;
061
062
063    /*
064     * Size of Comm.
065     */
066    private final int size;
067
068
069    /*
070     * This rank.
071     */
072    private final int rank;
073
074
075    /**
076     * Message tag for DHT communicaton.
077     */
078    public static final int DHTTAG = MPIEngine.TAG + 1;
079
080
081    /*
082     * TCP/IP object channels.
083     */
084    private final SocketChannel[] soc;
085
086
087    /**
088     * Transport layer.
089     * true: use TCP/IP socket layer, false: use MPI transport layer.
090     */
091    static final boolean useTCP = false;
092
093
094    /**
095     * DistHashTableMPI.
096     */
097    public DistHashTableMPI() throws MPIException, IOException {
098        this(MPIEngine.getCommunicator());
099    }
100
101
102    /**
103     * DistHashTableMPI.
104     * @param args command line for MPI runtime system.
105     */
106    public DistHashTableMPI(String[] args) throws MPIException, IOException {
107        this(MPIEngine.getCommunicator(args));
108    }
109
110
111    /**
112     * DistHashTableMPI.
113     * @param cm MPI communicator to use.
114     */
115    public DistHashTableMPI(Comm cm) throws MPIException, IOException {
116        engine = cm;
117        rank = engine.Rank();
118        size = engine.Size();
119        if (useTCP) { // && soc == null
120            int port = ChannelFactory.DEFAULT_PORT + 11;
121            ChannelFactory cf;
122            if (rank == 0) {
123                cf = new ChannelFactory(port); 
124                cf.init();
125                soc = new SocketChannel[size];
126                soc[0] = null;
127                try {
128                    for (int i = 1; i < size; i++) {
129                        SocketChannel sc = cf.getChannel(); // TODO not correct wrt rank
130                        soc[i] = sc;
131                    }
132                } catch (InterruptedException e) {
133                    throw new IOException(e);
134                }
135                cf.terminate();
136            } else {
137                cf = new ChannelFactory(port-1); // in case of localhost
138                soc = new SocketChannel[1];
139                SocketChannel sc = cf.getChannel(MPIEngine.hostNames.get(0), port);
140                soc[0] = sc;
141                cf.terminate();
142            }
143        } else {
144            soc = null;
145        }
146        theList = new TreeMap<K, V>();
147        //theList = new ConcurrentSkipListMap<K, V>(); // Java 1.6
148        listener = new DHTMPIListener<K, V>(engine, soc, theList);
149        logger.info("constructor: " + rank + "/" + size + ", useTCP: " + useTCP);
150    }
151
152
153    /**
154     * Hash code.
155     */
156    @Override
157    public int hashCode() {
158        return theList.hashCode();
159    }
160
161
162    /**
163     * Equals.
164     */
165    @Override
166    public boolean equals(Object o) {
167        return theList.equals(o);
168    }
169
170
171    /**
172     * Contains key.
173     */
174    @Override
175    public boolean containsKey(Object o) {
176        return theList.containsKey(o);
177    }
178
179
180    /**
181     * Contains value.
182     */
183    @Override
184    public boolean containsValue(Object o) {
185        return theList.containsValue(o);
186    }
187
188
189    /**
190     * Get the values as Collection.
191     */
192    @Override
193    public Collection<V> values() {
194        synchronized (theList) {
195            return new ArrayList<V>(theList.values());
196        }
197    }
198
199
200    /**
201     * Get the keys as set.
202     */
203    @Override
204    public Set<K> keySet() {
205        synchronized (theList) {
206            return theList.keySet();
207        }
208    }
209
210
211    /**
212     * Get the entries as Set.
213     */
214    @Override
215    public Set<Entry<K, V>> entrySet() {
216        synchronized (theList) {
217            return theList.entrySet();
218        }
219    }
220
221
222    /**
223     * Get the internal list, convert from Collection.
224     */
225    public List<V> getValueList() {
226        synchronized (theList) {
227            return new ArrayList<V>(theList.values());
228        }
229    }
230
231
232    /**
233     * Get the internal sorted map. For synchronization purpose in normalform.
234     */
235    public SortedMap<K, V> getList() {
236        return theList;
237    }
238
239
240    /**
241     * Size of the (local) list.
242     */
243    @Override
244    public int size() {
245        synchronized (theList) {
246            return theList.size();
247        }
248    }
249
250
251    /**
252     * Is the List empty?
253     */
254    @Override
255    public boolean isEmpty() {
256        synchronized (theList) {
257            return theList.isEmpty();
258        }
259    }
260
261
262    /**
263     * List key iterator.
264     */
265    public Iterator<K> iterator() {
266        synchronized (theList) {
267            return theList.keySet().iterator();
268        }
269    }
270
271
272    /**
273     * List value iterator.
274     */
275    public Iterator<V> valueIterator() {
276        synchronized (theList) {
277            return theList.values().iterator();
278        }
279    }
280
281
282    /**
283     * Put object to the distributed hash table. Blocks until the key value pair
284     * is send and received from the server.
285     * @param key
286     * @param value
287     */
288    public void putWait(K key, V value) {
289        put(key, value); // = send
290        // assume key does not change multiple times before test:
291        V val = null;
292        do {
293            val = getWait(key);
294            //System.out.print("#");
295        } while (!value.equals(val));
296    }
297
298
299    /**
300     * Put object to the distributed hash table. Returns immediately after
301     * sending does not block.
302     * @param key
303     * @param value
304     */
305    @Override
306    public V put(K key, V value) {
307        if (key == null || value == null) {
308            throw new NullPointerException("null keys or values not allowed");
309        }
310        try {
311            DHTTransport<K, V> tc = DHTTransport.<K, V> create(key, value);
312            for (int i = 1; i < size; i++) { // send not to self.listener
313                if (useTCP) {
314                    soc[i].send(tc);
315                } else {
316                    DHTTransport[] tcl = new DHTTransport[] { tc };
317                    synchronized (MPIEngine.class) { // do not remove
318                        engine.Send(tcl, 0, tcl.length, MPI.OBJECT, i, DHTTAG);
319                    }
320                }
321            }
322            synchronized (theList) { // add to self.listener
323                theList.put(key, value); //avoid seri: tc.key(), tc.value());
324                theList.notifyAll();
325            }
326            if (debug) {
327                K k = tc.key();
328                if (!key.equals(k)) {
329                    logger.warn("deserial(serial)) != key: " + key + " != " + k);
330                }
331                V v = tc.value();
332                if (!value.equals(v)) {
333                    logger.warn("deserial(serial)) != value: " + value + " != " + v);
334                }
335
336            }
337            //System.out.println("send: "+tc);
338        } catch (ClassNotFoundException e) {
339            logger.info("sending(key=" + key + ")");
340            logger.warn("send " + e);
341            e.printStackTrace();
342        } catch (MPIException e) {
343            logger.info("sending(key=" + key + ")");
344            logger.warn("send " + e);
345            e.printStackTrace();
346        } catch (IOException e) {
347            logger.info("sending(key=" + key + ")");
348            logger.warn("send " + e);
349            e.printStackTrace();
350        }
351        return null;
352    }
353
354
355    /**
356     * Get value under key from DHT. Blocks until the object is send and
357     * received from the server (actually it blocks until some value under key
358     * is received).
359     * @param key
360     * @return the value stored under the key.
361     */
362    public V getWait(K key) {
363        V value = null;
364        try {
365            synchronized (theList) {
366                value = theList.get(key);
367                while (value == null) {
368                    //System.out.print("-");
369                    theList.wait(100);
370                    value = theList.get(key);
371                }
372            }
373        } catch (InterruptedException e) {
374            //Thread.currentThread().interrupt();
375            e.printStackTrace();
376            return value;
377        }
378        return value;
379    }
380
381
382    /**
383     * Get value under key from DHT. If no value is jet available null is
384     * returned.
385     * @param key
386     * @return the value stored under the key.
387     */
388    @Override
389    public V get(Object key) {
390        synchronized (theList) {
391            return theList.get(key);
392        }
393    }
394
395
396    /**
397     * Clear the List. Caveat: must be called on all clients.
398     */
399    @Override
400    public void clear() {
401        // send clear message to others
402        synchronized (theList) {
403            theList.clear();
404        }
405    }
406
407
408    /**
409     * Initialize and start the list thread.
410     */
411    public void init() {
412        logger.info("init " + listener + ", theList = " + theList);
413        if (listener == null) {
414            return;
415        }
416        if (listener.isDone()) {
417            return;
418        }
419        if (debug) {
420            logger.debug("initialize " + listener);
421        }
422        synchronized (theList) {
423            listener.start();
424        }
425    }
426
427
428    /**
429     * Terminate the list thread.
430     */
431    public void terminate() {
432        if (listener == null) {
433            return;
434        }
435        if (debug) {
436            Runtime rt = Runtime.getRuntime();
437            logger.debug("terminate " + listener + ", runtime = " + rt.hashCode());
438        }
439        listener.setDone();
440        DHTTransport<K, V> tc = new DHTTransportTerminate<K, V>();
441        try {
442            if (rank == 0) {
443                //logger.info("send(" + rank + ") terminate");
444                for (int i = 1; i < size; i++) { // send not to self.listener
445                    if (useTCP) {
446                        soc[i].send(tc);
447                    } else {
448                        DHTTransport[] tcl = new DHTTransport[] { tc };
449                        synchronized (MPIEngine.class) { // do not remove
450                            engine.Send(tcl, 0, tcl.length, MPI.OBJECT, i, DHTTAG);
451                        }
452                    }
453                }
454            }
455        } catch (MPIException e) {
456            logger.info("sending(terminate)");
457            logger.info("send " + e);
458            e.printStackTrace();
459        } catch (IOException e) {
460            logger.info("sending(terminate)");
461            logger.info("send " + e);
462            e.printStackTrace();
463        }
464        try {
465            while (listener.isAlive()) {
466                //System.out.print("+++++");
467                listener.join(999);
468                listener.interrupt();
469            }
470        } catch (InterruptedException e) {
471            //Thread.currentThread().interrupt();
472        }
473        listener = null;
474    }
475
476}
477
478
479/**
480 * Thread to comunicate with the other DHT lists.
481 */
482class DHTMPIListener<K, V> extends Thread {
483
484
485    private static final Logger logger = LogManager.getLogger(DHTMPIListener.class);
486
487
488    private static final boolean debug = logger.isDebugEnabled();
489
490
491    private final Comm engine;
492
493
494    private final SortedMap<K, V> theList;
495
496
497    private final SocketChannel[] soc;
498
499
500    private boolean goon;
501
502
503    /**
504     * Constructor.
505     */
506    DHTMPIListener(Comm cm, SocketChannel[] s, SortedMap<K, V> list) {
507        engine = cm;
508        theList = list;
509        goon = true;
510        soc = s;
511    }
512
513
514    /**
515     * Test if done.
516     */
517    boolean isDone() {
518        return !goon;
519    }
520
521
522    /**
523     * Set to done status.
524     */
525    void setDone() {
526        goon = false;
527    }
528
529
530    /**
531     * run.
532     */
533    @SuppressWarnings("unchecked")
534    @Override
535    public void run() {
536        logger.info("listener run() " + this);
537        int rank = -1;
538        DHTTransport<K, V> tc;
539        //goon = true;
540        while (goon) {
541            tc = null;
542            try {
543                if (rank < 0) {
544                    rank = engine.Rank();
545                }
546                if (rank == 0) {
547                    logger.info("listener on rank 0 stopped");
548                    goon = false;
549                    continue;
550                }
551                Object to = null;
552                if (DistHashTableMPI.useTCP) {
553                    to = soc[0].receive();
554                } else {
555                    DHTTransport[] tcl = new DHTTransport[1];
556                    Status stat = null;
557                    synchronized (MPIEngine.class) { // do not remove global static lock, // only from 0:
558                        stat = engine.Recv(tcl, 0, tcl.length, MPI.OBJECT, 0,             //MPI.ANY_SOURCE,
559                                        DistHashTableMPI.DHTTAG);
560                    }
561                    //logger.info("waitRequest done: stat = " + stat);
562                    if (stat == null) {
563                        goon = false;
564                        break;
565                    }
566                    int cnt = stat.Get_count(MPI.OBJECT);
567                    if (cnt == 0) {
568                        goon = false;
569                        break;
570                    } else if (cnt > 1) {
571                        logger.warn("ignoring " + (cnt - 1) + " received objects");
572                    }
573                    to = tcl[0];
574                }
575                tc = (DHTTransport<K, V>) to;
576                if (debug) {
577                    logger.debug("receive(" + tc + ")");
578                }
579                if (tc instanceof DHTTransportTerminate) {
580                    logger.info("receive(" + rank + ") terminate");
581                    goon = false;
582                    break;
583                }
584                if (this.isInterrupted()) {
585                    goon = false;
586                    break;
587                }
588                K key = tc.key();
589                if (key != null) {
590                    logger.info("receive(" + rank + "), key=" + key);
591                    V val = tc.value();
592                    synchronized (theList) {
593                        theList.put(key, val);
594                        theList.notifyAll();
595                    }
596                }
597            } catch (MPIException e) {
598                goon = false;
599                logger.warn("receive(MPI) " + e);
600                //e.printStackTrace();
601            } catch (ClassNotFoundException e) {
602                goon = false;
603                logger.info("receive(Class) " + e);
604                e.printStackTrace();
605            } catch (Exception e) {
606                goon = false;
607                logger.info("receive " + e);
608                e.printStackTrace();
609            }
610        }
611        logger.info("terminated at " + rank);
612    }
613
614}