package ai.djl.basicmodelzoo.cv.object_detection.ssd;

import ai.djl.MalformedModelException;
import ai.djl.modality.cv.MultiBoxPrior;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.LambdaBlock;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.norm.BatchNorm;
import ai.djl.nn.pooling.Pool;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

/* loaded from: input_file:ai/djl/basicmodelzoo/cv/object_detection/ssd/SingleShotDetection.class */
public final class SingleShotDetection extends AbstractBlock {
    private static final byte VERSION = 2;
    private List<Block> features;
    private List<Block> classPredictionBlocks;
    private List<Block> anchorPredictionBlocks;
    private List<MultiBoxPrior> multiBoxPriors;
    private int numClasses;

    /* loaded from: input_file:ai/djl/basicmodelzoo/cv/object_detection/ssd/SingleShotDetection$Builder.class */
    public static class Builder {
        private Block network;
        private List<Block> features;
        private List<List<Float>> sizes;
        private List<List<Float>> ratios;
        private int numClasses;
        private int numFeatures = -1;
        private List<Block> classPredictionBlocks = new ArrayList();
        private List<Block> anchorPredictionBlocks = new ArrayList();
        private List<MultiBoxPrior> multiBoxPriors = new ArrayList();
        private boolean globalPool = true;

        Builder() {
        }

        public Builder setSizes(List<List<Float>> list) {
            this.sizes = list;
            return this;
        }

        public Builder setRatios(List<List<Float>> list) {
            this.ratios = list;
            return this;
        }

        public Builder setNumClasses(int i) {
            this.numClasses = i;
            return this;
        }

        public Builder setBaseNetwork(Block block) {
            this.network = block;
            return this;
        }

        public Builder setNumFeatures(int i) {
            this.numFeatures = i;
            return this;
        }

        public Builder optFeatures(List<Block> list) {
            this.features = list;
            return this;
        }

        public Builder optGlobalPool(boolean z) {
            this.globalPool = z;
            return this;
        }

        public SingleShotDetection build() {
            if (this.features == null && this.numFeatures < 0) {
                throw new IllegalArgumentException("Either numFeatures or features must be set");
            }
            if (this.features == null) {
                this.features = new ArrayList();
                this.features.add(this.network);
                for (int i = 0; i < this.numFeatures; i++) {
                    this.features.add(SingleShotDetection.getDownSamplingBlock(128));
                }
            }
            if (this.globalPool) {
                this.features.add(LambdaBlock.singleton(nDArray -> {
                    NDArray globalAvgPool2d = Pool.globalAvgPool2d(nDArray);
                    return globalAvgPool2d.reshape(globalAvgPool2d.getShape().add(new long[]{1, 1}));
                }));
            }
            int size = this.features.size();
            if (this.sizes.size() != this.ratios.size() || this.sizes.size() != size) {
                throw new IllegalArgumentException("Sizes and ratios must be of size: " + size);
            }
            for (int i2 = 0; i2 < size; i2++) {
                List<Float> list = this.sizes.get(i2);
                List<Float> list2 = this.ratios.get(i2);
                int size2 = (list.size() + list2.size()) - 1;
                this.classPredictionBlocks.add(SingleShotDetection.getClassPredictionBlock(size2, this.numClasses));
                this.anchorPredictionBlocks.add(SingleShotDetection.getAnchorPredictionBlock(size2));
                this.multiBoxPriors.add(MultiBoxPrior.builder().setSizes(list).setRatios(list2).build());
            }
            return new SingleShotDetection(this);
        }
    }

    private SingleShotDetection(Builder builder) {
        super((byte) 2);
        this.features = builder.features;
        this.features.forEach(block -> {
            addChildBlock(block.getClass().getSimpleName(), block);
        });
        this.numClasses = builder.numClasses;
        this.classPredictionBlocks = builder.classPredictionBlocks;
        this.classPredictionBlocks.forEach(block2 -> {
            addChildBlock(block2.getClass().getSimpleName(), block2);
        });
        this.anchorPredictionBlocks = builder.anchorPredictionBlocks;
        this.anchorPredictionBlocks.forEach(block3 -> {
            addChildBlock(block3.getClass().getSimpleName(), block3);
        });
        this.multiBoxPriors = builder.multiBoxPriors;
    }

