Add missing imports and use usePinned to guarantee proper unpinning

This commit is contained in:
Johannes Leupold 2024-08-12 16:15:46 +02:00
parent 6a66654436
commit e785bde585

View File

@ -1,19 +1,40 @@
package com.ionspin.kotlin.crypto.ristretto255 package com.ionspin.kotlin.crypto.ristretto255
import com.ionspin.kotlin.crypto.util.toPtr import com.ionspin.kotlin.crypto.util.toPtr
import kotlinx.cinterop.pin import kotlinx.cinterop.usePinned
import libsodium.crypto_core_ristretto255_add
import libsodium.crypto_core_ristretto255_from_hash
import libsodium.crypto_core_ristretto255_is_valid_point
import libsodium.crypto_core_ristretto255_random
import libsodium.crypto_core_ristretto255_scalar_add
import libsodium.crypto_core_ristretto255_scalar_complement
import libsodium.crypto_core_ristretto255_scalar_invert
import libsodium.crypto_core_ristretto255_scalar_mul
import libsodium.crypto_core_ristretto255_scalar_negate
import libsodium.crypto_core_ristretto255_scalar_random
import libsodium.crypto_core_ristretto255_scalar_reduce
import libsodium.crypto_core_ristretto255_scalar_sub
import libsodium.crypto_core_ristretto255_sub
import libsodium.crypto_scalarmult_ristretto255
import libsodium.crypto_scalarmult_ristretto255_base
actual abstract class Ristretto255LowLevel actual constructor() { actual abstract class Ristretto255LowLevel actual constructor() {
actual fun isValidPoint(encoded: UByteArray): Boolean { actual fun isValidPoint(encoded: UByteArray): Boolean {
return crypto_core_ristretto255_is_valid_point(encoded.pin().toPtr()) == 1 return encoded.usePinned { crypto_core_ristretto255_is_valid_point(it.toPtr()) == 1 }
} }
actual fun addPoints(p: UByteArray, q: UByteArray): UByteArray { actual fun addPoints(p: UByteArray, q: UByteArray): UByteArray {
val result = UByteArray(crypto_core_ristretto255_BYTES) val result = UByteArray(crypto_core_ristretto255_BYTES)
crypto_core_ristretto255_add(result.pin().toPtr(), p.pin().toPtr(), q.pin().toPtr()) result.usePinned { resultPinned ->
.ensureLibsodiumSuccess() p.usePinned { pPinned ->
q.usePinned { qPinned ->
crypto_core_ristretto255_add(resultPinned.toPtr(), pPinned.toPtr(), qPinned.toPtr())
.ensureLibsodiumSuccess()
}
}
}
return result return result
} }
@ -21,8 +42,14 @@ actual abstract class Ristretto255LowLevel actual constructor() {
actual fun subtractPoints(p: UByteArray, q: UByteArray): UByteArray { actual fun subtractPoints(p: UByteArray, q: UByteArray): UByteArray {
val result = UByteArray(crypto_core_ristretto255_BYTES) val result = UByteArray(crypto_core_ristretto255_BYTES)
crypto_core_ristretto255_sub(result.pin().toPtr(), p.pin().toPtr(), q.pin().toPtr()) result.usePinned { resultPinned ->
.ensureLibsodiumSuccess() p.usePinned { pPinned ->
q.usePinned { qPinned ->
crypto_core_ristretto255_sub(resultPinned.toPtr(), pPinned.toPtr(), qPinned.toPtr())
.ensureLibsodiumSuccess()
}
}
}
return result return result
} }
@ -30,23 +57,32 @@ actual abstract class Ristretto255LowLevel actual constructor() {
actual fun encodedPointFromHash(hash: UByteArray): UByteArray { actual fun encodedPointFromHash(hash: UByteArray): UByteArray {
val result = UByteArray(crypto_core_ristretto255_BYTES) val result = UByteArray(crypto_core_ristretto255_BYTES)
crypto_core_ristretto255_from_hash(result.pin().toPtr(), hash.pin().toPtr()) result.usePinned { resultPinned ->
hash.usePinned { hashPinned ->
crypto_core_ristretto255_from_hash(resultPinned.toPtr(), hashPinned.toPtr())
}
}
return result return result
} }
actual fun randomEncodedPoint(): UByteArray = UByteArray(crypto_core_ristretto255_BYTES).also { actual fun randomEncodedPoint(): UByteArray = UByteArray(crypto_core_ristretto255_BYTES).apply {
crypto_core_ristretto255_random(it.pin().toPtr()) usePinned { crypto_core_ristretto255_random(it.toPtr()) }
} }
actual fun randomEncodedScalar(): UByteArray = UByteArray(crypto_core_ristretto255_SCALARBYTES).also { actual fun randomEncodedScalar(): UByteArray = UByteArray(crypto_core_ristretto255_SCALARBYTES).apply {
crypto_core_ristretto255_scalar_random(it.pin().toPtr()) usePinned { crypto_core_ristretto255_scalar_random(it.toPtr()) }
} }
actual fun invert(scalar: UByteArray): UByteArray { actual fun invert(scalar: UByteArray): UByteArray {
val result = UByteArray(crypto_core_ristretto255_SCALARBYTES) val result = UByteArray(crypto_core_ristretto255_SCALARBYTES)
crypto_core_ristretto255_scalar_invert(result.pin().toPtr(), scalar.pin().toPtr()).ensureLibsodiumSuccess() result.usePinned { resultPinned ->
scalar.usePinned { scalarPinned ->
crypto_core_ristretto255_scalar_invert(resultPinned.toPtr(), scalarPinned.toPtr()).ensureLibsodiumSuccess()
}
}
return result return result
} }
@ -54,7 +90,12 @@ actual abstract class Ristretto255LowLevel actual constructor() {
actual fun negate(scalar: UByteArray): UByteArray { actual fun negate(scalar: UByteArray): UByteArray {
val result = UByteArray(crypto_core_ristretto255_SCALARBYTES) val result = UByteArray(crypto_core_ristretto255_SCALARBYTES)
crypto_core_ristretto255_scalar_negate(result.pin().toPtr(), scalar.pin().toPtr()) result.usePinned { resultPinned ->
scalar.usePinned { scalarPinned ->
crypto_core_ristretto255_scalar_negate(resultPinned.toPtr(), scalarPinned.toPtr())
}
}
return result return result
} }
@ -62,7 +103,11 @@ actual abstract class Ristretto255LowLevel actual constructor() {
actual fun complement(scalar: UByteArray): UByteArray { actual fun complement(scalar: UByteArray): UByteArray {
val result = UByteArray(crypto_core_ristretto255_SCALARBYTES) val result = UByteArray(crypto_core_ristretto255_SCALARBYTES)
crypto_core_ristretto255_scalar_complement(result.pin().toPtr(), scalar.pin().toPtr()) result.usePinned { resultPinned ->
scalar.usePinned { scalarPinned ->
crypto_core_ristretto255_scalar_complement(resultPinned.toPtr(), scalarPinned.toPtr())
}
}
return result return result
} }
@ -70,7 +115,13 @@ actual abstract class Ristretto255LowLevel actual constructor() {
actual fun addScalars(x: UByteArray, y: UByteArray): UByteArray { actual fun addScalars(x: UByteArray, y: UByteArray): UByteArray {
val result = UByteArray(crypto_core_ristretto255_SCALARBYTES) val result = UByteArray(crypto_core_ristretto255_SCALARBYTES)
crypto_core_ristretto255_scalar_add(result.pin().toPtr(), x.pin().toPtr(), y.pin().toPtr()) result.usePinned { resultPinned ->
x.usePinned { xPinned ->
y.usePinned { yPinned ->
crypto_core_ristretto255_scalar_add(resultPinned.toPtr(), xPinned.toPtr(), yPinned.toPtr())
}
}
}
return result return result
} }
@ -78,7 +129,13 @@ actual abstract class Ristretto255LowLevel actual constructor() {
actual fun subtractScalars(x: UByteArray, y: UByteArray): UByteArray { actual fun subtractScalars(x: UByteArray, y: UByteArray): UByteArray {
val result = UByteArray(crypto_core_ristretto255_SCALARBYTES) val result = UByteArray(crypto_core_ristretto255_SCALARBYTES)
crypto_core_ristretto255_scalar_sub(result.pin().toPtr(), x.pin().toPtr(), y.pin().toPtr()) result.usePinned { resultPinned ->
x.usePinned { xPinned ->
y.usePinned { yPinned ->
crypto_core_ristretto255_scalar_sub(resultPinned.toPtr(), xPinned.toPtr(), yPinned.toPtr())
}
}
}
return result return result
} }
@ -86,7 +143,13 @@ actual abstract class Ristretto255LowLevel actual constructor() {
actual fun multiplyScalars(x: UByteArray, y: UByteArray): UByteArray { actual fun multiplyScalars(x: UByteArray, y: UByteArray): UByteArray {
val result = UByteArray(crypto_core_ristretto255_SCALARBYTES) val result = UByteArray(crypto_core_ristretto255_SCALARBYTES)
crypto_core_ristretto255_scalar_mul(result.pin().toPtr(), x.pin().toPtr(), y.pin().toPtr()) result.usePinned { resultPinned ->
x.usePinned { xPinned ->
y.usePinned { yPinned ->
crypto_core_ristretto255_scalar_mul(resultPinned.toPtr(), xPinned.toPtr(), yPinned.toPtr())
}
}
}
return result return result
} }
@ -94,7 +157,11 @@ actual abstract class Ristretto255LowLevel actual constructor() {
actual fun reduce(scalar: UByteArray): UByteArray { actual fun reduce(scalar: UByteArray): UByteArray {
val result = UByteArray(crypto_core_ristretto255_SCALARBYTES) val result = UByteArray(crypto_core_ristretto255_SCALARBYTES)
crypto_core_ristretto255_scalar_reduce(result.pin().toPtr(), scalar.pin().toPtr()) result.usePinned { resultPinned ->
scalar.usePinned { scalarPinned ->
crypto_core_ristretto255_scalar_reduce(resultPinned.toPtr(), scalarPinned.toPtr())
}
}
return result return result
} }
@ -102,7 +169,14 @@ actual abstract class Ristretto255LowLevel actual constructor() {
actual fun scalarMultiplication(n: UByteArray, p: UByteArray): UByteArray { actual fun scalarMultiplication(n: UByteArray, p: UByteArray): UByteArray {
val result = UByteArray(crypto_core_ristretto255_BYTES) val result = UByteArray(crypto_core_ristretto255_BYTES)
crypto_scalarmult_ristretto255(result.pin().toPtr(), n.pin().toPtr(), p.pin().toPtr()) result.usePinned { resultPinned ->
n.usePinned { nPinned ->
p.usePinned { pPinned ->
crypto_scalarmult_ristretto255(resultPinned.toPtr(), nPinned.toPtr(), pPinned.toPtr())
}
}
}
return result return result
} }
@ -110,7 +184,11 @@ actual abstract class Ristretto255LowLevel actual constructor() {
actual fun scalarMultiplicationBase(n: UByteArray): UByteArray { actual fun scalarMultiplicationBase(n: UByteArray): UByteArray {
val result = UByteArray(crypto_core_ristretto255_BYTES) val result = UByteArray(crypto_core_ristretto255_BYTES)
crypto_scalarmult_ristretto255_base(result.pin().toPtr(), n.pin().toPtr()) result.usePinned { resultPinned ->
n.usePinned { nPinned ->
crypto_scalarmult_ristretto255_base(resultPinned.toPtr(), nPinned.toPtr())
}
}
return result return result
} }