001/*
002 * $Id$
003 */
004
005
006package edu.jas.util;
007
008
009import java.io.IOException;
010import java.util.Arrays;
011
012import mpi.Comm;
013import mpi.MPI;
014import mpi.MPIException;
015import mpi.Status;
016
017import org.apache.logging.log4j.Logger;
018import org.apache.logging.log4j.LogManager;
019
020import edu.jas.kern.MPIEngine;
021
022
023/**
024 * MPIChannel provides a communication channel for Java objects using MPI or
025 * TCP/IP to a given rank.
026 * @author Heinz Kredel
027 */
028public final class MPIChannel {
029
030
031    private static final Logger logger = LogManager.getLogger(MPIChannel.class);
032
033
034    public static final int CHANTAG = MPIEngine.TAG + 2;
035
036
037    /*
038     * Underlying MPI engine.
039     */
040    private final Comm engine; // essentially static when useTCP !
041
042
043    /*
044     * Size of Comm.
045     */
046    private final int size;
047
048
049    /*
050     * This rank.
051     */
052    private final int rank;
053
054
055    /*
056     * TCP/IP object channels with tags.
057     */
058    private static TaggedSocketChannel[] soc = null;
059
060
061    /*
062     * Transport layer.
063     * true: use TCP/IP socket layer, false: use MPI transport layer.
064     * Can not be set to false for OpenMPI Java: not working.
065     */
066    static final boolean useTCP = true;
067
068
069    /*
070     * Partner rank.
071     */
072    private final int partnerRank;
073
074
075    /*
076     * Message tag.
077     */
078    private final int tag;
079
080
081    /**
082     * Constructs a MPI channel on the given MPI engine.
083     * @param s MPI communicator object.
084     * @param r rank of MPI partner.
085     */
086    public MPIChannel(Comm s, int r) throws IOException, MPIException {
087        this(s, r, CHANTAG);
088    }
089
090
091    /**
092     * Constructs a MPI channel on the given MPI engine.
093     * @param s MPI communicator object.
094     * @param r rank of MPI partner.
095     * @param t tag for messages.
096     */
097    public MPIChannel(Comm s, int r, int t) throws IOException, MPIException {
098        engine = s;
099        rank = engine.Rank();
100        size = engine.Size();
101        if (r < 0 || size <= r) {
102            throw new IOException("r out of bounds: 0 <= r < size: " + r + ", " + size);
103        }
104        partnerRank = r;
105        tag = t;
106        synchronized (engine) {
107            if (soc == null && useTCP) {
108                int port = ChannelFactory.DEFAULT_PORT;
109                ChannelFactory cf;
110                if (rank == 0) {
111                    cf = new ChannelFactory(port);
112                    cf.init();
113                    soc = new TaggedSocketChannel[size];
114                    soc[0] = null;
115                    try {
116                        for (int i = 1; i < size; i++) {
117                            SocketChannel sc = cf.getChannel(); // TODO not correct wrt rank
118                            soc[i] = new TaggedSocketChannel(sc);
119                            soc[i].init();
120                        }
121                    } catch (InterruptedException e) {
122                        throw new IOException(e);
123                    }
124                    cf.terminate();
125                } else {
126                    cf = new ChannelFactory(port - 1); // in case of localhost
127                    soc = new TaggedSocketChannel[1];
128                    SocketChannel sc = cf.getChannel(MPIEngine.hostNames.get(0), port);
129                    soc[0] = new TaggedSocketChannel(sc);
130                    soc[0].init();
131                    cf.terminate();
132                }
133            }
134        }
135        logger.info("constructor: " + this.toString() + ", useTCP: " + useTCP);
136    }
137
138
139    /**
140     * Get the MPI engine.
141     */
142    public Comm getEngine() {
143        return engine;
144    }
145
146
147    /**
148     * Sends an object.
149     * @param v message object.
150     */
151    public void send(Object v) throws IOException, MPIException {
152        send(tag, v, partnerRank);
153    }
154
155
156    /**
157     * Sends an object.
158     * @param t message tag.
159     * @param v message object.
160     */
161    public void send(int t, Object v) throws IOException, MPIException {
162        send(t, v, partnerRank);
163    }
164
165
166    /**
167     * Sends an object.
168     * @param t message tag.
169     * @param v message object.
170     * @param pr partner rank.
171     */
172    void send(int t, Object v, int pr) throws IOException, MPIException {
173        if (useTCP) {
174            if (soc == null) {
175                logger.warn("soc not initialized: lost " + v);
176                return;
177            }
178            if (soc[pr] == null) {
179                logger.warn("soc[" + pr + "] not initialized: lost " + v);
180                return;
181            }
182            soc[pr].send(t, v);
183        } else {
184            Object[] va = new Object[] { v };
185            //synchronized (MPJEngine.class) {
186            engine.Send(va, 0, va.length, MPI.OBJECT, pr, t);
187            //}
188        }
189    }
190
191
192    /**
193     * Receives an object.
194     * @return a message object.
195     */
196    public Object receive() throws IOException, ClassNotFoundException, MPIException {
197        return receive(tag);
198    }
199
200
201    /**
202     * Receives an object.
203     * @param t message tag.
204     * @return a message object.
205     */
206    public Object receive(int t) throws IOException, ClassNotFoundException, MPIException {
207        if (useTCP) {
208            if (soc == null) {
209                logger.warn("soc not initialized");
210                return null;
211            }
212            if (soc[partnerRank] == null) {
213                logger.warn("soc[" + partnerRank + "] not initialized");
214                return null;
215            }
216            try {
217                return soc[partnerRank].receive(t);
218            } catch (InterruptedException e) {
219                throw new IOException(e);
220            }
221        }
222        Object[] va = new Object[1];
223        Status stat = null;
224        //synchronized (MPJEngine.class) {
225        stat = engine.Recv(va, 0, va.length, MPI.OBJECT, partnerRank, t);
226        //}
227        if (stat == null) {
228            throw new IOException("received null Status");
229        }
230        int cnt = stat.Get_count(MPI.OBJECT);
231        if (cnt == 0) {
232            throw new IOException("no object received");
233        }
234        if (cnt > 1) {
235            logger.warn("too many objects received, ignored " + (cnt - 1));
236        }
237        // int pr = stat.source;
238        // if (pr != partnerRank) {
239        //     logger.warn("received out of order message from " + pr);
240        // }
241        return va[0];
242    }
243
244
245    /**
246     * Closes the channel.
247     */
248    public void close() {
249        if (useTCP) {
250            if (soc == null) {
251                return;
252            }
253            for (int i = 0; i < soc.length; i++) {
254                if (soc[i] != null) {
255                    soc[i].close();
256                    soc[i] = null;
257                }
258            }
259        }
260    }
261
262
263    /**
264     * to string.
265     */
266    @Override
267    public String toString() {
268        return "MPIChannel(on=" + rank + ",to=" + partnerRank + ",tag=" + tag + "," + Arrays.toString(soc)
269                        + ")";
270    }
271
272}