Fix sqrt type inference and add regression coverage
This commit is contained in:
parent
c097464750
commit
3e338f3d53
@ -84,6 +84,7 @@ fun calculateDepth(
|
|||||||
|
|
||||||
h = hNew
|
h = hNew
|
||||||
iter++
|
iter++
|
||||||
|
println("iter: $iter: $h")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Не сошлось за maxIter
|
// Не сошлось за maxIter
|
||||||
@ -97,7 +98,7 @@ val d = 0.1 // м (10 см)
|
|||||||
|
|
||||||
val depth = calculateDepth(T, m, d)
|
val depth = calculateDepth(T, m, d)
|
||||||
if (depth != null) {
|
if (depth != null) {
|
||||||
println("Глубина: %.2f м".format(depth))
|
println("Глубина: %.2f м"(depth))
|
||||||
} else {
|
} else {
|
||||||
println("Расчёт не сошёлся")
|
println("Расчёт не сошёлся")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1296,23 +1296,11 @@ class Compiler(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (seedRecord != null) {
|
if (seedRecord != null) {
|
||||||
val value = seedRecord.value
|
seedImportTypeMetadata(name, seedRecord)
|
||||||
if (!nameObjClass.containsKey(name)) {
|
|
||||||
when (value) {
|
|
||||||
is ObjClass -> nameObjClass[name] = value
|
|
||||||
is ObjInstance -> nameObjClass[name] = value.objClass
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ImportBindingResolution(ImportBinding(name, ImportBindingSource.Seed), seedRecord)
|
return ImportBindingResolution(ImportBinding(name, ImportBindingSource.Seed), seedRecord)
|
||||||
}
|
}
|
||||||
if (rootRecord != null) {
|
if (rootRecord != null) {
|
||||||
val value = rootRecord.value
|
seedImportTypeMetadata(name, rootRecord)
|
||||||
if (!nameObjClass.containsKey(name)) {
|
|
||||||
when (value) {
|
|
||||||
is ObjClass -> nameObjClass[name] = value
|
|
||||||
is ObjInstance -> nameObjClass[name] = value.objClass
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ImportBindingResolution(ImportBinding(name, ImportBindingSource.Root), rootRecord)
|
return ImportBindingResolution(ImportBinding(name, ImportBindingSource.Root), rootRecord)
|
||||||
}
|
}
|
||||||
if (moduleMatches.isEmpty()) return null
|
if (moduleMatches.isEmpty()) return null
|
||||||
@ -1327,13 +1315,7 @@ class Compiler(
|
|||||||
val candidates = byOrigin[origin] ?: mutableListOf()
|
val candidates = byOrigin[origin] ?: mutableListOf()
|
||||||
val preferred = candidates.firstOrNull { it.first.scope.packageName == origin } ?: candidates.first()
|
val preferred = candidates.firstOrNull { it.first.scope.packageName == origin } ?: candidates.first()
|
||||||
val binding = ImportBinding(name, ImportBindingSource.Module(origin, preferred.first.pos))
|
val binding = ImportBinding(name, ImportBindingSource.Module(origin, preferred.first.pos))
|
||||||
val value = preferred.second.value
|
seedImportTypeMetadata(name, preferred.second)
|
||||||
if (!nameObjClass.containsKey(name)) {
|
|
||||||
when (value) {
|
|
||||||
is ObjClass -> nameObjClass[name] = value
|
|
||||||
is ObjInstance -> nameObjClass[name] = value.objClass
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ImportBindingResolution(binding, preferred.second)
|
return ImportBindingResolution(binding, preferred.second)
|
||||||
}
|
}
|
||||||
val moduleNames = moduleMatches.keys.toList()
|
val moduleNames = moduleMatches.keys.toList()
|
||||||
@ -1341,14 +1323,24 @@ class Compiler(
|
|||||||
}
|
}
|
||||||
val (module, record) = moduleMatches.values.first()
|
val (module, record) = moduleMatches.values.first()
|
||||||
val binding = ImportBinding(name, ImportBindingSource.Module(module.scope.packageName, module.pos))
|
val binding = ImportBinding(name, ImportBindingSource.Module(module.scope.packageName, module.pos))
|
||||||
val value = record.value
|
seedImportTypeMetadata(name, record)
|
||||||
|
return ImportBindingResolution(binding, record)
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun seedImportTypeMetadata(name: String, record: ObjRecord) {
|
||||||
|
if (record.typeDecl != null && nameTypeDecl[name] == null) {
|
||||||
|
nameTypeDecl[name] = record.typeDecl
|
||||||
|
}
|
||||||
if (!nameObjClass.containsKey(name)) {
|
if (!nameObjClass.containsKey(name)) {
|
||||||
when (value) {
|
record.typeDecl?.let { resolveTypeDeclObjClass(it) }?.let {
|
||||||
|
nameObjClass[name] = it
|
||||||
|
return
|
||||||
|
}
|
||||||
|
when (val value = record.value) {
|
||||||
is ObjClass -> nameObjClass[name] = value
|
is ObjClass -> nameObjClass[name] = value
|
||||||
is ObjInstance -> nameObjClass[name] = value.objClass
|
is ObjInstance -> nameObjClass[name] = value.objClass
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ImportBindingResolution(binding, record)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun collectModuleRecordMatches(
|
private fun collectModuleRecordMatches(
|
||||||
@ -4336,14 +4328,19 @@ class Compiler(
|
|||||||
is MapLiteralRef -> inferMapLiteralTypeDecl(ref)
|
is MapLiteralRef -> inferMapLiteralTypeDecl(ref)
|
||||||
is ConstRef -> inferTypeDeclFromConst(ref.constValue)
|
is ConstRef -> inferTypeDeclFromConst(ref.constValue)
|
||||||
is CallRef -> {
|
is CallRef -> {
|
||||||
|
val targetDecl = resolveReceiverTypeDecl(ref.target)
|
||||||
val targetName = when (val target = ref.target) {
|
val targetName = when (val target = ref.target) {
|
||||||
is LocalVarRef -> target.name
|
is LocalVarRef -> target.name
|
||||||
is FastLocalVarRef -> target.name
|
is FastLocalVarRef -> target.name
|
||||||
is LocalSlotRef -> target.name
|
is LocalSlotRef -> target.name
|
||||||
else -> null
|
else -> null
|
||||||
}
|
}
|
||||||
|
if (targetDecl is TypeDecl.Function) {
|
||||||
|
return targetDecl.returnType
|
||||||
|
}
|
||||||
if (targetName != null) {
|
if (targetName != null) {
|
||||||
callableReturnTypeDeclByName[targetName]?.let { return it }
|
callableReturnTypeDeclByName[targetName]?.let { return it }
|
||||||
|
(seedTypeDeclByName(targetName) as? TypeDecl.Function)?.let { return it.returnType }
|
||||||
}
|
}
|
||||||
inferCallReturnClass(ref)?.let { TypeDecl.Simple(it.className, false) }
|
inferCallReturnClass(ref)?.let { TypeDecl.Simple(it.className, false) }
|
||||||
?: run {
|
?: run {
|
||||||
@ -4677,6 +4674,17 @@ class Compiler(
|
|||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private fun lookupLocalTypeDeclByName(name: String): TypeDecl? {
|
||||||
|
val slotLoc = lookupSlotLocation(name, includeModule = true) ?: return null
|
||||||
|
return slotTypeDeclByScopeId[slotLoc.scopeId]?.get(slotLoc.slot)
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun lookupLocalObjClassByName(name: String): ObjClass? {
|
||||||
|
val slotLoc = lookupSlotLocation(name, includeModule = true) ?: return null
|
||||||
|
return slotTypeByScopeId[slotLoc.scopeId]?.get(slotLoc.slot)
|
||||||
|
?: slotTypeDeclByScopeId[slotLoc.scopeId]?.get(slotLoc.slot)?.let { resolveTypeDeclObjClass(it) }
|
||||||
|
}
|
||||||
|
|
||||||
private fun resolveReceiverTypeDecl(ref: ObjRef): TypeDecl? {
|
private fun resolveReceiverTypeDecl(ref: ObjRef): TypeDecl? {
|
||||||
return when (ref) {
|
return when (ref) {
|
||||||
is LocalSlotRef -> {
|
is LocalSlotRef -> {
|
||||||
@ -4686,8 +4694,12 @@ class Compiler(
|
|||||||
?: nameTypeDecl[ref.name]
|
?: nameTypeDecl[ref.name]
|
||||||
?: seedTypeDeclByName(ref.name)
|
?: seedTypeDeclByName(ref.name)
|
||||||
}
|
}
|
||||||
is LocalVarRef -> nameTypeDecl[ref.name] ?: seedTypeDeclByName(ref.name)
|
is LocalVarRef -> nameTypeDecl[ref.name]
|
||||||
is FastLocalVarRef -> nameTypeDecl[ref.name] ?: seedTypeDeclByName(ref.name)
|
?: lookupLocalTypeDeclByName(ref.name)
|
||||||
|
?: seedTypeDeclByName(ref.name)
|
||||||
|
is FastLocalVarRef -> nameTypeDecl[ref.name]
|
||||||
|
?: lookupLocalTypeDeclByName(ref.name)
|
||||||
|
?: seedTypeDeclByName(ref.name)
|
||||||
is FieldRef -> {
|
is FieldRef -> {
|
||||||
val targetDecl = resolveReceiverTypeDecl(ref.target) ?: return null
|
val targetDecl = resolveReceiverTypeDecl(ref.target) ?: return null
|
||||||
val targetClass = resolveTypeDeclObjClass(targetDecl) ?: resolveReceiverClassForMember(ref.target)
|
val targetClass = resolveTypeDeclObjClass(targetDecl) ?: resolveReceiverClassForMember(ref.target)
|
||||||
@ -4733,11 +4745,13 @@ class Compiler(
|
|||||||
is LocalVarRef -> nameObjClass[ref.name]
|
is LocalVarRef -> nameObjClass[ref.name]
|
||||||
?.takeIf { it == ObjDynamic.type }
|
?.takeIf { it == ObjDynamic.type }
|
||||||
?: nameObjClass[ref.name]
|
?: nameObjClass[ref.name]
|
||||||
|
?: lookupLocalObjClassByName(ref.name)
|
||||||
?: nameTypeDecl[ref.name]?.let { resolveTypeDeclObjClass(it) }
|
?: nameTypeDecl[ref.name]?.let { resolveTypeDeclObjClass(it) }
|
||||||
?: resolveClassByName(ref.name)
|
?: resolveClassByName(ref.name)
|
||||||
is FastLocalVarRef -> nameObjClass[ref.name]
|
is FastLocalVarRef -> nameObjClass[ref.name]
|
||||||
?.takeIf { it == ObjDynamic.type }
|
?.takeIf { it == ObjDynamic.type }
|
||||||
?: nameObjClass[ref.name]
|
?: nameObjClass[ref.name]
|
||||||
|
?: lookupLocalObjClassByName(ref.name)
|
||||||
?: nameTypeDecl[ref.name]?.let { resolveTypeDeclObjClass(it) }
|
?: nameTypeDecl[ref.name]?.let { resolveTypeDeclObjClass(it) }
|
||||||
?: resolveClassByName(ref.name)
|
?: resolveClassByName(ref.name)
|
||||||
is ClassScopeMemberRef -> {
|
is ClassScopeMemberRef -> {
|
||||||
|
|||||||
@ -395,9 +395,20 @@ class Script(
|
|||||||
requireExactCount(2)
|
requireExactCount(2)
|
||||||
decimalAwarePow(args[0], args[1])
|
decimalAwarePow(args[0], args[1])
|
||||||
}
|
}
|
||||||
addFn("sqrt") {
|
addItem(
|
||||||
decimalAwareUnaryMath(args.firstAndOnly(), fallback = ::sqrt)
|
"sqrt",
|
||||||
}
|
false,
|
||||||
|
ObjExternCallable.fromBridge {
|
||||||
|
decimalAwareUnaryMath(args.firstAndOnly(), fallback = ::sqrt)
|
||||||
|
},
|
||||||
|
recordType = ObjRecord.Type.Fun,
|
||||||
|
typeDecl = TypeDecl.Function(
|
||||||
|
receiver = null,
|
||||||
|
params = listOf(TypeDecl.TypeAny),
|
||||||
|
returnType = TypeDecl.Simple("Real", false),
|
||||||
|
nullable = false
|
||||||
|
)
|
||||||
|
)
|
||||||
addFn("abs") {
|
addFn("abs") {
|
||||||
val x = args.firstAndOnly()
|
val x = args.firstAndOnly()
|
||||||
if (x is ObjInt) ObjInt(x.value.absoluteValue)
|
if (x is ObjInt) ObjInt(x.value.absoluteValue)
|
||||||
|
|||||||
@ -0,0 +1,49 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2026 Sergey S. Chernov real.sergeych@gmail.com
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
import kotlinx.coroutines.test.runTest
|
||||||
|
import net.sergeych.lyng.eval
|
||||||
|
import net.sergeych.lyng.obj.ObjBool
|
||||||
|
import kotlin.test.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
class LocalRealMemberInferenceRegressionTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun sqrtInitializedLocalVarKeepsRealReceiverTypeForMembers() = runTest {
|
||||||
|
val result = eval(
|
||||||
|
"""
|
||||||
|
fun probe(T: Real, c: Real, g: Real): Bool {
|
||||||
|
fun passthrough(h: Real): Real = h / c
|
||||||
|
|
||||||
|
val term = 1.0 + g * T / c
|
||||||
|
val sqrtTerm = sqrt(1.0 + 2.0 * g * T / c)
|
||||||
|
assert(sqrtTerm is Real)
|
||||||
|
assert(!sqrtTerm.isNaN())
|
||||||
|
var h = (c * c / g) * (term - sqrtTerm)
|
||||||
|
assert(passthrough(h) >= 0.0)
|
||||||
|
assert(h is Real)
|
||||||
|
!h.isNaN()
|
||||||
|
}
|
||||||
|
|
||||||
|
probe(6.0, 340.0, 9.81)
|
||||||
|
""".trimIndent()
|
||||||
|
)
|
||||||
|
|
||||||
|
assertEquals(ObjBool(true), result)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -5576,7 +5576,7 @@ class ScriptTest {
|
|||||||
|
|
||||||
val depth = calculateDepth(T, m, d)
|
val depth = calculateDepth(T, m, d)
|
||||||
if (depth != null) {
|
if (depth != null) {
|
||||||
println("Глубина: %.2f м".format(depth))
|
println("Глубина: %.2f м"(depth))
|
||||||
} else {
|
} else {
|
||||||
println("Расчёт не сошёлся")
|
println("Расчёт не сошёлся")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -87,7 +87,7 @@ extern class MapEntry<K,V> : Array<Object> {
|
|||||||
extern fun abs(x: Object): Object
|
extern fun abs(x: Object): Object
|
||||||
extern fun ln(x: Object): Object
|
extern fun ln(x: Object): Object
|
||||||
extern fun pow(x: Object, y: Object): Object
|
extern fun pow(x: Object, y: Object): Object
|
||||||
extern fun sqrt(x: Object): Object
|
extern fun sqrt(x: Object): Real
|
||||||
extern fun clamp<T>(value: T, range: Range<T>): T
|
extern fun clamp<T>(value: T, range: Range<T>): T
|
||||||
|
|
||||||
class SeededRandom {
|
class SeededRandom {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user