/*
 * Decompiled with CFR 0.152.
 */
package edu.mines.jtk.dsp;

import edu.mines.jtk.dsp.LocalCorrelationFilter;
import edu.mines.jtk.util.Array;
import edu.mines.jtk.util.Check;
import java.util.ArrayList;

public class LocalPredictionFilter {
    private LocalCorrelationFilter _lcf;

    public LocalPredictionFilter(double sigma) {
        Check.argument(sigma >= 1.0, "sigma>=1.0");
        this._lcf = new LocalCorrelationFilter(sigma);
    }

    public float[][][] apply(int[] lag1, int[] lag2, float[][] f, float[][] g) {
        Check.argument(lag1.length == lag2.length, "lag1.length==lag2.length");
        Check.argument(f != g, "f!=g");
        R2Cache rcache = new R2Cache(f);
        int m = lag1.length;
        float[][][][] rkj = new float[m][m][][];
        float[][][] rk0 = new float[m][][];
        for (int k = 0; k < m; ++k) {
            int k1 = lag1[k];
            int k2 = lag2[k];
            for (int j = 0; j < m; ++j) {
                int j1 = lag1[j];
                int j2 = lag2[j];
                rkj[k][j] = rcache.get(j1 - k1, j2 - k2);
            }
            rk0[k] = rcache.get(k1, k2);
        }
        int n1 = f[0].length;
        int n2 = f.length;
        double[][] rkjt = new double[m][m];
        double[] rk0t = new double[m];
        double[] at = new double[m];
        float[][][] a = new float[m][n2][n1];
        CgSolver cgs = new CgSolver(m, 100);
        double niter = 0.0;
        for (int i2 = 0; i2 < n2; ++i2) {
            int i1b = i2 % 2 == 0 ? 0 : n1 - 1;
            int i1e = i2 % 2 == 0 ? n1 : -1;
            int i1s = i2 % 2 == 0 ? 1 : -1;
            for (int i1 = i1b; i1 != i1e; i1 += i1s) {
                for (int k = 0; k < m; ++k) {
                    for (int j = 0; j < m; ++j) {
                        rkjt[k][j] = rkj[k][j][i2][i1];
                    }
                    rk0t[k] = rk0[k][i2][i1];
                }
                niter += (double)cgs.solve(rkjt, rk0t, at);
                for (int i = 0; i < m; ++i) {
                    a[i][i2][i1] = (float)at[i];
                }
            }
        }
        niter /= (double)n1;
        System.out.println("Average number of CG iterations = " + (niter /= (double)n2));
        Array.zero(g);
        for (int j = 0; j < m; ++j) {
            int j1 = lag1[j];
            int j2 = lag2[j];
            float[][] aj = a[j];
            int i1min = Math.max(0, j1);
            int i1max = Math.min(n1, n1 + j1);
            int i2min = Math.max(0, j2);
            int i2max = Math.min(n2, n2 + j2);
            for (int i2 = i2min; i2 < i2max; ++i2) {
                for (int i1 = i1min; i1 < i1max; ++i1) {
                    float[] fArray = g[i2];
                    int n = i1;
                    fArray[n] = fArray[n] + aj[i2][i1] * f[i2 - j2][i1 - j1];
                }
            }
        }
        return a;
    }

    public void applyPef(int[] lag1, int[] lag2, float[][] f, float[][] g) {
        Check.argument(lag1.length == lag2.length, "lag1.length==lag2.length");
        Check.argument(f != g, "f!=g");
        R2Cache rcache = new R2Cache(f);
        int m = lag1.length;
        float[][][][] rkj = new float[m][m][][];
        float[][][] rk0 = new float[m][][];
        for (int k = 0; k < m; ++k) {
            int k1 = lag1[k];
            int k2 = lag2[k];
            for (int j = 0; j < m; ++j) {
                int j1 = lag1[j];
                int j2 = lag2[j];
                rkj[k][j] = rcache.get(j1 - k1, j2 - k2);
            }
            rk0[k] = rcache.get(k1, k2);
        }
        int n1 = f[0].length;
        int n2 = f.length;
        double[][] rkjt = new double[m][m];
        double[] rk0t = new double[m];
        double[] at = new double[m];
        float[][][] a = new float[m][n2][n1];
        CgSolver cgs = new CgSolver(m, 100);
        double niter = 0.0;
        for (int i2 = 0; i2 < n2; ++i2) {
            int i1b = i2 % 2 == 0 ? 0 : n1 - 1;
            int i1e = i2 % 2 == 0 ? n1 : -1;
            int i1s = i2 % 2 == 0 ? 1 : -1;
            for (int i1 = i1b; i1 != i1e; i1 += i1s) {
                for (int k = 0; k < m; ++k) {
                    for (int j = 0; j < m; ++j) {
                        rkjt[k][j] = rkj[k][j][i2][i1];
                    }
                    rk0t[k] = rk0[k][i2][i1];
                }
                niter += (double)cgs.solve(rkjt, rk0t, at);
                for (int i = 0; i < m; ++i) {
                    a[i][i2][i1] = (float)at[i];
                }
            }
        }
        niter /= (double)n1;
        System.out.println("Average number of CG iterations = " + (niter /= (double)n2));
        Array.copy(f, g);
        for (int j = 0; j < m; ++j) {
            int j1 = lag1[j];
            int j2 = lag2[j];
            float[][] aj = a[j];
            int i1min = Math.max(0, j1);
            int i1max = Math.min(n1, n1 + j1);
            int i2min = Math.max(0, j2);
            int i2max = Math.min(n2, n2 + j2);
            for (int i2 = i2min; i2 < i2max; ++i2) {
                for (int i1 = i1min; i1 < i1max; ++i1) {
                    float[] fArray = g[i2];
                    int n = i1;
                    fArray[n] = fArray[n] - aj[i2][i1] * f[i2 - j2][i1 - j1];
                }
            }
        }
    }

