#include <botan/ffi.h>

#include <string.h>

size_t _AEAD_K_L(botan_cipher_t ctx) {
    size_t k_l, _k_l, __k_l;
    assert(botan_cipher_get_keyspec(ctx, &k_l, &_k_l, &__k_l) == 0);
    return k_l;
}

size_t AEAD_K_L() {
    botan_cipher_t ctx;
    assert(botan_cipher_init(&ctx, AEAD_MODE, 0) == 0);
    size_t k_l = _AEAD_K_L(ctx);
    assert(botan_cipher_destroy(ctx) == 0);
    return k_l;
}

void aead_enc(const uint8_t *k, bytes_i ads, size_t n_ads, bytes_i m, bytes_O c) {
    botan_cipher_t ctx;
    assert(botan_cipher_init(&ctx, AEAD_MODE, BOTAN_CIPHER_INIT_FLAG_ENCRYPT) == 0);
    assert(botan_cipher_set_key(ctx, k, _AEAD_K_L(ctx)) == 0);
    size_t l_ads = 0;
    for (size_t n = 0; n < n_ads; ++n) {
        l_ads += ads[n].l;
    }
    bytes ad = alloc_bytes(l_ads);
    for (size_t n = 0, l = 0; n < n_ads; ++n) {
        memcpy(ad.p + l, ads[n].p, ads[n].l);
        l += ads[n].l;
    }
    assert(botan_cipher_set_associated_data(ctx, ad.p, ad.l) == 0);
    free_bytes(as_const_bytes(&ad));
    size_t iv_l;
    assert(botan_cipher_get_default_nonce_length(ctx, &iv_l) == 0);
    botan_rng_t rng;
    assert(botan_rng_init(&rng, NULL) == 0);
    bytes iv = alloc_bytes(iv_l);
    assert(botan_rng_get(rng, iv.p, iv.l) == 0);
    assert(botan_rng_destroy(rng) == 0);
    assert(botan_cipher_start(ctx, iv.p, iv.l) == 0);
    size_t l;
    assert(botan_cipher_output_length(ctx, m->l, &l) == 0);
    *c = alloc_bytes(iv_l + l); // @TODO: use sparse output
    memcpy(c->p, iv.p, iv_l);
    free_bytes(as_const_bytes(&iv));
    size_t c_l, m_l;
    assert(botan_cipher_update(ctx, BOTAN_CIPHER_UPDATE_FLAG_FINAL, c->p + iv_l, l, &c_l, m->p, m->l, &m_l) == 0);
    assert(c_l == c->l - iv_l);
    assert(m_l == m->l);
    assert(botan_cipher_destroy(ctx) == 0);
}
int aead_dec(const uint8_t *k, bytes_i ads, size_t n_ads, bytes_i c, bytes_O m) {
    botan_cipher_t ctx;
    assert(botan_cipher_init(&ctx, AEAD_MODE, BOTAN_CIPHER_INIT_FLAG_DECRYPT) == 0);
    assert(botan_cipher_set_key(ctx, k, _AEAD_K_L(ctx)) == 0);
    size_t l_ads = 0;
    for (size_t n = 0; n < n_ads; ++n) {
        l_ads += ads[n].l;
    }
    bytes ad = alloc_bytes(l_ads);
    for (size_t n = 0, l = 0; n < n_ads; ++n) {
        memcpy(ad.p + l, ads[n].p, ads[n].l);
        l += ads[n].l;
    }
    assert(botan_cipher_set_associated_data(ctx, ad.p, ad.l) == 0);
    free_bytes(as_const_bytes(&ad));
    size_t iv_l;
    assert(botan_cipher_get_default_nonce_length(ctx, &iv_l) == 0);
    assert(botan_cipher_start(ctx, c->p, iv_l) == 0);
    size_t l;
    assert(botan_cipher_output_length(ctx, c->l - iv_l, &l) == 0);
    *m = alloc_bytes(l);
    size_t m_l, c_l;
    int result = botan_cipher_update(ctx, BOTAN_CIPHER_UPDATE_FLAG_FINAL, m->p, m->l, &m_l, c->p + iv_l, c->l - iv_l, &c_l);
    assert(m_l == m->l);
    assert(c_l == c->l - iv_l);
    if (result != 0) {
        free_bytes(as_const_bytes(m));
        assert(botan_cipher_destroy(ctx) == 0);
        return 1;
    }
    assert(botan_cipher_destroy(ctx) == 0);
    return 0;
}
