From 9b386bda33e94c79d6b9a1db911d394c26592e71 Mon Sep 17 00:00:00 2001 From: Frank Denis Date: Fri, 9 Oct 2020 23:19:27 +0200 Subject: [PATCH] std/crypto: add a vectorized ChaCha20 implementation Brings a 30% speed boost on x86_64 even though we still process only one block at a time for now. Only enabled on x86_64 since the non-vectorized implementation seems to currently perform better on some architectures (at least on aarch64). But the non-vectorized implementation still gets a little speed boost as well (~17%) with these changes. --- lib/std/crypto/chacha20.zig | 418 +++++++++++++++++++++++++----------- 1 file changed, 296 insertions(+), 122 deletions(-) diff --git a/lib/std/crypto/chacha20.zig b/lib/std/crypto/chacha20.zig index 915f81b9f..6840dfb52 100644 --- a/lib/std/crypto/chacha20.zig +++ b/lib/std/crypto/chacha20.zig @@ -10,120 +10,315 @@ const mem = std.mem; const assert = std.debug.assert; const testing = std.testing; const maxInt = std.math.maxInt; +const Vector = std.meta.Vector; const Poly1305 = std.crypto.onetimeauth.Poly1305; -const QuarterRound = struct { - a: usize, - b: usize, - c: usize, - d: usize, +// Vectorized implementation of the core function +const ChaCha20VecImpl = struct { + const Lane = Vector(4, u32); + const BlockVec = [4]Lane; + + fn initContext(key: [8]u32, d: [4]u32) BlockVec { + const c = "expand 32-byte k"; + const constant_le = comptime Lane{ + mem.readIntLittle(u32, c[0..4]), + mem.readIntLittle(u32, c[4..8]), + mem.readIntLittle(u32, c[8..12]), + mem.readIntLittle(u32, c[12..16]), + }; + return BlockVec{ + constant_le, + Lane{ key[0], key[1], key[2], key[3] }, + Lane{ key[4], key[5], key[6], key[7] }, + Lane{ d[0], d[1], d[2], d[3] }, + }; + } + + inline fn chacha20Core(x: *BlockVec, input: BlockVec) void { + const rot8 = Vector(16, i32){ 3, 0, 1, 2, 7, 4, 5, 6, 11, 8, 9, 10, 15, 12, 13, 14 }; + const rot16 = Vector(16, i32){ 2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9, 14, 15, 12, 13 }; + + x.* = input; + + var r: usize = 0; + while (r < 20) : (r += 2) { + x[0] +%= x[1]; + x[3] ^= x[0]; + x[3] = @bitCast(Vector(4, u32), @shuffle(u8, @bitCast(Vector(16, u8), x[3]), undefined, rot16)); + + x[2] +%= x[3]; + x[1] ^= x[2]; + + var t1 = x[1]; + x[1] <<= @splat(4, @as(u5, 12)); + t1 >>= @splat(4, @as(u5, 20)); + x[1] ^= t1; + + x[0] +%= x[1]; + x[3] ^= x[0]; + x[0] = @shuffle(u32, x[0], undefined, Vector(4, i32){ 3, 0, 1, 2 }); + x[3] = @bitCast(Vector(4, u32), @shuffle(u8, @bitCast(Vector(16, u8), x[3]), undefined, rot8)); + + x[2] +%= x[3]; + x[3] = @shuffle(u32, x[3], undefined, Vector(4, i32){ 2, 3, 0, 1 }); + x[1] ^= x[2]; + x[2] = @shuffle(u32, x[2], undefined, Vector(4, i32){ 1, 2, 3, 0 }); + + t1 = x[1]; + x[1] <<= @splat(4, @as(u5, 7)); + t1 >>= @splat(4, @as(u5, 25)); + x[1] ^= t1; + + x[0] +%= x[1]; + x[3] ^= x[0]; + x[3] = @bitCast(Vector(4, u32), @shuffle(u8, @bitCast(Vector(16, u8), x[3]), undefined, rot16)); + + x[2] +%= x[3]; + x[1] ^= x[2]; + + t1 = x[1]; + x[1] <<= @splat(4, @as(u5, 12)); + t1 >>= @splat(4, @as(u5, 20)); + x[1] ^= t1; + + x[0] +%= x[1]; + x[3] ^= x[0]; + x[0] = @shuffle(u32, x[0], undefined, Vector(4, i32){ 1, 2, 3, 0 }); + x[3] = @bitCast(Vector(4, u32), @shuffle(u8, @bitCast(Vector(16, u8), x[3]), undefined, rot8)); + + x[2] +%= x[3]; + x[3] = @shuffle(u32, x[3], undefined, Vector(4, i32){ 2, 3, 0, 1 }); + x[1] ^= x[2]; + x[2] = @shuffle(u32, x[2], undefined, Vector(4, i32){ 3, 0, 1, 2 }); + + t1 = x[1]; + x[1] <<= @splat(4, @as(u5, 7)); + t1 >>= @splat(4, @as(u5, 25)); + x[1] ^= t1; + } + } + + inline fn hashToBytes(out: *[64]u8, x: BlockVec) void { + var i: usize = 0; + while (i < 4) : (i += 1) { + mem.writeIntLittle(u32, out[16 * i + 0 ..][0..4], x[i][0]); + mem.writeIntLittle(u32, out[16 * i + 4 ..][0..4], x[i][1]); + mem.writeIntLittle(u32, out[16 * i + 8 ..][0..4], x[i][2]); + mem.writeIntLittle(u32, out[16 * i + 12 ..][0..4], x[i][3]); + } + } + + inline fn contextFeedback(x: *BlockVec, ctx: BlockVec) void { + x[0] +%= ctx[0]; + x[1] +%= ctx[1]; + x[2] +%= ctx[2]; + x[3] +%= ctx[3]; + } + + fn chaCha20Internal(out: []u8, in: []const u8, key: [8]u32, counter: [4]u32) void { + var ctx = initContext(key, counter); + var x: BlockVec = undefined; + var buf: [64]u8 = undefined; + var i: usize = 0; + while (i + 64 <= in.len) : (i += 64) { + chacha20Core(x[0..], ctx); + contextFeedback(&x, ctx); + hashToBytes(buf[0..], x); + + var xout = out[i..]; + const xin = in[i..]; + var j: usize = 0; + while (j < 64) : (j += 1) { + xout[j] = xin[j]; + } + j = 0; + while (j < 64) : (j += 1) { + xout[j] ^= buf[j]; + } + ctx[3][0] += 1; + } + if (i < in.len) { + chacha20Core(x[0..], ctx); + contextFeedback(&x, ctx); + hashToBytes(buf[0..], x); + + var xout = out[i..]; + const xin = in[i..]; + var j: usize = 0; + while (j < in.len % 64) : (j += 1) { + xout[j] = xin[j] ^ buf[j]; + } + } + } + + fn hchacha20(input: [16]u8, key: [32]u8) [32]u8 { + var c: [4]u32 = undefined; + for (c) |_, i| { + c[i] = mem.readIntLittle(u32, input[4 * i ..][0..4]); + } + const ctx = initContext(keyToWords(key), c); + var x: BlockVec = undefined; + chacha20Core(x[0..], ctx); + var out: [32]u8 = undefined; + mem.writeIntLittle(u32, out[0..4], x[0][0]); + mem.writeIntLittle(u32, out[4..8], x[0][1]); + mem.writeIntLittle(u32, out[8..12], x[0][2]); + mem.writeIntLittle(u32, out[12..16], x[0][3]); + mem.writeIntLittle(u32, out[16..20], x[3][0]); + mem.writeIntLittle(u32, out[20..24], x[3][1]); + mem.writeIntLittle(u32, out[24..28], x[3][2]); + mem.writeIntLittle(u32, out[28..32], x[3][3]); + return out; + } }; -fn Rp(a: usize, b: usize, c: usize, d: usize) QuarterRound { - return QuarterRound{ - .a = a, - .b = b, - .c = c, - .d = d, - }; -} +// Non-vectorized implementation of the core function +const ChaCha20NonVecImpl = struct { + const BlockVec = [16]u32; -fn initContext(key: [8]u32, d: [4]u32) [16]u32 { - var ctx: [16]u32 = undefined; - const c = "expand 32-byte k"; - const constant_le = comptime [_]u32{ - mem.readIntLittle(u32, c[0..4]), - mem.readIntLittle(u32, c[4..8]), - mem.readIntLittle(u32, c[8..12]), - mem.readIntLittle(u32, c[12..16]), - }; - mem.copy(u32, ctx[0..], constant_le[0..4]); - mem.copy(u32, ctx[4..12], key[0..8]); - mem.copy(u32, ctx[12..16], d[0..4]); + fn initContext(key: [8]u32, d: [4]u32) BlockVec { + const c = "expand 32-byte k"; + const constant_le = comptime [4]u32{ + mem.readIntLittle(u32, c[0..4]), + mem.readIntLittle(u32, c[4..8]), + mem.readIntLittle(u32, c[8..12]), + mem.readIntLittle(u32, c[12..16]), + }; + return BlockVec{ + constant_le[0], constant_le[1], constant_le[2], constant_le[3], + key[0], key[1], key[2], key[3], + key[4], key[5], key[6], key[7], + d[0], d[1], d[2], d[3], + }; + } - return ctx; -} - -// The chacha family of ciphers are based on the salsa family. -inline fn chacha20Core(x: []u32, input: [16]u32) void { - for (x) |_, i| - x[i] = input[i]; - - const rounds = comptime [_]QuarterRound{ - Rp(0, 4, 8, 12), - Rp(1, 5, 9, 13), - Rp(2, 6, 10, 14), - Rp(3, 7, 11, 15), - Rp(0, 5, 10, 15), - Rp(1, 6, 11, 12), - Rp(2, 7, 8, 13), - Rp(3, 4, 9, 14), + const QuarterRound = struct { + a: usize, + b: usize, + c: usize, + d: usize, }; - comptime var j: usize = 0; - inline while (j < 20) : (j += 2) { - // two-round cycles - inline for (rounds) |r| { - x[r.a] +%= x[r.b]; - x[r.d] = std.math.rotl(u32, x[r.d] ^ x[r.a], @as(u32, 16)); - x[r.c] +%= x[r.d]; - x[r.b] = std.math.rotl(u32, x[r.b] ^ x[r.c], @as(u32, 12)); - x[r.a] +%= x[r.b]; - x[r.d] = std.math.rotl(u32, x[r.d] ^ x[r.a], @as(u32, 8)); - x[r.c] +%= x[r.d]; - x[r.b] = std.math.rotl(u32, x[r.b] ^ x[r.c], @as(u32, 7)); + fn Rp(a: usize, b: usize, c: usize, d: usize) QuarterRound { + return QuarterRound{ + .a = a, + .b = b, + .c = c, + .d = d, + }; + } + + inline fn chacha20Core(x: *BlockVec, input: BlockVec) void { + x.* = input; + + const rounds = comptime [_]QuarterRound{ + Rp(0, 4, 8, 12), + Rp(1, 5, 9, 13), + Rp(2, 6, 10, 14), + Rp(3, 7, 11, 15), + Rp(0, 5, 10, 15), + Rp(1, 6, 11, 12), + Rp(2, 7, 8, 13), + Rp(3, 4, 9, 14), + }; + + comptime var j: usize = 0; + inline while (j < 20) : (j += 2) { + inline for (rounds) |r| { + x[r.a] +%= x[r.b]; + x[r.d] = std.math.rotl(u32, x[r.d] ^ x[r.a], @as(u32, 16)); + x[r.c] +%= x[r.d]; + x[r.b] = std.math.rotl(u32, x[r.b] ^ x[r.c], @as(u32, 12)); + x[r.a] +%= x[r.b]; + x[r.d] = std.math.rotl(u32, x[r.d] ^ x[r.a], @as(u32, 8)); + x[r.c] +%= x[r.d]; + x[r.b] = std.math.rotl(u32, x[r.b] ^ x[r.c], @as(u32, 7)); + } } } -} -fn hashToBytes(out: []u8, x: [16]u32) void { - for (x) |_, i| { - mem.writeIntLittle(u32, out[4 * i ..][0..4], x[i]); + inline fn hashToBytes(out: *[64]u8, x: BlockVec) void { + var i: usize = 0; + while (i < 4) : (i += 1) { + mem.writeIntLittle(u32, out[16 * i + 0 ..][0..4], x[i * 4 + 0]); + mem.writeIntLittle(u32, out[16 * i + 4 ..][0..4], x[i * 4 + 1]); + mem.writeIntLittle(u32, out[16 * i + 8 ..][0..4], x[i * 4 + 2]); + mem.writeIntLittle(u32, out[16 * i + 12 ..][0..4], x[i * 4 + 3]); + } } -} -fn chaCha20_internal(out: []u8, in: []const u8, key: [8]u32, counter: [4]u32) void { - var ctx = initContext(key, counter); - var remaining: usize = if (in.len > out.len) in.len else out.len; - var cursor: usize = 0; - - while (true) { - var x: [16]u32 = undefined; - var buf: [64]u8 = undefined; - chacha20Core(x[0..], ctx); - for (x) |_, i| { + inline fn contextFeedback(x: *BlockVec, ctx: BlockVec) void { + var i: usize = 0; + while (i < 16) : (i += 1) { x[i] +%= ctx[i]; } - hashToBytes(buf[0..], x); - if (remaining < 64) { - var i: usize = 0; - while (i < remaining) : (i += 1) - out[cursor + i] = in[cursor + i] ^ buf[i]; - return; - } - - var i: usize = 0; - while (i < 64) : (i += 1) - out[cursor + i] = in[cursor + i] ^ buf[i]; - - cursor += 64; - remaining -= 64; - - ctx[12] += 1; } -} + + fn chaCha20Internal(out: []u8, in: []const u8, key: [8]u32, counter: [4]u32) void { + var ctx = initContext(key, counter); + var x: BlockVec = undefined; + var buf: [64]u8 = undefined; + var i: usize = 0; + while (i + 64 <= in.len) : (i += 64) { + chacha20Core(x[0..], ctx); + contextFeedback(&x, ctx); + hashToBytes(buf[0..], x); + + var xout = out[i..]; + const xin = in[i..]; + var j: usize = 0; + while (j < 64) : (j += 1) { + xout[j] = xin[j]; + } + j = 0; + while (j < 64) : (j += 1) { + xout[j] ^= buf[j]; + } + ctx[12] += 1; + } + if (i < in.len) { + chacha20Core(x[0..], ctx); + contextFeedback(&x, ctx); + hashToBytes(buf[0..], x); + + var xout = out[i..]; + const xin = in[i..]; + var j: usize = 0; + while (j < in.len % 64) : (j += 1) { + xout[j] = xin[j] ^ buf[j]; + } + } + } + + fn hchacha20(input: [16]u8, key: [32]u8) [32]u8 { + var c: [4]u32 = undefined; + for (c) |_, i| { + c[i] = mem.readIntLittle(u32, input[4 * i ..][0..4]); + } + const ctx = initContext(keyToWords(key), c); + var x: BlockVec = undefined; + chacha20Core(x[0..], ctx); + var out: [32]u8 = undefined; + mem.writeIntLittle(u32, out[0..4], x[0]); + mem.writeIntLittle(u32, out[4..8], x[1]); + mem.writeIntLittle(u32, out[8..12], x[2]); + mem.writeIntLittle(u32, out[12..16], x[3]); + mem.writeIntLittle(u32, out[16..20], x[12]); + mem.writeIntLittle(u32, out[20..24], x[13]); + mem.writeIntLittle(u32, out[24..28], x[14]); + mem.writeIntLittle(u32, out[28..32], x[15]); + return out; + } +}; + +const ChaCha20Impl = if (std.Target.current.cpu.arch == .x86_64) ChaCha20VecImpl else ChaCha20NonVecImpl; fn keyToWords(key: [32]u8) [8]u32 { var k: [8]u32 = undefined; - k[0] = mem.readIntLittle(u32, key[0..4]); - k[1] = mem.readIntLittle(u32, key[4..8]); - k[2] = mem.readIntLittle(u32, key[8..12]); - k[3] = mem.readIntLittle(u32, key[12..16]); - k[4] = mem.readIntLittle(u32, key[16..20]); - k[5] = mem.readIntLittle(u32, key[20..24]); - k[6] = mem.readIntLittle(u32, key[24..28]); - k[7] = mem.readIntLittle(u32, key[28..32]); - + var i: usize = 0; + while (i < 8) : (i += 1) { + k[i] = mem.readIntLittle(u32, key[i * 4 ..][0..4]); + } return k; } @@ -145,7 +340,7 @@ pub const ChaCha20IETF = struct { c[1] = mem.readIntLittle(u32, nonce[0..4]); c[2] = mem.readIntLittle(u32, nonce[4..8]); c[3] = mem.readIntLittle(u32, nonce[8..12]); - chaCha20_internal(out, in, keyToWords(key), c); + ChaCha20Impl.chaCha20Internal(out, in, keyToWords(key), c); } }; @@ -171,7 +366,7 @@ pub const ChaCha20With64BitNonce = struct { // first partial big block if (((@intCast(u64, maxInt(u32) - @truncate(u32, counter)) + 1) << 6) < in.len) { - chaCha20_internal(out[cursor..big_block], in[cursor..big_block], k, c); + ChaCha20Impl.chaCha20Internal(out[cursor..big_block], in[cursor..big_block], k, c); cursor = big_block - cursor; c[1] += 1; if (comptime @sizeOf(usize) > 4) { @@ -179,14 +374,14 @@ pub const ChaCha20With64BitNonce = struct { var remaining_blocks: u32 = @intCast(u32, (in.len / big_block)); var i: u32 = 0; while (remaining_blocks > 0) : (remaining_blocks -= 1) { - chaCha20_internal(out[cursor .. cursor + big_block], in[cursor .. cursor + big_block], k, c); - c[1] += 1; // upper 32-bit of counter, generic chaCha20_internal() doesn't know about this. + ChaCha20Impl.chaCha20Internal(out[cursor .. cursor + big_block], in[cursor .. cursor + big_block], k, c); + c[1] += 1; // upper 32-bit of counter, generic chaCha20Internal() doesn't know about this. cursor += big_block; } } } - chaCha20_internal(out[cursor..], in[cursor..], k, c); + ChaCha20Impl.chaCha20Internal(out[cursor..], in[cursor..], k, c); } }; @@ -533,33 +728,12 @@ fn chacha20poly1305Open(dst: []u8, ciphertextAndTag: []const u8, data: []const u return try chacha20poly1305OpenDetached(dst, ciphertextAndTag[0..ciphertextLen], ciphertextAndTag[ciphertextLen..][0..chacha20poly1305_tag_size], data, key, nonce); } -fn hchacha20(input: [16]u8, key: [32]u8) [32]u8 { - var c: [4]u32 = undefined; - for (c) |_, i| { - c[i] = mem.readIntLittle(u32, input[4 * i ..][0..4]); - } - const ctx = initContext(keyToWords(key), c); - var x: [16]u32 = undefined; - chacha20Core(x[0..], ctx); - var out: [32]u8 = undefined; - mem.writeIntLittle(u32, out[0..4], x[0]); - mem.writeIntLittle(u32, out[4..8], x[1]); - mem.writeIntLittle(u32, out[8..12], x[2]); - mem.writeIntLittle(u32, out[12..16], x[3]); - mem.writeIntLittle(u32, out[16..20], x[12]); - mem.writeIntLittle(u32, out[20..24], x[13]); - mem.writeIntLittle(u32, out[24..28], x[14]); - mem.writeIntLittle(u32, out[28..32], x[15]); - - return out; -} - fn extend(key: [32]u8, nonce: [24]u8) struct { key: [32]u8, nonce: [12]u8 } { var subnonce: [12]u8 = undefined; mem.set(u8, subnonce[0..4], 0); mem.copy(u8, subnonce[4..], nonce[16..24]); return .{ - .key = hchacha20(nonce[0..16].*, key), + .key = ChaCha20Impl.hchacha20(nonce[0..16].*, key), .nonce = subnonce, }; }