[MOD/WIP] Refactor code to use CommonIDs instead of UserIds

This commit is contained in:
lda 2023-12-09 21:34:50 +01:00
parent 88c9d10f90
commit 5d2ca5a21b
9 changed files with 129 additions and 78 deletions

View file

@ -323,6 +323,28 @@ ParseServerName(char **str, ServerPart *out)
return 1; return 1;
} }
int
ParseServerPart(char *str, ServerPart *part)
{
/* This is a wrapper behind the internal ParseServerName. */
if (!str || !part)
{
return 0;
}
return ParseServerName(&str, part);
}
void
ServerPartFree(ServerPart part)
{
if (part.hostname)
{
Free(part.hostname);
}
if (part.port)
{
Free(part.port);
}
}
int int
ParseCommonID(char *str, CommonID *id) ParseCommonID(char *str, CommonID *id)
@ -406,14 +428,7 @@ CommonIDFree(CommonID id)
{ {
Free(id.local); Free(id.local);
} }
if (id.server.hostname) ServerPartFree(id.server);
{
Free(id.server.hostname);
}
if (id.server.port)
{
Free(id.server.port);
}
} }
int int
ValidCommonID(char *str, char sigil) ValidCommonID(char *str, char sigil)
@ -432,6 +447,19 @@ ValidCommonID(char *str, char sigil)
return ret; return ret;
} }
char * char *
ParserRecomposeServerPart(ServerPart serverPart)
{
if (serverPart.hostname && serverPart.port)
{
return StrConcat(3, serverPart.hostname, ":", serverPart.port);
}
if (serverPart.hostname)
{
return StrDuplicate(serverPart.hostname);
}
return NULL;
}
char *
ParserRecomposeCommonID(CommonID id) ParserRecomposeCommonID(CommonID id)
{ {
char *ret = Malloc(2); char *ret = Malloc(2);
@ -442,19 +470,33 @@ ParserRecomposeCommonID(CommonID id)
{ {
char *tmp = StrConcat(2, ret, id.local); char *tmp = StrConcat(2, ret, id.local);
Free(ret); Free(ret);
ret = tmp; ret = tmp;
} }
if (id.server.hostname) if (id.server.hostname)
{ {
char *tmp = StrConcat(3, ret, ":", id.server.hostname); char *server = ParserRecomposeServerPart(id.server);
Free(ret); char *tmp = StrConcat(4, "@", ret, ":", server);
ret = tmp;
}
if (id.server.port)
{
char *tmp = StrConcat(3, ret, ":", id.server.port);
Free(ret); Free(ret);
Free(server);
ret = tmp; ret = tmp;
} }
return ret; return ret;
} }
int
ParserServerNameEquals(ServerPart serverPart, char *str)
{
char *idServer;
int ret;
if (!str)
{
return 0;
}
idServer = ParserRecomposeServerPart(serverPart);
ret = StrEquals(idServer, str);
Free(idServer);
return ret;
}

View file

