stdlib inference bug fixed

This commit is contained in:
Sergey Chernov 2026-04-03 11:04:16 +03:00
parent 88b0bb2147
commit caad7d8ab9
5 changed files with 412 additions and 25 deletions

View File

@ -4346,7 +4346,7 @@ class Compiler(
is MapLiteralRef -> inferMapLiteralTypeDecl(ref)
is ConstRef -> inferTypeDeclFromConst(ref.constValue)
is CallRef -> {
val targetDecl = resolveReceiverTypeDecl(ref.target)
val targetDecl = resolveReceiverTypeDecl(ref.target) ?: seedTypeDeclFromRef(ref.target)
val targetName = when (val target = ref.target) {
is LocalVarRef -> target.name
is FastLocalVarRef -> target.name
@ -4354,8 +4354,9 @@ class Compiler(
else -> null
}
if (targetDecl is TypeDecl.Function) {
return targetDecl.returnType
return inferCallReturnTypeDecl(ref) ?: targetDecl.returnType
}
inferCallReturnTypeDecl(ref)?.let { return it }
if (targetName != null) {
callableReturnTypeDeclByName[targetName]?.let { return it }
(seedTypeDeclByName(targetName) as? TypeDecl.Function)?.let { return it.returnType }
@ -4745,7 +4746,7 @@ class Compiler(
classMethodReturnTypeDecl(targetClass, "getAt")
}
is MethodCallRef -> methodReturnTypeDeclByRef[ref]
is CallRef -> callReturnTypeDeclByRef[ref]
is CallRef -> callReturnTypeDeclByRef[ref] ?: inferCallReturnTypeDecl(ref)
is BinaryOpRef -> inferBinaryOpReturnTypeDecl(ref)
is StatementRef -> (ref.statement as? ExpressionStatement)?.let { resolveReceiverTypeDecl(it.ref) }
else -> null
@ -4811,7 +4812,7 @@ class Compiler(
is ImplicitThisMethodCallRef -> inferMethodCallReturnClass(ref.methodName())
is ThisMethodSlotCallRef -> inferMethodCallReturnClass(ref.methodName())
is QualifiedThisMethodSlotCallRef -> inferMethodCallReturnClass(ref.methodName())
is CallRef -> inferCallReturnClass(ref)
is CallRef -> inferCallReturnTypeDecl(ref)?.let { resolveTypeDeclObjClass(it) } ?: inferCallReturnClass(ref)
is BinaryOpRef -> inferBinaryOpReturnClass(ref)
is FieldRef -> {
val targetClass = resolveReceiverClassForMember(ref.target)
@ -5478,12 +5479,45 @@ class Compiler(
args: List<ParsedArgument>,
pos: Pos
) {
lookupNamedFunctionDecl(target)?.let { decl ->
val hasComplexArgs = args.any { it.name != null } ||
decl.typeParams.isNotEmpty() ||
decl.params.any { it.defaultValue != null || it.isEllipsis }
if (hasComplexArgs) return
if (args.any { it.isSplat }) return
val actual = args.size
val params = decl.params
val ellipsisIndex = params.indexOfFirst { it.isEllipsis }
if (ellipsisIndex < 0) {
val minArgs = params.count { it.defaultValue == null }
val maxArgs = params.size
if (actual < minArgs || actual > maxArgs) {
val message = if (minArgs == maxArgs) {
"expected $maxArgs arguments, got $actual"
} else {
"expected $minArgs..$maxArgs arguments, got $actual"
}
throw ScriptError(pos, message)
}
return
}
val headRequired = (0 until ellipsisIndex).count { params[it].defaultValue == null }
val tailRequired = (ellipsisIndex + 1 until params.size).count { params[it].defaultValue == null }
val minArgs = headRequired + tailRequired
if (actual < minArgs) {
throw ScriptError(pos, "expected at least $minArgs arguments, got $actual")
}
return
}
val seededCallable = lookupNamedCallableRecord(target)
if (seededCallable != null && seededCallable.type == ObjRecord.Type.Fun && seededCallable.value !is ObjExternCallable) {
return
}
val decl = (resolveReceiverTypeDecl(target) as? TypeDecl.Function)
?: seedTypeDeclFromRef(target) as? TypeDecl.Function
?: return
if (args.any { it.isSplat }) return
val actual = args.size
val receiverCount = if (decl.receiver != null) 1 else 0
val paramList = mutableListOf<TypeDecl>()
decl.receiver?.let { paramList += it }
paramList += decl.params
@ -5512,6 +5546,91 @@ class Compiler(
} ?: return null
seedScope?.getLocalRecordDirect(name)?.typeDecl?.let { return it }
return seedScope?.get(name)?.typeDecl
?: importManager.rootScope.getLocalRecordDirect(name)?.typeDecl
?: importManager.rootScope.get(name)?.typeDecl
}
private fun lookupNamedFunctionDecl(target: ObjRef): GenericFunctionDecl? {
val name = when (target) {
is LocalVarRef -> target.name
is LocalSlotRef -> target.name
is FastLocalVarRef -> target.name
else -> null
} ?: return null
return lookupGenericFunctionDecl(name)
}
private fun lookupNamedCallableRecord(target: ObjRef): ObjRecord? {
val name = when (target) {
is LocalVarRef -> target.name
is LocalSlotRef -> target.name
is FastLocalVarRef -> target.name
else -> null
} ?: return null
findSeedScopeRecord(name)?.let { return it }
importManager.rootScope.getLocalRecordDirect(name)?.let { return it }
importManager.rootScope.get(name)?.let { return it }
for (module in importedModules.asReversed()) {
module.scope.get(name)?.let { return it }
}
return null
}
private fun inferCallReturnTypeDecl(ref: CallRef): TypeDecl? {
callReturnTypeDeclByRef[ref]?.let { return it }
val targetDecl = (resolveReceiverTypeDecl(ref.target) ?: seedTypeDeclFromRef(ref.target)) as? TypeDecl.Function
?: return null
val bindings = mutableMapOf<String, TypeDecl>()
val paramList = mutableListOf<TypeDecl>()
targetDecl.receiver?.let { paramList += it }
paramList += targetDecl.params
fun argTypeDecl(arg: ParsedArgument): TypeDecl? {
val stmt = arg.value as? ExpressionStatement ?: return null
val directRef = stmt.ref
return inferTypeDeclFromRef(directRef)
?: inferObjClassFromRef(directRef)?.let { TypeDecl.Simple(it.className, false) }
}
val ellipsisIndex = paramList.indexOfFirst { it is TypeDecl.Ellipsis }
if (ellipsisIndex < 0) {
val limit = minOf(paramList.size, ref.args.size)
for (i in 0 until limit) {
val argType = argTypeDecl(ref.args[i]) ?: continue
collectTypeVarBindings(paramList[i], argType, bindings)
}
} else {
val headCount = ellipsisIndex
val tailCount = paramList.size - ellipsisIndex - 1
val argCount = ref.args.size
val headLimit = minOf(headCount, argCount)
for (i in 0 until headLimit) {
val argType = argTypeDecl(ref.args[i]) ?: continue
collectTypeVarBindings(paramList[i], argType, bindings)
}
val tailStartArg = maxOf(headCount, argCount - tailCount)
for (i in tailStartArg until argCount) {
val paramIndex = paramList.size - (argCount - i)
val argType = argTypeDecl(ref.args[i]) ?: continue
collectTypeVarBindings(paramList[paramIndex], argType, bindings)
}
val ellipsisArgEnd = argCount - tailCount
val ellipsisType = paramList[ellipsisIndex] as TypeDecl.Ellipsis
for (i in headCount until ellipsisArgEnd) {
val argType = if (ref.args[i].isSplat) {
val stmt = ref.args[i].value as? ExpressionStatement
stmt?.ref?.let { inferElementTypeFromSpread(it) }
} else {
argTypeDecl(ref.args[i])
} ?: continue
collectTypeVarBindings(ellipsisType.elementType, argType, bindings)
}
}
val inferred = if (bindings.isEmpty()) targetDecl.returnType
else substituteTypeAliasTypeVars(targetDecl.returnType, bindings)
callReturnTypeDeclByRef[ref] = inferred
return inferred
}
private fun checkFunctionTypeCallTypes(
@ -5519,6 +5638,102 @@ class Compiler(
args: List<ParsedArgument>,
pos: Pos
) {
lookupNamedFunctionDecl(target)?.let { decl ->
val hasComplexArgs = args.any { it.name != null } ||
decl.typeParams.isNotEmpty() ||
decl.params.any { it.defaultValue != null || it.isEllipsis }
if (hasComplexArgs) return
val paramList = decl.params.map { if (it.isEllipsis) TypeDecl.Ellipsis(it.type) else it.type }
if (paramList.isEmpty()) return
val ellipsisIndex = decl.params.indexOfFirst { it.isEllipsis }
fun argTypeDecl(arg: ParsedArgument): TypeDecl? {
val stmt = arg.value as? ExpressionStatement ?: return null
val ref = stmt.ref
return inferTypeDeclFromRef(ref)
?: inferObjClassFromRef(ref)?.let { TypeDecl.Simple(it.className, false) }
}
fun typeDeclSubtypeOf(arg: TypeDecl, param: TypeDecl): Boolean {
if (param == TypeDecl.TypeAny || param == TypeDecl.TypeNullableAny) return true
val (argBase, argNullable) = stripNullable(arg)
val (paramBase, paramNullable) = stripNullable(param)
if (argNullable && !paramNullable) return false
if (paramBase == TypeDecl.TypeAny) return true
if (paramBase is TypeDecl.TypeVar) return true
if (argBase is TypeDecl.TypeVar) return true
if (paramBase is TypeDecl.Simple && (paramBase.name == "Object" || paramBase.name == "Obj")) return true
if (argBase is TypeDecl.Ellipsis) return typeDeclSubtypeOf(argBase.elementType, paramBase)
if (paramBase is TypeDecl.Ellipsis) return typeDeclSubtypeOf(argBase, paramBase.elementType)
return when (argBase) {
is TypeDecl.Union -> argBase.options.all { typeDeclSubtypeOf(it, paramBase) }
is TypeDecl.Intersection -> argBase.options.any { typeDeclSubtypeOf(it, paramBase) }
else -> when (paramBase) {
is TypeDecl.Union -> paramBase.options.any { typeDeclSubtypeOf(argBase, it) }
is TypeDecl.Intersection -> paramBase.options.all { typeDeclSubtypeOf(argBase, it) }
else -> {
val argClass = resolveTypeDeclObjClass(argBase) ?: return false
val paramClass = resolveTypeDeclObjClass(paramBase) ?: return false
argClass == paramClass || argClass.allParentsSet.contains(paramClass)
}
}
}
}
fun fail(argPos: Pos, expected: TypeDecl, got: TypeDecl) {
throw ScriptError(argPos, "argument type ${typeDeclName(got)} does not match ${typeDeclName(expected)}")
}
if (ellipsisIndex < 0) {
val limit = minOf(paramList.size, args.size)
for (i in 0 until limit) {
val arg = args[i]
val argType = argTypeDecl(arg) ?: continue
val paramType = paramList[i]
if (!typeDeclSubtypeOf(argType, paramType)) {
fail(arg.pos, paramType, argType)
}
}
return
}
val headCount = ellipsisIndex
val tailCount = paramList.size - ellipsisIndex - 1
val ellipsisType = paramList[ellipsisIndex] as TypeDecl.Ellipsis
val argCount = args.size
val headLimit = minOf(headCount, argCount)
for (i in 0 until headLimit) {
val arg = args[i]
val argType = argTypeDecl(arg) ?: continue
val paramType = paramList[i]
if (!typeDeclSubtypeOf(argType, paramType)) {
fail(arg.pos, paramType, argType)
}
}
val tailStartArg = maxOf(headCount, argCount - tailCount)
for (i in tailStartArg until argCount) {
val arg = args[i]
val paramType = paramList[paramList.size - (argCount - i)]
val argType = argTypeDecl(arg) ?: continue
if (!typeDeclSubtypeOf(argType, paramType)) {
fail(arg.pos, paramType, argType)
}
}
val ellipsisArgEnd = argCount - tailCount
for (i in headCount until ellipsisArgEnd) {
val arg = args[i]
val argType = if (arg.isSplat) {
val stmt = arg.value as? ExpressionStatement
val ref = stmt?.ref
ref?.let { inferElementTypeFromSpread(it) }
} else {
argTypeDecl(arg)
} ?: continue
if (!typeDeclSubtypeOf(argType, ellipsisType.elementType)) {
fail(arg.pos, ellipsisType.elementType, argType)
}
}
return
}
val seededCallable = lookupNamedCallableRecord(target)
if (seededCallable != null && seededCallable.type == ObjRecord.Type.Fun && seededCallable.value !is ObjExternCallable) {
return
}
val decl = (resolveReceiverTypeDecl(target) as? TypeDecl.Function)
?: seedTypeDeclFromRef(target) as? TypeDecl.Function
?: return
@ -8074,7 +8289,7 @@ class Compiler(
parseArgsDeclaration() ?: ArgsDeclaration(emptyList(), Token.Type.RPAREN)
} else ArgsDeclaration(emptyList(), Token.Type.RPAREN)
if (mergedTypeParamDecls.isNotEmpty() && declKind != SymbolKind.MEMBER) {
if (declKind != SymbolKind.MEMBER) {
currentGenericFunctionDecls()[name] = GenericFunctionDecl(mergedTypeParamDecls, argsDeclaration.params, nameStartPos)
}
@ -8097,6 +8312,9 @@ class Compiler(
if (compileBytecode) {
delegateExpression = wrapFunctionBytecode(delegateExpression, "delegate@$name")
}
if (declKind != SymbolKind.MEMBER) {
currentGenericFunctionDecls().remove(name)
}
}
if (isDelegated && declKind != SymbolKind.MEMBER) {
val plan = slotPlanStack.lastOrNull()
@ -8441,6 +8659,12 @@ class Compiler(
parentIsClassBody = parentIsClassBody,
externCallSignature = externCallSignature,
annotation = annotation,
typeDecl = if (isDelegated) null else TypeDecl.Function(
receiver = receiverTypeDecl,
params = argsDeclaration.params.map { it.type },
returnType = inferredReturnDecl ?: TypeDecl.TypeAny,
nullable = false
),
fnBody = fnBody,
closureBox = closureBox,
captureSlots = captureSlots,

View File

@ -1,5 +1,5 @@
/*
* Copyright 2026 Sergey S. Chernov
* Copyright 2026 Sergey S. Chernov real.sergeych@gmail.com
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -12,19 +12,12 @@
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package net.sergeych.lyng
import net.sergeych.lyng.obj.Obj
import net.sergeych.lyng.obj.ObjClass
import net.sergeych.lyng.obj.ObjExternCallable
import net.sergeych.lyng.obj.ObjExtensionMethodCallable
import net.sergeych.lyng.obj.ObjInstance
import net.sergeych.lyng.obj.ObjRecord
import net.sergeych.lyng.obj.ObjString
import net.sergeych.lyng.obj.ObjUnset
import net.sergeych.lyng.obj.ObjVoid
import net.sergeych.lyng.obj.*
class FunctionClosureBox(
var closure: Scope? = null,
@ -50,6 +43,7 @@ data class FunctionDeclSpec(
val parentIsClassBody: Boolean,
val externCallSignature: CallSignature?,
val annotation: (suspend (Scope, ObjString, Statement) -> Statement)?,
val typeDecl: TypeDecl?,
val fnBody: Statement,
val closureBox: FunctionClosureBox,
val captureSlots: List<CaptureSlot>,
@ -73,7 +67,8 @@ internal suspend fun executeFunctionDecl(
false,
value,
spec.visibility,
callSignature = existing.callSignature
callSignature = existing.callSignature,
typeDecl = spec.typeDecl
)
return value
}
@ -182,12 +177,20 @@ internal suspend fun executeFunctionDecl(
isMutable = false,
visibility = spec.visibility,
declaringClass = null,
type = ObjRecord.Type.Fun
type = ObjRecord.Type.Fun,
typeDecl = spec.typeDecl
)
)
val wrapperName = spec.extensionWrapperName ?: extensionCallableName(typeName, spec.name)
val wrapper = ObjExtensionMethodCallable(spec.name, compiledFnBody)
scope.addItem(wrapperName, false, wrapper, spec.visibility, recordType = ObjRecord.Type.Fun)
scope.addItem(
wrapperName,
false,
wrapper,
spec.visibility,
recordType = ObjRecord.Type.Fun,
typeDecl = spec.typeDecl
)
} ?: run {
val th = scope.thisObj
if (!spec.isStatic && th is ObjClass) {
@ -203,10 +206,18 @@ internal suspend fun executeFunctionDecl(
isClosed = spec.isClosed,
isOverride = spec.isOverride,
type = ObjRecord.Type.Fun,
methodId = spec.memberMethodId
methodId = spec.memberMethodId,
typeDecl = spec.typeDecl
)
val memberValue = cls.members[spec.name]?.value ?: compiledFnBody
scope.addItem(spec.name, false, memberValue, spec.visibility, callSignature = spec.externCallSignature)
scope.addItem(
spec.name,
false,
memberValue,
spec.visibility,
callSignature = spec.externCallSignature,
typeDecl = spec.typeDecl
)
compiledFnBody
} else {
scope.addItem(
@ -215,7 +226,8 @@ internal suspend fun executeFunctionDecl(
compiledFnBody,
spec.visibility,
recordType = ObjRecord.Type.Fun,
callSignature = spec.externCallSignature
callSignature = spec.externCallSignature,
typeDecl = spec.typeDecl
)
}
}

View File

@ -241,4 +241,136 @@ class StdlibTest {
assertThrows(IllegalArgumentException) { Random.next(..) }
""".trimIndent())
}
@Test
fun testInference2() = runTest {
eval(
$$"""
val a = 10
val b = 3.0
val c = floor(a / b)
//assert(c is Real)
c.toInt()
""".trimIndent()
)
}
@Test
fun testStdlibGlobalFunctionInference() = runTest {
eval(
$$"""
val absInt = abs(-5)
assert(absInt is Int)
absInt.toInt()
val absReal = abs(-5.5)
assert(absReal is Real)
absReal.isNaN()
val floorInt = floor(7)
assert(floorInt is Int)
floorInt.toInt()
val floorReal = floor(7.9)
assert(floorReal is Real)
floorReal.toInt()
val ceilInt = ceil(7)
assert(ceilInt is Int)
ceilInt.toInt()
val ceilReal = ceil(7.1)
assert(ceilReal is Real)
ceilReal.toInt()
val roundInt = round(7)
assert(roundInt is Int)
roundInt.toInt()
val roundReal = round(7.4)
assert(roundReal is Real)
roundReal.toInt()
val sinValue = sin(1)
sinValue.isInfinite()
assert(sinValue is Real)
val cosValue = cos(1)
cosValue.isNaN()
assert(cosValue is Real)
val tanValue = tan(1)
tanValue.toInt()
assert(tanValue is Real)
val asinValue = asin(0.5)
asinValue.toInt()
assert(asinValue is Real)
val acosValue = acos(0.5)
acosValue.toInt()
assert(acosValue is Real)
val atanValue = atan(1)
atanValue.toInt()
assert(atanValue is Real)
val sinhValue = sinh(1)
sinhValue.isInfinite()
assert(sinhValue is Real)
val coshValue = cosh(1)
coshValue.isNaN()
assert(coshValue is Real)
val tanhValue = tanh(1)
tanhValue.toInt()
assert(tanhValue is Real)
val asinhValue = asinh(1)
asinhValue.toInt()
assert(asinhValue is Real)
val acoshValue = acosh(2)
acoshValue.toInt()
assert(acoshValue is Real)
val atanhValue = atanh(0.5)
atanhValue.toInt()
assert(atanhValue is Real)
val expValue = exp(1)
expValue.isInfinite()
assert(expValue is Real)
val lnValue = ln(2)
lnValue.isNaN()
assert(lnValue is Real)
val log10Value = log10(100)
log10Value.toInt()
assert(log10Value is Real)
val log2Value = log2(8)
log2Value.toInt()
assert(log2Value is Real)
val powValue = pow(2, 8)
powValue.isInfinite()
assert(powValue is Real)
val sqrtValue = sqrt(9)
sqrtValue.isNaN()
assert(sqrtValue is Real)
val clampedInt = clamp(20, 0..10)
assert(clampedInt is Int)
clampedInt.toInt()
val clampedReal = clamp(2.5, 0.0..10.0)
assert(clampedReal is Real)
clampedReal.toInt()
""".trimIndent()
)
}
}

View File

@ -86,6 +86,7 @@ class ComplexModuleTest {
""".trimIndent()
)
}
@Test
fun testDecimalInferences() = runTest {
eval(

View File

@ -121,9 +121,27 @@ extern class MapEntry<K,V> : Array<Object> {
// - the remaining decimal cases currently use a temporary bridge:
// `Decimal -> Real -> host math -> Decimal`
// - this is temporary and will be replaced with dedicated decimal implementations
extern fun abs(x: Object): Object
extern fun ln(x: Object): Object
extern fun pow(x: Object, y: Object): Object
extern fun abs<T>(x: T): T
extern fun floor<T>(x: T): T
extern fun ceil<T>(x: T): T
extern fun round<T>(x: T): T
extern fun sin(x: Object): Real
extern fun cos(x: Object): Real
extern fun tan(x: Object): Real
extern fun asin(x: Object): Real
extern fun acos(x: Object): Real
extern fun atan(x: Object): Real
extern fun sinh(x: Object): Real
extern fun cosh(x: Object): Real
extern fun tanh(x: Object): Real
extern fun asinh(x: Object): Real
extern fun acosh(x: Object): Real
extern fun atanh(x: Object): Real
extern fun exp(x: Object): Real
extern fun ln(x: Object): Real
extern fun log10(x: Object): Real
extern fun log2(x: Object): Real
extern fun pow(x: Object, y: Object): Real
extern fun sqrt(x: Object): Real
extern fun clamp<T>(value: T, range: Range<T>): T