package deepboof.impl.backward.standard;

import deepboof.backward.DSpatialBatchNorm;
import deepboof.misc.TensorOps;
import deepboof.tensors.Tensor_F64;
import java.util.List;

/* loaded from: classes2.dex */
public class DSpatialBatchNorm_F64 extends BaseDBatchNorm_F64 implements DSpatialBatchNorm<Tensor_F64> {
    double M;
    double M_var;
    int numChannels;
    int numPixels;

    public DSpatialBatchNorm_F64(boolean z) {
        super(z);
    }

    private void applyGammaBeta(Tensor_F64 tensor_F64) {
        int i = tensor_F64.startIndex;
        int i2 = 0;
        int i3 = 0;
        while (i2 < this.miniBatchSize) {
            int i4 = i3;
            int i5 = i;
            int i6 = 0;
            while (i6 < this.numChannels) {
                int i7 = i6 * 2;
                double d = this.params.d[i7];
                double d2 = this.params.d[i7 + 1];
                int i8 = i4;
                int i9 = i5;
                int i10 = 0;
                while (i10 < this.numPixels) {
                    tensor_F64.d[i9] = (this.tensorXhat.d[i8] * d) + d2;
                    i10++;
                    i9++;
                    i8++;
                }
                i6++;
                i5 = i9;
                i4 = i8;
            }
            i2++;
            i = i5;
            i3 = i4;
        }
    }

    private void computeStatisticsAndNormalize(Tensor_F64 tensor_F64) {
        this.tensorMean.zero();
        this.tensorStd.zero();
        this.tensorXhat.zero();
        int i = tensor_F64.startIndex;
        int i2 = 0;
        while (i2 < this.miniBatchSize) {
            int i3 = i;
            int i4 = 0;
            while (i4 < this.numChannels) {
                int i5 = i3;
                double d = 0.0d;
                int i6 = 0;
                while (i6 < this.numPixels) {
                    d += tensor_F64.d[i5];
                    i6++;
                    i5++;
                }
                double[] dArr = this.tensorMean.d;
                dArr[i4] = dArr[i4] + d;
                i4++;
                i3 = i5;
            }
            i2++;
            i = i3;
        }
        for (int i7 = 0; i7 < this.numChannels; i7++) {
            double[] dArr2 = this.tensorMean.d;
            dArr2[i7] = dArr2[i7] / this.M;
        }
        int i8 = tensor_F64.startIndex;
        int i9 = 0;
        int i10 = 0;
        while (i9 < this.miniBatchSize) {
            int i11 = i8;
            int i12 = i10;
            int i13 = 0;
            while (i13 < this.numChannels) {
                double d2 = this.tensorMean.d[i13];
                int i14 = i12;
                int i15 = i11;
                int i16 = 0;
                double d3 = 0.0d;
                while (i16 < this.numPixels) {
                    double d4 = tensor_F64.d[i15] - d2;
                    this.tensorDiffX.d[i14] = d4;
                    d3 += d4 * d4;
                    i16++;
                    i14++;
                    i15++;
                }
                double[] dArr3 = this.tensorStd.d;
                dArr3[i13] = dArr3[i13] + d3;
                i13++;
                i11 = i15;
                i12 = i14;
            }
            i9++;
            i10 = i12;
            i8 = i11;
        }
        for (int i17 = 0; i17 < this.numChannels; i17++) {
            this.tensorStd.d[i17] = Math.sqrt((this.tensorStd.d[i17] / this.M_var) + this.EPS);
        }
        int i18 = 0;
        int i19 = 0;
        while (i18 < this.miniBatchSize) {
            int i20 = i19;
            int i21 = 0;
            while (i21 < this.numChannels) {
                double d5 = this.tensorStd.d[i21];
                int i22 = i20;
                int i23 = 0;
                while (i23 < this.numPixels) {
                    this.tensorXhat.d[i22] = this.tensorDiffX.d[i22] / d5;
                    i23++;
                    i22++;
                }
                i21++;
                i20 = i22;
            }
            i18++;
            i19 = i20;
        }
    }

    private void forwardLearning(Tensor_F64 tensor_F64, Tensor_F64 tensor_F642) {
        computeStatisticsAndNormalize(tensor_F64);
        if (this.requiresGammaBeta) {
            applyGammaBeta(tensor_F642);
        } else {
            tensor_F642.setTo(this.tensorXhat);
        }
    }

