diff --git a/src/commonMain/kotlin/net/sergeych/crypto2/EncryptedKVStorage.kt b/src/commonMain/kotlin/net/sergeych/crypto2/EncryptedKVStorage.kt index 5c461c5..bf0b179 100644 --- a/src/commonMain/kotlin/net/sergeych/crypto2/EncryptedKVStorage.kt +++ b/src/commonMain/kotlin/net/sergeych/crypto2/EncryptedKVStorage.kt @@ -15,11 +15,19 @@ import kotlin.random.nextUBytes * * Keys are stored encrypted and used hashed so it is not possible to * retrieve them without knowing the encryption key. + * + * @param plainStore where to store encrypted data + * @param encryptionKey key to decrypt existing/encrypt new data. Can cause [DecryptionFailedException] + * if the key is wrong and the storage is already initialized with a new key and same [prefix] + * @param prefix prefix for keys to distinguish from other data in [plainStore] + * @param removeExisting if true, removes all existing data in [plainStore] if the [encryptionKey] can't + * decrypt existing encrypted data */ class EncryptedKVStorage( private val plainStore: KVStorage, private var encryptionKey: SymmetricKey, - private val prefix: String = "EKVS_" + private val prefix: String = "EKVS_", + removeExisting: Boolean ) : KVStorage { private val op = ProtectedOp() @@ -29,10 +37,21 @@ class EncryptedKVStorage( init { var encryptedSeed by plainStore.optStored("$prefix#seed") - seed = encryptedSeed?.let { encryptionKey.decrypt(it) } - ?: Random.nextUBytes(32).also { - encryptedSeed = encryptionKey.encrypt(it) - } + seed = try { + encryptedSeed?.let { encryptionKey.decrypt(it) } + ?: Random.nextUBytes(32).also { + encryptedSeed = encryptionKey.encrypt(it) + } + } catch (x: DecryptionFailedException) { + if (removeExisting) { + plainStore.keys.filter { it.startsWith(prefix) }.forEach { + plainStore.delete(it) + } + Random.nextUBytes(32).also { + encryptedSeed = encryptionKey.encrypt(it) + } + } else throw x + } } private fun mkkey(key: String): String = diff --git a/src/commonTest/kotlin/StorageTest.kt b/src/commonTest/kotlin/StorageTest.kt index ab8a5ed..929fca4 100644 --- a/src/commonTest/kotlin/StorageTest.kt +++ b/src/commonTest/kotlin/StorageTest.kt @@ -1,11 +1,13 @@ import kotlinx.coroutines.test.runTest import net.sergeych.bintools.* import net.sergeych.bipack.decodeFromBipack +import net.sergeych.crypto2.DecryptionFailedException import net.sergeych.crypto2.EncryptedKVStorage import net.sergeych.crypto2.SymmetricKey import net.sergeych.crypto2.initCrypto import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertFailsWith import kotlin.test.assertNull class StorageTest { @@ -15,7 +17,7 @@ class StorageTest { initCrypto() val plain = MemoryKVStorage() val key = SymmetricKey.new() - val storage = EncryptedKVStorage(plain, key) + val storage = EncryptedKVStorage(plain, key, removeExisting = false) var hello by storage.optStored() assertNull(hello) @@ -42,8 +44,9 @@ class StorageTest { assertEquals("bar", bar) assertEquals("bazz", bazz) } + fun setup(s: KVStorage, k: SymmetricKey): EncryptedKVStorage { - val x = EncryptedKVStorage(s, k) + val x = EncryptedKVStorage(s, k, removeExisting = false) var foo by x.stored("1") var bar by x.stored("2") var bazz by x.stored("3") @@ -52,6 +55,7 @@ class StorageTest { bazz = "bazz" return x } + val k1 = SymmetricKey.new() val k2 = SymmetricKey.new() val plain = MemoryKVStorage() @@ -62,6 +66,18 @@ class StorageTest { // val s2 = EncryptedKVStorage(plain, k2) // test(s2) } + + @Test + fun testDeleteExisting() = runTest { + initCrypto() + val plain = MemoryKVStorage() + val c1 = EncryptedKVStorage(plain, SymmetricKey.new(), removeExisting = false) // 1 + c1.write("hello", "world") + assertFailsWith { + val c2 = EncryptedKVStorage(plain, SymmetricKey.new(), removeExisting = false) // 2 + } + val c2 = EncryptedKVStorage(plain, SymmetricKey.new(), removeExisting = true) // 2 + } } fun KVStorage.dump() {