From 3e338f3d539b7dd28a502017b9d4df5b05e3630d Mon Sep 17 00:00:00 2001 From: sergeych Date: Tue, 31 Mar 2026 22:26:58 +0300 Subject: [PATCH] Fix sqrt type inference and add regression coverage --- examples/free_fall.lyng | 3 +- .../kotlin/net/sergeych/lyng/Compiler.kt | 66 +++++++++++-------- .../kotlin/net/sergeych/lyng/Script.kt | 17 ++++- .../LocalRealMemberInferenceRegressionTest.kt | 49 ++++++++++++++ lynglib/src/commonTest/kotlin/ScriptTest.kt | 2 +- lynglib/stdlib/lyng/root.lyng | 2 +- 6 files changed, 107 insertions(+), 32 deletions(-) create mode 100644 lynglib/src/commonTest/kotlin/LocalRealMemberInferenceRegressionTest.kt diff --git a/examples/free_fall.lyng b/examples/free_fall.lyng index 145e27d..a487d4f 100644 --- a/examples/free_fall.lyng +++ b/examples/free_fall.lyng @@ -84,6 +84,7 @@ fun calculateDepth( h = hNew iter++ + println("iter: $iter: $h") } // Не сошлось за maxIter @@ -97,7 +98,7 @@ val d = 0.1 // м (10 см) val depth = calculateDepth(T, m, d) if (depth != null) { - println("Глубина: %.2f м".format(depth)) + println("Глубина: %.2f м"(depth)) } else { println("Расчёт не сошёлся") } diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt index e5a113b..68e980c 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt @@ -1296,23 +1296,11 @@ class Compiler( } } if (seedRecord != null) { - val value = seedRecord.value - if (!nameObjClass.containsKey(name)) { - when (value) { - is ObjClass -> nameObjClass[name] = value - is ObjInstance -> nameObjClass[name] = value.objClass - } - } + seedImportTypeMetadata(name, seedRecord) return ImportBindingResolution(ImportBinding(name, ImportBindingSource.Seed), seedRecord) } if (rootRecord != null) { - val value = rootRecord.value - if (!nameObjClass.containsKey(name)) { - when (value) { - is ObjClass -> nameObjClass[name] = value - is ObjInstance -> nameObjClass[name] = value.objClass - } - } + seedImportTypeMetadata(name, rootRecord) return ImportBindingResolution(ImportBinding(name, ImportBindingSource.Root), rootRecord) } if (moduleMatches.isEmpty()) return null @@ -1327,13 +1315,7 @@ class Compiler( val candidates = byOrigin[origin] ?: mutableListOf() val preferred = candidates.firstOrNull { it.first.scope.packageName == origin } ?: candidates.first() val binding = ImportBinding(name, ImportBindingSource.Module(origin, preferred.first.pos)) - val value = preferred.second.value - if (!nameObjClass.containsKey(name)) { - when (value) { - is ObjClass -> nameObjClass[name] = value - is ObjInstance -> nameObjClass[name] = value.objClass - } - } + seedImportTypeMetadata(name, preferred.second) return ImportBindingResolution(binding, preferred.second) } val moduleNames = moduleMatches.keys.toList() @@ -1341,14 +1323,24 @@ class Compiler( } val (module, record) = moduleMatches.values.first() val binding = ImportBinding(name, ImportBindingSource.Module(module.scope.packageName, module.pos)) - val value = record.value + seedImportTypeMetadata(name, record) + return ImportBindingResolution(binding, record) + } + + private fun seedImportTypeMetadata(name: String, record: ObjRecord) { + if (record.typeDecl != null && nameTypeDecl[name] == null) { + nameTypeDecl[name] = record.typeDecl + } if (!nameObjClass.containsKey(name)) { - when (value) { + record.typeDecl?.let { resolveTypeDeclObjClass(it) }?.let { + nameObjClass[name] = it + return + } + when (val value = record.value) { is ObjClass -> nameObjClass[name] = value is ObjInstance -> nameObjClass[name] = value.objClass } } - return ImportBindingResolution(binding, record) } private fun collectModuleRecordMatches( @@ -4336,14 +4328,19 @@ class Compiler( is MapLiteralRef -> inferMapLiteralTypeDecl(ref) is ConstRef -> inferTypeDeclFromConst(ref.constValue) is CallRef -> { + val targetDecl = resolveReceiverTypeDecl(ref.target) val targetName = when (val target = ref.target) { is LocalVarRef -> target.name is FastLocalVarRef -> target.name is LocalSlotRef -> target.name else -> null } + if (targetDecl is TypeDecl.Function) { + return targetDecl.returnType + } if (targetName != null) { callableReturnTypeDeclByName[targetName]?.let { return it } + (seedTypeDeclByName(targetName) as? TypeDecl.Function)?.let { return it.returnType } } inferCallReturnClass(ref)?.let { TypeDecl.Simple(it.className, false) } ?: run { @@ -4677,6 +4674,17 @@ class Compiler( return null } + private fun lookupLocalTypeDeclByName(name: String): TypeDecl? { + val slotLoc = lookupSlotLocation(name, includeModule = true) ?: return null + return slotTypeDeclByScopeId[slotLoc.scopeId]?.get(slotLoc.slot) + } + + private fun lookupLocalObjClassByName(name: String): ObjClass? { + val slotLoc = lookupSlotLocation(name, includeModule = true) ?: return null + return slotTypeByScopeId[slotLoc.scopeId]?.get(slotLoc.slot) + ?: slotTypeDeclByScopeId[slotLoc.scopeId]?.get(slotLoc.slot)?.let { resolveTypeDeclObjClass(it) } + } + private fun resolveReceiverTypeDecl(ref: ObjRef): TypeDecl? { return when (ref) { is LocalSlotRef -> { @@ -4686,8 +4694,12 @@ class Compiler( ?: nameTypeDecl[ref.name] ?: seedTypeDeclByName(ref.name) } - is LocalVarRef -> nameTypeDecl[ref.name] ?: seedTypeDeclByName(ref.name) - is FastLocalVarRef -> nameTypeDecl[ref.name] ?: seedTypeDeclByName(ref.name) + is LocalVarRef -> nameTypeDecl[ref.name] + ?: lookupLocalTypeDeclByName(ref.name) + ?: seedTypeDeclByName(ref.name) + is FastLocalVarRef -> nameTypeDecl[ref.name] + ?: lookupLocalTypeDeclByName(ref.name) + ?: seedTypeDeclByName(ref.name) is FieldRef -> { val targetDecl = resolveReceiverTypeDecl(ref.target) ?: return null val targetClass = resolveTypeDeclObjClass(targetDecl) ?: resolveReceiverClassForMember(ref.target) @@ -4733,11 +4745,13 @@ class Compiler( is LocalVarRef -> nameObjClass[ref.name] ?.takeIf { it == ObjDynamic.type } ?: nameObjClass[ref.name] + ?: lookupLocalObjClassByName(ref.name) ?: nameTypeDecl[ref.name]?.let { resolveTypeDeclObjClass(it) } ?: resolveClassByName(ref.name) is FastLocalVarRef -> nameObjClass[ref.name] ?.takeIf { it == ObjDynamic.type } ?: nameObjClass[ref.name] + ?: lookupLocalObjClassByName(ref.name) ?: nameTypeDecl[ref.name]?.let { resolveTypeDeclObjClass(it) } ?: resolveClassByName(ref.name) is ClassScopeMemberRef -> { diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Script.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Script.kt index 1648184..a666b2f 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Script.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Script.kt @@ -395,9 +395,20 @@ class Script( requireExactCount(2) decimalAwarePow(args[0], args[1]) } - addFn("sqrt") { - decimalAwareUnaryMath(args.firstAndOnly(), fallback = ::sqrt) - } + addItem( + "sqrt", + false, + ObjExternCallable.fromBridge { + decimalAwareUnaryMath(args.firstAndOnly(), fallback = ::sqrt) + }, + recordType = ObjRecord.Type.Fun, + typeDecl = TypeDecl.Function( + receiver = null, + params = listOf(TypeDecl.TypeAny), + returnType = TypeDecl.Simple("Real", false), + nullable = false + ) + ) addFn("abs") { val x = args.firstAndOnly() if (x is ObjInt) ObjInt(x.value.absoluteValue) diff --git a/lynglib/src/commonTest/kotlin/LocalRealMemberInferenceRegressionTest.kt b/lynglib/src/commonTest/kotlin/LocalRealMemberInferenceRegressionTest.kt new file mode 100644 index 0000000..d193f12 --- /dev/null +++ b/lynglib/src/commonTest/kotlin/LocalRealMemberInferenceRegressionTest.kt @@ -0,0 +1,49 @@ +/* + * Copyright 2026 Sergey S. Chernov real.sergeych@gmail.com + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +import kotlinx.coroutines.test.runTest +import net.sergeych.lyng.eval +import net.sergeych.lyng.obj.ObjBool +import kotlin.test.Test +import kotlin.test.assertEquals + +class LocalRealMemberInferenceRegressionTest { + + @Test + fun sqrtInitializedLocalVarKeepsRealReceiverTypeForMembers() = runTest { + val result = eval( + """ + fun probe(T: Real, c: Real, g: Real): Bool { + fun passthrough(h: Real): Real = h / c + + val term = 1.0 + g * T / c + val sqrtTerm = sqrt(1.0 + 2.0 * g * T / c) + assert(sqrtTerm is Real) + assert(!sqrtTerm.isNaN()) + var h = (c * c / g) * (term - sqrtTerm) + assert(passthrough(h) >= 0.0) + assert(h is Real) + !h.isNaN() + } + + probe(6.0, 340.0, 9.81) + """.trimIndent() + ) + + assertEquals(ObjBool(true), result) + } +} diff --git a/lynglib/src/commonTest/kotlin/ScriptTest.kt b/lynglib/src/commonTest/kotlin/ScriptTest.kt index d5fd142..86aaf68 100644 --- a/lynglib/src/commonTest/kotlin/ScriptTest.kt +++ b/lynglib/src/commonTest/kotlin/ScriptTest.kt @@ -5576,7 +5576,7 @@ class ScriptTest { val depth = calculateDepth(T, m, d) if (depth != null) { - println("Глубина: %.2f м".format(depth)) + println("Глубина: %.2f м"(depth)) } else { println("Расчёт не сошёлся") } diff --git a/lynglib/stdlib/lyng/root.lyng b/lynglib/stdlib/lyng/root.lyng index bd812a4..f92320c 100644 --- a/lynglib/stdlib/lyng/root.lyng +++ b/lynglib/stdlib/lyng/root.lyng @@ -87,7 +87,7 @@ extern class MapEntry : Array { extern fun abs(x: Object): Object extern fun ln(x: Object): Object extern fun pow(x: Object, y: Object): Object -extern fun sqrt(x: Object): Object +extern fun sqrt(x: Object): Real extern fun clamp(value: T, range: Range): T class SeededRandom {