package ai.djl.nn.transformer;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Activation;
import ai.djl.nn.transformer.BertBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.util.Arrays;

/* loaded from: input_file:WEB-INF/lib/api-0.19.0.jar:ai/djl/nn/transformer/BertPretrainingBlock.class */
public class BertPretrainingBlock extends AbstractBlock {
    private BertBlock bertBlock;
    private BertMaskedLanguageModelBlock mlBlock;
    private BertNextSentenceBlock nsBlock = (BertNextSentenceBlock) addChildBlock("BertNextSentenceBlock", (String) new BertNextSentenceBlock());

    public BertPretrainingBlock(BertBlock.Builder builder) {
        this.bertBlock = (BertBlock) addChildBlock("Bert", (String) builder.build());
        this.mlBlock = (BertMaskedLanguageModelBlock) addChildBlock("BertMaskedLanguageModelBlock", (String) new BertMaskedLanguageModelBlock(this.bertBlock, Activation::gelu));
    }

    @Override // ai.djl.nn.AbstractBaseBlock
    public void initializeChildBlocks(NDManager nDManager, DataType dataType, Shape... shapeArr) {
        this.inputNames = Arrays.asList("tokenIds", "typeIds", "sequenceMasks", "maskedIndices");
        this.bertBlock.initialize(nDManager, dataType, shapeArr);
        Shape[] outputShapes = this.bertBlock.getOutputShapes(shapeArr);
        Shape shape = outputShapes[0];
        Shape shape2 = outputShapes[1];
        this.mlBlock.initialize(nDManager, dataType, shape, new Shape(this.bertBlock.getTokenDictionarySize(), this.bertBlock.getEmbeddingSize()), shapeArr[2]);
        this.nsBlock.initialize(nDManager, dataType, shape2);
    }

    @Override // ai.djl.nn.AbstractBaseBlock
    protected NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        NDArray nDArray = nDList.get(0);
        NDArray nDArray2 = nDList.get(1);
        NDArray nDArray3 = nDList.get(2);
        NDArray nDArray4 = nDList.get(3);
        NDManager subManagerOf = NDManager.subManagerOf(nDArray);
        try {
            subManagerOf.tempAttachAll(nDList);
            NDList forward = this.bertBlock.forward(parameterStore, new NDList(nDArray, nDArray2, nDArray3), z);
            NDArray nDArray5 = forward.get(0);
            NDList nDList2 = (NDList) subManagerOf.ret(new NDList(this.nsBlock.forward(parameterStore, new NDList(forward.get(1)), z).singletonOrThrow(), this.mlBlock.forward(parameterStore, new NDList(nDArray5, nDArray4, this.bertBlock.getTokenEmbedding().getValue(parameterStore, nDArray5.getDevice(), z)), z).singletonOrThrow()));
            if (subManagerOf != null) {
                subManagerOf.close();
            }
            return nDList2;
        } catch (Throwable th) {
            if (subManagerOf != null) {
                try {
                    subManagerOf.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(Shape[] shapeArr) {
        long j = shapeArr[0].get(0);
        return new Shape[]{new Shape(j, 2), new Shape(j, shapeArr[3].get(1), this.bertBlock.getTokenDictionarySize())};
    }
}
