Add StrEquals(), as equality checking is the most common use for strcmp().

This allows us to get rid of the hideous MATRIX_PATH_PART_EQUALS macro,
and prevents inconsistent usage of strcmp() (for example, !strcmp() vs
strcmp() == 0).

StrEquals() also has sensible behavior for dealing with NULL strings (it
doesn't just segfault like strcmp()).
This commit is contained in:
Jordan Bancino 2023-05-06 22:34:36 +00:00
parent 0e69a12784
commit 0b11b97022
27 changed files with 127 additions and 90 deletions

View file

@ -23,7 +23,7 @@ Milestone: v0.3.0
[~] Client-Server API
[~] 4: Account management
[~] Deactivate
[x] Deactivate
[x] Make sure UserLogin() fails if user is deactivated.
[~] Whoami
[ ] Attach device id to user object

View file

@ -209,15 +209,15 @@ ConfigParseLog(Config * tConfig, HashMap * config)
CONFIG_REQUIRE("output", JSON_STRING);
str = JsonValueAsString(value);
if (strcmp(str, "stdout") == 0)
if (StrEquals(str, "stdout"))
{
tConfig->flags |= CONFIG_LOG_STDOUT;
}
else if (strcmp(str, "file") == 0)
else if (StrEquals(str, "file"))
{
tConfig->flags |= CONFIG_LOG_FILE;
}
else if (strcmp(str, "syslog") == 0)
else if (StrEquals(str, "syslog"))
{
tConfig->flags |= CONFIG_LOG_SYSLOG;
}
@ -229,23 +229,23 @@ ConfigParseLog(Config * tConfig, HashMap * config)
CONFIG_OPTIONAL_STRING(str, "level", "message");
if (strcmp(str, "message") == 0)
if (StrEquals(str, "message"))
{
tConfig->logLevel = LOG_INFO;
}
else if (strcmp(str, "debug") == 0)
else if (StrEquals(str, "debug"))
{
tConfig->logLevel = LOG_DEBUG;
}
else if (strcmp(str, "notice") == 0)
else if (StrEquals(str, "notice"))
{
tConfig->logLevel = LOG_NOTICE;
}
else if (strcmp(str, "warning") == 0)
else if (StrEquals(str, "warning"))
{
tConfig->logLevel = LOG_WARNING;
}
else if (strcmp(str, "error") == 0)
else if (StrEquals(str, "error"))
{
tConfig->logLevel = LOG_ERR;
}
@ -259,7 +259,7 @@ ConfigParseLog(Config * tConfig, HashMap * config)
CONFIG_OPTIONAL_STRING(tConfig->logTimestamp, "timestampFormat", "default");
if (strcmp(tConfig->logTimestamp, "none") == 0)
if (StrEquals(tConfig->logTimestamp, "none"))
{
Free(tConfig->logTimestamp);
tConfig->logTimestamp = NULL;

View file

@ -926,7 +926,7 @@ DbList(Db * db, size_t nArgs,...)
{
int nameOffset = namlen - 5;
if (strcmp(file->d_name + nameOffset, ".json") == 0)
if (StrEquals(file->d_name + nameOffset, ".json"))
{
file->d_name[nameOffset] = '\0';
ArrayAdd(result, StrDuplicate(file->d_name));

View file

@ -24,6 +24,7 @@
#include <HeaderParser.h>
#include <Memory.h>
#include <Str.h>
#include <string.h>
#include <ctype.h>
@ -342,10 +343,10 @@ HeaderParse(Stream * stream, HeaderExpr * expr)
strncpy(expr->data.text + i, word, HEADER_EXPR_MAX - i - 1);
i += strlen(word);
if (strcmp(word, "include") == 0 ||
strcmp(word, "undef") == 0 ||
strcmp(word, "ifdef") == 0 ||
strcmp(word, "ifndef") == 0)
if (StrEquals(word, "include") ||
StrEquals(word, "undef") ||
StrEquals(word, "ifdef") ||
StrEquals(word, "ifndef"))
{
/* Read one more word */
Free(word);
@ -365,10 +366,10 @@ HeaderParse(Stream * stream, HeaderExpr * expr)
Free(word);
}
else if (strcmp(word, "define") == 0 ||
strcmp(word, "if") == 0 ||
strcmp(word, "elif") == 0 ||
strcmp(word, "error") == 0)
else if (StrEquals(word, "define") ||
StrEquals(word, "if") ||
StrEquals(word, "elif") ||
StrEquals(word, "error"))
{
int pC;
@ -412,8 +413,8 @@ HeaderParse(Stream * stream, HeaderExpr * expr)
pC = c;
}
}
else if (strcmp(word, "else") == 0 ||
strcmp(word, "endif") == 0)
else if (StrEquals(word, "else") ||
StrEquals(word, "endif"))
{
/* Read no more words, that's the whole directive */
}
@ -433,7 +434,7 @@ HeaderParse(Stream * stream, HeaderExpr * expr)
StreamUngetc(expr->state.stream, c);
word = HeaderConsumeWord(expr);
if (strcmp(word, "typedef") == 0)
if (StrEquals(word, "typedef"))
{
int block = 0;
int i = 0;
@ -487,7 +488,7 @@ HeaderParse(Stream * stream, HeaderExpr * expr)
}
}
}
else if (strcmp(word, "extern") == 0)
else if (StrEquals(word, "extern"))
{
int wordLimit = sizeof(expr->data.declaration.returnType) - 8;
int wordLen;
@ -509,10 +510,10 @@ HeaderParse(Stream * stream, HeaderExpr * expr)
expr->type = HP_GLOBAL;
strncpy(expr->data.global.type, word, wordLimit);
if (strcmp(word, "struct") == 0 ||
strcmp(word, "enum") == 0 ||
strcmp(word, "const") == 0 ||
strcmp(word, "unsigned") == 0)
if (StrEquals(word, "struct") ||
StrEquals(word, "enum") ||
StrEquals(word, "const") ||
StrEquals(word, "unsigned"))
{
Free(word);
word = HeaderConsumeWord(expr);

View file

@ -31,6 +31,7 @@
#include <Memory.h>
#include <HashMap.h>
#include <Util.h>
#include <Str.h>
#ifndef TELODENDRIA_STRING_CHUNK
#define TELODENDRIA_STRING_CHUNK 64
@ -67,47 +68,47 @@ HttpRequestMethodToString(const HttpRequestMethod method)
HttpRequestMethod
HttpRequestMethodFromString(const char *str)
{
if (strcmp(str, "GET") == 0)
if (StrEquals(str, "GET"))
{
return HTTP_GET;
}
if (strcmp(str, "HEAD") == 0)
if (StrEquals(str, "HEAD"))
{
return HTTP_HEAD;
}
if (strcmp(str, "POST") == 0)
if (StrEquals(str, "POST"))
{
return HTTP_POST;
}
if (strcmp(str, "PUT") == 0)
if (StrEquals(str, "PUT"))
{
return HTTP_PUT;
}
if (strcmp(str, "DELETE") == 0)
if (StrEquals(str, "DELETE"))
{
return HTTP_DELETE;
}
if (strcmp(str, "CONNECT") == 0)
if (StrEquals(str, "CONNECT"))
{
return HTTP_CONNECT;
}
if (strcmp(str, "OPTIONS") == 0)
if (StrEquals(str, "OPTIONS"))
{
return HTTP_OPTIONS;
}
if (strcmp(str, "TRACE") == 0)
if (StrEquals(str, "TRACE"))
{
return HTTP_TRACE;
}
if (strcmp(str, "PATCH") == 0)
if (StrEquals(str, "PATCH"))
{
return HTTP_PATCH;
}
@ -569,7 +570,7 @@ HttpParseHeaders(Stream * fp)
ssize_t i;
size_t len;
if (strcmp(line, "\r\n") == 0 || strcmp(line, "\n") == 0)
if (StrEquals(line, "\r\n") || StrEquals(line, "\n"))
{
break;
}

View file

@ -154,7 +154,7 @@ HttpRouterAdd(HttpRouter * router, char *regPath, HttpRouteFunc * exec)
return 0;
}
if (strcmp(regPath, "/") == 0)
if (StrEquals(regPath, "/"))
{
router->root->exec = exec;
return 1;
@ -213,7 +213,7 @@ HttpRouterRoute(HttpRouter * router, char *path, void *args, void **ret)
node = router->root;
if (strcmp(path, "/") == 0)
if (StrEquals(path, "/"))
{
exec = node->exec;
}

View file

@ -540,7 +540,7 @@ HttpServerWorkerThread(void *args)
requestProtocol = &pathPtr[i + 1];
line[lineLen - 2] = '\0'; /* Get rid of \r and \n */
if (strcmp(requestProtocol, "HTTP/1.1") != 0 && strcmp(requestProtocol, "HTTP/1.0") != 0)
if (!StrEquals(requestProtocol, "HTTP/1.1") && !StrEquals(requestProtocol, "HTTP/1.0"))
{
Free(requestPath);
goto bad_request;

View file

@ -1020,7 +1020,7 @@ JsonTokenSeek(JsonParserState * state)
return;
}
if (!strcmp("true", state->token))
if (StrEquals("true", state->token))
{
state->tokenType = TOKEN_BOOLEAN;
state->tokenLen = 5;
@ -1041,7 +1041,7 @@ JsonTokenSeek(JsonParserState * state)
return;
}
if (!strcmp("false", state->token))
if (StrEquals("false", state->token))
{
state->tokenType = TOKEN_BOOLEAN;
state->tokenLen = 6;
@ -1062,7 +1062,7 @@ JsonTokenSeek(JsonParserState * state)
return;
}
if (!strcmp("null", state->token))
if (StrEquals("null", state->token))
{
state->tokenType = TOKEN_NULL;
}

View file

@ -271,7 +271,7 @@ start:
goto finish;
}
if (!tConfig->logTimestamp || strcmp(tConfig->logTimestamp, "default") != 0)
if (!tConfig->logTimestamp || !StrEquals(tConfig->logTimestamp, "default"))
{
LogConfigTimeStampFormatSet(LogConfigGlobal(), tConfig->logTimestamp);
}

View file

@ -79,6 +79,7 @@ ROUTE_IMPL(RouteDeactivate, path, argp)
{
/* No access token provided, require password */
Array *passwordFlow = ArrayCreate();
ArrayAdd(passwordFlow, UiaStageBuild("m.login.password", NULL));
ArrayAdd(uiaFlows, passwordFlow);
}
@ -110,7 +111,7 @@ ROUTE_IMPL(RouteDeactivate, path, argp)
else
{
/* No access token, we have to get the user off UIA */
char * session = JsonValueAsString(JsonGet(request, 2, "auth", "session"));
char *session = JsonValueAsString(JsonGet(request, 2, "auth", "session"));
DbRef *sessionRef = DbLock(db, 2, "user_interactive", session);
char *userId = JsonValueAsString(HashMapGet(DbJson(sessionRef), "user"));

View file

@ -104,7 +104,7 @@ ROUTE_IMPL(RouteLogin, path, argp)
}
type = JsonValueAsString(val);
if (strcmp(type, "m.login.password") != 0)
if (!StrEquals(type, "m.login.password"))
{
HttpResponseStatus(args->context, HTTP_BAD_REQUEST);
response = MatrixErrorCreate(M_UNRECOGNIZED);
@ -144,7 +144,7 @@ ROUTE_IMPL(RouteLogin, path, argp)
}
type = JsonValueAsString(val);
if (strcmp(type, "m.id.user") != 0)
if (!StrEquals(type, "m.id.user"))
{
HttpResponseStatus(args->context, HTTP_BAD_REQUEST);
response = MatrixErrorCreate(M_UNRECOGNIZED);
@ -174,7 +174,7 @@ ROUTE_IMPL(RouteLogin, path, argp)
break;
}
if (strcmp(userId->server, config->serverName) != 0
if (!StrEquals(userId->server, config->serverName)
|| !UserExists(db, userId->localpart))
{
HttpResponseStatus(args->context, HTTP_FORBIDDEN);

View file

@ -63,7 +63,7 @@ ROUTE_IMPL(RouteLogout, path, argp)
if (ArraySize(path) == 1)
{
if (!MATRIX_PATH_EQUALS(ArrayGet(path, 0), "all"))
if (!StrEquals(ArrayGet(path, 0), "all"))
{
HttpResponseStatus(args->context, HTTP_NOT_FOUND);
response = MatrixErrorCreate(M_NOT_FOUND);

View file

@ -25,6 +25,7 @@
#include <User.h>
#include <Memory.h>
#include <Str.h>
#include <string.h>
#include <signal.h>
@ -61,11 +62,11 @@ ROUTE_IMPL(RouteProcControl, path, argp)
switch (HttpRequestMethodGet(args->context))
{
case HTTP_POST:
if (strcmp(op, "restart") == 0)
if (StrEquals(op, "restart"))
{
raise(SIGUSR1);
}
else if (strcmp(op, "shutdown") == 0)
else if (StrEquals(op, "shutdown"))
{
raise(SIGINT);
}
@ -77,7 +78,7 @@ ROUTE_IMPL(RouteProcControl, path, argp)
}
break;
case HTTP_GET:
if (strcmp(op, "stats") == 0)
if (StrEquals(op, "stats"))
{
response = HashMapCreate();

View file

@ -156,7 +156,7 @@ ROUTE_IMPL(RouteRegister, path, argp)
kind = HashMapGet(HttpRequestParams(args->context), "kind");
/* We don't support guest accounts yet */
if (kind && strcmp(kind, "user") != 0)
if (kind && !StrEquals(kind, "user"))
{
HttpResponseStatus(args->context, HTTP_FORBIDDEN);
response = MatrixErrorCreate(M_INVALID_PARAM);
@ -303,7 +303,7 @@ finish:
else
{
if (HttpRequestMethodGet(args->context) == HTTP_GET &&
MATRIX_PATH_EQUALS(ArrayGet(path, 0), "available"))
StrEquals(ArrayGet(path, 0), "available"))
{
username = HashMapGet(
HttpRequestParams(args->context), "username");

View file

@ -96,7 +96,7 @@ ROUTE_IMPL(RouteRequestToken, path, argp)
goto finish;
}
if (strcmp(type, "email") == 0)
if (StrEquals(type, "email"))
{
val = HashMapGet(request, "email");
if (val && JsonValueType(val) != JSON_STRING)
@ -106,7 +106,7 @@ ROUTE_IMPL(RouteRequestToken, path, argp)
goto finish;
}
}
else if (strcmp(type, "msisdn") == 0)
else if (StrEquals(type, "msisdn"))
{
val = HashMapGet(request, "country");
if (val && JsonValueType(val) != JSON_STRING)

View file

@ -23,6 +23,8 @@
*/
#include <Routes.h>
#include <Str.h>
ROUTE_IMPL(RouteStaticResources, path, argp)
{
RouteArgs *args = argp;
@ -36,7 +38,7 @@ ROUTE_IMPL(RouteStaticResources, path, argp)
return MatrixErrorCreate(M_UNKNOWN);
}
if (strcmp(res, "js") == 0)
if (StrEquals(res, "js"))
{
HttpResponseHeader(args->context, "Content-Type", "text/javascript");
HttpSendHeaders(args->context);
@ -86,7 +88,7 @@ ROUTE_IMPL(RouteStaticResources, path, argp)
);
}
else if (strcmp(res, "css") == 0)
else if (StrEquals(res, "css"))
{
HttpResponseHeader(args->context, "Content-Type", "text/css");
HttpSendHeaders(args->context);

View file

@ -108,7 +108,7 @@ ROUTE_IMPL(RouteUiaFallback, path, argp)
HttpSendHeaders(args->context);
HtmlBegin(stream, "Authentication");
if (strcmp(authType, "m.login.password") == 0)
if (StrEquals(authType, "m.login.password"))
{
HtmlBeginForm(stream, "auth-form");
StreamPuts(stream,
@ -142,7 +142,7 @@ ROUTE_IMPL(RouteUiaFallback, path, argp)
"}", authType, sessionId);
HtmlEndJs(stream);
}
else if (strcmp(authType, "m.login.registration_token") == 0)
else if (StrEquals(authType, "m.login.registration_token"))
{
HtmlBeginForm(stream, "auth-form");
StreamPuts(stream,

View file

@ -134,8 +134,8 @@ ROUTE_IMPL(RouteUserProfile, path, argp)
goto finish;
}
entry = ArrayGet(path, 1);
if (strcmp(entry, "displayname") == 0 ||
strcmp(entry, "avatar_url") == 0)
if (StrEquals(entry, "displayname") ||
StrEquals(entry, "avatar_url"))
{
/* Check if user has privilege to do that action. */
if (strcmp(userId->localpart, UserGetName(user)) == 0)

View file

@ -44,7 +44,7 @@ ROUTE_IMPL(RouteWellKnown, path, argp)
return MatrixErrorCreate(M_UNKNOWN);
}
if (MATRIX_PATH_EQUALS(ArrayGet(path, 0), "client"))
if (StrEquals(ArrayGet(path, 0), "client"))
{
response = MatrixClientWellKnown(config->baseUrl, config->identityServer);
}

View file

@ -276,3 +276,22 @@ StrInt(long i)
return str;
}
int
StrEquals(const char *str1, const char *str2)
{
/* Both strings are NULL, they're equal */
if (!str1 && !str2)
{
return 1;
}
/* One or the other is NULL, they're not equal */
if (!str1 || !str2)
{
return 0;
}
/* Neither are NULL, do a regular string comparison */
return strcmp(str1, str2) == 0;
}

View file

@ -292,7 +292,7 @@ UiaComplete(Array * flows, HttpServerContext * context, Db * db,
char *flowStage = stage->type;
char *completedStage = JsonValueAsString(ArrayGet(completed, j));
if (strcmp(flowStage, completedStage) != 0)
if (!StrEquals(flowStage, completedStage))
{
break;
}
@ -323,7 +323,7 @@ UiaComplete(Array * flows, HttpServerContext * context, Db * db,
{
char *possible = ArrayGet(possibleNext, i);
if (strcmp(authType, possible) == 0)
if (StrEquals(authType, possible))
{
break;
}
@ -336,11 +336,11 @@ UiaComplete(Array * flows, HttpServerContext * context, Db * db,
goto finish;
}
if (strcmp(authType, "m.login.dummy") == 0)
if (StrEquals(authType, "m.login.dummy"))
{
/* Do nothing */
}
else if (strcmp(authType, "m.login.password") == 0)
else if (StrEquals(authType, "m.login.password"))
{
char *password = JsonValueAsString(HashMapGet(auth, "password"));
HashMap *identifier = JsonValueAsObject(HashMapGet(auth, "identifier"));
@ -359,8 +359,8 @@ UiaComplete(Array * flows, HttpServerContext * context, Db * db,
userId = UserIdParse(JsonValueAsString(HashMapGet(identifier, "user")),
config->serverName);
if (!type || strcmp(type, "m.id.user") != 0
|| !userId || strcmp(userId->server, config->serverName) != 0)
if (!type || !StrEquals(type, "m.id.user")
|| !userId || !StrEquals(userId->server, config->serverName))
{
HttpResponseStatus(context, HTTP_UNAUTHORIZED);
ret = BuildResponse(flows, db, response, session, dbRef);
@ -389,7 +389,7 @@ UiaComplete(Array * flows, HttpServerContext * context, Db * db,
UserIdFree(userId);
UserUnlock(user);
}
else if (strcmp(authType, "m.login.registration_token") == 0)
else if (StrEquals(authType, "m.login.registration_token"))
{
RegTokenInfo *tokenInfo;

View file

@ -374,7 +374,7 @@ UserCheckPassword(User * user, char *password)
hashedPwd = Sha256(tmp);
Free(tmp);
result = strcmp(hashedPwd, storedHash) == 0;
result = StrEquals(hashedPwd, storedHash);
Free(hashedPwd);
@ -667,7 +667,7 @@ UserDeleteTokens(User * user, char *exempt)
char *accessToken = JsonValueAsString(HashMapGet(device, "accessToken"));
char *refreshToken = JsonValueAsString(HashMapGet(device, "refreshToken"));
if (exempt && (strcmp(accessToken, exempt) == 0))
if (exempt && (StrEquals(accessToken, exempt)))
{
continue;
}
@ -764,27 +764,27 @@ UserDecodePrivilege(const char *p)
{
return USER_NONE;
}
else if (strcmp(p, "ALL") == 0)
else if (StrEquals(p, "ALL"))
{
return USER_ALL;
}
else if (strcmp(p, "DEACTIVATE") == 0)
else if (StrEquals(p, "DEACTIVATE"))
{
return USER_DEACTIVATE;
}
else if (strcmp(p, "ISSUE_TOKENS") == 0)
else if (StrEquals(p, "ISSUE_TOKENS"))
{
return USER_ISSUE_TOKENS;
}
else if (strcmp(p, "CONFIG") == 0)
else if (StrEquals(p, "CONFIG"))
{
return USER_CONFIG;
}
else if (strcmp(p, "GRANT_PRIVILEGES") == 0)
else if (StrEquals(p, "GRANT_PRIVILEGES"))
{
return USER_GRANT_PRIVILEGES;
}
else if (strcmp(p, "PROC_CONTROL") == 0)
else if (StrEquals(p, "PROC_CONTROL"))
{
return USER_PROC_CONTROL;
}

View file

@ -48,9 +48,6 @@
#include <string.h>
#define MATRIX_PATH_EQUALS(pathPart, str) \
((pathPart != NULL) && (strcmp(pathPart, str) == 0))
/**
* Every route function takes this structure, which contains the data
* it needs to successfully handle an API request.

View file

@ -96,4 +96,18 @@ extern char * StrRandom(size_t);
*/
extern char * StrInt(long);
/**
* Compare two strings and determine whether or not they are equal.
* This is the most common use case of strcmp() in Telodendria, but
* strcmp() doesn't like NULL pointers, so these have to be checked
* explicitly and can cause problems if they aren't. This function,
* on the other hand, makes NULL pointers special cases. If both
* arguments are NULL, then they are considered equal. If only one
* argument is NULL, they are considered not equal. Otherwise, if
* no arguments are NULL, a regular strcmp() takes place and this
* function returns a boolean value indicating whether or not
* strcmp() returned 0.
*/
extern int StrEquals(const char *, const char *);
#endif /* TELODENDRIA_STR_H */

View file

@ -132,7 +132,7 @@ main(int argc, char **argv)
break;
}
if (strcmp(optarg, "-") == 0)
if (StrEquals(optarg, "-"))
{
in = StreamStdin();
}
@ -167,7 +167,7 @@ main(int argc, char **argv)
break;
}
if (strcmp(optarg, "-") == 0)
if (StrEquals(optarg, "-"))
{
out = StreamStdout();
}

View file

@ -118,11 +118,11 @@ main(int argc, char **argv)
if (!uri->port)
{
if (strcmp(uri->proto, "https") == 0)
if (StrEquals(uri->proto, "https"))
{
uri->port = 443;
}
else if (strcmp(uri->proto, "http") == 0)
else if (StrEquals(uri->proto, "http"))
{
uri->port = 80;
}
@ -135,7 +135,7 @@ main(int argc, char **argv)
return 1;
}
if (strcmp(uri->proto, "https") == 0)
if (StrEquals(uri->proto, "https"))
{
requestFlags |= HTTP_FLAG_TLS;
}
@ -166,7 +166,7 @@ main(int argc, char **argv)
data++;
if (strcmp(data, "-") == 0)
if (StrEquals(data, "-"))
{
in = StreamStdin();
}

View file

@ -27,6 +27,7 @@
#include <Array.h>
#include <HashMap.h>
#include <Str.h>
#include <Memory.h>
#include <Json.h>
@ -58,7 +59,7 @@ query(char *select, HashMap * json)
if (keyName[0] == '@')
{
if (strcmp(keyName + 1, "length") == 0)
if (StrEquals(keyName + 1, "length"))
{
switch (JsonValueType(val))
{
@ -73,7 +74,7 @@ query(char *select, HashMap * json)
break;
}
}
else if (JsonValueType(val) == JSON_OBJECT && strcmp(keyName + 1, "keys") == 0)
else if (JsonValueType(val) == JSON_OBJECT && StrEquals(keyName + 1, "keys"))
{
HashMap *obj = JsonValueAsObject(val);
Array *arr = ArrayCreate();
@ -87,7 +88,7 @@ query(char *select, HashMap * json)
val = JsonValueArray(arr);
}
else if (JsonValueType(val) == JSON_STRING && strcmp(keyName + 1, "decode") == 0)
else if (JsonValueType(val) == JSON_STRING && StrEquals(keyName + 1, "decode"))
{
StreamPrintf(StreamStdout(), "%s\n", JsonValueAsString(val));
val = NULL;