package org.nd4j.jita.allocator.tad;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Semaphore;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.cache.TadDescriptor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/jita/allocator/tad/DeviceTADManager.class */
public class DeviceTADManager extends BasicTADManager {
    private static final Logger log = LoggerFactory.getLogger(DeviceTADManager.class);
    protected List<Map<TadDescriptor, Pair<DataBuffer, DataBuffer>>> tadCache = new ArrayList();
    private Semaphore lock = new Semaphore(1);

    public DeviceTADManager() {
        int numberOfDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        for (int i = 0; i < numberOfDevices; i++) {
            this.tadCache.add(i, new ConcurrentHashMap());
        }
    }

    @Override // org.nd4j.jita.allocator.tad.BasicTADManager
    public void purgeBuffers() {
        log.info("Purging TAD buffers...");
        this.tadCache = new ArrayList();
        int numberOfDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        for (int i = 0; i < numberOfDevices; i++) {
            log.info("Resetting device: [{}]", Integer.valueOf(i));
            this.tadCache.add(i, new ConcurrentHashMap());
        }
        super.purgeBuffers();
    }

    @Override // org.nd4j.jita.allocator.tad.BasicTADManager
    public Pair<DataBuffer, DataBuffer> getTADOnlyShapeInfo(INDArray iNDArray, int[] iArr) {
        if (iArr != null && iArr.length > 1) {
            Arrays.sort(iArr);
        }
        Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
        TadDescriptor tadDescriptor = new TadDescriptor(iNDArray, iArr);
        if (!this.tadCache.get(deviceId.intValue()).containsKey(tadDescriptor)) {
            log.trace("Creating new TAD...");
            Pair<DataBuffer, DataBuffer> tADOnlyShapeInfo = super.getTADOnlyShapeInfo(iNDArray, iArr);
            if (tADOnlyShapeInfo.getFirst() != iNDArray.shapeInfoDataBuffer()) {
                AtomicAllocator.getInstance().moveToConstant((DataBuffer) tADOnlyShapeInfo.getFirst());
            }
            if (tADOnlyShapeInfo.getSecond() != null) {
                AtomicAllocator.getInstance().moveToConstant((DataBuffer) tADOnlyShapeInfo.getSecond());
            }
            this.tadCache.get(deviceId.intValue()).put(tadDescriptor, tADOnlyShapeInfo);
            this.bytes.addAndGet(((DataBuffer) tADOnlyShapeInfo.getFirst()).length() * 4);
            if (tADOnlyShapeInfo.getSecond() != null) {
                this.bytes.addAndGet(((DataBuffer) tADOnlyShapeInfo.getSecond()).length() * 8);
            }
            log.trace("Using TAD from cache...");
        }
        return this.tadCache.get(deviceId.intValue()).get(tadDescriptor);
    }
}
