diff --git a/multiplatform-crypto-libsodium-bindings/src/commonMain/kotlin/com.ionspin.kotlin.crypto/ed25519/Ed25519.kt b/multiplatform-crypto-libsodium-bindings/src/commonMain/kotlin/com.ionspin.kotlin.crypto/ed25519/Ed25519.kt index 92d2bfe..f8ba0e2 100644 --- a/multiplatform-crypto-libsodium-bindings/src/commonMain/kotlin/com.ionspin.kotlin.crypto/ed25519/Ed25519.kt +++ b/multiplatform-crypto-libsodium-bindings/src/commonMain/kotlin/com.ionspin.kotlin.crypto/ed25519/Ed25519.kt @@ -17,7 +17,7 @@ const val crypto_core_ed25519_NONREDUCEDSCALARBYTES = 64 const val crypto_scalarmult_ed25519_BYTES = 32U const val crypto_scalarmult_ed25519_SCALARBYTES = 32U -expect abstract class Ed25519LowLevel() { +expect object Ed25519LowLevel { fun isValidPoint(encoded: UByteArray): Boolean fun addPoints(p: UByteArray, q: UByteArray): UByteArray fun subtractPoints(p: UByteArray, q: UByteArray): UByteArray @@ -37,123 +37,56 @@ expect abstract class Ed25519LowLevel() { fun scalarMultiplicationBaseNoClamp(n: UByteArray): UByteArray } -object Ed25519 : Ed25519LowLevel() { - fun add(p: Point, q: Point): Point = - Point(addPoints(p.encoded, q.encoded)) +object Ed25519 { - fun subtract(p: Point, q: Point): Point = - Point(subtractPoints(p.encoded, q.encoded)) + data class Point(private val encoded: UByteArray) { - fun pointFromUniform(uniform: UByteArray): Point = Point(encodedPointFromUniform(uniform)) + companion object { + val IDENTITY: Point = Point(UByteArray(crypto_core_ed25519_BYTES)) + val BASE: Point = multiplyBaseNoClamp(Scalar.ONE) - fun randomPoint(): Point = Point(randomEncodedPoint()) + fun fromUniform(uniform: UByteArray): Point = Point(Ed25519LowLevel.encodedPointFromUniform(uniform)) - fun randomScalar(): Scalar = Scalar(randomEncodedScalar()) + fun random(): Point = Point(Ed25519LowLevel.randomEncodedPoint()) - fun invert(scalar: Scalar): Scalar = - Scalar(invertScalar(scalar.encoded)) + fun multiplyBase(n: Scalar): Point = Point(Ed25519LowLevel.scalarMultiplicationBase(n.encoded)) - fun negate(scalar: Scalar): Scalar = - Scalar(negateScalar(scalar.encoded)) + fun multiplyBaseNoClamp(n: Scalar): Point = + Point(Ed25519LowLevel.scalarMultiplicationBaseNoClamp(n.encoded)) - fun complement(scalar: Scalar): Scalar = - Scalar(complementScalar(scalar.encoded)) + fun fromHex(hex: String): Point = Point(LibsodiumUtil.fromHex(hex)) - fun add(x: Scalar, y: Scalar): Scalar = - Scalar(addScalars(x.encoded, y.encoded)) + fun isValid(point: Point) : Boolean = Ed25519LowLevel.isValidPoint(point.encoded) + } - fun subtract(x: Scalar, y: Scalar): Scalar = - Scalar(subtractScalars(x.encoded, y.encoded)) + operator fun plus(q: Point): Point = Point(Ed25519LowLevel.addPoints(this.encoded, q.encoded)) + operator fun minus(q: Point): Point = Point(Ed25519LowLevel.subtractPoints(this.encoded, q.encoded)) - fun multiply(x: Scalar, y: Scalar): Scalar = - Scalar(multiplyScalars(x.encoded, y.encoded)) - - fun reduce(scalar: Scalar): Scalar = - Scalar(reduceScalar(scalar.encoded)) - - fun scalarMultiplication(p: Point, n: Scalar): Point = - Point(scalarMultiplication(n.encoded, p.encoded)) - - fun scalarMultiplicationNoClamp(p: Point, n: Scalar): Point = - Point(scalarMultiplicationNoClamp(n.encoded, p.encoded)) - - fun scalarMultiplicationBase(n: Scalar): Point = - Point(scalarMultiplicationBase(n.encoded)) - - fun scalarMultiplicationBaseNoClamp(n: Scalar): Point = - Point(scalarMultiplicationBaseNoClamp(n.encoded)) - - data class Point(val encoded: UByteArray) { - operator fun plus(q: Point): Point = add(this, q) - operator fun minus(q: Point): Point = subtract(this, q) - - operator fun times(n: Scalar): Point = scalarMultiplication(this, n) + operator fun times(n: Scalar): Point = Point(Ed25519LowLevel.scalarMultiplication(n.encoded, this.encoded)) fun times(n: Scalar, clamp: Boolean): Point = - if (clamp) scalarMultiplication(this, n) else scalarMultiplicationNoClamp(this, n) + if (clamp) times(n) else Point(Ed25519LowLevel.scalarMultiplicationNoClamp(n.encoded, this.encoded)) fun toHex(): String = LibsodiumUtil.toHex(encoded) override fun equals(other: Any?): Boolean = (other as? Point)?.encoded?.contentEquals(encoded) == true override fun hashCode(): Int = encoded.contentHashCode() - - companion object { - val IDENTITY: Point = Point(UByteArray(crypto_core_ed25519_BYTES)) - val BASE: Point = scalarMultiplicationBaseNoClamp(Scalar.ONE) - - fun fromUniform(uniform: UByteArray): Point = pointFromUniform(uniform) - - fun random(): Point = randomPoint() - - fun multiplyBase(n: Scalar): Point = scalarMultiplicationBase(n) - - fun multiplyBaseNoClamp(n: Scalar): Point = scalarMultiplicationBaseNoClamp(n) - - fun fromHex(hex: String): Point = Point(LibsodiumUtil.fromHex(hex)) - } } data class Scalar(val encoded: UByteArray) { - operator fun plus(y: Scalar): Scalar = add(this, y) - operator fun plus(y: UInt): Scalar = this + fromUInt(y) - operator fun plus(y: ULong): Scalar = this + fromULong(y) - - operator fun minus(y: Scalar): Scalar = subtract(this, y) - operator fun minus(y: UInt): Scalar = this - fromUInt(y) - operator fun minus(y: ULong): Scalar = this - fromULong(y) - - operator fun times(y: Scalar): Scalar = multiply(this, y) - operator fun times(y: UInt): Scalar = this * fromUInt(y) - operator fun times(y: ULong): Scalar = this * fromULong(y) - - operator fun div(y: Scalar): Scalar = multiply(this, invert(y)) - operator fun div(y: UInt): Scalar = this / fromUInt(y) - operator fun div(y: ULong): Scalar = this / fromULong(y) - - operator fun unaryMinus(): Scalar = negate(this) - - operator fun times(p: Point): Point = scalarMultiplication(p, this) - fun times(p: Point, clamp: Boolean): Point = - if (clamp) scalarMultiplication(p, this) else scalarMultiplicationNoClamp(p, this) - - fun reduce(): Scalar = reduce(this) - fun invert(): Scalar = invert(this) - fun complement(): Scalar = complement(this) - - fun multiplyWithBase(): Point = scalarMultiplicationBase(this) - - fun multiplyWithBaseNoClamp(): Point = scalarMultiplicationBaseNoClamp(this) - - fun toHex(): String = LibsodiumUtil.toHex(encoded) - - override fun equals(other: Any?): Boolean = (other as? Scalar)?.encoded?.contentEquals(encoded) == true - override fun hashCode(): Int = encoded.contentHashCode() companion object { val ZERO = fromUInt(0U) val ONE = fromUInt(1U) val TWO = fromUInt(2U) - fun random(): Scalar = randomScalar() + + fun random(): Scalar = Scalar(Ed25519LowLevel.randomEncodedScalar()) + + fun invert(scalar: Scalar): Scalar = Scalar(Ed25519LowLevel.invertScalar(scalar.encoded)) + + fun reduce(scalar: Scalar): Scalar = Scalar(Ed25519LowLevel.reduceScalar(scalar.encoded)) + + fun complement(scalar: Scalar) : Scalar = Scalar(Ed25519LowLevel.complementScalar(scalar.encoded)) fun fromUInt(i: UInt): Scalar = fromULong(i.toULong()) @@ -178,7 +111,7 @@ object Ed25519 : Ed25519LowLevel() { val encoded = LibsodiumUtil.fromHex(hex.padEnd(2 * crypto_core_ed25519_NONREDUCEDSCALARBYTES, '0')) // Scalars are encoded in little-endian order, so the end can be padded with zeroes up to the size of a // non-reduced scalar. After decoding, it is reduced, to obtain a scalar in the canonical range - return Scalar(reduceScalar(encoded)) + return Scalar(Ed25519LowLevel.reduceScalar(encoded)) } else { val encoded = LibsodiumUtil.fromHex(hex.padEnd(2 * crypto_core_ed25519_SCALARBYTES, '0')) // Scalars are encoded in little-endian order, so the end can be padded with zeroes up to the size of a @@ -187,5 +120,42 @@ object Ed25519 : Ed25519LowLevel() { } } } + + operator fun plus(y: Scalar): Scalar = Scalar(Ed25519LowLevel.addScalars(this.encoded, y.encoded)) + operator fun plus(y: UInt): Scalar = this + fromUInt(y) + operator fun plus(y: ULong): Scalar = this + fromULong(y) + + operator fun minus(y: Scalar): Scalar = Scalar(Ed25519LowLevel.subtractScalars(this.encoded, y.encoded)) + operator fun minus(y: UInt): Scalar = this - fromUInt(y) + operator fun minus(y: ULong): Scalar = this - fromULong(y) + + operator fun times(y: Scalar): Scalar = Scalar(Ed25519LowLevel.multiplyScalars(this.encoded, y.encoded)) + operator fun times(y: UInt): Scalar = this * fromUInt(y) + operator fun times(y: ULong): Scalar = this * fromULong(y) + + operator fun div(y: Scalar): Scalar = times(invert(y)) + operator fun div(y: UInt): Scalar = this / fromUInt(y) + operator fun div(y: ULong): Scalar = this / fromULong(y) + + operator fun unaryMinus(): Scalar = Scalar(Ed25519LowLevel.negateScalar(this.encoded)) + + operator fun times(p: Point): Point = p.times(this) + fun times(p: Point, clamp: Boolean): Point = + if (clamp) p.times(this) else p.times(this, clamp) + + fun reduce(): Scalar = reduce(this) + fun invert(): Scalar = invert(this) + fun complement(): Scalar = complement(this) + + fun multiplyWithBase(): Point = Point.multiplyBase(this) + + fun multiplyWithBaseNoClamp(): Point = Point.multiplyBaseNoClamp(this) + + fun toHex(): String = LibsodiumUtil.toHex(encoded) + + override fun equals(other: Any?): Boolean = (other as? Scalar)?.encoded?.contentEquals(encoded) == true + override fun hashCode(): Int = encoded.contentHashCode() + + } } \ No newline at end of file diff --git a/multiplatform-crypto-libsodium-bindings/src/commonTest/kotlin/com/ionspin/kotlin/crypto/ed25519/Ed25519Test.kt b/multiplatform-crypto-libsodium-bindings/src/commonTest/kotlin/com/ionspin/kotlin/crypto/ed25519/Ed25519Test.kt index e7d3e6c..28e2cee 100644 --- a/multiplatform-crypto-libsodium-bindings/src/commonTest/kotlin/com/ionspin/kotlin/crypto/ed25519/Ed25519Test.kt +++ b/multiplatform-crypto-libsodium-bindings/src/commonTest/kotlin/com/ionspin/kotlin/crypto/ed25519/Ed25519Test.kt @@ -98,9 +98,9 @@ class Ed25519Test { assertNotEquals(q, r) assertNotEquals(r, p) - assertTrue { Ed25519.isValidPoint(p.encoded) } - assertTrue { Ed25519.isValidPoint(q.encoded) } - assertTrue { Ed25519.isValidPoint(r.encoded) } + assertTrue { Ed25519.Point.isValid(p) } + assertTrue { Ed25519.Point.isValid(q) } + assertTrue { Ed25519.Point.isValid(r) } } } @@ -119,15 +119,15 @@ class Ed25519Test { fun testIsValidPoint() = runTest { LibsodiumInitializer.initializeWithCallback { for (hexEncoded in badEncodings) { - assertFalse { Ed25519.isValidPoint(LibsodiumUtil.fromHex(hexEncoded)) } + assertFalse { Ed25519.Point.isValid(Ed25519.Point.fromHex(hexEncoded)) } } for (hexEncoded in basePointSmallMultiplesNoClamp) { - assertTrue { Ed25519.isValidPoint(LibsodiumUtil.fromHex(hexEncoded)) } + assertTrue { Ed25519.Point.isValid(Ed25519.Point.fromHex(hexEncoded)) } } for (hexEncoded in basePointSmallMultiplesClamped) { - assertTrue { Ed25519.isValidPoint(LibsodiumUtil.fromHex(hexEncoded)) } + assertTrue { Ed25519.Point.isValid(Ed25519.Point.fromHex(hexEncoded)) } } } } @@ -140,8 +140,8 @@ class Ed25519Test { val b = Ed25519.Point.BASE val n = Ed25519.Scalar.fromUInt(i.toUInt() + 1U) - assertEquals(p, Ed25519.scalarMultiplicationBaseNoClamp(n)) - assertEquals(p, Ed25519.scalarMultiplicationNoClamp(b, n)) + assertEquals(p, Ed25519.Point.multiplyBaseNoClamp(n)) + assertEquals(p, b.times(n, false)) assertEquals(p, n.multiplyWithBaseNoClamp()) for (j in 0..