std/crypto: add a vectorized ChaCha20 implementation

Brings a 30% speed boost on x86_64 even though we still process only
one block at a time for now.

Only enabled on x86_64 since the non-vectorized implementation seems
to currently perform better on some architectures (at least on aarch64).

But the non-vectorized implementation still gets a little speed boost
as well (~17%) with these changes.
This commit is contained in:
Frank Denis 2020-10-09 23:19:27 +02:00
parent 53c63bdb73
commit 9b386bda33

View File

@ -10,120 +10,315 @@ const mem = std.mem;
const assert = std.debug.assert;
const testing = std.testing;
const maxInt = std.math.maxInt;
const Vector = std.meta.Vector;
const Poly1305 = std.crypto.onetimeauth.Poly1305;
const QuarterRound = struct {
a: usize,
b: usize,
c: usize,
d: usize,
// Vectorized implementation of the core function
const ChaCha20VecImpl = struct {
const Lane = Vector(4, u32);
const BlockVec = [4]Lane;
fn initContext(key: [8]u32, d: [4]u32) BlockVec {
const c = "expand 32-byte k";
const constant_le = comptime Lane{
mem.readIntLittle(u32, c[0..4]),
mem.readIntLittle(u32, c[4..8]),
mem.readIntLittle(u32, c[8..12]),
mem.readIntLittle(u32, c[12..16]),
};
return BlockVec{
constant_le,
Lane{ key[0], key[1], key[2], key[3] },
Lane{ key[4], key[5], key[6], key[7] },
Lane{ d[0], d[1], d[2], d[3] },
};
}
inline fn chacha20Core(x: *BlockVec, input: BlockVec) void {
const rot8 = Vector(16, i32){ 3, 0, 1, 2, 7, 4, 5, 6, 11, 8, 9, 10, 15, 12, 13, 14 };
const rot16 = Vector(16, i32){ 2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9, 14, 15, 12, 13 };
x.* = input;
var r: usize = 0;
while (r < 20) : (r += 2) {
x[0] +%= x[1];
x[3] ^= x[0];
x[3] = @bitCast(Vector(4, u32), @shuffle(u8, @bitCast(Vector(16, u8), x[3]), undefined, rot16));
x[2] +%= x[3];
x[1] ^= x[2];
var t1 = x[1];
x[1] <<= @splat(4, @as(u5, 12));
t1 >>= @splat(4, @as(u5, 20));
x[1] ^= t1;
x[0] +%= x[1];
x[3] ^= x[0];
x[0] = @shuffle(u32, x[0], undefined, Vector(4, i32){ 3, 0, 1, 2 });
x[3] = @bitCast(Vector(4, u32), @shuffle(u8, @bitCast(Vector(16, u8), x[3]), undefined, rot8));
x[2] +%= x[3];
x[3] = @shuffle(u32, x[3], undefined, Vector(4, i32){ 2, 3, 0, 1 });
x[1] ^= x[2];
x[2] = @shuffle(u32, x[2], undefined, Vector(4, i32){ 1, 2, 3, 0 });
t1 = x[1];
x[1] <<= @splat(4, @as(u5, 7));
t1 >>= @splat(4, @as(u5, 25));
x[1] ^= t1;
x[0] +%= x[1];
x[3] ^= x[0];
x[3] = @bitCast(Vector(4, u32), @shuffle(u8, @bitCast(Vector(16, u8), x[3]), undefined, rot16));
x[2] +%= x[3];
x[1] ^= x[2];
t1 = x[1];
x[1] <<= @splat(4, @as(u5, 12));
t1 >>= @splat(4, @as(u5, 20));
x[1] ^= t1;
x[0] +%= x[1];
x[3] ^= x[0];
x[0] = @shuffle(u32, x[0], undefined, Vector(4, i32){ 1, 2, 3, 0 });
x[3] = @bitCast(Vector(4, u32), @shuffle(u8, @bitCast(Vector(16, u8), x[3]), undefined, rot8));
x[2] +%= x[3];
x[3] = @shuffle(u32, x[3], undefined, Vector(4, i32){ 2, 3, 0, 1 });
x[1] ^= x[2];
x[2] = @shuffle(u32, x[2], undefined, Vector(4, i32){ 3, 0, 1, 2 });
t1 = x[1];
x[1] <<= @splat(4, @as(u5, 7));
t1 >>= @splat(4, @as(u5, 25));
x[1] ^= t1;
}
}
inline fn hashToBytes(out: *[64]u8, x: BlockVec) void {
var i: usize = 0;
while (i < 4) : (i += 1) {
mem.writeIntLittle(u32, out[16 * i + 0 ..][0..4], x[i][0]);
mem.writeIntLittle(u32, out[16 * i + 4 ..][0..4], x[i][1]);
mem.writeIntLittle(u32, out[16 * i + 8 ..][0..4], x[i][2]);
mem.writeIntLittle(u32, out[16 * i + 12 ..][0..4], x[i][3]);
}
}
inline fn contextFeedback(x: *BlockVec, ctx: BlockVec) void {
x[0] +%= ctx[0];
x[1] +%= ctx[1];
x[2] +%= ctx[2];
x[3] +%= ctx[3];
}
fn chaCha20Internal(out: []u8, in: []const u8, key: [8]u32, counter: [4]u32) void {
var ctx = initContext(key, counter);
var x: BlockVec = undefined;
var buf: [64]u8 = undefined;
var i: usize = 0;
while (i + 64 <= in.len) : (i += 64) {
chacha20Core(x[0..], ctx);
contextFeedback(&x, ctx);
hashToBytes(buf[0..], x);
var xout = out[i..];
const xin = in[i..];
var j: usize = 0;
while (j < 64) : (j += 1) {
xout[j] = xin[j];
}
j = 0;
while (j < 64) : (j += 1) {
xout[j] ^= buf[j];
}
ctx[3][0] += 1;
}
if (i < in.len) {
chacha20Core(x[0..], ctx);
contextFeedback(&x, ctx);
hashToBytes(buf[0..], x);
var xout = out[i..];
const xin = in[i..];
var j: usize = 0;
while (j < in.len % 64) : (j += 1) {
xout[j] = xin[j] ^ buf[j];
}
}
}
fn hchacha20(input: [16]u8, key: [32]u8) [32]u8 {
var c: [4]u32 = undefined;
for (c) |_, i| {
c[i] = mem.readIntLittle(u32, input[4 * i ..][0..4]);
}
const ctx = initContext(keyToWords(key), c);
var x: BlockVec = undefined;
chacha20Core(x[0..], ctx);
var out: [32]u8 = undefined;
mem.writeIntLittle(u32, out[0..4], x[0][0]);
mem.writeIntLittle(u32, out[4..8], x[0][1]);
mem.writeIntLittle(u32, out[8..12], x[0][2]);
mem.writeIntLittle(u32, out[12..16], x[0][3]);
mem.writeIntLittle(u32, out[16..20], x[3][0]);
mem.writeIntLittle(u32, out[20..24], x[3][1]);
mem.writeIntLittle(u32, out[24..28], x[3][2]);
mem.writeIntLittle(u32, out[28..32], x[3][3]);
return out;
}
};
fn Rp(a: usize, b: usize, c: usize, d: usize) QuarterRound {
return QuarterRound{
.a = a,
.b = b,
.c = c,
.d = d,
};
}
// Non-vectorized implementation of the core function
const ChaCha20NonVecImpl = struct {
const BlockVec = [16]u32;
fn initContext(key: [8]u32, d: [4]u32) [16]u32 {
var ctx: [16]u32 = undefined;
const c = "expand 32-byte k";
const constant_le = comptime [_]u32{
mem.readIntLittle(u32, c[0..4]),
mem.readIntLittle(u32, c[4..8]),
mem.readIntLittle(u32, c[8..12]),
mem.readIntLittle(u32, c[12..16]),
};
mem.copy(u32, ctx[0..], constant_le[0..4]);
mem.copy(u32, ctx[4..12], key[0..8]);
mem.copy(u32, ctx[12..16], d[0..4]);
fn initContext(key: [8]u32, d: [4]u32) BlockVec {
const c = "expand 32-byte k";
const constant_le = comptime [4]u32{
mem.readIntLittle(u32, c[0..4]),
mem.readIntLittle(u32, c[4..8]),
mem.readIntLittle(u32, c[8..12]),
mem.readIntLittle(u32, c[12..16]),
};
return BlockVec{
constant_le[0], constant_le[1], constant_le[2], constant_le[3],
key[0], key[1], key[2], key[3],
key[4], key[5], key[6], key[7],
d[0], d[1], d[2], d[3],
};
}
return ctx;
}
// The chacha family of ciphers are based on the salsa family.
inline fn chacha20Core(x: []u32, input: [16]u32) void {
for (x) |_, i|
x[i] = input[i];
const rounds = comptime [_]QuarterRound{
Rp(0, 4, 8, 12),
Rp(1, 5, 9, 13),
Rp(2, 6, 10, 14),
Rp(3, 7, 11, 15),
Rp(0, 5, 10, 15),
Rp(1, 6, 11, 12),
Rp(2, 7, 8, 13),
Rp(3, 4, 9, 14),
const QuarterRound = struct {
a: usize,
b: usize,
c: usize,
d: usize,
};
comptime var j: usize = 0;
inline while (j < 20) : (j += 2) {
// two-round cycles
inline for (rounds) |r| {
x[r.a] +%= x[r.b];
x[r.d] = std.math.rotl(u32, x[r.d] ^ x[r.a], @as(u32, 16));
x[r.c] +%= x[r.d];
x[r.b] = std.math.rotl(u32, x[r.b] ^ x[r.c], @as(u32, 12));
x[r.a] +%= x[r.b];
x[r.d] = std.math.rotl(u32, x[r.d] ^ x[r.a], @as(u32, 8));
x[r.c] +%= x[r.d];
x[r.b] = std.math.rotl(u32, x[r.b] ^ x[r.c], @as(u32, 7));
fn Rp(a: usize, b: usize, c: usize, d: usize) QuarterRound {
return QuarterRound{
.a = a,
.b = b,
.c = c,
.d = d,
};
}
inline fn chacha20Core(x: *BlockVec, input: BlockVec) void {
x.* = input;
const rounds = comptime [_]QuarterRound{
Rp(0, 4, 8, 12),
Rp(1, 5, 9, 13),
Rp(2, 6, 10, 14),
Rp(3, 7, 11, 15),
Rp(0, 5, 10, 15),
Rp(1, 6, 11, 12),
Rp(2, 7, 8, 13),
Rp(3, 4, 9, 14),
};
comptime var j: usize = 0;
inline while (j < 20) : (j += 2) {
inline for (rounds) |r| {
x[r.a] +%= x[r.b];
x[r.d] = std.math.rotl(u32, x[r.d] ^ x[r.a], @as(u32, 16));
x[r.c] +%= x[r.d];
x[r.b] = std.math.rotl(u32, x[r.b] ^ x[r.c], @as(u32, 12));
x[r.a] +%= x[r.b];
x[r.d] = std.math.rotl(u32, x[r.d] ^ x[r.a], @as(u32, 8));
x[r.c] +%= x[r.d];
x[r.b] = std.math.rotl(u32, x[r.b] ^ x[r.c], @as(u32, 7));
}
}
}
}
fn hashToBytes(out: []u8, x: [16]u32) void {
for (x) |_, i| {
mem.writeIntLittle(u32, out[4 * i ..][0..4], x[i]);
inline fn hashToBytes(out: *[64]u8, x: BlockVec) void {
var i: usize = 0;
while (i < 4) : (i += 1) {
mem.writeIntLittle(u32, out[16 * i + 0 ..][0..4], x[i * 4 + 0]);
mem.writeIntLittle(u32, out[16 * i + 4 ..][0..4], x[i * 4 + 1]);
mem.writeIntLittle(u32, out[16 * i + 8 ..][0..4], x[i * 4 + 2]);
mem.writeIntLittle(u32, out[16 * i + 12 ..][0..4], x[i * 4 + 3]);
}
}
}
fn chaCha20_internal(out: []u8, in: []const u8, key: [8]u32, counter: [4]u32) void {
var ctx = initContext(key, counter);
var remaining: usize = if (in.len > out.len) in.len else out.len;
var cursor: usize = 0;
while (true) {
var x: [16]u32 = undefined;
var buf: [64]u8 = undefined;
chacha20Core(x[0..], ctx);
for (x) |_, i| {
inline fn contextFeedback(x: *BlockVec, ctx: BlockVec) void {
var i: usize = 0;
while (i < 16) : (i += 1) {
x[i] +%= ctx[i];
}
hashToBytes(buf[0..], x);
if (remaining < 64) {
var i: usize = 0;
while (i < remaining) : (i += 1)
out[cursor + i] = in[cursor + i] ^ buf[i];
return;
}
var i: usize = 0;
while (i < 64) : (i += 1)
out[cursor + i] = in[cursor + i] ^ buf[i];
cursor += 64;
remaining -= 64;
ctx[12] += 1;
}
}
fn chaCha20Internal(out: []u8, in: []const u8, key: [8]u32, counter: [4]u32) void {
var ctx = initContext(key, counter);
var x: BlockVec = undefined;
var buf: [64]u8 = undefined;
var i: usize = 0;
while (i + 64 <= in.len) : (i += 64) {
chacha20Core(x[0..], ctx);
contextFeedback(&x, ctx);
hashToBytes(buf[0..], x);
var xout = out[i..];
const xin = in[i..];
var j: usize = 0;
while (j < 64) : (j += 1) {
xout[j] = xin[j];
}
j = 0;
while (j < 64) : (j += 1) {
xout[j] ^= buf[j];
}
ctx[12] += 1;
}
if (i < in.len) {
chacha20Core(x[0..], ctx);
contextFeedback(&x, ctx);
hashToBytes(buf[0..], x);
var xout = out[i..];
const xin = in[i..];
var j: usize = 0;
while (j < in.len % 64) : (j += 1) {
xout[j] = xin[j] ^ buf[j];
}
}
}
fn hchacha20(input: [16]u8, key: [32]u8) [32]u8 {
var c: [4]u32 = undefined;
for (c) |_, i| {
c[i] = mem.readIntLittle(u32, input[4 * i ..][0..4]);
}
const ctx = initContext(keyToWords(key), c);
var x: BlockVec = undefined;
chacha20Core(x[0..], ctx);
var out: [32]u8 = undefined;
mem.writeIntLittle(u32, out[0..4], x[0]);
mem.writeIntLittle(u32, out[4..8], x[1]);
mem.writeIntLittle(u32, out[8..12], x[2]);
mem.writeIntLittle(u32, out[12..16], x[3]);
mem.writeIntLittle(u32, out[16..20], x[12]);
mem.writeIntLittle(u32, out[20..24], x[13]);
mem.writeIntLittle(u32, out[24..28], x[14]);
mem.writeIntLittle(u32, out[28..32], x[15]);
return out;
}
};
const ChaCha20Impl = if (std.Target.current.cpu.arch == .x86_64) ChaCha20VecImpl else ChaCha20NonVecImpl;
fn keyToWords(key: [32]u8) [8]u32 {
var k: [8]u32 = undefined;
k[0] = mem.readIntLittle(u32, key[0..4]);
k[1] = mem.readIntLittle(u32, key[4..8]);
k[2] = mem.readIntLittle(u32, key[8..12]);
k[3] = mem.readIntLittle(u32, key[12..16]);
k[4] = mem.readIntLittle(u32, key[16..20]);
k[5] = mem.readIntLittle(u32, key[20..24]);
k[6] = mem.readIntLittle(u32, key[24..28]);
k[7] = mem.readIntLittle(u32, key[28..32]);
var i: usize = 0;
while (i < 8) : (i += 1) {
k[i] = mem.readIntLittle(u32, key[i * 4 ..][0..4]);
}
return k;
}
@ -145,7 +340,7 @@ pub const ChaCha20IETF = struct {
c[1] = mem.readIntLittle(u32, nonce[0..4]);
c[2] = mem.readIntLittle(u32, nonce[4..8]);
c[3] = mem.readIntLittle(u32, nonce[8..12]);
chaCha20_internal(out, in, keyToWords(key), c);
ChaCha20Impl.chaCha20Internal(out, in, keyToWords(key), c);
}
};
@ -171,7 +366,7 @@ pub const ChaCha20With64BitNonce = struct {
// first partial big block
if (((@intCast(u64, maxInt(u32) - @truncate(u32, counter)) + 1) << 6) < in.len) {
chaCha20_internal(out[cursor..big_block], in[cursor..big_block], k, c);
ChaCha20Impl.chaCha20Internal(out[cursor..big_block], in[cursor..big_block], k, c);
cursor = big_block - cursor;
c[1] += 1;
if (comptime @sizeOf(usize) > 4) {
@ -179,14 +374,14 @@ pub const ChaCha20With64BitNonce = struct {
var remaining_blocks: u32 = @intCast(u32, (in.len / big_block));
var i: u32 = 0;
while (remaining_blocks > 0) : (remaining_blocks -= 1) {
chaCha20_internal(out[cursor .. cursor + big_block], in[cursor .. cursor + big_block], k, c);
c[1] += 1; // upper 32-bit of counter, generic chaCha20_internal() doesn't know about this.
ChaCha20Impl.chaCha20Internal(out[cursor .. cursor + big_block], in[cursor .. cursor + big_block], k, c);
c[1] += 1; // upper 32-bit of counter, generic chaCha20Internal() doesn't know about this.
cursor += big_block;
}
}
}
chaCha20_internal(out[cursor..], in[cursor..], k, c);
ChaCha20Impl.chaCha20Internal(out[cursor..], in[cursor..], k, c);
}
};
@ -533,33 +728,12 @@ fn chacha20poly1305Open(dst: []u8, ciphertextAndTag: []const u8, data: []const u
return try chacha20poly1305OpenDetached(dst, ciphertextAndTag[0..ciphertextLen], ciphertextAndTag[ciphertextLen..][0..chacha20poly1305_tag_size], data, key, nonce);
}
fn hchacha20(input: [16]u8, key: [32]u8) [32]u8 {
var c: [4]u32 = undefined;
for (c) |_, i| {
c[i] = mem.readIntLittle(u32, input[4 * i ..][0..4]);
}
const ctx = initContext(keyToWords(key), c);
var x: [16]u32 = undefined;
chacha20Core(x[0..], ctx);
var out: [32]u8 = undefined;
mem.writeIntLittle(u32, out[0..4], x[0]);
mem.writeIntLittle(u32, out[4..8], x[1]);
mem.writeIntLittle(u32, out[8..12], x[2]);
mem.writeIntLittle(u32, out[12..16], x[3]);
mem.writeIntLittle(u32, out[16..20], x[12]);
mem.writeIntLittle(u32, out[20..24], x[13]);
mem.writeIntLittle(u32, out[24..28], x[14]);
mem.writeIntLittle(u32, out[28..32], x[15]);
return out;
}
fn extend(key: [32]u8, nonce: [24]u8) struct { key: [32]u8, nonce: [12]u8 } {
var subnonce: [12]u8 = undefined;
mem.set(u8, subnonce[0..4], 0);
mem.copy(u8, subnonce[4..], nonce[16..24]);
return .{
.key = hchacha20(nonce[0..16].*, key),
.key = ChaCha20Impl.hchacha20(nonce[0..16].*, key),
.nonce = subnonce,
};
}