Wipe (some) secrets from memory after use
[tinc] / src / gcrypt / rsagen.c
1 /*
2     rsagen.c -- RSA key generation and export
3     Copyright (C) 2008-2022 Guus Sliepen <guus@tinc-vpn.org>
4
5     This program is free software; you can redistribute it and/or modify
6     it under the terms of the GNU General Public License as published by
7     the Free Software Foundation; either version 2 of the License, or
8     (at your option) any later version.
9
10     This program is distributed in the hope that it will be useful,
11     but WITHOUT ANY WARRANTY; without even the implied warranty of
12     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13     GNU General Public License for more details.
14
15     You should have received a copy of the GNU General Public License along
16     with this program; if not, write to the Free Software Foundation, Inc.,
17     51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
18 */
19
20 #include "../system.h"
21
22 #include <gcrypt.h>
23 #include <assert.h>
24
25 #include "rsa.h"
26 #include "pem.h"
27 #include "../rsagen.h"
28 #include "../xalloc.h"
29 #include "../utils.h"
30
31 // ASN.1 tags.
32 typedef enum {
33         TAG_INTEGER = 2,
34         TAG_SEQUENCE = 16,
35 } asn1_tag_t;
36
37 static size_t der_tag_len(size_t n) {
38         if(n < 128) {
39                 return 2;
40         }
41
42         if(n < 256) {
43                 return 3;
44         }
45
46         if(n < 65536) {
47                 return 4;
48         }
49
50         abort();
51 }
52
53 static uint8_t *der_store_tag(uint8_t *p, asn1_tag_t tag, size_t n) {
54         if(tag == TAG_SEQUENCE) {
55                 tag |= 0x20;
56         }
57
58         *p++ = tag;
59
60         if(n < 128) {
61                 *p++ = n;
62         } else if(n < 256) {
63                 *p++ = 0x81;
64                 *p++ = n;
65         } else if(n < 65536) {
66                 *p++ = 0x82;
67                 *p++ = n >> 8;
68                 *p++ = n & 0xff;
69         } else {
70                 abort();
71         }
72
73         return p;
74 }
75
76 static size_t der_fill(uint8_t *derbuf, bool is_private, const gcry_mpi_t mpi[], size_t num_mpi) {
77         size_t needed = 0;
78         size_t lengths[16] = {0};
79
80         assert(num_mpi > 0 && num_mpi < sizeof(lengths) / sizeof(*lengths));
81
82         if(is_private) {
83                 // Add space for the version number.
84                 needed += der_tag_len(1) + 1;
85         }
86
87         for(size_t i = 0; i < num_mpi; ++i) {
88                 gcry_mpi_print(GCRYMPI_FMT_STD, NULL, 0, &lengths[i], mpi[i]);
89                 needed += der_tag_len(lengths[i]) + lengths[i];
90         }
91
92         const size_t derlen = der_tag_len(needed) + needed;
93
94         uint8_t *der = derbuf;
95         der = der_store_tag(der, TAG_SEQUENCE, needed);
96
97         if(is_private) {
98                 // Private key requires storing version number.
99                 der = der_store_tag(der, TAG_INTEGER, 1);
100                 *der++ = 0;
101         }
102
103         for(size_t i = 0; i < num_mpi; ++i) {
104                 const size_t len = lengths[i];
105                 der = der_store_tag(der, TAG_INTEGER, len);
106                 gcry_mpi_print(GCRYMPI_FMT_STD, der, len, NULL, mpi[i]);
107                 der += len;
108         }
109
110         assert((size_t)(der - derbuf) == derlen);
111         return derlen;
112 }
113
114 bool rsa_write_pem_public_key(rsa_t *rsa, FILE *fp) {
115         uint8_t derbuf[8096];
116
117         gcry_mpi_t params[] = {
118                 rsa->n,
119                 rsa->e,
120         };
121
122         size_t derlen = der_fill(derbuf, false, params, sizeof(params) / sizeof(*params));
123
124         return pem_encode(fp, "RSA PUBLIC KEY", derbuf, derlen);
125 }
126
127 // Calculate p/q primes from n/e/d.
128 static void get_p_q(gcry_mpi_t *p,
129                     gcry_mpi_t *q,
130                     const gcry_mpi_t n,
131                     const gcry_mpi_t e,
132                     const gcry_mpi_t d) {
133         const size_t nbits = gcry_mpi_get_nbits(n);
134
135         gcry_mpi_t k = gcry_mpi_new(nbits);
136         gcry_mpi_mul(k, e, d);
137         gcry_mpi_sub_ui(k, k, 1);
138
139         size_t t = 0;
140
141         while(!gcry_mpi_test_bit(k, t)) {
142                 ++t;
143         }
144
145         gcry_mpi_t g = gcry_mpi_new(nbits);
146         gcry_mpi_t gk = gcry_mpi_new(0);
147         gcry_mpi_t sq = gcry_mpi_new(0);
148         gcry_mpi_t rem = gcry_mpi_new(0);
149         gcry_mpi_t gcd = gcry_mpi_new(0);
150
151         while(true) {
152                 gcry_mpi_t kt = gcry_mpi_copy(k);
153                 gcry_mpi_randomize(g, nbits, GCRY_STRONG_RANDOM);
154
155                 size_t i;
156
157                 for(i = 0; i < t; ++i) {
158                         gcry_mpi_rshift(kt, kt, 1);
159                         gcry_mpi_powm(gk, g, kt, n);
160
161                         if(gcry_mpi_cmp_ui(gk, 1) != 0) {
162                                 gcry_mpi_mul(sq, gk, gk);
163                                 gcry_mpi_mod(rem, sq, n);
164
165                                 if(gcry_mpi_cmp_ui(rem, 1) == 0) {
166                                         break;
167                                 }
168                         }
169                 }
170
171                 gcry_mpi_release(kt);
172
173                 if(i < t) {
174                         gcry_mpi_sub_ui(gk, gk, 1);
175                         gcry_mpi_gcd(gcd, gk, n);
176
177                         if(gcry_mpi_cmp_ui(gcd, 1) != 0) {
178                                 break;
179                         }
180                 }
181         }
182
183         gcry_mpi_release(k);
184         gcry_mpi_release(g);
185         gcry_mpi_release(gk);
186         gcry_mpi_release(sq);
187         gcry_mpi_release(rem);
188
189         *p = gcd;
190         *q = gcry_mpi_new(0);
191
192         gcry_mpi_div(*q, NULL, n, *p, 0);
193 }
194
195 bool rsa_write_pem_private_key(rsa_t *rsa, FILE *fp) {
196         gcry_mpi_t params[] = {
197                 rsa->n,
198                 rsa->e,
199                 rsa->d,
200                 NULL, // p
201                 NULL, // q
202                 gcry_mpi_new(0), // d mod (p-1)
203                 gcry_mpi_new(0), // d mod (q-1)
204                 gcry_mpi_new(0), // u = p^-1 mod q
205         };
206
207         // Indexes into params.
208         const size_t d = 2;
209         const size_t p = 3;
210         const size_t q = 4;
211         const size_t dp = 5;
212         const size_t dq = 6;
213         const size_t u = 7;
214
215         // Calculate p and q.
216         get_p_q(&params[p], &params[q], rsa->n, rsa->e, rsa->d);
217
218         // Swap p and q if q > p.
219         if(gcry_mpi_cmp(params[q], params[p]) > 0) {
220                 gcry_mpi_swap(params[p], params[q]);
221         }
222
223         // Calculate u.
224         gcry_mpi_invm(params[u], params[p], params[q]);
225
226         // Calculate d mod (p - 1).
227         gcry_mpi_sub_ui(params[dp], params[p], 1);
228         gcry_mpi_mod(params[dp], params[d], params[dp]);
229
230         // Calculate d mod (q - 1).
231         gcry_mpi_sub_ui(params[dq], params[q], 1);
232         gcry_mpi_mod(params[dq], params[d], params[dq]);
233
234         uint8_t derbuf[8096];
235         const size_t nparams = sizeof(params) / sizeof(*params);
236         size_t derlen = der_fill(derbuf, true, params, nparams);
237
238         gcry_mpi_release(params[p]);
239         gcry_mpi_release(params[q]);
240         gcry_mpi_release(params[dp]);
241         gcry_mpi_release(params[dq]);
242         gcry_mpi_release(params[u]);
243
244         bool success = pem_encode(fp, "RSA PRIVATE KEY", derbuf, derlen);
245         memzero(derbuf, sizeof(derbuf));
246         return success;
247 }
248
249 static gcry_mpi_t find_mpi(const gcry_sexp_t rsa, const char *token) {
250         gcry_sexp_t sexp = gcry_sexp_find_token(rsa, token, 1);
251
252         if(!sexp) {
253                 fprintf(stderr, "Token %s not found in RSA S-expression.\n", token);
254                 return NULL;
255         }
256
257         gcry_mpi_t mpi = gcry_sexp_nth_mpi(sexp, 1, GCRYMPI_FMT_USG);
258         gcry_sexp_release(sexp);
259         return mpi;
260 }
261
262 rsa_t *rsa_generate(size_t bits, unsigned long exponent) {
263         gcry_sexp_t s_params;
264         gcry_error_t err = gcry_sexp_build(&s_params, NULL,
265                                            "(genkey"
266                                            "  (rsa"
267                                            "    (nbits %u)"
268                                            "    (rsa-use-e %u)))",
269                                            bits,
270                                            exponent);
271
272         if(err) {
273                 fprintf(stderr, "Error building keygen S-expression: %s.\n", gcry_strerror(err));
274                 return NULL;
275         }
276
277         gcry_sexp_t s_key;
278         err = gcry_pk_genkey(&s_key, s_params);
279         gcry_sexp_release(s_params);
280
281         if(err) {
282                 fprintf(stderr, "Error generating RSA key pair: %s.\n", gcry_strerror(err));
283                 return NULL;
284         }
285
286         // `gcry_sexp_extract_param` can replace everything below
287         // with a single line, but it's not available on CentOS 7.
288         gcry_sexp_t s_priv = gcry_sexp_find_token(s_key, "private-key", 0);
289
290         if(!s_priv) {
291                 fprintf(stderr, "Private key not found in gcrypt result.\n");
292                 gcry_sexp_release(s_key);
293                 return NULL;
294         }
295
296         gcry_sexp_t s_rsa = gcry_sexp_find_token(s_priv, "rsa", 0);
297
298         if(!s_rsa) {
299                 fprintf(stderr, "RSA not found in gcrypt result.\n");
300                 gcry_sexp_release(s_priv);
301                 gcry_sexp_release(s_key);
302                 return NULL;
303         }
304
305         rsa_t *rsa = xzalloc(sizeof(*rsa));
306
307         rsa->n = find_mpi(s_rsa, "n");
308         rsa->e = find_mpi(s_rsa, "e");
309         rsa->d = find_mpi(s_rsa, "d");
310
311         gcry_sexp_release(s_rsa);
312         gcry_sexp_release(s_priv);
313         gcry_sexp_release(s_key);
314
315         if(rsa->n && rsa->e && rsa->d) {
316                 return rsa;
317         }
318
319         rsa_free(rsa);
320         return NULL;
321 }