diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt index b362d6d..23a066f 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt @@ -1120,7 +1120,16 @@ class Compiler( if (implicitType != null) { resolutionSink?.referenceMember(name, pos, implicitType) val ids = resolveImplicitThisMemberIds(name, pos, implicitType) - return ImplicitThisMemberRef(name, pos, ids.fieldId, ids.methodId, implicitType) + val inClassContext = codeContexts.any { ctx -> ctx is CodeContext.ClassBody } + val currentImplicitType = currentImplicitThisTypeName() + val preferredType = when { + inClassContext -> implicitType + // Extension receiver aliases (extern class name -> host runtime class) can fail strict + // variant casts; keep current receiver untyped but preserve non-current receivers. + implicitType == currentImplicitType -> null + else -> implicitType + } + return ImplicitThisMemberRef(name, pos, ids.fieldId, ids.methodId, preferredType) } if (classCtx != null && classCtx.classScopeMembers.contains(name)) { resolutionSink?.referenceMember(name, pos, classCtx.name) @@ -4779,6 +4788,11 @@ class Compiler( if (payload != null) return payload } val receiverClass = resolveReceiverClassForMember(ref.receiver) + classMethodReturnClass(receiverClass, ref.name)?.let { return it } + inferFieldReturnClass(receiverClass, ref.name)?.let { return it } + if (receiverClass != null && isClassScopeCallableMember(receiverClass.className, ref.name)) { + resolveClassByName("${receiverClass.className}.${ref.name}")?.let { return it } + } return inferMethodCallReturnClass(ref.name) } diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt index da3700b..9358fd2 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt @@ -1067,18 +1067,9 @@ class BytecodeCompiler( updateSlotType(dst, SlotType.OBJ) return CompiledValue(dst, SlotType.OBJ) } - if (receiverClass == null && memberName == "negate") { - val zeroId = builder.addConst(BytecodeConst.IntVal(0)) - val zeroSlot = allocSlot() - builder.emit(Opcode.CONST_INT, zeroId, zeroSlot) - updateSlotType(zeroSlot, SlotType.INT) - val obj = ensureObjSlot(value) - val dst = allocSlot() - builder.emit(Opcode.SUB_OBJ, zeroSlot, obj.slot, dst) - updateSlotType(dst, SlotType.OBJ) - return CompiledValue(dst, SlotType.OBJ) - } - if (memberName == "negate" && receiverClass in setOf(ObjInt.type, ObjReal.type)) { + if (memberName == "negate" && + (receiverClass == null || isDelegateClass(receiverClass) || receiverClass in setOf(ObjInt.type, ObjReal.type)) + ) { val zeroId = builder.addConst(BytecodeConst.IntVal(0)) val zeroSlot = allocSlot() builder.emit(Opcode.CONST_INT, zeroId, zeroSlot) @@ -1095,6 +1086,11 @@ class BytecodeCompiler( ) } + private fun isDelegateClass(receiverClass: ObjClass): Boolean = + receiverClass.className == "Delegate" || + receiverClass.className == "LazyDelegate" || + receiverClass.implementingNames.contains("Delegate") + private fun operatorMemberName(op: BinOp): String? = when (op) { BinOp.PLUS -> "plus" BinOp.MINUS -> "minus" @@ -1142,11 +1138,7 @@ class BytecodeCompiler( ): CompiledValue? { val memberName = operatorMemberName(op) ?: return null val receiverClass = resolveReceiverClass(leftRef) - if (receiverClass == null || - receiverClass.className == "Delegate" || - receiverClass.className == "LazyDelegate" || - receiverClass.implementingNames.contains("Delegate") - ) { + if (receiverClass == null || isDelegateClass(receiverClass)) { val objOpcode = when (op) { BinOp.PLUS -> Opcode.ADD_OBJ BinOp.MINUS -> Opcode.SUB_OBJ @@ -4821,9 +4813,26 @@ class BytecodeCompiler( private data class CallArgs(val base: Int, val count: Int, val planId: Int?) - private fun resolveExtensionCallableSlot(receiverClass: ObjClass, memberName: String): CompiledValue? { + private fun extensionReceiverTypeNames(receiverClass: ObjClass): Set { + val names = LinkedHashSet() for (cls in receiverClass.mro) { - val candidate = extensionCallableName(cls.className, memberName) + names.add(cls.className) + for ((knownName, knownClass) in nameObjClass) { + if (knownClass !== cls && knownClass.className != cls.className) continue + names.add(knownName) + names.add(knownName.substringAfterLast('.')) + } + } + return names + } + + private fun resolveExtensionSlotByReceiverNames( + receiverClass: ObjClass, + memberName: String, + wrapperName: (String, String) -> String + ): CompiledValue? { + for (receiverName in extensionReceiverTypeNames(receiverClass)) { + val candidate = wrapperName(receiverName, memberName) if (allowedScopeNames != null && !allowedScopeNames.contains(candidate) && !localSlotIndexByName.containsKey(candidate) @@ -4835,32 +4844,39 @@ class BytecodeCompiler( return null } + private fun resolveUniqueExtensionWrapperSlot( + memberName: String, + wrapperPrefix: String + ): CompiledValue? { + val suffix = "__$memberName" + val candidates = LinkedHashSet() + for (name in localSlotIndexByName.keys) { + if (name.startsWith(wrapperPrefix) && name.endsWith(suffix)) { + candidates.add(name) + } + } + for (name in scopeSlotIndexByName.keys) { + if (name.startsWith(wrapperPrefix) && name.endsWith(suffix)) { + candidates.add(name) + } + } + if (candidates.size != 1) return null + return resolveDirectNameSlot(candidates.first()) + } + + private fun resolveExtensionCallableSlot(receiverClass: ObjClass, memberName: String): CompiledValue? { + return resolveExtensionSlotByReceiverNames(receiverClass, memberName, ::extensionCallableName) + ?: resolveUniqueExtensionWrapperSlot(memberName, "__ext__") + } + private fun resolveExtensionGetterSlot(receiverClass: ObjClass, memberName: String): CompiledValue? { - for (cls in receiverClass.mro) { - val candidate = extensionPropertyGetterName(cls.className, memberName) - if (allowedScopeNames != null && - !allowedScopeNames.contains(candidate) && - !localSlotIndexByName.containsKey(candidate) - ) { - continue - } - resolveDirectNameSlot(candidate)?.let { return it } - } - return null + return resolveExtensionSlotByReceiverNames(receiverClass, memberName, ::extensionPropertyGetterName) + ?: resolveUniqueExtensionWrapperSlot(memberName, "__ext_get__") } private fun resolveExtensionSetterSlot(receiverClass: ObjClass, memberName: String): CompiledValue? { - for (cls in receiverClass.mro) { - val candidate = extensionPropertySetterName(cls.className, memberName) - if (allowedScopeNames != null && - !allowedScopeNames.contains(candidate) && - !localSlotIndexByName.containsKey(candidate) - ) { - continue - } - resolveDirectNameSlot(candidate)?.let { return it } - } - return null + return resolveExtensionSlotByReceiverNames(receiverClass, memberName, ::extensionPropertySetterName) + ?: resolveUniqueExtensionWrapperSlot(memberName, "__ext_set__") } private fun compileCallArgsWithReceiver( @@ -7468,7 +7484,7 @@ class BytecodeCompiler( if (targetClass == null) return null if (targetClass == ObjDynamic.type) return ObjDynamic.type classFieldTypesByName[targetClass.className]?.get(name)?.let { cls -> - if (cls.className == "Delegate" || cls.className == "LazyDelegate" || cls.implementingNames.contains("Delegate")) { + if (isDelegateClass(cls)) { return null } return cls @@ -7555,8 +7571,8 @@ class BytecodeCompiler( private fun queueExtensionCallableNames(receiverClass: ObjClass, memberName: String) { if (!useScopeSlots && globalSlotInfo.isEmpty()) return - for (cls in receiverClass.mro) { - val name = extensionCallableName(cls.className, memberName) + for (receiverName in extensionReceiverTypeNames(receiverClass)) { + val name = extensionCallableName(receiverName, memberName) if (allowedScopeNames == null || allowedScopeNames.contains(name)) { pendingScopeNameRefs.add(name) } @@ -7565,12 +7581,12 @@ class BytecodeCompiler( private fun queueExtensionPropertyNames(receiverClass: ObjClass, memberName: String) { if (!useScopeSlots && globalSlotInfo.isEmpty()) return - for (cls in receiverClass.mro) { - val getter = extensionPropertyGetterName(cls.className, memberName) + for (receiverName in extensionReceiverTypeNames(receiverClass)) { + val getter = extensionPropertyGetterName(receiverName, memberName) if (allowedScopeNames == null || allowedScopeNames.contains(getter)) { pendingScopeNameRefs.add(getter) } - val setter = extensionPropertySetterName(cls.className, memberName) + val setter = extensionPropertySetterName(receiverName, memberName) if (allowedScopeNames == null || allowedScopeNames.contains(setter)) { pendingScopeNameRefs.add(setter) } diff --git a/lynglib/src/commonTest/kotlin/TypesTest.kt b/lynglib/src/commonTest/kotlin/TypesTest.kt index 418b2f0..08d50e7 100644 --- a/lynglib/src/commonTest/kotlin/TypesTest.kt +++ b/lynglib/src/commonTest/kotlin/TypesTest.kt @@ -16,9 +16,10 @@ */ import kotlinx.coroutines.test.runTest -import net.sergeych.lyng.Script -import net.sergeych.lyng.ScriptError -import net.sergeych.lyng.eval +import net.sergeych.lyng.* +import net.sergeych.lyng.obj.Obj +import net.sergeych.lyng.obj.ObjClass +import net.sergeych.lyng.obj.ObjNull import kotlin.test.Test import kotlin.test.assertFailsWith import kotlin.test.assertTrue @@ -27,33 +28,42 @@ class TypesTest { @Test fun testTypeCollection1() = runTest { - eval(""" + eval( + """ class Point(x: Real, y: Real) assert(Point(1,2).x == 1) assert(Point(1,2).y == 2) assert(Point(1,2) is Point) - """.trimIndent()) + """.trimIndent() + ) } + @Test fun testTypeCollection2() = runTest { - eval(""" + eval( + """ fun fn1(x: Real, y: Real): Real { x + y } - """.trimIndent()) + """.trimIndent() + ) } + @Test fun testTypeCollection3() = runTest { - eval(""" + eval( + """ class Test(a: Int) { fun fn1(x: Real, y: Real): Real { x + y } } - """.trimIndent()) + """.trimIndent() + ) } @Test fun testExternDeclarations() = runTest { - eval(""" + eval( + """ extern fun foo1(a: String): Void assertThrows { foo1("1") } class Test(a: Int) { @@ -69,22 +79,26 @@ class TypesTest { } // println("4") - """.trimIndent()) + """.trimIndent() + ) } @Test fun testUserClassCompareTo() = runTest { - eval(""" + eval( + """ class Point(val a,b) assertEquals(Point(0,1), Point(0,1) ) assertNotEquals(Point(0,1), Point(1,1) ) - """.trimIndent()) + """.trimIndent() + ) } @Test fun testUserClassCompareTo2() = runTest { - eval(""" + eval( + """ class Point(val a,b) { var c = 0 } @@ -98,12 +112,14 @@ class TypesTest { assertEquals(p1, p2) assertNotEquals(Point(0,1), Point(1,1) ) assertNotEquals(Point(0,1), p3) - """.trimIndent()) + """.trimIndent() + ) } @Test fun testNumericInference() = runTest { - eval(""" + eval( + """ val x = 1 var y = 2.0 assert( x is Int ) @@ -111,11 +127,14 @@ class TypesTest { assert( x + y is Real ) assert( abs(x+y) is Real ) assert( abs(x/y) is Real ) - """.trimIndent()) + """.trimIndent() + ) } + @Test fun testNumericInferenceBug1() = runTest { - eval(""" + eval( + """ fun findSumLimit(f) { var sum = 0.0 for( n in 1..100 ) { @@ -142,12 +161,14 @@ class TypesTest { val limit = findSumLimit { n -> 1.0/n/n } assert( limit != null ) println("Result: "+limit) - """.trimIndent()) + """.trimIndent() + ) } @Test fun testNullableHints() = runTest { - eval(""" + eval( + """ // nullable, without type os Object? class N(x=null) assertEquals(null, N().x) @@ -167,7 +188,8 @@ class TypesTest { @Test fun testIsUnionIntersection() = runTest { - eval(""" + eval( + """ class A class B class C: A, B @@ -179,7 +201,8 @@ class TypesTest { val v = 1 assert( v is Int | String | Real ) assert( !(v is String | Bool) ) - """.trimIndent()) + """.trimIndent() + ) } @Test @@ -221,51 +244,64 @@ class TypesTest { @Test fun testListLiteralInferenceForBounds() = runTest { - eval(""" + eval( + """ fun acceptInts(xs: List) { } acceptInts([1, 2, 3]) val base = [1, 2] acceptInts([...base, 3]) - """.trimIndent()) - eval(""" + """.trimIndent() + ) + eval( + """ fun acceptReals(xs: List) { } acceptReals([1.0, 2.0, 3.0]) val base = [1.0, 2.0] acceptReals([...base, 3.0]) - """.trimIndent()) + """.trimIndent() + ) assertFailsWith { - eval(""" + eval( + """ fun acceptInts(xs: List) { } acceptInts([1, "a"]) - """.trimIndent()) + """.trimIndent() + ) } assertFailsWith { - eval(""" + eval( + """ fun acceptReals(xs: List) { } acceptReals([1.0, "a"]) - """.trimIndent()) + """.trimIndent() + ) } } @Test fun testMapLiteralInferenceForBounds() = runTest { - eval(""" + eval( + """ fun acceptMap(m: Map) { } acceptMap({ "a": 1, "b": 2 }) val base = { "a": 1 } acceptMap({ ...base, "b": 3 }) - """.trimIndent()) + """.trimIndent() + ) assertFailsWith { - eval(""" + eval( + """ fun acceptMap(m: Map) { } acceptMap({ "a": 1, "b": "x" }) - """.trimIndent()) + """.trimIndent() + ) } } @Test fun testUnionTypeLists() = runTest { - eval(""" + eval( + """ fun fMixed(list: List) { println(list) @@ -282,12 +318,14 @@ class TypesTest { } fMixed([1, "two", true]) fInts([1,2,3]) - """) + """ + ) } @Test fun testTypeAliases() = runTest { - eval(""" + eval( + """ type Num = Int | Real type AB = A & B class A @@ -308,12 +346,14 @@ class TypesTest { type IntList = List fun accept(xs: IntList) { } accept([1,2,3]) - """.trimIndent()) + """.trimIndent() + ) } @Test fun testMultipleReceivers() = runTest { - eval(""" + eval( + """ class R1(shared,r1="r1") class R2(shared,r2="r2") @@ -340,69 +380,85 @@ class TypesTest { assertEquals("r1", r1) } } - """) + """ + ) } @Test fun testLambdaTypes1() = runTest { val scope = Script.newScope() // declare: ok - scope.eval(""" + scope.eval( + """ var l1: (Int,String)->String - """.trimIndent()) + """.trimIndent() + ) // this should be Lyng compile time exception assertFailsWith { - scope.eval(""" + scope.eval( + """ fun test() { // compiler should detect that l1 us called with arguments that does not match // declare type (Int,String)->String: l1() } - """.trimIndent()) + """.trimIndent() + ) } } @Test fun testLambdaTypesEllipsis() = runTest { val scope = Script.newScope() - scope.eval(""" + scope.eval( + """ var l2: (Int,Object...,String)->Real var l4: (Int,String...,String)->Real var l3: (...)->Int - """.trimIndent()) + """.trimIndent() + ) assertFailsWith { - scope.eval(""" + scope.eval( + """ fun testTooFew() { l2(1) } - """.trimIndent()) + """.trimIndent() + ) } assertFailsWith { - scope.eval(""" + scope.eval( + """ fun testWrongHead() { l2("x", "y") } - """.trimIndent()) + """.trimIndent() + ) } assertFailsWith { - scope.eval(""" + scope.eval( + """ fun testWrongEllipsis() { l4(1, 2, "x") } - """.trimIndent()) + """.trimIndent() + ) } - scope.eval(""" + scope.eval( + """ fun testOk1() { l2(1, "x") } fun testOk2() { l2(1, 2, 3, "x") } fun testOk3() { l3() } fun testOk4() { l3(1, true, "x") } fun testOk5() { l4(1, "a", "b", "x") } - """.trimIndent()) + """.trimIndent() + ) } @Test fun testSetTyped() = runTest { - eval(""" + eval( + """ var s = Set() val typed: Set = s assertEquals(Set(), typed) @@ -413,12 +469,14 @@ class TypesTest { assertEquals(Set(), s) s += ["foo", "bar"] assertEquals(Set("foo", "bar"), s) - """.trimIndent()) + """.trimIndent() + ) } @Test fun testListTyped() = runTest { - eval(""" + eval( + """ var l = List() val typed: List = l assertEquals(List(), typed) @@ -430,13 +488,15 @@ class TypesTest { l += ["foo", "bar"] assertEquals(List("foo", "bar"), l) - """.trimIndent()) + """.trimIndent() + ) } @Test fun testAliasesInGenerics1() = runTest { val scope = Script.newScope() - scope.eval(""" + scope.eval( + """ type IntList = List type IntMap = Map type IntSet = Set @@ -459,13 +519,16 @@ class TypesTest { assertEquals(Set("tag1", "tag2", Buffer("tag3")), x.tags) x.tags += Buffer("tag4") assertEquals(Set("tag1", "tag2", Buffer("tag3"), Buffer("tag4")), x.tags) - """) - scope.eval(""" + """ + ) + scope.eval( + """ assert(x is X) x.tags += "42" assertEquals(Set("tag1", "tag2", Buffer("tag3"), Buffer("tag4"), "42"), x.tags) - """.trimIndent()) + """.trimIndent() + ) // now this must fail becaise element type does not match the declared: assertFailsWith { scope.eval( @@ -479,7 +542,8 @@ class TypesTest { @Test fun testAliasesInGenericsList1() = runTest { val scope = Script.newScope() - scope.eval(""" + scope.eval( + """ import lyng.buffer type Tag = String | Buffer @@ -495,12 +559,15 @@ class TypesTest { assertEquals(List("tag1", "tag2", Buffer("tag3")), x.tags) x.tags += ["tag4", Buffer("tag5")] assertEquals(List("tag1", "tag2", Buffer("tag3"), "tag4", Buffer("tag5")), x.tags) - """) - scope.eval(""" + """ + ) + scope.eval( + """ assert(x is X) x.tags += "42" assertEquals(List("tag1", "tag2", Buffer("tag3"), "tag4", Buffer("tag5"), "42"), x.tags) - """.trimIndent()) + """.trimIndent() + ) assertFailsWith { scope.eval( """ @@ -512,18 +579,21 @@ class TypesTest { @Test fun testClassName() = runTest { - eval(""" + eval( + """ class X { var x = 1 } assert( X::class is Class) assertEquals("Class", X::class.name) - """.trimIndent()) + """.trimIndent() + ) } @Test fun testGenericTypes() = runTest { - eval(""" + eval( + """ fun t(): String = when(T) { null -> "%s is Null"(T::class.name) @@ -532,24 +602,28 @@ class TypesTest { } assert( Int is Object) assertEquals( t(), "Class is Object") - """.trimIndent()) + """.trimIndent() + ) } @Test fun testGenericNullableTypePredicate() = runTest { - eval(""" + eval( + """ fun isTypeNullable(x: T): Bool = T is nullable type MaybeInt = Int? assert(isTypeNullable(null)) assert(!isTypeNullable(1)) assert(MaybeInt is nullable) assert(!(Int is nullable)) - """.trimIndent()) + """.trimIndent() + ) } @Test fun testWhenNullableTypeCase() = runTest { - eval(""" + eval( + """ fun describe(x: T): String = when(T) { nullable -> "nullable" else -> "non-null" @@ -562,11 +636,14 @@ class TypesTest { assertEquals("non-null", describe(1)) assertEquals("nullable", describeIs(null)) assertEquals("non-null", describeIs(1)) - """.trimIndent()) + """.trimIndent() + ) } - @Test fun testIndexer() = runTest { - eval(""" + @Test + fun testIndexer() = runTest { + eval( + """ class Greeter { override fun getAt(name) = "Hello, %s!"(name) } @@ -581,11 +658,14 @@ class TypesTest { assertEquals("How do you do, Bob?",Polite["Bob"]) - """.trimIndent()) + """.trimIndent() + ) } - @Test fun testIndexer2() = runTest { - eval(""" + @Test + fun testIndexer2() = runTest { + eval( + """ class Foo(bar) class Greeter { @@ -613,17 +693,65 @@ class TypesTest { assertEquals("How do you do, Bob?",g2v.bar) assertEquals("How do you do, Bob?",Greeter2()["Bob"].bar) - """.trimIndent()) + """.trimIndent() + ) } @Test fun testExternGenerics() = runTest { - eval(""" + eval( + """ extern fun f(x: T): T extern class Cell { var value: T } - """) + """ + ) + } + + class ObjFoo : Obj() { + + override val objClass = klass + + var value: Obj = ObjNull + + companion object { + val klass = object : ObjClass("ObjFoo") { + override suspend fun callOn(scope: Scope): Obj { + return ObjFoo() + } + }.apply { + addProperty("bar", { + check(thisObj is ObjFoo) + thisAs().value + }, { + thisAs().value = args.firstAndOnly() + }) + } + } + } + + @Test + fun testExtensionToExternClasses() = runTest { + val scope = Script.newScope() + scope.addConst("Foo", ObjFoo.klass) + scope.eval( + """ + extern class Foo { + var bar + } + + fun Foo.foobar() = "foo" + bar + + val f = Foo() + f.bar = 42 + assert( f is Foo ) + assertEquals(42, f.bar) + + assertEquals("foo42", f.foobar()) + + """.trimIndent() + ) } @Test