/*
 * Copyright (C) 2022-2024 Jordan Bancino <@jordan:bancino.net> with
 * other valuable contributors. See CONTRIBUTORS.txt for the full list.
 *
 * Permission is hereby granted, free of charge, to any person
 * obtaining a copy of this software and associated documentation files
 * (the "Software"), to deal in the Software without restriction,
 * including without limitation the rights to use, copy, modify, merge,
 * publish, distribute, sublicense, and/or sell copies of the Software,
 * and to permit persons to whom the Software is furnished to do so,
 * subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be
 * included in all copies or portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
 * NONINFRINGEMENT.  IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */

#include <State.h>

#include <Cytoplasm/HashMap.h>
#include <Cytoplasm/Memory.h>
#include <Cytoplasm/Array.h>
#include <Cytoplasm/Str.h>
#include <Cytoplasm/Sha.h>

#include <string.h>

#include <Event.h>
#include <Room.h>

int
V1Cmp(void *a, void *b)
{
    HashMap *e1 = a, *e2 = b;
    int64_t depth1, depth2;

    depth1 = 
        JsonValueAsInteger(JsonGet(e1, 1, "depth"));
    depth2 = 
        JsonValueAsInteger(JsonGet(e2, 1, "depth"));

    if (depth1 > depth2)
    {
        return 1;
    }
    else if (depth1 < depth2)
    {
        return -1;
    }
    else
    {
        char *e1id = 
            JsonValueAsString(JsonGet(e1, 1, "event_id"));
        char *e2id = 
            JsonValueAsString(JsonGet(e2, 1, "event_id"));
        unsigned char *sha1 = Sha1(e1id);
        unsigned char *sha2 = Sha1(e2id);
        char *str1 = ShaToHex(sha1);
        char *str2 = ShaToHex(sha2);
        int ret = strcmp(str1, str2) * -1;

        Free(str1);
        Free(str2);
        Free(sha1);
        Free(sha2);

        /* Descending */
        return ret;
    }
}
static HashMap *
StateResolveV1(Room * room, Array * states)
{
    HashMap *R = HashMapCreate();
    HashMap *conflicts = HashMapCreate();
    Array *events = NULL, *types = NULL, *conflicting = NULL;
    size_t i;
    ssize_t j;

    char *type, *key, *event_id;

    for (i = 0; i < ArraySize(states); i++)
    {
        HashMap *state = ArrayGet(states, i);
        char *tuple;
        while (HashMapIterate(state, &tuple, (void **) &event_id))
        {
            if (HashMapGet(R, tuple))
            {
                Array *arr;
                HashMap *hm;

                /* Conflicts! */
                HashMapDelete(R, tuple);
                arr = HashMapGet(conflicts, tuple);
                if (!arr)
                {
                    arr = ArrayCreate();
                }
                hm = RoomEventFetch(room, event_id);
                ArrayAdd(arr, hm);
                HashMapSet(conflicts, tuple, arr);
            }
            else
            {
                /* Add to R */
                HashMapSet(R, tuple, StrDuplicate(event_id));
            }

        }
    }
    /* R and conflicts are now configured */
    types = ArrayCreate();
    ArrayAdd(types, "m.room.power_levels");
    ArrayAdd(types, "m.room.join_rules");
    ArrayAdd(types, "m.room.member");
    for (i = 0; i < ArraySize(types); i++)
    {
        char *t = ArrayGet(types, i);
        HashMap *first;
        Array *state_keys;

        events = ArrayCreate();
        while (StateIterate(conflicts, &type, &key, (void **) &conflicting))
        {
            if (StrEquals(type, t))
            {
                for (j = 0; j < (ssize_t) ArraySize(conflicting); j++)
                {
                    HashMap *event = ArrayGet(conflicting, j);
                    ArrayAdd(events, event);
                }
            }
            Free(type);
            Free(key);
        }
        ArraySort(events, V1Cmp);
        /* Add first event. */
        first = ArrayDelete(events, 0);
        StateSet(
            R,
            JsonValueAsString(JsonGet(first, 1, "type")),
            JsonValueAsString(JsonGet(first, 1, "state_key")),
            JsonValueAsString(JsonGet(first, 1, "event_id")));
        JsonFree(first);
        
        for (j = 0; j < (ssize_t) ArraySize(events); j++)
        {
            HashMap *event = ArrayGet(events, j);
            PduV1 pdu;
            char *msg;

            PduV1FromJson(event, &pdu, &msg);
            if (RoomAuthoriseEventV1(room, pdu, R))
            {
                StateSet(R, pdu.type, pdu.state_key, pdu.event_id);
            }
            else
            {
                PduV1Free(&pdu);
                JsonFree(event);
                break;
            }
            (void) msg;
            PduV1Free(&pdu);
            JsonFree(event);
        }
        ArrayFree(events);
        /* Delete all elements within a key. */
        state_keys = ArrayCreate();
        while (StateIterate(conflicts, &type, &key, (void **) &event_id))
        {
            if (StrEquals(type, t))
            {
                ArrayAdd(state_keys, key);
            }
            Free(type);
        }
        for (j = 0; j < (ssize_t) ArraySize(state_keys); j++)
        {
            char *state_key = ArrayGet(state_keys, j);
            StateSet(conflicts, t, state_key, NULL); 
            Free(state_key);
        }
        ArrayFree(state_keys);
    }
    ArrayFree(types);

    while (StateIterate(conflicts, &type, &key, (void **) &conflicting))
    {
        ArraySort(conflicting, V1Cmp);
        for (j = ArraySize(conflicting) - 1; j >= 0; j--)
        {
            HashMap *event = ArrayGet(events, j);
            PduV1 pdu;
            char *msg;

            PduV1FromJson(event, &pdu, &msg);

            if (RoomAuthoriseEventV1(room, pdu, R))
            {
                StateSet(R, pdu.type, pdu.state_key, pdu.event_id); 
                PduV1Free(&pdu);
                break;
            }
            (void) msg;
            PduV1Free(&pdu);
        }
        Free(type);
        Free(key);
    }
    while (HashMapIterate(conflicts, &type, (void **) &conflicting))
    {
        for (i = 0; i < ArraySize(conflicting); i++)
        {
            JsonFree(ArrayGet(conflicting, i));
        }
        ArrayFree(conflicting);
    }
    HashMapFree(conflicts);

    return R;
}

