Wipe (some) secrets from memory after use
[tinc] / src / sptps_test.c
1 /*
2     sptps_test.c -- Simple Peer-to-Peer Security test program
3     Copyright (C) 2011-2022 Guus Sliepen <guus@tinc-vpn.org>
4
5     This program is free software; you can redistribute it and/or modify
6     it under the terms of the GNU General Public License as published by
7     the Free Software Foundation; either version 2 of the License, or
8     (at your option) any later version.
9
10     This program is distributed in the hope that it will be useful,
11     but WITHOUT ANY WARRANTY; without even the implied warranty of
12     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13     GNU General Public License for more details.
14
15     You should have received a copy of the GNU General Public License along
16     with this program; if not, write to the Free Software Foundation, Inc.,
17     51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
18 */
19
20 #include "system.h"
21
22 #ifdef HAVE_LINUX
23 #include <linux/if_tun.h>
24 #endif
25
26 #include "crypto.h"
27 #include "ecdsa.h"
28 #include "meta.h"
29 #include "protocol.h"
30 #include "sptps.h"
31 #include "utils.h"
32 #include "names.h"
33 #include "random.h"
34
35 #ifndef HAVE_WINDOWS
36 #define closesocket(s) close(s)
37 #endif
38
39 // Symbols necessary to link with logger.o
40 bool send_request(struct connection_t *c, const char *msg, ...) {
41         (void)c;
42         (void)msg;
43         return false;
44 }
45
46 list_t connection_list;
47
48 bool send_meta(struct connection_t *c, const void *msg, size_t len) {
49         (void)c;
50         (void)msg;
51         (void)len;
52         return false;
53 }
54
55 bool do_detach = false;
56 struct timeval now;
57
58 static bool special;
59 static bool verbose;
60 static bool readonly;
61 static bool writeonly;
62 static int in = 0;
63 static int out = 1;
64 int addressfamily = AF_UNSPEC;
65
66 static bool send_data(void *handle, uint8_t type, const void *data, size_t len) {
67         (void)type;
68         char *hex = alloca(len * 2 + 1);
69         bin2hex(data, hex, len);
70
71         if(verbose) {
72                 fprintf(stderr, "Sending %lu bytes of data:\n%s\n", (unsigned long)len, hex);
73         }
74
75         const int *sock = handle;
76         const char *p = data;
77
78         while(len) {
79                 ssize_t sent = send(*sock, p, len, 0);
80
81                 if(sent <= 0) {
82                         fprintf(stderr, "Error sending data: %s\n", strerror(errno));
83                         return false;
84                 }
85
86                 p += sent;
87                 len -= sent;
88         }
89
90         return true;
91 }
92
93 static bool receive_record(void *handle, uint8_t type, const void *data, uint16_t len) {
94         (void)handle;
95
96         if(verbose) {
97                 fprintf(stderr, "Received type %d record of %u bytes:\n", type, len);
98         }
99
100         if(writeonly) {
101                 return true;
102         }
103
104         const char *p = data;
105
106         while(len) {
107                 ssize_t written = write(out, p, len);
108
109                 if(written <= 0) {
110                         fprintf(stderr, "Error writing received data: %s\n", strerror(errno));
111                         return false;
112                 }
113
114                 p += written;
115                 len -= written;
116         }
117
118         return true;
119 }
120
121 static struct option const long_options[] = {
122         {"datagram", no_argument, NULL, 'd'},
123         {"quit", no_argument, NULL, 'q'},
124         {"readonly", no_argument, NULL, 'r'},
125         {"writeonly", no_argument, NULL, 'w'},
126         {"packet-loss", required_argument, NULL, 'L'},
127         {"replay-window", required_argument, NULL, 'W'},
128         {"special", no_argument, NULL, 's'},
129         {"verbose", required_argument, NULL, 'v'},
130         {"help", no_argument, NULL, 1},
131         {NULL, 0, NULL, 0}
132 };
133
134 static void usage(void) {
135         static const char *message =
136                 "Usage: %s [options] my_ed25519_key_file his_ed25519_key_file [host] port\n"
137                 "\n"
138                 "Valid options are:\n"
139                 "  -d, --datagram          Enable datagram mode.\n"
140                 "  -q, --quit              Quit when EOF occurs on stdin.\n"
141                 "  -r, --readonly          Only send data from the socket to stdout.\n"
142 #ifdef HAVE_LINUX
143                 "  -t, --tun               Use a tun device instead of stdio.\n"
144 #endif
145                 "  -w, --writeonly         Only send data from stdin to the socket.\n"
146                 "  -L, --packet-loss RATE  Fake packet loss of RATE percent.\n"
147                 "  -R, --replay-window N   Set replay window to N bytes.\n"
148                 "  -s, --special           Enable special handling of lines starting with #, ^ and $.\n"
149                 "  -v, --verbose           Display debug messages.\n"
150                 "  -4                      Use IPv4.\n"
151                 "  -6                      Use IPv6.\n"
152                 "\n"
153                 "Report bugs to tinc@tinc-vpn.org.\n";
154
155         fprintf(stderr, message, program_name);
156 }
157
158 #ifdef HAVE_WINDOWS
159
160 int stdin_sock_fd = -1;
161
162 // Windows does not allow calling select() on anything but sockets. Therefore,
163 // to keep the same code as on other operating systems, we have to put a
164 // separate thread between the stdin and the sptps loop way below. This thread
165 // reads stdin and sends its content to the main thread through a TCP socket,
166 // which can be properly select()'ed.
167 static DWORD WINAPI stdin_reader_thread(LPVOID arg) {
168         struct sockaddr_in sa;
169         socklen_t sa_size = sizeof(sa);
170
171         while(true) {
172                 int peer_fd = accept(stdin_sock_fd, (struct sockaddr *) &sa, &sa_size);
173
174                 if(peer_fd < 0) {
175                         fprintf(stderr, "accept() failed: %s\n", strerror(errno));
176                         continue;
177                 }
178
179                 if(verbose) {
180                         fprintf(stderr, "New connection received from :%d\n", ntohs(sa.sin_port));
181                 }
182
183                 char buf[1024];
184                 ssize_t nread;
185
186                 while((nread = read(STDIN_FILENO, buf, sizeof(buf))) > 0) {
187                         if(verbose) {
188                                 fprintf(stderr, "Read %lld bytes from input\n", nread);
189                         }
190
191                         char *start = buf;
192                         ssize_t nleft = nread;
193
194                         while(nleft) {
195                                 ssize_t nsend = send(peer_fd, start, nleft, 0);
196
197                                 if(nsend < 0) {
198                                         if(sockwouldblock(sockerrno)) {
199                                                 continue;
200                                         }
201
202                                         break;
203                                 }
204
205                                 start += nsend;
206                                 nleft -= nsend;
207                         }
208
209                         if(nleft) {
210                                 fprintf(stderr, "Could not send data: %s\n", strerror(errno));
211                                 break;
212                         }
213
214                         if(verbose) {
215                                 fprintf(stderr, "Sent %lld bytes to peer\n", nread);
216                         }
217                 }
218
219                 closesocket(peer_fd);
220         }
221
222         closesocket(stdin_sock_fd);
223         stdin_sock_fd = -1;
224         return 0;
225 }
226
227 static int start_input_reader(void) {
228         if(stdin_sock_fd != -1) {
229                 fprintf(stderr, "stdin thread can only be started once.\n");
230                 return -1;
231         }
232
233         stdin_sock_fd = socket(AF_INET, SOCK_STREAM, 0);
234
235         if(stdin_sock_fd < 0) {
236                 fprintf(stderr, "Could not create server socket: %s\n", strerror(errno));
237                 return -1;
238         }
239
240         struct sockaddr_in serv_sa;
241
242         memset(&serv_sa, 0, sizeof(serv_sa));
243
244         serv_sa.sin_family = AF_INET;
245
246         serv_sa.sin_addr.s_addr = htonl(0x7f000001); // 127.0.0.1
247
248         int res = bind(stdin_sock_fd, (struct sockaddr *)&serv_sa, sizeof(serv_sa));
249
250         if(res < 0) {
251                 fprintf(stderr, "Could not bind socket: %s\n", strerror(errno));
252                 goto server_err;
253         }
254
255         if(listen(stdin_sock_fd, 1) < 0) {
256                 fprintf(stderr, "Could not listen: %s\n", strerror(errno));
257                 goto server_err;
258         }
259
260         struct sockaddr_in connect_sa;
261
262         socklen_t addr_len = sizeof(connect_sa);
263
264         if(getsockname(stdin_sock_fd, (struct sockaddr *)&connect_sa, &addr_len) < 0) {
265                 fprintf(stderr, "Could not determine the address of the stdin thread socket\n");
266                 goto server_err;
267         }
268
269         if(verbose) {
270                 fprintf(stderr, "stdin thread is listening on :%d\n", ntohs(connect_sa.sin_port));
271         }
272
273         if(!CreateThread(NULL, 0, stdin_reader_thread, NULL, 0, NULL)) {
274                 fprintf(stderr, "Could not start reader thread: %d\n", GetLastError());
275                 goto server_err;
276         }
277
278         int client_fd = socket(AF_INET, SOCK_STREAM, 0);
279
280         if(client_fd < 0) {
281                 fprintf(stderr, "Could not create client socket: %s\n", strerror(errno));
282                 return -1;
283         }
284
285         if(connect(client_fd, (struct sockaddr *)&connect_sa, sizeof(connect_sa)) < 0) {
286                 fprintf(stderr, "Could not connect: %s\n", strerror(errno));
287                 closesocket(client_fd);
288                 return -1;
289         }
290
291         return client_fd;
292
293 server_err:
294
295         if(stdin_sock_fd != -1) {
296                 closesocket(stdin_sock_fd);
297                 stdin_sock_fd = -1;
298         }
299
300         return -1;
301 }
302
303 #endif // HAVE_WINDOWS
304
305 static void print_listening_msg(int sock) {
306         sockaddr_t sa = {0};
307         socklen_t salen = sizeof(sa);
308         int port = 0;
309
310         if(!getsockname(sock, &sa.sa, &salen)) {
311                 port = ntohs(sa.in.sin_port);
312         }
313
314         fprintf(stderr, "Listening on %d...\n", port);
315         fflush(stderr);
316 }
317
318 static int run_test(int argc, char *argv[]) {
319         program_name = argv[0];
320         bool initiator = false;
321         bool datagram = false;
322 #ifdef HAVE_LINUX
323         bool tun = false;
324 #endif
325         int packetloss = 0;
326         int r;
327         int option_index = 0;
328         bool quit = false;
329
330         while((r = getopt_long(argc, argv, "dqrstwL:W:v46", long_options, &option_index)) != EOF) {
331                 switch(r) {
332                 case 0:   /* long option */
333                         break;
334
335                 case 'd': /* datagram mode */
336                         datagram = true;
337                         break;
338
339                 case 'q': /* close connection on EOF from stdin */
340                         quit = true;
341                         break;
342
343                 case 'r': /* read only */
344                         readonly = true;
345                         break;
346
347                 case 't': /* read only */
348 #ifdef HAVE_LINUX
349                         tun = true;
350 #else
351                         fprintf(stderr, "--tun is only supported on Linux.\n");
352                         usage();
353                         return 1;
354 #endif
355                         break;
356
357                 case 'w': /* write only */
358                         writeonly = true;
359                         break;
360
361                 case 'L': /* packet loss rate */
362                         packetloss = atoi(optarg);
363                         break;
364
365                 case 'W': /* replay window size */
366                         sptps_replaywin = atoi(optarg);
367                         break;
368
369                 case 'v': /* be verbose */
370                         verbose = true;
371                         break;
372
373                 case 's': /* special character handling */
374                         special = true;
375                         break;
376
377                 case '?': /* wrong options */
378                         usage();
379                         return 1;
380
381                 case '4': /* IPv4 */
382                         addressfamily = AF_INET;
383                         break;
384
385                 case '6': /* IPv6 */
386                         addressfamily = AF_INET6;
387                         break;
388
389                 case 1: /* help */
390                         usage();
391                         return 0;
392
393                 default:
394                         break;
395                 }
396         }
397
398         argc -= optind - 1;
399         argv += optind - 1;
400
401         if(argc < 4 || argc > 5) {
402                 fprintf(stderr, "Wrong number of arguments.\n");
403                 usage();
404                 return 1;
405         }
406
407         if(argc > 4) {
408                 initiator = true;
409         }
410
411 #ifdef HAVE_LINUX
412
413         if(tun) {
414                 in = out = open("/dev/net/tun", O_RDWR | O_NONBLOCK);
415
416                 if(in < 0) {
417                         fprintf(stderr, "Could not open tun device: %s\n", strerror(errno));
418                         return 1;
419                 }
420
421                 struct ifreq ifr = {
422                         .ifr_flags = IFF_TUN
423                 };
424
425                 if(ioctl(in, TUNSETIFF, &ifr)) {
426                         fprintf(stderr, "Could not configure tun interface: %s\n", strerror(errno));
427                         return 1;
428                 }
429
430                 ifr.ifr_name[IFNAMSIZ - 1] = 0;
431                 fprintf(stderr, "Using tun interface %s\n", ifr.ifr_name);
432         }
433
434 #endif
435
436 #ifdef HAVE_WINDOWS
437         static struct WSAData wsa_state;
438
439         if(WSAStartup(MAKEWORD(2, 2), &wsa_state)) {
440                 return 1;
441         }
442
443 #endif
444
445         struct addrinfo *ai, hint;
446         memset(&hint, 0, sizeof(hint));
447
448         hint.ai_family = addressfamily;
449         hint.ai_socktype = datagram ? SOCK_DGRAM : SOCK_STREAM;
450         hint.ai_protocol = datagram ? IPPROTO_UDP : IPPROTO_TCP;
451         hint.ai_flags = initiator ? 0 : AI_PASSIVE;
452
453         if(getaddrinfo(initiator ? argv[3] : NULL, initiator ? argv[4] : argv[3], &hint, &ai) || !ai) {
454                 fprintf(stderr, "getaddrinfo() failed: %s\n", sockstrerror(sockerrno));
455                 return 1;
456         }
457
458         int sock = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
459
460         if(sock < 0) {
461                 fprintf(stderr, "Could not create socket: %s\n", sockstrerror(sockerrno));
462                 freeaddrinfo(ai);
463                 return 1;
464         }
465
466         int one = 1;
467         setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (void *)&one, sizeof(one));
468
469         if(initiator) {
470                 int res = connect(sock, ai->ai_addr, ai->ai_addrlen);
471
472                 freeaddrinfo(ai);
473                 ai = NULL;
474
475                 if(res) {
476                         fprintf(stderr, "Could not connect to peer: %s\n", sockstrerror(sockerrno));
477                         return 1;
478                 }
479
480                 fprintf(stderr, "Connected\n");
481         } else {
482                 int res = bind(sock, ai->ai_addr, ai->ai_addrlen);
483
484                 freeaddrinfo(ai);
485                 ai = NULL;
486
487                 if(res) {
488                         fprintf(stderr, "Could not bind socket: %s\n", sockstrerror(sockerrno));
489                         return 1;
490                 }
491
492                 if(!datagram) {
493                         if(listen(sock, 1)) {
494                                 fprintf(stderr, "Could not listen on socket: %s\n", sockstrerror(sockerrno));
495                                 return 1;
496                         }
497
498                         print_listening_msg(sock);
499
500                         sock = accept(sock, NULL, NULL);
501
502                         if(sock < 0) {
503                                 fprintf(stderr, "Could not accept connection: %s\n", sockstrerror(sockerrno));
504                                 return 1;
505                         }
506                 } else {
507                         print_listening_msg(sock);
508
509                         char buf[65536];
510                         struct sockaddr addr;
511                         socklen_t addrlen = sizeof(addr);
512
513                         if(recvfrom(sock, buf, sizeof(buf), MSG_PEEK, &addr, &addrlen) <= 0) {
514                                 fprintf(stderr, "Could not read from socket: %s\n", sockstrerror(sockerrno));
515                                 return 1;
516                         }
517
518                         if(connect(sock, &addr, addrlen)) {
519                                 fprintf(stderr, "Could not accept connection: %s\n", sockstrerror(sockerrno));
520                                 return 1;
521                         }
522                 }
523
524                 fprintf(stderr, "Connected\n");
525         }
526
527         FILE *fp = fopen(argv[1], "r");
528
529         if(!fp) {
530                 fprintf(stderr, "Could not open %s: %s\n", argv[1], strerror(errno));
531                 return 1;
532         }
533
534         ecdsa_t *mykey = NULL;
535
536         if(!(mykey = ecdsa_read_pem_private_key(fp))) {
537                 return 1;
538         }
539
540         fclose(fp);
541
542         fp = fopen(argv[2], "r");
543
544         if(!fp) {
545                 fprintf(stderr, "Could not open %s: %s\n", argv[2], strerror(errno));
546                 ecdsa_free(mykey);
547                 return 1;
548         }
549
550         ecdsa_t *hiskey = NULL;
551
552         if(!(hiskey = ecdsa_read_pem_public_key(fp))) {
553                 ecdsa_free(mykey);
554                 return 1;
555         }
556
557         fclose(fp);
558
559         if(verbose) {
560                 fprintf(stderr, "Keys loaded\n");
561         }
562
563         sptps_t s;
564
565         if(!sptps_start(&s, &sock, initiator, datagram, mykey, hiskey, "sptps_test", 10, send_data, receive_record)) {
566                 ecdsa_free(mykey);
567                 ecdsa_free(hiskey);
568                 return 1;
569         }
570
571 #ifdef HAVE_WINDOWS
572
573         if(!readonly) {
574                 in = start_input_reader();
575
576                 if(in < 0) {
577                         fprintf(stderr, "Could not init stdin reader thread\n");
578                         ecdsa_free(mykey);
579                         ecdsa_free(hiskey);
580                         return 1;
581                 }
582         }
583
584 #endif
585
586         int max_fd = MAX(sock, in);
587
588         while(true) {
589                 if(writeonly && readonly) {
590                         break;
591                 }
592
593                 char buf[65535] = "";
594                 size_t readsize = datagram ? 1460u : sizeof(buf);
595
596                 fd_set fds;
597                 FD_ZERO(&fds);
598
599                 if(!readonly && s.instate) {
600                         FD_SET(in, &fds);
601                 }
602
603                 FD_SET(sock, &fds);
604
605                 if(select(max_fd + 1, &fds, NULL, NULL, NULL) <= 0) {
606                         ecdsa_free(mykey);
607                         ecdsa_free(hiskey);
608                         return 1;
609                 }
610
611                 if(FD_ISSET(in, &fds)) {
612 #ifdef HAVE_WINDOWS
613                         ssize_t len = recv(in, buf, readsize, 0);
614 #else
615                         ssize_t len = read(in, buf, readsize);
616 #endif
617
618                         if(len < 0) {
619                                 fprintf(stderr, "Could not read from stdin: %s\n", strerror(errno));
620                                 ecdsa_free(mykey);
621                                 ecdsa_free(hiskey);
622                                 return 1;
623                         }
624
625                         if(len == 0) {
626 #ifdef HAVE_WINDOWS
627                                 shutdown(in, SD_SEND);
628                                 closesocket(in);
629 #endif
630
631                                 if(quit) {
632                                         break;
633                                 }
634
635                                 readonly = true;
636                                 continue;
637                         }
638
639                         if(special && buf[0] == '#') {
640                                 s.outseqno = atoi(buf + 1);
641                         }
642
643                         if(special && buf[0] == '^') {
644                                 sptps_send_record(&s, SPTPS_HANDSHAKE, NULL, 0);
645                         } else if(special && buf[0] == '$') {
646                                 sptps_force_kex(&s);
647
648                                 if(len > 1) {
649                                         sptps_send_record(&s, 0, buf, len);
650                                 }
651                         } else if(!sptps_send_record(&s, buf[0] == '!' ? 1 : 0, buf, (len == 1 && buf[0] == '\n') ? 0 : buf[0] == '*' ? sizeof(buf) : (size_t)len)) {
652                                 ecdsa_free(mykey);
653                                 ecdsa_free(hiskey);
654                                 return 1;
655                         }
656                 }
657
658                 if(FD_ISSET(sock, &fds)) {
659                         ssize_t len = recv(sock, buf, sizeof(buf), 0);
660
661                         if(len < 0) {
662                                 fprintf(stderr, "Could not read from socket: %s\n", sockstrerror(sockerrno));
663                                 ecdsa_free(mykey);
664                                 ecdsa_free(hiskey);
665                                 return 1;
666                         }
667
668                         if(len == 0) {
669                                 fprintf(stderr, "Connection terminated by peer.\n");
670                                 break;
671                         }
672
673                         if(verbose) {
674                                 char *hex = alloca(len * 2 + 1);
675                                 bin2hex(buf, hex, len);
676                                 fprintf(stderr, "Received %ld bytes of data:\n%s\n", (long)len, hex);
677                         }
678
679                         if(packetloss && (int)prng(100) < packetloss) {
680                                 if(verbose) {
681                                         fprintf(stderr, "Dropped.\n");
682                                 }
683
684                                 continue;
685                         }
686
687                         char *bufp = buf;
688
689                         while(len) {
690                                 size_t done = sptps_receive_data(&s, bufp, len);
691
692                                 if(!done) {
693                                         if(!datagram) {
694                                                 ecdsa_free(mykey);
695                                                 ecdsa_free(hiskey);
696                                                 return 1;
697                                         }
698                                 }
699
700                                 bufp += done;
701                                 len -= (ssize_t) done;
702                         }
703                 }
704         }
705
706         bool stopped = sptps_stop(&s);
707
708         ecdsa_free(mykey);
709         ecdsa_free(hiskey);
710         closesocket(sock);
711
712         return !stopped;
713 }
714
715 int main(int argc, char *argv[]) {
716         random_init();
717         crypto_init();
718         prng_init();
719
720         int result = run_test(argc, argv);
721
722         random_exit();
723
724         return result;
725 }