/*
 * Copyright (C) 2022-2023 Jordan Bancino <@jordan:bancino.net>
 *
 * Permission is hereby granted, free of charge, to any person
 * obtaining a copy of this software and associated documentation files
 * (the "Software"), to deal in the Software without restriction,
 * including without limitation the rights to use, copy, modify, merge,
 * publish, distribute, sublicense, and/or sell copies of the Software,
 * and to permit persons to whom the Software is furnished to do so,
 * subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be
 * included in all copies or portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
 * NONINFRINGEMENT.  IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */
#include <Tls.h>

#if TLS_IMPL == TLS_OPENSSL

#include <Memory.h>
#include <Log.h>

#include <openssl/ssl.h>
#include <openssl/err.h>

typedef struct OpenSSLCookie
{
    int fd;
    const SSL_METHOD *method;
    SSL_CTX *ctx;
    SSL *ssl;
} OpenSSLCookie;

void *
TlsInitClient(int fd, const char *serverName)
{
    OpenSSLCookie *cookie;
    char errorStr[256];

    /*
     * TODO: Seems odd that this isn't needed to make the
     * connection... we should figure out how to verify the
     * certificate matches the server we think we're
     * connecting to.
     */
    (void) serverName;

    cookie = Malloc(sizeof(OpenSSLCookie));
    if (!cookie)
    {
        return NULL;
    }

    memset(cookie, 0, sizeof(OpenSSLCookie));

    cookie->method = TLS_client_method();
    cookie->ctx = SSL_CTX_new(cookie->method);
    if (!cookie->ctx)
    {
        goto error;
    }

    cookie->ssl = SSL_new(cookie->ctx);
    if (!cookie->ssl)
    {
        goto error;
    }

    if (!SSL_set_fd(cookie->ssl, fd))
    {
        goto error;
    }

    if (SSL_connect(cookie->ssl) <= 0)
    {
        goto error;
    }

    return cookie;

error:
    Log(LOG_ERR, "TlsClientInit(): %s", ERR_error_string(ERR_get_error(), errorStr));

    if (cookie->ssl)
    {
        SSL_shutdown(cookie->ssl);
        SSL_free(cookie->ssl);
    }

    close(cookie->fd);

    if (cookie->ctx)
    {
        SSL_CTX_free(cookie->ctx);
    }

    return NULL;
}

void *
TlsInitServer(int fd, const char *crt, const char *key)
{
    OpenSSLCookie *cookie;
    char errorStr[256];

    cookie = Malloc(sizeof(OpenSSLCookie));
    if (!cookie)
    {
        return NULL;
    }

    memset(cookie, 0, sizeof(OpenSSLCookie));

    cookie->method = TLS_server_method();
    cookie->ctx = SSL_CTX_new(cookie->method);
    if (!cookie->ctx)
    {
        Log(LOG_ERR, "TlsServerInit(): Unable to create SSL Context.");
        goto error;
    }

    if (SSL_CTX_use_certificate_file(cookie->ctx, crt, SSL_FILETYPE_PEM) <= 0)
    {
        Log(LOG_ERR, "TlsServerInit(): Unable to set certificate file.");
        goto error;
    }

    if (SSL_CTX_use_PrivateKey_file(cookie->ctx, key, SSL_FILETYPE_PEM) <= 0)
    {
        Log(LOG_ERR, "TlsServerInit(): Unable to set key file.");
        goto error;
    }

    cookie->ssl = SSL_new(cookie->ctx);
    if (!cookie->ssl)
    {
        Log(LOG_ERR, "TlsServerInit(): Unable to create SSL object.");
        goto error;
    }

    if (!SSL_set_fd(cookie->ssl, fd))
    {
        Log(LOG_ERR, "TlsServerInit(): Unable to set file descriptor.");
        goto error;
    }

    if (SSL_accept(cookie->ssl) <= 0)
    {
        Log(LOG_ERR, "TlsServerInit(): Unable to accept connection.");
        goto error;
    }

    return cookie;

error:
    Log(LOG_ERR, "TlsServerInit(): %s", ERR_error_string(ERR_get_error(), errorStr));

    if (cookie->ssl)
    {
        SSL_shutdown(cookie->ssl);
        SSL_free(cookie->ssl);
    }

    close(cookie->fd);

    if (cookie->ctx)
    {
        SSL_CTX_free(cookie->ctx);
    }

    return NULL;
}

ssize_t
TlsRead(void *cookie, void *buf, size_t nBytes)
{
    OpenSSLCookie *ssl = cookie;
    int ret = SSL_read(ssl->ssl, buf, nBytes);

    if (ret <= 0)
    {
        ret = -1;
        switch (SSL_get_error(ssl->ssl, ret))
        {
            case SSL_ERROR_WANT_READ:
            case SSL_ERROR_WANT_WRITE:
            case SSL_ERROR_WANT_CONNECT:
            case SSL_ERROR_WANT_ACCEPT:
            case SSL_ERROR_WANT_X509_LOOKUP:
                errno = EAGAIN;
                break;
            case SSL_ERROR_ZERO_RETURN:
                ret = 0;
                break;
            default:
                errno = EIO;
                break;
        }
    }

    return ret;
}

ssize_t
TlsWrite(void *cookie, void *buf, size_t nBytes)
{
    OpenSSLCookie *ssl = cookie;
    int ret = SSL_write(ssl->ssl, buf, nBytes);

    if (ret <= 0)
    {
        ret = -1;
        switch (SSL_get_error(ssl->ssl, ret))
        {
            case SSL_ERROR_WANT_READ:
            case SSL_ERROR_WANT_WRITE:
            case SSL_ERROR_WANT_CONNECT:
            case SSL_ERROR_WANT_ACCEPT:
            case SSL_ERROR_WANT_X509_LOOKUP:
                errno = EAGAIN;
                break;
            case SSL_ERROR_ZERO_RETURN:
                ret = 0;
                break;
            default:
                errno = EIO;
                break;
        }
    }

    return ret;
}

int
TlsClose(void *cookie)
{
    OpenSSLCookie *ssl = cookie;

    SSL_shutdown(ssl->ssl);
    SSL_free(ssl->ssl);
    close(ssl->fd);
    SSL_CTX_free(ssl->ctx);

    Free(ssl);

    return 0;
}

#endif