    private void partialMean() {
        this.tensorDMean.zero();
        this.tensorTmp.zero();
        int i = 0;
        int i2 = 0;
        while (i < this.miniBatchSize) {
            int i3 = i2;
            int i4 = 0;
            while (i4 < this.numChannels) {
                double d = 0.0d;
                int i5 = i3;
                double d2 = 0.0d;
                int i6 = 0;
                while (i6 < this.numPixels) {
                    d += this.tensorDiffX.d[i5];
                    d2 -= this.tensorDXhat.d[i5];
                    i6++;
                    i5++;
                }
                double[] dArr = this.tensorTmp.d;
                dArr[i4] = dArr[i4] + d;
                double[] dArr2 = this.tensorDMean.d;
                dArr2[i4] = dArr2[i4] + d2;
                i4++;
                i3 = i5;
            }
            i++;
            i2 = i3;
        }
        for (int i7 = 0; i7 < this.numChannels; i7++) {
            double[] dArr3 = this.tensorDMean.d;
            dArr3[i7] = dArr3[i7] / this.tensorStd.d[i7];
            double[] dArr4 = this.tensorDMean.d;
            dArr4[i7] = dArr4[i7] - (((this.tensorDVar.d[i7] * 2.0d) * this.tensorTmp.d[i7]) / this.M_var);
        }
    }

    private void partialParameters(Tensor_F64 tensor_F64, Tensor_F64 tensor_F642) {
        tensor_F64.zero();
        int i = tensor_F642.startIndex;
        int i2 = 0;
        int i3 = 0;
        while (i2 < this.miniBatchSize) {
            int i4 = i3;
            int i5 = 0;
            int i6 = i;
            int i7 = 0;
            while (i7 < this.numChannels) {
                double d = 0.0d;
                int i8 = i4;
                double d2 = 0.0d;
                int i9 = i6;
                int i10 = 0;
                while (i10 < this.numPixels) {
                    double d3 = tensor_F642.d[i9];
                    d += this.tensorXhat.d[i8] * d3;
                    d2 += d3;
                    i10++;
                    i8++;
                    i9++;
                }
                double[] dArr = tensor_F64.d;
                int i11 = i5 + 1;
                dArr[i5] = dArr[i5] + d;
                double[] dArr2 = tensor_F64.d;
                i5 = i11 + 1;
                dArr2[i11] = dArr2[i11] + d2;
                i7++;
                i6 = i9;
                i4 = i8;
            }
            i2++;
            i = i6;
            i3 = i4;
        }
    }

    private void partialVariance() {
        this.tensorDVar.zero();
        int i = 0;
        int i2 = 0;
        while (i < this.miniBatchSize) {
            int i3 = i2;
            int i4 = 0;
            while (i4 < this.numChannels) {
                double d = 0.0d;
                int i5 = i3;
                int i6 = 0;
                while (i6 < this.numPixels) {
                    d += this.tensorDXhat.d[i5] * this.tensorDiffX.d[i5];
                    i6++;
                    i5++;
                }
                double[] dArr = this.tensorDVar.d;
                dArr[i4] = dArr[i4] + d;
                i4++;
                i3 = i5;
            }
            i++;
            i2 = i3;
        }
        for (int i7 = 0; i7 < this.numChannels; i7++) {
            double d2 = this.tensorStd.d[i7];
            double[] dArr2 = this.tensorDVar.d;
            dArr2[i7] = dArr2[i7] / (((d2 * d2) * d2) * (-2.0d));
        }
    }

    private void partialX(Tensor_F64 tensor_F64) {
        int i = tensor_F64.startIndex;
        int i2 = 0;
        int i3 = 0;
        while (i2 < this.miniBatchSize) {
            int i4 = i3;
            int i5 = i;
            int i6 = 0;
            while (i6 < this.numChannels) {
                double d = this.tensorStd.d[i6];
                double d2 = this.tensorDVar.d[i6];
                double d3 = this.tensorDMean.d[i6];
                int i7 = i5;
                int i8 = 0;
                while (i8 < this.numPixels) {
                    tensor_F64.d[i7] = (this.tensorDXhat.d[i4] / d) + (((2.0d * d2) * this.tensorDiffX.d[i4]) / this.M_var) + (d3 / this.M);
                    i8++;
                    i4++;
                    i7++;
                    i6 = i6;
                }
                i6++;
                i5 = i7;
            }
            i2++;
            i = i5;
            i3 = i4;
        }
    }

