/*
 * Java
 *
 * Copyright 2025 MicroEJ Corp. All rights reserved.
 * Use of this source code is governed by a BSD-style license that can be found with this software.
 */
package ej.microai;

/**
 * Native functions
 */
class MLNatives {

	private MLNatives() {
		// Prevent instantiation.
	}

	/**
	 * Initializes the model : Map the model into a usable data structure, build an interpreter to run the model with and allocate memory for the model's tensors.
	 *
	 * @param modelPath the model path.
	 * @param inferenceMemoryPoolSize Memory allocated to load the tensors and perform inferences.
	 *
	 * @return model interpreter handle.
	 */
	static native long initModelFromResource(byte[] modelPath, int inferenceMemoryPoolSize);

	/**
	 * Allocates native space to store the model data on the native side.
	 * Returns model native pointer address.
	 *
	 * @param modelByteSize the model size in bytes.
	 *
	 * @return model native pointer address.
	 */
	static native long allocateModel(int modelByteSize);

	/**
	 * Frees native space allocated to store the model data.
	 *
	 * @param modelPointer the model pointer address.
	 *
	 */
	static native void freeModel(long modelPointer);

	/**
	 * Loads <code>length</code> bytes of the model data on the native side at the specified index.
	 *
	 * @param data the data to be loaded.
	 * @param address the model pointer address.
	 * @param index the index where data must be loaded.
	 * @param length the number of bytes to be loaded.
	 *
	 */
	static native void loadModelChunk(byte[] data, long address, int index, int length);

	/**
	 * Initializes the model : Map the model into a usable data structure, build an interpreter to run the model with and allocate memory for the model's tensors.
	 *
	 * @param modelPointer the model pointer address.
	 * @param modelByteSize the model size in bytes.
	 * @param inferenceMemoryPoolSize Memory allocated to load the tensors and perform inferences.
	 *
	 * @return model interpreter handle.
	 */
	static native long initModelFromBuffer(long modelPointer, int modelByteSize, int inferenceMemoryPoolSize);

	/**
	 * Resets the state of the model interpreter.
	 *
	 * @param modelHandle the model interpreter handle.
	 *
	 * @return the execution status : 0 if the execution succeeds, 1 otherwise.
	 */
	static native int reset(long modelHandle);

	/**
	 * Runs an inference on the model.
	 *
	 * @param modelHandle the model interpreter handle.
	 *
	 * @return the execution status : 0 if the execution succeeds, 1 otherwise.
	 */
	static native int run(long modelHandle);

	/**
	 * Deletes the model and free the allocated native resources.
	 *
	 * @param modelHandle the model interpreter handle.
	 *
	 */
	static native void clean(long modelHandle);

	/**
	 * Gets the number of input tensors.
	 *
	 * @param modelHandle the model interpreter handle.
	 *
	 * @return number of input tensors.
	 */
	static native int getInputTensorCount(long modelHandle);

	/**
	 * Gets the number of output tensors.
	 *
	 * @param modelHandle the model interpreter handle.
	 *
	 * @return number of output tensors.
	 */
	static native int getOutputTensorCount(long modelHandle);

	/**
	 * Sets the inputData as array of signed or unsigned bytes to the input tensor specified by index.
	 *
	 * @param modelHandle the model interpreter handle.
	 * @param index the tensor index.
	 * @param inputData data to be loaded into the model.
	 */
	static native void setInputDataAsByteArray(long modelHandle, int index, byte[] inputData);

	/**
	 * Sets the inputData as array of signed or unsigned integers to the input tensor specified by index.
	 *
	 * @param modelHandle the model interpreter handle.
	 * @param index the tensor index.
	 * @param inputData data to be loaded into the model.
	 */
	static native void setInputDataAsIntArray(long modelHandle, int index, int[] inputData);

	/**
	 * Sets the inputData as array of floats to the input tensor specified by index.
	 *
	 * @param modelHandle the model interpreter handle.
	 * @param index the tensor index.
	 * @param inputData data to be loaded into the model.
	 */
	static native void setInputDataAsFloatArray(long modelHandle, int index, float[] inputData);

	/**
	 * Gets the quantization status of the output tensor specified by index.
	 *
	 *  @param modelHandle the model interpreter handle.
	 *  @param index the tensor index.
	 *
	 *  @return true if the index is valid and the output tensor is quantized, false otherwise.
	 */
	static native boolean outputQuantized(long modelHandle, int index);

	/**
	 * Gets the quantization status of the input tensor specified by index.
	 *
	 *  @param modelHandle the model interpreter handle.
	 *  @param index the tensor index.
	 *
	 *  @return true if the index is valid and the input tensor is quantized, false otherwise.
	 */
	static native boolean inputQuantized(long modelHandle, int index);

	/**
	 * Gets the data type of the output tensor specified by index.
	 *
	 * @param modelHandle the model interpreter handle.
	 * @param index the tensor index.
	 *
	 * @return - 0 : unknown type
	 * 	   - 1 : int8
	 *     - 2 : uint8
	 *     - 3 : float32
	 *     - 4 : int32
	 *     - 5 : uint32

	 */
	static native int getOutputDataType(long modelHandle, int index);

