Update the built-in Chacha20-Poly1305 code to an RFC 7539 complaint version.
[tinc] / src / chacha-poly1305 / poly1305.c
1 /*
2 poly1305 implementation using 32 bit * 32 bit = 64 bit multiplication and 64 bit addition
3 public domain
4 */
5
6 #include "poly1305.h"
7
8 #if (USE_UNALIGNED == 1)
9 #define U8TO32(p) \
10         (*((uint32_t *)(p)))
11 #define U32TO8(p, v) \
12         do { \
13                 *((uint32_t *)(p)) = v; \
14         } while (0)
15 #else
16 /* interpret four 8 bit unsigned integers as a 32 bit unsigned integer in little endian */
17 static uint32_t
18 U8TO32(const unsigned char *p) {
19         return
20                 (((uint32_t)(p[0] & 0xff)) |
21                  ((uint32_t)(p[1] & 0xff) <<  8) |
22                  ((uint32_t)(p[2] & 0xff) << 16) |
23                  ((uint32_t)(p[3] & 0xff) << 24));
24 }
25
26 /* store a 32 bit unsigned integer as four 8 bit unsigned integers in little endian */
27 static void
28 U32TO8(unsigned char *p, uint32_t v) {
29         p[0] = (v) & 0xff;
30         p[1] = (v >>  8) & 0xff;
31         p[2] = (v >> 16) & 0xff;
32         p[3] = (v >> 24) & 0xff;
33 }
34 #endif
35
36 void
37 poly1305_init(struct poly1305_context *st, const unsigned char key[32]) {
38         /* r &= 0xffffffc0ffffffc0ffffffc0fffffff */
39         st->r[0] = (U8TO32(&key[ 0])) & 0x3ffffff;
40         st->r[1] = (U8TO32(&key[ 3]) >> 2) & 0x3ffff03;
41         st->r[2] = (U8TO32(&key[ 6]) >> 4) & 0x3ffc0ff;
42         st->r[3] = (U8TO32(&key[ 9]) >> 6) & 0x3f03fff;
43         st->r[4] = (U8TO32(&key[12]) >> 8) & 0x00fffff;
44
45         /* h = 0 */
46         st->h[0] = 0;
47         st->h[1] = 0;
48         st->h[2] = 0;
49         st->h[3] = 0;
50         st->h[4] = 0;
51
52         /* save pad for later */
53         st->pad[0] = U8TO32(&key[16]);
54         st->pad[1] = U8TO32(&key[20]);
55         st->pad[2] = U8TO32(&key[24]);
56         st->pad[3] = U8TO32(&key[28]);
57
58         st->leftover = 0;
59         st->final = 0;
60 }
61
62 static void
63 poly1305_blocks(struct poly1305_context *st, const unsigned char *m, size_t bytes) {
64         const uint32_t hibit = (st->final) ? 0 : (1 << 24); /* 1 << 128 */
65         uint32_t r0, r1, r2, r3, r4;
66         uint32_t s1, s2, s3, s4;
67         uint32_t h0, h1, h2, h3, h4;
68         uint64_t d0, d1, d2, d3, d4;
69         uint32_t c;
70
71         r0 = st->r[0];
72         r1 = st->r[1];
73         r2 = st->r[2];
74         r3 = st->r[3];
75         r4 = st->r[4];
76
77         s1 = r1 * 5;
78         s2 = r2 * 5;
79         s3 = r3 * 5;
80         s4 = r4 * 5;
81
82         h0 = st->h[0];
83         h1 = st->h[1];
84         h2 = st->h[2];
85         h3 = st->h[3];
86         h4 = st->h[4];
87
88         while(bytes >= POLY1305_BLOCK_SIZE) {
89                 /* h += m[i] */
90                 h0 += (U8TO32(m + 0)) & 0x3ffffff;
91                 h1 += (U8TO32(m + 3) >> 2) & 0x3ffffff;
92                 h2 += (U8TO32(m + 6) >> 4) & 0x3ffffff;
93                 h3 += (U8TO32(m + 9) >> 6) & 0x3ffffff;
94                 h4 += (U8TO32(m + 12) >> 8) | hibit;
95
96                 /* h *= r */
97                 d0 = ((uint64_t)h0 * r0) + ((uint64_t)h1 * s4) + ((uint64_t)h2 * s3) + ((uint64_t)h3 * s2) + ((uint64_t)h4 * s1);
98                 d1 = ((uint64_t)h0 * r1) + ((uint64_t)h1 * r0) + ((uint64_t)h2 * s4) + ((uint64_t)h3 * s3) + ((uint64_t)h4 * s2);
99                 d2 = ((uint64_t)h0 * r2) + ((uint64_t)h1 * r1) + ((uint64_t)h2 * r0) + ((uint64_t)h3 * s4) + ((uint64_t)h4 * s3);
100                 d3 = ((uint64_t)h0 * r3) + ((uint64_t)h1 * r2) + ((uint64_t)h2 * r1) + ((uint64_t)h3 * r0) + ((uint64_t)h4 * s4);
101                 d4 = ((uint64_t)h0 * r4) + ((uint64_t)h1 * r3) + ((uint64_t)h2 * r2) + ((uint64_t)h3 * r1) + ((uint64_t)h4 * r0);
102
103                 /* (partial) h %= p */
104                 c = (uint32_t)(d0 >> 26);
105                 h0 = (uint32_t)d0 & 0x3ffffff;
106                 d1 += c;
107                 c = (uint32_t)(d1 >> 26);
108                 h1 = (uint32_t)d1 & 0x3ffffff;
109                 d2 += c;
110                 c = (uint32_t)(d2 >> 26);
111                 h2 = (uint32_t)d2 & 0x3ffffff;
112                 d3 += c;
113                 c = (uint32_t)(d3 >> 26);
114                 h3 = (uint32_t)d3 & 0x3ffffff;
115                 d4 += c;
116                 c = (uint32_t)(d4 >> 26);
117                 h4 = (uint32_t)d4 & 0x3ffffff;
118                 h0 += c * 5;
119                 c = (h0 >> 26);
120                 h0 =           h0 & 0x3ffffff;
121                 h1 += c;
122
123                 m += POLY1305_BLOCK_SIZE;
124                 bytes -= POLY1305_BLOCK_SIZE;
125         }
126
127         st->h[0] = h0;
128         st->h[1] = h1;
129         st->h[2] = h2;
130         st->h[3] = h3;
131         st->h[4] = h4;
132 }
133
134 void
135 poly1305_finish(struct poly1305_context *st, unsigned char mac[16]) {
136         uint32_t h0, h1, h2, h3, h4, c;
137         uint32_t g0, g1, g2, g3, g4;
138         uint64_t f;
139         uint32_t mask;
140
141         /* process the remaining block */
142         if(st->leftover) {
143                 size_t i = st->leftover;
144                 st->buffer[i++] = 1;
145
146                 for(; i < POLY1305_BLOCK_SIZE; i++) {
147                         st->buffer[i] = 0;
148                 }
149
150                 st->final = 1;
151                 poly1305_blocks(st, st->buffer, POLY1305_BLOCK_SIZE);
152         }
153
154         /* fully carry h */
155         h0 = st->h[0];
156         h1 = st->h[1];
157         h2 = st->h[2];
158         h3 = st->h[3];
159         h4 = st->h[4];
160
161         c = h1 >> 26;
162         h1 = h1 & 0x3ffffff;
163         h2 +=     c;
164         c = h2 >> 26;
165         h2 = h2 & 0x3ffffff;
166         h3 +=     c;
167         c = h3 >> 26;
168         h3 = h3 & 0x3ffffff;
169         h4 +=     c;
170         c = h4 >> 26;
171         h4 = h4 & 0x3ffffff;
172         h0 += c * 5;
173         c = h0 >> 26;
174         h0 = h0 & 0x3ffffff;
175         h1 +=     c;
176
177         /* compute h + -p */
178         g0 = h0 + 5;
179         c = g0 >> 26;
180         g0 &= 0x3ffffff;
181         g1 = h1 + c;
182         c = g1 >> 26;
183         g1 &= 0x3ffffff;
184         g2 = h2 + c;
185         c = g2 >> 26;
186         g2 &= 0x3ffffff;
187         g3 = h3 + c;
188         c = g3 >> 26;
189         g3 &= 0x3ffffff;
190         g4 = h4 + c - (1 << 26);
191
192         /* select h if h < p, or h + -p if h >= p */
193         mask = (g4 >> ((sizeof(uint32_t) * 8) - 1)) - 1;
194         g0 &= mask;
195         g1 &= mask;
196         g2 &= mask;
197         g3 &= mask;
198         g4 &= mask;
199         mask = ~mask;
200         h0 = (h0 & mask) | g0;
201         h1 = (h1 & mask) | g1;
202         h2 = (h2 & mask) | g2;
203         h3 = (h3 & mask) | g3;
204         h4 = (h4 & mask) | g4;
205
206         /* h = h % (2^128) */
207         h0 = ((h0) | (h1 << 26)) & 0xffffffff;
208         h1 = ((h1 >>  6) | (h2 << 20)) & 0xffffffff;
209         h2 = ((h2 >> 12) | (h3 << 14)) & 0xffffffff;
210         h3 = ((h3 >> 18) | (h4 <<  8)) & 0xffffffff;
211
212         /* mac = (h + pad) % (2^128) */
213         f = (uint64_t)h0 + st->pad[0]            ;
214         h0 = (uint32_t)f;
215         f = (uint64_t)h1 + st->pad[1] + (f >> 32);
216         h1 = (uint32_t)f;
217         f = (uint64_t)h2 + st->pad[2] + (f >> 32);
218         h2 = (uint32_t)f;
219         f = (uint64_t)h3 + st->pad[3] + (f >> 32);
220         h3 = (uint32_t)f;
221
222         U32TO8(mac +  0, h0);
223         U32TO8(mac +  4, h1);
224         U32TO8(mac +  8, h2);
225         U32TO8(mac + 12, h3);
226
227         /* zero out the state */
228         st->h[0] = 0;
229         st->h[1] = 0;
230         st->h[2] = 0;
231         st->h[3] = 0;
232         st->h[4] = 0;
233         st->r[0] = 0;
234         st->r[1] = 0;
235         st->r[2] = 0;
236         st->r[3] = 0;
237         st->r[4] = 0;
238         st->pad[0] = 0;
239         st->pad[1] = 0;
240         st->pad[2] = 0;
241         st->pad[3] = 0;
242 }
243
244
245 void
246 poly1305_update(struct poly1305_context *st, const unsigned char *m, size_t bytes) {
247         size_t i;
248
249         /* handle leftover */
250         if(st->leftover) {
251                 size_t want = (POLY1305_BLOCK_SIZE - st->leftover);
252
253                 if(want > bytes) {
254                         want = bytes;
255                 }
256
257                 for(i = 0; i < want; i++) {
258                         st->buffer[st->leftover + i] = m[i];
259                 }
260
261                 bytes -= want;
262                 m += want;
263                 st->leftover += want;
264
265                 if(st->leftover < POLY1305_BLOCK_SIZE) {
266                         return;
267                 }
268
269                 poly1305_blocks(st, st->buffer, POLY1305_BLOCK_SIZE);
270                 st->leftover = 0;
271         }
272
273         /* process full blocks */
274         if(bytes >= POLY1305_BLOCK_SIZE) {
275                 size_t want = (bytes & ~(POLY1305_BLOCK_SIZE - 1));
276                 poly1305_blocks(st, m, want);
277                 m += want;
278                 bytes -= want;
279         }
280
281         /* store leftover */
282         if(bytes) {
283 #if (USE_MEMCPY == 1)
284                 memcpy(st->buffer + st->leftover, m, bytes);
285 #else
286
287                 for(i = 0; i < bytes; i++) {
288                         st->buffer[st->leftover + i] = m[i];
289                 }
290
291 #endif
292                 st->leftover += bytes;
293         }
294 }
295
296 void
297 poly1305_auth(unsigned char mac[16], const unsigned char *m, size_t bytes, const unsigned char key[32]) {
298         struct poly1305_context ctx;
299         poly1305_init(&ctx, key);
300         poly1305_update(&ctx, m, bytes);
301         poly1305_finish(&ctx, mac);
302 }