/*
 * C
 *
 * Copyright 2025 MicroEJ Corp. All rights reserved.
 * Use of this source code is governed by a BSD-style license that can be found with this software.
 */

#ifndef LLML_IMPL
#define LLML_IMPL

#include <stdint.h>

#ifdef __cplusplus
extern "C" {
#endif

/*
 * @brief Use SNI library to declare its native methods.
 */
#include "sni.h"

// --------------------------------------------------------------------------------
// Native function name redefinition
// --------------------------------------------------------------------------------
#define LLML_IMPL_init_model_from_resource Java_ej_microai_MLNatives_initModelFromResource
#define LLML_IMPL_init_model_from_buffer Java_ej_microai_MLNatives_initModelFromBuffer
#define LLML_IMPL_allocate_model Java_ej_microai_MLNatives_allocateModel
#define LLML_IMPL_free_model Java_ej_microai_MLNatives_freeModel
#define LLML_IMPL_load_model_chunk Java_ej_microai_MLNatives_loadModelChunk
#define LLML_IMPL_clean Java_ej_microai_MLNatives_clean
#define LLML_IMPL_reset Java_ej_microai_MLNatives_reset
#define LLML_IMPL_run Java_ej_microai_MLNatives_run

#define LLML_IMPL_get_input_tensor_count Java_ej_microai_MLNatives_getInputTensorCount
#define LLML_IMPL_get_input_num_dimensions Java_ej_microai_MLNatives_getInputNumDimensions
#define LLML_IMPL_get_input_shape Java_ej_microai_MLNatives_getInputShape
#define LLML_IMPL_get_input_data_type Java_ej_microai_MLNatives_getInputDataType
#define LLML_IMPL_get_input_num_elements Java_ej_microai_MLNatives_getInputNumElements
#define LLML_IMPL_get_input_num_bytes Java_ej_microai_MLNatives_getInputNumBytes
#define LLML_IMPL_get_input_zero_point Java_ej_microai_MLNatives_getInputZeroPoint
#define LLML_IMPL_get_input_scale Java_ej_microai_MLNatives_getInputScale
#define LLML_IMPL_input_quantized Java_ej_microai_MLNatives_inputQuantized

#define LLML_IMPL_get_output_tensor_count Java_ej_microai_MLNatives_getOutputTensorCount
#define LLML_IMPL_get_output_num_dimensions Java_ej_microai_MLNatives_getOutputNumDimensions
#define LLML_IMPL_get_output_shape Java_ej_microai_MLNatives_getOutputShape
#define LLML_IMPL_get_output_data_type Java_ej_microai_MLNatives_getOutputDataType
#define LLML_IMPL_get_output_num_elements Java_ej_microai_MLNatives_getOutputNumElements
#define LLML_IMPL_get_output_num_bytes Java_ej_microai_MLNatives_getOutputNumBytes
#define LLML_IMPL_get_output_zero_point Java_ej_microai_MLNatives_getOutputZeroPoint
#define LLML_IMPL_get_output_scale Java_ej_microai_MLNatives_getOutputScale
#define LLML_IMPL_output_quantized Java_ej_microai_MLNatives_outputQuantized

#define LLML_IMPL_set_input_data_as_byte_array Java_ej_microai_MLNatives_setInputDataAsByteArray
#define LLML_IMPL_set_input_data_as_int_array Java_ej_microai_MLNatives_setInputDataAsIntArray
#define LLML_IMPL_set_input_data_as_float_array Java_ej_microai_MLNatives_setInputDataAsFloatArray

#define LLML_IMPL_get_output_data_as_byte_array Java_ej_microai_MLNatives_getOutputDataAsByteArray
#define LLML_IMPL_get_output_data_as_integer_array Java_ej_microai_MLNatives_getOutputDataAsIntegerArray
#define LLML_IMPL_get_output_data_as_float_array Java_ej_microai_MLNatives_getOutputDataAsFloatArray

// --------------------------------------------------------------------------------
// Typedefs and Structures
// --------------------------------------------------------------------------------

/*
 * @brief LLML tensor types.
 *
 * These values are used to get the tensor type of the input and output tensor of the model.
 */
typedef enum {
	LLML_TENSOR_TYPE_UNKNOWN = 0,
	LLML_TENSOR_TYPE_INT8    = 1,
	LLML_TENSOR_TYPE_UINT8   = 2,
	LLML_TENSOR_TYPE_FLOAT32 = 3,
	LLML_TENSOR_TYPE_INT32   = 4,
	LLML_TENSOR_TYPE_UINT32  = 5
} LLML_Tensor_Type;

// --------------------------------------------------------------------------------
// Functions that must be implemented
// --------------------------------------------------------------------------------

/*
 * @brief Initializes the model : Map the model into a usable data structure, build an interpreter to run the model with and allocate memory for the model's tensors.
 *
 * @param[in] modelPath the model path.
 *
 * @return model interpreter handle.
 */
jint LLML_IMPL_init_model_from_resource(jbyte *modelPath);

/*
 * @brief Initializes the model : Map the model into a usable data structure, build an interpreter to run the model with and allocate memory for the model's tensors.
 *
 * @param[in] modelPointer the model pointer address.
 * @param[in] modelByteSize the model size in bytes.
 *
 * @return model interpreter handle.
 */
jint LLML_IMPL_init_model_from_buffer(jint modelPointer, jint modelByteSize);

/**
 * @brief Allocates native space to store the model data on the native side.
 *
 * @param[in] modelByteSize the model size in bytes.
 *
 * @return model native pointer address.
 */
jint LLML_IMPL_allocate_model(jint modelByteSize);

/**
 * @brief Frees native space allocated to store the model data.
 *
 * @param[in] modelPointer the model pointer address.
 *
 */
void LLML_IMPL_free_model(jint modelPointer);

/**
 * @brief Loads <code>length</code> bytes of the model data on the native side at the specified index.
 *
 * @param[in] data the data to be loaded.
 * @param[in] address the model pointer address.
 * @param[in] index the index where data must be loaded.
 * @param[in] length the number of bytes to be loaded.
 *
 */
void LLML_IMPL_load_model_chunk(jbyte *data, int address, jint index, jint length);

/*
 * @brief Deletes the model and free the allocated native resources.
 *
 * @param[in] modelHandle the model interpreter handle.
 */
void LLML_IMPL_clean(jint modelHandle);

/*
 * @brief Resets the state of the model interpreter.
 *
 * @param[in] modelHandle the model interpreter handle.
 *
 * @return the execution status : 0 if the execution succeeds, 1 otherwise.
 */
jint LLML_IMPL_reset(jint modelHandle);

/*
 * @brief Runs an inference on the model.
 *
 * @param[in] modelHandle the model interpreter handle.
 *
 * @return the execution status : 0 if the execution succeeds, 1 otherwise.
 */
jint LLML_IMPL_run(jint modelHandle);

/*
 * @brief Gets the number of input tensors.
 *
 * @param[in] modelHandle the model interpreter handle.
 *
 * @return number of input tensors.
 */
jint LLML_IMPL_get_input_tensor_count(jint modelHandle);

/*
 * @brief Gets the number of output tensors.
 *
 * @param[in] modelHandle the model interpreter handle.
 *
 * @return number of output tensors.
 */
jint LLML_IMPL_get_output_tensor_count(jint modelHandle);

/*
 * @brief Gets the number of dimensions of the input tensor specified by index.
 *
 * @param[in] modelHandle the model interpreter handle.
 * @param[in] index the tensor index.
 *
 * @return number of dimensions of the input tensor.
 */
jint LLML_IMPL_get_input_num_dimensions(jint modelHandle, jint index);

/*
 * @brief Gets the shape of the input tensor specified by index.
 *
 * @param[in] modelHandle the model interpreter handle.
 * @param[in] index the tensor index.
 * @param[out] sizes sizes of the dimensions of the input tensor.
 */
void LLML_IMPL_get_input_shape(jint modelHandle, jint index, jint *sizes);

/*
 * @brief Gets the data type of the input tensor specified by index.
 *
 * @param[in] modelHandle the model interpreter handle.
 * @param[in] index the tensor index.
 *
 * @return - 0 : unknown type
 *     - 1 : int8
 *     - 2 : uint8
 *     - 3 : float32
 *     - 4 : int32
 *     - 5 : uint32
 */
jint LLML_IMPL_get_input_data_type(jint modelHandle, jint index);

/*
 * @brief Gets the number of elements of the input tensor specified by index.
 *
 * @param[in] modelHandle the model interpreter handle.
 * @param[in] index the tensor index.
 *
 * @return number of elements of the input tensor.
 */
jint LLML_IMPL_get_input_num_elements(jint modelHandle, jint index);

/*
 * @brief Gets the number of bytes the input tensor specified by index.
 *
 * @param[in] modelHandle the model interpreter handle.
 * @param[in] index the tensor index.
 *
 * @return number of bytes of the input tensor.
 */
jint LLML_IMPL_get_input_num_bytes(jint modelHandle, jint index);

/*
 * @brief Gets the zero point of data of the input tensor specified by index.
 *
 * @param[in] modelHandle the model interpreter handle.
 * @param[in] index the tensor index.
 *
 * @return the zero point.
 */
jint LLML_IMPL_get_input_zero_point(jint modelHandle, jint index);

/*
 * @brief Gets the scale of data of the input tensor specified by index.
 *
 * @param[in] modelHandle the model interpreter handle.
 * @param[in] index the tensor index.
 *
 * @return the scale.
 */
jfloat LLML_IMPL_get_input_scale(jint modelHandle, jint index);

/*
 * Get the status of the input tensor specified by index.
 *
 * @param[in] modelHandle the model interpreter handle.
 * @param[in] index the tensor index.
 *
 * @return true if the index is valid and the input tensor is quantized, false otherwise.
 */
jboolean LLML_IMPL_input_quantized(jint modelHandle, jint index);

/*
 * @brief Gets the number of dimensions of the output tensor specified by index.
 *
 * @param[in] modelHandle the model interpreter handle.
 * @param[in] index the tensor index.
 *
 * @return number of dimensions of the output tensor.
 */
jint LLML_IMPL_get_output_num_dimensions(jint modelHandle, jint index);

/*
 * @brief Gets the shape of the input tensor specified by index.
 *
 * @param[in] modelHandle the model interpreter handle.
 * @param[in] index the tensor index.
 * @param[out] sizes sizes of the dimensions of the output tensor.
 */
void LLML_IMPL_get_output_shape(jint modelHandle, jint index, jint *sizes);

/*
 * @brief Gets the data type of the output tensor specified by index.
 *
 * @param[in] modelHandle the model interpreter handle.
 * @param[in] index the tensor index.
 *
 * @return - 0 : unknown type
 *     - 1 : int8
 *     - 2 : uint8
 *     - 3 : float32
 *     - 4 : int32
 *     - 5 : uint32
 */
jint LLML_IMPL_get_output_data_type(jint modelHandle, jint index);

/*
 * @brief Gets the number of elements of the output tensor specified by index.
 *
 * @param[in] modelHandle the model interpreter handle.
 * @param[in] index the tensor index.
 *
 * @return number of elements of the output tensor.
 */
jint LLML_IMPL_get_output_num_elements(jint modelHandle, jint index);

/*
 * @brief Gets the number of bytes the output tensor specified by index.
 *
 * @param[in] modelHandle the model interpreter handle.
 * @param[in] index the tensor index.
 *
 * @return number of bytes of the output tensor.
 */
jint LLML_IMPL_get_output_num_bytes(jint modelHandle, jint index);

/*
 * @brief Gets the zero point of data of the output tensor specified by index.
 *
 * @param[in] modelHandle the model interpreter handle.
 * @param[in] index the tensor index.
 *
 * @return the zero point.
 */
jint LLML_IMPL_get_output_zero_point(jint modelHandle, jint index);

/*
 * @brief Gets the scale of data of the output tensor specified by index.
 *
 * @param[in] modelHandle the model interpreter handle.
 * @param[in] index the tensor index.
 *
 * @return the scale.
 */
jfloat LLML_IMPL_get_output_scale(jint modelHandle, jint index);

/*
 * @brief Gets the status of the output tensor specified by index.
 *
 * @param[in] modelHandle the model interpreter handle.
 * @param[in] index the tensor index.
 *
 * @return true if the index is valid and the output tensor is quantized, false otherwise.
 */
jboolean LLML_IMPL_output_quantized(jint modelHandle, jint index);

/*
 * @brief Sets the inputData as array of signed or unsigned bytes to the input tensor specified by index.
 *
 * @param[in] modelHandle the model interpreter handle.
 * @param[in] index the tensor index.
 * @param[in] inputData data to be loaded into the model.
 */
void LLML_IMPL_set_input_data_as_byte_array(jint modelHandle, jint index, jbyte *inputData);

/*
 * @brief Sets the inputData as array of signed or unsigned integers to the input tensor specified by index.
 *
 * @param[in] modelHandle the model interpreter handle.
 * @param[in] index the tensor index
 * @param[in] inputData data to be loaded into the model.
 */
void LLML_IMPL_set_input_data_as_int_array(jint modelHandle, jint index, jint *inputData);

/*
 * @brief Sets the inputData as array of floats to the input tensor specified by index.
 *
 * @param[in] modelHandle the model interpreter handle.
 * @param[in] index the tensor index.
 * @param[in] inputData data to be loaded into the model.
 */
void LLML_IMPL_set_input_data_as_float_array(jint modelHandle, jint index, jfloat *inputData);

/*
 * @brief Gets the outputData of the tensor specified by index.
 *
 * @param[in] modelHandle the model interpreter handle.
 * @param[in] index the tensor index.
 * @param[out] outputData the inference result as array of signed or unsigned bytes if index is valid, null otherwise.
 */
void LLML_IMPL_get_output_data_as_byte_array(jint modelHandle, jint index, jbyte *outputData);

/*
 * @brief Gets the outputData of the tensor specified by index.
 *
 * @param[in] modelHandle the model interpreter handle.
 * @param[in] index the tensor index.
 * @param[out] outputData the inference result as array of signed or unsigned integers if index is valid, null otherwise.
 */
void LLML_IMPL_get_output_data_as_integer_array(jint modelHandle, jint index, jint *outputData);

/*
 * @brief Gets the outputData of the tensor specified by index.
 *
 * @param[in] modelHandle the model interpreter handle.
 * @param[in] index the tensor index.
 * @param[out] outputData the inference result as array of floats if index is valid, null otherwise.
 */
void LLML_IMPL_get_output_data_as_float_array(jint modelHandle, jint index, jfloat *outputData);

#ifdef __cplusplus
}
#endif
#endif
