Added key expansion

This commit is contained in:
Ugljesa Jovanovic 2019-09-17 23:29:44 +02:00 committed by Ugljesa Jovanovic
parent d72d55ef71
commit 90fd7adcc4
No known key found for this signature in database
GPG Key ID: 33A5F353387711A5
2 changed files with 179 additions and 38 deletions

View File

@ -4,7 +4,7 @@ package com.ionspin.kotlin.crypto.symmetric
* Created by Ugljesa Jovanovic (jovanovic.ugljesa@gmail.com) on 07/Sep/2019 * Created by Ugljesa Jovanovic (jovanovic.ugljesa@gmail.com) on 07/Sep/2019
*/ */
@ExperimentalUnsignedTypes @ExperimentalUnsignedTypes
class Aes { class Aes(val aesKey: AesKey) {
companion object { companion object {
val sBox: UByteArray = val sBox: UByteArray =
ubyteArrayOf( ubyteArrayOf(
@ -50,23 +50,10 @@ class Aes {
// @formatter:on // @formatter:on
) )
val rcon: UByteArray = ubyteArrayOf(0x8DU, 0x01U, 0x02U, 0x04U, 0x08U, 0x10U, 0x20U, 0x40U, 0x80U, 0x1BU, 0x36U)
} }
sealed class AesKey(val key: String, val keyLength: Int) {
class Aes128Key(key: String) : AesKey(key, 128)
class Aes192Key(key: String) : AesKey(key, 192)
class Aes256Key(key: String) : AesKey(key, 256)
init {
checkKeyLength(key, keyLength)
}
fun checkKeyLength(key: String, expectedLength: Int) {
if ((key.length / 2) != expectedLength) {
throw RuntimeException("Invalid key length")
}
}
}
val state: UByteArray = UByteArray(16) { 0U } val state: UByteArray = UByteArray(16) { 0U }
@ -74,17 +61,23 @@ class Aes {
Array<UByte>(4) { 0U } Array<UByte>(4) { 0U }
}.toTypedArray() }.toTypedArray()
val expandedKey: Array<Array<UByte>> = expandKey()
fun subBytes() { fun subBytes() {
stateMatrix.forEachIndexed { indexRow, row -> stateMatrix.forEachIndexed { indexRow, row ->
row.forEachIndexed { indexColumn, element -> row.forEachIndexed { indexColumn, element ->
val firstDigit = (element / 16U).toInt() stateMatrix[indexRow][indexColumn] = getSBoxValue(element)
val secondDigit = (element % 16U).toInt()
val substitutionValue = sBox[firstDigit * 16 + secondDigit]
stateMatrix[indexRow][indexColumn] = substitutionValue
} }
} }
} }
fun getSBoxValue(element: UByte): UByte {
val firstDigit = (element / 16U).toInt()
val secondDigit = (element % 16U).toInt()
return sBox[firstDigit * 16 + secondDigit]
}
fun shiftRows() { fun shiftRows() {
stateMatrix[0] = arrayOf(stateMatrix[0][0], stateMatrix[0][1], stateMatrix[0][2], stateMatrix[0][3]) stateMatrix[0] = arrayOf(stateMatrix[0][0], stateMatrix[0][1], stateMatrix[0][2], stateMatrix[0][3])
stateMatrix[1] = arrayOf(stateMatrix[1][1], stateMatrix[1][2], stateMatrix[1][3], stateMatrix[1][0]) stateMatrix[1] = arrayOf(stateMatrix[1][1], stateMatrix[1][2], stateMatrix[1][3], stateMatrix[1][0])
@ -93,29 +86,35 @@ class Aes {
} }
fun mixColumns() { fun mixColumns() {
val stateMixed : Array<Array<UByte>> = (0 until 4).map { val stateMixed: Array<Array<UByte>> = (0 until 4).map {
Array<UByte>(4) { 0U } Array<UByte>(4) { 0U }
}.toTypedArray() }.toTypedArray()
for (c in 0 .. 3) { for (c in 0..3) {
stateMixed[0][c] = (2U gfm stateMatrix[0][c]) xor galoisFieldMultiply(3U, stateMatrix[1][c]) xor stateMatrix[2][c] xor stateMatrix[3][c] stateMixed[0][c] = (2U gfm stateMatrix[0][c]) xor galoisFieldMultiply(
stateMixed[1][c] = stateMatrix[0][c] xor (2U gfm stateMatrix[1][c]) xor (3U gfm stateMatrix[2][c]) xor stateMatrix[3][c] 3U,
stateMixed[2][c] = stateMatrix[0][c] xor stateMatrix[1][c] xor (2U gfm stateMatrix[2][c]) xor (3U gfm stateMatrix[3][c]) stateMatrix[1][c]
stateMixed[3][c] = 3U gfm stateMatrix[0][c] xor stateMatrix[1][c] xor stateMatrix[2][c] xor (2U gfm stateMatrix[3][c]) ) xor stateMatrix[2][c] xor stateMatrix[3][c]
stateMixed[1][c] =
stateMatrix[0][c] xor (2U gfm stateMatrix[1][c]) xor (3U gfm stateMatrix[2][c]) xor stateMatrix[3][c]
stateMixed[2][c] =
stateMatrix[0][c] xor stateMatrix[1][c] xor (2U gfm stateMatrix[2][c]) xor (3U gfm stateMatrix[3][c])
stateMixed[3][c] =
3U gfm stateMatrix[0][c] xor stateMatrix[1][c] xor stateMatrix[2][c] xor (2U gfm stateMatrix[3][c])
} }
stateMixed.copyInto(stateMatrix) stateMixed.copyInto(stateMatrix)
} }
fun galoisFieldAdd(first : UByte, second : UByte) : UByte { fun galoisFieldAdd(first: UByte, second: UByte): UByte {
return first xor second return first xor second
} }
fun galoisFieldMultiply(first : UByte, second : UByte) : UByte { fun galoisFieldMultiply(first: UByte, second: UByte): UByte {
var result : UInt = 0U var result: UInt = 0U
var firstInt = first.toUInt() var firstInt = first.toUInt()
var secondInt = second.toUInt() var secondInt = second.toUInt()
var carry : UInt = 0U var carry: UInt = 0U
for (i in 0 .. 7) { for (i in 0..7) {
if (secondInt and 0x01U == 1U) { if (secondInt and 0x01U == 1U) {
result = result xor firstInt result = result xor firstInt
} }
@ -130,11 +129,73 @@ class Aes {
return result.toUByte() return result.toUByte()
} }
infix fun UInt.gfm(second : UByte) : UByte { fun addRoundKey() {
}
infix fun UInt.gfm(second: UByte): UByte {
return galoisFieldMultiply(this.toUByte(), second) return galoisFieldMultiply(this.toUByte(), second)
} }
fun expandKey(key: AesKey) { fun expandKey(): Array<Array<UByte>> {
val expandedKey = (0 until 4 * (aesKey.numberOfRounds + 1)).map {
Array<UByte>(4) { 0U }
}.toTypedArray()
// First round
for (i in 0 until aesKey.numberOf32BitWords) {
expandedKey[i][0] = aesKey.keyArray[i * 4 + 0]
expandedKey[i][1] = aesKey.keyArray[i * 4 + 1]
expandedKey[i][2] = aesKey.keyArray[i * 4 + 2]
expandedKey[i][3] = aesKey.keyArray[i * 4 + 3]
}
for (i in aesKey.numberOf32BitWords until 4 * (aesKey.numberOfRounds + 1)) {
val temp = expandedKey[i - 1].copyOf()
if (i % aesKey.numberOf32BitWords == 0) {
//RotWord
val tempTemp = temp[0]
temp[0] = temp[1]
temp[1] = temp[2]
temp[2] = temp[3]
temp[3] = tempTemp
//SubWord
temp[0] = getSBoxValue(temp[0])
temp[1] = getSBoxValue(temp[1])
temp[2] = getSBoxValue(temp[2])
temp[3] = getSBoxValue(temp[3])
temp[0] = temp[0] xor rcon[i / aesKey.numberOf32BitWords]
} else if (aesKey is AesKey.Aes256Key && i % aesKey.numberOf32BitWords == 4) {
temp[0] = getSBoxValue(temp[0])
temp[1] = getSBoxValue(temp[1])
temp[2] = getSBoxValue(temp[2])
temp[3] = getSBoxValue(temp[3])
}
expandedKey[i] = expandedKey[i - aesKey.numberOf32BitWords].mapIndexed { index, it ->
it xor temp[index]
}.toTypedArray()
}
return expandedKey
}
}
sealed class AesKey(val key: String, val keyLength: Int, val numberOfRounds: Int) {
val keyArray: Array<UByte> = key.chunked(2).map { it.toUByte(16) }.toTypedArray()
val numberOf32BitWords = keyLength / 32
class Aes128Key(key: String) : AesKey(key, 128, 10)
class Aes192Key(key: String) : AesKey(key, 192, 12)
class Aes256Key(key: String) : AesKey(key, 256, 14)
init {
checkKeyLength(key, keyLength)
}
fun checkKeyLength(key: String, expectedLength: Int) {
if ((key.length / 2) != expectedLength / 8) {
throw RuntimeException("Invalid key length")
}
} }
} }

View File

@ -17,7 +17,7 @@ class AesTest {
ubyteArrayOf(0U, 0U, 0U, 0U).toTypedArray(), ubyteArrayOf(0U, 0U, 0U, 0U).toTypedArray(),
ubyteArrayOf(0U, 0U, 0U, 0U).toTypedArray() ubyteArrayOf(0U, 0U, 0U, 0U).toTypedArray()
) )
val aes = Aes() val aes = Aes(AesKey.Aes128Key("2b7e151628aed2a6abf7158809cf4f3c"))
fakeState.copyInto(aes.stateMatrix) fakeState.copyInto(aes.stateMatrix)
aes.subBytes() aes.subBytes()
aes.stateMatrix.forEach { aes.stateMatrix.forEach {
@ -42,7 +42,7 @@ class AesTest {
ubyteArrayOf(2U, 3U, 0U, 1U).toTypedArray(), ubyteArrayOf(2U, 3U, 0U, 1U).toTypedArray(),
ubyteArrayOf(3U, 0U, 1U, 2U).toTypedArray() ubyteArrayOf(3U, 0U, 1U, 2U).toTypedArray()
) )
val aes = Aes() val aes = Aes(AesKey.Aes128Key("2b7e151628aed2a6abf7158809cf4f3c"))
fakeState.copyInto(aes.stateMatrix) fakeState.copyInto(aes.stateMatrix)
aes.shiftRows() aes.shiftRows()
aes.stateMatrix.forEach { aes.stateMatrix.forEach {
@ -59,7 +59,7 @@ class AesTest {
assertTrue { assertTrue {
val a = 0x57U val a = 0x57U
val b = 0x83U val b = 0x83U
val aes = Aes() val aes = Aes(AesKey.Aes128Key("2b7e151628aed2a6abf7158809cf4f3c"))
val c = aes.galoisFieldMultiply(a.toUByte(), b.toUByte()) val c = aes.galoisFieldMultiply(a.toUByte(), b.toUByte())
c == 0xC1U.toUByte() c == 0xC1U.toUByte()
} }
@ -67,7 +67,7 @@ class AesTest {
assertTrue { assertTrue {
val a = 0x57U val a = 0x57U
val b = 0x13U val b = 0x13U
val aes = Aes() val aes = Aes(AesKey.Aes128Key("2b7e151628aed2a6abf7158809cf4f3c"))
val c = aes.galoisFieldMultiply(a.toUByte(), b.toUByte()) val c = aes.galoisFieldMultiply(a.toUByte(), b.toUByte())
c == 0xFEU.toUByte() c == 0xFEU.toUByte()
} }
@ -92,7 +92,7 @@ class AesTest {
ubyteArrayOf(0xbcU, 0x9dU, 0x01U, 0xc6U).toTypedArray() ubyteArrayOf(0xbcU, 0x9dU, 0x01U, 0xc6U).toTypedArray()
) )
val aes = Aes() val aes = Aes(AesKey.Aes128Key("2b7e151628aed2a6abf7158809cf4f3c"))
fakeState.copyInto(aes.stateMatrix) fakeState.copyInto(aes.stateMatrix)
aes.mixColumns() aes.mixColumns()
aes.stateMatrix.forEach { aes.stateMatrix.forEach {
@ -103,4 +103,84 @@ class AesTest {
} }
} }
@Test
fun testKeyExpansion() {
assertTrue {
val key = "2b7e151628aed2a6abf7158809cf4f3c"
val expectedExpandedKey = uintArrayOf(
// @formatter:off
0x2b7e1516U, 0x28aed2a6U, 0xabf71588U, 0x09cf4f3cU, 0xa0fafe17U, 0x88542cb1U,
0x23a33939U, 0x2a6c7605U, 0xf2c295f2U, 0x7a96b943U, 0x5935807aU, 0x7359f67fU,
0x3d80477dU, 0x4716fe3eU, 0x1e237e44U, 0x6d7a883bU, 0xef44a541U, 0xa8525b7fU,
0xb671253bU, 0xdb0bad00U, 0xd4d1c6f8U, 0x7c839d87U, 0xcaf2b8bcU, 0x11f915bcU,
0x6d88a37aU, 0x110b3efdU, 0xdbf98641U, 0xca0093fdU, 0x4e54f70eU, 0x5f5fc9f3U,
0x84a64fb2U, 0x4ea6dc4fU, 0xead27321U, 0xb58dbad2U, 0x312bf560U, 0x7f8d292fU,
0xac7766f3U, 0x19fadc21U, 0x28d12941U, 0x575c006eU, 0xd014f9a8U, 0xc9ee2589U,
0xe13f0cc8U, 0xb6630ca6U
// @formatter:on
).toTypedArray()
val aes = Aes(AesKey.Aes128Key(key))
val result = aes.expandedKey.map {
it.foldIndexed(0U) { index, acc, uByte ->
acc + (uByte.toUInt() shl (24 - index * 8))
}
}.toTypedArray()
expectedExpandedKey.contentEquals(result)
}
assertTrue {
val key = "8e73b0f7da0e6452c810f32b809079e562f8ead2522c6b7b"
val expectedExpandedKey = uintArrayOf(
// @formatter:off
0x8e73b0f7U, 0xda0e6452U, 0xc810f32bU, 0x809079e5U, 0x62f8ead2U, 0x522c6b7bU,
0xfe0c91f7U, 0x2402f5a5U, 0xec12068eU, 0x6c827f6bU, 0x0e7a95b9U, 0x5c56fec2U, 0x4db7b4bdU, 0x69b54118U,
0x85a74796U, 0xe92538fdU, 0xe75fad44U, 0xbb095386U, 0x485af057U, 0x21efb14fU, 0xa448f6d9U, 0x4d6dce24U,
0xaa326360U, 0x113b30e6U, 0xa25e7ed5U, 0x83b1cf9aU, 0x27f93943U, 0x6a94f767U, 0xc0a69407U, 0xd19da4e1U,
0xec1786ebU, 0x6fa64971U, 0x485f7032U, 0x22cb8755U, 0xe26d1352U, 0x33f0b7b3U, 0x40beeb28U, 0x2f18a259U,
0x6747d26bU, 0x458c553eU, 0xa7e1466cU, 0x9411f1dfU, 0x821f750aU, 0xad07d753U, 0xca400538U, 0x8fcc5006U,
0x282d166aU, 0xbc3ce7b5U, 0xe98ba06fU, 0x448c773cU, 0x8ecc7204U, 0x01002202U
// @formatter:on
).toTypedArray()
val aes = Aes(AesKey.Aes192Key(key))
val result = aes.expandedKey.map {
it.foldIndexed(0U) { index, acc, uByte ->
acc + (uByte.toUInt() shl (24 - index * 8))
}
}.toTypedArray()
expectedExpandedKey.contentEquals(result)
}
assertTrue {
val key = "603deb1015ca71be2b73aef0857d77811f352c073b6108d72d9810a30914dff4"
val expectedExpandedKey = uintArrayOf(
// @formatter:off
0x603deb10U, 0x15ca71beU, 0x2b73aef0U, 0x857d7781U, 0x1f352c07U, 0x3b6108d7U, 0x2d9810a3U, 0x0914dff4U,
0x9ba35411U, 0x8e6925afU, 0xa51a8b5fU, 0x2067fcdeU, 0xa8b09c1aU, 0x93d194cdU, 0xbe49846eU, 0xb75d5b9aU,
0xd59aecb8U, 0x5bf3c917U, 0xfee94248U, 0xde8ebe96U, 0xb5a9328aU, 0x2678a647U, 0x98312229U, 0x2f6c79b3U,
0x812c81adU, 0xdadf48baU, 0x24360af2U, 0xfab8b464U, 0x98c5bfc9U, 0xbebd198eU, 0x268c3ba7U, 0x09e04214U,
0x68007bacU, 0xb2df3316U, 0x96e939e4U, 0x6c518d80U, 0xc814e204U, 0x76a9fb8aU, 0x5025c02dU, 0x59c58239U,
0xde136967U, 0x6ccc5a71U, 0xfa256395U, 0x9674ee15U, 0x5886ca5dU, 0x2e2f31d7U, 0x7e0af1faU, 0x27cf73c3U,
0x749c47abU, 0x18501ddaU, 0xe2757e4fU, 0x7401905aU, 0xcafaaae3U, 0xe4d59b34U, 0x9adf6aceU, 0xbd10190dU,
0xfe4890d1U, 0xe6188d0bU, 0x046df344U, 0x706c631eU
// @formatter:on
).toTypedArray()
val aes = Aes(AesKey.Aes256Key(key))
val result = aes.expandedKey.map {
it.foldIndexed(0U) { index, acc, uByte ->
acc + (uByte.toUInt() shl (24 - index * 8))
}
}.toTypedArray()
expectedExpandedKey.contentEquals(result)
}
}
} }