diff --git a/multiplatform-crypto-libsodium-bindings/src/commonMain/kotlin/com.ionspin.kotlin.crypto/ed25519/Ed25519.kt b/multiplatform-crypto-libsodium-bindings/src/commonMain/kotlin/com.ionspin.kotlin.crypto/ed25519/Ed25519.kt index f8ba0e2..332b506 100644 --- a/multiplatform-crypto-libsodium-bindings/src/commonMain/kotlin/com.ionspin.kotlin.crypto/ed25519/Ed25519.kt +++ b/multiplatform-crypto-libsodium-bindings/src/commonMain/kotlin/com.ionspin.kotlin.crypto/ed25519/Ed25519.kt @@ -84,6 +84,8 @@ object Ed25519 { fun invert(scalar: Scalar): Scalar = Scalar(Ed25519LowLevel.invertScalar(scalar.encoded)) + fun negate(scalar: Scalar): Scalar = Scalar(Ed25519LowLevel.negateScalar(scalar.encoded)) + fun reduce(scalar: Scalar): Scalar = Scalar(Ed25519LowLevel.reduceScalar(scalar.encoded)) fun complement(scalar: Scalar) : Scalar = Scalar(Ed25519LowLevel.complementScalar(scalar.encoded)) @@ -137,7 +139,7 @@ object Ed25519 { operator fun div(y: UInt): Scalar = this / fromUInt(y) operator fun div(y: ULong): Scalar = this / fromULong(y) - operator fun unaryMinus(): Scalar = Scalar(Ed25519LowLevel.negateScalar(this.encoded)) + operator fun unaryMinus(): Scalar = negate(this) operator fun times(p: Point): Point = p.times(this) fun times(p: Point, clamp: Boolean): Point = diff --git a/multiplatform-crypto-libsodium-bindings/src/commonMain/kotlin/com.ionspin.kotlin.crypto/ristretto255/Ristretto255.kt b/multiplatform-crypto-libsodium-bindings/src/commonMain/kotlin/com.ionspin.kotlin.crypto/ristretto255/Ristretto255.kt index 790632a..cd03d4e 100644 --- a/multiplatform-crypto-libsodium-bindings/src/commonMain/kotlin/com.ionspin.kotlin.crypto/ristretto255/Ristretto255.kt +++ b/multiplatform-crypto-libsodium-bindings/src/commonMain/kotlin/com.ionspin.kotlin.crypto/ristretto255/Ristretto255.kt @@ -16,7 +16,7 @@ const val crypto_core_ristretto255_NONREDUCEDSCALARBYTES = 64 const val crypto_scalarmult_ristretto255_BYTES = 32U const val crypto_scalarmult_ristretto255_SCALARBYTES = 32U -expect abstract class Ristretto255LowLevel() { +expect object Ristretto255LowLevel { fun isValidPoint(encoded: UByteArray): Boolean fun addPoints(p: UByteArray, q: UByteArray): UByteArray fun subtractPoints(p: UByteArray, q: UByteArray): UByteArray @@ -34,97 +34,102 @@ expect abstract class Ristretto255LowLevel() { fun scalarMultiplicationBase(n: UByteArray): UByteArray } -object Ristretto255 : Ristretto255LowLevel() { - fun add(p: Point, q: Point): Point = - Point(addPoints(p.encoded, q.encoded)) - - fun subtract(p: Point, q: Point): Point = - Point(subtractPoints(p.encoded, q.encoded)) - - fun pointFromHash(hash: UByteArray): Point = Point(encodedPointFromHash(hash)) - - fun randomPoint(): Point = Point(randomEncodedPoint()) - - fun randomScalar(): Scalar = Scalar(randomEncodedScalar()) - - fun invert(scalar: Scalar): Scalar = - Scalar(invertScalar(scalar.encoded)) - - fun negate(scalar: Scalar): Scalar = - Scalar(negateScalar(scalar.encoded)) - - fun complement(scalar: Scalar): Scalar = - Scalar(complementScalar(scalar.encoded)) - - fun add(x: Scalar, y: Scalar): Scalar = - Scalar(addScalars(x.encoded, y.encoded)) - - fun subtract(x: Scalar, y: Scalar): Scalar = - Scalar(subtractScalars(x.encoded, y.encoded)) - - fun multiply(x: Scalar, y: Scalar): Scalar = - Scalar(multiplyScalars(x.encoded, y.encoded)) - - fun reduce(scalar: Scalar): Scalar = - Scalar(reduceScalar(scalar.encoded)) - - fun scalarMultiplication(p: Point, n: Scalar): Point = - Point(scalarMultiplication(n.encoded, p.encoded)) - - fun scalarMultiplicationBase(n: Scalar): Point = - Point(scalarMultiplicationBase(n.encoded)) +object Ristretto255 { +// fun add(p: Point, q: Point): Point = +// Point(addPoints(p.encoded, q.encoded)) +// +// fun subtract(p: Point, q: Point): Point = +// Point(subtractPoints(p.encoded, q.encoded)) +// +// fun pointFromHash(hash: UByteArray): Point = Point(encodedPointFromHash(hash)) +// +// fun randomPoint(): Point = Point(randomEncodedPoint()) +// +// fun randomScalar(): Scalar = Scalar(randomEncodedScalar()) +// +// fun invert(scalar: Scalar): Scalar = +// Scalar(invertScalar(scalar.encoded)) +// +// fun negate(scalar: Scalar): Scalar = +// Scalar(negateScalar(scalar.encoded)) +// +// fun complement(scalar: Scalar): Scalar = +// Scalar(complementScalar(scalar.encoded)) +// +// fun add(x: Scalar, y: Scalar): Scalar = +// Scalar(addScalars(x.encoded, y.encoded)) +// +// fun subtract(x: Scalar, y: Scalar): Scalar = +// Scalar(subtractScalars(x.encoded, y.encoded)) +// +// fun multiply(x: Scalar, y: Scalar): Scalar = +// Scalar(multiplyScalars(x.encoded, y.encoded)) +// +// fun reduce(scalar: Scalar): Scalar = +// Scalar(reduceScalar(scalar.encoded)) +// +// fun scalarMultiplication(p: Point, n: Scalar): Point = +// Point(scalarMultiplication(n.encoded, p.encoded)) +// +// fun scalarMultiplicationBase(n: Scalar): Point = +// Point(scalarMultiplicationBase(n.encoded)) data class Point(val encoded: UByteArray) { - operator fun plus(q: Point): Point = add(this, q) - operator fun minus(q: Point): Point = subtract(this, q) - operator fun times(n: Scalar): Point = scalarMultiplication(this, n) + companion object { + val IDENTITY: Point = Point(UByteArray(crypto_core_ristretto255_BYTES)) + val BASE: Point = multiplyBase(Scalar.ONE) + + fun fromHash(hash: UByteArray): Point = Point(Ristretto255LowLevel.encodedPointFromHash(hash)) + + fun random(): Point = Point(Ristretto255LowLevel.randomEncodedPoint()) + + fun multiplyBase(n: Scalar): Point = Point(Ristretto255LowLevel.scalarMultiplicationBase(n.encoded)) + + fun fromHex(hex: String): Point = Point(LibsodiumUtil.fromHex(hex)) + + fun isValid(point: Point) : Boolean = Ristretto255LowLevel.isValidPoint(point.encoded) + } + + operator fun plus(q: Point): Point = Point(Ristretto255LowLevel.addPoints(this.encoded, q.encoded)) + operator fun minus(q: Point): Point = Point(Ristretto255LowLevel.subtractPoints(this.encoded, q.encoded)) + + operator fun times(n: Scalar): Point = Point(Ristretto255LowLevel.scalarMultiplication(n.encoded, this.encoded)) fun toHex(): String = LibsodiumUtil.toHex(encoded) override fun equals(other: Any?): Boolean = (other as? Point)?.encoded?.contentEquals(encoded) == true override fun hashCode(): Int = encoded.contentHashCode() - companion object { - val IDENTITY: Point = Point(UByteArray(crypto_core_ristretto255_BYTES)) - val BASE: Point = scalarMultiplicationBase(Scalar.ONE) - fun fromHash(hash: UByteArray): Point = pointFromHash(hash) - - fun random(): Point = randomPoint() - - fun multiplyBase(n: Scalar): Point = scalarMultiplicationBase(n) - - fun fromHex(hex: String): Point = Point(LibsodiumUtil.fromHex(hex)) - } } data class Scalar(val encoded: UByteArray) { - operator fun plus(y: Scalar): Scalar = add(this, y) + operator fun plus(y: Scalar): Scalar = Scalar(Ristretto255LowLevel.addScalars(this.encoded, y.encoded)) operator fun plus(y: UInt): Scalar = this + fromUInt(y) operator fun plus(y: ULong): Scalar = this + fromULong(y) - operator fun minus(y: Scalar): Scalar = subtract(this, y) + operator fun minus(y: Scalar): Scalar = Scalar(Ristretto255LowLevel.subtractScalars(this.encoded, y.encoded)) operator fun minus(y: UInt): Scalar = this - fromUInt(y) operator fun minus(y: ULong): Scalar = this - fromULong(y) - operator fun times(y: Scalar): Scalar = multiply(this, y) + operator fun times(y: Scalar): Scalar = Scalar(Ristretto255LowLevel.multiplyScalars(this.encoded, y.encoded)) operator fun times(y: UInt): Scalar = this * fromUInt(y) operator fun times(y: ULong): Scalar = this * fromULong(y) - operator fun div(y: Scalar): Scalar = multiply(this, invert(y)) + operator fun div(y: Scalar): Scalar = Scalar(Ristretto255LowLevel.multiplyScalars(this.encoded, y.invert().encoded)) operator fun div(y: UInt): Scalar = this / fromUInt(y) operator fun div(y: ULong): Scalar = this / fromULong(y) operator fun unaryMinus(): Scalar = negate(this) - operator fun times(p: Point): Point = scalarMultiplication(p, this) + operator fun times(p: Point): Point = p.times(this) fun reduce(): Scalar = reduce(this) fun invert(): Scalar = invert(this) fun complement(): Scalar = complement(this) - fun multiplyWithBase(): Point = scalarMultiplicationBase(this) + fun multiplyWithBase(): Point = Point.multiplyBase(this) fun toHex(): String = LibsodiumUtil.toHex(encoded) @@ -136,8 +141,20 @@ object Ristretto255 : Ristretto255LowLevel() { val ONE = fromUInt(1U) val TWO = fromUInt(2U) - fun random(): Scalar = randomScalar() + fun random(): Scalar = Scalar(Ristretto255LowLevel.randomEncodedScalar()) + fun invert(scalar: Scalar): Scalar = + Scalar(Ristretto255LowLevel.invertScalar(scalar.encoded)) + + fun negate(scalar: Scalar): Scalar = + Scalar(Ristretto255LowLevel.negateScalar(scalar.encoded)) + + fun reduce(scalar: Scalar): Scalar = + Scalar(Ristretto255LowLevel.reduceScalar(scalar.encoded)) + + fun complement(scalar: Scalar): Scalar = + Scalar(Ristretto255LowLevel.complementScalar(scalar.encoded)) + fun fromUInt(i: UInt): Scalar = fromULong(i.toULong()) fun fromULong(l: ULong): Scalar { @@ -162,7 +179,7 @@ object Ristretto255 : Ristretto255LowLevel() { LibsodiumUtil.fromHex(hex.padEnd(2 * crypto_core_ristretto255_NONREDUCEDSCALARBYTES, '0')) // Scalars are encoded in little-endian order, so the end can be padded with zeroes up to the size of a // non-reduced scalar. After decoding, it is reduced, to obtain a scalar in the canonical range - return Scalar(reduceScalar(encoded)) + return Scalar(Ristretto255LowLevel.reduceScalar(encoded)) } else { val encoded = LibsodiumUtil.fromHex(hex.padEnd(2 * crypto_core_ristretto255_SCALARBYTES, '0')) // Scalars are encoded in little-endian order, so the end can be padded with zeroes up to the size of a diff --git a/multiplatform-crypto-libsodium-bindings/src/commonTest/kotlin/com/ionspin/kotlin/crypto/ristretto255/Ristretto255Test.kt b/multiplatform-crypto-libsodium-bindings/src/commonTest/kotlin/com/ionspin/kotlin/crypto/ristretto255/Ristretto255Test.kt index 2602eac..d1725a2 100644 --- a/multiplatform-crypto-libsodium-bindings/src/commonTest/kotlin/com/ionspin/kotlin/crypto/ristretto255/Ristretto255Test.kt +++ b/multiplatform-crypto-libsodium-bindings/src/commonTest/kotlin/com/ionspin/kotlin/crypto/ristretto255/Ristretto255Test.kt @@ -96,9 +96,9 @@ class Ristretto255Test { - assertTrue { Ristretto255.isValidPoint(p.encoded) } - assertTrue { Ristretto255.isValidPoint(q.encoded) } - assertTrue { Ristretto255.isValidPoint(r.encoded) } + assertTrue { Ristretto255.Point.isValid(p) } + assertTrue { Ristretto255.Point.isValid(q) } + assertTrue { Ristretto255.Point.isValid(r) } } } @@ -117,11 +117,11 @@ class Ristretto255Test { fun testIsValidPoint() = runTest { LibsodiumInitializer.initializeWithCallback { for (hexEncoded in badEncodings) { - assertFalse { Ristretto255.isValidPoint(LibsodiumUtil.fromHex(hexEncoded)) } + assertFalse { Ristretto255.Point.isValid(Ristretto255.Point.fromHex(hexEncoded)) } } for (hexEncoded in basePointSmallMultiples) { - assertTrue { Ristretto255.isValidPoint(LibsodiumUtil.fromHex(hexEncoded)) } + assertTrue { Ristretto255.Point.isValid(Ristretto255.Point.fromHex(hexEncoded)) } } } } @@ -134,8 +134,8 @@ class Ristretto255Test { val b = Ristretto255.Point.BASE val n = Ristretto255.Scalar.fromUInt(i.toUInt() + 1U) - assertEquals(p, Ristretto255.scalarMultiplicationBase(n)) - assertEquals(p, Ristretto255.scalarMultiplication(b, n)) + assertEquals(p, Ristretto255.Point.multiplyBase(n)) + assertEquals(p, b.times(n)) assertEquals(p, n.multiplyWithBase()) for (j in 0..