    private class R2Cache {
        float[][] _f;
        ArrayList<R2> _rlist = new ArrayList();

        R2Cache(float[][] f) {
            this._f = f;
        }

        float[][][] get() {
            int n = this._rlist.size();
            float[][][] r = new float[n][][];
            int i = 0;
            for (R2 r2 : this._rlist) {
                r[i] = r2.r;
                ++i;
            }
            return r;
        }

        float[][] get(int l1, int l2) {
            for (R2 r2 : this._rlist) {
                if ((l1 != r2.l1 || l2 != r2.l2) && (-l1 != r2.l1 || -l2 != r2.l2)) continue;
                return r2.r;
            }
            R2 r2 = new R2(l1, l2, this._f);
            this._rlist.add(r2);
            return r2.r;
        }
    }

    private class R2 {
        int l1;
        int l2;
        float[][] r;

        R2(int l1, int l2, float[][] f) {
            int n1 = f[0].length;
            int n2 = f.length;
            this.l1 = l1;
            this.l2 = l2;
            this.r = new float[n2][n1];
            LocalPredictionFilter.this._lcf.apply(l1, l2, f, f, this.r);
            if (l1 == 0 && l2 == 0) {
                for (int i2 = 0; i2 < n2; ++i2) {
                    int i1 = 0;
                    while (i1 < n1) {
                        float[] fArray = this.r[i2];
                        int n = i1++;
                        fArray[n] = fArray[n] * 1.01f;
                    }
                }
            }
        }
    }

    private static class CgSolver {
        private static final double TINY = Math.ulp(1.0f);
        private int m;
        private int maxiter;
        private double[] p;
        private double[] q;
        private double[] r;

        CgSolver(int m, int maxiter) {
            this.m = m;
            this.maxiter = maxiter;
            this.p = new double[m];
            this.q = new double[m];
            this.r = new double[m];
        }

        int solve(double[][] a, double[] b, double[] x) {
            int niter;
            double rp = 0.0;
            double rr = 0.0;
            double bb = 0.0;
            for (int i = 0; i < this.m; ++i) {
                double[] ai = a[i];
                double ax = 0.0;
                for (int j = 0; j < this.m; ++j) {
                    ax += ai[j] * x[j];
                }
                double bi = b[i];
                double ri = this.r[i] = bi - ax;
                bb += bi * bi;
                rr += ri * ri;
            }
            double small = bb * TINY;
            for (niter = 0; niter < this.maxiter && rr > small; ++niter) {
                int i;
                if (niter == 0) {
                    for (int i2 = 0; i2 < this.m; ++i2) {
                        this.p[i2] = this.r[i2];
                    }
                } else {
                    double beta = rr / rp;
                    for (i = 0; i < this.m; ++i) {
                        this.p[i] = this.r[i] + beta * this.p[i];
                    }
                }
                double pq = 0.0;
                for (i = 0; i < this.m; ++i) {
                    double[] ai = a[i];
                    double ap = 0.0;
                    for (int j = 0; j < this.m; ++j) {
                        ap += ai[j] * this.p[j];
                    }
                    this.q[i] = ap;
                    pq += this.p[i] * this.q[i];
                }
                double alpha = rr / pq;
                rp = rr;
                rr = 0.0;
                for (int i3 = 0; i3 < this.m; ++i3) {
                    int n = i3;
                    x[n] = x[n] + alpha * this.p[i3];
                    int n2 = i3;
                    this.r[n2] = this.r[n2] - alpha * this.q[i3];
                    rr += this.r[i3] * this.r[i3];
                }
            }
            if (rr > small) {
                System.out.println("CgSolver.solve: failed to converge");
            }
            return niter;
        }
    }
}

