#include "bip39.h"

#include "bip39.words.h"
#include "log.h"
#include "util.js.h"

#include "sodium/crypto_hash_sha256.h"

#include <assert.h>
#include <stdio.h>
#include <string.h>

static_assert(k_bip39_words_count == 1 << 11, "Incorrect word count in k_bip39_words.");

bool tf_bip39_bytes_to_words(const uint8_t* bytes, size_t bytes_size, char* out_words, size_t words_size)
{
	if (words_size)
	{
		*out_words = '\0';
	}

	if (bytes_size != 32)
	{
		tf_printf("%s expected 32 bytes, got %zd.\n", __func__, bytes_size);
		return false;
	}

	uint8_t data[33];
	crypto_hash_sha256(data, bytes, bytes_size);
	data[bytes_size] = data[0];
	memcpy(data, bytes, bytes_size);

	int offset = 0;
	for (int i = 0; i < (int)bytes_size * 8; i += 11)
	{
		uint32_t value = 0;
		for (int j = 0; j < 11; j++)
		{
			value <<= 1;
			if ((data[(i + j) / 8] & (1 << (7 - (i + j) % 8))) != 0)
			{
				value |= 1;
			}
		}

		offset += snprintf(out_words + offset, words_size - offset, i < (int)bytes_size * 8 - 11 ? "%s " : "%s", k_bip39_words[value]);
	}
	return true;
}

static int _bip39_word_to_index(const char* word)
{
	if (!word)
	{
		return -1;
	}

	for (int i = 0; i < k_bip39_words_count; i++)
	{
		if (strcmp(word, k_bip39_words[i]) == 0)
		{
			return i;
		}
	}

	return -1;
}

bool tf_bip39_words_to_bytes(const char* words, uint8_t* out_bytes, size_t bytes_size)
{
	int i = 0;
	int word_start = 0;
	uint32_t value = 0;
	uint32_t value_bits = 0;
	int out_index = 0;
	uint8_t bytes[33];
	while (words[i])
	{
		for (; words[i] == ' '; i++)
		{
			word_start = i + 1;
		}

		for (;; i++)
		{
			if (words[i] == ' ' || words[i] == '\0')
			{
				char copy[32] = "";
				memcpy(copy, words + word_start, i - word_start);
				int index = _bip39_word_to_index(copy);
				if (index < 0)
				{
					tf_printf("%s: Word \"%s\" not found in dictionary.\n", __func__, copy);
					return false;
				}
				value <<= 11;
				value |= index;
				value_bits += 11;

				while (value_bits >= 8 && out_index < (int)sizeof(bytes))
				{
					bytes[out_index++] = (value >> (value_bits - 8)) & 0xff;
					value_bits -= 8;
				}
				break;
			}
		}
	}

	if (out_index != sizeof(bytes))
	{
		return false;
	}

	uint8_t data[33];
	crypto_hash_sha256(data, bytes, 32);
	if (data[0] != bytes[32])
	{
		tf_printf("%s: Checksum mismatch (%d vs. %d).\n", __func__, data[0], bytes[32]);
		return false;
	}

	memcpy(out_bytes, bytes, tf_min(sizeof(bytes), bytes_size));
	return true;
}