static HashMap *
StateResolveV2(Array * states)
{
    (void) states;
    return NULL;
}

static HashMap *
StateFromPrevs(Room *room, Array *states)
{
    HashMap *ret_state;
    switch (RoomVersionGet(room))
    {
        case 1:
            ret_state = StateResolveV1(room, states);
            break;
        default:
            ret_state = StateResolveV2(states);
            break;
    }

    return ret_state;
}

HashMap *
StateResolve(Room * room, HashMap * event)
{
    Array *states;
    size_t i;

    Array *prevEvents;

    HashMap *ret_state;

    char *room_id, *event_id;

    Db *db;

    if (!room || !event)
    {
        return NULL;
    }

    /* TODO: Return cached state if it exists */
    db = RoomGetDB(room);
    room_id = JsonValueAsString(HashMapGet(event, "room_id"));
    event_id = JsonValueAsString(HashMapGet(event, "event_id"));
    if (DbExists(db, 4, "rooms", room_id, "state", event_id))
    {
        DbRef *ref = DbLock(db, 4,
            "rooms", room_id, "state", event_id
        );
        ret_state = StateDeserialise(DbJson(ref));
        DbUnlock(db, ref);
        if (ret_state)
        {
            return ret_state;
        }

        /* If a DB error stops us from getting an existing state,
         * recompute it. */
    }


    states = ArrayCreate();
    if (!states)
    {
        return NULL;
    }
    prevEvents = JsonValueAsArray(HashMapGet(event, "prev_events"));

    for (i = 0; i < ArraySize(prevEvents); i++)
    {
        HashMap *prevEvent = 
            RoomEventFetch(room, JsonValueAsString(ArrayGet(prevEvents, i)));
        HashMap *state = StateResolve(room, prevEvent);

        if (HashMapGet(prevEvent, "state_key"))
        {
            StateSet(
                state, 
                JsonValueAsString(HashMapGet(prevEvent, "type")),
                JsonValueAsString(HashMapGet(prevEvent, "state_key")),
                JsonValueAsString(HashMapGet(prevEvent, "event_id")));
        }

        ArrayAdd(states, state);
        JsonFree(prevEvent);
    }

    ret_state = StateFromPrevs(room, states);

    for (i = 0; i < ArraySize(states); i++)
    {
        HashMap *state = ArrayGet(states, i);
        StateFree(state);
    }
    ArrayFree(states);

    if (ret_state)
    {
        HashMap *json = StateSerialise(ret_state);
        DbRef *ref = DbCreate(db, 4, "rooms", room_id, "state", event_id);
        DbJsonSet(ref, json);
        JsonFree(json);
        DbUnlock(db, ref);
    }

    return ret_state;
}
HashMap *
StateCurrent(Room *room)
{
    Array *prevEvents;
    Array *states;
    size_t i;
    HashMap *ret;
    if (!room)
    {
        return NULL;
    }

    prevEvents = RoomPrevEventsGet(room);
    states = ArrayCreate();
    for (i = 0; i < ArraySize(prevEvents); i++)
    {
        HashMap *event = 
            RoomEventFetch(room, JsonValueAsString(ArrayGet(prevEvents, i)));
        HashMap *state = StateResolve(room, event);

        if (HashMapGet(event, "state_key"))
        {
            StateSet(
                state, 
                JsonValueAsString(HashMapGet(event, "type")),
                JsonValueAsString(HashMapGet(event, "state_key")),
                JsonValueAsString(HashMapGet(event, "event_id")));
        }

        ArrayAdd(states, state);
    }

    ret = StateFromPrevs(room, states);

    for (i = 0; i < ArraySize(states); i++)
    {
        HashMap *state = ArrayGet(states, i);
        StateFree(state);
    }
    ArrayFree(states);

    return ret;
}
bool StateIterate(HashMap *state, char **type, char **key, void **event)
{
    char *tuple;
    bool ret;
    if (!state || !type || !key || !event)
    {
        return false;
    }
    
    ret = HashMapIterate(state, &tuple, event);
    if (ret)
    {
        tuple = StrDuplicate(tuple);
        *(strchr(tuple, ',')) = '\0';

        *type = tuple;
        *key = StrDuplicate(tuple + strlen(tuple) + 1);
    }

    return ret;
}
char *
StateGet(HashMap *state, char *type, char *key)
{
    char *full_string;
    char *ret;
    if (!state || !type || !key)
    {
        return NULL;
    }

    full_string = StrConcat(3, type, ",", key);
    ret = HashMapGet(state, full_string);
    Free(full_string);

    return ret;
}
void
StateSet(HashMap *state, char *type, char *key, char *event)
{
    char *full_string, *old;
    if (!state || !type || !key)
    {
        return;
    }

    full_string = StrConcat(3, type, ",", key);
    old = HashMapDelete(state, full_string);
    if (old)
    {
        Free(old);
    }
    if (event)
    {
        HashMapSet(state, full_string, StrDuplicate(event));
    }
    Free(full_string);
}
void 
StateFree(HashMap *state)
{
    char *full;
    char *event_id;

    if (!state)
    {
        return;
    }
    while (HashMapIterate(state, &full, (void **) &event_id))
    {
        Free(event_id);
    }
    HashMapFree(state);
}
HashMap *
StateDeserialise(HashMap *json_state)
{
    HashMap *raw_state;
    
    char *state_type;
    JsonValue *state_keys;

    if (!json_state)
    {
        return NULL;
    }

    raw_state = HashMapCreate();

    while (HashMapIterate(json_state, &state_type, (void **) &state_keys))
    {
        HashMap *state_keys_obj = JsonValueAsObject(state_keys);
        char *state_key;
        JsonValue *event_id;

        while (HashMapIterate(state_keys_obj, &state_key, (void **) &event_id))
        {
            char *eid_string = JsonValueAsString(event_id);
            char *key_name = StrConcat(3, state_type, ",", state_key);

            HashMapSet(raw_state, key_name, StrDuplicate(eid_string));

            Free(key_name);
        }
    }

    return raw_state;
}
HashMap *
StateSerialise(HashMap *rawState)
{
    HashMap *returned;
    char *type, *key, *event;
    if (!rawState)
    {
        return NULL;
    }

    returned = HashMapCreate();
    while (StateIterate(rawState, &type, &key, (void **) &event))
    {
        JsonSet(returned, JsonValueString(event), 2, type, key);
        Free(type);
        Free(key);
    }

    return returned;
}