/*
 * Java
 *
 * Copyright 2025 MicroEJ Corp. All rights reserved.
 * Use of this source code is governed by a BSD-style license that can be found with this software.
 */
package ej.microai;

import ej.sni.SNI;

import java.io.Closeable;
import java.io.IOException;
import java.io.InputStream;

/**
 * The MLInferenceEngine class provides essential functionality for interacting with
 * Machine Learning models, specifically for running inferences.
 * <p>
 * It allows users to load a trained model from a resource (using {@link MLInferenceEngine#MLInferenceEngine(String)}) or from an {@link InputStream} (using {@link MLInferenceEngine#MLInferenceEngine(InputStream)}).
 * <p>
 * Users can set input tensors, execute inference, and retrieve output tensor values.
 * <p>
 * A MLInferenceEngine allocates native resources when it is opened.
 * The MLInferenceEngine should be closed with the {@link #close()} method in order to free the native allocation.
 */
public class MLInferenceEngine implements Closeable {
	/**
	 * Number of bytes to read from the input stream.
	 */
	private static final int READ_CHUNK_SIZE = 4096;
	private final int modelHandle;
	private final int modelBuffer;
    private volatile boolean closed;
	private final int numberInputTensors;
	private final int numberOutputTensors;
	private final InputTensor[] inputTensors;
	private final OutputTensor[] outputTensors;

	/**
	 * Initializes the model from a pre-trained model resource.
	 * <p>
	 * Map the model into a usable data structure, build an interpreter to run the model with and allocate memory for
	 * the model's tensors.
	 *
	 * @param modelPath
	 *            the path of the pre-trained model resource.
	 */
	public MLInferenceEngine(String modelPath) {
		this.modelHandle = MLNatives.initModelFromResource(SNI.toCString(modelPath));
		this.modelBuffer = 0;
        this.closed = false;
		this.numberInputTensors = MLNatives.getInputTensorCount(this.modelHandle);
		this.numberOutputTensors = MLNatives.getOutputTensorCount(this.modelHandle);
		this.inputTensors = new InputTensor[this.numberInputTensors];
		this.outputTensors = new OutputTensor[this.numberOutputTensors];
	}

	/**
	 * Initializes the model from a given {@link InputStream}.
	 * <p>
	 * This method will block until the model is completely retrieved in the native side from the input stream or an {@link java.io.IOException} occurs.
	 * <p>
	 * Map the model into a usable data structure, build an interpreter to run the model with and allocate memory for
	 * the model's tensors.
	 *
	 * @param is
	 *            the input stream to read model from.
	 *
	 * @throws IOException if an {@link IOException} is encountered by the input stream.
	 */
	public MLInferenceEngine(InputStream is) throws IOException {
		// Allocates the space for the model data on the native side.
		this.modelBuffer = MLNatives.allocateModel(is.available());

		// Load model data on the native side.
		byte[] modelData = new byte[READ_CHUNK_SIZE];
		int totalBytesRead = 0;
		int bytesRead;
		while((bytesRead = is.read(modelData, 0, READ_CHUNK_SIZE)) != -1) {
			MLNatives.loadModelChunk(modelData, this.modelBuffer, totalBytesRead, bytesRead);
			totalBytesRead += bytesRead;
		}

		// Initialize the model
		this.modelHandle = MLNatives.initModelFromBuffer(this.modelBuffer, totalBytesRead);
		this.closed = false;
		this.numberInputTensors = MLNatives.getInputTensorCount(this.modelHandle);
		this.numberOutputTensors = MLNatives.getOutputTensorCount(this.modelHandle);
		this.inputTensors = new InputTensor[this.numberInputTensors];
		this.outputTensors = new OutputTensor[this.numberOutputTensors];
	}

	/**
	 * Resets the state of the model interpreter.
	 *
	 * @return 0 if the execution succeeds, 1 otherwise.
	 */
	public int reset() {
		return MLNatives.reset(this.modelHandle);
	}

	/**
	 * Runs an inference on the model.
	 *
	 * @return 0 if the execution succeeds, 1 otherwise.
	 */
	public int run() {
		return MLNatives.run(this.modelHandle);
	}

	/**
	 * Gets the input tensor specified by the index.
	 *
	 * @param index
	 *            the index of the tensor, usually <code>index = 0</code> except models who have more than one tensor as
	 *            input.
	 *
	 * @throws IllegalArgumentException
	 *             if index is invalid.
	 *
	 * @return the input tensor specified by index.
	 */
	public InputTensor getInputTensor(int index) {
		if (index >= this.numberInputTensors) {
			throw new IllegalArgumentException("Input index " + index + " out of range (length is "
					+ this.numberInputTensors + ")");
		}
		if (this.inputTensors[index] == null) {
			this.inputTensors[index] = new InputTensor(this.modelHandle, index);
		}
		return this.inputTensors[index];
	}

	/**
	 * Gets the output tensor specified by the index.
	 *
	 * @param index
	 *            the index of the tensor, usually <code>index = 0</code> except models who have more than one tensor as
	 *            output.
	 *
	 * @throws IllegalArgumentException
	 *             if index is invalid.
	 *
	 * @return the output tensor specified by index.
	 */
	public OutputTensor getOutputTensor(int index) {
		if (index >= this.numberOutputTensors) {
			throw new IllegalArgumentException("Output index " + index + " out of range (length is "
					+ this.numberOutputTensors + ")");
		}
		if (this.outputTensors[index] == null) {
			this.outputTensors[index] = new OutputTensor(this.modelHandle, index);
		}
		return this.outputTensors[index];
	}

	/**
	 * Gets the number of input Tensors.
	 *
	 * @return the number of input Tensors.
	 *
	 */
	public int getInputTensorCount() {
		return this.numberInputTensors;
	}

	/**
	 * Gets the number of output Tensors.
	 *
	 * @return the number of output Tensors.
	 *
	 */
	public int getOutputTensorCount() {
		return this.numberOutputTensors;
	}

    /**
     * Returns whether this model has been closed.
     * @return whether this model has been closed.
     */
    public boolean isClosed() {
        return this.closed;
    }

    /**
     * Closes this model and its associated resources.
     * <p>
     * This method releases the native resources allocated when opening this model.
     * Calling this method on a model that has already been closed has no effect.
     */
    @Override
    public void close() {
        this.closed = true;
        MLNatives.clean(this.modelHandle);
        MLNatives.freeModel(this.modelBuffer);
    }
}