@ -118,18 +118,21 @@ ROUTE_IMPL(RouteAliasDirectory, path, argp)
{ {
HashMap *newAlias; HashMap *newAlias;
char *id; char *id;
char *serverPart;
/* Check for server name. serverPart = ParserRecomposeServerPart(aliasID.server);
* TODO: Take the port into account, that might need a if (!StrEquals(serverPart, config->serverName))
* refactor for it to use a ServerPart */
if (!StrEquals(aliasID.server.hostname, config->serverName))
{ {
msg = "Invalid server name."; msg = "Invalid server name.";
HttpResponseStatus(args->context, HTTP_BAD_REQUEST); HttpResponseStatus(args->context, HTTP_BAD_REQUEST);
response = MatrixErrorCreate(M_INVALID_PARAM, msg); response = MatrixErrorCreate(M_INVALID_PARAM, msg);
Free(serverPart);
goto finish; goto finish;
} }
Free(serverPart);
if (JsonGet(aliases, 2, "alias", alias)) if (JsonGet(aliases, 2, "alias", alias))
{ {
HttpResponseStatus(args->context, HTTP_CONFLICT); HttpResponseStatus(args->context, HTTP_CONFLICT);

View file

@ -64,7 +64,7 @@ ROUTE_IMPL(RouteFilter, path, argp)
HashMap *response = NULL; HashMap *response = NULL;
User *user = NULL; User *user = NULL;
UserId *id = NULL; CommonID *id = NULL;
char *token = NULL; char *token = NULL;
char *serverName = NULL; char *serverName = NULL;
@ -97,7 +97,7 @@ ROUTE_IMPL(RouteFilter, path, argp)
goto finish; goto finish;
} }
if (!StrEquals(id->server, serverName)) if (!ParserServerNameEquals(id->server, serverName))
{ {
msg = "Cannot use /filter for non-local users."; msg = "Cannot use /filter for non-local users.";
HttpResponseStatus(args->context, HTTP_UNAUTHORIZED); HttpResponseStatus(args->context, HTTP_UNAUTHORIZED);
@ -119,7 +119,7 @@ ROUTE_IMPL(RouteFilter, path, argp)
goto finish; goto finish;
} }
if (!StrEquals(id->localpart, UserGetName(user))) if (!StrEquals(id->local, UserGetName(user)))
{ {
msg = "Unauthorized to use /filter."; msg = "Unauthorized to use /filter.";
HttpResponseStatus(args->context, HTTP_UNAUTHORIZED); HttpResponseStatus(args->context, HTTP_UNAUTHORIZED);

View file

@ -49,7 +49,7 @@ ROUTE_IMPL(RouteLogin, path, argp)
LoginRequest loginRequest; LoginRequest loginRequest;
LoginRequestUserIdentifier userIdentifier; LoginRequestUserIdentifier userIdentifier;
UserId *userId = NULL; CommonID *userId = NULL;
Db *db = args->matrixArgs->db; Db *db = args->matrixArgs->db;
@ -160,8 +160,8 @@ ROUTE_IMPL(RouteLogin, path, argp)
break; break;
} }
if (!StrEquals(userId->server, config.serverName) if (!ParserServerNameEquals(userId->server, config.serverName)
|| !UserExists(db, userId->localpart)) || !UserExists(db, userId->local))
{ {
msg = "Unknown user ID."; msg = "Unknown user ID.";
HttpResponseStatus(args->context, HTTP_FORBIDDEN); HttpResponseStatus(args->context, HTTP_FORBIDDEN);
@ -175,7 +175,7 @@ ROUTE_IMPL(RouteLogin, path, argp)
password = loginRequest.password; password = loginRequest.password;
refreshToken = loginRequest.refresh_token; refreshToken = loginRequest.refresh_token;
user = UserLock(db, userId->localpart); user = UserLock(db, userId->local);
if (!user) if (!user)
{ {

View file

@ -40,7 +40,7 @@ ROUTE_IMPL(RouteUserProfile, path, argp)
HashMap *request = NULL; HashMap *request = NULL;
HashMap *response = NULL; HashMap *response = NULL;
UserId *userId = NULL; CommonID *userId = NULL;
User *user = NULL; User *user = NULL;
char *serverName; char *serverName;
@ -73,7 +73,7 @@ ROUTE_IMPL(RouteUserProfile, path, argp)
response = MatrixErrorCreate(M_INVALID_PARAM, msg); response = MatrixErrorCreate(M_INVALID_PARAM, msg);
goto finish; goto finish;
} }
if (strcmp(userId->server, serverName)) if (!ParserServerNameEquals(userId->server, serverName))
{ {
/* TODO: Implement lookup over federation. */ /* TODO: Implement lookup over federation. */
msg = "User profile endpoint currently doesn't support lookup over " msg = "User profile endpoint currently doesn't support lookup over "
@ -87,7 +87,7 @@ ROUTE_IMPL(RouteUserProfile, path, argp)
switch (HttpRequestMethodGet(args->context)) switch (HttpRequestMethodGet(args->context))
{ {
case HTTP_GET: case HTTP_GET:
user = UserLock(db, userId->localpart); user = UserLock(db, userId->local);
if (!user) if (!user)
{ {
msg = "Couldn't lock user."; msg = "Couldn't lock user.";
@ -147,11 +147,11 @@ ROUTE_IMPL(RouteUserProfile, path, argp)
StrEquals(entry, "avatar_url")) StrEquals(entry, "avatar_url"))
{ {
/* Check if user has privilege to do that action. */ /* Check if user has privilege to do that action. */
if (StrEquals(userId->localpart, UserGetName(user))) if (StrEquals(userId->local, UserGetName(user)))
{ {
value = JsonValueAsString(HashMapGet(request, entry)); value = JsonValueAsString(HashMapGet(request, entry));
/* TODO: Make UserSetProfile notify other /* TODO: Make UserSetProfile notify other parties of
* parties of the change */ * the change */
UserSetProfile(user, entry, value); UserSetProfile(user, entry, value);
response = HashMapCreate(); response = HashMapCreate();
goto finish; goto finish;

View file

@ -351,7 +351,7 @@ UiaComplete(Array * flows, HttpServerContext * context, Db * db,
char *password = JsonValueAsString(HashMapGet(auth, "password")); char *password = JsonValueAsString(HashMapGet(auth, "password"));
HashMap *identifier = JsonValueAsObject(HashMapGet(auth, "identifier")); HashMap *identifier = JsonValueAsObject(HashMapGet(auth, "identifier"));
char *type; char *type;
UserId *userId; CommonID *userId;
User *user; User *user;
if (!password || !identifier) if (!password || !identifier)
@ -366,7 +366,8 @@ UiaComplete(Array * flows, HttpServerContext * context, Db * db,
config.serverName); config.serverName);
if (!type || !StrEquals(type, "m.id.user") if (!type || !StrEquals(type, "m.id.user")
|| !userId || !StrEquals(userId->server, config.serverName)) || !userId
|| !ParserServerNameEquals(userId->server, config.serverName))
{ {
HttpResponseStatus(context, HTTP_UNAUTHORIZED); HttpResponseStatus(context, HTTP_UNAUTHORIZED);
ret = BuildResponse(flows, db, response, session, dbRef); ret = BuildResponse(flows, db, response, session, dbRef);
@ -374,7 +375,7 @@ UiaComplete(Array * flows, HttpServerContext * context, Db * db,
goto finish; goto finish;
} }
user = UserLock(db, userId->localpart); user = UserLock(db, userId->local);
if (!user) if (!user)
{ {
HttpResponseStatus(context, HTTP_UNAUTHORIZED); HttpResponseStatus(context, HTTP_UNAUTHORIZED);

View file

@ -884,10 +884,11 @@ finish:
return arr; return arr;
} }
UserId * CommonID *
UserIdParse(char *id, char *defaultServer) UserIdParse(char *id, char *defaultServer)
{ {
UserId *userId; CommonID *userId;
char *server;
if (!id) if (!id)
{ {
@ -900,52 +901,38 @@ UserIdParse(char *id, char *defaultServer)
return NULL; return NULL;
} }
userId = Malloc(sizeof(UserId)); userId = Malloc(sizeof(CommonID));
if (!userId) if (!userId)
{ {
goto finish; goto finish;
} }
memset(userId, 0, sizeof(CommonID));
/* Fully-qualified user ID */ /* Fully-qualified user ID */
if (*id == '@') if (*id == '@')
{ {
/* TODO: Just use the CommonID. */ if (!ParseCommonID(id, userId) || !userId->server.hostname)
CommonID commonID;
commonID.sigil = '\0';
commonID.local = NULL;
commonID.server.hostname = NULL;
commonID.server.port = NULL;
if (!ParseCommonID(id, &commonID) || !commonID.server.hostname)
{ {
Free(userId); UserIdFree(userId);
Free(commonID.local);
if (commonID.server.hostname)
{
Free(commonID.server.hostname);
}
userId = NULL; userId = NULL;
goto finish; goto finish;
} }
userId->localpart = commonID.local;
userId->server = commonID.server.hostname;
if (commonID.server.port)
{
Free(commonID.server.port);
}
} }
else else
{ {
/* Treat it as just a localpart */ /* Treat it as just a localpart */
userId->localpart = StrDuplicate(id); userId->local = StrDuplicate(id);
userId->server = StrDuplicate(defaultServer); ParseServerPart(defaultServer, &userId->server);
} }
if (!UserHistoricalValidate(userId->localpart, userId->server)) server = ParserRecomposeServerPart(userId->server);
if (!UserHistoricalValidate(userId->local, server))
{ {
UserIdFree(userId); UserIdFree(userId);
userId = NULL; userId = NULL;
} }
Free(server);
finish: finish:
Free(id); Free(id);
@ -953,12 +940,11 @@ finish:
} }
void void
UserIdFree(UserId * id) UserIdFree(CommonID * id)
{ {
if (id) if (id)
{ {
Free(id->localpart); CommonIDFree(*id);
Free(id->server);
Free(id); Free(id);
} }
} }

View file

@ -58,6 +58,12 @@ typedef struct CommonID {
* by the [matrix] specification. * by the [matrix] specification.
*/ */
extern int ParseCommonID(char *, CommonID *); extern int ParseCommonID(char *, CommonID *);
/**
* Parses the server part in a common identifier.
*/
extern int ParseServerPart(char *, ServerPart *);
/** /**
* Checks whenever the string is a valid common ID with the correct sigil. * Checks whenever the string is a valid common ID with the correct sigil.
*/ */
@ -66,7 +72,15 @@ extern int ValidCommonID(char *, char);
/** /**
* Frees a CommonID's values. Note that it doesn't free the CommonID itself. * Frees a CommonID's values. Note that it doesn't free the CommonID itself.
*/ */
extern void CommonIDFree(CommonID id); extern void CommonIDFree(CommonID);
/**
* Frees a ServerPart's values. Note that it doesn't free the ServerPart
* itself, and that
* .Fn CommonIDFree
* automatically deals with its server part.
*/
extern void ServerPartFree(ServerPart);
/** /**
* Recompose a Common ID into a string which lives in the heap, and must * Recompose a Common ID into a string which lives in the heap, and must
@ -75,5 +89,17 @@ extern void CommonIDFree(CommonID id);
*/ */
extern char * ParserRecomposeCommonID(CommonID); extern char * ParserRecomposeCommonID(CommonID);
/**
* Recompose a server part into a string which lives in the heap, and must
* therefore be freed with
* .Fn Free .
*/
extern char * ParserRecomposeServerPart(ServerPart);
/**
* Compares whenever a ServerName is equivalent to a server name string.
*/
extern int ParserServerNameEquals(ServerPart, char *);
#endif /* TELODENDRIA_PARSER_H */ #endif /* TELODENDRIA_PARSER_H */

View file

@ -43,6 +43,8 @@
#include <Cytoplasm/Db.h> #include <Cytoplasm/Db.h>
#include <Cytoplasm/Json.h> #include <Cytoplasm/Json.h>
#include <Parser.h>
/** /**
* Many functions here operate on an opaque user structure. * Many functions here operate on an opaque user structure.
*/ */
@ -88,15 +90,6 @@ typedef struct UserLoginInfo
char *refreshToken; char *refreshToken;
} UserLoginInfo; } UserLoginInfo;
/**
* A description of a Matrix user ID.
*/
typedef struct UserId
{
char *localpart;
char *server;
} UserId;
/** /**
* Take a localpart and domain as separate parameters and validate them * Take a localpart and domain as separate parameters and validate them
* against the rules of the Matrix specification. The reasion the * against the rules of the Matrix specification. The reasion the
@ -303,15 +296,15 @@ extern Array *UserEncodePrivileges(int);
extern int UserDecodePrivilege(const char *); extern int UserDecodePrivilege(const char *);
/** /**
* Parse either a localpart or a fully qualified Matrix ID. If the * Parse either a localpart or a fully qualified Matrix common ID. If the
* first argument is a localpart, then the second argument is used as * first argument is a localpart, then the second argument is used as
* the server name. * the server name.
*/ */
extern UserId * UserIdParse(char *, char *); extern CommonID * UserIdParse(char *, char *);
/** /**
* Free the memory associated with the parsed Matrix ID. * Frees the user's common ID and the memory allocated for it.
*/ */
extern void UserIdFree(UserId *); extern void UserIdFree(CommonID *);
#endif /* TELODENDRIA_USER_H */ #endif /* TELODENDRIA_USER_H */