#include "sh.h"

#include <stdlib.h>
#include <stdio.h>
#include <string.h>

#include "signal.h"
#include "cka.h"

#define MAGIC 9999

static Signal_state state[2];

void _sh_init(bytes_i seed) {
    bytes cka_seeds[CKA_SEED_N()];
    cka_rand_seed(cka_seeds);
    Signal_init_send(seed, as_const_bytes(cka_seeds), &state[0]);
    Signal_init_recv(seed, as_const_bytes(cka_seeds), &state[1]);
    cka_free_seed(as_const_bytes(cka_seeds));
#ifdef PLOT
    printf("\n\t\tAlice\t\t\t\t\tBob\n\n");
}
static const char arrow[2][2][9] = {{"-->", "\t\t\t\t\t<--"}, {"<--", "\t\t\t\t\t-->"}};
#else
}
#endif
void sh_init(size_t seed_l) {
    bytes seed = alloc_bytes(seed_l);
    _sh_init(as_const_bytes(&seed));
    free_bytes(as_const_bytes(&seed));
}
void init() {
    sh_init(MAGIC % 997);
}

static struct {
    uint8_t id;
    const_bytes m;
    Signal_header h;
    bytes c;
} buffer[MAGIC];
static size_t n = 0;

size_t _send(uint8_t id, const char message[], int verbose) {
    if (n >= MAGIC) {
        printf("Warn: Buffer overflow.");
        n = 0;
    }
    id = id % 2;
    buffer[n].id = id;
    if (verbose > 1) {
        // @TODO: add logging print_state(id);
    }
    const_bytes m = {.p = (const uint8_t *)message, .l = strlen(message)};
    buffer[n].m = to_const_bytes(copy_bytes(&m));
    Signal_send(&state[id], &m, &buffer[n].h, &buffer[n].c);
    if (verbose > 0) {
        if (verbose > 1) {
            // @TODO: add logging print_state(id);
        }
        // @TODO: add logging print_buffer(n);
    }
#ifdef PLOT
    printf("%s message %zu: (%zu, %zu | %zu) %s\n", arrow[0][id], n, buffer[n].h.t, buffer[n].h.i, buffer[n].h.l, message);
#endif
    return n++;
}
size_t send(uint8_t id, char message[]) {
    return _send(id, message, 0);
}

int _recv(size_t send_id, int verbose) {
    uint8_t id = 1 - buffer[send_id].id;
    if (verbose > 1) {
        // @TODO: add logging print_state(id);
    }
    bytes m;
    bool use_cka_free;
    int result = Signal_recv(&state[id], as_const_bytes(&buffer[send_id].c), &buffer[send_id].h, &use_cka_free, &m);
    free_bytes(as_const_bytes(&buffer[send_id].c));
    Signal_free_header(&buffer[send_id].h, use_cka_free);
    if (result != 0) {
#ifdef PLOT
        printf("%s decrypt %zu: (%zu, %zu | %zu) ???\n", arrow[1][id], send_id, buffer[send_id].h.t, buffer[send_id].h.i, buffer[send_id].h.l);
#endif
        printf("Error: %d.\n", result);
        if (verbose == 0) {
            exit(result);
        } else if (verbose > 1) {
            // @TODO: add logging print_state(id);
        }
        return result;
    }
    assert(m.l == buffer[send_id].m.l && strncmp((const char *)m.p, (const char *)buffer[send_id].m.p, m.l) == 0);
    free_bytes(&buffer[send_id].m);
    if (verbose > 1) {
        // @TODO: add logging print_state(id);
    }
#ifdef PLOT
    printf("%s decrypt %zu: (%zu, %zu | %zu) %.*s\n", arrow[1][id], send_id, buffer[send_id].h.t, buffer[send_id].h.i, buffer[send_id].h.l, (int)m.l, m.p);
#endif
    free_bytes(as_const_bytes(&m));
    return 0;
}
int recv(size_t send_id) {
    return _recv(send_id, 0);
}

void free_all() {
    Signal_free(&state[0]);
    Signal_free(&state[1]);
    // @TODO: free buffer
}
