001/*
002 * $Id$
003 */
004
005package edu.jas.vector;
006
007
008import java.io.Serializable;
009import java.util.ArrayList;
010import java.util.List;
011
012import org.apache.logging.log4j.LogManager;
013import org.apache.logging.log4j.Logger;
014
015import edu.jas.structure.RingElem;
016
017
018/**
019 * Linear algebra methods. Implements linear algebra computations and tests,
020 * mainly based on Gauss elimination. Partly based on <a href=
021 * "https://en.wikipedia.org/wiki/LU_decomposition">LU_decomposition</a>.
022 * Computation of Null space basis, row echelon form, inverses and ranks.
023 * @param <C> coefficient type
024 * @author Heinz Kredel
025 */
026
027public class LinAlg<C extends RingElem<C>> implements Serializable {
028
029
030    private static final Logger logger = LogManager.getLogger(LinAlg.class);
031
032
033    //private static final boolean debug = logger.isDebugEnabled();
034
035
036    /**
037     * Constructor.
038     */
039    public LinAlg() {
040    }
041
042
043    /**
044     * Matrix LU decomposition. Matrix A is replaced by its LU decomposition. A
045     * contains a copy of both matrices L-E and U as A=(L-E)+U such that
046     * P*A=L*U. The permutation matrix is not stored as a matrix, but in an
047     * integer vector P of size N+1 containing column indexes where the
048     * permutation matrix has "1". The last element P[N]=S+N, where S is the
049     * number of row exchanges needed for determinant computation, det(P)=(-1)^S
050     * @param A a n&times;n matrix.
051     * @return permutation vector P and modified matrix A.
052     */
053    public List<Integer> decompositionLU(GenMatrix<C> A) {
054        if (A == null) {
055            return null;
056        }
057        GenMatrixRing<C> ring = A.ring;
058        int N = ring.rows;
059        int M = ring.cols;
060        int NM = Math.min(N,M);
061        if (N != M) {
062            logger.warn("nosquare matrix");
063        }
064        List<Integer> P = new ArrayList<Integer>(NM + 1);
065        for (int i = 0; i <= NM; i++) {
066            P.add(i); //Unit permutation matrix, P[NM] initialized with NM
067        }
068        ArrayList<ArrayList<C>> mat = A.matrix;
069        for (int i = 0; i < NM; i++) {
070            int imax = i;
071            C maxA = ring.coFac.getZERO();
072            for (int k = i; k < N; k++) {
073                // absA = fabs(A[k][i])
074                C absA = mat.get(k).get(i).abs();
075                if (absA.compareTo(maxA) > 0) {
076                    maxA = absA;
077                    imax = k;
078                    break; // first
079                }
080            }
081            if (maxA.isZERO()) {
082                logger.warn("matrix is degenerate at col {}", i);
083                mat.get(i).set(i, ring.coFac.getZERO()); // already zero
084                //continue;
085                P.clear();
086                return P; //failure, matrix is degenerate
087            }
088            if (imax != i) {
089                //pivoting P
090                int j = P.get(i);
091                P.set(i, P.get(imax));
092                P.set(imax, j);
093                //System.out.println("new pivot " + imax); // + ", P = " + P);
094                //pivoting rows of A
095                ArrayList<C> ptr = mat.get(i);
096                mat.set(i, mat.get(imax));
097                mat.set(imax, ptr);
098                //counting pivots starting from NM (for determinant)
099                P.set(NM, P.get(NM) + 1);
100            }
101            C dd = mat.get(i).get(i).inverse();
102            for (int j = i + 1; j < N; j++) {
103                // A[j][i] /= A[i][i];
104                C d = mat.get(j).get(i).multiply(dd); //divide(dd);
105                mat.get(j).set(i, d);
106                for (int k = i + 1; k < M; k++) {
107                    // A[j][k] -= A[j][i] * A[i][k];
108                    C a = mat.get(j).get(i).multiply(mat.get(i).get(k));
109                    mat.get(j).set(k, mat.get(j).get(k).subtract(a));
110                }
111            }
112            //System.out.println("row(last) = " + mat.get(NM-1));
113        }
114        return P;
115    }
116
117
118    /**
119     * Solve with LU decomposition.
120     * @param A a n&times;n matrix in LU decomposition.
121     * @param P permutation vector.
122     * @param b right hand side vector.
123     * @return x solution vector of A*x = b.
124     */
125    public GenVector<C> solveLU(GenMatrix<C> A, List<Integer> P, GenVector<C> b) {
126        if (A == null || b == null) {
127            return null;
128        }
129        if (P.size() == 0) {
130            return null;
131        }
132        GenMatrixRing<C> ring = A.ring;
133        int N = ring.rows;
134        GenVectorModul<C> xfac = new GenVectorModul<C>(ring.coFac, N);
135        GenVector<C> x = new GenVector<C>(xfac);
136        List<C> vec = x.val;
137        ArrayList<ArrayList<C>> mat = A.matrix;
138        for (int i = 0; i < N; i++) {
139            //x[i] = b[P[i]];
140            vec.set(i, b.get(P.get(i)));
141            C xi = vec.get(i);
142            for (int k = 0; k < i; k++) {
143                //x[i] -= A[i][k] * x[k];
144                C ax = mat.get(i).get(k).multiply(vec.get(k));
145                xi = xi.subtract(ax);
146            }
147            vec.set(i, xi);
148        }
149        //System.out.println("vec = " + vec);
150        for (int i = N - 1; i >= 0; i--) {
151            C xi = vec.get(i);
152            for (int k = i + 1; k < N; k++) {
153                //x[i] -= A[i][k] * x[k];
154                C ax = mat.get(i).get(k).multiply(vec.get(k));
155                xi = xi.subtract(ax);
156            }
157            vec.set(i, xi);
158            //x[i] /= A[i][i];
159            vec.set(i, xi.divide(mat.get(i).get(i)));
160        }
161        return x;
162    }
163
164
165    /**
166     * Solve linear system of equations.
167     * @param A a n&times;n matrix.
168     * @param b right hand side vector.
169     * @return x solution vector of A*x = b.
170     */
171    public GenVector<C> solve(GenMatrix<C> A, GenVector<C> b) {
172        if (A == null || b == null) {
173            return null;
174        }
175        GenMatrix<C> Ap = A.copy();
176        List<Integer> P = decompositionLU(Ap);
177        if (P.size() == 0) {
178            System.out.println("undecomposable");
179            return b.modul.getZERO();
180        }
181        GenVector<C> x = solveLU(Ap, P, b);
182        return x;
183    }
184
185
186    /**
187     * Determinant with LU decomposition.
188     * @param A a n&times;n matrix in LU decomposition.
189     * @param P permutation vector.
190     * @return d determinant of A.
191     */
192    public C determinantLU(GenMatrix<C> A, List<Integer> P) {
193        if (A == null) {
194            return null;
195        }
196        if (P.size() == 0) {
197            return A.ring.coFac.getZERO();
198        }
199        int N = A.ring.rows; //P.size() - 1 - 1;
200        ArrayList<ArrayList<C>> mat = A.matrix;
201        // det = A[0][0];
202        C det = mat.get(0).get(0);
203        for (int i = 1; i < N; i++) {
204            //det *= A[i][i];
205            det = det.multiply(mat.get(i).get(i));
206        }
207        //return (P[N] - N) % 2 == 0 ? det : -det
208        int s = P.get(N) - N;
209        if (s % 2 != 0) {
210            det = det.negate();
211        }
212        return det;
213    }
214
215
216    /**
217     * Inverse with LU decomposition.
218     * @param A a n&times;n matrix in LU decomposition.
219     * @param P permutation vector.
220     * @return inv(A) with A * inv(A) == 1.
221     */
222    public GenMatrix<C> inverseLU(GenMatrix<C> A, List<Integer> P) {
223        GenMatrixRing<C> ring = A.ring;
224        GenMatrix<C> inv = new GenMatrix<C>(ring);
225        int N = ring.rows; //P.size() - 1 - 1;
226        ArrayList<ArrayList<C>> mat = A.matrix;
227        ArrayList<ArrayList<C>> imat = inv.matrix;
228        for (int j = 0; j < N; j++) {
229            // transform right hand vector with L matrix
230            for (int i = 0; i < N; i++) {
231                //IA[i][j] = P[i] == j ? 1.0 : 0.0;
232                C e = (P.get(i) == j) ? ring.coFac.getONE() : ring.coFac.getZERO();
233                imat.get(i).set(j, e);
234                C b = e; //imat.get(i).get(j);
235                for (int k = 0; k < i; k++) {
236                    //IA[i][j] -= A[i][k] * IA[k][j];
237                    C a = mat.get(i).get(k).multiply(imat.get(k).get(j));
238                    b = b.subtract(a);
239                }
240                imat.get(i).set(j, b);
241            }
242            // solve inverse matrix column with U matrix
243            for (int i = N - 1; i >= 0; i--) {
244                C b = imat.get(i).get(j);
245                for (int k = i + 1; k < N; k++) {
246                    //IA[i][j] -= A[i][k] * IA[k][j];
247                    C a = mat.get(i).get(k).multiply(imat.get(k).get(j));
248                    b = b.subtract(a);
249                }
250                imat.get(i).set(j, b);
251                //IA[i][j] /= A[i][i];
252                C e = b; //imat.get(i).get(j);
253                e = e.divide(mat.get(i).get(i));
254                imat.get(i).set(j, e);
255            }
256        }
257        return inv;
258    }
259
260
261    /**
262     * Matrix Null Space basis, cokernel. From the transpose matrix At it
263     * computes the kernel with At*v_i = 0.
264     * @param A a n&times;n matrix.
265     * @return V a list of basis vectors (v_1, ..., v_k) with v_i*A == 0.
266     */
267    public List<GenVector<C>> nullSpaceBasis(GenMatrix<C> A) {
268        if (A == null) {
269            return null;
270        }
271        GenMatrixRing<C> ring = A.ring;
272        int N = ring.rows;
273        int M = ring.cols;
274        if (N != M) {
275            logger.warn("nosquare matrix");
276        }
277        List<GenVector<C>> nspb = new ArrayList<GenVector<C>>();
278        GenVectorModul<C> vfac = new GenVectorModul<C>(ring.coFac, M);
279        ArrayList<ArrayList<C>> mat = A.matrix;
280        for (int i = 0; i < N; i++) {
281            C maxA, absA;
282            // search privot imax
283            int imax = i;
284            maxA = ring.coFac.getZERO();
285            for (int k = i; k < M; k++) { // k = 0 ?
286                // absA = fabs(A[k][i])
287                absA = mat.get(i).get(k).abs();
288                if (absA.compareTo(maxA) > 0 && maxA.isZERO()) {
289                    maxA = absA;
290                    imax = k;
291                }
292            }
293            logger.info("pivot: {}, i = {}, maxA = {}", imax, i, maxA);
294            if (maxA.isZERO()) {
295                // check for complete zero row or left pivot
296                int imaxl = i;
297                for (int k = 0; k < i; k++) { // k = 0 ?
298                    // absA = fabs(A[k][i])
299                    absA = mat.get(i).get(k).abs();
300                    if (absA.compareTo(maxA) > 0) { // first or last imax: && maxA.isZERO()
301                        imaxl = k;
302                        // check if upper triangular column is zero
303                        boolean iszero = true;
304                        for (int m = 0; m < i; m++) {
305                            C amm = mat.get(m).get(imaxl).abs();
306                            if (!amm.isZERO()) {
307                                iszero = false;
308                                break;
309                            }
310                        }
311                        if (iszero) { // left pivot okay
312                            imax = imaxl;
313                            logger.info("pivot*: {}, i = {}, absA = {}", imax, i, absA);
314                            maxA = ring.coFac.getONE();
315                        }
316                    }
317                }
318                if (maxA.isZERO()) { // complete zero row
319                    continue;
320                }
321            }
322            if (imax < M) { //!= i
323                //normalize column i
324                C mp = mat.get(i).get(imax).inverse();
325                for (int k = 0; k < N; k++) { // k = i ?
326                    C b = mat.get(k).get(imax);
327                    b = b.multiply(mp);
328                    mat.get(k).set(imax, b);
329                }
330                //pivoting columns of A
331                if (imax != i) {
332                    for (int k = 0; k < N; k++) {
333                        C b = mat.get(k).get(i);
334                        mat.get(k).set(i, mat.get(k).get(imax));
335                        mat.get(k).set(imax, b);
336                    }
337                }
338                //eliminate rest of row i via column operations
339                for (int j = 0; j < M; j++) {
340                    if (i == j) { // is already normalized
341                        continue;
342                    }
343                    C mm = mat.get(i).get(j);
344                    for (int k = 0; k < N; k++) { // or k = 0
345                        C b = mat.get(k).get(j);
346                        C c = mat.get(k).get(i);
347                        C d = b.subtract(c.multiply(mm));
348                        mat.get(k).set(j, d);
349                    }
350                }
351            }
352        }
353        // convert to A-I
354        for (int i = 0; i < N; i++) {
355            C b = mat.get(i).get(i);
356            b = b.subtract(ring.coFac.getONE());
357            mat.get(i).set(i, b);
358        }
359        //System.out.println("mat-1 = " + A);
360        // read off non zero rows of A
361        for (int i = 0; i < N; i++) {
362            List<C> row = mat.get(i);
363            boolean iszero = true;
364            for (int k = 0; k < M; k++) {
365                if (!row.get(k).isZERO()) {
366                    iszero = false;
367                    break;
368                }
369            }
370            if (!iszero) {
371                GenVector<C> v = new GenVector<C>(vfac, row);
372                nspb.add(v);
373            }
374        }
375        return nspb;
376    }
377
378
379    /**
380     * Rank via null space.
381     * @param A a n&times;n matrix.
382     * @return r rank of A.
383     */
384    public long rankNS(GenMatrix<C> A) {
385        if (A == null) {
386            return -1l;
387        }
388        GenMatrix<C> Ap = A.copy();
389        long n = Math.min(A.ring.rows, A.ring.cols);
390        List<GenVector<C>> ns = nullSpaceBasis(Ap);
391        long s = ns.size();
392        return n - s;
393    }
394
395
396    /**
397     * Matrix row echelon form construction. Matrix A is replaced by its row
398     * echelon form, an upper triangle matrix.
399     * @param A a n&times;m matrix.
400     * @return A row echelon form of A, matrix A is modified.
401     */
402    public GenMatrix<C> rowEchelonForm(GenMatrix<C> A) {
403        if (A == null) {
404            return null;
405        }
406        GenMatrixRing<C> ring = A.ring;
407        int N = ring.rows;
408        int M = ring.cols;
409        if (N != M) {
410            logger.warn("nosquare matrix");
411        }
412        int kmax = 0;
413        ArrayList<ArrayList<C>> mat = A.matrix;
414        for (int i = 0; i < N;) {
415            int imax = i;
416            C maxA = ring.coFac.getZERO();
417            // search non-zero rows
418            for (int k = i; k < N; k++) {
419                // absA = fabs(A[k][i])
420                C absA = mat.get(k).get(kmax).abs();
421                if (absA.compareTo(maxA) > 0) {
422                    maxA = absA;
423                    imax = k;
424                    break; // first
425                }
426            }
427            if (maxA.isZERO()) {
428                //System.out.println("matrix is zero at col " + kmax);
429                kmax++;
430                if (kmax >= M) {
431                    break;
432                }
433                continue;
434            }
435            //System.out.println("matrix is non zero at row " + imax);
436            if (imax != i) {
437                //swap pivoting rows of A
438                ArrayList<C> ptr = mat.get(i);
439                mat.set(i, mat.get(imax));
440                mat.set(imax, ptr);
441            }
442            // A[j][i] /= A[i][i];
443            C dd = mat.get(i).get(kmax).inverse();
444            //System.out.println("matrix is non zero at row " + imax + ", dd = " + dd);
445            for (int k = kmax; k < M; k++) {
446                C d = mat.get(i).get(k).multiply(dd); //divide(dd);
447                mat.get(i).set(k, d);
448            }
449            for (int j = i + 1; j < N; j++) {
450                for (int k = kmax; k < M; k++) {
451                    // A[j][k] -= A[j][k] * A[i][k];
452                    C a = mat.get(j).get(k).multiply(mat.get(i).get(k));
453                    if (a.isZERO()) {
454                        continue;
455                    }
456                    mat.get(j).set(k, mat.get(j).get(k).subtract(a));
457                }
458            }
459            mat.get(i).set(kmax, ring.coFac.getONE());
460            i++;
461            kmax++;
462            if (kmax >= M) {
463                break;
464            }
465            //System.out.println("rowEch(last) = " + mat.get(N-1));
466        }
467        return A;
468    }
469
470
471    /**
472     * Rank via row echelon form.
473     * @param A a n&times;n matrix.
474     * @return r rank of A.
475     */
476    public long rankRE(GenMatrix<C> A) {
477        if (A == null) {
478            return -1l;
479        }
480        long n = A.ring.rows;
481        long m = A.ring.cols;
482        ArrayList<ArrayList<C>> mat = A.matrix;
483        // count non-zero rows
484        long r = 0;
485        for (int i = 0; i < n; i++) {
486            ArrayList<C> row = mat.get(i);
487            for (int j = i; j < m; j++) {
488                if (!row.get(j).isZERO()) {
489                    r++;
490                    break;
491                }
492            }
493        }
494        return r;
495    }
496
497
498    /**
499     * Matrix row echelon form construction. Matrix A is replaced by
500     * its row echelon form, an upper triangle matrix with less
501     * non-zero entries. No column swaps and transforms are performed
502     * as with the Gauss-Jordan algorithm.
503     * @param A a n&times;m matrix.
504     * @return A sparse row echelon form of A, matrix A is modified.
505     */
506    public GenMatrix<C> rowEchelonFormSparse(GenMatrix<C> A) {
507        if (A == null) {
508            return null;
509        }
510        GenMatrixRing<C> ring = A.ring;
511        int N = ring.rows;
512        int M = ring.cols;
513        if (N != M) {
514            logger.warn("nosquare matrix");
515        }
516        int i, imax, kmax;
517        C maxA, absA;
518        ArrayList<ArrayList<C>> mat = A.matrix;
519        for (i = N - 1; i > 0; i--) {
520            imax = i;
521            maxA = ring.coFac.getZERO();
522            //System.out.println("matrix row " + A.getRow(i));
523            kmax = -1;
524            // search non-zero entry in row i
525            for (int k = i; k < M; k++) {
526                // absA = fabs(A[i][k])
527                absA = mat.get(i).get(k).abs();
528                if (absA.compareTo(maxA) > 0) {
529                    //System.out.println("absA(" + i +"," + k + ") = " + absA);
530                    maxA = absA;
531                    kmax = k;
532                    break; // first
533                }
534            }
535            if (maxA.isZERO()) {
536                continue;
537            }
538            // reduce upper rows
539            for (int j = imax - 1; j >= 0; j--) {
540                for (int k = kmax; k < M; k++) {
541                    // A[j,k] -= A[j,k] * A[imax,k]
542                    C mjk = mat.get(j).get(k);
543                    if (mjk.isZERO()) {
544                        continue;
545                    }
546                    C mk = mat.get(imax).get(k);
547                    if (mk.isZERO()) {
548                        continue;
549                    }
550                    C a = mk.multiply(mjk);
551                    //System.out.println("mjk(" + j +"," + k + ") = " + mjk + ", mk = " + mk);
552                    mjk = mjk.subtract(a);
553                    mat.get(j).set(k,mjk);
554                }
555            }
556        }
557        return A;
558    }
559
560
561    /**
562     * Matrix fraction free Gauss elimination. Matrix A is replaced by
563     * its fraction free LU decomposition. A contains a copy of both
564     * matrices L-E and U as A=(L-E)+U such that P*A=L*U. TODO: L is
565     * not computed but 0. The permutation matrix is not stored as a
566     * matrix, but in an integer vector P of size N+1 containing
567     * column indexes where the permutation matrix has "1". The last
568     * element P[N]=S+N, where S is the number of row exchanges needed
569     * for determinant computation, det(P)=(-1)^S
570     * @param A a n&times;n matrix.
571     * @return permutation vector P and modified matrix A.
572     */
573    public List<Integer> fractionfreeGaussElimination(GenMatrix<C> A) {
574        if (A == null) {
575            return null;
576        }
577        GenMatrixRing<C> ring = A.ring;
578        int N = ring.rows;
579        int M = ring.cols;
580        int NM = Math.min(N,M);
581        if (N != M) {
582            logger.warn("nosquare matrix");
583        }
584        List<Integer> P = new ArrayList<Integer>(NM + 1);
585        for (int i = 0; i <= NM; i++) {
586            P.add(i); //Unit permutation matrix, P[NM] initialized with NM
587        }
588        int r = 0;
589        C divisor = ring.coFac.getONE();
590        ArrayList<ArrayList<C>> mat = A.matrix;
591        for (int i = 0; i < NM && r < N; i++) {
592            int imax = i;
593            C maxA = ring.coFac.getZERO();
594            for (int k = i; k < N; k++) {
595                // absA = fabs(A[k][i])
596                C absA = mat.get(k).get(i).abs();
597                if (absA.compareTo(maxA) > 0) {
598                    maxA = absA;
599                    imax = k;
600                    break; // first
601                }
602            }
603            if (maxA.isZERO()) {
604                logger.warn("matrix is degenerate at col {}", i);
605                mat.get(i).set(i, ring.coFac.getZERO()); //already zero
606                //continue;
607                P.clear();
608                return P; //failure, matrix is degenerate
609            }
610            if (imax != i) {
611                //pivoting P
612                int j = P.get(i);
613                P.set(i, P.get(imax));
614                P.set(imax, j);
615                //System.out.println("new pivot " + imax); // + ", P = " + P);
616                //pivoting rows of A
617                ArrayList<C> ptr = mat.get(i);
618                mat.set(i, mat.get(imax));
619                mat.set(imax, ptr);
620                //counting pivots starting from NM (for determinant)
621                P.set(NM, P.get(NM) + 1);
622            }
623            //C dd = mat.get(i).get(i).inverse();
624            for (int j = r + 1; j < N; j++) {
625                // A[j][i] /= A[i][i];
626                //C d = mat.get(j).get(i).multiply(dd); //divide(dd);
627                //mat.get(j).set(i, d);
628                for (int k = i + 1; k < M; k++) { //+ 1
629                    // A[j][k] -= A[j][i] * A[i][k];
630                    //C a = mat.get(j).get(i).multiply(mat.get(i).get(k));
631                    //mat.get(j).set(k, mat.get(j).get(k).subtract(a));
632                    //System.out.println("i = " + i + ", r = " + r + ", j = " + j + ", k = " + k);
633                    C a = mat.get(r).get(i).multiply(mat.get(j).get(k));
634                    C b = mat.get(r).get(k).multiply(mat.get(j).get(i));
635                    C d = a.subtract(b).divide(divisor);
636                    //System.out.println(", a = " + a + ", b = " + b + ", d = " + d);
637                    mat.get(j).set(k, d);
638                }
639            }
640            for (int j = r + 1; j < N; j++) { // set L-E = 0
641                mat.get(j).set(i, ring.coFac.getZERO());
642            }
643            divisor = mat.get(r).get(i);
644            r++;
645            //System.out.println("divisor = " + divisor);
646            //System.out.println("mat = " + mat);
647            //System.out.println("row(last) = " + mat.get(NM-1));
648        }
649        return P;
650    }
651
652}