diff --git a/src/Parser.c b/src/Parser.c index d46da35..62ba03a 100644 --- a/src/Parser.c +++ b/src/Parser.c @@ -323,6 +323,28 @@ ParseServerName(char **str, ServerPart *out) return 1; } +int +ParseServerPart(char *str, ServerPart *part) +{ + /* This is a wrapper behind the internal ParseServerName. */ + if (!str || !part) + { + return 0; + } + return ParseServerName(&str, part); +} +void +ServerPartFree(ServerPart part) +{ + if (part.hostname) + { + Free(part.hostname); + } + if (part.port) + { + Free(part.port); + } +} int ParseCommonID(char *str, CommonID *id) @@ -406,14 +428,7 @@ CommonIDFree(CommonID id) { Free(id.local); } - if (id.server.hostname) - { - Free(id.server.hostname); - } - if (id.server.port) - { - Free(id.server.port); - } + ServerPartFree(id.server); } int ValidCommonID(char *str, char sigil) @@ -432,6 +447,19 @@ ValidCommonID(char *str, char sigil) return ret; } char * +ParserRecomposeServerPart(ServerPart serverPart) +{ + if (serverPart.hostname && serverPart.port) + { + return StrConcat(3, serverPart.hostname, ":", serverPart.port); + } + if (serverPart.hostname) + { + return StrDuplicate(serverPart.hostname); + } + return NULL; +} +char * ParserRecomposeCommonID(CommonID id) { char *ret = Malloc(2); @@ -442,19 +470,33 @@ ParserRecomposeCommonID(CommonID id) { char *tmp = StrConcat(2, ret, id.local); Free(ret); + ret = tmp; } if (id.server.hostname) { - char *tmp = StrConcat(3, ret, ":", id.server.hostname); - Free(ret); - ret = tmp; - } - if (id.server.port) - { - char *tmp = StrConcat(3, ret, ":", id.server.port); + char *server = ParserRecomposeServerPart(id.server); + char *tmp = StrConcat(4, "@", ret, ":", server); Free(ret); + Free(server); + ret = tmp; } return ret; } +int +ParserServerNameEquals(ServerPart serverPart, char *str) +{ + char *idServer; + int ret; + if (!str) + { + return 0; + } + idServer = ParserRecomposeServerPart(serverPart); + + ret = StrEquals(idServer, str); + Free(idServer); + + return ret; +} diff --git a/src/Routes/RouteAliasDirectory.c b/src/Routes/RouteAliasDirectory.c index 4c3938e..137e1ce 100644 --- a/src/Routes/RouteAliasDirectory.c +++ b/src/Routes/RouteAliasDirectory.c @@ -118,18 +118,21 @@ ROUTE_IMPL(RouteAliasDirectory, path, argp) { HashMap *newAlias; char *id; + char *serverPart; - /* Check for server name. - * TODO: Take the port into account, that might need a - * refactor for it to use a ServerPart */ - if (!StrEquals(aliasID.server.hostname, config->serverName)) + serverPart = ParserRecomposeServerPart(aliasID.server); + if (!StrEquals(serverPart, config->serverName)) { msg = "Invalid server name."; HttpResponseStatus(args->context, HTTP_BAD_REQUEST); response = MatrixErrorCreate(M_INVALID_PARAM, msg); + + Free(serverPart); goto finish; } + Free(serverPart); + if (JsonGet(aliases, 2, "alias", alias)) { HttpResponseStatus(args->context, HTTP_CONFLICT); diff --git a/src/Routes/RouteFilter.c b/src/Routes/RouteFilter.c index c2ead37..c8f4fa8 100644 --- a/src/Routes/RouteFilter.c +++ b/src/Routes/RouteFilter.c @@ -64,7 +64,7 @@ ROUTE_IMPL(RouteFilter, path, argp) HashMap *response = NULL; User *user = NULL; - UserId *id = NULL; + CommonID *id = NULL; char *token = NULL; char *serverName = NULL; @@ -97,7 +97,7 @@ ROUTE_IMPL(RouteFilter, path, argp) goto finish; } - if (!StrEquals(id->server, serverName)) + if (!ParserServerNameEquals(id->server, serverName)) { msg = "Cannot use /filter for non-local users."; HttpResponseStatus(args->context, HTTP_UNAUTHORIZED); @@ -119,7 +119,7 @@ ROUTE_IMPL(RouteFilter, path, argp) goto finish; } - if (!StrEquals(id->localpart, UserGetName(user))) + if (!StrEquals(id->local, UserGetName(user))) { msg = "Unauthorized to use /filter."; HttpResponseStatus(args->context, HTTP_UNAUTHORIZED); diff --git a/src/Routes/RouteLogin.c b/src/Routes/RouteLogin.c index 158629d..6234f15 100644 --- a/src/Routes/RouteLogin.c +++ b/src/Routes/RouteLogin.c @@ -49,7 +49,7 @@ ROUTE_IMPL(RouteLogin, path, argp) LoginRequest loginRequest; LoginRequestUserIdentifier userIdentifier; - UserId *userId = NULL; + CommonID *userId = NULL; Db *db = args->matrixArgs->db; @@ -160,8 +160,8 @@ ROUTE_IMPL(RouteLogin, path, argp) break; } - if (!StrEquals(userId->server, config.serverName) - || !UserExists(db, userId->localpart)) + if (!ParserServerNameEquals(userId->server, config.serverName) + || !UserExists(db, userId->local)) { msg = "Unknown user ID."; HttpResponseStatus(args->context, HTTP_FORBIDDEN); @@ -175,7 +175,7 @@ ROUTE_IMPL(RouteLogin, path, argp) password = loginRequest.password; refreshToken = loginRequest.refresh_token; - user = UserLock(db, userId->localpart); + user = UserLock(db, userId->local); if (!user) { diff --git a/src/Routes/RouteUserProfile.c b/src/Routes/RouteUserProfile.c index 5c1af7b..4c438b0 100644 --- a/src/Routes/RouteUserProfile.c +++ b/src/Routes/RouteUserProfile.c @@ -40,7 +40,7 @@ ROUTE_IMPL(RouteUserProfile, path, argp) HashMap *request = NULL; HashMap *response = NULL; - UserId *userId = NULL; + CommonID *userId = NULL; User *user = NULL; char *serverName; @@ -73,7 +73,7 @@ ROUTE_IMPL(RouteUserProfile, path, argp) response = MatrixErrorCreate(M_INVALID_PARAM, msg); goto finish; } - if (strcmp(userId->server, serverName)) + if (!ParserServerNameEquals(userId->server, serverName)) { /* TODO: Implement lookup over federation. */ msg = "User profile endpoint currently doesn't support lookup over " @@ -87,7 +87,7 @@ ROUTE_IMPL(RouteUserProfile, path, argp) switch (HttpRequestMethodGet(args->context)) { case HTTP_GET: - user = UserLock(db, userId->localpart); + user = UserLock(db, userId->local); if (!user) { msg = "Couldn't lock user."; @@ -147,11 +147,11 @@ ROUTE_IMPL(RouteUserProfile, path, argp) StrEquals(entry, "avatar_url")) { /* Check if user has privilege to do that action. */ - if (StrEquals(userId->localpart, UserGetName(user))) + if (StrEquals(userId->local, UserGetName(user))) { value = JsonValueAsString(HashMapGet(request, entry)); - /* TODO: Make UserSetProfile notify other - * parties of the change */ + /* TODO: Make UserSetProfile notify other parties of + * the change */ UserSetProfile(user, entry, value); response = HashMapCreate(); goto finish; diff --git a/src/Uia.c b/src/Uia.c index 2c0abb8..ff00b06 100644 --- a/src/Uia.c +++ b/src/Uia.c @@ -351,7 +351,7 @@ UiaComplete(Array * flows, HttpServerContext * context, Db * db, char *password = JsonValueAsString(HashMapGet(auth, "password")); HashMap *identifier = JsonValueAsObject(HashMapGet(auth, "identifier")); char *type; - UserId *userId; + CommonID *userId; User *user; if (!password || !identifier) @@ -366,7 +366,8 @@ UiaComplete(Array * flows, HttpServerContext * context, Db * db, config.serverName); if (!type || !StrEquals(type, "m.id.user") - || !userId || !StrEquals(userId->server, config.serverName)) + || !userId + || !ParserServerNameEquals(userId->server, config.serverName)) { HttpResponseStatus(context, HTTP_UNAUTHORIZED); ret = BuildResponse(flows, db, response, session, dbRef); @@ -374,7 +375,7 @@ UiaComplete(Array * flows, HttpServerContext * context, Db * db, goto finish; } - user = UserLock(db, userId->localpart); + user = UserLock(db, userId->local); if (!user) { HttpResponseStatus(context, HTTP_UNAUTHORIZED); diff --git a/src/User.c b/src/User.c index 20a4289..1fc872d 100644 --- a/src/User.c +++ b/src/User.c @@ -884,10 +884,11 @@ finish: return arr; } -UserId * +CommonID * UserIdParse(char *id, char *defaultServer) { - UserId *userId; + CommonID *userId; + char *server; if (!id) { @@ -900,52 +901,38 @@ UserIdParse(char *id, char *defaultServer) return NULL; } - userId = Malloc(sizeof(UserId)); + userId = Malloc(sizeof(CommonID)); if (!userId) { goto finish; } + memset(userId, 0, sizeof(CommonID)); /* Fully-qualified user ID */ if (*id == '@') { - /* TODO: Just use the CommonID. */ - CommonID commonID; - commonID.sigil = '\0'; - commonID.local = NULL; - commonID.server.hostname = NULL; - commonID.server.port = NULL; - if (!ParseCommonID(id, &commonID) || !commonID.server.hostname) + if (!ParseCommonID(id, userId) || !userId->server.hostname) { - Free(userId); - Free(commonID.local); - if (commonID.server.hostname) - { - Free(commonID.server.hostname); - } + UserIdFree(userId); + userId = NULL; goto finish; } - - userId->localpart = commonID.local; - userId->server = commonID.server.hostname; - if (commonID.server.port) - { - Free(commonID.server.port); - } } else { /* Treat it as just a localpart */ - userId->localpart = StrDuplicate(id); - userId->server = StrDuplicate(defaultServer); + userId->local = StrDuplicate(id); + ParseServerPart(defaultServer, &userId->server); } - if (!UserHistoricalValidate(userId->localpart, userId->server)) + server = ParserRecomposeServerPart(userId->server); + if (!UserHistoricalValidate(userId->local, server)) { UserIdFree(userId); userId = NULL; } + Free(server); finish: Free(id); @@ -953,12 +940,11 @@ finish: } void -UserIdFree(UserId * id) +UserIdFree(CommonID * id) { if (id) { - Free(id->localpart); - Free(id->server); + CommonIDFree(*id); Free(id); } } diff --git a/src/include/Parser.h b/src/include/Parser.h index 405caf1..bd57dc0 100644 --- a/src/include/Parser.h +++ b/src/include/Parser.h @@ -58,6 +58,12 @@ typedef struct CommonID { * by the [matrix] specification. */ extern int ParseCommonID(char *, CommonID *); + +/** + * Parses the server part in a common identifier. + */ +extern int ParseServerPart(char *, ServerPart *); + /** * Checks whenever the string is a valid common ID with the correct sigil. */ @@ -66,7 +72,15 @@ extern int ValidCommonID(char *, char); /** * Frees a CommonID's values. Note that it doesn't free the CommonID itself. */ -extern void CommonIDFree(CommonID id); +extern void CommonIDFree(CommonID); + +/** + * Frees a ServerPart's values. Note that it doesn't free the ServerPart + * itself, and that + * .Fn CommonIDFree + * automatically deals with its server part. + */ +extern void ServerPartFree(ServerPart); /** * Recompose a Common ID into a string which lives in the heap, and must @@ -75,5 +89,17 @@ extern void CommonIDFree(CommonID id); */ extern char * ParserRecomposeCommonID(CommonID); +/** + * Recompose a server part into a string which lives in the heap, and must + * therefore be freed with + * .Fn Free . + */ +extern char * ParserRecomposeServerPart(ServerPart); + +/** + * Compares whenever a ServerName is equivalent to a server name string. + */ +extern int ParserServerNameEquals(ServerPart, char *); + #endif /* TELODENDRIA_PARSER_H */ diff --git a/src/include/User.h b/src/include/User.h index 6eb5530..f84ac94 100644 --- a/src/include/User.h +++ b/src/include/User.h @@ -43,6 +43,8 @@ #include #include +#include + /** * Many functions here operate on an opaque user structure. */ @@ -88,15 +90,6 @@ typedef struct UserLoginInfo char *refreshToken; } UserLoginInfo; -/** - * A description of a Matrix user ID. - */ -typedef struct UserId -{ - char *localpart; - char *server; -} UserId; - /** * Take a localpart and domain as separate parameters and validate them * against the rules of the Matrix specification. The reasion the @@ -303,15 +296,15 @@ extern Array *UserEncodePrivileges(int); extern int UserDecodePrivilege(const char *); /** - * Parse either a localpart or a fully qualified Matrix ID. If the + * Parse either a localpart or a fully qualified Matrix common ID. If the * first argument is a localpart, then the second argument is used as * the server name. */ -extern UserId * UserIdParse(char *, char *); +extern CommonID * UserIdParse(char *, char *); /** - * Free the memory associated with the parsed Matrix ID. + * Frees the user's common ID and the memory allocated for it. */ -extern void UserIdFree(UserId *); +extern void UserIdFree(CommonID *); #endif /* TELODENDRIA_USER_H */