Update the built-in Chacha20-Poly1305 code to an RFC 7539 complaint version.
[tinc] / src / chacha-poly1305 / chacha.c
1 /*
2 chacha-merged.c version 20080118
3 D. J. Bernstein
4 Public domain.
5 */
6
7 #include "chacha.h"
8
9 #define U8C(v) (v##U)
10 #define U32C(v) (v##U)
11
12 #define U8V(v) ((unsigned char)(v) & U8C(0xFF))
13 #define U32V(v) ((uint32_t)(v) & U32C(0xFFFFFFFF))
14
15 #define ROTL32(v, n) \
16         (U32V((v) << (n)) | ((v) >> (32 - (n))))
17
18 #if (USE_UNALIGNED == 1)
19 #define U8TO32_LITTLE(p) \
20         (*((uint32_t *)(p)))
21 #define U32TO8_LITTLE(p, v) \
22         do { \
23                 *((uint32_t *)(p)) = v; \
24         } while (0)
25 #else
26 #define U8TO32_LITTLE(p) \
27         (((uint32_t)((p)[0])      ) | \
28          ((uint32_t)((p)[1]) <<  8) | \
29          ((uint32_t)((p)[2]) << 16) | \
30          ((uint32_t)((p)[3]) << 24))
31 #define U32TO8_LITTLE(p, v) \
32         do { \
33                 (p)[0] = U8V((v)      ); \
34                 (p)[1] = U8V((v) >>  8); \
35                 (p)[2] = U8V((v) >> 16); \
36                 (p)[3] = U8V((v) >> 24); \
37         } while (0)
38 #endif
39
40 #define ROTATE(v,c) (ROTL32(v,c))
41 #define XOR(v,w) ((v) ^ (w))
42 #define PLUS(v,w) (U32V((v) + (w)))
43 #define PLUSONE(v) (PLUS((v),1))
44
45 #define QUARTERROUND(a,b,c,d) \
46         a = PLUS(a,b); d = ROTATE(XOR(d,a),16); \
47         c = PLUS(c,d); b = ROTATE(XOR(b,c),12); \
48         a = PLUS(a,b); d = ROTATE(XOR(d,a), 8); \
49         c = PLUS(c,d); b = ROTATE(XOR(b,c), 7);
50
51 static const char sigma[16] = "expand 32-byte k";
52 static const char tau[16] = "expand 16-byte k";
53
54 void
55 chacha_keysetup(struct chacha_ctx *x, const unsigned char *k, uint32_t kbits) {
56         const char *constants;
57
58         x->input[4] = U8TO32_LITTLE(k + 0);
59         x->input[5] = U8TO32_LITTLE(k + 4);
60         x->input[6] = U8TO32_LITTLE(k + 8);
61         x->input[7] = U8TO32_LITTLE(k + 12);
62
63         if(kbits == 256) {  /* recommended */
64                 k += 16;
65                 constants = sigma;
66         } else { /* kbits == 128 */
67                 constants = tau;
68         }
69
70         x->input[8] = U8TO32_LITTLE(k + 0);
71         x->input[9] = U8TO32_LITTLE(k + 4);
72         x->input[10] = U8TO32_LITTLE(k + 8);
73         x->input[11] = U8TO32_LITTLE(k + 12);
74         x->input[0] = U8TO32_LITTLE(constants + 0);
75         x->input[1] = U8TO32_LITTLE(constants + 4);
76         x->input[2] = U8TO32_LITTLE(constants + 8);
77         x->input[3] = U8TO32_LITTLE(constants + 12);
78 }
79
80 void
81 chacha_ivsetup(struct chacha_ctx *x, const unsigned char *iv, const unsigned char *counter) {
82         x->input[12] = counter == NULL ? 0 : U8TO32_LITTLE(counter + 0);
83         //x->input[13] = counter == NULL ? 0 : U8TO32_LITTLE(counter + 4);
84         x->input[13] = U8TO32_LITTLE(iv + 0);
85         x->input[14] = U8TO32_LITTLE(iv + 4);
86         x->input[15] = U8TO32_LITTLE(iv + 8);
87 }
88
89 void
90 chacha_encrypt_bytes(struct chacha_ctx *x, const unsigned char *m, unsigned char *c, uint32_t bytes) {
91         uint32_t x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15;
92         uint32_t j0, j1, j2, j3, j4, j5, j6, j7, j8, j9, j10, j11, j12, j13, j14, j15;
93         unsigned char *ctarget = NULL;
94         unsigned char tmp[64];
95         uint32_t i;
96
97         if(!bytes) {
98                 return;
99         }
100
101         j0 = x->input[0];
102         j1 = x->input[1];
103         j2 = x->input[2];
104         j3 = x->input[3];
105         j4 = x->input[4];
106         j5 = x->input[5];
107         j6 = x->input[6];
108         j7 = x->input[7];
109         j8 = x->input[8];
110         j9 = x->input[9];
111         j10 = x->input[10];
112         j11 = x->input[11];
113         j12 = x->input[12];
114         j13 = x->input[13];
115         j14 = x->input[14];
116         j15 = x->input[15];
117
118         for(;;) {
119                 if(bytes < 64) {
120 #if (USE_MEMCPY == 1)
121                         memcpy(tmp, m, bytes);
122 #else
123
124                         for(i = 0; i < bytes; ++i) {
125                                 tmp[i] = m[i];
126                         }
127
128 #endif
129                         m = tmp;
130                         ctarget = c;
131                         c = tmp;
132                 }
133
134                 x0 = j0;
135                 x1 = j1;
136                 x2 = j2;
137                 x3 = j3;
138                 x4 = j4;
139                 x5 = j5;
140                 x6 = j6;
141                 x7 = j7;
142                 x8 = j8;
143                 x9 = j9;
144                 x10 = j10;
145                 x11 = j11;
146                 x12 = j12;
147                 x13 = j13;
148                 x14 = j14;
149                 x15 = j15;
150
151                 for(i = 20; i > 0; i -= 2) {
152                         QUARTERROUND(x0, x4, x8, x12)
153                         QUARTERROUND(x1, x5, x9, x13)
154                         QUARTERROUND(x2, x6, x10, x14)
155                         QUARTERROUND(x3, x7, x11, x15)
156                         QUARTERROUND(x0, x5, x10, x15)
157                         QUARTERROUND(x1, x6, x11, x12)
158                         QUARTERROUND(x2, x7, x8, x13)
159                         QUARTERROUND(x3, x4, x9, x14)
160                 }
161
162                 x0 = PLUS(x0, j0);
163                 x1 = PLUS(x1, j1);
164                 x2 = PLUS(x2, j2);
165                 x3 = PLUS(x3, j3);
166                 x4 = PLUS(x4, j4);
167                 x5 = PLUS(x5, j5);
168                 x6 = PLUS(x6, j6);
169                 x7 = PLUS(x7, j7);
170                 x8 = PLUS(x8, j8);
171                 x9 = PLUS(x9, j9);
172                 x10 = PLUS(x10, j10);
173                 x11 = PLUS(x11, j11);
174                 x12 = PLUS(x12, j12);
175                 x13 = PLUS(x13, j13);
176                 x14 = PLUS(x14, j14);
177                 x15 = PLUS(x15, j15);
178
179                 x0 = XOR(x0, U8TO32_LITTLE(m + 0));
180                 x1 = XOR(x1, U8TO32_LITTLE(m + 4));
181                 x2 = XOR(x2, U8TO32_LITTLE(m + 8));
182                 x3 = XOR(x3, U8TO32_LITTLE(m + 12));
183                 x4 = XOR(x4, U8TO32_LITTLE(m + 16));
184                 x5 = XOR(x5, U8TO32_LITTLE(m + 20));
185                 x6 = XOR(x6, U8TO32_LITTLE(m + 24));
186                 x7 = XOR(x7, U8TO32_LITTLE(m + 28));
187                 x8 = XOR(x8, U8TO32_LITTLE(m + 32));
188                 x9 = XOR(x9, U8TO32_LITTLE(m + 36));
189                 x10 = XOR(x10, U8TO32_LITTLE(m + 40));
190                 x11 = XOR(x11, U8TO32_LITTLE(m + 44));
191                 x12 = XOR(x12, U8TO32_LITTLE(m + 48));
192                 x13 = XOR(x13, U8TO32_LITTLE(m + 52));
193                 x14 = XOR(x14, U8TO32_LITTLE(m + 56));
194                 x15 = XOR(x15, U8TO32_LITTLE(m + 60));
195
196                 j12 = PLUSONE(j12);
197
198                 if(!j12) {
199                         j13 = PLUSONE(j13);
200                         /* stopping at 2^70 bytes per nonce is user's responsibility */
201                 }
202
203                 U32TO8_LITTLE(c + 0, x0);
204                 U32TO8_LITTLE(c + 4, x1);
205                 U32TO8_LITTLE(c + 8, x2);
206                 U32TO8_LITTLE(c + 12, x3);
207                 U32TO8_LITTLE(c + 16, x4);
208                 U32TO8_LITTLE(c + 20, x5);
209                 U32TO8_LITTLE(c + 24, x6);
210                 U32TO8_LITTLE(c + 28, x7);
211                 U32TO8_LITTLE(c + 32, x8);
212                 U32TO8_LITTLE(c + 36, x9);
213                 U32TO8_LITTLE(c + 40, x10);
214                 U32TO8_LITTLE(c + 44, x11);
215                 U32TO8_LITTLE(c + 48, x12);
216                 U32TO8_LITTLE(c + 52, x13);
217                 U32TO8_LITTLE(c + 56, x14);
218                 U32TO8_LITTLE(c + 60, x15);
219
220                 if(bytes <= 64) {
221                         if(bytes < 64) {
222 #if (USE_MEMCPY == 1)
223                                 memcpy(ctarget, c, bytes);
224 #else
225
226                                 for(i = 0; i < bytes; ++i) {
227                                         ctarget[i] = c[i];
228                                 }
229
230 #endif
231                         }
232
233                         x->input[12] = j12;
234                         x->input[13] = j13;
235                         return;
236                 }
237
238                 bytes -= 64;
239                 c += 64;
240                 m += 64;
241         }
242 }