From de5969972ed120ea0069f7b7d3225edb1fdaf360 Mon Sep 17 00:00:00 2001 From: "Daniel C. Jones" Date: Mon, 22 Jan 2018 11:26:54 -0800 Subject: [PATCH] Optimize SHA1 implementation. (#47) --- src/common.jl | 8 ++------ src/constants.jl | 2 +- src/sha3.jl | 35 ++++++++++++++++++++--------------- src/types.jl | 14 +++++++++----- test/perf.jl | 2 +- 5 files changed, 33 insertions(+), 28 deletions(-) diff --git a/src/common.jl b/src/common.jl index 1877ae6..26050b6 100644 --- a/src/common.jl +++ b/src/common.jl @@ -13,9 +13,7 @@ function update!(context::T, data::U) where {T<:SHA_CTX, usedspace = context.bytecount % blocklen(T) while len - data_idx + usedspace >= blocklen(T) # Fill up as much of the buffer as we can with the data given us - for i in 1:(blocklen(T) - usedspace) - context.buffer[usedspace + i] = data[data_idx + i] - end + copy!(context.buffer, usedspace + 1, data, data_idx + 1, blocklen(T) - usedspace) transform!(context) context.bytecount += blocklen(T) - usedspace @@ -25,9 +23,7 @@ function update!(context::T, data::U) where {T<:SHA_CTX, # There is less than a complete block left, but we need to save the leftovers into context.buffer: if len > data_idx - for i = 1:(len - data_idx) - context.buffer[usedspace + i] = data[data_idx + i] - end + copy!(context.buffer, usedspace + 1, data, data_idx + 1, len - data_idx) context.bytecount += len - data_idx end end diff --git a/src/constants.jl b/src/constants.jl index fb5d5ef..37c152e 100644 --- a/src/constants.jl +++ b/src/constants.jl @@ -123,7 +123,7 @@ const SHA3_ROTC = UInt64[ ] # Permutation indices for SHA3 rounds (+1'ed so as to work with julia's 1-based indexing) -const SHA3_PILN = UInt64[ +const SHA3_PILN = Int[ 11, 8, 12, 18, 19, 4, 6, 17, 9, 22, 25, 5, 16, 24, 20, 14, 13, 3, 21, 15, 23, 10, 7, 2 ] diff --git a/src/sha3.jl b/src/sha3.jl index 9a76ee7..89e15e7 100644 --- a/src/sha3.jl +++ b/src/sha3.jl @@ -4,43 +4,48 @@ function transform!(context::T) where {T<:SHA3_CTX} for idx in 1:div(blocklen(T),8) context.state[idx] = context.state[idx] ⊻ unsafe_load(pbuf, idx) end - bc = Vector{UInt64}(uninitialized, 5) + bc = context.bc + state = context.state # We always assume 24 rounds - for round in 0:23 + @inbounds for round in 0:23 # Theta function for i in 1:5 - bc[i] = context.state[i] ⊻ context.state[i + 5] ⊻ context.state[i + 10] ⊻ context.state[i + 15] ⊻ context.state[i + 20] + bc[i] = state[i] ⊻ state[i + 5] ⊻ state[i + 10] ⊻ state[i + 15] ⊻ state[i + 20] end - for i in 1:5 - temp = bc[mod1(i + 4, 5)] ⊻ L64(1, bc[mod1(i + 1, 5)]) - for j in 0:5:20 - context.state[i + j] = context.state[i + j] ⊻ temp + for i in 0:4 + temp = bc[rem(i + 4, 5) + 1] ⊻ L64(1, bc[rem(i + 1, 5) + 1]) + j = 0 + while j <= 20 + state[Int(i + j + 1)] = state[i + j + 1] ⊻ temp + j += 5 end end # Rho Pi - temp = context.state[2] + temp = state[2] for i in 1:24 j = SHA3_PILN[i] - bc[1] = context.state[j] - context.state[j] = L64(SHA3_ROTC[i], temp) + bc[1] = state[j] + state[j] = L64(SHA3_ROTC[i], temp) temp = bc[1] end # Chi - for j in 0:5:20 + j = 0 + while j <= 20 for i in 1:5 - bc[i] = context.state[i + j] + bc[i] = state[i + j] end - for i in 1:5 - context.state[j + i] = context.state[j + i] ⊻ (~bc[mod1(i + 1, 5)] & bc[mod1(i + 2, 5)]) + for i in 0:4 + state[j + i + 1] = state[j + i + 1] ⊻ (~bc[rem(i + 1, 5) + 1] & bc[rem(i + 2, 5) + 1]) end + j += 5 end # Iota - context.state[1] = context.state[1] ⊻ SHA3_ROUND_CONSTS[round+1] + state[1] = state[1] ⊻ SHA3_ROUND_CONSTS[round+1] end return context.state diff --git a/src/types.jl b/src/types.jl index a9a7292..b52023b 100644 --- a/src/types.jl +++ b/src/types.jl @@ -51,21 +51,25 @@ mutable struct SHA3_224_CTX <: SHA3_CTX state::Array{UInt64,1} bytecount::UInt128 buffer::Array{UInt8,1} + bc::Array{UInt64,1} end mutable struct SHA3_256_CTX <: SHA3_CTX state::Array{UInt64,1} bytecount::UInt128 buffer::Array{UInt8,1} + bc::Array{UInt64,1} end mutable struct SHA3_384_CTX <: SHA3_CTX state::Array{UInt64,1} bytecount::UInt128 buffer::Array{UInt8,1} + bc::Array{UInt64,1} end mutable struct SHA3_512_CTX <: SHA3_CTX state::Array{UInt64,1} bytecount::UInt128 buffer::Array{UInt8,1} + bc::Array{UInt64,1} end # Define constants via functions so as not to bloat context objects. Yay dispatch! @@ -111,10 +115,10 @@ SHA2_256_CTX() = SHA2_256_CTX(copy(SHA2_256_initial_hash_value), 0, zeros(UInt8, SHA2_384_CTX() = SHA2_384_CTX(copy(SHA2_384_initial_hash_value), 0, zeros(UInt8, blocklen(SHA2_384_CTX))) SHA2_512_CTX() = SHA2_512_CTX(copy(SHA2_512_initial_hash_value), 0, zeros(UInt8, blocklen(SHA2_512_CTX))) -SHA3_224_CTX() = SHA3_224_CTX(zeros(UInt64, 25), 0, zeros(UInt8, blocklen(SHA3_224_CTX))) -SHA3_256_CTX() = SHA3_256_CTX(zeros(UInt64, 25), 0, zeros(UInt8, blocklen(SHA3_256_CTX))) -SHA3_384_CTX() = SHA3_384_CTX(zeros(UInt64, 25), 0, zeros(UInt8, blocklen(SHA3_384_CTX))) -SHA3_512_CTX() = SHA3_512_CTX(zeros(UInt64, 25), 0, zeros(UInt8, blocklen(SHA3_512_CTX))) +SHA3_224_CTX() = SHA3_224_CTX(zeros(UInt64, 25), 0, zeros(UInt8, blocklen(SHA3_224_CTX)), Vector{UInt64}(5)) +SHA3_256_CTX() = SHA3_256_CTX(zeros(UInt64, 25), 0, zeros(UInt8, blocklen(SHA3_256_CTX)), Vector{UInt64}(5)) +SHA3_384_CTX() = SHA3_384_CTX(zeros(UInt64, 25), 0, zeros(UInt8, blocklen(SHA3_384_CTX)), Vector{UInt64}(5)) +SHA3_512_CTX() = SHA3_512_CTX(zeros(UInt64, 25), 0, zeros(UInt8, blocklen(SHA3_512_CTX)), Vector{UInt64}(5)) # Nickname'd outer constructor methods for SHA2 const SHA224_CTX = SHA2_224_CTX @@ -129,7 +133,7 @@ SHA1_CTX() = SHA1_CTX(copy(SHA1_initial_hash_value), 0, zeros(UInt8, blocklen(SH # Copy functions copy(ctx::T) where {T<:SHA1_CTX} = T(copy(ctx.state), ctx.bytecount, copy(ctx.buffer), copy(ctx.W)) copy(ctx::T) where {T<:SHA2_CTX} = T(copy(ctx.state), ctx.bytecount, copy(ctx.buffer)) -copy(ctx::T) where {T<:SHA3_CTX} = T(copy(ctx.state), ctx.bytecount, copy(ctx.buffer)) +copy(ctx::T) where {T<:SHA3_CTX} = T(copy(ctx.state), ctx.bytecount, copy(ctx.buffer), Array{UInt64}(5)) # Make printing these types a little friendlier diff --git a/test/perf.jl b/test/perf.jl index c130288..a366a6f 100644 --- a/test/perf.jl +++ b/test/perf.jl @@ -12,7 +12,7 @@ function do_tests(filepath) print("read: ") @time begin const fh = open(filepath, "r") - const bytes = readbytes(fh) + const bytes = read(fh) end gc()