Added key expansion

This commit is contained in:
Ugljesa Jovanovic 2019-09-17 23:29:44 +02:00 committed by Ugljesa Jovanovic
parent d72d55ef71
commit 90fd7adcc4
No known key found for this signature in database
GPG Key ID: 33A5F353387711A5
2 changed files with 179 additions and 38 deletions

View File

@ -4,7 +4,7 @@ package com.ionspin.kotlin.crypto.symmetric
* Created by Ugljesa Jovanovic (jovanovic.ugljesa@gmail.com) on 07/Sep/2019
*/
@ExperimentalUnsignedTypes
class Aes {
class Aes(val aesKey: AesKey) {
companion object {
val sBox: UByteArray =
ubyteArrayOf(
@ -50,23 +50,10 @@ class Aes {
// @formatter:on
)
val rcon: UByteArray = ubyteArrayOf(0x8DU, 0x01U, 0x02U, 0x04U, 0x08U, 0x10U, 0x20U, 0x40U, 0x80U, 0x1BU, 0x36U)
}
sealed class AesKey(val key: String, val keyLength: Int) {
class Aes128Key(key: String) : AesKey(key, 128)
class Aes192Key(key: String) : AesKey(key, 192)
class Aes256Key(key: String) : AesKey(key, 256)
init {
checkKeyLength(key, keyLength)
}
fun checkKeyLength(key: String, expectedLength: Int) {
if ((key.length / 2) != expectedLength) {
throw RuntimeException("Invalid key length")
}
}
}
val state: UByteArray = UByteArray(16) { 0U }
@ -74,15 +61,21 @@ class Aes {
Array<UByte>(4) { 0U }
}.toTypedArray()
val expandedKey: Array<Array<UByte>> = expandKey()
fun subBytes() {
stateMatrix.forEachIndexed { indexRow, row ->
row.forEachIndexed { indexColumn, element ->
stateMatrix[indexRow][indexColumn] = getSBoxValue(element)
}
}
}
fun getSBoxValue(element: UByte): UByte {
val firstDigit = (element / 16U).toInt()
val secondDigit = (element % 16U).toInt()
val substitutionValue = sBox[firstDigit * 16 + secondDigit]
stateMatrix[indexRow][indexColumn] = substitutionValue
}
}
return sBox[firstDigit * 16 + secondDigit]
}
fun shiftRows() {
@ -98,10 +91,16 @@ class Aes {
}.toTypedArray()
for (c in 0..3) {
stateMixed[0][c] = (2U gfm stateMatrix[0][c]) xor galoisFieldMultiply(3U, stateMatrix[1][c]) xor stateMatrix[2][c] xor stateMatrix[3][c]
stateMixed[1][c] = stateMatrix[0][c] xor (2U gfm stateMatrix[1][c]) xor (3U gfm stateMatrix[2][c]) xor stateMatrix[3][c]
stateMixed[2][c] = stateMatrix[0][c] xor stateMatrix[1][c] xor (2U gfm stateMatrix[2][c]) xor (3U gfm stateMatrix[3][c])
stateMixed[3][c] = 3U gfm stateMatrix[0][c] xor stateMatrix[1][c] xor stateMatrix[2][c] xor (2U gfm stateMatrix[3][c])
stateMixed[0][c] = (2U gfm stateMatrix[0][c]) xor galoisFieldMultiply(
3U,
stateMatrix[1][c]
) xor stateMatrix[2][c] xor stateMatrix[3][c]
stateMixed[1][c] =
stateMatrix[0][c] xor (2U gfm stateMatrix[1][c]) xor (3U gfm stateMatrix[2][c]) xor stateMatrix[3][c]
stateMixed[2][c] =
stateMatrix[0][c] xor stateMatrix[1][c] xor (2U gfm stateMatrix[2][c]) xor (3U gfm stateMatrix[3][c])
stateMixed[3][c] =
3U gfm stateMatrix[0][c] xor stateMatrix[1][c] xor stateMatrix[2][c] xor (2U gfm stateMatrix[3][c])
}
stateMixed.copyInto(stateMatrix)
}
@ -130,11 +129,73 @@ class Aes {
return result.toUByte()
}
fun addRoundKey() {
}
infix fun UInt.gfm(second: UByte): UByte {
return galoisFieldMultiply(this.toUByte(), second)
}
fun expandKey(key: AesKey) {
fun expandKey(): Array<Array<UByte>> {
val expandedKey = (0 until 4 * (aesKey.numberOfRounds + 1)).map {
Array<UByte>(4) { 0U }
}.toTypedArray()
// First round
for (i in 0 until aesKey.numberOf32BitWords) {
expandedKey[i][0] = aesKey.keyArray[i * 4 + 0]
expandedKey[i][1] = aesKey.keyArray[i * 4 + 1]
expandedKey[i][2] = aesKey.keyArray[i * 4 + 2]
expandedKey[i][3] = aesKey.keyArray[i * 4 + 3]
}
for (i in aesKey.numberOf32BitWords until 4 * (aesKey.numberOfRounds + 1)) {
val temp = expandedKey[i - 1].copyOf()
if (i % aesKey.numberOf32BitWords == 0) {
//RotWord
val tempTemp = temp[0]
temp[0] = temp[1]
temp[1] = temp[2]
temp[2] = temp[3]
temp[3] = tempTemp
//SubWord
temp[0] = getSBoxValue(temp[0])
temp[1] = getSBoxValue(temp[1])
temp[2] = getSBoxValue(temp[2])
temp[3] = getSBoxValue(temp[3])
temp[0] = temp[0] xor rcon[i / aesKey.numberOf32BitWords]
} else if (aesKey is AesKey.Aes256Key && i % aesKey.numberOf32BitWords == 4) {
temp[0] = getSBoxValue(temp[0])
temp[1] = getSBoxValue(temp[1])
temp[2] = getSBoxValue(temp[2])
temp[3] = getSBoxValue(temp[3])
}
expandedKey[i] = expandedKey[i - aesKey.numberOf32BitWords].mapIndexed { index, it ->
it xor temp[index]
}.toTypedArray()
}
return expandedKey
}
}
sealed class AesKey(val key: String, val keyLength: Int, val numberOfRounds: Int) {
val keyArray: Array<UByte> = key.chunked(2).map { it.toUByte(16) }.toTypedArray()
val numberOf32BitWords = keyLength / 32
class Aes128Key(key: String) : AesKey(key, 128, 10)
class Aes192Key(key: String) : AesKey(key, 192, 12)
class Aes256Key(key: String) : AesKey(key, 256, 14)
init {
checkKeyLength(key, keyLength)
}
fun checkKeyLength(key: String, expectedLength: Int) {
if ((key.length / 2) != expectedLength / 8) {
throw RuntimeException("Invalid key length")
}
}
}

View File

@ -17,7 +17,7 @@ class AesTest {
ubyteArrayOf(0U, 0U, 0U, 0U).toTypedArray(),
ubyteArrayOf(0U, 0U, 0U, 0U).toTypedArray()
)
val aes = Aes()
val aes = Aes(AesKey.Aes128Key("2b7e151628aed2a6abf7158809cf4f3c"))
fakeState.copyInto(aes.stateMatrix)
aes.subBytes()
aes.stateMatrix.forEach {
@ -42,7 +42,7 @@ class AesTest {
ubyteArrayOf(2U, 3U, 0U, 1U).toTypedArray(),
ubyteArrayOf(3U, 0U, 1U, 2U).toTypedArray()
)
val aes = Aes()
val aes = Aes(AesKey.Aes128Key("2b7e151628aed2a6abf7158809cf4f3c"))
fakeState.copyInto(aes.stateMatrix)
aes.shiftRows()
aes.stateMatrix.forEach {
@ -59,7 +59,7 @@ class AesTest {
assertTrue {
val a = 0x57U
val b = 0x83U
val aes = Aes()
val aes = Aes(AesKey.Aes128Key("2b7e151628aed2a6abf7158809cf4f3c"))
val c = aes.galoisFieldMultiply(a.toUByte(), b.toUByte())
c == 0xC1U.toUByte()
}
@ -67,7 +67,7 @@ class AesTest {
assertTrue {
val a = 0x57U
val b = 0x13U
val aes = Aes()
val aes = Aes(AesKey.Aes128Key("2b7e151628aed2a6abf7158809cf4f3c"))
val c = aes.galoisFieldMultiply(a.toUByte(), b.toUByte())
c == 0xFEU.toUByte()
}
@ -92,7 +92,7 @@ class AesTest {
ubyteArrayOf(0xbcU, 0x9dU, 0x01U, 0xc6U).toTypedArray()
)
val aes = Aes()
val aes = Aes(AesKey.Aes128Key("2b7e151628aed2a6abf7158809cf4f3c"))
fakeState.copyInto(aes.stateMatrix)
aes.mixColumns()
aes.stateMatrix.forEach {
@ -103,4 +103,84 @@ class AesTest {
}
}
@Test
fun testKeyExpansion() {
assertTrue {
val key = "2b7e151628aed2a6abf7158809cf4f3c"
val expectedExpandedKey = uintArrayOf(
// @formatter:off
0x2b7e1516U, 0x28aed2a6U, 0xabf71588U, 0x09cf4f3cU, 0xa0fafe17U, 0x88542cb1U,
0x23a33939U, 0x2a6c7605U, 0xf2c295f2U, 0x7a96b943U, 0x5935807aU, 0x7359f67fU,
0x3d80477dU, 0x4716fe3eU, 0x1e237e44U, 0x6d7a883bU, 0xef44a541U, 0xa8525b7fU,
0xb671253bU, 0xdb0bad00U, 0xd4d1c6f8U, 0x7c839d87U, 0xcaf2b8bcU, 0x11f915bcU,
0x6d88a37aU, 0x110b3efdU, 0xdbf98641U, 0xca0093fdU, 0x4e54f70eU, 0x5f5fc9f3U,
0x84a64fb2U, 0x4ea6dc4fU, 0xead27321U, 0xb58dbad2U, 0x312bf560U, 0x7f8d292fU,
0xac7766f3U, 0x19fadc21U, 0x28d12941U, 0x575c006eU, 0xd014f9a8U, 0xc9ee2589U,
0xe13f0cc8U, 0xb6630ca6U
// @formatter:on
).toTypedArray()
val aes = Aes(AesKey.Aes128Key(key))
val result = aes.expandedKey.map {
it.foldIndexed(0U) { index, acc, uByte ->
acc + (uByte.toUInt() shl (24 - index * 8))
}
}.toTypedArray()
expectedExpandedKey.contentEquals(result)
}
assertTrue {
val key = "8e73b0f7da0e6452c810f32b809079e562f8ead2522c6b7b"
val expectedExpandedKey = uintArrayOf(
// @formatter:off
0x8e73b0f7U, 0xda0e6452U, 0xc810f32bU, 0x809079e5U, 0x62f8ead2U, 0x522c6b7bU,
0xfe0c91f7U, 0x2402f5a5U, 0xec12068eU, 0x6c827f6bU, 0x0e7a95b9U, 0x5c56fec2U, 0x4db7b4bdU, 0x69b54118U,
0x85a74796U, 0xe92538fdU, 0xe75fad44U, 0xbb095386U, 0x485af057U, 0x21efb14fU, 0xa448f6d9U, 0x4d6dce24U,
0xaa326360U, 0x113b30e6U, 0xa25e7ed5U, 0x83b1cf9aU, 0x27f93943U, 0x6a94f767U, 0xc0a69407U, 0xd19da4e1U,
0xec1786ebU, 0x6fa64971U, 0x485f7032U, 0x22cb8755U, 0xe26d1352U, 0x33f0b7b3U, 0x40beeb28U, 0x2f18a259U,
0x6747d26bU, 0x458c553eU, 0xa7e1466cU, 0x9411f1dfU, 0x821f750aU, 0xad07d753U, 0xca400538U, 0x8fcc5006U,
0x282d166aU, 0xbc3ce7b5U, 0xe98ba06fU, 0x448c773cU, 0x8ecc7204U, 0x01002202U
// @formatter:on
).toTypedArray()
val aes = Aes(AesKey.Aes192Key(key))
val result = aes.expandedKey.map {
it.foldIndexed(0U) { index, acc, uByte ->
acc + (uByte.toUInt() shl (24 - index * 8))
}
}.toTypedArray()
expectedExpandedKey.contentEquals(result)
}
assertTrue {
val key = "603deb1015ca71be2b73aef0857d77811f352c073b6108d72d9810a30914dff4"
val expectedExpandedKey = uintArrayOf(
// @formatter:off
0x603deb10U, 0x15ca71beU, 0x2b73aef0U, 0x857d7781U, 0x1f352c07U, 0x3b6108d7U, 0x2d9810a3U, 0x0914dff4U,
0x9ba35411U, 0x8e6925afU, 0xa51a8b5fU, 0x2067fcdeU, 0xa8b09c1aU, 0x93d194cdU, 0xbe49846eU, 0xb75d5b9aU,
0xd59aecb8U, 0x5bf3c917U, 0xfee94248U, 0xde8ebe96U, 0xb5a9328aU, 0x2678a647U, 0x98312229U, 0x2f6c79b3U,
0x812c81adU, 0xdadf48baU, 0x24360af2U, 0xfab8b464U, 0x98c5bfc9U, 0xbebd198eU, 0x268c3ba7U, 0x09e04214U,
0x68007bacU, 0xb2df3316U, 0x96e939e4U, 0x6c518d80U, 0xc814e204U, 0x76a9fb8aU, 0x5025c02dU, 0x59c58239U,
0xde136967U, 0x6ccc5a71U, 0xfa256395U, 0x9674ee15U, 0x5886ca5dU, 0x2e2f31d7U, 0x7e0af1faU, 0x27cf73c3U,
0x749c47abU, 0x18501ddaU, 0xe2757e4fU, 0x7401905aU, 0xcafaaae3U, 0xe4d59b34U, 0x9adf6aceU, 0xbd10190dU,
0xfe4890d1U, 0xe6188d0bU, 0x046df344U, 0x706c631eU
// @formatter:on
).toTypedArray()
val aes = Aes(AesKey.Aes256Key(key))
val result = aes.expandedKey.map {
it.foldIndexed(0U) { index, acc, uByte ->
acc + (uByte.toUInt() shl (24 - index * 8))
}
}.toTypedArray()
expectedExpandedKey.contentEquals(result)
}
}
}