package org.nd4j.linalg.api.ops.impl.reduce.custom;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/reduce/custom/BatchMmul.class */
public class BatchMmul extends DynamicCustomOp {
    protected int transposeA;
    protected int transposeB;
    protected int batchSize;
    protected int M;
    protected int N;
    protected int K;

    public BatchMmul(SameDiff sameDiff, SDVariable[] sDVariableArr, SDVariable[] sDVariableArr2, boolean z, boolean z2) {
        this(sameDiff, (SDVariable[]) ArrayUtils.addAll(sDVariableArr, sDVariableArr2), z, z2);
    }

    public BatchMmul(SameDiff sameDiff, SDVariable[] sDVariableArr, boolean z, boolean z2) {
        super((String) null, sameDiff, (SDVariable[]) ArrayUtils.addAll(new SDVariable[]{sameDiff.var(Nd4j.ones(sDVariableArr[0].dataType(), sDVariableArr.length / 2)), sameDiff.var(Nd4j.zeros(sDVariableArr[1].dataType(), sDVariableArr.length / 2))}, sDVariableArr));
        Preconditions.checkState(sDVariableArr.length % 2 == 0, "The number of provided matrices needsto be divisible by two.");
        this.batchSize = sDVariableArr.length / 2;
        long[] shape = sDVariableArr[0].getShape();
        for (int i = 0; i < this.batchSize; i++) {
            Preconditions.checkState(Arrays.equals(shape, sDVariableArr[i].getShape()));
        }
        long[] shape2 = sDVariableArr[(2 * this.batchSize) - 1].getShape();
        for (int i2 = this.batchSize; i2 < 2 * this.batchSize; i2++) {
            Preconditions.checkState(Arrays.equals(shape2, sDVariableArr[i2].getShape()));
        }
        this.transposeA = z ? 1 : 0;
        this.transposeB = z2 ? 1 : 0;
        this.M = z ? (int) shape[1] : (int) shape[0];
        this.N = z ? (int) shape[0] : (int) shape[1];
        this.K = z2 ? (int) shape2[0] : (int) shape2[1];
        addArgs();
    }

    public BatchMmul(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, boolean z, boolean z2) {
        super((INDArray[]) ArrayUtils.addAll(iNDArrayArr, iNDArrayArr2), (INDArray[]) null);
        this.batchSize = iNDArrayArr.length;
        this.transposeA = z ? 1 : 0;
        this.transposeB = z2 ? 1 : 0;
        long[] shape = iNDArrayArr[0].shape();
        long[] shape2 = iNDArrayArr2[0].shape();
        this.M = z ? (int) shape[1] : (int) shape[0];
        this.N = z ? (int) shape[0] : (int) shape[1];
        this.K = z2 ? (int) shape2[0] : (int) shape2[1];
        addArgs();
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public int getNumOutputs() {
        return this.batchSize;
    }

    public void addArgs() {
        addIArgument(this.transposeA, this.transposeB, this.M, this.K, this.N, this.M, this.K, this.N, this.batchSize);
    }

    public BatchMmul() {
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction, org.nd4j.linalg.api.ops.CustomOp
    public String opName() {
        return "batched_gemm";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public List<SDVariable> doDiff(List<SDVariable> list) {
        SDVariable[] sDVariableArr = (SDVariable[]) list.toArray(new SDVariable[list.size()]);
        SDVariable[] args = args();
        SDVariable[] sDVariableArr2 = (SDVariable[]) Arrays.copyOfRange(args, 0, this.batchSize);
        SDVariable[] batchMmul = this.sameDiff.batchMmul(sDVariableArr, (SDVariable[]) Arrays.copyOfRange(args, this.batchSize, 2 * this.batchSize), false, this.transposeB == 1);
        SDVariable[] batchMmul2 = this.sameDiff.batchMmul(sDVariableArr2, sDVariableArr, this.transposeA == 1, false);
        ArrayList arrayList = new ArrayList();
        Collections.addAll(arrayList, batchMmul);
        Collections.addAll(arrayList, batchMmul2);
        return arrayList;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<DataType> calculateOutputDataTypes(List<DataType> list) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size() - 2; i++) {
            Preconditions.checkState(list.get(i).isFPType(), "Inputs to batch mmul op must all be a floating point type: got %s", list);
            if (i % 2 == 0) {
                arrayList.add(list.get(i));
            }
        }
        return arrayList;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof BatchMmul)) {
            return false;
        }
        BatchMmul batchMmul = (BatchMmul) obj;
        return batchMmul.canEqual(this) && this.transposeA == batchMmul.transposeA && this.transposeB == batchMmul.transposeB && this.batchSize == batchMmul.batchSize && this.M == batchMmul.M && this.N == batchMmul.N && this.K == batchMmul.K;
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof BatchMmul;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public int hashCode() {
        return (((((((((((1 * 59) + this.transposeA) * 59) + this.transposeB) * 59) + this.batchSize) * 59) + this.M) * 59) + this.N) * 59) + this.K;
    }
}
