fixed extending extern classes (extension methods)

This commit is contained in:
Sergey Chernov 2026-03-15 05:35:29 +03:00
parent 8e0442670d
commit d2a47d34a3
3 changed files with 288 additions and 130 deletions

View File

@ -1120,7 +1120,16 @@ class Compiler(
if (implicitType != null) {
resolutionSink?.referenceMember(name, pos, implicitType)
val ids = resolveImplicitThisMemberIds(name, pos, implicitType)
return ImplicitThisMemberRef(name, pos, ids.fieldId, ids.methodId, implicitType)
val inClassContext = codeContexts.any { ctx -> ctx is CodeContext.ClassBody }
val currentImplicitType = currentImplicitThisTypeName()
val preferredType = when {
inClassContext -> implicitType
// Extension receiver aliases (extern class name -> host runtime class) can fail strict
// variant casts; keep current receiver untyped but preserve non-current receivers.
implicitType == currentImplicitType -> null
else -> implicitType
}
return ImplicitThisMemberRef(name, pos, ids.fieldId, ids.methodId, preferredType)
}
if (classCtx != null && classCtx.classScopeMembers.contains(name)) {
resolutionSink?.referenceMember(name, pos, classCtx.name)
@ -4779,6 +4788,11 @@ class Compiler(
if (payload != null) return payload
}
val receiverClass = resolveReceiverClassForMember(ref.receiver)
classMethodReturnClass(receiverClass, ref.name)?.let { return it }
inferFieldReturnClass(receiverClass, ref.name)?.let { return it }
if (receiverClass != null && isClassScopeCallableMember(receiverClass.className, ref.name)) {
resolveClassByName("${receiverClass.className}.${ref.name}")?.let { return it }
}
return inferMethodCallReturnClass(ref.name)
}

View File

@ -1067,18 +1067,9 @@ class BytecodeCompiler(
updateSlotType(dst, SlotType.OBJ)
return CompiledValue(dst, SlotType.OBJ)
}
if (receiverClass == null && memberName == "negate") {
val zeroId = builder.addConst(BytecodeConst.IntVal(0))
val zeroSlot = allocSlot()
builder.emit(Opcode.CONST_INT, zeroId, zeroSlot)
updateSlotType(zeroSlot, SlotType.INT)
val obj = ensureObjSlot(value)
val dst = allocSlot()
builder.emit(Opcode.SUB_OBJ, zeroSlot, obj.slot, dst)
updateSlotType(dst, SlotType.OBJ)
return CompiledValue(dst, SlotType.OBJ)
}
if (memberName == "negate" && receiverClass in setOf(ObjInt.type, ObjReal.type)) {
if (memberName == "negate" &&
(receiverClass == null || isDelegateClass(receiverClass) || receiverClass in setOf(ObjInt.type, ObjReal.type))
) {
val zeroId = builder.addConst(BytecodeConst.IntVal(0))
val zeroSlot = allocSlot()
builder.emit(Opcode.CONST_INT, zeroId, zeroSlot)
@ -1095,6 +1086,11 @@ class BytecodeCompiler(
)
}
private fun isDelegateClass(receiverClass: ObjClass): Boolean =
receiverClass.className == "Delegate" ||
receiverClass.className == "LazyDelegate" ||
receiverClass.implementingNames.contains("Delegate")
private fun operatorMemberName(op: BinOp): String? = when (op) {
BinOp.PLUS -> "plus"
BinOp.MINUS -> "minus"
@ -1142,11 +1138,7 @@ class BytecodeCompiler(
): CompiledValue? {
val memberName = operatorMemberName(op) ?: return null
val receiverClass = resolveReceiverClass(leftRef)
if (receiverClass == null ||
receiverClass.className == "Delegate" ||
receiverClass.className == "LazyDelegate" ||
receiverClass.implementingNames.contains("Delegate")
) {
if (receiverClass == null || isDelegateClass(receiverClass)) {
val objOpcode = when (op) {
BinOp.PLUS -> Opcode.ADD_OBJ
BinOp.MINUS -> Opcode.SUB_OBJ
@ -4821,9 +4813,26 @@ class BytecodeCompiler(
private data class CallArgs(val base: Int, val count: Int, val planId: Int?)
private fun resolveExtensionCallableSlot(receiverClass: ObjClass, memberName: String): CompiledValue? {
private fun extensionReceiverTypeNames(receiverClass: ObjClass): Set<String> {
val names = LinkedHashSet<String>()
for (cls in receiverClass.mro) {
val candidate = extensionCallableName(cls.className, memberName)
names.add(cls.className)
for ((knownName, knownClass) in nameObjClass) {
if (knownClass !== cls && knownClass.className != cls.className) continue
names.add(knownName)
names.add(knownName.substringAfterLast('.'))
}
}
return names
}
private fun resolveExtensionSlotByReceiverNames(
receiverClass: ObjClass,
memberName: String,
wrapperName: (String, String) -> String
): CompiledValue? {
for (receiverName in extensionReceiverTypeNames(receiverClass)) {
val candidate = wrapperName(receiverName, memberName)
if (allowedScopeNames != null &&
!allowedScopeNames.contains(candidate) &&
!localSlotIndexByName.containsKey(candidate)
@ -4835,32 +4844,39 @@ class BytecodeCompiler(
return null
}
private fun resolveUniqueExtensionWrapperSlot(
memberName: String,
wrapperPrefix: String
): CompiledValue? {
val suffix = "__$memberName"
val candidates = LinkedHashSet<String>()
for (name in localSlotIndexByName.keys) {
if (name.startsWith(wrapperPrefix) && name.endsWith(suffix)) {
candidates.add(name)
}
}
for (name in scopeSlotIndexByName.keys) {
if (name.startsWith(wrapperPrefix) && name.endsWith(suffix)) {
candidates.add(name)
}
}
if (candidates.size != 1) return null
return resolveDirectNameSlot(candidates.first())
}
private fun resolveExtensionCallableSlot(receiverClass: ObjClass, memberName: String): CompiledValue? {
return resolveExtensionSlotByReceiverNames(receiverClass, memberName, ::extensionCallableName)
?: resolveUniqueExtensionWrapperSlot(memberName, "__ext__")
}
private fun resolveExtensionGetterSlot(receiverClass: ObjClass, memberName: String): CompiledValue? {
for (cls in receiverClass.mro) {
val candidate = extensionPropertyGetterName(cls.className, memberName)
if (allowedScopeNames != null &&
!allowedScopeNames.contains(candidate) &&
!localSlotIndexByName.containsKey(candidate)
) {
continue
}
resolveDirectNameSlot(candidate)?.let { return it }
}
return null
return resolveExtensionSlotByReceiverNames(receiverClass, memberName, ::extensionPropertyGetterName)
?: resolveUniqueExtensionWrapperSlot(memberName, "__ext_get__")
}
private fun resolveExtensionSetterSlot(receiverClass: ObjClass, memberName: String): CompiledValue? {
for (cls in receiverClass.mro) {
val candidate = extensionPropertySetterName(cls.className, memberName)
if (allowedScopeNames != null &&
!allowedScopeNames.contains(candidate) &&
!localSlotIndexByName.containsKey(candidate)
) {
continue
}
resolveDirectNameSlot(candidate)?.let { return it }
}
return null
return resolveExtensionSlotByReceiverNames(receiverClass, memberName, ::extensionPropertySetterName)
?: resolveUniqueExtensionWrapperSlot(memberName, "__ext_set__")
}
private fun compileCallArgsWithReceiver(
@ -7468,7 +7484,7 @@ class BytecodeCompiler(
if (targetClass == null) return null
if (targetClass == ObjDynamic.type) return ObjDynamic.type
classFieldTypesByName[targetClass.className]?.get(name)?.let { cls ->
if (cls.className == "Delegate" || cls.className == "LazyDelegate" || cls.implementingNames.contains("Delegate")) {
if (isDelegateClass(cls)) {
return null
}
return cls
@ -7555,8 +7571,8 @@ class BytecodeCompiler(
private fun queueExtensionCallableNames(receiverClass: ObjClass, memberName: String) {
if (!useScopeSlots && globalSlotInfo.isEmpty()) return
for (cls in receiverClass.mro) {
val name = extensionCallableName(cls.className, memberName)
for (receiverName in extensionReceiverTypeNames(receiverClass)) {
val name = extensionCallableName(receiverName, memberName)
if (allowedScopeNames == null || allowedScopeNames.contains(name)) {
pendingScopeNameRefs.add(name)
}
@ -7565,12 +7581,12 @@ class BytecodeCompiler(
private fun queueExtensionPropertyNames(receiverClass: ObjClass, memberName: String) {
if (!useScopeSlots && globalSlotInfo.isEmpty()) return
for (cls in receiverClass.mro) {
val getter = extensionPropertyGetterName(cls.className, memberName)
for (receiverName in extensionReceiverTypeNames(receiverClass)) {
val getter = extensionPropertyGetterName(receiverName, memberName)
if (allowedScopeNames == null || allowedScopeNames.contains(getter)) {
pendingScopeNameRefs.add(getter)
}
val setter = extensionPropertySetterName(cls.className, memberName)
val setter = extensionPropertySetterName(receiverName, memberName)
if (allowedScopeNames == null || allowedScopeNames.contains(setter)) {
pendingScopeNameRefs.add(setter)
}

View File

@ -16,9 +16,10 @@
*/
import kotlinx.coroutines.test.runTest
import net.sergeych.lyng.Script
import net.sergeych.lyng.ScriptError
import net.sergeych.lyng.eval
import net.sergeych.lyng.*
import net.sergeych.lyng.obj.Obj
import net.sergeych.lyng.obj.ObjClass
import net.sergeych.lyng.obj.ObjNull
import kotlin.test.Test
import kotlin.test.assertFailsWith
import kotlin.test.assertTrue
@ -27,33 +28,42 @@ class TypesTest {
@Test
fun testTypeCollection1() = runTest {
eval("""
eval(
"""
class Point(x: Real, y: Real)
assert(Point(1,2).x == 1)
assert(Point(1,2).y == 2)
assert(Point(1,2) is Point)
""".trimIndent())
""".trimIndent()
)
}
@Test
fun testTypeCollection2() = runTest {
eval("""
eval(
"""
fun fn1(x: Real, y: Real): Real { x + y }
""".trimIndent())
""".trimIndent()
)
}
@Test
fun testTypeCollection3() = runTest {
eval("""
eval(
"""
class Test(a: Int) {
fun fn1(x: Real, y: Real): Real { x + y }
}
""".trimIndent())
""".trimIndent()
)
}
@Test
fun testExternDeclarations() = runTest {
eval("""
eval(
"""
extern fun foo1(a: String): Void
assertThrows { foo1("1") }
class Test(a: Int) {
@ -69,22 +79,26 @@ class TypesTest {
}
// println("4")
""".trimIndent())
""".trimIndent()
)
}
@Test
fun testUserClassCompareTo() = runTest {
eval("""
eval(
"""
class Point(val a,b)
assertEquals(Point(0,1), Point(0,1) )
assertNotEquals(Point(0,1), Point(1,1) )
""".trimIndent())
""".trimIndent()
)
}
@Test
fun testUserClassCompareTo2() = runTest {
eval("""
eval(
"""
class Point(val a,b) {
var c = 0
}
@ -98,12 +112,14 @@ class TypesTest {
assertEquals(p1, p2)
assertNotEquals(Point(0,1), Point(1,1) )
assertNotEquals(Point(0,1), p3)
""".trimIndent())
""".trimIndent()
)
}
@Test
fun testNumericInference() = runTest {
eval("""
eval(
"""
val x = 1
var y = 2.0
assert( x is Int )
@ -111,11 +127,14 @@ class TypesTest {
assert( x + y is Real )
assert( abs(x+y) is Real )
assert( abs(x/y) is Real )
""".trimIndent())
""".trimIndent()
)
}
@Test
fun testNumericInferenceBug1() = runTest {
eval("""
eval(
"""
fun findSumLimit(f) {
var sum = 0.0
for( n in 1..100 ) {
@ -142,12 +161,14 @@ class TypesTest {
val limit = findSumLimit { n -> 1.0/n/n }
assert( limit != null )
println("Result: "+limit)
""".trimIndent())
""".trimIndent()
)
}
@Test
fun testNullableHints() = runTest {
eval("""
eval(
"""
// nullable, without type os Object?
class N(x=null)
assertEquals(null, N().x)
@ -167,7 +188,8 @@ class TypesTest {
@Test
fun testIsUnionIntersection() = runTest {
eval("""
eval(
"""
class A
class B
class C: A, B
@ -179,7 +201,8 @@ class TypesTest {
val v = 1
assert( v is Int | String | Real )
assert( !(v is String | Bool) )
""".trimIndent())
""".trimIndent()
)
}
@Test
@ -221,51 +244,64 @@ class TypesTest {
@Test
fun testListLiteralInferenceForBounds() = runTest {
eval("""
eval(
"""
fun acceptInts<T: Int>(xs: List<T>) { }
acceptInts([1, 2, 3])
val base = [1, 2]
acceptInts([...base, 3])
""".trimIndent())
eval("""
""".trimIndent()
)
eval(
"""
fun acceptReals<T: Real>(xs: List<T>) { }
acceptReals([1.0, 2.0, 3.0])
val base = [1.0, 2.0]
acceptReals([...base, 3.0])
""".trimIndent())
""".trimIndent()
)
assertFailsWith<net.sergeych.lyng.ScriptError> {
eval("""
eval(
"""
fun acceptInts<T: Int>(xs: List<T>) { }
acceptInts([1, "a"])
""".trimIndent())
""".trimIndent()
)
}
assertFailsWith<net.sergeych.lyng.ScriptError> {
eval("""
eval(
"""
fun acceptReals<T: Real>(xs: List<T>) { }
acceptReals([1.0, "a"])
""".trimIndent())
""".trimIndent()
)
}
}
@Test
fun testMapLiteralInferenceForBounds() = runTest {
eval("""
eval(
"""
fun acceptMap<T: Int>(m: Map<String, T>) { }
acceptMap({ "a": 1, "b": 2 })
val base = { "a": 1 }
acceptMap({ ...base, "b": 3 })
""".trimIndent())
""".trimIndent()
)
assertFailsWith<net.sergeych.lyng.ScriptError> {
eval("""
eval(
"""
fun acceptMap<T: Int>(m: Map<String, T>) { }
acceptMap({ "a": 1, "b": "x" })
""".trimIndent())
""".trimIndent()
)
}
}
@Test
fun testUnionTypeLists() = runTest {
eval("""
eval(
"""
fun fMixed<T>(list: List<T>) {
println(list)
@ -282,12 +318,14 @@ class TypesTest {
}
fMixed([1, "two", true])
fInts([1,2,3])
""")
"""
)
}
@Test
fun testTypeAliases() = runTest {
eval("""
eval(
"""
type Num = Int | Real
type AB = A & B
class A
@ -308,12 +346,14 @@ class TypesTest {
type IntList<T: Int> = List<T>
fun accept<T: Int>(xs: IntList<T>) { }
accept([1,2,3])
""".trimIndent())
""".trimIndent()
)
}
@Test
fun testMultipleReceivers() = runTest {
eval("""
eval(
"""
class R1(shared,r1="r1")
class R2(shared,r2="r2")
@ -340,69 +380,85 @@ class TypesTest {
assertEquals("r1", r1)
}
}
""")
"""
)
}
@Test
fun testLambdaTypes1() = runTest {
val scope = Script.newScope()
// declare: ok
scope.eval("""
scope.eval(
"""
var l1: (Int,String)->String
""".trimIndent())
""".trimIndent()
)
// this should be Lyng compile time exception
assertFailsWith<ScriptError> {
scope.eval("""
scope.eval(
"""
fun test() {
// compiler should detect that l1 us called with arguments that does not match
// declare type (Int,String)->String:
l1()
}
""".trimIndent())
""".trimIndent()
)
}
}
@Test
fun testLambdaTypesEllipsis() = runTest {
val scope = Script.newScope()
scope.eval("""
scope.eval(
"""
var l2: (Int,Object...,String)->Real
var l4: (Int,String...,String)->Real
var l3: (...)->Int
""".trimIndent())
""".trimIndent()
)
assertFailsWith<ScriptError> {
scope.eval("""
scope.eval(
"""
fun testTooFew() {
l2(1)
}
""".trimIndent())
""".trimIndent()
)
}
assertFailsWith<ScriptError> {
scope.eval("""
scope.eval(
"""
fun testWrongHead() {
l2("x", "y")
}
""".trimIndent())
""".trimIndent()
)
}
assertFailsWith<ScriptError> {
scope.eval("""
scope.eval(
"""
fun testWrongEllipsis() {
l4(1, 2, "x")
}
""".trimIndent())
""".trimIndent()
)
}
scope.eval("""
scope.eval(
"""
fun testOk1() { l2(1, "x") }
fun testOk2() { l2(1, 2, 3, "x") }
fun testOk3() { l3() }
fun testOk4() { l3(1, true, "x") }
fun testOk5() { l4(1, "a", "b", "x") }
""".trimIndent())
""".trimIndent()
)
}
@Test
fun testSetTyped() = runTest {
eval("""
eval(
"""
var s = Set<String>()
val typed: Set<String> = s
assertEquals(Set(), typed)
@ -413,12 +469,14 @@ class TypesTest {
assertEquals(Set(), s)
s += ["foo", "bar"]
assertEquals(Set("foo", "bar"), s)
""".trimIndent())
""".trimIndent()
)
}
@Test
fun testListTyped() = runTest {
eval("""
eval(
"""
var l = List<String>()
val typed: List<String> = l
assertEquals(List(), typed)
@ -430,13 +488,15 @@ class TypesTest {
l += ["foo", "bar"]
assertEquals(List("foo", "bar"), l)
""".trimIndent())
""".trimIndent()
)
}
@Test
fun testAliasesInGenerics1() = runTest {
val scope = Script.newScope()
scope.eval("""
scope.eval(
"""
type IntList<T: Int> = List<T>
type IntMap<K,V> = Map<K,V>
type IntSet<T: Int> = Set<T>
@ -459,13 +519,16 @@ class TypesTest {
assertEquals(Set("tag1", "tag2", Buffer("tag3")), x.tags)
x.tags += Buffer("tag4")
assertEquals(Set("tag1", "tag2", Buffer("tag3"), Buffer("tag4")), x.tags)
""")
scope.eval("""
"""
)
scope.eval(
"""
assert(x is X)
x.tags += "42"
assertEquals(Set("tag1", "tag2", Buffer("tag3"), Buffer("tag4"), "42"), x.tags)
""".trimIndent())
""".trimIndent()
)
// now this must fail becaise element type does not match the declared:
assertFailsWith<ScriptError> {
scope.eval(
@ -479,7 +542,8 @@ class TypesTest {
@Test
fun testAliasesInGenericsList1() = runTest {
val scope = Script.newScope()
scope.eval("""
scope.eval(
"""
import lyng.buffer
type Tag = String | Buffer
@ -495,12 +559,15 @@ class TypesTest {
assertEquals(List("tag1", "tag2", Buffer("tag3")), x.tags)
x.tags += ["tag4", Buffer("tag5")]
assertEquals(List("tag1", "tag2", Buffer("tag3"), "tag4", Buffer("tag5")), x.tags)
""")
scope.eval("""
"""
)
scope.eval(
"""
assert(x is X)
x.tags += "42"
assertEquals(List("tag1", "tag2", Buffer("tag3"), "tag4", Buffer("tag5"), "42"), x.tags)
""".trimIndent())
""".trimIndent()
)
assertFailsWith<ScriptError> {
scope.eval(
"""
@ -512,18 +579,21 @@ class TypesTest {
@Test
fun testClassName() = runTest {
eval("""
eval(
"""
class X {
var x = 1
}
assert( X::class is Class)
assertEquals("Class", X::class.name)
""".trimIndent())
""".trimIndent()
)
}
@Test
fun testGenericTypes() = runTest {
eval("""
eval(
"""
fun t<T>(): String =
when(T) {
null -> "%s is Null"(T::class.name)
@ -532,24 +602,28 @@ class TypesTest {
}
assert( Int is Object)
assertEquals( t<Int>(), "Class is Object")
""".trimIndent())
""".trimIndent()
)
}
@Test
fun testGenericNullableTypePredicate() = runTest {
eval("""
eval(
"""
fun isTypeNullable<T>(x: T): Bool = T is nullable
type MaybeInt = Int?
assert(isTypeNullable<Int?>(null))
assert(!isTypeNullable<Int>(1))
assert(MaybeInt is nullable)
assert(!(Int is nullable))
""".trimIndent())
""".trimIndent()
)
}
@Test
fun testWhenNullableTypeCase() = runTest {
eval("""
eval(
"""
fun describe<T>(x: T): String = when(T) {
nullable -> "nullable"
else -> "non-null"
@ -562,11 +636,14 @@ class TypesTest {
assertEquals("non-null", describe<Int>(1))
assertEquals("nullable", describeIs<Int?>(null))
assertEquals("non-null", describeIs<Int>(1))
""".trimIndent())
""".trimIndent()
)
}
@Test fun testIndexer() = runTest {
eval("""
@Test
fun testIndexer() = runTest {
eval(
"""
class Greeter {
override fun getAt(name) = "Hello, %s!"(name)
}
@ -581,11 +658,14 @@ class TypesTest {
assertEquals("How do you do, Bob?",Polite["Bob"])
""".trimIndent())
""".trimIndent()
)
}
@Test fun testIndexer2() = runTest {
eval("""
@Test
fun testIndexer2() = runTest {
eval(
"""
class Foo(bar)
class Greeter {
@ -613,17 +693,65 @@ class TypesTest {
assertEquals("How do you do, Bob?",g2v.bar)
assertEquals("How do you do, Bob?",Greeter2()["Bob"].bar)
""".trimIndent())
""".trimIndent()
)
}
@Test
fun testExternGenerics() = runTest {
eval("""
eval(
"""
extern fun f<T>(x: T): T
extern class Cell<T> {
var value: T
}
""")
"""
)
}
class ObjFoo : Obj() {
override val objClass = klass
var value: Obj = ObjNull
companion object {
val klass = object : ObjClass("ObjFoo") {
override suspend fun callOn(scope: Scope): Obj {
return ObjFoo()
}
}.apply {
addProperty("bar", {
check(thisObj is ObjFoo)
thisAs<ObjFoo>().value
}, {
thisAs<ObjFoo>().value = args.firstAndOnly()
})
}
}
}
@Test
fun testExtensionToExternClasses() = runTest {
val scope = Script.newScope()
scope.addConst("Foo", ObjFoo.klass)
scope.eval(
"""
extern class Foo {
var bar
}
fun Foo.foobar() = "foo" + bar
val f = Foo()
f.bar = 42
assert( f is Foo )
assertEquals(42, f.bar)
assertEquals("foo42", f.foobar())
""".trimIndent()
)
}
@Test