[aes] Improve performance slightly (#1135)

Authored by: sulyi
This commit is contained in:
Ákos Sülyi 2021-10-02 20:50:39 +02:00 committed by GitHub
parent 9359f3d4f0
commit ff1dec819a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -178,7 +178,7 @@ def aes_encrypt(data, expanded_key):
data = sub_bytes(data)
data = shift_rows(data)
if i != rounds:
data = mix_columns(data)
data = list(iter_mix_columns(data, MIX_COLUMN_MATRIX))
data = xor(data, expanded_key[i * BLOCK_SIZE_BYTES: (i + 1) * BLOCK_SIZE_BYTES])
return data
@ -197,7 +197,7 @@ def aes_decrypt(data, expanded_key):
for i in range(rounds, 0, -1):
data = xor(data, expanded_key[i * BLOCK_SIZE_BYTES: (i + 1) * BLOCK_SIZE_BYTES])
if i != rounds:
data = mix_columns_inv(data)
data = list(iter_mix_columns(data, MIX_COLUMN_MATRIX_INV))
data = shift_rows_inv(data)
data = sub_bytes_inv(data)
data = xor(data, expanded_key[:BLOCK_SIZE_BYTES])
@ -375,49 +375,23 @@ def xor(data1, data2):
return [x ^ y for x, y in zip(data1, data2)]
def rijndael_mul(a, b):
if a == 0 or b == 0:
return 0
return RIJNDAEL_EXP_TABLE[(RIJNDAEL_LOG_TABLE[a] + RIJNDAEL_LOG_TABLE[b]) % 0xFF]
def mix_column(data, matrix):
data_mixed = []
for row in range(4):
mixed = 0
for column in range(4):
# xor is (+) and (-)
mixed ^= rijndael_mul(data[column], matrix[row][column])
data_mixed.append(mixed)
return data_mixed
def mix_columns(data, matrix=MIX_COLUMN_MATRIX):
data_mixed = []
for i in range(4):
column = data[i * 4: (i + 1) * 4]
data_mixed += mix_column(column, matrix)
return data_mixed
def mix_columns_inv(data):
return mix_columns(data, MIX_COLUMN_MATRIX_INV)
def iter_mix_columns(data, matrix):
for i in (0, 4, 8, 12):
for row in matrix:
mixed = 0
for j in range(4):
# xor is (+) and (-)
mixed ^= (0 if data[i:i + 4][j] == 0 or row[j] == 0 else
RIJNDAEL_EXP_TABLE[(RIJNDAEL_LOG_TABLE[data[i + j]] + RIJNDAEL_LOG_TABLE[row[j]]) % 0xFF])
yield mixed
def shift_rows(data):
data_shifted = []
for column in range(4):
for row in range(4):
data_shifted.append(data[((column + row) & 0b11) * 4 + row])
return data_shifted
return [data[((column + row) & 0b11) * 4 + row] for column in range(4) for row in range(4)]
def shift_rows_inv(data):
data_shifted = []
for column in range(4):
for row in range(4):
data_shifted.append(data[((column - row) & 0b11) * 4 + row])
return data_shifted
return [data[((column - row) & 0b11) * 4 + row] for column in range(4) for row in range(4)]
def shift_block(data):