Improve base64 encoding/decoding, add URL-safe variant.
[tinc] / src / protocol_key.c
index af103c6..0eeddb8 100644 (file)
@@ -158,11 +158,12 @@ static bool req_key_ext_h(connection_t *c, const char *request, node_t *from, in
                                logger(DEBUG_ALWAYS, LOG_DEBUG, "Got REQ_KEY from %s while we already started a SPTPS session!", from->name);
 
                        char buf[MAX_STRING_SIZE];
-                       if(sscanf(request, "%*d %*s %*s %*d " MAX_STRING, buf) != 1) {
+                       int len;
+
+                       if(sscanf(request, "%*d %*s %*s %*d " MAX_STRING, buf) != 1 || !(len = b64decode(buf, buf, strlen(buf)))) {
                                logger(DEBUG_ALWAYS, LOG_ERR, "Got bad %s from %s (%s): %s", "REQ_SPTPS_START", from->name, from->hostname, "invalid SPTPS data");
                                return true;
                        }
-                       int len = b64decode(buf, buf, strlen(buf));
 
                        char label[25 + strlen(from->name) + strlen(myself->name)];
                        snprintf(label, sizeof label, "tinc UDP key expansion %s %s", from->name, myself->name);
@@ -182,11 +183,11 @@ static bool req_key_ext_h(connection_t *c, const char *request, node_t *from, in
                        }
 
                        char buf[MAX_STRING_SIZE];
-                       if(sscanf(request, "%*d %*s %*s %*d " MAX_STRING, buf) != 1) {
+                       int len;
+                       if(sscanf(request, "%*d %*s %*s %*d " MAX_STRING, buf) != 1 || !(len = b64decode(buf, buf, strlen(buf)))) {
                                logger(DEBUG_ALWAYS, LOG_ERR, "Got bad %s from %s (%s): %s", "REQ_SPTPS", from->name, from->hostname, "invalid SPTPS data");
                                return true;
                        }
-                       int len = b64decode(buf, buf, strlen(buf));
                        sptps_receive_data(&from->sptps, buf, len);
                        return true;
                }
@@ -375,7 +376,7 @@ bool ans_key_h(connection_t *c, const char *request) {
                char buf[strlen(key)];
                int len = b64decode(key, buf, strlen(key));
 
-               if(!sptps_receive_data(&from->sptps, buf, len))
+               if(!len || !sptps_receive_data(&from->sptps, buf, len))
                        logger(DEBUG_ALWAYS, LOG_ERR, "Error processing SPTPS data from %s (%s)", from->name, from->hostname);
 
                if(from->status.validkey) {