    protected NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        NDList nDList2 = nDList;
        NDArray[] nDArrayArr = new NDArray[this.features.size()];
        NDArray[] nDArrayArr2 = new NDArray[this.features.size()];
        NDArray[] nDArrayArr3 = new NDArray[this.features.size()];
        for (int i = 0; i < this.features.size(); i++) {
            nDList2 = this.features.get(i).forward(parameterStore, nDList2, z);
            nDArrayArr[i] = this.multiBoxPriors.get(i).generateAnchorBoxes(nDList2.singletonOrThrow());
            nDArrayArr2[i] = this.classPredictionBlocks.get(i).forward(parameterStore, nDList2, z).singletonOrThrow();
            nDArrayArr3[i] = this.anchorPredictionBlocks.get(i).forward(parameterStore, nDList2, z).singletonOrThrow();
        }
        NDArray concat = NDArrays.concat(new NDList(nDArrayArr), 1);
        NDArray concatPredictions = concatPredictions(new NDList(nDArrayArr2));
        return new NDList(new NDArray[]{concat, concatPredictions.reshape(new long[]{concatPredictions.size(0), -1, this.numClasses + 1}), concatPredictions(new NDList(nDArrayArr3))});
    }

    private NDArray concatPredictions(NDList nDList) {
        return NDArrays.concat(new NDList((NDArray[]) nDList.stream().map(nDArray -> {
            return nDArray.transpose(new int[]{0, VERSION, 3, 1}).reshape(new long[]{nDArray.size(0), -1});
        }).toArray(i -> {
            return new NDArray[i];
        })), 1);
    }

    public Shape[] getOutputShapes(Shape[] shapeArr) {
        NDManager newBaseManager = NDManager.newBaseManager();
        try {
            Shape[] shapeArr2 = shapeArr;
            Shape[] shapeArr3 = new Shape[this.features.size()];
            Shape[] shapeArr4 = new Shape[this.features.size()];
            Shape[] shapeArr5 = new Shape[this.features.size()];
            for (int i = 0; i < this.features.size(); i++) {
                shapeArr2 = this.features.get(i).getOutputShapes(shapeArr2);
                shapeArr3[i] = this.multiBoxPriors.get(i).generateAnchorBoxes(newBaseManager.ones(shapeArr2[0])).getShape();
                shapeArr4[i] = this.classPredictionBlocks.get(i).getOutputShapes(shapeArr2)[0];
                shapeArr5[i] = this.anchorPredictionBlocks.get(i).getOutputShapes(shapeArr2)[0];
            }
            Shape shape = new Shape(new long[0]);
            for (Shape shape2 : shapeArr3) {
                shape = concatShape(shape, shape2, 1);
            }
            NDList nDList = new NDList();
            for (Shape shape3 : shapeArr4) {
                nDList.add(newBaseManager.ones(shape3));
            }
            NDArray concatPredictions = concatPredictions(nDList);
            Shape shape4 = concatPredictions.reshape(new long[]{concatPredictions.size(0), -1, this.numClasses + 1}).getShape();
            NDList nDList2 = new NDList();
            for (Shape shape5 : shapeArr5) {
                nDList2.add(newBaseManager.ones(shape5));
            }
            Shape[] shapeArr6 = {shape, shape4, concatPredictions(nDList2).getShape()};
            if (newBaseManager != null) {
                newBaseManager.close();
            }
            return shapeArr6;
        } catch (Throwable th) {
            if (newBaseManager != null) {
                try {
                    newBaseManager.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private Shape concatShape(Shape shape, Shape shape2, int i) {
        if (shape.dimension() == 0) {
            return shape2;
        }
        if (shape.dimension() != shape2.dimension()) {
            throw new IllegalArgumentException("Shapes must have same dimensions");
        }
        long[] jArr = new long[shape.dimension()];
        for (int i2 = 0; i2 < shape.dimension(); i2++) {
            if (i == i2) {
                jArr[i2] = shape.get(i2) + shape2.get(i2);
            } else {
                if (shape.get(i2) != shape2.get(i2)) {
                    throw new UnsupportedOperationException("These shapes cannot be concatenated along axis " + i2);
                }
                jArr[i2] = shape.get(i2);
            }
        }
        return new Shape(jArr);
    }

    public void initialize(NDManager nDManager, DataType dataType, Shape... shapeArr) {
        beforeInitialize(shapeArr);
        Shape[] shapeArr2 = shapeArr;
        for (int i = 0; i < this.features.size(); i++) {
            this.features.get(i).initialize(nDManager, dataType, shapeArr2);
            shapeArr2 = this.features.get(i).getOutputShapes(shapeArr2);
            this.classPredictionBlocks.get(i).initialize(nDManager, dataType, shapeArr2);
            this.anchorPredictionBlocks.get(i).initialize(nDManager, dataType, shapeArr2);
        }
    }

    public void loadMetadata(byte b, DataInputStream dataInputStream) throws IOException, MalformedModelException {
        if (b == this.version) {
            readInputShapes(dataInputStream);
        } else if (b != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + ((int) b));
        }
    }

    public static SequentialBlock getDownSamplingBlock(int i) {
        SequentialBlock sequentialBlock = new SequentialBlock();
        for (int i2 = 0; i2 < VERSION; i2++) {
            sequentialBlock.add(Conv2d.builder().setKernelShape(new Shape(new long[]{3, 3})).setFilters(i).optPadding(new Shape(new long[]{1, 1})).build()).add(BatchNorm.builder().build()).add(Activation::relu);
        }
        sequentialBlock.add(Pool.maxPool2dBlock(new Shape(new long[]{2, 2}), new Shape(new long[]{2, 2}), new Shape(new long[]{0, 0})));
        return sequentialBlock;
    }

    public static Conv2d getClassPredictionBlock(int i, int i2) {
        return Conv2d.builder().setKernelShape(new Shape(new long[]{3, 3})).setFilters((i2 + 1) * i).optPadding(new Shape(new long[]{1, 1})).build();
    }

    public static Conv2d getAnchorPredictionBlock(int i) {
        return Conv2d.builder().setKernelShape(new Shape(new long[]{3, 3})).setFilters(4 * i).optPadding(new Shape(new long[]{1, 1})).build();
    }

    public static Builder builder() {
        return new Builder();
    }
}
