Fix reading broken BER in gcrypt/rsa.c
[tinc] / src / gcrypt / rsa.c
index 1b15164..04aa358 100644 (file)
@@ -23,6 +23,7 @@
 
 #include "pem.h"
 
+#include "asn1.h"
 #include "rsa.h"
 #include "../logger.h"
 #include "../rsa.h"
@@ -71,7 +72,7 @@ static size_t ber_read_len(unsigned char **p, size_t *buflen) {
                        return 0;
                }
 
-               while(len--) {
+               for(; len; --len) {
                        result = (size_t)(result << 8);
                        result |= *(*p)++;
                        (*buflen)--;
@@ -84,20 +85,11 @@ static size_t ber_read_len(unsigned char **p, size_t *buflen) {
        }
 }
 
-
-static bool ber_read_sequence(unsigned char **p, size_t *buflen, size_t *result) {
+static bool ber_skip_sequence(unsigned char **p, size_t *buflen) {
        int tag = ber_read_id(p, buflen);
-       size_t len = ber_read_len(p, buflen);
 
-       if(tag == 0x10) {
-               if(result) {
-                       *result = len;
-               }
-
-               return true;
-       } else {
-               return false;
-       }
+       return tag == TAG_SEQUENCE &&
+              ber_read_len(p, buflen) > 0;
 }
 
 static bool ber_read_mpi(unsigned char **p, size_t *buflen, gcry_mpi_t *mpi) {
@@ -130,7 +122,7 @@ rsa_t *rsa_set_hex_public_key(const char *n, const char *e) {
 
        if(err) {
                logger(DEBUG_ALWAYS, LOG_ERR, "Error while reading RSA public key: %s", gcry_strerror(errno));
-               free(rsa);
+               rsa_free(rsa);
                return false;
        }
 
@@ -152,8 +144,8 @@ rsa_t *rsa_set_hex_private_key(const char *n, const char *e, const char *d) {
 
        if(err) {
                logger(DEBUG_ALWAYS, LOG_ERR, "Error while reading RSA public key: %s", gcry_strerror(errno));
-               free(rsa);
-               return false;
+               rsa_free(rsa);
+               return NULL;
        }
 
        return rsa;
@@ -172,12 +164,12 @@ rsa_t *rsa_read_pem_public_key(FILE *fp) {
 
        rsa_t *rsa = xzalloc(sizeof(rsa_t));
 
-       if(!ber_read_sequence(&derp, &derlen, NULL)
+       if(!ber_skip_sequence(&derp, &derlen)
                        || !ber_read_mpi(&derp, &derlen, &rsa->n)
                        || !ber_read_mpi(&derp, &derlen, &rsa->e)
                        || derlen) {
                logger(DEBUG_ALWAYS, LOG_ERR, "Error while decoding RSA public key");
-               free(rsa);
+               rsa_free(rsa);
                return NULL;
        }
 
@@ -195,7 +187,7 @@ rsa_t *rsa_read_pem_private_key(FILE *fp) {
 
        rsa_t *rsa = xzalloc(sizeof(rsa_t));
 
-       if(!ber_read_sequence(&derp, &derlen, NULL)
+       if(!ber_skip_sequence(&derp, &derlen)
                        || !ber_read_mpi(&derp, &derlen, NULL)
                        || !ber_read_mpi(&derp, &derlen, &rsa->n)
                        || !ber_read_mpi(&derp, &derlen, &rsa->e)
@@ -207,10 +199,11 @@ rsa_t *rsa_read_pem_private_key(FILE *fp) {
                        || !ber_read_mpi(&derp, &derlen, NULL) // u
                        || derlen) {
                logger(DEBUG_ALWAYS, LOG_ERR, "Error while decoding RSA private key");
-               free(rsa);
-               return NULL;
+               rsa_free(rsa);
+               rsa = NULL;
        }
 
+       memzero(derbuf, sizeof(derbuf));
        return rsa;
 }
 
@@ -218,19 +211,27 @@ size_t rsa_size(const rsa_t *rsa) {
        return (gcry_mpi_get_nbits(rsa->n) + 7) / 8;
 }
 
+static bool check(gcry_error_t err) {
+       if(err) {
+               logger(DEBUG_ALWAYS, LOG_ERR, "gcrypt error %s/%s", gcry_strsource(err), gcry_strerror(err));
+       }
+
+       return !err;
+}
+
 /* Well, libgcrypt has functions to handle RSA keys, but they suck.
  * So we just use libgcrypt's mpi functions, and do the math ourselves.
  */
 
-// TODO: get rid of this macro, properly clean up gcry_ structures after use
-#define check(foo) { gcry_error_t err = (foo); if(err) {logger(DEBUG_ALWAYS, LOG_ERR, "gcrypt error %s/%s at %s:%d", gcry_strsource(err), gcry_strerror(err), __FILE__, __LINE__); return false; }}
+static bool rsa_powm(const gcry_mpi_t ed, const gcry_mpi_t n, const void *in, size_t len, void *out) {
+       gcry_mpi_t inmpi = NULL;
 
-bool rsa_public_encrypt(rsa_t *rsa, const void *in, size_t len, void *out) {
-       gcry_mpi_t inmpi;
-       check(gcry_mpi_scan(&inmpi, GCRYMPI_FMT_USG, in, len, NULL));
+       if(!check(gcry_mpi_scan(&inmpi, GCRYMPI_FMT_USG, in, len, NULL))) {
+               return false;
+       }
 
-       gcry_mpi_t outmpi = gcry_mpi_new(len * 8);
-       gcry_mpi_powm(outmpi, inmpi, rsa->e, rsa->n);
+       gcry_mpi_t outmpi = gcry_mpi_snew(len * 8);
+       gcry_mpi_powm(outmpi, inmpi, ed, n);
 
        size_t out_bytes = (gcry_mpi_get_nbits(outmpi) + 7) / 8;
        size_t pad = len - MIN(out_bytes, len);
@@ -240,28 +241,20 @@ bool rsa_public_encrypt(rsa_t *rsa, const void *in, size_t len, void *out) {
                *pout++ = 0;
        }
 
-       check(gcry_mpi_print(GCRYMPI_FMT_USG, pout, len, NULL, outmpi));
+       bool ok = check(gcry_mpi_print(GCRYMPI_FMT_USG, pout, len, NULL, outmpi));
 
-       return true;
-}
+       gcry_mpi_release(outmpi);
+       gcry_mpi_release(inmpi);
 
-bool rsa_private_decrypt(rsa_t *rsa, const void *in, size_t len, void *out) {
-       gcry_mpi_t inmpi;
-       check(gcry_mpi_scan(&inmpi, GCRYMPI_FMT_USG, in, len, NULL));
-
-       gcry_mpi_t outmpi = gcry_mpi_new(len * 8);
-       gcry_mpi_powm(outmpi, inmpi, rsa->d, rsa->n);
-
-       size_t pad = len - (gcry_mpi_get_nbits(outmpi) + 7) / 8;
-       unsigned char *pout = out;
-
-       for(; pad; --pad) {
-               *pout++ = 0;
-       }
+       return ok;
+}
 
-       check(gcry_mpi_print(GCRYMPI_FMT_USG, pout, len, NULL, outmpi));
+bool rsa_public_encrypt(rsa_t *rsa, const void *in, size_t len, void *out) {
+       return rsa_powm(rsa->e, rsa->n, in, len, out);
+}
 
-       return true;
+bool rsa_private_decrypt(rsa_t *rsa, const void *in, size_t len, void *out) {
+       return rsa_powm(rsa->d, rsa->n, in, len, out);
 }
 
 void rsa_free(rsa_t *rsa) {