/*
 * Copyright (c) 2021 Intel Corporation
 *
 * SPDX-License-Identifier: Apache-2.0
 */

#include <errno.h>
#include <stddef.h>
#include <stdbool.h>
#include <stdio.h>
#include <sys/bitarray.h>
#include <sys/check.h>
#include <sys/sys_io.h>

/* Number of bits represented by one bundle */
#define bundle_bitness(ba)	(sizeof(ba->bundles[0]) * 8)

struct bundle_data {
	 /* Start and end index of bundles */
	size_t sidx, eidx;

	 /* Offset inside start and end bundles */
	size_t soff, eoff;

	 /* Masks for start/end bundles */
	uint32_t smask, emask;
};

static void setup_bundle_data(sys_bitarray_t *bitarray,
			      struct bundle_data *bd,
			      size_t offset, size_t num_bits)
{
	bd->sidx = offset / bundle_bitness(bitarray);
	bd->soff = offset % bundle_bitness(bitarray);

	bd->eidx = (offset + num_bits - 1) / bundle_bitness(bitarray);
	bd->eoff = (offset + num_bits - 1) % bundle_bitness(bitarray);

	bd->smask = ~(BIT(bd->soff) - 1);
	bd->emask = (BIT(bd->eoff) - 1) | BIT(bd->eoff);

	if (bd->sidx == bd->eidx) {
		/* The region lies within the same bundle. So combine the masks. */
		bd->smask &= bd->emask;
	}
}

/*
 * Find out if the bits in a region is all set or all clear.
 *
 * @param[in]  bitarray  Bitarray struct
 * @param[in]  offset    Starting bit location
 * @param[in]  num_bits  Number of bits in the region
 * @param[in]  match_set True if matching all set bits,
 *                       False if matching all cleared bits
 * @param[out] bd        Data related to matching which can be
 *                       used later to find out where the region
 *                       lies in the bitarray bundles.
 * @param[out] mismatch  Offset to the mismatched bit.
 *                       Can be NULL.
 *
 * @retval     true      If all bits are set or cleared
 * @retval     false     Not all bits are set or cleared
 */
static bool match_region(sys_bitarray_t *bitarray, size_t offset,
			 size_t num_bits, bool match_set,
			 struct bundle_data *bd,
			 size_t *mismatch)
{
	int idx;
	uint32_t bundle;
	uint32_t mismatch_bundle;
	uint32_t mismatch_mask;
	size_t mismatch_bundle_idx;
	size_t mismatch_bit_off;

	setup_bundle_data(bitarray, bd, offset, num_bits);

	if (bd->sidx == bd->eidx) {
		bundle = bitarray->bundles[bd->sidx];
		if (!match_set) {
			bundle = ~bundle;
		}

		if ((bundle & bd->smask) != bd->smask) {
			/* Not matching to mask. */
			mismatch_bundle = ~bundle & bd->smask;
			mismatch_bundle_idx = bd->sidx;
			mismatch_mask = bd->smask;
			goto mismatch;
		} else {
			/* Matching to mask. */
			goto out;
		}
	}

	/* Region lies in a number of bundles. Need to loop through them. */

	/* Start of bundles */
	bundle = bitarray->bundles[bd->sidx];
	if (!match_set) {
		bundle = ~bundle;
	}

	if ((bundle & bd->smask) != bd->smask) {
		/* Start bundle not matching to mask. */
		mismatch_bundle = ~bundle & bd->smask;
		mismatch_bundle_idx = bd->sidx;
		mismatch_mask = bd->smask;
		goto mismatch;
	}

	/* End of bundles */
	bundle = bitarray->bundles[bd->eidx];
	if (!match_set) {
		bundle = ~bundle;
	}

	if ((bundle & bd->emask) != bd->emask) {
		/* End bundle not matching to mask. */
		mismatch_bundle = ~bundle & bd->emask;
		mismatch_bundle_idx = bd->eidx;
		mismatch_mask = bd->emask;
		goto mismatch;
	}

	/* In-between bundles */
	for (idx = bd->sidx + 1; idx < bd->eidx; idx++) {
		/* Note that this is opposite from above so that
		 * we are simply checking if bundle == 0.
		 */
		bundle = bitarray->bundles[idx];
		if (match_set) {
			bundle = ~bundle;
		}

		if (bundle != 0U) {
			/* Bits in "between bundles" do not match */
			mismatch_bundle = ~bundle;
			mismatch_bundle_idx = idx;
			mismatch_mask = ~0U;
			goto mismatch;
		}
	}

out:
	/* All bits in region matched. */
	return true;

mismatch:
	if (mismatch != NULL) {
		/* Must have at least 1 bit set to indicate
		 * where the mismatch is.
		 */
		__ASSERT_NO_MSG(mismatch_bundle != 0);

		mismatch_bit_off = find_lsb_set(mismatch_bundle) - 1;
		mismatch_bit_off += mismatch_bundle_idx *
				    bundle_bitness(bitarray);
		*mismatch = (uint32_t)mismatch_bit_off;
	}
	return false;
}

