package org.apache.tvm.contrib;

import org.apache.tvm.Device;
import org.apache.tvm.Function;
import org.apache.tvm.Module;
import org.apache.tvm.NDArray;

/* loaded from: classes6.dex */
public class GraphModule {
    private Device device;
    private Function fdebugGetOutput;
    private Function fgetInput;
    private Function fgetOutput;
    private Function floadParams;
    private Function frun;
    private Function fsetInput;
    private Module module;

    /* JADX INFO: Access modifiers changed from: package-private */
    public GraphModule(Module module, Device device) {
        this.module = module;
        this.device = device;
        this.fsetInput = module.getFunction("set_input");
        this.frun = module.getFunction("run");
        this.fgetInput = module.getFunction("get_input");
        this.fgetOutput = module.getFunction("get_output");
        try {
            this.fdebugGetOutput = module.getFunction("debug_get_output");
        } catch (IllegalArgumentException e) {
        }
        this.floadParams = module.getFunction("load_params");
    }

    public NDArray debugGetOutput(int i, NDArray nDArray) {
        Function function = this.fdebugGetOutput;
        if (function == null) {
            throw new RuntimeException("Please compile runtime with USE_PROFILER = ON");
        }
        function.pushArg(i).pushArg(nDArray).invoke();
        return nDArray;
    }

    public NDArray debugGetOutput(String str, NDArray nDArray) {
        Function function = this.fdebugGetOutput;
        if (function == null) {
            throw new RuntimeException("Please compile runtime with USE_PROFILER = ON");
        }
        function.pushArg(str).pushArg(nDArray).invoke();
        return nDArray;
    }

    public Function getFunction(String str) {
        return this.module.getFunction(str);
    }

    public NDArray getInput(int i, NDArray nDArray) {
        this.fgetInput.pushArg(i).pushArg(nDArray).invoke();
        return nDArray;
    }

    public NDArray getOutput(int i, NDArray nDArray) {
        this.fgetOutput.pushArg(i).pushArg(nDArray).invoke();
        return nDArray;
    }

    public GraphModule loadParams(byte[] bArr) {
        this.floadParams.pushArg(bArr).invoke();
        return this;
    }

    public void release() {
        this.fsetInput.release();
        this.frun.release();
        this.fgetInput.release();
        this.fgetOutput.release();
        Function function = this.fdebugGetOutput;
        if (function != null) {
            function.release();
        }
        this.floadParams.release();
        this.module.release();
    }

    public GraphModule run() {
        this.frun.invoke();
        return this;
    }

    public GraphModule setInput(int i, NDArray nDArray) {
        if (!nDArray.device().equals(this.device)) {
            NDArray empty = NDArray.empty(nDArray.shape(), this.device);
            nDArray.copyTo(empty);
            nDArray = empty;
        }
        this.fsetInput.pushArg(i).pushArg(nDArray).invoke();
        return this;
    }

    public GraphModule setInput(String str, NDArray nDArray) {
        if (!nDArray.device().equals(this.device)) {
            NDArray empty = NDArray.empty(nDArray.shape(), this.device);
            nDArray.copyTo(empty);
            nDArray = empty;
        }
        this.fsetInput.pushArg(str).pushArg(nDArray).invoke();
        return this;
    }
}
