If we link with OpenSSL, use it for Chacha20-Poly1305 as well.
[tinc] / src / sptps.c
1 /*
2     sptps.c -- Simple Peer-to-Peer Security
3     Copyright (C) 2011-2015 Guus Sliepen <guus@tinc-vpn.org>,
4                   2010      Brandon L. Black <blblack@gmail.com>
5
6     This program is free software; you can redistribute it and/or modify
7     it under the terms of the GNU General Public License as published by
8     the Free Software Foundation; either version 2 of the License, or
9     (at your option) any later version.
10
11     This program is distributed in the hope that it will be useful,
12     but WITHOUT ANY WARRANTY; without even the implied warranty of
13     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14     GNU General Public License for more details.
15
16     You should have received a copy of the GNU General Public License along
17     with this program; if not, write to the Free Software Foundation, Inc.,
18     51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
19 */
20
21 #include "system.h"
22
23 #include "chacha-poly1305/chacha-poly1305.h"
24 #include "ecdh.h"
25 #include "ecdsa.h"
26 #include "prf.h"
27 #include "sptps.h"
28 #include "random.h"
29 #include "xalloc.h"
30
31 #ifdef HAVE_OPENSSL
32 #include <openssl/evp.h>
33 #endif
34
35 unsigned int sptps_replaywin = 16;
36
37 /*
38    Nonce MUST be exchanged first (done)
39    Signatures MUST be done over both nonces, to guarantee the signature is fresh
40    Otherwise: if ECDHE key of one side is compromised, it can be reused!
41
42    Add explicit tag to beginning of structure to distinguish the client and server when signing. (done)
43
44    Sign all handshake messages up to ECDHE kex with long-term public keys. (done)
45
46    HMACed KEX finished message to prevent downgrade attacks and prove you have the right key material (done by virtue of Ed25519 over the whole ECDHE exchange?)
47
48    Explicit close message needs to be added.
49
50    Maybe do add some alert messages to give helpful error messages? Not more than TLS sends.
51
52    Use counter mode instead of OFB. (done)
53
54    Make sure ECC operations are fixed time (aka prevent side-channel attacks).
55 */
56
57 void sptps_log_quiet(sptps_t *s, int s_errno, const char *format, va_list ap) {
58         (void)s;
59         (void)s_errno;
60         (void)format;
61         (void)ap;
62 }
63
64 void sptps_log_stderr(sptps_t *s, int s_errno, const char *format, va_list ap) {
65         (void)s;
66         (void)s_errno;
67
68         vfprintf(stderr, format, ap);
69         fputc('\n', stderr);
70 }
71
72 void (*sptps_log)(sptps_t *s, int s_errno, const char *format, va_list ap) = sptps_log_stderr;
73
74 // Log an error message.
75 static bool error(sptps_t *s, int s_errno, const char *format, ...) ATTR_FORMAT(printf, 3, 4);
76 static bool error(sptps_t *s, int s_errno, const char *format, ...) {
77         (void)s;
78         (void)s_errno;
79
80         if(format) {
81                 va_list ap;
82                 va_start(ap, format);
83                 sptps_log(s, s_errno, format, ap);
84                 va_end(ap);
85         }
86
87         errno = s_errno;
88         return false;
89 }
90
91 static void warning(sptps_t *s, const char *format, ...) ATTR_FORMAT(printf, 2, 3);
92 static void warning(sptps_t *s, const char *format, ...) {
93         va_list ap;
94         va_start(ap, format);
95         sptps_log(s, 0, format, ap);
96         va_end(ap);
97 }
98
99 static sptps_kex_t *new_sptps_kex(void) {
100         return xzalloc(sizeof(sptps_kex_t));
101 }
102
103 static void free_sptps_kex(sptps_kex_t *kex) {
104         xzfree(kex, sizeof(sptps_kex_t));
105 }
106
107 static sptps_key_t *new_sptps_key(void) {
108         return xzalloc(sizeof(sptps_key_t));
109 }
110
111 static void free_sptps_key(sptps_key_t *key) {
112         xzfree(key, sizeof(sptps_key_t));
113 }
114
115 static bool cipher_init(uint8_t suite, void **ctx, const sptps_key_t *keys, bool key_half) {
116         const uint8_t *key = key_half ? keys->key1 : keys->key0;
117
118         switch(suite) {
119 #ifndef HAVE_OPENSSL
120
121         case SPTPS_CHACHA_POLY1305:
122                 *ctx = chacha_poly1305_init();
123                 return ctx && chacha_poly1305_set_key(*ctx, key);
124
125 #else
126
127         case SPTPS_CHACHA_POLY1305:
128                 *ctx = EVP_CIPHER_CTX_new();
129
130                 if(!ctx) {
131                         return false;
132                 }
133
134                 return EVP_EncryptInit_ex(*ctx, EVP_chacha20_poly1305(), NULL, NULL, NULL)
135                        && EVP_CIPHER_CTX_ctrl(*ctx, EVP_CTRL_AEAD_SET_IVLEN, 12, NULL)
136                        && EVP_EncryptInit_ex(*ctx, NULL, NULL, key, key + 32);
137
138         case SPTPS_AES256_GCM:
139                 *ctx = EVP_CIPHER_CTX_new();
140
141                 if(!ctx) {
142                         return false;
143                 }
144
145                 return EVP_EncryptInit_ex(*ctx, EVP_aes_256_gcm(), NULL, NULL, NULL)
146                        && EVP_CIPHER_CTX_ctrl(*ctx, EVP_CTRL_AEAD_SET_IVLEN, 12, NULL)
147                        && EVP_EncryptInit_ex(*ctx, NULL, NULL, key, key + 32);
148 #endif
149
150         default:
151                 return false;
152         }
153 }
154
155 static void cipher_exit(uint8_t suite, void *ctx) {
156         switch(suite) {
157 #ifndef HAVE_OPENSSL
158
159         case SPTPS_CHACHA_POLY1305:
160                 chacha_poly1305_exit(ctx);
161                 break;
162
163 #else
164
165         case SPTPS_CHACHA_POLY1305:
166         case SPTPS_AES256_GCM:
167                 EVP_CIPHER_CTX_free(ctx);
168                 break;
169 #endif
170
171         default:
172                 break;
173         }
174 }
175
176 static bool cipher_encrypt(uint8_t suite, void *ctx, uint32_t seqno, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) {
177         switch(suite) {
178 #ifndef HAVE_OPENSSL
179
180         case SPTPS_CHACHA_POLY1305:
181                 chacha_poly1305_encrypt(ctx, seqno, in, inlen, out, outlen);
182                 return true;
183
184 #else
185
186         case SPTPS_CHACHA_POLY1305:
187         case SPTPS_AES256_GCM: {
188                 uint8_t nonce[12] = {seqno, seqno >> 8, seqno >> 16, seqno >> 24};
189
190                 if(!EVP_EncryptInit_ex(ctx, NULL, NULL, NULL, nonce)) {
191                         return false;
192                 }
193
194                 int outlen1 = 0, outlen2 = 0;
195
196                 if(!EVP_EncryptUpdate(ctx, out, &outlen1, in, (int)inlen)) {
197                         return false;
198                 }
199
200                 if(!EVP_EncryptFinal_ex(ctx, out + outlen1, &outlen2)) {
201                         return false;
202                 }
203
204                 outlen1 += outlen2;
205
206                 if(!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, 16, out + outlen1)) {
207                         return false;
208                 }
209
210                 outlen1 += 16;
211
212                 if(outlen) {
213                         *outlen = outlen1;
214                 }
215
216                 return true;
217         }
218
219 #endif
220
221         default:
222                 return false;
223         }
224 }
225
226 static bool cipher_decrypt(uint8_t suite, void *ctx, uint32_t seqno, const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen) {
227         switch(suite) {
228 #ifndef HAVE_OPENSSL
229
230         case SPTPS_CHACHA_POLY1305:
231                 return chacha_poly1305_decrypt(ctx, seqno, in, inlen, out, outlen);
232
233 #else
234
235         case SPTPS_CHACHA_POLY1305:
236         case SPTPS_AES256_GCM: {
237                 if(inlen < 16) {
238                         return false;
239                 }
240
241                 inlen -= 16;
242
243                 uint8_t nonce[12] = {seqno, seqno >> 8, seqno >> 16, seqno >> 24};
244
245                 if(!EVP_DecryptInit_ex(ctx, NULL, NULL, NULL, nonce)) {
246                         return false;
247                 }
248
249                 int outlen1 = 0, outlen2 = 0;
250
251                 if(!EVP_DecryptUpdate(ctx, out, &outlen1, in, (int)inlen)) {
252                         return false;
253                 }
254
255                 if(!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_TAG, 16, (void *)(in + inlen))) {
256                         return false;
257                 }
258
259                 if(!EVP_DecryptFinal_ex(ctx, out + outlen1, &outlen2)) {
260                         return false;
261                 }
262
263                 if(outlen) {
264                         *outlen = outlen1 + outlen2;
265                 }
266
267                 return true;
268         }
269
270 #endif
271
272         default:
273                 return false;
274         }
275 }
276
277 // Send a record (datagram version, accepts all record types, handles encryption and authentication).
278 static bool send_record_priv_datagram(sptps_t *s, uint8_t type, const void *data, uint16_t len) {
279         uint8_t *buffer = alloca(len + SPTPS_DATAGRAM_OVERHEAD);
280         // Create header with sequence number, length and record type
281         uint32_t seqno = s->outseqno++;
282
283         memcpy(buffer, &seqno, 4);
284         buffer[4] = type;
285         memcpy(buffer + 5, data, len);
286
287         if(s->outstate) {
288                 // If first handshake has finished, encrypt and HMAC
289                 if(!cipher_encrypt(s->cipher_suite, s->outcipher, seqno, buffer + 4, len + 1, buffer + 4, NULL)) {
290                         return error(s, EINVAL, "Failed to encrypt message");
291                 }
292
293                 return s->send_data(s->handle, type, buffer, len + SPTPS_DATAGRAM_OVERHEAD);
294         } else {
295                 // Otherwise send as plaintext
296                 return s->send_data(s->handle, type, buffer, len + SPTPS_DATAGRAM_HEADER);
297         }
298 }
299 // Send a record (private version, accepts all record types, handles encryption and authentication).
300 static bool send_record_priv(sptps_t *s, uint8_t type, const void *data, uint16_t len) {
301         if(s->datagram) {
302                 return send_record_priv_datagram(s, type, data, len);
303         }
304
305         uint8_t *buffer = alloca(len + SPTPS_OVERHEAD);
306
307         // Create header with sequence number, length and record type
308         uint32_t seqno = s->outseqno++;
309         uint16_t netlen = len;
310
311         memcpy(buffer, &netlen, 2);
312         buffer[2] = type;
313         memcpy(buffer + 3, data, len);
314
315         if(s->outstate) {
316                 // If first handshake has finished, encrypt and HMAC
317                 if(!cipher_encrypt(s->cipher_suite, s->outcipher, seqno, buffer + 2, len + 1, buffer + 2, NULL)) {
318                         return error(s, EINVAL, "Failed to encrypt message");
319                 }
320
321                 return s->send_data(s->handle, type, buffer, len + SPTPS_OVERHEAD);
322         } else {
323                 // Otherwise send as plaintext
324                 return s->send_data(s->handle, type, buffer, len + SPTPS_HEADER);
325         }
326 }
327
328 // Send an application record.
329 bool sptps_send_record(sptps_t *s, uint8_t type, const void *data, uint16_t len) {
330         // Sanity checks: application cannot send data before handshake is finished,
331         // and only record types 0..127 are allowed.
332         if(!s->outstate) {
333                 return error(s, EINVAL, "Handshake phase not finished yet");
334         }
335
336         if(type >= SPTPS_HANDSHAKE) {
337                 return error(s, EINVAL, "Invalid application record type");
338         }
339
340         return send_record_priv(s, type, data, len);
341 }
342
343 // Send a Key EXchange record, containing a random nonce and an ECDHE public key.
344 static bool send_kex(sptps_t *s) {
345         // Make room for our KEX message, which we will keep around since send_sig() needs it.
346         if(s->mykex) {
347                 return false;
348         }
349
350         s->mykex = new_sptps_kex();
351
352         // Set version byte to zero.
353         s->mykex->version = SPTPS_VERSION;
354         s->mykex->preferred_suite = s->preferred_suite;
355         s->mykex->cipher_suites = s->cipher_suites;
356
357         // Create a random nonce.
358         randomize(s->mykex->nonce, ECDH_SIZE);
359
360         // Create a new ECDH public key.
361         if(!(s->ecdh = ecdh_generate_public(s->mykex->pubkey))) {
362                 return error(s, EINVAL, "Failed to generate ECDH public key");
363         }
364
365         return send_record_priv(s, SPTPS_HANDSHAKE, s->mykex, sizeof(sptps_kex_t));
366 }
367
368 static size_t sigmsg_len(size_t labellen) {
369         return 1 + 2 * sizeof(sptps_kex_t) + labellen;
370 }
371
372 static void fill_msg(uint8_t *msg, bool initiator, const sptps_kex_t *kex0, const sptps_kex_t *kex1, const sptps_t *s) {
373         *msg = initiator, msg++;
374         memcpy(msg, kex0, sizeof(*kex0)), msg += sizeof(*kex0);
375         memcpy(msg, kex1, sizeof(*kex1)), msg += sizeof(*kex1);
376         memcpy(msg, s->label, s->labellen);
377 }
378
379 // Send a SIGnature record, containing an Ed25519 signature over both KEX records.
380 static bool send_sig(sptps_t *s) {
381         // Concatenate both KEX messages, plus tag indicating if it is from the connection originator, plus label
382         size_t msglen = sigmsg_len(s->labellen);
383         uint8_t *msg = alloca(msglen);
384         fill_msg(msg, s->initiator, s->mykex, s->hiskex, s);
385
386         // Sign the result.
387         size_t siglen = ecdsa_size(s->mykey);
388         uint8_t *sig = alloca(siglen);
389
390         if(!ecdsa_sign(s->mykey, msg, msglen, sig)) {
391                 return error(s, EINVAL, "Failed to sign SIG record");
392         }
393
394         // Send the SIG exchange record.
395         return send_record_priv(s, SPTPS_HANDSHAKE, sig, siglen);
396 }
397
398 // Generate key material from the shared secret created from the ECDHE key exchange.
399 static bool generate_key_material(sptps_t *s, const uint8_t *shared, size_t len) {
400         // Allocate memory for key material
401         s->key = new_sptps_key();
402
403         // Create the HMAC seed, which is "key expansion" + session label + server nonce + client nonce
404         const size_t msglen = sizeof("key expansion") - 1;
405         const size_t seedlen = msglen + s->labellen + ECDH_SIZE * 2;
406         uint8_t *seed = alloca(seedlen);
407
408         uint8_t *ptr = seed;
409         memcpy(ptr, "key expansion", msglen);
410         ptr += msglen;
411
412         memcpy(ptr, (s->initiator ? s->mykex : s->hiskex)->nonce, ECDH_SIZE);
413         ptr += ECDH_SIZE;
414
415         memcpy(ptr, (s->initiator ? s->hiskex : s->mykex)->nonce, ECDH_SIZE);
416         ptr += ECDH_SIZE;
417
418         memcpy(ptr, s->label, s->labellen);
419
420         // Use PRF to generate the key material
421         if(!prf(shared, len, seed, seedlen, s->key->both, sizeof(sptps_key_t))) {
422                 return error(s, EINVAL, "Failed to generate key material");
423         }
424
425         return true;
426 }
427
428 // Send an ACKnowledgement record.
429 static bool send_ack(sptps_t *s) {
430         return send_record_priv(s, SPTPS_HANDSHAKE, "", 0);
431 }
432
433 // Receive an ACKnowledgement record.
434 static bool receive_ack(sptps_t *s, const uint8_t *data, uint16_t len) {
435         (void)data;
436
437         if(len) {
438                 return error(s, EIO, "Invalid ACK record length");
439         }
440
441         if(!cipher_init(s->cipher_suite, &s->incipher, s->key, s->initiator)) {
442                 return error(s, EINVAL, "Failed to initialize cipher");
443         }
444
445         free_sptps_key(s->key);
446         s->key = NULL;
447         s->instate = true;
448
449         return true;
450 }
451
452 static uint8_t select_cipher_suite(uint16_t mask, uint8_t pref1, uint8_t pref2) {
453         // Check if there is a viable preference, if so select the lowest one
454         uint8_t selection = 255;
455
456         if(mask & (1U << pref1)) {
457                 selection = pref1;
458         }
459
460         if(pref2 < selection && (mask & (1U << pref2))) {
461                 selection = pref2;
462         }
463
464         // Otherwise, select the lowest cipher suite both sides support
465         if(selection == 255) {
466                 selection = 0;
467
468                 while(!(mask & 1U)) {
469                         selection++;
470                         mask >>= 1;
471                 }
472         }
473
474         return selection;
475 }
476
477 // Receive a Key EXchange record, respond by sending a SIG record.
478 static bool receive_kex(sptps_t *s, const uint8_t *data, uint16_t len) {
479         // Verify length of the HELLO record
480
481         if(len != sizeof(sptps_kex_t)) {
482                 return error(s, EIO, "Invalid KEX record length");
483         }
484
485         if(*data != SPTPS_VERSION) {
486                 return error(s, EINVAL, "Received incorrect version %d", *data);
487         }
488
489         uint16_t suites;
490         memcpy(&suites, data + 2, 2);
491         suites &= s->cipher_suites;
492
493         if(!suites) {
494                 return error(s, EIO, "No matching cipher suites");
495         }
496
497         s->cipher_suite = select_cipher_suite(suites, s->preferred_suite, data[1] & 0xf);
498
499         // Make a copy of the KEX message, send_sig() and receive_sig() need it
500         if(s->hiskex) {
501                 return error(s, EINVAL, "Received a second KEX message before first has been processed");
502         }
503
504         s->hiskex = new_sptps_kex();
505         memcpy(s->hiskex, data, sizeof(sptps_kex_t));
506
507         if(s->initiator) {
508                 return send_sig(s);
509         } else {
510                 return true;
511         }
512 }
513
514 // Receive a SIGnature record, verify it, if it passed, compute the shared secret and calculate the session keys.
515 static bool receive_sig(sptps_t *s, const uint8_t *data, uint16_t len) {
516         // Verify length of KEX record.
517         if(len != ecdsa_size(s->hiskey)) {
518                 return error(s, EIO, "Invalid KEX record length");
519         }
520
521         // Concatenate both KEX messages, plus tag indicating if it is from the connection originator
522         const size_t msglen = sigmsg_len(s->labellen);
523         uint8_t *msg = alloca(msglen);
524         fill_msg(msg, !s->initiator, s->hiskex, s->mykex, s);
525
526         // Verify signature.
527         if(!ecdsa_verify(s->hiskey, msg, msglen, data)) {
528                 return error(s, EIO, "Failed to verify SIG record");
529         }
530
531         // Compute shared secret.
532         uint8_t shared[ECDH_SHARED_SIZE];
533
534         if(!ecdh_compute_shared(s->ecdh, s->hiskex->pubkey, shared)) {
535                 memzero(shared, sizeof(shared));
536                 return error(s, EINVAL, "Failed to compute ECDH shared secret");
537         }
538
539         s->ecdh = NULL;
540
541         // Generate key material from shared secret.
542         bool generated = generate_key_material(s, shared, sizeof(shared));
543         memzero(shared, sizeof(shared));
544
545         if(!generated) {
546                 return false;
547         }
548
549         if(!s->initiator && !send_sig(s)) {
550                 return false;
551         }
552
553         free_sptps_kex(s->mykex);
554         s->mykex = NULL;
555
556         free_sptps_kex(s->hiskex);
557         s->hiskex = NULL;
558
559         // Send cipher change record
560         if(s->outstate && !send_ack(s)) {
561                 return false;
562         }
563
564         if(!cipher_init(s->cipher_suite, &s->outcipher, s->key, !s->initiator)) {
565                 return error(s, EINVAL, "Failed to initialize cipher");
566         }
567
568         return true;
569 }
570
571 // Force another Key EXchange (for testing purposes).
572 bool sptps_force_kex(sptps_t *s) {
573         if(!s->outstate || s->state != SPTPS_SECONDARY_KEX) {
574                 return error(s, EINVAL, "Cannot force KEX in current state");
575         }
576
577         s->state = SPTPS_KEX;
578         return send_kex(s);
579 }
580
581 // Receive a handshake record.
582 static bool receive_handshake(sptps_t *s, const uint8_t *data, uint16_t len) {
583         // Only a few states to deal with handshaking.
584         switch(s->state) {
585         case SPTPS_SECONDARY_KEX:
586
587                 // We receive a secondary KEX request, first respond by sending our own.
588                 if(!send_kex(s)) {
589                         return false;
590                 }
591
592         // Fall through
593         case SPTPS_KEX:
594
595                 // We have sent our KEX request, we expect our peer to sent one as well.
596                 if(!receive_kex(s, data, len)) {
597                         return false;
598                 }
599
600                 s->state = SPTPS_SIG;
601                 return true;
602
603         case SPTPS_SIG:
604
605                 // If we already sent our secondary public ECDH key, we expect the peer to send his.
606                 if(!receive_sig(s, data, len)) {
607                         return false;
608                 }
609
610                 if(s->outstate) {
611                         s->state = SPTPS_ACK;
612                 } else {
613                         s->outstate = true;
614
615                         if(!receive_ack(s, NULL, 0)) {
616                                 return false;
617                         }
618
619                         s->receive_record(s->handle, SPTPS_HANDSHAKE, NULL, 0);
620                         s->state = SPTPS_SECONDARY_KEX;
621                 }
622
623                 return true;
624
625         case SPTPS_ACK:
626
627                 // We expect a handshake message to indicate transition to the new keys.
628                 if(!receive_ack(s, data, len)) {
629                         return false;
630                 }
631
632                 s->receive_record(s->handle, SPTPS_HANDSHAKE, NULL, 0);
633                 s->state = SPTPS_SECONDARY_KEX;
634                 return true;
635
636         // TODO: split ACK into a VERify and ACK?
637         default:
638                 return error(s, EIO, "Invalid session state %d", s->state);
639         }
640 }
641
642 static bool sptps_check_seqno(sptps_t *s, uint32_t seqno, bool update_state) {
643         // Replay protection using a sliding window of configurable size.
644         // s->inseqno is expected sequence number
645         // seqno is received sequence number
646         // s->late[] is a circular buffer, a 1 bit means a packet has not been received yet
647         // The circular buffer contains bits for sequence numbers from s->inseqno - s->replaywin * 8 to (but excluding) s->inseqno.
648         if(s->replaywin) {
649                 if(seqno != s->inseqno) {
650                         if(seqno >= s->inseqno + s->replaywin * 8) {
651                                 // Prevent packets that jump far ahead of the queue from causing many others to be dropped.
652                                 bool farfuture = s->farfuture < s->replaywin >> 2;
653
654                                 if(update_state) {
655                                         s->farfuture++;
656                                 }
657
658                                 if(farfuture) {
659                                         return update_state ? error(s, EIO, "Packet is %d seqs in the future, dropped (%u)\n", seqno - s->inseqno, s->farfuture) : false;
660                                 }
661
662                                 // Unless we have seen lots of them, in which case we consider the others lost.
663                                 if(update_state) {
664                                         warning(s, "Lost %d packets\n", seqno - s->inseqno);
665                                 }
666
667                                 if(update_state) {
668                                         // Mark all packets in the replay window as being late.
669                                         memset(s->late, 255, s->replaywin);
670                                 }
671                         } else if(seqno < s->inseqno) {
672                                 // If the sequence number is farther in the past than the bitmap goes, or if the packet was already received, drop it.
673                                 if((s->inseqno >= s->replaywin * 8 && seqno < s->inseqno - s->replaywin * 8) || !(s->late[(seqno / 8) % s->replaywin] & (1 << seqno % 8))) {
674                                         return update_state ? error(s, EIO, "Received late or replayed packet, seqno %d, last received %d\n", seqno, s->inseqno) : false;
675                                 }
676                         } else if(update_state) {
677                                 // We missed some packets. Mark them in the bitmap as being late.
678                                 for(uint32_t i = s->inseqno; i < seqno; i++) {
679                                         s->late[(i / 8) % s->replaywin] |= 1 << i % 8;
680                                 }
681                         }
682                 }
683
684                 if(update_state) {
685                         // Mark the current packet as not being late.
686                         s->late[(seqno / 8) % s->replaywin] &= ~(1 << seqno % 8);
687                         s->farfuture = 0;
688                 }
689         }
690
691         if(update_state) {
692                 if(seqno >= s->inseqno) {
693                         s->inseqno = seqno + 1;
694                 }
695
696                 if(!s->inseqno) {
697                         s->received = 0;
698                 } else {
699                         s->received++;
700                 }
701         }
702
703         return true;
704 }
705
706 // Check datagram for valid HMAC
707 bool sptps_verify_datagram(sptps_t *s, const void *vdata, size_t len) {
708         if(!s->instate || len < 21) {
709                 return error(s, EIO, "Received short packet");
710         }
711
712         const uint8_t *data = vdata;
713         uint32_t seqno;
714         memcpy(&seqno, data, 4);
715
716         if(!sptps_check_seqno(s, seqno, false)) {
717                 return false;
718         }
719
720         uint8_t *buffer = alloca(len);
721         return cipher_decrypt(s->cipher_suite, s->incipher, seqno, data + 4, len - 4, buffer, NULL);
722 }
723
724 // Receive incoming data, datagram version.
725 static bool sptps_receive_data_datagram(sptps_t *s, const uint8_t *data, size_t len) {
726         if(len < (s->instate ? 21 : 5)) {
727                 return error(s, EIO, "Received short packet");
728         }
729
730         uint32_t seqno;
731         memcpy(&seqno, data, 4);
732         data += 4;
733         len -= 4;
734
735         if(!s->instate) {
736                 if(seqno != s->inseqno) {
737                         return error(s, EIO, "Invalid packet seqno: %d != %d", seqno, s->inseqno);
738                 }
739
740                 s->inseqno = seqno + 1;
741
742                 uint8_t type = *(data++);
743                 len--;
744
745                 if(type != SPTPS_HANDSHAKE) {
746                         return error(s, EIO, "Application record received before handshake finished");
747                 }
748
749                 return receive_handshake(s, data, len);
750         }
751
752         // Decrypt
753
754         uint8_t *buffer = alloca(len);
755         size_t outlen;
756
757         if(!cipher_decrypt(s->cipher_suite, s->incipher, seqno, data, len, buffer, &outlen)) {
758                 return error(s, EIO, "Failed to decrypt and verify packet");
759         }
760
761         if(!sptps_check_seqno(s, seqno, true)) {
762                 return false;
763         }
764
765         // Append a NULL byte for safety.
766         buffer[outlen] = 0;
767
768         data = buffer;
769         len = outlen;
770
771         uint8_t type = *(data++);
772         len--;
773
774         if(type < SPTPS_HANDSHAKE) {
775                 if(!s->instate) {
776                         return error(s, EIO, "Application record received before handshake finished");
777                 }
778
779                 if(!s->receive_record(s->handle, type, data, len)) {
780                         return false;
781                 }
782         } else if(type == SPTPS_HANDSHAKE) {
783                 if(!receive_handshake(s, data, len)) {
784                         return false;
785                 }
786         } else {
787                 return error(s, EIO, "Invalid record type %d", type);
788         }
789
790         return true;
791 }
792
793 // Receive incoming data. Check if it contains a complete record, if so, handle it.
794 size_t sptps_receive_data(sptps_t *s, const void *vdata, size_t len) {
795         const uint8_t *data = vdata;
796         size_t total_read = 0;
797
798         if(!s->state) {
799                 return error(s, EIO, "Invalid session state zero");
800         }
801
802         if(s->datagram) {
803                 return sptps_receive_data_datagram(s, data, len) ? len : false;
804         }
805
806         // First read the 2 length bytes.
807         if(s->buflen < 2) {
808                 size_t toread = 2 - s->buflen;
809
810                 if(toread > len) {
811                         toread = len;
812                 }
813
814                 memcpy(s->inbuf + s->buflen, data, toread);
815
816                 total_read += toread;
817                 s->buflen += toread;
818                 len -= toread;
819                 data += toread;
820
821                 // Exit early if we don't have the full length.
822                 if(s->buflen < 2) {
823                         return total_read;
824                 }
825
826                 // Get the length bytes
827
828                 memcpy(&s->reclen, s->inbuf, 2);
829
830                 // If we have the length bytes, ensure our buffer can hold the whole request.
831                 s->inbuf = realloc(s->inbuf, s->reclen + SPTPS_OVERHEAD);
832
833                 if(!s->inbuf) {
834                         return error(s, errno, "%s", strerror(errno));
835                 }
836
837                 // Exit early if we have no more data to process.
838                 if(!len) {
839                         return total_read;
840                 }
841         }
842
843         // Read up to the end of the record.
844         size_t toread = s->reclen + (s->instate ? SPTPS_OVERHEAD : SPTPS_HEADER) - s->buflen;
845
846         if(toread > len) {
847                 toread = len;
848         }
849
850         memcpy(s->inbuf + s->buflen, data, toread);
851         total_read += toread;
852         s->buflen += toread;
853
854         // If we don't have a whole record, exit.
855         if(s->buflen < s->reclen + (s->instate ? SPTPS_OVERHEAD : SPTPS_HEADER)) {
856                 return total_read;
857         }
858
859         // Update sequence number.
860
861         uint32_t seqno = s->inseqno++;
862
863         // Check HMAC and decrypt.
864         if(s->instate) {
865                 if(!cipher_decrypt(s->cipher_suite, s->incipher, seqno, s->inbuf + 2UL, s->reclen + 17UL, s->inbuf + 2UL, NULL)) {
866                         return error(s, EINVAL, "Failed to decrypt and verify record");
867                 }
868         }
869
870         // Append a NULL byte for safety.
871         s->inbuf[s->reclen + SPTPS_HEADER] = 0;
872
873         uint8_t type = s->inbuf[2];
874
875         if(type < SPTPS_HANDSHAKE) {
876                 if(!s->instate) {
877                         return error(s, EIO, "Application record received before handshake finished");
878                 }
879
880                 if(!s->receive_record(s->handle, type, s->inbuf + 3, s->reclen)) {
881                         return false;
882                 }
883         } else if(type == SPTPS_HANDSHAKE) {
884                 if(!receive_handshake(s, s->inbuf + 3, s->reclen)) {
885                         return false;
886                 }
887         } else {
888                 return error(s, EIO, "Invalid record type %d", type);
889         }
890
891         s->buflen = 0;
892
893         return total_read;
894 }
895
896 // Start a SPTPS session.
897 bool sptps_start(sptps_t *s, const sptps_params_t *params) {
898         // Initialise struct sptps
899         memset(s, 0, sizeof(*s));
900
901         s->handle = params->handle;
902         s->initiator = params->initiator;
903         s->datagram = params->datagram;
904         s->mykey = params->mykey;
905         s->hiskey = params->hiskey;
906         s->replaywin = sptps_replaywin;
907         s->cipher_suites = params->cipher_suites ? params->cipher_suites & SPTPS_ALL_CIPHER_SUITES : SPTPS_ALL_CIPHER_SUITES;
908         s->preferred_suite = params->preferred_suite;
909
910         if(s->replaywin) {
911                 s->late = malloc(s->replaywin);
912
913                 if(!s->late) {
914                         return error(s, errno, "%s", strerror(errno));
915                 }
916
917                 memset(s->late, 0, s->replaywin);
918         }
919
920         s->labellen = params->labellen ? params->labellen : strlen(params->label);
921         s->label = malloc(s->labellen);
922
923         if(!s->label) {
924                 return error(s, errno, "%s", strerror(errno));
925         }
926
927         memcpy(s->label, params->label, s->labellen);
928
929         if(!s->datagram) {
930                 s->inbuf = malloc(7);
931
932                 if(!s->inbuf) {
933                         return error(s, errno, "%s", strerror(errno));
934                 }
935
936                 s->buflen = 0;
937         }
938
939
940         s->send_data = params->send_data;
941         s->receive_record = params->receive_record;
942
943         // Do first KEX immediately
944         s->state = SPTPS_KEX;
945         return send_kex(s);
946 }
947
948 // Stop a SPTPS session.
949 bool sptps_stop(sptps_t *s) {
950         // Clean up any resources.
951         cipher_exit(s->cipher_suite, s->incipher);
952         cipher_exit(s->cipher_suite, s->outcipher);
953         ecdh_free(s->ecdh);
954         free(s->inbuf);
955         free_sptps_kex(s->mykex);
956         free_sptps_kex(s->hiskex);
957         free_sptps_key(s->key);
958         free(s->label);
959         free(s->late);
960         memset(s, 0, sizeof(*s));
961         return true;
962 }