/*
 * Set or clear a region of bits.
 *
 * @param bitarray Bitarray struct
 * @param offset   Starting bit location
 * @param num_bits Number of bits in the region
 * @param to_set   True if to set all bits.
 *                 False if to clear all bits.
 * @param bd       Bundle data. Can reuse the output from
 *                 match_region(). NULL if there is no
 *                 prior call to match_region().
 */
static void set_region(sys_bitarray_t *bitarray, size_t offset,
		       size_t num_bits, bool to_set,
		       struct bundle_data *bd)
{
	int idx;
	struct bundle_data bdata;

	if (bd == NULL) {
		bd = &bdata;
		setup_bundle_data(bitarray, bd, offset, num_bits);
	}

	if (bd->sidx == bd->eidx) {
		/* Start/end at same bundle */
		if (to_set) {
			bitarray->bundles[bd->sidx] |= bd->smask;
		} else {
			bitarray->bundles[bd->sidx] &= ~bd->smask;
		}
	} else {
		/* Start/end at different bundle.
		 * So set/clear the bits in start and end bundles
		 * separately. For in-between bundles,
		 * set/clear all bits.
		 */
		if (to_set) {
			bitarray->bundles[bd->sidx] |= bd->smask;
			bitarray->bundles[bd->eidx] |= bd->emask;
			for (idx = bd->sidx + 1; idx < bd->eidx; idx++) {
				bitarray->bundles[idx] = ~0U;
			}
		} else {
			bitarray->bundles[bd->sidx] &= ~bd->smask;
			bitarray->bundles[bd->eidx] &= ~bd->emask;
			for (idx = bd->sidx + 1; idx < bd->eidx; idx++) {
				bitarray->bundles[idx] = 0U;
			}
		}
	}
}

int sys_bitarray_set_bit(sys_bitarray_t *bitarray, size_t bit)
{
	k_spinlock_key_t key;
	int ret;
	size_t idx, off;

	key = k_spin_lock(&bitarray->lock);

	__ASSERT_NO_MSG(bitarray->num_bits > 0);

	if (bit >= bitarray->num_bits) {
		ret = -EINVAL;
		goto out;
	}

	idx = bit / bundle_bitness(bitarray);
	off = bit % bundle_bitness(bitarray);

	bitarray->bundles[idx] |= BIT(off);

	ret = 0;

out:
	k_spin_unlock(&bitarray->lock, key);
	return ret;
}

int sys_bitarray_clear_bit(sys_bitarray_t *bitarray, size_t bit)
{
	k_spinlock_key_t key;
	int ret;
	size_t idx, off;

	key = k_spin_lock(&bitarray->lock);

	__ASSERT_NO_MSG(bitarray->num_bits > 0);

	if (bit >= bitarray->num_bits) {
		ret = -EINVAL;
		goto out;
	}

	idx = bit / bundle_bitness(bitarray);
	off = bit % bundle_bitness(bitarray);

	bitarray->bundles[idx] &= ~BIT(off);

	ret = 0;

out:
	k_spin_unlock(&bitarray->lock, key);
	return ret;
}

