/*
 * Java
 *
 * Copyright 2025 MicroEJ Corp. All rights reserved.
 * Use of this source code is governed by a BSD-style license that can be found with this software.
 */
package ej.microai;

/**
 * The <code>Tensor</code> class offers services to deal with MicroAI tensors.
 * <p>
 * <code>Tensor</code> are the input/output of your {@link MLInferenceEngine}.
 */
public class Tensor {

    private final int modelHandle;
    private final int tensorIndex;
    private final boolean isInputTensor;

    /**
     * Tensor constructor.
     *
     * @param modelHandle
     *            the model handle of the tensor.
     * @param index
     *            the index of the tensor, usually <code>index = 0</code> except models who have more than one tensor as
     *            output.
     * @param isInputTensor
     *            true if it is an input tensor, false if it is an output tensor.
     */
    Tensor(int modelHandle, int index, boolean isInputTensor) {
        this.modelHandle = modelHandle;
        this.tensorIndex = index;
        this.isInputTensor = isInputTensor;
    }

    /**
     * Gets tensor data type.
     *
     * @return the tensor data type (see {@link Tensor.DataType}).
     */
    public int getDataType() {
        int dataType;
        if (this.isInputTensor) {
            dataType = MLNatives.getInputDataType(this.modelHandle, this.tensorIndex);
        } else {
            dataType = MLNatives.getOutputDataType(this.modelHandle, this.tensorIndex);
        }
        return dataType;
    }

    /**
     * Gets the number of bytes of the tensor.
     *
     * @return number of bytes of the tensor.
     */
    public int getNumberBytes() {
        int numberBytes;
        if (this.isInputTensor) {
            numberBytes = MLNatives.getInputNumBytes(this.modelHandle, this.tensorIndex);
        } else {
            numberBytes = MLNatives.getOutputNumBytes(this.modelHandle, this.tensorIndex);
        }
        return numberBytes;
    }

    /**
     * Gets the number of dimensions of the tensor.
     *
     * @return number of dimensions of the tensor.
     */
    public int getNumberDimensions() {
        int numberDimensions;
        if (this.isInputTensor) {
            numberDimensions = MLNatives.getInputNumDimensions(this.modelHandle, this.tensorIndex);
        } else {
            numberDimensions = MLNatives.getOutputNumDimensions(this.modelHandle, this.tensorIndex);
        }
        return numberDimensions;
    }

    /**
     * Gets the number of elements of the tensor.
     *
     * @return number of elements of the tensor.
     */
    public int getNumberElements() {
        int numberElements;
        if (this.isInputTensor) {
            numberElements = MLNatives.getInputNumElements(this.modelHandle, this.tensorIndex);
        } else {
            numberElements = MLNatives.getOutputNumElements(this.modelHandle, this.tensorIndex);
        }
        return numberElements;
    }

    /**
     * Gets parameters of asymmetric quantization for the tensor.
     * <p>
     * Real values can be quantized using: <code>quantized_value = real_value/scale + zero_point</code>.
     *
     * @return the quantization parameters (see {@link QuantizationParameters}).
     */
    public QuantizationParameters getQuantizationParams() {
        int zeroPoint;
        float scale;
        if (this.isInputTensor) {
            zeroPoint = MLNatives.getInputZeroPoint(this.modelHandle, this.tensorIndex);
            scale = MLNatives.getInputScale(this.modelHandle, this.tensorIndex);
        } else {
            zeroPoint = MLNatives.getOutputZeroPoint(this.modelHandle, this.tensorIndex);
            scale = MLNatives.getOutputScale(this.modelHandle, this.tensorIndex);
        }
        return new QuantizationParameters(zeroPoint, scale);
    }

    /**
     * Gets the tensor shape.
     * Fill an array with the size of each dimension where the <code>n-th</code> element of the array correspond to the <code>n-th</code> dimension of the tensor.
     *
     * @param sizes is an array that contains the size of each dimension of the tensor.
     */
    public void getShape(int[] sizes) {
        if (this.isInputTensor) {
            MLNatives.getInputShape(this.modelHandle, this.tensorIndex, sizes);
        } else {
            MLNatives.getOutputShape(this.modelHandle, this.tensorIndex, sizes);
        }
    }

    /**
     * Gets the quantization status of the tensor.
     *
     * @return {@code true} if the tensor is quantized, {@code false} otherwise.
     */
    public boolean isQuantized() {
        boolean isQuantized;
        if (this.isInputTensor) {
            isQuantized = MLNatives.inputQuantized(this.modelHandle, this.tensorIndex);
        } else {
            isQuantized = MLNatives.outputQuantized(this.modelHandle, this.tensorIndex);
        }
        return isQuantized;
    }

    /**
     * The <code>Tensor.DataType</code> class enumerates the MicroAI data types.
     */
    public static class DataType {

        /**
         * Unknown type.
         */
        public static final int UNKNOWN = 0;

        /**
         * 8-bits signed integer type.
         */
        public static final int INT8 = 1;

        /**
         * 8-bits unsigned integer type.
         */
        public static final int UINT8 = 2;

        /**
         * 32-bits float type.
         */
        public static final int FLOAT32 = 3;

        /**
         * 32-bits signed integer type.
         */
        public static final int INT32 = 4;

        /**
         * 32-bits unsigned integer type.
         */
        public static final int UINT32 = 5;

        private DataType() {
            // Prevent instantiation.
        }

    }

    /**
     * The <code>QuantizationParameters</code> class represents the quantization parameters used by a {@link MLInferenceEngine}.
     * <p>
     * <code>scale</code> and <code>zeroPoint</code> values are <code>0</code> when the tensor is not quantized.
     * <p>
     * Real values can be quantized using: <code>quantized_value = real_value/scale + zero_point</code>.
     * <p>
     * Quantized values can be converted back to float using:
     * <code>real_value = scale * (quantized_value - zero_point)</code>.
     */
    public static class QuantizationParameters {
        /*
         * scale of quantization
         */
        private final float scale;
        /*
         * origin of quantization
         */
        private final int zeroPoint;

        /**
         * Creates an instance of QuantizationParameters.
         *
         * @param zeroPoint
         *            zero point of quantization
         * @param scale
         *            scale of quantization
         */
        public QuantizationParameters(int zeroPoint, float scale) {
            this.scale = scale;
            this.zeroPoint = zeroPoint;
        }

        /**
         * Gets the scale value of the quantization parameters.
         *
         * @return the scale value of the quantization parameters.
         */
        public float getScale() {
            return this.scale;
        }

        /**
         * Gets the zero point value of the quantization parameters.
         *
         * @return the zero point value of the quantization parameters.
         */
        public int getZeroPoint() {
            return this.zeroPoint;
        }

    }
}
