#include "../cka.h"

#include "../kem.h"

size_t CKA_SEED_N() {
    return 2;
}

void cka_rand_seed(bytes_O seeds) {
    seeds[0] = alloc_bytes(KEM_PK_L());
    seeds[1] = alloc_bytes(KEM_SK_L());
    kem_gen(seeds[0].p, seeds[1].p);
}

void cka_init_send(bytes_i seeds, bytes_O k_send) {
    *k_send = __WARN__to_bytes(seeds[0]);
}
void cka_init_recv(bytes_i seeds, bytes_O k_recv) {
    *k_recv = __WARN__to_bytes(seeds[1]);
}

void cka_free_seed(bytes_I seeds) {
    // @Note: seeds are already completely used as k_send and k_recv
}

size_t CKA_CT_N() {
    return 2;
}

void cka_send(uint8_t id, bytes_i k_send, bytes_O cts, bytes_O k, bytes_O k_recv) {
    cts[0] = alloc_bytes(KEM_CT_L());
    *k = alloc_bytes(KEM_K_L());
    kem_enc(k_send->p, cts[0].p, k->p);
    cts[1] = alloc_bytes(KEM_PK_L());
    *k_recv = alloc_bytes(KEM_SK_L());
    kem_gen(cts[1].p, k_recv->p);
}

void cka_recv(uint8_t id, bytes_i k_recv, bytes_i cts, bytes_O k, bytes_O k_send) {
    *k = alloc_bytes(KEM_K_L());
    kem_dec(k_recv->p, cts[0].p, k->p);
    *k_send = __WARN__to_bytes(cts[1]);
}

void cka_free_ct(bytes_I cts) {
    free_bytes(&cts[0]);
    // @Note: cts[1] is used as k_send
}
