package ai.djl.modality.nlp;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Arrays;

/* loaded from: input_file:WEB-INF/lib/api-0.19.0.jar:ai/djl/modality/nlp/EncoderDecoder.class */
public class EncoderDecoder extends AbstractBlock {
    private static final byte VERSION = 1;
    protected Encoder encoder;
    protected Decoder decoder;

    public EncoderDecoder(Encoder encoder, Decoder decoder) {
        super((byte) 1);
        this.encoder = (Encoder) addChildBlock("Encoder", (String) encoder);
        this.decoder = (Decoder) addChildBlock("Decoder", (String) decoder);
        this.inputNames = Arrays.asList("encoderInput", "decoderInput");
    }

    @Override // ai.djl.nn.AbstractBaseBlock, ai.djl.nn.Block
    public PairList<String, Shape> describeInput() {
        if (isInitialized()) {
            return new PairList<>(this.inputNames, Arrays.asList(this.inputShapes));
        }
        throw new IllegalStateException("Parameter of this block are not initialised");
    }

    @Override // ai.djl.nn.AbstractBaseBlock
    protected NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        if (z) {
            throw new IllegalArgumentException("You must use forward with labels when training");
        }
        throw new UnsupportedOperationException("EncoderDecoder prediction has not been implemented yet");
    }

    @Override // ai.djl.nn.AbstractBaseBlock, ai.djl.nn.Block
    public NDList forward(ParameterStore parameterStore, NDList nDList, NDList nDList2, PairList<String, Object> pairList) {
        nDList2.addAll(this.encoder.getStates(this.encoder.forward(parameterStore, nDList, true, pairList)));
        return this.decoder.forward(parameterStore, nDList2, true, pairList);
    }

    @Override // ai.djl.nn.AbstractBaseBlock, ai.djl.nn.Block
    public void initialize(NDManager nDManager, DataType dataType, Shape... shapeArr) {
        beforeInitialize(shapeArr);
        this.encoder.initialize(nDManager, dataType, shapeArr[0]);
        this.decoder.initialize(nDManager, dataType, shapeArr[1]);
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(Shape[] shapeArr) {
        return this.decoder.getOutputShapes(new Shape[]{shapeArr[1]});
    }

    @Override // ai.djl.nn.AbstractBaseBlock, ai.djl.nn.Block
    public void saveParameters(DataOutputStream dataOutputStream) throws IOException {
        this.encoder.saveParameters(dataOutputStream);
        this.decoder.saveParameters(dataOutputStream);
    }

    @Override // ai.djl.nn.AbstractBaseBlock, ai.djl.nn.Block
    public void loadParameters(NDManager nDManager, DataInputStream dataInputStream) throws IOException, MalformedModelException {
        this.encoder.loadParameters(nDManager, dataInputStream);
        this.decoder.loadParameters(nDManager, dataInputStream);
    }
}
