From 2f145a0ea70910446a69bea203a437de164ba472 Mon Sep 17 00:00:00 2001 From: sergeych Date: Tue, 7 Apr 2026 09:33:40 +0300 Subject: [PATCH] Fix nullable let member inference --- .../kotlin/net/sergeych/lyng/Compiler.kt | 65 +++++++++---- lynglib/src/commonTest/kotlin/OOTest.kt | 97 +++++++++++++++++++ lynglib/src/commonTest/kotlin/ScriptTest.kt | 13 +++ 3 files changed, 155 insertions(+), 20 deletions(-) diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt index cc71606..7259d0c 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt @@ -2783,7 +2783,7 @@ class Compiler( Token.Type.LPAREN -> { cc.next() if (shouldTreatAsClassScopeCall(left, next.value)) { - val parsed = parseArgs(null, implicitItTypeNameForMemberLambda(left, next.value)) + val parsed = parseArgs(null, implicitItTypeForMemberLambda(left, next.value)) val args = parsed.first val tailBlock = parsed.second isCall = true @@ -2800,7 +2800,7 @@ class Compiler( val receiverType = if (next.value == "apply" || next.value == "run") { inferReceiverTypeFromRef(left) } else null - val parsed = parseArgs(receiverType, implicitItTypeNameForMemberLambda(left, next.value)) + val parsed = parseArgs(receiverType, implicitItTypeForMemberLambda(left, next.value)) val args = parsed.first val tailBlock = parsed.second if (left is LocalVarRef && left.name == "scope") { @@ -2879,7 +2879,7 @@ class Compiler( val receiverType = if (next.value == "apply" || next.value == "run") { inferReceiverTypeFromRef(left) } else null - val itType = implicitItTypeNameForMemberLambda(left, next.value) + val itType = implicitItTypeForMemberLambda(left, next.value) val lambda = parseLambdaExpression(receiverType, implicitItType = itType) val argPos = next.pos val args = listOf(ParsedArgument(ExpressionStatement(lambda, argPos), next.pos)) @@ -3287,7 +3287,7 @@ class Compiler( private suspend fun parseLambdaExpression( expectedReceiverType: String? = null, wrapAsExtensionCallable: Boolean = false, - implicitItType: String? = null + implicitItType: TypeDecl? = null ): ObjRef { // lambda args are different: val startPos = cc.currentPos() @@ -3304,14 +3304,15 @@ class Compiler( val slotParamNames = if (hasImplicitIt) paramNames + "it" else paramNames val paramSlotPlan = buildParamSlotPlan(slotParamNames) if (implicitItType != null) { - val cls = resolveClassByName(implicitItType) - ?: resolveTypeDeclObjClass(TypeDecl.Simple(implicitItType, false)) + val cls = resolveTypeDeclObjClass(implicitItType) val itSlot = paramSlotPlan.slots["it"]?.index - if (cls != null && itSlot != null) { - val paramTypeMap = slotTypeByScopeId.getOrPut(paramSlotPlan.id) { mutableMapOf() } - paramTypeMap[itSlot] = cls + if (itSlot != null) { + if (cls != null) { + val paramTypeMap = slotTypeByScopeId.getOrPut(paramSlotPlan.id) { mutableMapOf() } + paramTypeMap[itSlot] = cls + } val paramTypeDeclMap = slotTypeDeclByScopeId.getOrPut(paramSlotPlan.id) { mutableMapOf() } - paramTypeDeclMap[itSlot] = TypeDecl.Simple(implicitItType, false) + paramTypeDeclMap[itSlot] = implicitItType } } @@ -4745,6 +4746,13 @@ class Compiler( return null } + private fun classMethodReturnTypeDecl(typeName: String?, name: String): TypeDecl? { + if (typeName == null) return null + classMethodReturnTypeDeclByName[typeName]?.get(name)?.let { return it } + classMethodReturnTypeByName[typeName]?.get(name)?.let { return TypeDecl.Simple(it.className, false) } + return null + } + private fun classMethodReturnClass(targetClass: ObjClass?, name: String): ObjClass? { if (targetClass == null) return null if (targetClass == ObjDynamic.type) return ObjDynamic.type @@ -4858,6 +4866,26 @@ class Compiler( classMethodReturnTypeDecl(targetClass, "getAt") } is MethodCallRef -> methodReturnTypeDeclByRef[ref] ?: inferMethodCallReturnTypeDecl(ref) + is ImplicitThisMethodCallRef -> { + val typeName = ref.preferredThisTypeName() ?: currentImplicitThisTypeName() + val receiverDecl = typeName?.let { TypeDecl.Simple(it, false) } + inferMethodCallReturnTypeDecl(ref.methodName(), receiverDecl, ref.arguments()) + ?: classMethodReturnTypeDecl(typeName, ref.methodName()) + ?: typeName?.let { resolveClassByName(it) }?.let { classMethodReturnTypeDecl(it, ref.methodName()) } + } + is ThisMethodSlotCallRef -> { + val typeName = currentImplicitThisTypeName() + val receiverDecl = typeName?.let { TypeDecl.Simple(it, false) } + inferMethodCallReturnTypeDecl(ref.methodName(), receiverDecl, ref.arguments()) + ?: classMethodReturnTypeDecl(typeName, ref.methodName()) + ?: typeName?.let { resolveClassByName(it) }?.let { classMethodReturnTypeDecl(it, ref.methodName()) } + } + is QualifiedThisMethodSlotCallRef -> { + val receiverDecl = TypeDecl.Simple(ref.receiverTypeName(), false) + inferMethodCallReturnTypeDecl(ref.methodName(), receiverDecl, ref.arguments()) + ?: classMethodReturnTypeDecl(ref.receiverTypeName(), ref.methodName()) + ?: resolveClassByName(ref.receiverTypeName())?.let { classMethodReturnTypeDecl(it, ref.methodName()) } + } is CallRef -> callReturnTypeDeclByRef[ref] ?: inferCallReturnTypeDecl(ref) is BinaryOpRef -> inferBinaryOpReturnTypeDecl(ref) is StatementRef -> (ref.statement as? ExpressionStatement)?.let { resolveReceiverTypeDecl(it.ref) } @@ -5122,6 +5150,7 @@ class Compiler( private fun inferMethodCallReturnTypeDecl(ref: MethodCallRef): TypeDecl? { methodReturnTypeDeclByRef[ref]?.let { return it } val inferred = inferMethodCallReturnTypeDecl(ref.name, resolveReceiverTypeDecl(ref.receiver), ref.args) + ?: classMethodReturnTypeDecl(resolveReceiverClassForMember(ref.receiver), ref.name) if (inferred != null) { methodReturnTypeDeclByRef[ref] = inferred } @@ -5272,21 +5301,17 @@ class Compiler( } } - private fun implicitItTypeNameForMemberLambda(receiver: ObjRef, memberName: String): String? { + private fun implicitItTypeForMemberLambda(receiver: ObjRef, memberName: String): TypeDecl? { if (memberName == "fill" && isListTypeRef(receiver)) { - return "Int" + return TypeDecl.Simple("Int", false) } if (memberName == "let" || memberName == "also") { - return inferReceiverTypeFromRef(receiver) + val receiverType = inferTypeDeclFromRef(receiver) ?: resolveReceiverTypeDecl(receiver) + return receiverType?.let { makeTypeDeclNonNullable(it) } } - val typeDecl = when (memberName) { + return when (memberName) { "forEach", "map" -> inferIterableElementTypeDecl(receiver) else -> null - } ?: return null - return when (typeDecl) { - is TypeDecl.Simple -> typeDecl.name.substringAfterLast('.') - is TypeDecl.Generic -> typeDecl.name.substringAfterLast('.') - else -> resolveTypeDeclObjClass(typeDecl)?.className } } @@ -6293,7 +6318,7 @@ class Compiler( */ private suspend fun parseArgs( expectedTailBlockReceiver: String? = null, - implicitItType: String? = null + implicitItType: TypeDecl? = null ): Pair, Boolean> { val args = mutableListOf() diff --git a/lynglib/src/commonTest/kotlin/OOTest.kt b/lynglib/src/commonTest/kotlin/OOTest.kt index 9e0c0e3..d3db727 100644 --- a/lynglib/src/commonTest/kotlin/OOTest.kt +++ b/lynglib/src/commonTest/kotlin/OOTest.kt @@ -1102,4 +1102,101 @@ class OOTest { """.trimIndent()) } + @Test + fun testExtendingObjectWithExternals2() = runTest { + val s = EvalSession() + s.eval(""" + import lyng.serialization + object Storage { + extern val spaceUsed: Int + extern val spaceAvailable: Int + + /* + Return packed binary data or null + */ + extern fun getPacked(key: String): Buffer? + + /* + Upsert packed binary data + */ + extern fun putPacked(key: String,value: Buffer) + + /* + Delete data. + @return true if data were actually deleted, false means + there were no data for the key. + */ + extern fun delete(key: String): Bool + + override fun putAt(key: String,value: Object) { + putPacked(key, Lynon.encode(value).toBuffer()) + } + + override fun getAt(key: String): Object? = + getPacked(key)?.let { Lynon.decode(it.toBitInput()) } + } + + """.trimIndent() + ) + val scope = s.getScope() as ModuleScope + scope.bindObject("Storage") { + init { _ -> + data = mutableMapOf() + } + addVal("spaceUsed") { + val storage = (thisObj as ObjInstance).data as MutableMap + ObjInt(storage.values.sumOf { it.size }.toLong()) + } + addVal("spaceAvailable") { + val storage = (thisObj as ObjInstance).data as MutableMap + val capacity = 1_024 + ObjInt((capacity - storage.values.sumOf { it.size }).toLong()) + } + addFun("getPacked") { + val storage = (thisObj as ObjInstance).data as MutableMap + val key = (args.list[0] as ObjString).value + storage[key] ?: ObjNull + } + addFun("putPacked") { + val storage = (thisObj as ObjInstance).data as MutableMap + val key = (args.list[0] as ObjString).value + val value = args.list[1] as ObjBuffer + storage[key] = value + ObjVoid + } + addFun("delete") { + val storage = (thisObj as ObjInstance).data as MutableMap + val key = (args.list[0] as ObjString).value + ObjBool(storage.remove(key) != null) + } + } + s.eval(""" + assertEquals(0, Storage.spaceUsed) + assertEquals(1024, Storage.spaceAvailable) + val missing: String? = Storage["missing"] + assertEquals(null, missing) + + Storage["name"] = "alice" + Storage["count"] = 42 + + val name: String? = Storage["name"] + val count: Int? = Storage["count"] + assertEquals("alice", name) + assertEquals(42, count) + assert(Storage.spaceUsed > 0) + assert(Storage.spaceAvailable < 1024) + + val wrappedName: String? = Storage.getAt("name") + assertEquals("alice", wrappedName) + Storage.putAt("flag", true) + val flag: Bool? = Storage["flag"] + assertEquals(true, flag) + + assert(Storage.delete("name")) + val deletedName: String? = Storage["name"] + assertEquals(null, deletedName) + assert(!Storage.delete("name")) + """.trimIndent()) + } + } diff --git a/lynglib/src/commonTest/kotlin/ScriptTest.kt b/lynglib/src/commonTest/kotlin/ScriptTest.kt index c7ed52a..714240e 100644 --- a/lynglib/src/commonTest/kotlin/ScriptTest.kt +++ b/lynglib/src/commonTest/kotlin/ScriptTest.kt @@ -109,6 +109,19 @@ class ScriptTest { assertTrue(res is ObjString && res.value.isNotEmpty()) } + @Test + fun testNullableLetKeepsReceiverMemberType() = runTest { + Script.newScope().eval( + """ + import lyng.serialization + + val packed: Buffer? = Lynon.encode("alice").toBuffer() + val decoded = packed?.let { Lynon.decode(it.toBitInput()) } + assertEquals("alice", decoded) + """.trimIndent() + ) + } + @Test fun testNoInfiniteRecursionOnUnknownInNestedClosure() = runTest { val scope = Script.newScope()