	/**
	 * Gets the data type of the input tensor specified by index.
	 *
	 * @param modelHandle the model interpreter handle.
	 * @param index the tensor index.
	 *
	 * @return - 0 : unknown type
	 * 	   - 1 : int8
	 *     - 2 : uint8
	 *     - 3 : float32
	 *     - 4 : int32
	 *     - 5 : uint32

	 */
	static native int getInputDataType(long modelHandle, int index);

	/**
	 * Gets the scale of data of the output tensor specified by index.
	 *
	 * @param modelHandle the model interpreter handle.
	 * @param index the tensor index.
	 *
	 * @return the scale.
	 */
	static native float getOutputScale(long modelHandle, int index);

	/**
	 * Gets the scale of data of the input tensor specified by index.
	 *
	 * @param modelHandle the model interpreter handle.
	 * @param index the tensor index.
	 *
	 * @return the scale.
	 */
	static native float getInputScale(long modelHandle, int index);

	/**
	 * Gets the zero point of data of the output tensor specified by index.
	 *
	 * @param modelHandle the model interpreter handle.
	 * @param index the tensor index.
	 *
	 * @return the zero point.
	 */
	static native int getOutputZeroPoint(long modelHandle, int index);

	/**
	 * Gets the zero point of data of the input tensor specified by index.
	 *
	 * @param modelHandle the model interpreter handle.
	 * @param index the tensor index.
	 *
	 * @return the zero point.
	 */
	static native int getInputZeroPoint(long modelHandle, int index);

	/**
	 * Gets the number of bytes the output tensor specified by index.
	 *
	 * @param modelHandle the model interpreter handle.
	 * @param index the tensor index.
	 *
	 * @return number of bytes of the output tensor.
	 */
	static native int getOutputNumBytes(long modelHandle, int index);

	/**
	 * Gets the number of bytes the input tensor specified by index.
	 *
	 * @param modelHandle the model interpreter handle.
	 * @param index the tensor index.
	 *
	 * @return number of bytes of the input tensor.
	 */
	static native int getInputNumBytes(long modelHandle, int index);

	/**
	 * Gets the number of elements of the input tensor specified by index.
	 *
	 * @param modelHandle the model interpreter handle.
	 * @param index the tensor index.
	 *
	 * @return number of elements of the input tensor.
	 */
	static native int getInputNumElements(long modelHandle, int index);

	/**
	 * Gets the number of elements of the output tensor specified by index.
	 *
	 * @param modelHandle the model interpreter handle.
	 * @param index the tensor index.
	 *
	 * @return number of elements of the output tensor.
	 */
	static native int getOutputNumElements(long modelHandle, int index);

	/**
	 * Gets the number of dimensions of the input tensor specified by index.
	 *
	 * @param modelHandle the model interpreter handle.
	 * @param index the tensor index.
	 *
	 * @return number of dimensions of the input tensor.
	 */
	static native int getInputNumDimensions(long modelHandle, int index);

	/**
	 * Gets the number of dimensions of the output tensor specified by index.
	 *
	 * @param modelHandle the model interpreter handle.
	 * @param index the tensor index.
	 *
	 * @return number of dimensions of the output tensor.
	 */
	static native int getOutputNumDimensions(long modelHandle, int index);

	/**
	 * Gets the shape of the input tensor specified by index.
	 *
	 * @param modelHandle the model interpreter handle.
	 * @param index the tensor index.
	 * @param sizes sizes of the dimensions of the input tensor.
	 */
	static native void getInputShape(long modelHandle, int index, int[] sizes);

	/**
	 * Gets the shape of the input tensor specified by index.
	 *
	 * @param modelHandle the model interpreter handle.
	 * @param index the tensor index.
	 * @param sizes sizes of the dimensions of the output tensor.
	 */
	static native void getOutputShape(long modelHandle, int index, int[] sizes);

	/**
	 * Gets the outputData of the tensor specified by index.
	 *
	 * @param modelHandle the model interpreter handle.
	 * @param index the tensor index.
	 * @param outputData the inference result as array of signed or unsigned bytes if index is valid, null otherwise.
	 */
	static native void getOutputDataAsByteArray(long modelHandle, int index, byte[] outputData);

	/**
	 * Gets the outputData of the tensor specified by index.
	 *
	 * @param modelHandle the model interpreter handle.
	 * @param index the tensor index.
	 * @param outputData the inference result as array of signed or unsigned integers if index is valid, null otherwise.
	 */
	static native void getOutputDataAsIntegerArray(long modelHandle, int index, int[] outputData);

	/**
	 * Gets the outputData of the tensor specified by index.
	 *
	 * @param modelHandle the model interpreter handle.
	 * @param index the tensor index.
	 * @param outputData the inference result as array of floats if index is valid, null otherwise.
	 */
	static native void getOutputDataAsFloatArray(long modelHandle, int index, float[] outputData);

}
