/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.engine.rust;

import ai.djl.engine.rust.RsNDArray;
import ai.djl.engine.rust.RsNDManager;
import ai.djl.engine.rust.RustLibrary;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.ParameterList;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicReference;

public class RsSymbolBlock
extends AbstractSymbolBlock
implements AutoCloseable {
    private AtomicReference<Long> handle;
    private String uid;
    private RsNDManager manager;

    public RsSymbolBlock(RsNDManager manager, long handle) {
        this.handle = new AtomicReference<Long>(handle);
        this.manager = manager;
        this.inputNames = Arrays.asList(RustLibrary.getInputNames(handle));
        this.uid = String.valueOf(handle);
        manager.attachInternal(this.uid, new AutoCloseable[]{this});
    }

    protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        if (this.inputNames.size() != inputs.size()) {
            throw new IllegalArgumentException("Input size mismatch, requires: " + this.inputNames);
        }
        try (RsNDManager sub = (RsNDManager)this.manager.newSubManager();){
            long[] inputHandles = new long[inputs.size()];
            for (int i = 0; i < inputs.size(); ++i) {
                inputHandles[i] = (Long)sub.from((NDArray)inputs.get(i)).getHandle();
            }
            long outputHandle = RustLibrary.runInference(this.handle.get(), inputHandles);
            RsNDArray output = new RsNDArray(this.manager, outputHandle, inputs.head().getDataType());
            output.attach(inputs.head().getManager());
            NDList nDList = new NDList(new NDArray[]{output});
            return nDList;
        }
    }

    @Override
    public void close() {
        Long pointer = this.handle.getAndSet(null);
        if (pointer != null) {
            this.manager.detachInternal(this.uid);
            this.manager = null;
        }
    }

    public Long getHandle() {
        Long reference = this.handle.get();
        if (reference == null) {
            throw new IllegalStateException("Rust model handle has been released!");
        }
        return reference;
    }

    public ParameterList getDirectParameters() {
        throw new UnsupportedOperationException("Not yet supported");
    }
}

