001/*
002 * $Id$
003 */
004
005
006package edu.jas.util;
007
008
009import java.io.IOException;
010import java.util.Arrays;
011
012import org.apache.logging.log4j.Logger;
013import org.apache.logging.log4j.LogManager;
014
015import edu.jas.kern.MPJEngine;
016
017import mpi.Comm;
018import mpi.MPI;
019import mpi.MPIException;
020import mpi.Status;
021
022
023/**
024 * MPJChannel provides a communication channel for Java objects using MPI or
025 * TCP/IP to a given rank. Can use MPI transport layer for "niodev" with
026 * FastMPJ.
027 * @author Heinz Kredel
028 */
029public final class MPJChannel {
030
031
032    private static final Logger logger = LogManager.getLogger(MPJChannel.class);
033
034
035    public static final int CHANTAG = MPJEngine.TAG + 2;
036
037
038    /*
039     * Underlying MPI engine.
040     */
041    private final Comm engine; // essentially static when useTCP !
042
043
044    /*
045     * Size of Comm.
046     */
047    private final int size;
048
049
050    /*
051     * This rank.
052     */
053    private final int rank;
054
055
056    /*
057     * TCP/IP object channels with tags.
058     */
059    private static TaggedSocketChannel[] soc = null;
060
061
062    /*
063     * Transport layer.
064     * true: use TCP/IP socket layer, false: use MPI transport layer.
065     * Can be set to false for "niodev" with FastMPJ.
066     */
067    static boolean useTCP = false;
068
069
070    /*
071     * Partner rank.
072     */
073    private final int partnerRank;
074
075
076    /*
077     * Message tag.
078     */
079    private final int tag;
080
081
082    /**
083     * Constructs a MPI channel on the given MPI engine.
084     * @param s MPI communicator object.
085     * @param r rank of MPI partner.
086     */
087    public MPJChannel(Comm s, int r) throws IOException, MPIException {
088        this(s, r, CHANTAG);
089    }
090
091
092    /**
093     * Constructs a MPI channel on the given MPI engine.
094     * @param s MPI communicator object.
095     * @param r rank of MPI partner.
096     * @param t tag for messages.
097     */
098    public MPJChannel(Comm s, int r, int t) throws IOException, MPIException {
099        engine = s;
100        rank = engine.Rank();
101        size = engine.Size();
102        if (r < 0 || size <= r) {
103            throw new IOException("r out of bounds: 0 <= r < size: " + r + ", " + size);
104        }
105        partnerRank = r;
106        tag = t;
107        synchronized (engine) {
108            if (soc == null && useTCP) {
109                int port = ChannelFactory.DEFAULT_PORT;
110                ChannelFactory cf;
111                if (rank == 0) {
112                    cf = new ChannelFactory(port);
113                    cf.init();
114                    soc = new TaggedSocketChannel[size];
115                    soc[0] = null;
116                    try {
117                        for (int i = 1; i < size; i++) {
118                            SocketChannel sc = cf.getChannel(); // TODO not correct wrt rank
119                            soc[i] = new TaggedSocketChannel(sc);
120                            soc[i].init();
121                        }
122                    } catch (InterruptedException e) {
123                        throw new IOException(e);
124                    }
125                    cf.terminate();
126                } else {
127                    cf = new ChannelFactory(port - 1); // in case of localhost
128                    soc = new TaggedSocketChannel[1];
129                    SocketChannel sc = cf.getChannel(MPJEngine.hostNames.get(0), port);
130                    soc[0] = new TaggedSocketChannel(sc);
131                    soc[0].init();
132                    cf.terminate();
133                }
134            }
135        }
136        logger.info("constructor: " + this.toString() + ", useTCP: " + useTCP);
137    }
138
139
140    /**
141     * Get the MPI engine.
142     */
143    public Comm getEngine() {
144        return engine;
145    }
146
147
148    /**
149     * Sends an object.
150     * @param v message object.
151     */
152    public void send(Object v) throws IOException, MPIException {
153        send(tag, v, partnerRank);
154    }
155
156
157    /**
158     * Sends an object.
159     * @param t message tag.
160     * @param v message object.
161     */
162    public void send(int t, Object v) throws IOException, MPIException {
163        send(t, v, partnerRank);
164    }
165
166
167    /**
168     * Sends an object.
169     * @param t message tag.
170     * @param v message object.
171     * @param pr partner rank.
172     */
173    void send(int t, Object v, int pr) throws IOException, MPIException {
174        if (useTCP) {
175            if (soc == null) {
176                logger.warn("soc not initialized: lost " + v);
177                return;
178            }
179            if (soc[pr] == null) {
180                logger.warn("soc[" + pr + "] not initialized: lost " + v);
181                return;
182            }
183            soc[pr].send(t, v);
184        } else {
185            Object[] va = new Object[] { v };
186            //synchronized (MPJEngine.class) {
187            engine.Send(va, 0, va.length, MPI.OBJECT, pr, t);
188            //}
189        }
190    }
191
192
193    /**
194     * Receives an object.
195     * @return a message object.
196     */
197    public Object receive() throws IOException, ClassNotFoundException, MPIException {
198        return receive(tag);
199    }
200
201
202    /**
203     * Receives an object.
204     * @param t message tag.
205     * @return a message object.
206     */
207    public Object receive(int t) throws IOException, ClassNotFoundException, MPIException {
208        if (useTCP) {
209            if (soc == null) {
210                logger.warn("soc not initialized");
211                return null;
212            }
213            if (soc[partnerRank] == null) {
214                logger.warn("soc[" + partnerRank + "] not initialized");
215                return null;
216            }
217            try {
218                return soc[partnerRank].receive(t);
219            } catch (InterruptedException e) {
220                throw new IOException(e);
221            }
222        }
223        Object[] va = new Object[1];
224        Status stat = null;
225        //synchronized (MPJEngine.class) {
226        stat = engine.Recv(va, 0, va.length, MPI.OBJECT, partnerRank, t);
227        //}
228        if (stat == null) {
229            throw new IOException("received null Status");
230        }
231        int cnt = stat.Get_count(MPI.OBJECT);
232        if (cnt == 0) {
233            throw new IOException("no object received");
234        }
235        if (cnt > 1) {
236            logger.warn("too many objects received, ignored " + (cnt - 1));
237        }
238        // int pr = stat.source;
239        // if (pr != partnerRank) {
240        //     logger.warn("received out of order message from " + pr);
241        // }
242        return va[0];
243    }
244
245
246    /**
247     * Closes the channel.
248     */
249    public void close() {
250        if (useTCP) {
251            if (soc == null) {
252                return;
253            }
254            for (int i = 0; i < soc.length; i++) {
255                if (soc[i] != null) {
256                    soc[i].close();
257                    soc[i] = null;
258                }
259            }
260        }
261    }
262
263
264    /**
265     * to string.
266     */
267    @Override
268    public String toString() {
269        return "MPJChannel(on=" + rank + ",to=" + partnerRank + ",tag=" + tag + "," + Arrays.toString(soc)
270                        + ")";
271    }
272
273}