Move poly1305_get_tag() into poly1305.c, hide poly1305_init().
[tinc] / src / chacha-poly1305 / poly1305.c
1 /*
2 poly1305 implementation using 32 bit * 32 bit = 64 bit multiplication and 64 bit addition
3 public domain
4 */
5
6 #include "poly1305.h"
7
8 /* use memcpy() to copy blocks of memory (typically faster) */
9 #define USE_MEMCPY          1
10 /* use unaligned little-endian load/store (can be faster) */
11 #define USE_UNALIGNED       0
12
13 struct poly1305_context {
14         uint32_t r[5];
15         uint32_t h[5];
16         uint32_t pad[4];
17         size_t leftover;
18         unsigned char buffer[POLY1305_BLOCK_SIZE];
19         unsigned char final;
20 };
21
22 #if (USE_UNALIGNED == 1)
23 #define U8TO32(p) \
24         (*((uint32_t *)(p)))
25 #define U32TO8(p, v) \
26         do { \
27                 *((uint32_t *)(p)) = v; \
28         } while (0)
29 #else
30 /* interpret four 8 bit unsigned integers as a 32 bit unsigned integer in little endian */
31 static uint32_t
32 U8TO32(const unsigned char *p) {
33         return
34                 (((uint32_t)(p[0] & 0xff)) |
35                  ((uint32_t)(p[1] & 0xff) <<  8) |
36                  ((uint32_t)(p[2] & 0xff) << 16) |
37                  ((uint32_t)(p[3] & 0xff) << 24));
38 }
39
40 /* store a 32 bit unsigned integer as four 8 bit unsigned integers in little endian */
41 static void
42 U32TO8(unsigned char *p, uint32_t v) {
43         p[0] = (v) & 0xff;
44         p[1] = (v >>  8) & 0xff;
45         p[2] = (v >> 16) & 0xff;
46         p[3] = (v >> 24) & 0xff;
47 }
48 #endif
49
50 static void
51 poly1305_init(struct poly1305_context *st, const unsigned char key[32]) {
52         /* r &= 0xffffffc0ffffffc0ffffffc0fffffff */
53         st->r[0] = (U8TO32(&key[ 0])) & 0x3ffffff;
54         st->r[1] = (U8TO32(&key[ 3]) >> 2) & 0x3ffff03;
55         st->r[2] = (U8TO32(&key[ 6]) >> 4) & 0x3ffc0ff;
56         st->r[3] = (U8TO32(&key[ 9]) >> 6) & 0x3f03fff;
57         st->r[4] = (U8TO32(&key[12]) >> 8) & 0x00fffff;
58
59         /* h = 0 */
60         st->h[0] = 0;
61         st->h[1] = 0;
62         st->h[2] = 0;
63         st->h[3] = 0;
64         st->h[4] = 0;
65
66         /* save pad for later */
67         st->pad[0] = U8TO32(&key[16]);
68         st->pad[1] = U8TO32(&key[20]);
69         st->pad[2] = U8TO32(&key[24]);
70         st->pad[3] = U8TO32(&key[28]);
71
72         st->leftover = 0;
73         st->final = 0;
74 }
75
76 static void
77 poly1305_blocks(struct poly1305_context *st, const unsigned char *m, size_t bytes) {
78         const uint32_t hibit = (st->final) ? 0 : (1 << 24); /* 1 << 128 */
79         uint32_t r0, r1, r2, r3, r4;
80         uint32_t s1, s2, s3, s4;
81         uint32_t h0, h1, h2, h3, h4;
82         uint64_t d0, d1, d2, d3, d4;
83         uint32_t c;
84
85         r0 = st->r[0];
86         r1 = st->r[1];
87         r2 = st->r[2];
88         r3 = st->r[3];
89         r4 = st->r[4];
90
91         s1 = r1 * 5;
92         s2 = r2 * 5;
93         s3 = r3 * 5;
94         s4 = r4 * 5;
95
96         h0 = st->h[0];
97         h1 = st->h[1];
98         h2 = st->h[2];
99         h3 = st->h[3];
100         h4 = st->h[4];
101
102         while(bytes >= POLY1305_BLOCK_SIZE) {
103                 /* h += m[i] */
104                 h0 += (U8TO32(m + 0)) & 0x3ffffff;
105                 h1 += (U8TO32(m + 3) >> 2) & 0x3ffffff;
106                 h2 += (U8TO32(m + 6) >> 4) & 0x3ffffff;
107                 h3 += (U8TO32(m + 9) >> 6) & 0x3ffffff;
108                 h4 += (U8TO32(m + 12) >> 8) | hibit;
109
110                 /* h *= r */
111                 d0 = ((uint64_t)h0 * r0) + ((uint64_t)h1 * s4) + ((uint64_t)h2 * s3) + ((uint64_t)h3 * s2) + ((uint64_t)h4 * s1);
112                 d1 = ((uint64_t)h0 * r1) + ((uint64_t)h1 * r0) + ((uint64_t)h2 * s4) + ((uint64_t)h3 * s3) + ((uint64_t)h4 * s2);
113                 d2 = ((uint64_t)h0 * r2) + ((uint64_t)h1 * r1) + ((uint64_t)h2 * r0) + ((uint64_t)h3 * s4) + ((uint64_t)h4 * s3);
114                 d3 = ((uint64_t)h0 * r3) + ((uint64_t)h1 * r2) + ((uint64_t)h2 * r1) + ((uint64_t)h3 * r0) + ((uint64_t)h4 * s4);
115                 d4 = ((uint64_t)h0 * r4) + ((uint64_t)h1 * r3) + ((uint64_t)h2 * r2) + ((uint64_t)h3 * r1) + ((uint64_t)h4 * r0);
116
117                 /* (partial) h %= p */
118                 c = (uint32_t)(d0 >> 26);
119                 h0 = (uint32_t)d0 & 0x3ffffff;
120                 d1 += c;
121                 c = (uint32_t)(d1 >> 26);
122                 h1 = (uint32_t)d1 & 0x3ffffff;
123                 d2 += c;
124                 c = (uint32_t)(d2 >> 26);
125                 h2 = (uint32_t)d2 & 0x3ffffff;
126                 d3 += c;
127                 c = (uint32_t)(d3 >> 26);
128                 h3 = (uint32_t)d3 & 0x3ffffff;
129                 d4 += c;
130                 c = (uint32_t)(d4 >> 26);
131                 h4 = (uint32_t)d4 & 0x3ffffff;
132                 h0 += c * 5;
133                 c = (h0 >> 26);
134                 h0 =           h0 & 0x3ffffff;
135                 h1 += c;
136
137                 m += POLY1305_BLOCK_SIZE;
138                 bytes -= POLY1305_BLOCK_SIZE;
139         }
140
141         st->h[0] = h0;
142         st->h[1] = h1;
143         st->h[2] = h2;
144         st->h[3] = h3;
145         st->h[4] = h4;
146 }
147
148 static void
149 poly1305_finish(struct poly1305_context *st, unsigned char mac[16]) {
150         uint32_t h0, h1, h2, h3, h4, c;
151         uint32_t g0, g1, g2, g3, g4;
152         uint64_t f;
153         uint32_t mask;
154
155         /* process the remaining block */
156         if(st->leftover) {
157                 size_t i = st->leftover;
158                 st->buffer[i++] = 1;
159
160                 for(; i < POLY1305_BLOCK_SIZE; i++) {
161                         st->buffer[i] = 0;
162                 }
163
164                 st->final = 1;
165                 poly1305_blocks(st, st->buffer, POLY1305_BLOCK_SIZE);
166         }
167
168         /* fully carry h */
169         h0 = st->h[0];
170         h1 = st->h[1];
171         h2 = st->h[2];
172         h3 = st->h[3];
173         h4 = st->h[4];
174
175         c = h1 >> 26;
176         h1 = h1 & 0x3ffffff;
177         h2 +=     c;
178         c = h2 >> 26;
179         h2 = h2 & 0x3ffffff;
180         h3 +=     c;
181         c = h3 >> 26;
182         h3 = h3 & 0x3ffffff;
183         h4 +=     c;
184         c = h4 >> 26;
185         h4 = h4 & 0x3ffffff;
186         h0 += c * 5;
187         c = h0 >> 26;
188         h0 = h0 & 0x3ffffff;
189         h1 +=     c;
190
191         /* compute h + -p */
192         g0 = h0 + 5;
193         c = g0 >> 26;
194         g0 &= 0x3ffffff;
195         g1 = h1 + c;
196         c = g1 >> 26;
197         g1 &= 0x3ffffff;
198         g2 = h2 + c;
199         c = g2 >> 26;
200         g2 &= 0x3ffffff;
201         g3 = h3 + c;
202         c = g3 >> 26;
203         g3 &= 0x3ffffff;
204         g4 = h4 + c - (1 << 26);
205
206         /* select h if h < p, or h + -p if h >= p */
207         mask = (g4 >> ((sizeof(uint32_t) * 8) - 1)) - 1;
208         g0 &= mask;
209         g1 &= mask;
210         g2 &= mask;
211         g3 &= mask;
212         g4 &= mask;
213         mask = ~mask;
214         h0 = (h0 & mask) | g0;
215         h1 = (h1 & mask) | g1;
216         h2 = (h2 & mask) | g2;
217         h3 = (h3 & mask) | g3;
218         h4 = (h4 & mask) | g4;
219
220         /* h = h % (2^128) */
221         h0 = ((h0) | (h1 << 26)) & 0xffffffff;
222         h1 = ((h1 >>  6) | (h2 << 20)) & 0xffffffff;
223         h2 = ((h2 >> 12) | (h3 << 14)) & 0xffffffff;
224         h3 = ((h3 >> 18) | (h4 <<  8)) & 0xffffffff;
225
226         /* mac = (h + pad) % (2^128) */
227         f = (uint64_t)h0 + st->pad[0]            ;
228         h0 = (uint32_t)f;
229         f = (uint64_t)h1 + st->pad[1] + (f >> 32);
230         h1 = (uint32_t)f;
231         f = (uint64_t)h2 + st->pad[2] + (f >> 32);
232         h2 = (uint32_t)f;
233         f = (uint64_t)h3 + st->pad[3] + (f >> 32);
234         h3 = (uint32_t)f;
235
236         U32TO8(mac +  0, h0);
237         U32TO8(mac +  4, h1);
238         U32TO8(mac +  8, h2);
239         U32TO8(mac + 12, h3);
240
241         /* zero out the state */
242         st->h[0] = 0;
243         st->h[1] = 0;
244         st->h[2] = 0;
245         st->h[3] = 0;
246         st->h[4] = 0;
247         st->r[0] = 0;
248         st->r[1] = 0;
249         st->r[2] = 0;
250         st->r[3] = 0;
251         st->r[4] = 0;
252         st->pad[0] = 0;
253         st->pad[1] = 0;
254         st->pad[2] = 0;
255         st->pad[3] = 0;
256 }
257
258 static void
259 poly1305_update(struct poly1305_context *st, const unsigned char *m, size_t bytes) {
260         size_t i;
261
262         /* handle leftover */
263         if(st->leftover) {
264                 size_t want = (POLY1305_BLOCK_SIZE - st->leftover);
265
266                 if(want > bytes) {
267                         want = bytes;
268                 }
269
270                 for(i = 0; i < want; i++) {
271                         st->buffer[st->leftover + i] = m[i];
272                 }
273
274                 bytes -= want;
275                 m += want;
276                 st->leftover += want;
277
278                 if(st->leftover < POLY1305_BLOCK_SIZE) {
279                         return;
280                 }
281
282                 poly1305_blocks(st, st->buffer, POLY1305_BLOCK_SIZE);
283                 st->leftover = 0;
284         }
285
286         /* process full blocks */
287         if(bytes >= POLY1305_BLOCK_SIZE) {
288                 size_t want = (bytes & ~(POLY1305_BLOCK_SIZE - 1));
289                 poly1305_blocks(st, m, want);
290                 m += want;
291                 bytes -= want;
292         }
293
294         /* store leftover */
295         if(bytes) {
296 #if (USE_MEMCPY == 1)
297                 memcpy(st->buffer + st->leftover, m, bytes);
298 #else
299
300                 for(i = 0; i < bytes; i++) {
301                         st->buffer[st->leftover + i] = m[i];
302                 }
303
304 #endif
305                 st->leftover += bytes;
306         }
307 }
308
309 /**
310  * Poly1305 tag generation. This concatenates a string according to the rules
311  * outlined in RFC 7539 and calculates the tag.
312  *
313  * \param key 32 byte secret one-time key for poly1305
314  * \param ct ciphertext
315  * \param ct_len ciphertext length in bytes
316  * \param tag pointer to 16 bytes for tag storage
317  */
318 void
319 poly1305_get_tag(const unsigned char key[32], const void *ct, int ct_len, unsigned char tag[16]) {
320         struct poly1305_context ctx;
321         unsigned left_over;
322         uint64_t len;
323         unsigned char pad[16];
324
325         poly1305_init(&ctx, key);
326         memset(&pad, 0, sizeof(pad));
327
328         /* payload and padding */
329         poly1305_update(&ctx, ct, ct_len);
330         left_over = ct_len % 16;
331
332         if(left_over) {
333                 poly1305_update(&ctx, pad, 16 - left_over);
334         }
335
336         /* lengths */
337         len = 0;
338         poly1305_update(&ctx, (unsigned char *)&len, 8);
339         len = ct_len;
340         poly1305_update(&ctx, (unsigned char *)&len, 8);
341         poly1305_finish(&ctx, tag);
342 }