From caad7d8ab9674c4e64668a7753c2f0ff66de80ab Mon Sep 17 00:00:00 2001 From: sergeych Date: Fri, 3 Apr 2026 11:04:16 +0300 Subject: [PATCH] stdlib inference bug fixed --- .../kotlin/net/sergeych/lyng/Compiler.kt | 236 +++++++++++++++++- .../sergeych/lyng/FunctionDeclStatement.kt | 44 ++-- lynglib/src/commonTest/kotlin/StdlibTest.kt | 132 ++++++++++ .../net/sergeych/lyng/ComplexModuleTest.kt | 1 + lynglib/stdlib/lyng/root.lyng | 24 +- 5 files changed, 412 insertions(+), 25 deletions(-) diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt index abe3156..e0df2ea 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt @@ -4346,7 +4346,7 @@ class Compiler( is MapLiteralRef -> inferMapLiteralTypeDecl(ref) is ConstRef -> inferTypeDeclFromConst(ref.constValue) is CallRef -> { - val targetDecl = resolveReceiverTypeDecl(ref.target) + val targetDecl = resolveReceiverTypeDecl(ref.target) ?: seedTypeDeclFromRef(ref.target) val targetName = when (val target = ref.target) { is LocalVarRef -> target.name is FastLocalVarRef -> target.name @@ -4354,8 +4354,9 @@ class Compiler( else -> null } if (targetDecl is TypeDecl.Function) { - return targetDecl.returnType + return inferCallReturnTypeDecl(ref) ?: targetDecl.returnType } + inferCallReturnTypeDecl(ref)?.let { return it } if (targetName != null) { callableReturnTypeDeclByName[targetName]?.let { return it } (seedTypeDeclByName(targetName) as? TypeDecl.Function)?.let { return it.returnType } @@ -4745,7 +4746,7 @@ class Compiler( classMethodReturnTypeDecl(targetClass, "getAt") } is MethodCallRef -> methodReturnTypeDeclByRef[ref] - is CallRef -> callReturnTypeDeclByRef[ref] + is CallRef -> callReturnTypeDeclByRef[ref] ?: inferCallReturnTypeDecl(ref) is BinaryOpRef -> inferBinaryOpReturnTypeDecl(ref) is StatementRef -> (ref.statement as? ExpressionStatement)?.let { resolveReceiverTypeDecl(it.ref) } else -> null @@ -4811,7 +4812,7 @@ class Compiler( is ImplicitThisMethodCallRef -> inferMethodCallReturnClass(ref.methodName()) is ThisMethodSlotCallRef -> inferMethodCallReturnClass(ref.methodName()) is QualifiedThisMethodSlotCallRef -> inferMethodCallReturnClass(ref.methodName()) - is CallRef -> inferCallReturnClass(ref) + is CallRef -> inferCallReturnTypeDecl(ref)?.let { resolveTypeDeclObjClass(it) } ?: inferCallReturnClass(ref) is BinaryOpRef -> inferBinaryOpReturnClass(ref) is FieldRef -> { val targetClass = resolveReceiverClassForMember(ref.target) @@ -5478,12 +5479,45 @@ class Compiler( args: List, pos: Pos ) { + lookupNamedFunctionDecl(target)?.let { decl -> + val hasComplexArgs = args.any { it.name != null } || + decl.typeParams.isNotEmpty() || + decl.params.any { it.defaultValue != null || it.isEllipsis } + if (hasComplexArgs) return + if (args.any { it.isSplat }) return + val actual = args.size + val params = decl.params + val ellipsisIndex = params.indexOfFirst { it.isEllipsis } + if (ellipsisIndex < 0) { + val minArgs = params.count { it.defaultValue == null } + val maxArgs = params.size + if (actual < minArgs || actual > maxArgs) { + val message = if (minArgs == maxArgs) { + "expected $maxArgs arguments, got $actual" + } else { + "expected $minArgs..$maxArgs arguments, got $actual" + } + throw ScriptError(pos, message) + } + return + } + val headRequired = (0 until ellipsisIndex).count { params[it].defaultValue == null } + val tailRequired = (ellipsisIndex + 1 until params.size).count { params[it].defaultValue == null } + val minArgs = headRequired + tailRequired + if (actual < minArgs) { + throw ScriptError(pos, "expected at least $minArgs arguments, got $actual") + } + return + } + val seededCallable = lookupNamedCallableRecord(target) + if (seededCallable != null && seededCallable.type == ObjRecord.Type.Fun && seededCallable.value !is ObjExternCallable) { + return + } val decl = (resolveReceiverTypeDecl(target) as? TypeDecl.Function) ?: seedTypeDeclFromRef(target) as? TypeDecl.Function ?: return if (args.any { it.isSplat }) return val actual = args.size - val receiverCount = if (decl.receiver != null) 1 else 0 val paramList = mutableListOf() decl.receiver?.let { paramList += it } paramList += decl.params @@ -5512,6 +5546,91 @@ class Compiler( } ?: return null seedScope?.getLocalRecordDirect(name)?.typeDecl?.let { return it } return seedScope?.get(name)?.typeDecl + ?: importManager.rootScope.getLocalRecordDirect(name)?.typeDecl + ?: importManager.rootScope.get(name)?.typeDecl + } + + private fun lookupNamedFunctionDecl(target: ObjRef): GenericFunctionDecl? { + val name = when (target) { + is LocalVarRef -> target.name + is LocalSlotRef -> target.name + is FastLocalVarRef -> target.name + else -> null + } ?: return null + return lookupGenericFunctionDecl(name) + } + + private fun lookupNamedCallableRecord(target: ObjRef): ObjRecord? { + val name = when (target) { + is LocalVarRef -> target.name + is LocalSlotRef -> target.name + is FastLocalVarRef -> target.name + else -> null + } ?: return null + findSeedScopeRecord(name)?.let { return it } + importManager.rootScope.getLocalRecordDirect(name)?.let { return it } + importManager.rootScope.get(name)?.let { return it } + for (module in importedModules.asReversed()) { + module.scope.get(name)?.let { return it } + } + return null + } + + private fun inferCallReturnTypeDecl(ref: CallRef): TypeDecl? { + callReturnTypeDeclByRef[ref]?.let { return it } + val targetDecl = (resolveReceiverTypeDecl(ref.target) ?: seedTypeDeclFromRef(ref.target)) as? TypeDecl.Function + ?: return null + val bindings = mutableMapOf() + val paramList = mutableListOf() + targetDecl.receiver?.let { paramList += it } + paramList += targetDecl.params + + fun argTypeDecl(arg: ParsedArgument): TypeDecl? { + val stmt = arg.value as? ExpressionStatement ?: return null + val directRef = stmt.ref + return inferTypeDeclFromRef(directRef) + ?: inferObjClassFromRef(directRef)?.let { TypeDecl.Simple(it.className, false) } + } + + val ellipsisIndex = paramList.indexOfFirst { it is TypeDecl.Ellipsis } + if (ellipsisIndex < 0) { + val limit = minOf(paramList.size, ref.args.size) + for (i in 0 until limit) { + val argType = argTypeDecl(ref.args[i]) ?: continue + collectTypeVarBindings(paramList[i], argType, bindings) + } + } else { + val headCount = ellipsisIndex + val tailCount = paramList.size - ellipsisIndex - 1 + val argCount = ref.args.size + val headLimit = minOf(headCount, argCount) + for (i in 0 until headLimit) { + val argType = argTypeDecl(ref.args[i]) ?: continue + collectTypeVarBindings(paramList[i], argType, bindings) + } + val tailStartArg = maxOf(headCount, argCount - tailCount) + for (i in tailStartArg until argCount) { + val paramIndex = paramList.size - (argCount - i) + val argType = argTypeDecl(ref.args[i]) ?: continue + collectTypeVarBindings(paramList[paramIndex], argType, bindings) + } + val ellipsisArgEnd = argCount - tailCount + val ellipsisType = paramList[ellipsisIndex] as TypeDecl.Ellipsis + for (i in headCount until ellipsisArgEnd) { + val argType = if (ref.args[i].isSplat) { + val stmt = ref.args[i].value as? ExpressionStatement + stmt?.ref?.let { inferElementTypeFromSpread(it) } + } else { + argTypeDecl(ref.args[i]) + } ?: continue + collectTypeVarBindings(ellipsisType.elementType, argType, bindings) + } + } + + val inferred = if (bindings.isEmpty()) targetDecl.returnType + else substituteTypeAliasTypeVars(targetDecl.returnType, bindings) + callReturnTypeDeclByRef[ref] = inferred + return inferred } private fun checkFunctionTypeCallTypes( @@ -5519,6 +5638,102 @@ class Compiler( args: List, pos: Pos ) { + lookupNamedFunctionDecl(target)?.let { decl -> + val hasComplexArgs = args.any { it.name != null } || + decl.typeParams.isNotEmpty() || + decl.params.any { it.defaultValue != null || it.isEllipsis } + if (hasComplexArgs) return + val paramList = decl.params.map { if (it.isEllipsis) TypeDecl.Ellipsis(it.type) else it.type } + if (paramList.isEmpty()) return + val ellipsisIndex = decl.params.indexOfFirst { it.isEllipsis } + fun argTypeDecl(arg: ParsedArgument): TypeDecl? { + val stmt = arg.value as? ExpressionStatement ?: return null + val ref = stmt.ref + return inferTypeDeclFromRef(ref) + ?: inferObjClassFromRef(ref)?.let { TypeDecl.Simple(it.className, false) } + } + fun typeDeclSubtypeOf(arg: TypeDecl, param: TypeDecl): Boolean { + if (param == TypeDecl.TypeAny || param == TypeDecl.TypeNullableAny) return true + val (argBase, argNullable) = stripNullable(arg) + val (paramBase, paramNullable) = stripNullable(param) + if (argNullable && !paramNullable) return false + if (paramBase == TypeDecl.TypeAny) return true + if (paramBase is TypeDecl.TypeVar) return true + if (argBase is TypeDecl.TypeVar) return true + if (paramBase is TypeDecl.Simple && (paramBase.name == "Object" || paramBase.name == "Obj")) return true + if (argBase is TypeDecl.Ellipsis) return typeDeclSubtypeOf(argBase.elementType, paramBase) + if (paramBase is TypeDecl.Ellipsis) return typeDeclSubtypeOf(argBase, paramBase.elementType) + return when (argBase) { + is TypeDecl.Union -> argBase.options.all { typeDeclSubtypeOf(it, paramBase) } + is TypeDecl.Intersection -> argBase.options.any { typeDeclSubtypeOf(it, paramBase) } + else -> when (paramBase) { + is TypeDecl.Union -> paramBase.options.any { typeDeclSubtypeOf(argBase, it) } + is TypeDecl.Intersection -> paramBase.options.all { typeDeclSubtypeOf(argBase, it) } + else -> { + val argClass = resolveTypeDeclObjClass(argBase) ?: return false + val paramClass = resolveTypeDeclObjClass(paramBase) ?: return false + argClass == paramClass || argClass.allParentsSet.contains(paramClass) + } + } + } + } + fun fail(argPos: Pos, expected: TypeDecl, got: TypeDecl) { + throw ScriptError(argPos, "argument type ${typeDeclName(got)} does not match ${typeDeclName(expected)}") + } + if (ellipsisIndex < 0) { + val limit = minOf(paramList.size, args.size) + for (i in 0 until limit) { + val arg = args[i] + val argType = argTypeDecl(arg) ?: continue + val paramType = paramList[i] + if (!typeDeclSubtypeOf(argType, paramType)) { + fail(arg.pos, paramType, argType) + } + } + return + } + val headCount = ellipsisIndex + val tailCount = paramList.size - ellipsisIndex - 1 + val ellipsisType = paramList[ellipsisIndex] as TypeDecl.Ellipsis + val argCount = args.size + val headLimit = minOf(headCount, argCount) + for (i in 0 until headLimit) { + val arg = args[i] + val argType = argTypeDecl(arg) ?: continue + val paramType = paramList[i] + if (!typeDeclSubtypeOf(argType, paramType)) { + fail(arg.pos, paramType, argType) + } + } + val tailStartArg = maxOf(headCount, argCount - tailCount) + for (i in tailStartArg until argCount) { + val arg = args[i] + val paramType = paramList[paramList.size - (argCount - i)] + val argType = argTypeDecl(arg) ?: continue + if (!typeDeclSubtypeOf(argType, paramType)) { + fail(arg.pos, paramType, argType) + } + } + val ellipsisArgEnd = argCount - tailCount + for (i in headCount until ellipsisArgEnd) { + val arg = args[i] + val argType = if (arg.isSplat) { + val stmt = arg.value as? ExpressionStatement + val ref = stmt?.ref + ref?.let { inferElementTypeFromSpread(it) } + } else { + argTypeDecl(arg) + } ?: continue + if (!typeDeclSubtypeOf(argType, ellipsisType.elementType)) { + fail(arg.pos, ellipsisType.elementType, argType) + } + } + return + } + val seededCallable = lookupNamedCallableRecord(target) + if (seededCallable != null && seededCallable.type == ObjRecord.Type.Fun && seededCallable.value !is ObjExternCallable) { + return + } val decl = (resolveReceiverTypeDecl(target) as? TypeDecl.Function) ?: seedTypeDeclFromRef(target) as? TypeDecl.Function ?: return @@ -8074,7 +8289,7 @@ class Compiler( parseArgsDeclaration() ?: ArgsDeclaration(emptyList(), Token.Type.RPAREN) } else ArgsDeclaration(emptyList(), Token.Type.RPAREN) - if (mergedTypeParamDecls.isNotEmpty() && declKind != SymbolKind.MEMBER) { + if (declKind != SymbolKind.MEMBER) { currentGenericFunctionDecls()[name] = GenericFunctionDecl(mergedTypeParamDecls, argsDeclaration.params, nameStartPos) } @@ -8097,6 +8312,9 @@ class Compiler( if (compileBytecode) { delegateExpression = wrapFunctionBytecode(delegateExpression, "delegate@$name") } + if (declKind != SymbolKind.MEMBER) { + currentGenericFunctionDecls().remove(name) + } } if (isDelegated && declKind != SymbolKind.MEMBER) { val plan = slotPlanStack.lastOrNull() @@ -8441,6 +8659,12 @@ class Compiler( parentIsClassBody = parentIsClassBody, externCallSignature = externCallSignature, annotation = annotation, + typeDecl = if (isDelegated) null else TypeDecl.Function( + receiver = receiverTypeDecl, + params = argsDeclaration.params.map { it.type }, + returnType = inferredReturnDecl ?: TypeDecl.TypeAny, + nullable = false + ), fnBody = fnBody, closureBox = closureBox, captureSlots = captureSlots, diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/FunctionDeclStatement.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/FunctionDeclStatement.kt index 4805a42..615d936 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/FunctionDeclStatement.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/FunctionDeclStatement.kt @@ -1,5 +1,5 @@ /* - * Copyright 2026 Sergey S. Chernov + * Copyright 2026 Sergey S. Chernov real.sergeych@gmail.com * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -12,19 +12,12 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. + * */ package net.sergeych.lyng -import net.sergeych.lyng.obj.Obj -import net.sergeych.lyng.obj.ObjClass -import net.sergeych.lyng.obj.ObjExternCallable -import net.sergeych.lyng.obj.ObjExtensionMethodCallable -import net.sergeych.lyng.obj.ObjInstance -import net.sergeych.lyng.obj.ObjRecord -import net.sergeych.lyng.obj.ObjString -import net.sergeych.lyng.obj.ObjUnset -import net.sergeych.lyng.obj.ObjVoid +import net.sergeych.lyng.obj.* class FunctionClosureBox( var closure: Scope? = null, @@ -50,6 +43,7 @@ data class FunctionDeclSpec( val parentIsClassBody: Boolean, val externCallSignature: CallSignature?, val annotation: (suspend (Scope, ObjString, Statement) -> Statement)?, + val typeDecl: TypeDecl?, val fnBody: Statement, val closureBox: FunctionClosureBox, val captureSlots: List, @@ -73,7 +67,8 @@ internal suspend fun executeFunctionDecl( false, value, spec.visibility, - callSignature = existing.callSignature + callSignature = existing.callSignature, + typeDecl = spec.typeDecl ) return value } @@ -182,12 +177,20 @@ internal suspend fun executeFunctionDecl( isMutable = false, visibility = spec.visibility, declaringClass = null, - type = ObjRecord.Type.Fun + type = ObjRecord.Type.Fun, + typeDecl = spec.typeDecl ) ) val wrapperName = spec.extensionWrapperName ?: extensionCallableName(typeName, spec.name) val wrapper = ObjExtensionMethodCallable(spec.name, compiledFnBody) - scope.addItem(wrapperName, false, wrapper, spec.visibility, recordType = ObjRecord.Type.Fun) + scope.addItem( + wrapperName, + false, + wrapper, + spec.visibility, + recordType = ObjRecord.Type.Fun, + typeDecl = spec.typeDecl + ) } ?: run { val th = scope.thisObj if (!spec.isStatic && th is ObjClass) { @@ -203,10 +206,18 @@ internal suspend fun executeFunctionDecl( isClosed = spec.isClosed, isOverride = spec.isOverride, type = ObjRecord.Type.Fun, - methodId = spec.memberMethodId + methodId = spec.memberMethodId, + typeDecl = spec.typeDecl ) val memberValue = cls.members[spec.name]?.value ?: compiledFnBody - scope.addItem(spec.name, false, memberValue, spec.visibility, callSignature = spec.externCallSignature) + scope.addItem( + spec.name, + false, + memberValue, + spec.visibility, + callSignature = spec.externCallSignature, + typeDecl = spec.typeDecl + ) compiledFnBody } else { scope.addItem( @@ -215,7 +226,8 @@ internal suspend fun executeFunctionDecl( compiledFnBody, spec.visibility, recordType = ObjRecord.Type.Fun, - callSignature = spec.externCallSignature + callSignature = spec.externCallSignature, + typeDecl = spec.typeDecl ) } } diff --git a/lynglib/src/commonTest/kotlin/StdlibTest.kt b/lynglib/src/commonTest/kotlin/StdlibTest.kt index 3291371..c2e5eb6 100644 --- a/lynglib/src/commonTest/kotlin/StdlibTest.kt +++ b/lynglib/src/commonTest/kotlin/StdlibTest.kt @@ -241,4 +241,136 @@ class StdlibTest { assertThrows(IllegalArgumentException) { Random.next(..) } """.trimIndent()) } + + @Test + fun testInference2() = runTest { + eval( + $$""" + val a = 10 + val b = 3.0 + val c = floor(a / b) + //assert(c is Real) + c.toInt() + """.trimIndent() + ) + } + + @Test + fun testStdlibGlobalFunctionInference() = runTest { + eval( + $$""" + val absInt = abs(-5) + assert(absInt is Int) + absInt.toInt() + + val absReal = abs(-5.5) + assert(absReal is Real) + absReal.isNaN() + + val floorInt = floor(7) + assert(floorInt is Int) + floorInt.toInt() + + val floorReal = floor(7.9) + assert(floorReal is Real) + floorReal.toInt() + + val ceilInt = ceil(7) + assert(ceilInt is Int) + ceilInt.toInt() + + val ceilReal = ceil(7.1) + assert(ceilReal is Real) + ceilReal.toInt() + + val roundInt = round(7) + assert(roundInt is Int) + roundInt.toInt() + + val roundReal = round(7.4) + assert(roundReal is Real) + roundReal.toInt() + + val sinValue = sin(1) + sinValue.isInfinite() + assert(sinValue is Real) + + val cosValue = cos(1) + cosValue.isNaN() + assert(cosValue is Real) + + val tanValue = tan(1) + tanValue.toInt() + assert(tanValue is Real) + + val asinValue = asin(0.5) + asinValue.toInt() + assert(asinValue is Real) + + val acosValue = acos(0.5) + acosValue.toInt() + assert(acosValue is Real) + + val atanValue = atan(1) + atanValue.toInt() + assert(atanValue is Real) + + val sinhValue = sinh(1) + sinhValue.isInfinite() + assert(sinhValue is Real) + + val coshValue = cosh(1) + coshValue.isNaN() + assert(coshValue is Real) + + val tanhValue = tanh(1) + tanhValue.toInt() + assert(tanhValue is Real) + + val asinhValue = asinh(1) + asinhValue.toInt() + assert(asinhValue is Real) + + val acoshValue = acosh(2) + acoshValue.toInt() + assert(acoshValue is Real) + + val atanhValue = atanh(0.5) + atanhValue.toInt() + assert(atanhValue is Real) + + val expValue = exp(1) + expValue.isInfinite() + assert(expValue is Real) + + val lnValue = ln(2) + lnValue.isNaN() + assert(lnValue is Real) + + val log10Value = log10(100) + log10Value.toInt() + assert(log10Value is Real) + + val log2Value = log2(8) + log2Value.toInt() + assert(log2Value is Real) + + val powValue = pow(2, 8) + powValue.isInfinite() + assert(powValue is Real) + + val sqrtValue = sqrt(9) + sqrtValue.isNaN() + assert(sqrtValue is Real) + + val clampedInt = clamp(20, 0..10) + assert(clampedInt is Int) + clampedInt.toInt() + + val clampedReal = clamp(2.5, 0.0..10.0) + assert(clampedReal is Real) + clampedReal.toInt() + """.trimIndent() + ) + } } diff --git a/lynglib/src/commonTest/kotlin/net/sergeych/lyng/ComplexModuleTest.kt b/lynglib/src/commonTest/kotlin/net/sergeych/lyng/ComplexModuleTest.kt index c38461e..3e17fb4 100644 --- a/lynglib/src/commonTest/kotlin/net/sergeych/lyng/ComplexModuleTest.kt +++ b/lynglib/src/commonTest/kotlin/net/sergeych/lyng/ComplexModuleTest.kt @@ -86,6 +86,7 @@ class ComplexModuleTest { """.trimIndent() ) } + @Test fun testDecimalInferences() = runTest { eval( diff --git a/lynglib/stdlib/lyng/root.lyng b/lynglib/stdlib/lyng/root.lyng index 0ad4b21..b447087 100644 --- a/lynglib/stdlib/lyng/root.lyng +++ b/lynglib/stdlib/lyng/root.lyng @@ -121,9 +121,27 @@ extern class MapEntry : Array { // - the remaining decimal cases currently use a temporary bridge: // `Decimal -> Real -> host math -> Decimal` // - this is temporary and will be replaced with dedicated decimal implementations -extern fun abs(x: Object): Object -extern fun ln(x: Object): Object -extern fun pow(x: Object, y: Object): Object +extern fun abs(x: T): T +extern fun floor(x: T): T +extern fun ceil(x: T): T +extern fun round(x: T): T +extern fun sin(x: Object): Real +extern fun cos(x: Object): Real +extern fun tan(x: Object): Real +extern fun asin(x: Object): Real +extern fun acos(x: Object): Real +extern fun atan(x: Object): Real +extern fun sinh(x: Object): Real +extern fun cosh(x: Object): Real +extern fun tanh(x: Object): Real +extern fun asinh(x: Object): Real +extern fun acosh(x: Object): Real +extern fun atanh(x: Object): Real +extern fun exp(x: Object): Real +extern fun ln(x: Object): Real +extern fun log10(x: Object): Real +extern fun log2(x: Object): Real +extern fun pow(x: Object, y: Object): Real extern fun sqrt(x: Object): Real extern fun clamp(value: T, range: Range): T