diff --git a/examples/sqlite_serialization.lyng b/examples/sqlite_serialization.lyng index de9e47c..98bdf99 100644 --- a/examples/sqlite_serialization.lyng +++ b/examples/sqlite_serialization.lyng @@ -45,7 +45,7 @@ val restored = openSqlite(":memory:").transaction { tx -> assertEquals(21, restored.state.count) assertEquals("updated", restored.note) restored -} as Item +} println("Restored item:") println(" id=" + restored.id) diff --git a/lyng/src/commonMain/kotlin/Common.kt b/lyng/src/commonMain/kotlin/Common.kt index 7984098..f38f28d 100644 --- a/lyng/src/commonMain/kotlin/Common.kt +++ b/lyng/src/commonMain/kotlin/Common.kt @@ -32,6 +32,7 @@ import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock import kotlinx.coroutines.withContext import net.sergeych.lyng.EvalSession +import net.sergeych.lyng.ExecutionError import net.sergeych.lyng.LyngVersion import net.sergeych.lyng.Pos import net.sergeych.lyng.Scope @@ -153,8 +154,8 @@ val baseScopeDefer = globalDefer { baseCliImportManagerDefer.await().copy().apply { invalidateCliModuleCaches() }.newStdScope().apply { - installCliDeclarations() installCliBuiltins() + installCliDeclarations() addConst("ARGV", ObjList(mutableListOf())) } } @@ -364,8 +365,8 @@ private fun registerLocalCliModules(manager: ImportManager, modules: List): Scope = newStdScope().apply { - installCliDeclarations() installCliBuiltins() + installCliDeclarations() addConst("ARGV", ObjList(argv.map { ObjString(it) }.toMutableList())) } @@ -547,6 +548,15 @@ suspend fun executeSource(source: Source, initialScope: Scope? = null) { evalOnCliDispatcher(session, source) } catch (e: CliExitRequested) { requestedExitCode = e.code + } catch (e: ExecutionError) { + val cliExit = generateSequence(e) { it.cause } + .filterIsInstance() + .firstOrNull() + if (cliExit != null) { + requestedExitCode = cliExit.code + } else { + throw e + } } } finally { shutdownHooks.uninstall() diff --git a/lyngio/src/jvmTest/kotlin/net/sergeych/lyng/io/db/sqlite/LyngSqliteModuleTest.kt b/lyngio/src/jvmTest/kotlin/net/sergeych/lyng/io/db/sqlite/LyngSqliteModuleTest.kt index f5a995e..e26ce07 100644 --- a/lyngio/src/jvmTest/kotlin/net/sergeych/lyng/io/db/sqlite/LyngSqliteModuleTest.kt +++ b/lyngio/src/jvmTest/kotlin/net/sergeych/lyng/io/db/sqlite/LyngSqliteModuleTest.kt @@ -137,6 +137,47 @@ class LyngSqliteModuleTest { assertEquals(2L, result.value) } + @Test + fun testTransactionGenericReturnTypeFlowsToOuterVal() = runTest { + val scope = Script.newScope() + createSqliteModule(scope.importManager) + + val code = """ + import lyng.io.db + import lyng.io.db.sqlite + + class Payload(name: String, count: Int) + class Item(id: Int, title: String, @DbJson meta: Payload, @DbLynon state: Payload) { + var note: String = "" + } + + val restored = openSqlite(":memory:").transaction { tx -> + tx.execute("create table item(id integer not null, title text not null, meta text not null, state blob not null, note text not null)") + val item = Item(1, "first", Payload("json", 10), Payload("bin", 20)) + item.note = "created" + tx.execute("insert into item(@cols(?1)) values(@vals(?1))", item) + item.title = "second" + item.meta = Payload("json2", 11) + item.state = Payload("bin2", 21) + item.note = "updated" + tx.execute("update item set @set(?1 except: \"id\") where id = ?2", item, item.id) + val restored = tx.select("select * from item where id = ?", 1).decodeAs().first + assertEquals("second", restored.title) + assertEquals("json2", restored.meta.name) + assertEquals(11, restored.meta.count) + assertEquals("bin2", restored.state.name) + assertEquals(21, restored.state.count) + assertEquals("updated", restored.note) + restored + } + + restored.id + """.trimIndent() + + val result = Compiler.compile(Source("", code), scope.importManager).execute(scope) as ObjInt + assertEquals(1L, result.value) + } + @Test fun testDecodeAsProjectsJsonColumnIntoObjectField() = runTest { val scope = Script.newScope() diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt index 65007f3..7bb9a3c 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt @@ -189,6 +189,7 @@ class Compiler( private val callableReturnTypeDeclByName: MutableMap = mutableMapOf() private val callSignatureByName: MutableMap = mutableMapOf() private val lambdaReturnTypeByRef: MutableMap = mutableMapOf() + private val lambdaTypeDeclByRef: MutableMap = mutableMapOf() private val exactLambdaRefByScopeId: MutableMap> = mutableMapOf() private val lambdaCaptureEntriesByRef: MutableMap> = mutableMapOf() @@ -682,6 +683,7 @@ class Compiler( return CompileClassInfo( name, cls.logicalPackageName, + emptyList(), fieldIds, methodIds, nextFieldId, @@ -1750,6 +1752,7 @@ class Compiler( private data class CompileClassInfo( val name: String, val packageName: String?, + val typeParams: List, val fieldIds: Map, val methodIds: Map, val nextFieldId: Int, @@ -3496,10 +3499,12 @@ class Compiler( paramTypeDeclMap[slot] = typeDecl } + val lambdaParamTypeDecls = mutableListOf() + if (argsDeclaration != null) { val expectedParams = expectedCallableType?.params.orEmpty() argsDeclaration.params.forEachIndexed { index, param -> - val effectiveType = if ((param.type == TypeDecl.TypeAny || param.type == TypeDecl.TypeNullableAny) && + val rawType = if ((param.type == TypeDecl.TypeAny || param.type == TypeDecl.TypeNullableAny) && index < expectedParams.size ) { expectedParams[index] @@ -3510,15 +3515,20 @@ class Compiler( } else { param.type } + val effectiveType = if (param.isEllipsis) TypeDecl.Ellipsis(rawType) else rawType + lambdaParamTypeDecls += effectiveType if (effectiveType != TypeDecl.TypeAny && effectiveType != TypeDecl.TypeNullableAny) { - seedLambdaParamType(param.name, effectiveType) + seedLambdaParamType(param.name, rawType) } } } else { val effectiveImplicitItType = implicitItType ?: expectedCallableType?.params?.singleOrNull() if (effectiveImplicitItType != null) { + lambdaParamTypeDecls += effectiveImplicitItType seedLambdaParamType("it", effectiveImplicitItType) + } else { + lambdaParamTypeDecls += TypeDecl.Ellipsis(TypeDecl.TypeAny) } } @@ -3580,6 +3590,7 @@ class Compiler( } else { emptyList() } + val inferredReturnDecl = inferReturnTypeDeclFromStatement(body) val returnClass = inferReturnClassFromStatement(body) val paramKnownClasses = mutableMapOf() argsDeclaration?.params?.forEach { param -> @@ -3784,6 +3795,13 @@ class Compiler( returnLabels = returnLabels, pos = startPos ) + val lambdaTypeDecl = TypeDecl.Function( + receiver = null, + params = lambdaParamTypeDecls.toList(), + returnType = inferredReturnDecl ?: returnClass?.let { TypeDecl.Simple(it.className, false) } ?: TypeDecl.TypeAny, + nullable = false + ) + lambdaTypeDeclByRef[ref] = lambdaTypeDecl if (returnClass != null) { lambdaReturnTypeByRef[ref] = returnClass } @@ -4747,6 +4765,7 @@ class Compiler( private fun inferTypeDeclFromRef(ref: ObjRef): TypeDecl? { resolveReceiverTypeDecl(ref)?.let { return it } return when (ref) { + is ValueFnRef -> lambdaTypeDeclByRef[ref] is ListLiteralRef -> inferListLiteralTypeDecl(ref) is MapLiteralRef -> inferMapLiteralTypeDecl(ref) is ConstRef -> inferTypeDeclFromConst(ref.constValue) @@ -5057,6 +5076,37 @@ class Compiler( return null } + private fun substituteReceiverTypeParams(receiverType: TypeDecl?, ownerClassName: String?, memberType: TypeDecl?): TypeDecl? { + if (receiverType !is TypeDecl.Generic || ownerClassName == null || memberType == null) return memberType + val info = resolveCompileClassInfo(ownerClassName) ?: return memberType + if (info.typeParams.isEmpty()) return memberType + val bindings = LinkedHashMap(info.typeParams.size) + for ((index, typeParamName) in info.typeParams.withIndex()) { + val argType = receiverType.args.getOrNull(index) ?: continue + bindings[typeParamName] = argType + } + if (bindings.isEmpty()) return memberType + return substituteTypeAliasTypeVars(memberType, bindings) + } + + private fun inferExtensionPropertyTypeDecl(receiverDecl: TypeDecl?, receiverClass: ObjClass?, memberName: String): TypeDecl? { + if (receiverClass == null) return null + for (cls in receiverClass.mro) { + val wrapperName = extensionPropertyGetterName(cls.className, memberName) + val resolved = resolveImportBinding(wrapperName, Pos.builtIn) ?: continue + registerImportBinding(wrapperName, resolved.binding, Pos.builtIn) + val wrapperType = resolved.record.typeDecl as? TypeDecl.Function ?: continue + val bindings = mutableMapOf() + val receiverParam = wrapperType.params.firstOrNull() ?: wrapperType.receiver + if (receiverParam != null && receiverDecl != null) { + collectTypeVarBindings(receiverParam, receiverDecl, bindings) + } + return if (bindings.isEmpty()) wrapperType.returnType + else substituteTypeAliasTypeVars(wrapperType.returnType, bindings) + } + return null + } + private fun classMemberTypeDecl(targetClass: ObjClass?, name: String): TypeDecl? { if (targetClass == null) return null if (targetClass == ObjDynamic.type) return TypeDecl.TypeAny @@ -5192,7 +5242,12 @@ class Compiler( is FieldRef -> { val targetDecl = resolveReceiverTypeDecl(ref.target) ?: return null val targetClass = resolveTypeDeclObjClass(targetDecl) ?: resolveReceiverClassForMember(ref.target) - classMemberTypeDecl(targetClass, ref.name)?.let { return it } + classMemberTypeDecl(targetClass, ref.name)?.let { declared -> + val ownerClassName = targetClass?.getInstanceMemberOrNull(ref.name, includeAbstract = true) + ?.declaringClass?.className ?: targetClass?.className + return substituteReceiverTypeParams(targetDecl, ownerClassName, declared) + } + inferExtensionPropertyTypeDecl(targetDecl, targetClass, ref.name)?.let { return it } classFieldTypesByName[targetClass?.className]?.get(ref.name) ?.let { return TypeDecl.Simple(it.className, false) } when (targetDecl) { @@ -5529,8 +5584,22 @@ class Compiler( private fun inferMethodCallReturnTypeDecl(ref: MethodCallRef): TypeDecl? { methodReturnTypeDeclByRef[ref]?.let { return it } - val inferred = inferMethodCallReturnTypeDecl(ref.name, resolveReceiverTypeDecl(ref.receiver), ref.args) - ?: classMethodReturnTypeDecl(resolveReceiverClassForMember(ref.receiver), ref.name) + val receiverDecl = resolveReceiverTypeDecl(ref.receiver) + val inferred = inferMethodCallReturnTypeDecl(ref.name, receiverDecl, ref.args) + ?: inferDeclaredMethodCallReturnTypeDecl( + ref.name, + receiverDecl, + resolveReceiverClassForMember(ref.receiver), + ref.args, + ref.explicitTypeArgs + ) + ?: run { + val receiverClass = resolveReceiverClassForMember(ref.receiver) + val declared = classMethodReturnTypeDecl(receiverClass, ref.name) + val ownerClassName = receiverClass?.getInstanceMemberOrNull(ref.name, includeAbstract = true) + ?.declaringClass?.className ?: receiverClass?.className + substituteReceiverTypeParams(receiverDecl, ownerClassName, declared) + } if (inferred != null) { methodReturnTypeDeclByRef[ref] = inferred } @@ -5635,6 +5704,108 @@ class Compiler( } } + private fun inferDeclaredMethodCallReturnTypeDecl( + name: String, + receiverDecl: TypeDecl?, + receiverClass: ObjClass?, + args: List, + explicitTypeArgs: List? = null + ): TypeDecl? { + if (receiverClass == null) return null + val ownerClassName = receiverClass.getInstanceMemberOrNull(name, includeAbstract = true) + ?.declaringClass?.className ?: receiverClass.className + val memberType = substituteReceiverTypeParams( + receiverDecl, + ownerClassName, + classMemberTypeDecl(receiverClass, name) + ) as? TypeDecl.Function ?: return null + + fun argTypeDecl(arg: ParsedArgument): TypeDecl? { + val stmt = arg.value as? ExpressionStatement ?: return null + val directRef = stmt.ref + return inferTypeDeclFromRef(directRef) + ?: inferObjClassFromRef(directRef)?.let { TypeDecl.Simple(it.className, false) } + } + + val bindings = mutableMapOf() + collectExplicitMethodTypeBindings(memberType, explicitTypeArgs, bindings) + memberType.receiver?.let { declaredReceiver -> + receiverDecl?.let { collectTypeVarBindings(declaredReceiver, it, bindings) } + } + + val paramList = memberType.params + val ellipsisIndex = paramList.indexOfFirst { it is TypeDecl.Ellipsis } + if (ellipsisIndex < 0) { + val limit = minOf(paramList.size, args.size) + for (i in 0 until limit) { + val argType = argTypeDecl(args[i]) ?: continue + collectTypeVarBindings(paramList[i], argType, bindings) + } + } else { + val headCount = ellipsisIndex + val tailCount = paramList.size - ellipsisIndex - 1 + val argCount = args.size + val headLimit = minOf(headCount, argCount) + for (i in 0 until headLimit) { + val argType = argTypeDecl(args[i]) ?: continue + collectTypeVarBindings(paramList[i], argType, bindings) + } + val tailStartArg = maxOf(headCount, argCount - tailCount) + for (i in tailStartArg until argCount) { + val paramIndex = paramList.size - (argCount - i) + val argType = argTypeDecl(args[i]) ?: continue + collectTypeVarBindings(paramList[paramIndex], argType, bindings) + } + val ellipsisArgEnd = argCount - tailCount + val ellipsisType = paramList[ellipsisIndex] as TypeDecl.Ellipsis + for (i in headCount until ellipsisArgEnd) { + val argType = if (args[i].isSplat) { + val stmt = args[i].value as? ExpressionStatement + stmt?.ref?.let { inferElementTypeFromSpread(it) } + } else { + argTypeDecl(args[i]) + } ?: continue + collectTypeVarBindings(ellipsisType.elementType, argType, bindings) + } + } + + return if (bindings.isEmpty()) memberType.returnType + else substituteTypeAliasTypeVars(memberType.returnType, bindings) + } + + private fun collectExplicitMethodTypeBindings( + memberType: TypeDecl.Function, + explicitTypeArgs: List?, + out: MutableMap + ) { + if (explicitTypeArgs.isNullOrEmpty()) return + val typeVars = LinkedHashSet() + memberType.receiver?.let { collectTypeVarNamesInOrder(it, typeVars) } + memberType.params.forEach { collectTypeVarNamesInOrder(it, typeVars) } + collectTypeVarNamesInOrder(memberType.returnType, typeVars) + val names = typeVars.toList() + val limit = minOf(names.size, explicitTypeArgs.size) + for (i in 0 until limit) { + out[names[i]] = explicitTypeArgs[i] + } + } + + private fun collectTypeVarNamesInOrder(type: TypeDecl, out: MutableSet) { + when (type) { + is TypeDecl.TypeVar -> out += type.name + is TypeDecl.Generic -> type.args.forEach { collectTypeVarNamesInOrder(it, out) } + is TypeDecl.Function -> { + type.receiver?.let { collectTypeVarNamesInOrder(it, out) } + type.params.forEach { collectTypeVarNamesInOrder(it, out) } + collectTypeVarNamesInOrder(type.returnType, out) + } + is TypeDecl.Ellipsis -> collectTypeVarNamesInOrder(type.elementType, out) + is TypeDecl.Union -> type.options.forEach { collectTypeVarNamesInOrder(it, out) } + is TypeDecl.Intersection -> type.options.forEach { collectTypeVarNamesInOrder(it, out) } + else -> {} + } + } + private fun inferCallableReturnTypeDeclFromArgument(arg: ParsedArgument): TypeDecl? { val stmt = arg.value as? ExpressionStatement ?: return null val ref = stmt.ref @@ -6104,6 +6275,7 @@ class Compiler( args: List, pos: Pos ) { + if (shouldSkipStaticCallableChecks(target)) return lookupNamedFunctionDecl(target)?.let { decl -> val hasComplexArgs = args.any { it.name != null } || decl.typeParams.isNotEmpty() || @@ -6201,6 +6373,22 @@ class Compiler( return null } + private fun shouldSkipStaticCallableChecks(target: ObjRef): Boolean { + resolveExactLambdaRef(target)?.let { lambda -> + if (lambda.argsDeclaration == null) return true + } + val name = when (target) { + is LocalVarRef -> target.name + is LocalSlotRef -> target.name + is FastLocalVarRef -> target.name + else -> null + } + if (name != null && callSignatureForName(name) != null) { + return true + } + return lookupNamedCallableRecord(target)?.callSignature != null + } + private fun inferCallReturnTypeDecl(ref: CallRef): TypeDecl? { callReturnTypeDeclByRef[ref]?.let { return it } val targetDecl = (resolveReceiverTypeDecl(ref.target) ?: seedTypeDeclFromRef(ref.target)) as? TypeDecl.Function @@ -6263,6 +6451,7 @@ class Compiler( args: List, pos: Pos ) { + if (shouldSkipStaticCallableChecks(target)) return lookupNamedFunctionDecl(target)?.let { decl -> val hasComplexArgs = args.any { it.name != null } || decl.typeParams.isNotEmpty() || @@ -6356,7 +6545,7 @@ class Compiler( return } val seededCallable = lookupNamedCallableRecord(target) - if (seededCallable != null && seededCallable.type == ObjRecord.Type.Fun && seededCallable.value !is ObjExternCallable) { + if (seededCallable != null && seededCallable.type == ObjRecord.Type.Fun) { return } val decl = (resolveReceiverTypeDecl(target) as? TypeDecl.Function) @@ -6470,6 +6659,17 @@ class Compiler( } } } + is TypeDecl.Function -> { + if (argType is TypeDecl.Function && paramType.params.size == argType.params.size) { + if (paramType.receiver != null && argType.receiver != null) { + collectTypeVarBindings(paramType.receiver, argType.receiver, out) + } + for (i in paramType.params.indices) { + collectTypeVarBindings(paramType.params[i], argType.params[i], out) + } + collectTypeVarBindings(paramType.returnType, argType.returnType, out) + } + } is TypeDecl.Union -> { if (argType is TypeDecl.Union) { val limit = minOf(paramType.options.size, argType.options.size) @@ -7820,6 +8020,7 @@ class Compiler( compileClassInfos[qualifiedName] = CompileClassInfo( name = qualifiedName, packageName = packageName, + typeParams = emptyList(), fieldIds = fieldIds, methodIds = methodIds, nextFieldId = fieldIds.size, @@ -7943,6 +8144,7 @@ class Compiler( compileClassInfos[className] = CompileClassInfo( name = className, packageName = packageName, + typeParams = emptyList(), fieldIds = ctx.memberFieldIds.toMap(), methodIds = ctx.memberMethodIds.toMap(), nextFieldId = ctx.nextFieldId, @@ -7985,6 +8187,7 @@ class Compiler( compileClassInfos[className] = CompileClassInfo( name = className, packageName = packageName, + typeParams = emptyList(), fieldIds = baseIds.fieldIds, methodIds = baseIds.methodIds, nextFieldId = baseIds.nextFieldId, @@ -8250,6 +8453,7 @@ class Compiler( compileClassInfos[qualifiedName] = CompileClassInfo( name = qualifiedName, packageName = packageName, + typeParams = typeParamDecls.map { it.name }, fieldIds = ctx.memberFieldIds.toMap(), methodIds = ctx.memberMethodIds.toMap(), nextFieldId = ctx.nextFieldId, @@ -8267,6 +8471,7 @@ class Compiler( compileClassInfos[qualifiedName] = CompileClassInfo( name = qualifiedName, packageName = packageName, + typeParams = typeParamDecls.map { it.name }, fieldIds = ctx.memberFieldIds.toMap(), methodIds = ctx.memberMethodIds.toMap(), nextFieldId = ctx.nextFieldId, @@ -8355,6 +8560,7 @@ class Compiler( compileClassInfos[qualifiedName] = CompileClassInfo( name = qualifiedName, packageName = packageName, + typeParams = typeParamDecls.map { it.name }, fieldIds = ctx.memberFieldIds.toMap(), methodIds = ctx.memberMethodIds.toMap(), nextFieldId = ctx.nextFieldId, @@ -9539,7 +9745,8 @@ class Compiler( resolutionSink?.exitScope(cc.currentPos()) } val rawFnStatements = parsedFnStatements?.let { unwrapBytecodeDeep(it) } - val inferredReturnClass = returnTypeDecl?.let { resolveTypeDeclObjClass(it) } + val inferredReturnDecl = returnTypeDecl ?: inferReturnTypeDeclFromStatement(rawFnStatements) + val inferredReturnClass = inferredReturnDecl?.let { resolveTypeDeclObjClass(it) } ?: inferReturnClassFromStatement(rawFnStatements) if (parentContext is CodeContext.ClassBody && !isStatic && extTypeName == null) { val ownerClassName = parentContext.name @@ -9547,15 +9754,12 @@ class Compiler( val memberTypeDecl = TypeDecl.Function( receiver = receiverTypeDecl, params = argsDeclaration.params.map { it.type }, - returnType = returnTypeDecl - ?: inferredReturnClass?.let { TypeDecl.Simple(it.className, false) } - ?: TypeDecl.TypeAny, + returnType = inferredReturnDecl ?: TypeDecl.TypeAny, nullable = false ) classMemberTypeDeclByName .getOrPut(ownerClassName) { mutableMapOf() }[name] = memberTypeDecl - val returnDecl = returnTypeDecl - ?: inferredReturnClass?.let { TypeDecl.Simple(it.className, false) } + val returnDecl = inferredReturnDecl if (returnDecl != null) { classMethodReturnTypeDeclByName .getOrPut(ownerClassName) { mutableMapOf() }[name] = returnDecl @@ -9576,7 +9780,6 @@ class Compiler( inferredReturnClass } } - val inferredReturnDecl = returnTypeDecl ?: inferredReturnClass?.let { TypeDecl.Simple(it.className, false) } if (declKind != SymbolKind.MEMBER && inferredReturnDecl != null) { callableReturnTypeDeclByName[name] = inferredReturnDecl } @@ -9881,6 +10084,40 @@ class Compiler( } } + private fun inferReturnTypeDeclFromStatement(stmt: Statement?): TypeDecl? { + if (stmt == null) return null + val unwrapped = unwrapBytecodeDeep(stmt) + return when (unwrapped) { + is ExpressionStatement -> inferTypeDeclFromInitializer(unwrapped) + is ReturnStatement -> unwrapped.resultExpr?.let { inferTypeDeclFromInitializer(it) } + is VarDeclStatement -> unwrapped.typeDecl ?: unwrapped.initializer?.let { inferTypeDeclFromInitializer(it) } + is BlockStatement -> { + val stmts = unwrapped.statements() + val returnTypes = stmts.mapNotNull { s -> + (s as? ReturnStatement)?.resultExpr?.let { inferTypeDeclFromInitializer(it) } + } + if (returnTypes.isNotEmpty()) { + val first = returnTypes.first() + if (returnTypes.all { typeDeclKey(it) == typeDeclKey(first) }) first else null + } else { + inferReturnTypeDeclFromStatement(stmts.lastOrNull()) + } + } + is InlineBlockStatement -> inferReturnTypeDeclFromStatement(unwrapped.statements().lastOrNull()) + is IfStatement -> { + val ifType = inferReturnTypeDeclFromStatement(unwrapped.ifBody) + val elseType = unwrapped.elseBody?.let { inferReturnTypeDeclFromStatement(it) } + when { + ifType == null -> elseType + elseType == null -> ifType + typeDeclKey(ifType) == typeDeclKey(elseType) -> ifType + else -> null + } + } + else -> null + } + } + private fun unwrapDirectRef(initializer: Statement?): ObjRef? { var initStmt = initializer while (initStmt is BytecodeStatement) { @@ -9961,8 +10198,11 @@ class Compiler( inferMethodCallReturnClass(directRef) } is FieldRef -> { - val targetClass = resolveReceiverClassForMember(directRef.target) - inferFieldReturnClass(targetClass, directRef.name) + resolveReceiverTypeDecl(directRef)?.let { resolveTypeDeclObjClass(it) } + ?: run { + val targetClass = resolveReceiverClassForMember(directRef.target) + inferFieldReturnClass(targetClass, directRef.name) + } } is ImplicitThisMemberRef -> resolveReceiverClassForMember(directRef) is CallRef -> { @@ -10367,6 +10607,7 @@ class Compiler( nameStartPos = nameToken.pos } val receiverNormalization = normalizeReceiverTypeDecl(receiverTypeDecl, emptySet()) + receiverTypeDecl = receiverNormalization.first val implicitTypeParams = receiverNormalization.second if (implicitTypeParams.isNotEmpty()) pendingTypeParamStack.add(implicitTypeParams) try { @@ -10599,6 +10840,7 @@ class Compiler( val directRef = unwrapDirectRef(initialExpression) val declClass = resolveTypeDeclObjClass(varTypeDecl) val initFromExpr = resolveInitializerObjClass(initialExpression) + val inferredInitTypeDecl = directRef?.let { inferTypeDeclFromRef(it) } val isNullLiteral = (directRef as? ConstRef)?.constValue == ObjNull val initObjClass = if (declClass != null && isNullLiteral) declClass else initFromExpr ?: declClass if (varTypeDecl !is TypeDecl.TypeAny && varTypeDecl !is TypeDecl.TypeNullableAny) { @@ -10606,16 +10848,29 @@ class Compiler( slotTypeDeclByScopeId.getOrPut(scopeId) { mutableMapOf() }[slotIndex] = varTypeDecl } nameTypeDecl[name] = varTypeDecl - } - if (directRef is ValueFnRef) { - val returnClass = lambdaReturnTypeByRef[directRef] - if (returnClass != null) { + } else { + val inferredFunctionType = inferredInitTypeDecl as? TypeDecl.Function + if (inferredFunctionType != null) { if (slotIndex != null && scopeId != null) { - callableReturnTypeByScopeId.getOrPut(scopeId) { mutableMapOf() }[slotIndex] = returnClass + slotTypeDeclByScopeId.getOrPut(scopeId) { mutableMapOf() }[slotIndex] = inferredFunctionType } - callableReturnTypeByName[name] = returnClass + nameTypeDecl[name] = inferredFunctionType } } + val inferredCallableReturnClass = when { + directRef is ValueFnRef -> lambdaReturnTypeByRef[directRef] + directRef != null -> (inferredInitTypeDecl as? TypeDecl.Function) + ?.returnType + ?.let { resolveTypeDeclObjClass(it) } + else -> null + } + if (inferredCallableReturnClass != null) { + if (slotIndex != null && scopeId != null) { + callableReturnTypeByScopeId.getOrPut(scopeId) { mutableMapOf() }[slotIndex] = + inferredCallableReturnClass + } + callableReturnTypeByName[name] = inferredCallableReturnClass + } if (directRef is MethodCallRef && directRef.name == "encode") { val payloadClass = inferEncodedPayloadClass(directRef.args) if (payloadClass != null) { @@ -10879,6 +11134,24 @@ class Compiler( if (declarationAnnotationSpecs.isNotEmpty()) { throw ScriptError(start, "declaration annotations are not supported on extension properties") } + val getterTypeDecl = receiverTypeDecl?.let { recv -> + TypeDecl.Function( + receiver = null, + params = listOf(recv), + returnType = varTypeDecl, + nullable = false + ) + } + val setterTypeDecl = if (setter != null && receiverTypeDecl != null) { + TypeDecl.Function( + receiver = null, + params = listOf(receiverTypeDecl, varTypeDecl), + returnType = TypeDecl.Simple("void", false), + nullable = false + ) + } else { + null + } declareLocalName(extensionPropertyGetterName(extTypeName, name), isMutable = false) if (setter != null) { declareLocalName(extensionPropertySetterName(extTypeName, name), isMutable = false) @@ -10895,6 +11168,8 @@ class Compiler( property = prop, visibility = visibility, setterVisibility = setterVisibility, + getterTypeDecl = getterTypeDecl, + setterTypeDecl = setterTypeDecl, startPos = start ) } diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/ExtensionPropertyDeclStatement.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/ExtensionPropertyDeclStatement.kt index 59bb2e9..8690d6f 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/ExtensionPropertyDeclStatement.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/ExtensionPropertyDeclStatement.kt @@ -24,6 +24,8 @@ class ExtensionPropertyDeclStatement( val property: ObjProperty, val visibility: Visibility, val setterVisibility: Visibility?, + val getterTypeDecl: TypeDecl?, + val setterTypeDecl: TypeDecl?, private val startPos: Pos, ) : Statement() { override val pos: Pos = startPos diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt index 7b3168d..74ae496 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt @@ -8003,7 +8003,9 @@ class BytecodeCompiler( stmt.extTypeName, stmt.property, stmt.visibility, - stmt.setterVisibility + stmt.setterVisibility, + stmt.getterTypeDecl, + stmt.setterTypeDecl ) ) val slot = allocSlot() @@ -8644,29 +8646,73 @@ class BytecodeCompiler( } private fun inferCallReturnClass(ref: CallRef): ObjClass? { + fun exactLambdaReturnClass(slot: Int): ObjClass? = + exactLambdaRefBySlot[slot]?.inferredReturnClass + + fun callableReturnClassFromSlot(slot: Int): ObjClass? { + exactLambdaReturnClass(slot)?.let { return it } + typeDeclForSlot(slot)?.let { decl -> + val functionDecl = decl as? TypeDecl.Function + if (functionDecl != null) { + resolveClassFromTypeDecl(functionDecl.returnType)?.let { return it } + } + } + return null + } + + fun callableResultClassOrNull( + directReturnClass: ObjClass?, + directTypeDecl: TypeDecl?, + nameClass: ObjClass?, + typeNameFallback: String? + ): ObjClass? { + if (directReturnClass != null) return directReturnClass + if (directTypeDecl is TypeDecl.Function) { + return null + } + if (nameClass == ObjClassType) { + return typeNameFallback?.let { resolveTypeNameClass(it) } ?: ObjDynamic.type + } + if (nameClass == Statement.type) { + return null + } + return nameClass ?: typeNameFallback?.let { resolveTypeNameClass(it) } + } + return when (val target = ref.target) { is LocalSlotRef -> { - callableReturnTypeByScopeId[target.scopeId]?.get(target.slot) - ?: run { - val nameClass = nameObjClass[target.name] - if (nameClass == ObjClassType) { - resolveTypeNameClass(target.name) ?: ObjDynamic.type - } else { - nameClass ?: resolveTypeNameClass(target.name) - } - } + val mappedSlot = resolveLocalSlotByRefOrName(target) + callableResultClassOrNull( + directReturnClass = mappedSlot?.let { callableReturnClassFromSlot(it) } + ?: exactLambdaRefByScopeId[target.scopeId]?.get(target.slot)?.inferredReturnClass + ?: callableReturnTypeByScopeId[target.scopeId]?.get(target.slot), + directTypeDecl = mappedSlot?.let { typeDeclForSlot(it) } + ?: slotTypeDeclByScopeId[target.scopeId]?.get(target.slot), + nameClass = nameObjClass[target.name], + typeNameFallback = target.name + ) } is LocalVarRef -> { - callableReturnTypeByName[target.name] - ?: run { - val nameClass = nameObjClass[target.name] - if (nameClass == ObjClassType) { - resolveTypeNameClass(target.name) ?: ObjDynamic.type - } else { - nameClass ?: resolveTypeNameClass(target.name) - } - } + val directSlot = resolveDirectNameSlot(target.name)?.slot + callableResultClassOrNull( + directReturnClass = directSlot?.let { callableReturnClassFromSlot(it) } + ?: callableReturnTypeByName[target.name], + directTypeDecl = directSlot?.let { typeDeclForSlot(it) }, + nameClass = nameObjClass[target.name], + typeNameFallback = target.name + ) } + is FastLocalVarRef -> { + val directSlot = resolveDirectNameSlot(target.name)?.slot + callableResultClassOrNull( + directReturnClass = directSlot?.let { callableReturnClassFromSlot(it) } + ?: callableReturnTypeByName[target.name], + directTypeDecl = directSlot?.let { typeDeclForSlot(it) }, + nameClass = nameObjClass[target.name], + typeNameFallback = target.name + ) + } + is BoundLocalVarRef -> callableReturnClassFromSlot(target.slotIndex()) is ConstRef -> target.constValue as? ObjClass else -> null } diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeConst.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeConst.kt index 7c9dbb1..4432119 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeConst.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeConst.kt @@ -66,6 +66,8 @@ sealed class BytecodeConst { val property: ObjProperty, val visibility: Visibility, val setterVisibility: Visibility?, + val getterTypeDecl: TypeDecl?, + val setterTypeDecl: TypeDecl?, ) : BytecodeConst() data class LocalDecl( val name: String, diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/CmdRuntime.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/CmdRuntime.kt index b203a44..18c0170 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/CmdRuntime.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/CmdRuntime.kt @@ -3268,7 +3268,14 @@ class CmdDeclExtProperty(internal val constId: Int, internal val slot: Int) : Cm ) val getterName = extensionPropertyGetterName(decl.extTypeName, decl.property.name) val getterWrapper = ObjExtensionPropertyGetterCallable(decl.property.name, decl.property) - frame.ensureScope().addItem(getterName, false, getterWrapper, decl.visibility, recordType = ObjRecord.Type.Fun) + frame.ensureScope().addItem( + getterName, + false, + getterWrapper, + decl.visibility, + recordType = ObjRecord.Type.Fun, + typeDecl = decl.getterTypeDecl + ) val getterLocal = resolveLocalSlotIndex(frame.fn, getterName, preferCapture = false) if (getterLocal != null) { frame.setObjUnchecked(frame.fn.scopeSlotCount + getterLocal, getterWrapper) @@ -3277,7 +3284,14 @@ class CmdDeclExtProperty(internal val constId: Int, internal val slot: Int) : Cm val setterName = extensionPropertySetterName(decl.extTypeName, decl.property.name) val setterWrapper = ObjExtensionPropertySetterCallable(decl.property.name, decl.property) frame.ensureScope() - .addItem(setterName, false, setterWrapper, decl.visibility, recordType = ObjRecord.Type.Fun) + .addItem( + setterName, + false, + setterWrapper, + decl.visibility, + recordType = ObjRecord.Type.Fun, + typeDecl = decl.setterTypeDecl + ) val setterLocal = resolveLocalSlotIndex(frame.fn, setterName, preferCapture = false) if (setterLocal != null) { frame.setObjUnchecked(frame.fn.scopeSlotCount + setterLocal, setterWrapper) diff --git a/lynglib/src/commonTest/kotlin/TypeInferenceTest.kt b/lynglib/src/commonTest/kotlin/TypeInferenceTest.kt index 5b40fe8..766a78e 100644 --- a/lynglib/src/commonTest/kotlin/TypeInferenceTest.kt +++ b/lynglib/src/commonTest/kotlin/TypeInferenceTest.kt @@ -104,4 +104,39 @@ class TypeInferenceTest { Pool(2).closeAll() """.trimIndent()) } + + @Test + fun testIterableFirstPreservesElementTypeForBlockReturnInference() = runBlocking { + eval(""" + class Item(title: String) + + fun restored() { + val values = [Item("ok")] + values.first + } + + val item = restored() + assertEquals("ok", item.title) + """.trimIndent()) + } + + @Test + fun testCallableLocalInitializedFromFunctionCallPreservesReturnType() = runBlocking { + eval(""" + fun makeAdder(base) { + return { x -> x + base + 0.5 } + } + + fun run() { + val add = makeAdder(2) + val value = add(3) + 4 + assert(value is Real) + value + } + + val result = run() + assert(result is Real) + assertEquals(9.5, result) + """.trimIndent()) + } }