int sys_bitarray_test_bit(sys_bitarray_t *bitarray, size_t bit, int *val)
{
	k_spinlock_key_t key;
	int ret;
	size_t idx, off;

	key = k_spin_lock(&bitarray->lock);

	__ASSERT_NO_MSG(bitarray->num_bits > 0);

	CHECKIF(val == NULL) {
		ret = -EINVAL;
		goto out;
	}

	if (bit >= bitarray->num_bits) {
		ret = -EINVAL;
		goto out;
	}

	idx = bit / bundle_bitness(bitarray);
	off = bit % bundle_bitness(bitarray);

	if ((bitarray->bundles[idx] & BIT(off)) != 0) {
		*val = 1;
	} else {
		*val = 0;
	}

	ret = 0;

out:
	k_spin_unlock(&bitarray->lock, key);
	return ret;
}

int sys_bitarray_test_and_set_bit(sys_bitarray_t *bitarray, size_t bit, int *prev_val)
{
	k_spinlock_key_t key;
	int ret;
	size_t idx, off;

	key = k_spin_lock(&bitarray->lock);

	__ASSERT_NO_MSG(bitarray->num_bits > 0);

	CHECKIF(prev_val == NULL) {
		ret = -EINVAL;
		goto out;
	}

	if (bit >= bitarray->num_bits) {
		ret = -EINVAL;
		goto out;
	}

	idx = bit / bundle_bitness(bitarray);
	off = bit % bundle_bitness(bitarray);

	if ((bitarray->bundles[idx] & BIT(off)) != 0) {
		*prev_val = 1;
	} else {
		*prev_val = 0;
	}

	bitarray->bundles[idx] |= BIT(off);

	ret = 0;

out:
	k_spin_unlock(&bitarray->lock, key);
	return ret;
}

int sys_bitarray_test_and_clear_bit(sys_bitarray_t *bitarray, size_t bit, int *prev_val)
{
	k_spinlock_key_t key;
	int ret;
	size_t idx, off;

	key = k_spin_lock(&bitarray->lock);

	__ASSERT_NO_MSG(bitarray->num_bits > 0);

	CHECKIF(prev_val == NULL) {
		ret = -EINVAL;
		goto out;
	}

	if (bit >= bitarray->num_bits) {
		ret = -EINVAL;
		goto out;
	}

	idx = bit / bundle_bitness(bitarray);
	off = bit % bundle_bitness(bitarray);

	if ((bitarray->bundles[idx] & BIT(off)) != 0) {
		*prev_val = 1;
	} else {
		*prev_val = 0;
	}

	bitarray->bundles[idx] &= ~BIT(off);

	ret = 0;

out:
	k_spin_unlock(&bitarray->lock, key);
	return ret;
}

int sys_bitarray_alloc(sys_bitarray_t *bitarray, size_t num_bits,
		       size_t *offset)
{
	k_spinlock_key_t key;
	uint32_t bit_idx;
	int ret;
	struct bundle_data bd;
	size_t off_start, off_end;
	size_t mismatch;

	key = k_spin_lock(&bitarray->lock);

	__ASSERT_NO_MSG(bitarray->num_bits > 0);

	CHECKIF(offset == NULL) {
		ret = -EINVAL;
		goto out;
	}

	if ((num_bits == 0) || (num_bits > bitarray->num_bits)) {
		ret = -EINVAL;
		goto out;
	}

	bit_idx = 0;

	/* Find the first non-allocated bit by looking at bundles
	 * instead of individual bits.
	 *
	 * On RISC-V 64-bit, it complains about undefined reference to `ffs`.
	 * So don't use this on RISCV64.
	 */
	for (ret = 0; ret < bitarray->num_bundles; ret++) {
		if (~bitarray->bundles[ret] == 0U) {
			/* bundle is all 1s => all allocated, skip */
			bit_idx += bundle_bitness(bitarray);
			continue;
		}

		if (bitarray->bundles[ret] != 0U) {
			/* Find the first free bit in bundle if not all free */
			off_start = find_lsb_set(~bitarray->bundles[ret]) - 1;
			bit_idx += off_start;
		}

		break;
	}

	off_end = bitarray->num_bits - num_bits;
	ret = -ENOSPC;
	while (bit_idx <= off_end) {
		if (match_region(bitarray, bit_idx, num_bits, false,
				 &bd, &mismatch)) {
			off_end = bit_idx + num_bits - 1;

			set_region(bitarray, bit_idx, num_bits, true, &bd);

			*offset = bit_idx;
			ret = 0;
			break;
		}

		/* Fast-forward to the bit just after
		 * the mismatched bit.
		 */
		bit_idx = mismatch + 1;
	}

out:
	k_spin_unlock(&bitarray->lock, key);
	return ret;
}

