package org.nd4j.linalg.api.ops.impl.layers.recurrent;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights;
import org.nd4j.shade.guava.primitives.Booleans;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayerBp.class */
public class LSTMLayerBp extends DynamicCustomOp {
    private LSTMLayerConfig configuration;
    private LSTMLayerWeights weights;
    private SDVariable cLast;
    private SDVariable yLast;
    private SDVariable maxTSLength;

    public LSTMLayerBp(@NonNull SameDiff sameDiff, @NonNull SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, SDVariable sDVariable4, @NonNull LSTMLayerWeights lSTMLayerWeights, @NonNull LSTMLayerConfig lSTMLayerConfig, SDVariable sDVariable5, SDVariable sDVariable6, SDVariable sDVariable7) {
        super("lstmLayer_bp", sameDiff, (SDVariable[]) wrapFilterNull(sDVariable, lSTMLayerWeights.getWeights(), lSTMLayerWeights.getRWeights(), lSTMLayerWeights.getBias(), sDVariable4, sDVariable3, sDVariable2, lSTMLayerWeights.getPeepholeWeights(), sDVariable5, sDVariable6, sDVariable7));
        if (sameDiff == null) {
            throw new NullPointerException("sameDiff is marked non-null but is null");
        }
        if (sDVariable == null) {
            throw new NullPointerException("x is marked non-null but is null");
        }
        if (lSTMLayerWeights == null) {
            throw new NullPointerException("weights is marked non-null but is null");
        }
        if (lSTMLayerConfig == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
        this.configuration = lSTMLayerConfig;
        this.weights = lSTMLayerWeights;
        this.cLast = sDVariable2;
        this.yLast = sDVariable3;
        this.maxTSLength = sDVariable4;
        addIArgument(iArgs());
        addTArgument(tArgs());
        addBArgument(bArgs(lSTMLayerWeights, sDVariable4, sDVariable3, sDVariable2));
        Preconditions.checkState(this.configuration.isRetLastH() || this.configuration.isRetLastC() || this.configuration.isRetFullSequence(), "You have to specify at least one output you want to return. Use isRetLastC, isRetLast and isRetFullSequence  methods  in LSTMLayerConfig builder to specify them");
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<DataType> calculateOutputDataTypes(List<DataType> list) {
        DataType dataType = list.get(1);
        Preconditions.checkState(dataType.isFPType(), "Input type 1 must be a floating point type, got %s", dataType);
        ArrayList arrayList = new ArrayList();
        arrayList.add(dataType);
        arrayList.add(dataType);
        arrayList.add(dataType);
        if (this.weights.hasBias()) {
            arrayList.add(dataType);
        }
        if (this.maxTSLength != null) {
            arrayList.add(dataType);
        }
        if (this.yLast != null) {
            arrayList.add(dataType);
        }
        if (this.cLast != null) {
            arrayList.add(dataType);
        }
        if (this.weights.hasPH()) {
            arrayList.add(dataType);
        }
        return arrayList;
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction, org.nd4j.linalg.api.ops.CustomOp
    public String opName() {
        return "lstmLayer_bp";
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Map<String, Object> propertiesForFunction() {
        return this.configuration.toProperties(true, true);
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.linalg.api.ops.CustomOp
    public long[] iArgs() {
        return new long[]{this.configuration.getLstmdataformat().ordinal(), this.configuration.getDirectionMode().ordinal(), this.configuration.getGateAct().ordinal(), this.configuration.getOutAct().ordinal(), this.configuration.getCellAct().ordinal()};
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.linalg.api.ops.CustomOp
    public double[] tArgs() {
        return new double[]{this.configuration.getCellClip()};
    }

    protected <T> boolean[] bArgs(LSTMLayerWeights lSTMLayerWeights, T t, T t2, T t3) {
        boolean[] zArr = new boolean[8];
        zArr[0] = lSTMLayerWeights.hasBias();
        zArr[1] = t != null;
        zArr[2] = t2 != null;
        zArr[3] = t3 != null;
        zArr[4] = lSTMLayerWeights.hasPH();
        zArr[5] = this.configuration.isRetFullSequence();
        zArr[6] = this.configuration.isRetLastH();
        zArr[7] = this.configuration.isRetLastC();
        return zArr;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public boolean isConfigProperties() {
        return true;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String configFieldName() {
        return "configuration";
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public int getNumOutputs() {
        boolean[] zArr = new boolean[8];
        zArr[0] = true;
        zArr[1] = true;
        zArr[2] = true;
        zArr[3] = this.weights.hasBias();
        zArr[4] = this.maxTSLength != null;
        zArr[5] = this.yLast != null;
        zArr[6] = this.cLast != null;
        zArr[7] = this.weights.hasPH();
        return Booleans.countTrue(zArr);
    }

    public LSTMLayerBp() {
    }

    public LSTMLayerConfig getConfiguration() {
        return this.configuration;
    }

    public LSTMLayerWeights getWeights() {
        return this.weights;
    }
}
