Fix nullable let member inference

This commit is contained in:
Sergey Chernov 2026-04-07 09:33:40 +03:00
parent 15617f6998
commit 2f145a0ea7
3 changed files with 155 additions and 20 deletions

View File

@ -2783,7 +2783,7 @@ class Compiler(
Token.Type.LPAREN -> {
cc.next()
if (shouldTreatAsClassScopeCall(left, next.value)) {
val parsed = parseArgs(null, implicitItTypeNameForMemberLambda(left, next.value))
val parsed = parseArgs(null, implicitItTypeForMemberLambda(left, next.value))
val args = parsed.first
val tailBlock = parsed.second
isCall = true
@ -2800,7 +2800,7 @@ class Compiler(
val receiverType = if (next.value == "apply" || next.value == "run") {
inferReceiverTypeFromRef(left)
} else null
val parsed = parseArgs(receiverType, implicitItTypeNameForMemberLambda(left, next.value))
val parsed = parseArgs(receiverType, implicitItTypeForMemberLambda(left, next.value))
val args = parsed.first
val tailBlock = parsed.second
if (left is LocalVarRef && left.name == "scope") {
@ -2879,7 +2879,7 @@ class Compiler(
val receiverType = if (next.value == "apply" || next.value == "run") {
inferReceiverTypeFromRef(left)
} else null
val itType = implicitItTypeNameForMemberLambda(left, next.value)
val itType = implicitItTypeForMemberLambda(left, next.value)
val lambda = parseLambdaExpression(receiverType, implicitItType = itType)
val argPos = next.pos
val args = listOf(ParsedArgument(ExpressionStatement(lambda, argPos), next.pos))
@ -3287,7 +3287,7 @@ class Compiler(
private suspend fun parseLambdaExpression(
expectedReceiverType: String? = null,
wrapAsExtensionCallable: Boolean = false,
implicitItType: String? = null
implicitItType: TypeDecl? = null
): ObjRef {
// lambda args are different:
val startPos = cc.currentPos()
@ -3304,14 +3304,15 @@ class Compiler(
val slotParamNames = if (hasImplicitIt) paramNames + "it" else paramNames
val paramSlotPlan = buildParamSlotPlan(slotParamNames)
if (implicitItType != null) {
val cls = resolveClassByName(implicitItType)
?: resolveTypeDeclObjClass(TypeDecl.Simple(implicitItType, false))
val cls = resolveTypeDeclObjClass(implicitItType)
val itSlot = paramSlotPlan.slots["it"]?.index
if (cls != null && itSlot != null) {
if (itSlot != null) {
if (cls != null) {
val paramTypeMap = slotTypeByScopeId.getOrPut(paramSlotPlan.id) { mutableMapOf() }
paramTypeMap[itSlot] = cls
}
val paramTypeDeclMap = slotTypeDeclByScopeId.getOrPut(paramSlotPlan.id) { mutableMapOf() }
paramTypeDeclMap[itSlot] = TypeDecl.Simple(implicitItType, false)
paramTypeDeclMap[itSlot] = implicitItType
}
}
@ -4745,6 +4746,13 @@ class Compiler(
return null
}
private fun classMethodReturnTypeDecl(typeName: String?, name: String): TypeDecl? {
if (typeName == null) return null
classMethodReturnTypeDeclByName[typeName]?.get(name)?.let { return it }
classMethodReturnTypeByName[typeName]?.get(name)?.let { return TypeDecl.Simple(it.className, false) }
return null
}
private fun classMethodReturnClass(targetClass: ObjClass?, name: String): ObjClass? {
if (targetClass == null) return null
if (targetClass == ObjDynamic.type) return ObjDynamic.type
@ -4858,6 +4866,26 @@ class Compiler(
classMethodReturnTypeDecl(targetClass, "getAt")
}
is MethodCallRef -> methodReturnTypeDeclByRef[ref] ?: inferMethodCallReturnTypeDecl(ref)
is ImplicitThisMethodCallRef -> {
val typeName = ref.preferredThisTypeName() ?: currentImplicitThisTypeName()
val receiverDecl = typeName?.let { TypeDecl.Simple(it, false) }
inferMethodCallReturnTypeDecl(ref.methodName(), receiverDecl, ref.arguments())
?: classMethodReturnTypeDecl(typeName, ref.methodName())
?: typeName?.let { resolveClassByName(it) }?.let { classMethodReturnTypeDecl(it, ref.methodName()) }
}
is ThisMethodSlotCallRef -> {
val typeName = currentImplicitThisTypeName()
val receiverDecl = typeName?.let { TypeDecl.Simple(it, false) }
inferMethodCallReturnTypeDecl(ref.methodName(), receiverDecl, ref.arguments())
?: classMethodReturnTypeDecl(typeName, ref.methodName())
?: typeName?.let { resolveClassByName(it) }?.let { classMethodReturnTypeDecl(it, ref.methodName()) }
}
is QualifiedThisMethodSlotCallRef -> {
val receiverDecl = TypeDecl.Simple(ref.receiverTypeName(), false)
inferMethodCallReturnTypeDecl(ref.methodName(), receiverDecl, ref.arguments())
?: classMethodReturnTypeDecl(ref.receiverTypeName(), ref.methodName())
?: resolveClassByName(ref.receiverTypeName())?.let { classMethodReturnTypeDecl(it, ref.methodName()) }
}
is CallRef -> callReturnTypeDeclByRef[ref] ?: inferCallReturnTypeDecl(ref)
is BinaryOpRef -> inferBinaryOpReturnTypeDecl(ref)
is StatementRef -> (ref.statement as? ExpressionStatement)?.let { resolveReceiverTypeDecl(it.ref) }
@ -5122,6 +5150,7 @@ class Compiler(
private fun inferMethodCallReturnTypeDecl(ref: MethodCallRef): TypeDecl? {
methodReturnTypeDeclByRef[ref]?.let { return it }
val inferred = inferMethodCallReturnTypeDecl(ref.name, resolveReceiverTypeDecl(ref.receiver), ref.args)
?: classMethodReturnTypeDecl(resolveReceiverClassForMember(ref.receiver), ref.name)
if (inferred != null) {
methodReturnTypeDeclByRef[ref] = inferred
}
@ -5272,21 +5301,17 @@ class Compiler(
}
}
private fun implicitItTypeNameForMemberLambda(receiver: ObjRef, memberName: String): String? {
private fun implicitItTypeForMemberLambda(receiver: ObjRef, memberName: String): TypeDecl? {
if (memberName == "fill" && isListTypeRef(receiver)) {
return "Int"
return TypeDecl.Simple("Int", false)
}
if (memberName == "let" || memberName == "also") {
return inferReceiverTypeFromRef(receiver)
val receiverType = inferTypeDeclFromRef(receiver) ?: resolveReceiverTypeDecl(receiver)
return receiverType?.let { makeTypeDeclNonNullable(it) }
}
val typeDecl = when (memberName) {
return when (memberName) {
"forEach", "map" -> inferIterableElementTypeDecl(receiver)
else -> null
} ?: return null
return when (typeDecl) {
is TypeDecl.Simple -> typeDecl.name.substringAfterLast('.')
is TypeDecl.Generic -> typeDecl.name.substringAfterLast('.')
else -> resolveTypeDeclObjClass(typeDecl)?.className
}
}
@ -6293,7 +6318,7 @@ class Compiler(
*/
private suspend fun parseArgs(
expectedTailBlockReceiver: String? = null,
implicitItType: String? = null
implicitItType: TypeDecl? = null
): Pair<List<ParsedArgument>, Boolean> {
val args = mutableListOf<ParsedArgument>()

View File

@ -1102,4 +1102,101 @@ class OOTest {
""".trimIndent())
}
@Test
fun testExtendingObjectWithExternals2() = runTest {
val s = EvalSession()
s.eval("""
import lyng.serialization
object Storage {
extern val spaceUsed: Int
extern val spaceAvailable: Int
/*
Return packed binary data or null
*/
extern fun getPacked(key: String): Buffer?
/*
Upsert packed binary data
*/
extern fun putPacked(key: String,value: Buffer)
/*
Delete data.
@return true if data were actually deleted, false means
there were no data for the key.
*/
extern fun delete(key: String): Bool
override fun putAt(key: String,value: Object) {
putPacked(key, Lynon.encode(value).toBuffer())
}
override fun getAt(key: String): Object? =
getPacked(key)?.let { Lynon.decode(it.toBitInput()) }
}
""".trimIndent()
)
val scope = s.getScope() as ModuleScope
scope.bindObject("Storage") {
init { _ ->
data = mutableMapOf<String, ObjBuffer>()
}
addVal("spaceUsed") {
val storage = (thisObj as ObjInstance).data as MutableMap<String, ObjBuffer>
ObjInt(storage.values.sumOf { it.size }.toLong())
}
addVal("spaceAvailable") {
val storage = (thisObj as ObjInstance).data as MutableMap<String, ObjBuffer>
val capacity = 1_024
ObjInt((capacity - storage.values.sumOf { it.size }).toLong())
}
addFun("getPacked") {
val storage = (thisObj as ObjInstance).data as MutableMap<String, ObjBuffer>
val key = (args.list[0] as ObjString).value
storage[key] ?: ObjNull
}
addFun("putPacked") {
val storage = (thisObj as ObjInstance).data as MutableMap<String, ObjBuffer>
val key = (args.list[0] as ObjString).value
val value = args.list[1] as ObjBuffer
storage[key] = value
ObjVoid
}
addFun("delete") {
val storage = (thisObj as ObjInstance).data as MutableMap<String, ObjBuffer>
val key = (args.list[0] as ObjString).value
ObjBool(storage.remove(key) != null)
}
}
s.eval("""
assertEquals(0, Storage.spaceUsed)
assertEquals(1024, Storage.spaceAvailable)
val missing: String? = Storage["missing"]
assertEquals(null, missing)
Storage["name"] = "alice"
Storage["count"] = 42
val name: String? = Storage["name"]
val count: Int? = Storage["count"]
assertEquals("alice", name)
assertEquals(42, count)
assert(Storage.spaceUsed > 0)
assert(Storage.spaceAvailable < 1024)
val wrappedName: String? = Storage.getAt("name")
assertEquals("alice", wrappedName)
Storage.putAt("flag", true)
val flag: Bool? = Storage["flag"]
assertEquals(true, flag)
assert(Storage.delete("name"))
val deletedName: String? = Storage["name"]
assertEquals(null, deletedName)
assert(!Storage.delete("name"))
""".trimIndent())
}
}

View File

@ -109,6 +109,19 @@ class ScriptTest {
assertTrue(res is ObjString && res.value.isNotEmpty())
}
@Test
fun testNullableLetKeepsReceiverMemberType() = runTest {
Script.newScope().eval(
"""
import lyng.serialization
val packed: Buffer? = Lynon.encode("alice").toBuffer()
val decoded = packed?.let { Lynon.decode(it.toBitInput()) }
assertEquals("alice", decoded)
""".trimIndent()
)
}
@Test
fun testNoInfiniteRecursionOnUnknownInNestedClosure() = runTest {
val scope = Script.newScope()