    private void partialXHat(Tensor_F64 tensor_F64) {
        int i = tensor_F64.startIndex;
        int i2 = 0;
        int i3 = 0;
        while (i2 < this.miniBatchSize) {
            int i4 = i3;
            int i5 = i;
            int i6 = 0;
            while (i6 < this.numChannels) {
                double d = this.params.d[i6 * 2];
                int i7 = i5;
                int i8 = 0;
                while (i8 < this.numPixels) {
                    this.tensorDXhat.d[i4] = tensor_F64.d[i7] * d;
                    i8++;
                    i4++;
                    i7++;
                }
                i6++;
                i5 = i7;
            }
            i2++;
            i = i5;
            i3 = i4;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // deepboof.impl.backward.standard.BaseDFunction
    public void _backwards(Tensor_F64 tensor_F64, Tensor_F64 tensor_F642, Tensor_F64 tensor_F643, List<Tensor_F64> list) {
        this.tensorDXhat.reshape(tensor_F64.shape);
        if (this.requiresGammaBeta) {
            partialXHat(tensor_F642);
        } else {
            this.tensorDXhat.setTo(tensor_F642);
        }
        partialVariance();
        partialMean();
        partialX(tensor_F643);
        if (this.requiresGammaBeta) {
            partialParameters(list.get(0), tensor_F642);
        }
    }

    @Override // deepboof.impl.forward.standard.BaseFunction
    public void _forward(Tensor_F64 tensor_F64, Tensor_F64 tensor_F642) {
        if (tensor_F64.length(0) <= 1) {
            throw new IllegalArgumentException("There must be more than 1 minibatch");
        }
        this.tensorDiffX.reshape(tensor_F64.shape);
        this.tensorXhat.reshape(tensor_F64.shape);
        this.numChannels = tensor_F64.length(1);
        this.numPixels = TensorOps.outerLength(tensor_F64.shape, 2);
        this.M = this.miniBatchSize * this.numPixels;
        this.M_var = this.M - 1.0d;
        if (this.learningMode) {
            forwardLearning(tensor_F64, tensor_F642);
        } else {
            forwardEvaluate(tensor_F64, tensor_F642);
        }
    }

    @Override // deepboof.impl.backward.standard.BaseDBatchNorm_F64
    protected int[] createShapeVariables(int[] iArr) {
        return new int[]{iArr[0]};
    }

    public void forwardEvaluate(Tensor_F64 tensor_F64, Tensor_F64 tensor_F642) {
        DSpatialBatchNorm_F64 dSpatialBatchNorm_F64 = this;
        int length = tensor_F64.length(1);
        int length2 = tensor_F64.length(2) * tensor_F64.length(3);
        int i = tensor_F64.startIndex;
        int i2 = tensor_F642.startIndex;
        if (!hasGammaBeta()) {
            int i3 = i2;
            int i4 = i;
            int i5 = 0;
            while (i5 < dSpatialBatchNorm_F64.miniBatchSize) {
                int i6 = i3;
                int i7 = i4;
                int i8 = 0;
                while (i8 < length) {
                    double d = dSpatialBatchNorm_F64.tensorMean.d[i8];
                    double d2 = dSpatialBatchNorm_F64.tensorStd.d[i8];
                    int i9 = i7 + length2;
                    while (i7 < i9) {
                        tensor_F642.d[i6] = (tensor_F64.d[i7] - d) / d2;
                        i6++;
                        i7++;
                    }
                    i8++;
                    dSpatialBatchNorm_F64 = this;
                }
                i5++;
                dSpatialBatchNorm_F64 = this;
                i4 = i7;
                i3 = i6;
            }
            return;
        }
        int i10 = i2;
        int i11 = i;
        int i12 = 0;
        while (i12 < dSpatialBatchNorm_F64.miniBatchSize) {
            int i13 = dSpatialBatchNorm_F64.params.startIndex;
            int i14 = i10;
            int i15 = i11;
            int i16 = 0;
            while (i16 < length) {
                double d3 = dSpatialBatchNorm_F64.tensorMean.d[i16];
                double d4 = dSpatialBatchNorm_F64.tensorStd.d[i16];
                int i17 = i13 + 1;
                double d5 = dSpatialBatchNorm_F64.params.d[i13];
                int i18 = i17 + 1;
                double d6 = dSpatialBatchNorm_F64.params.d[i17];
                int i19 = i15 + length2;
                while (i15 < i19) {
                    tensor_F642.d[i14] = ((tensor_F64.d[i15] - d3) * (d5 / d4)) + d6;
                    i14++;
                    i19 = i19;
                    i15++;
                }
                i16++;
                i13 = i18;
            }
            i12++;
            i11 = i15;
            i10 = i14;
        }
    }
}