int sys_bitarray_free(sys_bitarray_t *bitarray, size_t num_bits,
		      size_t offset)
{
	k_spinlock_key_t key;
	int ret;
	size_t off_end = offset + num_bits - 1;
	struct bundle_data bd;

	key = k_spin_lock(&bitarray->lock);

	__ASSERT_NO_MSG(bitarray->num_bits > 0);

	if ((num_bits == 0)
	    || (num_bits > bitarray->num_bits)
	    || (offset >= bitarray->num_bits)
	    || (off_end >= bitarray->num_bits)) {
		ret = -EINVAL;
		goto out;
	}

	/* Note that we need to make sure the bits in specified region
	 * (offset to offset + num_bits) are all allocated before we clear
	 * them.
	 */
	if (match_region(bitarray, offset, num_bits, true, &bd, NULL)) {
		set_region(bitarray, offset, num_bits, false, &bd);
		ret = 0;
	} else {
		ret = -EFAULT;
	}

out:
	k_spin_unlock(&bitarray->lock, key);
	return ret;
}

static bool is_region_set_clear(sys_bitarray_t *bitarray, size_t num_bits,
				size_t offset, bool to_set)
{
	bool ret;
	struct bundle_data bd;
	size_t off_end = offset + num_bits - 1;
	k_spinlock_key_t key = k_spin_lock(&bitarray->lock);

	__ASSERT_NO_MSG(bitarray->num_bits > 0);

	if ((num_bits == 0)
	    || (num_bits > bitarray->num_bits)
	    || (offset >= bitarray->num_bits)
	    || (off_end >= bitarray->num_bits)) {
		ret = false;
		goto out;
	}

	ret = match_region(bitarray, offset, num_bits, to_set, &bd, NULL);

out:
	k_spin_unlock(&bitarray->lock, key);
	return ret;
}

bool sys_bitarray_is_region_set(sys_bitarray_t *bitarray, size_t num_bits,
				size_t offset)
{
	return is_region_set_clear(bitarray, num_bits, offset, true);
}

bool sys_bitarray_is_region_cleared(sys_bitarray_t *bitarray, size_t num_bits,
				    size_t offset)
{
	return is_region_set_clear(bitarray, num_bits, offset, false);
}

static int set_clear_region(sys_bitarray_t *bitarray, size_t num_bits,
			    size_t offset, bool to_set)
{
	int ret;
	size_t off_end = offset + num_bits - 1;
	k_spinlock_key_t key = k_spin_lock(&bitarray->lock);

	__ASSERT_NO_MSG(bitarray->num_bits > 0);

	if ((num_bits == 0)
	    || (num_bits > bitarray->num_bits)
	    || (offset >= bitarray->num_bits)
	    || (off_end >= bitarray->num_bits)) {
		ret = -EINVAL;
		goto out;
	}

	set_region(bitarray, offset, num_bits, to_set, NULL);
	ret = 0;

out:
	k_spin_unlock(&bitarray->lock, key);
	return ret;
}

int sys_bitarray_set_region(sys_bitarray_t *bitarray, size_t num_bits,
			    size_t offset)
{
	return set_clear_region(bitarray, num_bits, offset, true);
}

int sys_bitarray_clear_region(sys_bitarray_t *bitarray, size_t num_bits,
			      size_t offset)
{
	return set_clear_region(bitarray, num_bits, offset, false);
}
