Fix callable return inference regressions

This commit is contained in:
Sergey Chernov 2026-04-25 19:07:10 +03:00
parent eba7158330
commit 79b015ee56
9 changed files with 470 additions and 45 deletions

View File

@ -45,7 +45,7 @@ val restored = openSqlite(":memory:").transaction { tx ->
assertEquals(21, restored.state.count) assertEquals(21, restored.state.count)
assertEquals("updated", restored.note) assertEquals("updated", restored.note)
restored restored
} as Item }
println("Restored item:") println("Restored item:")
println(" id=" + restored.id) println(" id=" + restored.id)

View File

@ -32,6 +32,7 @@ import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
import net.sergeych.lyng.EvalSession import net.sergeych.lyng.EvalSession
import net.sergeych.lyng.ExecutionError
import net.sergeych.lyng.LyngVersion import net.sergeych.lyng.LyngVersion
import net.sergeych.lyng.Pos import net.sergeych.lyng.Pos
import net.sergeych.lyng.Scope import net.sergeych.lyng.Scope
@ -153,8 +154,8 @@ val baseScopeDefer = globalDefer {
baseCliImportManagerDefer.await().copy().apply { baseCliImportManagerDefer.await().copy().apply {
invalidateCliModuleCaches() invalidateCliModuleCaches()
}.newStdScope().apply { }.newStdScope().apply {
installCliDeclarations()
installCliBuiltins() installCliBuiltins()
installCliDeclarations()
addConst("ARGV", ObjList(mutableListOf())) addConst("ARGV", ObjList(mutableListOf()))
} }
} }
@ -364,8 +365,8 @@ private fun registerLocalCliModules(manager: ImportManager, modules: List<LocalC
private suspend fun ImportManager.newCliScope(argv: List<String>): Scope = private suspend fun ImportManager.newCliScope(argv: List<String>): Scope =
newStdScope().apply { newStdScope().apply {
installCliDeclarations()
installCliBuiltins() installCliBuiltins()
installCliDeclarations()
addConst("ARGV", ObjList(argv.map { ObjString(it) }.toMutableList())) addConst("ARGV", ObjList(argv.map { ObjString(it) }.toMutableList()))
} }
@ -547,6 +548,15 @@ suspend fun executeSource(source: Source, initialScope: Scope? = null) {
evalOnCliDispatcher(session, source) evalOnCliDispatcher(session, source)
} catch (e: CliExitRequested) { } catch (e: CliExitRequested) {
requestedExitCode = e.code requestedExitCode = e.code
} catch (e: ExecutionError) {
val cliExit = generateSequence<Throwable>(e) { it.cause }
.filterIsInstance<CliExitRequested>()
.firstOrNull()
if (cliExit != null) {
requestedExitCode = cliExit.code
} else {
throw e
}
} }
} finally { } finally {
shutdownHooks.uninstall() shutdownHooks.uninstall()

View File

@ -137,6 +137,47 @@ class LyngSqliteModuleTest {
assertEquals(2L, result.value) assertEquals(2L, result.value)
} }
@Test
fun testTransactionGenericReturnTypeFlowsToOuterVal() = runTest {
val scope = Script.newScope()
createSqliteModule(scope.importManager)
val code = """
import lyng.io.db
import lyng.io.db.sqlite
class Payload(name: String, count: Int)
class Item(id: Int, title: String, @DbJson meta: Payload, @DbLynon state: Payload) {
var note: String = ""
}
val restored = openSqlite(":memory:").transaction { tx ->
tx.execute("create table item(id integer not null, title text not null, meta text not null, state blob not null, note text not null)")
val item = Item(1, "first", Payload("json", 10), Payload("bin", 20))
item.note = "created"
tx.execute("insert into item(@cols(?1)) values(@vals(?1))", item)
item.title = "second"
item.meta = Payload("json2", 11)
item.state = Payload("bin2", 21)
item.note = "updated"
tx.execute("update item set @set(?1 except: \"id\") where id = ?2", item, item.id)
val restored = tx.select("select * from item where id = ?", 1).decodeAs<Item>().first
assertEquals("second", restored.title)
assertEquals("json2", restored.meta.name)
assertEquals(11, restored.meta.count)
assertEquals("bin2", restored.state.name)
assertEquals(21, restored.state.count)
assertEquals("updated", restored.note)
restored
}
restored.id
""".trimIndent()
val result = Compiler.compile(Source("<sqlite-transaction-return-inference>", code), scope.importManager).execute(scope) as ObjInt
assertEquals(1L, result.value)
}
@Test @Test
fun testDecodeAsProjectsJsonColumnIntoObjectField() = runTest { fun testDecodeAsProjectsJsonColumnIntoObjectField() = runTest {
val scope = Script.newScope() val scope = Script.newScope()

View File

@ -189,6 +189,7 @@ class Compiler(
private val callableReturnTypeDeclByName: MutableMap<String, TypeDecl> = mutableMapOf() private val callableReturnTypeDeclByName: MutableMap<String, TypeDecl> = mutableMapOf()
private val callSignatureByName: MutableMap<String, CallSignature> = mutableMapOf() private val callSignatureByName: MutableMap<String, CallSignature> = mutableMapOf()
private val lambdaReturnTypeByRef: MutableMap<ObjRef, ObjClass> = mutableMapOf() private val lambdaReturnTypeByRef: MutableMap<ObjRef, ObjClass> = mutableMapOf()
private val lambdaTypeDeclByRef: MutableMap<ObjRef, TypeDecl.Function> = mutableMapOf()
private val exactLambdaRefByScopeId: MutableMap<Int, MutableMap<Int, LambdaFnRef>> = mutableMapOf() private val exactLambdaRefByScopeId: MutableMap<Int, MutableMap<Int, LambdaFnRef>> = mutableMapOf()
private val lambdaCaptureEntriesByRef: MutableMap<ValueFnRef, List<net.sergeych.lyng.bytecode.LambdaCaptureEntry>> = private val lambdaCaptureEntriesByRef: MutableMap<ValueFnRef, List<net.sergeych.lyng.bytecode.LambdaCaptureEntry>> =
mutableMapOf() mutableMapOf()
@ -682,6 +683,7 @@ class Compiler(
return CompileClassInfo( return CompileClassInfo(
name, name,
cls.logicalPackageName, cls.logicalPackageName,
emptyList(),
fieldIds, fieldIds,
methodIds, methodIds,
nextFieldId, nextFieldId,
@ -1750,6 +1752,7 @@ class Compiler(
private data class CompileClassInfo( private data class CompileClassInfo(
val name: String, val name: String,
val packageName: String?, val packageName: String?,
val typeParams: List<String>,
val fieldIds: Map<String, Int>, val fieldIds: Map<String, Int>,
val methodIds: Map<String, Int>, val methodIds: Map<String, Int>,
val nextFieldId: Int, val nextFieldId: Int,
@ -3496,10 +3499,12 @@ class Compiler(
paramTypeDeclMap[slot] = typeDecl paramTypeDeclMap[slot] = typeDecl
} }
val lambdaParamTypeDecls = mutableListOf<TypeDecl>()
if (argsDeclaration != null) { if (argsDeclaration != null) {
val expectedParams = expectedCallableType?.params.orEmpty() val expectedParams = expectedCallableType?.params.orEmpty()
argsDeclaration.params.forEachIndexed { index, param -> argsDeclaration.params.forEachIndexed { index, param ->
val effectiveType = if ((param.type == TypeDecl.TypeAny || param.type == TypeDecl.TypeNullableAny) && val rawType = if ((param.type == TypeDecl.TypeAny || param.type == TypeDecl.TypeNullableAny) &&
index < expectedParams.size index < expectedParams.size
) { ) {
expectedParams[index] expectedParams[index]
@ -3510,15 +3515,20 @@ class Compiler(
} else { } else {
param.type param.type
} }
val effectiveType = if (param.isEllipsis) TypeDecl.Ellipsis(rawType) else rawType
lambdaParamTypeDecls += effectiveType
if (effectiveType != TypeDecl.TypeAny && effectiveType != TypeDecl.TypeNullableAny) { if (effectiveType != TypeDecl.TypeAny && effectiveType != TypeDecl.TypeNullableAny) {
seedLambdaParamType(param.name, effectiveType) seedLambdaParamType(param.name, rawType)
} }
} }
} else { } else {
val effectiveImplicitItType = implicitItType val effectiveImplicitItType = implicitItType
?: expectedCallableType?.params?.singleOrNull() ?: expectedCallableType?.params?.singleOrNull()
if (effectiveImplicitItType != null) { if (effectiveImplicitItType != null) {
lambdaParamTypeDecls += effectiveImplicitItType
seedLambdaParamType("it", effectiveImplicitItType) seedLambdaParamType("it", effectiveImplicitItType)
} else {
lambdaParamTypeDecls += TypeDecl.Ellipsis(TypeDecl.TypeAny)
} }
} }
@ -3580,6 +3590,7 @@ class Compiler(
} else { } else {
emptyList() emptyList()
} }
val inferredReturnDecl = inferReturnTypeDeclFromStatement(body)
val returnClass = inferReturnClassFromStatement(body) val returnClass = inferReturnClassFromStatement(body)
val paramKnownClasses = mutableMapOf<String, ObjClass>() val paramKnownClasses = mutableMapOf<String, ObjClass>()
argsDeclaration?.params?.forEach { param -> argsDeclaration?.params?.forEach { param ->
@ -3784,6 +3795,13 @@ class Compiler(
returnLabels = returnLabels, returnLabels = returnLabels,
pos = startPos pos = startPos
) )
val lambdaTypeDecl = TypeDecl.Function(
receiver = null,
params = lambdaParamTypeDecls.toList(),
returnType = inferredReturnDecl ?: returnClass?.let { TypeDecl.Simple(it.className, false) } ?: TypeDecl.TypeAny,
nullable = false
)
lambdaTypeDeclByRef[ref] = lambdaTypeDecl
if (returnClass != null) { if (returnClass != null) {
lambdaReturnTypeByRef[ref] = returnClass lambdaReturnTypeByRef[ref] = returnClass
} }
@ -4747,6 +4765,7 @@ class Compiler(
private fun inferTypeDeclFromRef(ref: ObjRef): TypeDecl? { private fun inferTypeDeclFromRef(ref: ObjRef): TypeDecl? {
resolveReceiverTypeDecl(ref)?.let { return it } resolveReceiverTypeDecl(ref)?.let { return it }
return when (ref) { return when (ref) {
is ValueFnRef -> lambdaTypeDeclByRef[ref]
is ListLiteralRef -> inferListLiteralTypeDecl(ref) is ListLiteralRef -> inferListLiteralTypeDecl(ref)
is MapLiteralRef -> inferMapLiteralTypeDecl(ref) is MapLiteralRef -> inferMapLiteralTypeDecl(ref)
is ConstRef -> inferTypeDeclFromConst(ref.constValue) is ConstRef -> inferTypeDeclFromConst(ref.constValue)
@ -5057,6 +5076,37 @@ class Compiler(
return null return null
} }
private fun substituteReceiverTypeParams(receiverType: TypeDecl?, ownerClassName: String?, memberType: TypeDecl?): TypeDecl? {
if (receiverType !is TypeDecl.Generic || ownerClassName == null || memberType == null) return memberType
val info = resolveCompileClassInfo(ownerClassName) ?: return memberType
if (info.typeParams.isEmpty()) return memberType
val bindings = LinkedHashMap<String, TypeDecl>(info.typeParams.size)
for ((index, typeParamName) in info.typeParams.withIndex()) {
val argType = receiverType.args.getOrNull(index) ?: continue
bindings[typeParamName] = argType
}
if (bindings.isEmpty()) return memberType
return substituteTypeAliasTypeVars(memberType, bindings)
}
private fun inferExtensionPropertyTypeDecl(receiverDecl: TypeDecl?, receiverClass: ObjClass?, memberName: String): TypeDecl? {
if (receiverClass == null) return null
for (cls in receiverClass.mro) {
val wrapperName = extensionPropertyGetterName(cls.className, memberName)
val resolved = resolveImportBinding(wrapperName, Pos.builtIn) ?: continue
registerImportBinding(wrapperName, resolved.binding, Pos.builtIn)
val wrapperType = resolved.record.typeDecl as? TypeDecl.Function ?: continue
val bindings = mutableMapOf<String, TypeDecl>()
val receiverParam = wrapperType.params.firstOrNull() ?: wrapperType.receiver
if (receiverParam != null && receiverDecl != null) {
collectTypeVarBindings(receiverParam, receiverDecl, bindings)
}
return if (bindings.isEmpty()) wrapperType.returnType
else substituteTypeAliasTypeVars(wrapperType.returnType, bindings)
}
return null
}
private fun classMemberTypeDecl(targetClass: ObjClass?, name: String): TypeDecl? { private fun classMemberTypeDecl(targetClass: ObjClass?, name: String): TypeDecl? {
if (targetClass == null) return null if (targetClass == null) return null
if (targetClass == ObjDynamic.type) return TypeDecl.TypeAny if (targetClass == ObjDynamic.type) return TypeDecl.TypeAny
@ -5192,7 +5242,12 @@ class Compiler(
is FieldRef -> { is FieldRef -> {
val targetDecl = resolveReceiverTypeDecl(ref.target) ?: return null val targetDecl = resolveReceiverTypeDecl(ref.target) ?: return null
val targetClass = resolveTypeDeclObjClass(targetDecl) ?: resolveReceiverClassForMember(ref.target) val targetClass = resolveTypeDeclObjClass(targetDecl) ?: resolveReceiverClassForMember(ref.target)
classMemberTypeDecl(targetClass, ref.name)?.let { return it } classMemberTypeDecl(targetClass, ref.name)?.let { declared ->
val ownerClassName = targetClass?.getInstanceMemberOrNull(ref.name, includeAbstract = true)
?.declaringClass?.className ?: targetClass?.className
return substituteReceiverTypeParams(targetDecl, ownerClassName, declared)
}
inferExtensionPropertyTypeDecl(targetDecl, targetClass, ref.name)?.let { return it }
classFieldTypesByName[targetClass?.className]?.get(ref.name) classFieldTypesByName[targetClass?.className]?.get(ref.name)
?.let { return TypeDecl.Simple(it.className, false) } ?.let { return TypeDecl.Simple(it.className, false) }
when (targetDecl) { when (targetDecl) {
@ -5529,8 +5584,22 @@ class Compiler(
private fun inferMethodCallReturnTypeDecl(ref: MethodCallRef): TypeDecl? { private fun inferMethodCallReturnTypeDecl(ref: MethodCallRef): TypeDecl? {
methodReturnTypeDeclByRef[ref]?.let { return it } methodReturnTypeDeclByRef[ref]?.let { return it }
val inferred = inferMethodCallReturnTypeDecl(ref.name, resolveReceiverTypeDecl(ref.receiver), ref.args) val receiverDecl = resolveReceiverTypeDecl(ref.receiver)
?: classMethodReturnTypeDecl(resolveReceiverClassForMember(ref.receiver), ref.name) val inferred = inferMethodCallReturnTypeDecl(ref.name, receiverDecl, ref.args)
?: inferDeclaredMethodCallReturnTypeDecl(
ref.name,
receiverDecl,
resolveReceiverClassForMember(ref.receiver),
ref.args,
ref.explicitTypeArgs
)
?: run {
val receiverClass = resolveReceiverClassForMember(ref.receiver)
val declared = classMethodReturnTypeDecl(receiverClass, ref.name)
val ownerClassName = receiverClass?.getInstanceMemberOrNull(ref.name, includeAbstract = true)
?.declaringClass?.className ?: receiverClass?.className
substituteReceiverTypeParams(receiverDecl, ownerClassName, declared)
}
if (inferred != null) { if (inferred != null) {
methodReturnTypeDeclByRef[ref] = inferred methodReturnTypeDeclByRef[ref] = inferred
} }
@ -5635,6 +5704,108 @@ class Compiler(
} }
} }
private fun inferDeclaredMethodCallReturnTypeDecl(
name: String,
receiverDecl: TypeDecl?,
receiverClass: ObjClass?,
args: List<ParsedArgument>,
explicitTypeArgs: List<TypeDecl>? = null
): TypeDecl? {
if (receiverClass == null) return null
val ownerClassName = receiverClass.getInstanceMemberOrNull(name, includeAbstract = true)
?.declaringClass?.className ?: receiverClass.className
val memberType = substituteReceiverTypeParams(
receiverDecl,
ownerClassName,
classMemberTypeDecl(receiverClass, name)
) as? TypeDecl.Function ?: return null
fun argTypeDecl(arg: ParsedArgument): TypeDecl? {
val stmt = arg.value as? ExpressionStatement ?: return null
val directRef = stmt.ref
return inferTypeDeclFromRef(directRef)
?: inferObjClassFromRef(directRef)?.let { TypeDecl.Simple(it.className, false) }
}
val bindings = mutableMapOf<String, TypeDecl>()
collectExplicitMethodTypeBindings(memberType, explicitTypeArgs, bindings)
memberType.receiver?.let { declaredReceiver ->
receiverDecl?.let { collectTypeVarBindings(declaredReceiver, it, bindings) }
}
val paramList = memberType.params
val ellipsisIndex = paramList.indexOfFirst { it is TypeDecl.Ellipsis }
if (ellipsisIndex < 0) {
val limit = minOf(paramList.size, args.size)
for (i in 0 until limit) {
val argType = argTypeDecl(args[i]) ?: continue
collectTypeVarBindings(paramList[i], argType, bindings)
}
} else {
val headCount = ellipsisIndex
val tailCount = paramList.size - ellipsisIndex - 1
val argCount = args.size
val headLimit = minOf(headCount, argCount)
for (i in 0 until headLimit) {
val argType = argTypeDecl(args[i]) ?: continue
collectTypeVarBindings(paramList[i], argType, bindings)
}
val tailStartArg = maxOf(headCount, argCount - tailCount)
for (i in tailStartArg until argCount) {
val paramIndex = paramList.size - (argCount - i)
val argType = argTypeDecl(args[i]) ?: continue
collectTypeVarBindings(paramList[paramIndex], argType, bindings)
}
val ellipsisArgEnd = argCount - tailCount
val ellipsisType = paramList[ellipsisIndex] as TypeDecl.Ellipsis
for (i in headCount until ellipsisArgEnd) {
val argType = if (args[i].isSplat) {
val stmt = args[i].value as? ExpressionStatement
stmt?.ref?.let { inferElementTypeFromSpread(it) }
} else {
argTypeDecl(args[i])
} ?: continue
collectTypeVarBindings(ellipsisType.elementType, argType, bindings)
}
}
return if (bindings.isEmpty()) memberType.returnType
else substituteTypeAliasTypeVars(memberType.returnType, bindings)
}
private fun collectExplicitMethodTypeBindings(
memberType: TypeDecl.Function,
explicitTypeArgs: List<TypeDecl>?,
out: MutableMap<String, TypeDecl>
) {
if (explicitTypeArgs.isNullOrEmpty()) return
val typeVars = LinkedHashSet<String>()
memberType.receiver?.let { collectTypeVarNamesInOrder(it, typeVars) }
memberType.params.forEach { collectTypeVarNamesInOrder(it, typeVars) }
collectTypeVarNamesInOrder(memberType.returnType, typeVars)
val names = typeVars.toList()
val limit = minOf(names.size, explicitTypeArgs.size)
for (i in 0 until limit) {
out[names[i]] = explicitTypeArgs[i]
}
}
private fun collectTypeVarNamesInOrder(type: TypeDecl, out: MutableSet<String>) {
when (type) {
is TypeDecl.TypeVar -> out += type.name
is TypeDecl.Generic -> type.args.forEach { collectTypeVarNamesInOrder(it, out) }
is TypeDecl.Function -> {
type.receiver?.let { collectTypeVarNamesInOrder(it, out) }
type.params.forEach { collectTypeVarNamesInOrder(it, out) }
collectTypeVarNamesInOrder(type.returnType, out)
}
is TypeDecl.Ellipsis -> collectTypeVarNamesInOrder(type.elementType, out)
is TypeDecl.Union -> type.options.forEach { collectTypeVarNamesInOrder(it, out) }
is TypeDecl.Intersection -> type.options.forEach { collectTypeVarNamesInOrder(it, out) }
else -> {}
}
}
private fun inferCallableReturnTypeDeclFromArgument(arg: ParsedArgument): TypeDecl? { private fun inferCallableReturnTypeDeclFromArgument(arg: ParsedArgument): TypeDecl? {
val stmt = arg.value as? ExpressionStatement ?: return null val stmt = arg.value as? ExpressionStatement ?: return null
val ref = stmt.ref val ref = stmt.ref
@ -6104,6 +6275,7 @@ class Compiler(
args: List<ParsedArgument>, args: List<ParsedArgument>,
pos: Pos pos: Pos
) { ) {
if (shouldSkipStaticCallableChecks(target)) return
lookupNamedFunctionDecl(target)?.let { decl -> lookupNamedFunctionDecl(target)?.let { decl ->
val hasComplexArgs = args.any { it.name != null } || val hasComplexArgs = args.any { it.name != null } ||
decl.typeParams.isNotEmpty() || decl.typeParams.isNotEmpty() ||
@ -6201,6 +6373,22 @@ class Compiler(
return null return null
} }
private fun shouldSkipStaticCallableChecks(target: ObjRef): Boolean {
resolveExactLambdaRef(target)?.let { lambda ->
if (lambda.argsDeclaration == null) return true
}
val name = when (target) {
is LocalVarRef -> target.name
is LocalSlotRef -> target.name
is FastLocalVarRef -> target.name
else -> null
}
if (name != null && callSignatureForName(name) != null) {
return true
}
return lookupNamedCallableRecord(target)?.callSignature != null
}
private fun inferCallReturnTypeDecl(ref: CallRef): TypeDecl? { private fun inferCallReturnTypeDecl(ref: CallRef): TypeDecl? {
callReturnTypeDeclByRef[ref]?.let { return it } callReturnTypeDeclByRef[ref]?.let { return it }
val targetDecl = (resolveReceiverTypeDecl(ref.target) ?: seedTypeDeclFromRef(ref.target)) as? TypeDecl.Function val targetDecl = (resolveReceiverTypeDecl(ref.target) ?: seedTypeDeclFromRef(ref.target)) as? TypeDecl.Function
@ -6263,6 +6451,7 @@ class Compiler(
args: List<ParsedArgument>, args: List<ParsedArgument>,
pos: Pos pos: Pos
) { ) {
if (shouldSkipStaticCallableChecks(target)) return
lookupNamedFunctionDecl(target)?.let { decl -> lookupNamedFunctionDecl(target)?.let { decl ->
val hasComplexArgs = args.any { it.name != null } || val hasComplexArgs = args.any { it.name != null } ||
decl.typeParams.isNotEmpty() || decl.typeParams.isNotEmpty() ||
@ -6356,7 +6545,7 @@ class Compiler(
return return
} }
val seededCallable = lookupNamedCallableRecord(target) val seededCallable = lookupNamedCallableRecord(target)
if (seededCallable != null && seededCallable.type == ObjRecord.Type.Fun && seededCallable.value !is ObjExternCallable) { if (seededCallable != null && seededCallable.type == ObjRecord.Type.Fun) {
return return
} }
val decl = (resolveReceiverTypeDecl(target) as? TypeDecl.Function) val decl = (resolveReceiverTypeDecl(target) as? TypeDecl.Function)
@ -6470,6 +6659,17 @@ class Compiler(
} }
} }
} }
is TypeDecl.Function -> {
if (argType is TypeDecl.Function && paramType.params.size == argType.params.size) {
if (paramType.receiver != null && argType.receiver != null) {
collectTypeVarBindings(paramType.receiver, argType.receiver, out)
}
for (i in paramType.params.indices) {
collectTypeVarBindings(paramType.params[i], argType.params[i], out)
}
collectTypeVarBindings(paramType.returnType, argType.returnType, out)
}
}
is TypeDecl.Union -> { is TypeDecl.Union -> {
if (argType is TypeDecl.Union) { if (argType is TypeDecl.Union) {
val limit = minOf(paramType.options.size, argType.options.size) val limit = minOf(paramType.options.size, argType.options.size)
@ -7820,6 +8020,7 @@ class Compiler(
compileClassInfos[qualifiedName] = CompileClassInfo( compileClassInfos[qualifiedName] = CompileClassInfo(
name = qualifiedName, name = qualifiedName,
packageName = packageName, packageName = packageName,
typeParams = emptyList(),
fieldIds = fieldIds, fieldIds = fieldIds,
methodIds = methodIds, methodIds = methodIds,
nextFieldId = fieldIds.size, nextFieldId = fieldIds.size,
@ -7943,6 +8144,7 @@ class Compiler(
compileClassInfos[className] = CompileClassInfo( compileClassInfos[className] = CompileClassInfo(
name = className, name = className,
packageName = packageName, packageName = packageName,
typeParams = emptyList(),
fieldIds = ctx.memberFieldIds.toMap(), fieldIds = ctx.memberFieldIds.toMap(),
methodIds = ctx.memberMethodIds.toMap(), methodIds = ctx.memberMethodIds.toMap(),
nextFieldId = ctx.nextFieldId, nextFieldId = ctx.nextFieldId,
@ -7985,6 +8187,7 @@ class Compiler(
compileClassInfos[className] = CompileClassInfo( compileClassInfos[className] = CompileClassInfo(
name = className, name = className,
packageName = packageName, packageName = packageName,
typeParams = emptyList(),
fieldIds = baseIds.fieldIds, fieldIds = baseIds.fieldIds,
methodIds = baseIds.methodIds, methodIds = baseIds.methodIds,
nextFieldId = baseIds.nextFieldId, nextFieldId = baseIds.nextFieldId,
@ -8250,6 +8453,7 @@ class Compiler(
compileClassInfos[qualifiedName] = CompileClassInfo( compileClassInfos[qualifiedName] = CompileClassInfo(
name = qualifiedName, name = qualifiedName,
packageName = packageName, packageName = packageName,
typeParams = typeParamDecls.map { it.name },
fieldIds = ctx.memberFieldIds.toMap(), fieldIds = ctx.memberFieldIds.toMap(),
methodIds = ctx.memberMethodIds.toMap(), methodIds = ctx.memberMethodIds.toMap(),
nextFieldId = ctx.nextFieldId, nextFieldId = ctx.nextFieldId,
@ -8267,6 +8471,7 @@ class Compiler(
compileClassInfos[qualifiedName] = CompileClassInfo( compileClassInfos[qualifiedName] = CompileClassInfo(
name = qualifiedName, name = qualifiedName,
packageName = packageName, packageName = packageName,
typeParams = typeParamDecls.map { it.name },
fieldIds = ctx.memberFieldIds.toMap(), fieldIds = ctx.memberFieldIds.toMap(),
methodIds = ctx.memberMethodIds.toMap(), methodIds = ctx.memberMethodIds.toMap(),
nextFieldId = ctx.nextFieldId, nextFieldId = ctx.nextFieldId,
@ -8355,6 +8560,7 @@ class Compiler(
compileClassInfos[qualifiedName] = CompileClassInfo( compileClassInfos[qualifiedName] = CompileClassInfo(
name = qualifiedName, name = qualifiedName,
packageName = packageName, packageName = packageName,
typeParams = typeParamDecls.map { it.name },
fieldIds = ctx.memberFieldIds.toMap(), fieldIds = ctx.memberFieldIds.toMap(),
methodIds = ctx.memberMethodIds.toMap(), methodIds = ctx.memberMethodIds.toMap(),
nextFieldId = ctx.nextFieldId, nextFieldId = ctx.nextFieldId,
@ -9539,7 +9745,8 @@ class Compiler(
resolutionSink?.exitScope(cc.currentPos()) resolutionSink?.exitScope(cc.currentPos())
} }
val rawFnStatements = parsedFnStatements?.let { unwrapBytecodeDeep(it) } val rawFnStatements = parsedFnStatements?.let { unwrapBytecodeDeep(it) }
val inferredReturnClass = returnTypeDecl?.let { resolveTypeDeclObjClass(it) } val inferredReturnDecl = returnTypeDecl ?: inferReturnTypeDeclFromStatement(rawFnStatements)
val inferredReturnClass = inferredReturnDecl?.let { resolveTypeDeclObjClass(it) }
?: inferReturnClassFromStatement(rawFnStatements) ?: inferReturnClassFromStatement(rawFnStatements)
if (parentContext is CodeContext.ClassBody && !isStatic && extTypeName == null) { if (parentContext is CodeContext.ClassBody && !isStatic && extTypeName == null) {
val ownerClassName = parentContext.name val ownerClassName = parentContext.name
@ -9547,15 +9754,12 @@ class Compiler(
val memberTypeDecl = TypeDecl.Function( val memberTypeDecl = TypeDecl.Function(
receiver = receiverTypeDecl, receiver = receiverTypeDecl,
params = argsDeclaration.params.map { it.type }, params = argsDeclaration.params.map { it.type },
returnType = returnTypeDecl returnType = inferredReturnDecl ?: TypeDecl.TypeAny,
?: inferredReturnClass?.let { TypeDecl.Simple(it.className, false) }
?: TypeDecl.TypeAny,
nullable = false nullable = false
) )
classMemberTypeDeclByName classMemberTypeDeclByName
.getOrPut(ownerClassName) { mutableMapOf() }[name] = memberTypeDecl .getOrPut(ownerClassName) { mutableMapOf() }[name] = memberTypeDecl
val returnDecl = returnTypeDecl val returnDecl = inferredReturnDecl
?: inferredReturnClass?.let { TypeDecl.Simple(it.className, false) }
if (returnDecl != null) { if (returnDecl != null) {
classMethodReturnTypeDeclByName classMethodReturnTypeDeclByName
.getOrPut(ownerClassName) { mutableMapOf() }[name] = returnDecl .getOrPut(ownerClassName) { mutableMapOf() }[name] = returnDecl
@ -9576,7 +9780,6 @@ class Compiler(
inferredReturnClass inferredReturnClass
} }
} }
val inferredReturnDecl = returnTypeDecl ?: inferredReturnClass?.let { TypeDecl.Simple(it.className, false) }
if (declKind != SymbolKind.MEMBER && inferredReturnDecl != null) { if (declKind != SymbolKind.MEMBER && inferredReturnDecl != null) {
callableReturnTypeDeclByName[name] = inferredReturnDecl callableReturnTypeDeclByName[name] = inferredReturnDecl
} }
@ -9881,6 +10084,40 @@ class Compiler(
} }
} }
private fun inferReturnTypeDeclFromStatement(stmt: Statement?): TypeDecl? {
if (stmt == null) return null
val unwrapped = unwrapBytecodeDeep(stmt)
return when (unwrapped) {
is ExpressionStatement -> inferTypeDeclFromInitializer(unwrapped)
is ReturnStatement -> unwrapped.resultExpr?.let { inferTypeDeclFromInitializer(it) }
is VarDeclStatement -> unwrapped.typeDecl ?: unwrapped.initializer?.let { inferTypeDeclFromInitializer(it) }
is BlockStatement -> {
val stmts = unwrapped.statements()
val returnTypes = stmts.mapNotNull { s ->
(s as? ReturnStatement)?.resultExpr?.let { inferTypeDeclFromInitializer(it) }
}
if (returnTypes.isNotEmpty()) {
val first = returnTypes.first()
if (returnTypes.all { typeDeclKey(it) == typeDeclKey(first) }) first else null
} else {
inferReturnTypeDeclFromStatement(stmts.lastOrNull())
}
}
is InlineBlockStatement -> inferReturnTypeDeclFromStatement(unwrapped.statements().lastOrNull())
is IfStatement -> {
val ifType = inferReturnTypeDeclFromStatement(unwrapped.ifBody)
val elseType = unwrapped.elseBody?.let { inferReturnTypeDeclFromStatement(it) }
when {
ifType == null -> elseType
elseType == null -> ifType
typeDeclKey(ifType) == typeDeclKey(elseType) -> ifType
else -> null
}
}
else -> null
}
}
private fun unwrapDirectRef(initializer: Statement?): ObjRef? { private fun unwrapDirectRef(initializer: Statement?): ObjRef? {
var initStmt = initializer var initStmt = initializer
while (initStmt is BytecodeStatement) { while (initStmt is BytecodeStatement) {
@ -9961,8 +10198,11 @@ class Compiler(
inferMethodCallReturnClass(directRef) inferMethodCallReturnClass(directRef)
} }
is FieldRef -> { is FieldRef -> {
val targetClass = resolveReceiverClassForMember(directRef.target) resolveReceiverTypeDecl(directRef)?.let { resolveTypeDeclObjClass(it) }
inferFieldReturnClass(targetClass, directRef.name) ?: run {
val targetClass = resolveReceiverClassForMember(directRef.target)
inferFieldReturnClass(targetClass, directRef.name)
}
} }
is ImplicitThisMemberRef -> resolveReceiverClassForMember(directRef) is ImplicitThisMemberRef -> resolveReceiverClassForMember(directRef)
is CallRef -> { is CallRef -> {
@ -10367,6 +10607,7 @@ class Compiler(
nameStartPos = nameToken.pos nameStartPos = nameToken.pos
} }
val receiverNormalization = normalizeReceiverTypeDecl(receiverTypeDecl, emptySet()) val receiverNormalization = normalizeReceiverTypeDecl(receiverTypeDecl, emptySet())
receiverTypeDecl = receiverNormalization.first
val implicitTypeParams = receiverNormalization.second val implicitTypeParams = receiverNormalization.second
if (implicitTypeParams.isNotEmpty()) pendingTypeParamStack.add(implicitTypeParams) if (implicitTypeParams.isNotEmpty()) pendingTypeParamStack.add(implicitTypeParams)
try { try {
@ -10599,6 +10840,7 @@ class Compiler(
val directRef = unwrapDirectRef(initialExpression) val directRef = unwrapDirectRef(initialExpression)
val declClass = resolveTypeDeclObjClass(varTypeDecl) val declClass = resolveTypeDeclObjClass(varTypeDecl)
val initFromExpr = resolveInitializerObjClass(initialExpression) val initFromExpr = resolveInitializerObjClass(initialExpression)
val inferredInitTypeDecl = directRef?.let { inferTypeDeclFromRef(it) }
val isNullLiteral = (directRef as? ConstRef)?.constValue == ObjNull val isNullLiteral = (directRef as? ConstRef)?.constValue == ObjNull
val initObjClass = if (declClass != null && isNullLiteral) declClass else initFromExpr ?: declClass val initObjClass = if (declClass != null && isNullLiteral) declClass else initFromExpr ?: declClass
if (varTypeDecl !is TypeDecl.TypeAny && varTypeDecl !is TypeDecl.TypeNullableAny) { if (varTypeDecl !is TypeDecl.TypeAny && varTypeDecl !is TypeDecl.TypeNullableAny) {
@ -10606,16 +10848,29 @@ class Compiler(
slotTypeDeclByScopeId.getOrPut(scopeId) { mutableMapOf() }[slotIndex] = varTypeDecl slotTypeDeclByScopeId.getOrPut(scopeId) { mutableMapOf() }[slotIndex] = varTypeDecl
} }
nameTypeDecl[name] = varTypeDecl nameTypeDecl[name] = varTypeDecl
} } else {
if (directRef is ValueFnRef) { val inferredFunctionType = inferredInitTypeDecl as? TypeDecl.Function
val returnClass = lambdaReturnTypeByRef[directRef] if (inferredFunctionType != null) {
if (returnClass != null) {
if (slotIndex != null && scopeId != null) { if (slotIndex != null && scopeId != null) {
callableReturnTypeByScopeId.getOrPut(scopeId) { mutableMapOf() }[slotIndex] = returnClass slotTypeDeclByScopeId.getOrPut(scopeId) { mutableMapOf() }[slotIndex] = inferredFunctionType
} }
callableReturnTypeByName[name] = returnClass nameTypeDecl[name] = inferredFunctionType
} }
} }
val inferredCallableReturnClass = when {
directRef is ValueFnRef -> lambdaReturnTypeByRef[directRef]
directRef != null -> (inferredInitTypeDecl as? TypeDecl.Function)
?.returnType
?.let { resolveTypeDeclObjClass(it) }
else -> null
}
if (inferredCallableReturnClass != null) {
if (slotIndex != null && scopeId != null) {
callableReturnTypeByScopeId.getOrPut(scopeId) { mutableMapOf() }[slotIndex] =
inferredCallableReturnClass
}
callableReturnTypeByName[name] = inferredCallableReturnClass
}
if (directRef is MethodCallRef && directRef.name == "encode") { if (directRef is MethodCallRef && directRef.name == "encode") {
val payloadClass = inferEncodedPayloadClass(directRef.args) val payloadClass = inferEncodedPayloadClass(directRef.args)
if (payloadClass != null) { if (payloadClass != null) {
@ -10879,6 +11134,24 @@ class Compiler(
if (declarationAnnotationSpecs.isNotEmpty()) { if (declarationAnnotationSpecs.isNotEmpty()) {
throw ScriptError(start, "declaration annotations are not supported on extension properties") throw ScriptError(start, "declaration annotations are not supported on extension properties")
} }
val getterTypeDecl = receiverTypeDecl?.let { recv ->
TypeDecl.Function(
receiver = null,
params = listOf(recv),
returnType = varTypeDecl,
nullable = false
)
}
val setterTypeDecl = if (setter != null && receiverTypeDecl != null) {
TypeDecl.Function(
receiver = null,
params = listOf(receiverTypeDecl, varTypeDecl),
returnType = TypeDecl.Simple("void", false),
nullable = false
)
} else {
null
}
declareLocalName(extensionPropertyGetterName(extTypeName, name), isMutable = false) declareLocalName(extensionPropertyGetterName(extTypeName, name), isMutable = false)
if (setter != null) { if (setter != null) {
declareLocalName(extensionPropertySetterName(extTypeName, name), isMutable = false) declareLocalName(extensionPropertySetterName(extTypeName, name), isMutable = false)
@ -10895,6 +11168,8 @@ class Compiler(
property = prop, property = prop,
visibility = visibility, visibility = visibility,
setterVisibility = setterVisibility, setterVisibility = setterVisibility,
getterTypeDecl = getterTypeDecl,
setterTypeDecl = setterTypeDecl,
startPos = start startPos = start
) )
} }

View File

@ -24,6 +24,8 @@ class ExtensionPropertyDeclStatement(
val property: ObjProperty, val property: ObjProperty,
val visibility: Visibility, val visibility: Visibility,
val setterVisibility: Visibility?, val setterVisibility: Visibility?,
val getterTypeDecl: TypeDecl?,
val setterTypeDecl: TypeDecl?,
private val startPos: Pos, private val startPos: Pos,
) : Statement() { ) : Statement() {
override val pos: Pos = startPos override val pos: Pos = startPos

View File

@ -8003,7 +8003,9 @@ class BytecodeCompiler(
stmt.extTypeName, stmt.extTypeName,
stmt.property, stmt.property,
stmt.visibility, stmt.visibility,
stmt.setterVisibility stmt.setterVisibility,
stmt.getterTypeDecl,
stmt.setterTypeDecl
) )
) )
val slot = allocSlot() val slot = allocSlot()
@ -8644,29 +8646,73 @@ class BytecodeCompiler(
} }
private fun inferCallReturnClass(ref: CallRef): ObjClass? { private fun inferCallReturnClass(ref: CallRef): ObjClass? {
fun exactLambdaReturnClass(slot: Int): ObjClass? =
exactLambdaRefBySlot[slot]?.inferredReturnClass
fun callableReturnClassFromSlot(slot: Int): ObjClass? {
exactLambdaReturnClass(slot)?.let { return it }
typeDeclForSlot(slot)?.let { decl ->
val functionDecl = decl as? TypeDecl.Function
if (functionDecl != null) {
resolveClassFromTypeDecl(functionDecl.returnType)?.let { return it }
}
}
return null
}
fun callableResultClassOrNull(
directReturnClass: ObjClass?,
directTypeDecl: TypeDecl?,
nameClass: ObjClass?,
typeNameFallback: String?
): ObjClass? {
if (directReturnClass != null) return directReturnClass
if (directTypeDecl is TypeDecl.Function) {
return null
}
if (nameClass == ObjClassType) {
return typeNameFallback?.let { resolveTypeNameClass(it) } ?: ObjDynamic.type
}
if (nameClass == Statement.type) {
return null
}
return nameClass ?: typeNameFallback?.let { resolveTypeNameClass(it) }
}
return when (val target = ref.target) { return when (val target = ref.target) {
is LocalSlotRef -> { is LocalSlotRef -> {
callableReturnTypeByScopeId[target.scopeId]?.get(target.slot) val mappedSlot = resolveLocalSlotByRefOrName(target)
?: run { callableResultClassOrNull(
val nameClass = nameObjClass[target.name] directReturnClass = mappedSlot?.let { callableReturnClassFromSlot(it) }
if (nameClass == ObjClassType) { ?: exactLambdaRefByScopeId[target.scopeId]?.get(target.slot)?.inferredReturnClass
resolveTypeNameClass(target.name) ?: ObjDynamic.type ?: callableReturnTypeByScopeId[target.scopeId]?.get(target.slot),
} else { directTypeDecl = mappedSlot?.let { typeDeclForSlot(it) }
nameClass ?: resolveTypeNameClass(target.name) ?: slotTypeDeclByScopeId[target.scopeId]?.get(target.slot),
} nameClass = nameObjClass[target.name],
} typeNameFallback = target.name
)
} }
is LocalVarRef -> { is LocalVarRef -> {
callableReturnTypeByName[target.name] val directSlot = resolveDirectNameSlot(target.name)?.slot
?: run { callableResultClassOrNull(
val nameClass = nameObjClass[target.name] directReturnClass = directSlot?.let { callableReturnClassFromSlot(it) }
if (nameClass == ObjClassType) { ?: callableReturnTypeByName[target.name],
resolveTypeNameClass(target.name) ?: ObjDynamic.type directTypeDecl = directSlot?.let { typeDeclForSlot(it) },
} else { nameClass = nameObjClass[target.name],
nameClass ?: resolveTypeNameClass(target.name) typeNameFallback = target.name
} )
}
} }
is FastLocalVarRef -> {
val directSlot = resolveDirectNameSlot(target.name)?.slot
callableResultClassOrNull(
directReturnClass = directSlot?.let { callableReturnClassFromSlot(it) }
?: callableReturnTypeByName[target.name],
directTypeDecl = directSlot?.let { typeDeclForSlot(it) },
nameClass = nameObjClass[target.name],
typeNameFallback = target.name
)
}
is BoundLocalVarRef -> callableReturnClassFromSlot(target.slotIndex())
is ConstRef -> target.constValue as? ObjClass is ConstRef -> target.constValue as? ObjClass
else -> null else -> null
} }

View File

@ -66,6 +66,8 @@ sealed class BytecodeConst {
val property: ObjProperty, val property: ObjProperty,
val visibility: Visibility, val visibility: Visibility,
val setterVisibility: Visibility?, val setterVisibility: Visibility?,
val getterTypeDecl: TypeDecl?,
val setterTypeDecl: TypeDecl?,
) : BytecodeConst() ) : BytecodeConst()
data class LocalDecl( data class LocalDecl(
val name: String, val name: String,

View File

@ -3268,7 +3268,14 @@ class CmdDeclExtProperty(internal val constId: Int, internal val slot: Int) : Cm
) )
val getterName = extensionPropertyGetterName(decl.extTypeName, decl.property.name) val getterName = extensionPropertyGetterName(decl.extTypeName, decl.property.name)
val getterWrapper = ObjExtensionPropertyGetterCallable(decl.property.name, decl.property) val getterWrapper = ObjExtensionPropertyGetterCallable(decl.property.name, decl.property)
frame.ensureScope().addItem(getterName, false, getterWrapper, decl.visibility, recordType = ObjRecord.Type.Fun) frame.ensureScope().addItem(
getterName,
false,
getterWrapper,
decl.visibility,
recordType = ObjRecord.Type.Fun,
typeDecl = decl.getterTypeDecl
)
val getterLocal = resolveLocalSlotIndex(frame.fn, getterName, preferCapture = false) val getterLocal = resolveLocalSlotIndex(frame.fn, getterName, preferCapture = false)
if (getterLocal != null) { if (getterLocal != null) {
frame.setObjUnchecked(frame.fn.scopeSlotCount + getterLocal, getterWrapper) frame.setObjUnchecked(frame.fn.scopeSlotCount + getterLocal, getterWrapper)
@ -3277,7 +3284,14 @@ class CmdDeclExtProperty(internal val constId: Int, internal val slot: Int) : Cm
val setterName = extensionPropertySetterName(decl.extTypeName, decl.property.name) val setterName = extensionPropertySetterName(decl.extTypeName, decl.property.name)
val setterWrapper = ObjExtensionPropertySetterCallable(decl.property.name, decl.property) val setterWrapper = ObjExtensionPropertySetterCallable(decl.property.name, decl.property)
frame.ensureScope() frame.ensureScope()
.addItem(setterName, false, setterWrapper, decl.visibility, recordType = ObjRecord.Type.Fun) .addItem(
setterName,
false,
setterWrapper,
decl.visibility,
recordType = ObjRecord.Type.Fun,
typeDecl = decl.setterTypeDecl
)
val setterLocal = resolveLocalSlotIndex(frame.fn, setterName, preferCapture = false) val setterLocal = resolveLocalSlotIndex(frame.fn, setterName, preferCapture = false)
if (setterLocal != null) { if (setterLocal != null) {
frame.setObjUnchecked(frame.fn.scopeSlotCount + setterLocal, setterWrapper) frame.setObjUnchecked(frame.fn.scopeSlotCount + setterLocal, setterWrapper)

View File

@ -104,4 +104,39 @@ class TypeInferenceTest {
Pool(2).closeAll() Pool(2).closeAll()
""".trimIndent()) """.trimIndent())
} }
@Test
fun testIterableFirstPreservesElementTypeForBlockReturnInference() = runBlocking<Unit> {
eval("""
class Item(title: String)
fun restored() {
val values = [Item("ok")]
values.first
}
val item = restored()
assertEquals("ok", item.title)
""".trimIndent())
}
@Test
fun testCallableLocalInitializedFromFunctionCallPreservesReturnType() = runBlocking<Unit> {
eval("""
fun makeAdder(base) {
return { x -> x + base + 0.5 }
}
fun run() {
val add = makeAdder(2)
val value = add(3) + 4
assert(value is Real)
value
}
val result = run()
assert(result is Real)
assertEquals(9.5, result)
""".trimIndent())
}
} }