diff --git a/src/Parser.c b/src/Parser.c index bef70af..3dda8a0 100644 --- a/src/Parser.c +++ b/src/Parser.c @@ -24,12 +24,18 @@ #include -#include #include +#include +#include #include +#include #include + +/* Iterate through a char **. */ +#define Iterate(s) (*(*s)++) + /* Parse an extended localpart */ static int ParseUserLocalpart(char **str, char **out) @@ -45,21 +51,187 @@ ParseUserLocalpart(char **str, char **out) /* An extended localpart contains every ASCII printable character, * except an ':'. */ start = *str; - while (isascii((c = (*(*str)++))) && c != ':' && c) + while (isascii((c = Iterate(str))) && c != ':' && c) { /* Do nothing */ } - length = (size_t) (*str - start) -1; - if (c != ':' || length < 1) + length = (size_t) (*str - start) - 1; + if (length < 1) { *str = start; return 0; } + if (c == ':') + { + --(*str); + } *out = Malloc(length + 1); memcpy(*out, start, length); - *out[length] = '\0'; + (*out)[length] = '\0'; + return 1; +} +/* Parses an IPv4 address. */ +static int +ParseIPv4(char **str, char **out) +{ + /* Be *very* careful with this buffer */ + char buffer[4]; + char *start; + size_t length; + char c; + + int digit = 0; + int digits = 0; + + memset(buffer, '\0', 4); + start = *str; + + /* An IPv4 address is made of 4 blocks between 1-3 digits, like so: + * (1-3)*DIGIT.(1-3)*DIGIT.(1-3)*DIGIT.(1-3)*DIGIT */ + while ((isdigit(c = Iterate(str)) || c == '.') && c && digits < 4) + { + if (isdigit(c)) + { + digit++; + continue; + } + if (digit < 1 || digit > 3) + { + /* Current digit is too long for the spec! */ + *str = start; + return 0; + } + memcpy(buffer, *str - digit - 1, digit); + if (atoi(buffer) > 255) + { + /* Current digit is too large for the spec! */ + *str = start; + return 0; + } + memset(buffer, '\0', 4); + digit = 0; + digits++; /* We have parsed a digit. */ + } + if (c == '.' || digits != 3) + { + *str = start; + return 0; + } + length = (size_t) (*str - start) - 1; + *out = Malloc(length + 1); + memcpy(*out, start, length); + (*str)--; + return 1; +} +static int +ParseIPv6(char **str, char **out) +{ + /* TODO */ + (void) str; + (void) out; + return 0; +} +static int +ParseHostname(char **str, char **out) +{ + char *start; + size_t length = 0; + char c; + + start = *str; + while ((c = Iterate(str)) && + (isalnum(c) || c == '.' || c == '-') && + ++length < 256) + { + /* Do nothing. */ + } + if (length < 1 || length > 255) + { + *str = start; + return 0; + } + length = (size_t) (*str - start) - 1; + *out = Malloc(length + 1); + memcpy(*out, start, length); + (*str)--; + return 1; +} + +static int +ParseServerName(char **str, ServerPart *out) +{ + char c; + char *start; + char *startPort; + size_t chars = 0; + + char *host = NULL; + char *port = NULL; + + if (!str || !out) + { + return 0; + } + + start = *str; + + if (!host) + { + /* If we can parse an IPv4 address, use that. */ + ParseIPv4(str, &host); + } + if (!host) + { + /* If we can parse an IPv6 address, use that. */ + ParseIPv6(str, &host); + } + if (!host) + { + /* If we can parse an hostname, use that. */ + ParseHostname(str, &host); + } + if (!host) + { + /* Can't parse a valid server name. */ + return 0; + } + /* Now, there's only 2 options: a ':', or the end(everything else.) */ + if (**str != ':') + { + /* We're done. */ + out->hostname = host; + out->port = NULL; + return 1; + } + /* TODO: Separate this out */ + startPort = ++(*str); + while(isdigit(c = Iterate(str)) && c && ++chars < 5) + { + /* Do nothing. */ + } + if (chars < 1 || chars > 5) + { + *str = start; + Free(host); + host = NULL; + return 0; + } + + port = Malloc(chars + 1); + memcpy(port, startPort, chars); + port[chars] = '\0'; + if (atol(port) > 65535) + { + Free(port); + Free(host); + *str = start; + return 0; + } + + out->hostname = host; + out->port = port; return 1; } @@ -89,17 +261,48 @@ ParseCommonID(char *str, CommonID *id) return 0; } id->sigil = sigil; + id->local = NULL; + id->server.hostname = NULL; + id->server.port = NULL; - switch(sigil) + switch (sigil) { case '@': if (!ParseUserLocalpart(&str, &id->local)) { return 0; } + if (*str++ != ':') + { + Free(id->local); + id->local = NULL; + return 0; + } /* TODO: Match whenever str is valid. */ - id->server = StrDuplicate(str); + if (!ParseServerName(&str, &id->server)) + { + Free(id->local); + id->local = NULL; + return 0; + } break; } return 1; } + +void +CommonIDFree(CommonID id) +{ + if (id.local) + { + Free(id.local); + } + if (id.server.hostname) + { + Free(id.server.hostname); + } + if (id.server.port) + { + Free(id.server.port); + } +} diff --git a/src/User.c b/src/User.c index 0f0f892..20a4289 100644 --- a/src/User.c +++ b/src/User.c @@ -909,29 +909,30 @@ UserIdParse(char *id, char *defaultServer) /* Fully-qualified user ID */ if (*id == '@') { + /* TODO: Just use the CommonID. */ CommonID commonID; commonID.sigil = '\0'; commonID.local = NULL; - commonID.server = NULL; - if (!ParseCommonID(id, &commonID) || !commonID.server) + commonID.server.hostname = NULL; + commonID.server.port = NULL; + if (!ParseCommonID(id, &commonID) || !commonID.server.hostname) { Free(userId); Free(commonID.local); - Free(commonID.server); - userId = NULL; - goto finish; - } - if (*commonID.server == '\0') - { - Free(userId); - Free(commonID.local); - Free(commonID.server); + if (commonID.server.hostname) + { + Free(commonID.server.hostname); + } userId = NULL; goto finish; } userId->localpart = commonID.local; - userId->server = commonID.server; + userId->server = commonID.server.hostname; + if (commonID.server.port) + { + Free(commonID.server.port); + } } else { diff --git a/src/include/Parser.h b/src/include/Parser.h index 3f7fb7e..788f92f 100644 --- a/src/include/Parser.h +++ b/src/include/Parser.h @@ -36,6 +36,13 @@ * Matrix specification */ +/** + * The host[:port] format in a servername. + */ +typedef struct ServerPart { + char *hostname; + char *port; +} ServerPart; /** * A common identifier in the form '&local[:server]', where * & determines the *type* of the identifier. @@ -43,7 +50,7 @@ typedef struct CommonID { char sigil; char *local; - char *server; /* Might be NULL for some sigils(e.g: room IDs >= v3) */ + ServerPart server; } CommonID; /**