Type-aware += checks for declared collection members; preserve field type metadata

This commit is contained in:
Sergey Chernov 2026-03-12 19:41:42 +03:00
parent c9021eb9cf
commit 9417f8f0cc
8 changed files with 205 additions and 32 deletions

View File

@ -33,6 +33,7 @@ class ClassInstanceFieldDeclStatement(
val isMutable: Boolean,
val visibility: Visibility,
val writeVisibility: Visibility?,
val typeDecl: TypeDecl?,
val isAbstract: Boolean,
val isClosed: Boolean,
val isOverride: Boolean,

View File

@ -212,6 +212,9 @@ class Compiler(
scopeSeedNames.add(name)
if (record.typeDecl != null && nameTypeDecl[name] == null) {
nameTypeDecl[name] = record.typeDecl
if (nameObjClass[name] == null) {
resolveTypeDeclObjClass(record.typeDecl)?.let { nameObjClass[name] = it }
}
}
val instance = record.value as? ObjInstance
if (instance != null && nameObjClass[name] == null) {
@ -291,6 +294,9 @@ class Compiler(
scopeSeedNames.add(name)
if (record.typeDecl != null && nameTypeDecl[name] == null) {
nameTypeDecl[name] = record.typeDecl
if (nameObjClass[name] == null) {
resolveTypeDeclObjClass(record.typeDecl)?.let { nameObjClass[name] = it }
}
}
if (record.typeDecl != null) {
slotTypeDeclByScopeId.getOrPut(plan.id) { mutableMapOf() }[slotIndex] = record.typeDecl
@ -1208,6 +1214,11 @@ class Compiler(
for ((name, record) in current.objects) {
if (!record.visibility.isPublic) continue
if (nameObjClass.containsKey(name)) continue
val declaredClass = record.typeDecl?.let { resolveTypeDeclObjClass(it) }
if (declaredClass != null) {
nameObjClass[name] = declaredClass
continue
}
val resolved = when (val raw = record.value) {
is FrameSlotRef -> raw.peekValue() ?: raw.read()
is RecordSlotRef -> raw.peekValue() ?: raw.read()
@ -2519,6 +2530,9 @@ class Compiler(
} else {
val rvalue = parseExpressionLevel(level + 1)
?: throw ScriptError(opToken.pos, "Expecting expression")
if (opToken.type == Token.Type.PLUSASSIGN) {
checkCollectionPlusAssignTypes(lvalue!!, rvalue, opToken.pos)
}
op.generate(opToken.pos, lvalue!!, rvalue)
}
if (opToken.type == Token.Type.ASSIGN) {
@ -4197,6 +4211,22 @@ class Compiler(
is ListLiteralRef -> inferListLiteralTypeDecl(ref)
is MapLiteralRef -> inferMapLiteralTypeDecl(ref)
is ConstRef -> inferTypeDeclFromConst(ref.constValue)
is CallRef -> {
inferCallReturnClass(ref)?.let { TypeDecl.Simple(it.className, false) }
?: run {
val targetName = when (val target = ref.target) {
is LocalVarRef -> target.name
is FastLocalVarRef -> target.name
is LocalSlotRef -> target.name
else -> null
}
if (targetName != null && targetName.firstOrNull()?.isUpperCase() == true) {
TypeDecl.Simple(targetName, false)
} else {
null
}
}
}
else -> null
}
}
@ -4351,6 +4381,66 @@ class Compiler(
return TypeDecl.TypeAny
}
private fun inferCollectionElementType(typeDecl: TypeDecl): TypeDecl? {
val generic = typeDecl as? TypeDecl.Generic ?: return null
val base = generic.name.substringAfterLast('.')
return when (base) {
"Set", "List", "Iterable", "Collection", "Array" -> generic.args.firstOrNull()
else -> null
}
}
private fun typeDeclSubtypeOf(arg: TypeDecl, param: TypeDecl): Boolean {
if (param == TypeDecl.TypeAny || param == TypeDecl.TypeNullableAny) return true
val (argBase, argNullable) = stripNullable(arg)
val (paramBase, paramNullable) = stripNullable(param)
if (argNullable && !paramNullable) return false
if (paramBase == TypeDecl.TypeAny) return true
if (paramBase is TypeDecl.TypeVar) return true
if (argBase is TypeDecl.TypeVar) return true
if (paramBase is TypeDecl.Simple && (paramBase.name == "Object" || paramBase.name == "Obj")) return true
if (argBase is TypeDecl.Ellipsis) return typeDeclSubtypeOf(argBase.elementType, paramBase)
if (paramBase is TypeDecl.Ellipsis) return typeDeclSubtypeOf(argBase, paramBase.elementType)
return when (argBase) {
is TypeDecl.Union -> argBase.options.all { typeDeclSubtypeOf(it, paramBase) }
is TypeDecl.Intersection -> argBase.options.any { typeDeclSubtypeOf(it, paramBase) }
else -> when (paramBase) {
is TypeDecl.Union -> paramBase.options.any { typeDeclSubtypeOf(argBase, it) }
is TypeDecl.Intersection -> paramBase.options.all { typeDeclSubtypeOf(argBase, it) }
else -> {
val argClass = resolveTypeDeclObjClass(argBase) ?: return false
val paramClass = resolveTypeDeclObjClass(paramBase) ?: return false
argClass == paramClass || argClass.allParentsSet.contains(paramClass)
}
}
}
}
private fun checkCollectionPlusAssignTypes(targetRef: ObjRef, valueRef: ObjRef, pos: Pos) {
// Enforce strict compile-time element checks for declared members.
// Local vars can be inferred from literals and are allowed to widen dynamically.
if (targetRef !is FieldRef) return
val targetDeclRaw = resolveReceiverTypeDecl(targetRef) ?: return
val targetDecl = expandTypeAliases(targetDeclRaw, pos)
val targetGeneric = targetDecl as? TypeDecl.Generic ?: return
val targetBase = targetGeneric.name.substringAfterLast('.')
if (targetBase != "Set" && targetBase != "List") return
val elementRaw = targetGeneric.args.firstOrNull() ?: return
val elementDecl = expandTypeAliases(elementRaw, pos)
val valueDeclRaw = inferTypeDeclFromRef(valueRef) ?: return
val valueDecl = expandTypeAliases(valueDeclRaw, pos)
if (typeDeclSubtypeOf(valueDecl, elementDecl)) return
val sourceElementDecl = inferCollectionElementType(valueDecl)?.let { expandTypeAliases(it, pos) }
if (sourceElementDecl != null && typeDeclSubtypeOf(sourceElementDecl, elementDecl)) return
throw ScriptError(
pos,
"argument type ${typeDeclName(valueDecl)} does not match ${typeDeclName(elementDecl)} for '+='"
)
}
private fun stripNullable(type: TypeDecl): Pair<TypeDecl, Boolean> {
if (type is TypeDecl.TypeNullableAny) return TypeDecl.TypeAny to true
val nullable = type.isNullable
@ -4425,7 +4515,7 @@ class Compiler(
is FastLocalVarRef -> nameTypeDecl[ref.name] ?: seedTypeDeclByName(ref.name)
is FieldRef -> {
val targetDecl = resolveReceiverTypeDecl(ref.target) ?: return null
val targetClass = resolveTypeDeclObjClass(targetDecl)
val targetClass = resolveTypeDeclObjClass(targetDecl) ?: resolveReceiverClassForMember(ref.target)
targetClass?.getInstanceMemberOrNull(ref.name, includeAbstract = true)?.typeDecl?.let { return it }
classFieldTypesByName[targetClass?.className]?.get(ref.name)
?.let { return TypeDecl.Simple(it.className, false) }
@ -9039,6 +9129,7 @@ class Compiler(
isMutable = isMutable,
visibility = visibility,
writeVisibility = setterVisibility,
typeDecl = if (varTypeDecl == TypeDecl.TypeAny || varTypeDecl == TypeDecl.TypeNullableAny) null else varTypeDecl,
isAbstract = isAbstract,
isClosed = isClosed,
isOverride = isOverride,

View File

@ -5174,6 +5174,7 @@ class BytecodeCompiler(
isMutable = stmt.isMutable,
visibility = stmt.visibility,
writeVisibility = stmt.writeVisibility,
typeDecl = stmt.typeDecl,
isTransient = stmt.isTransient,
isAbstract = stmt.isAbstract,
isClosed = stmt.isClosed,
@ -6998,7 +6999,9 @@ class BytecodeCompiler(
val slot = resolveSlot(ref)
val fromSlot = slot?.let { slotObjClass[it] }
fromSlot
?: slot?.let { typeDeclForSlot(it) }?.let { resolveClassFromTypeDecl(it) }
?: slotTypeByScopeId[ownerScopeId]?.get(ownerSlot)
?: slotTypeDeclByScopeId[ownerScopeId]?.get(ownerSlot)?.let { resolveClassFromTypeDecl(it) }
?: nameObjClass[ref.name]
?: resolveTypeNameClass(ref.name)
?: slotInitClassByKey[ScopeSlotKey(ownerScopeId, ownerSlot)]
@ -7016,9 +7019,14 @@ class BytecodeCompiler(
}
val fromSlot = resolveDirectNameSlot(ref.name)?.let { slotObjClass[it.slot] }
if (fromSlot != null) return fromSlot
val fromDirectTypeDecl = resolveDirectNameSlot(ref.name)
?.let { typeDeclForSlot(it.slot) }
?.let { resolveClassFromTypeDecl(it) }
if (fromDirectTypeDecl != null) return fromDirectTypeDecl
val key = localSlotInfoMap.entries.firstOrNull { it.value.name == ref.name }?.key
key?.let {
slotTypeByScopeId[it.scopeId]?.get(it.slot)
?: slotTypeDeclByScopeId[it.scopeId]?.get(it.slot)?.let { decl -> resolveClassFromTypeDecl(decl) }
?: slotInitClassByKey[it]
} ?: nameObjClass[ref.name]
?: resolveTypeNameClass(ref.name)
@ -7029,9 +7037,14 @@ class BytecodeCompiler(
}
val fromSlot = resolveDirectNameSlot(ref.name)?.let { slotObjClass[it.slot] }
if (fromSlot != null) return fromSlot
val fromDirectTypeDecl = resolveDirectNameSlot(ref.name)
?.let { typeDeclForSlot(it.slot) }
?.let { resolveClassFromTypeDecl(it) }
if (fromDirectTypeDecl != null) return fromDirectTypeDecl
val key = localSlotInfoMap.entries.firstOrNull { it.value.name == ref.name }?.key
key?.let {
slotTypeByScopeId[it.scopeId]?.get(it.slot)
?: slotTypeDeclByScopeId[it.scopeId]?.get(it.slot)?.let { decl -> resolveClassFromTypeDecl(decl) }
?: slotInitClassByKey[it]
} ?: nameObjClass[ref.name]
?: resolveTypeNameClass(ref.name)
@ -7073,6 +7086,23 @@ class BytecodeCompiler(
}
}
private fun resolveClassFromTypeDecl(typeDecl: TypeDecl): ObjClass? {
return when (typeDecl) {
is TypeDecl.Simple -> {
resolveTypeNameClass(typeDecl.name) ?: nameObjClass[typeDecl.name]?.let { cls ->
if (cls == ObjClassType) ObjDynamic.type else cls
}
}
is TypeDecl.Generic -> {
resolveTypeNameClass(typeDecl.name) ?: nameObjClass[typeDecl.name]?.let { cls ->
if (cls == ObjClassType) ObjDynamic.type else cls
}
}
is TypeDecl.Ellipsis -> resolveClassFromTypeDecl(typeDecl.elementType)
else -> null
}
}
private fun isKnownClassReceiver(ref: ObjRef): Boolean {
return when (ref) {
is LocalVarRef -> {

View File

@ -100,6 +100,7 @@ sealed class BytecodeConst {
val isMutable: Boolean,
val visibility: Visibility,
val writeVisibility: Visibility?,
val typeDecl: TypeDecl?,
val isTransient: Boolean,
val isAbstract: Boolean,
val isClosed: Boolean,

View File

@ -348,6 +348,7 @@ class BytecodeStatement private constructor(
stmt.isMutable,
stmt.visibility,
stmt.writeVisibility,
stmt.typeDecl,
stmt.isAbstract,
stmt.isClosed,
stmt.isOverride,

View File

@ -2750,6 +2750,7 @@ class CmdDeclClassInstanceField(internal val constId: Int, internal val slot: In
isClosed = decl.isClosed,
isOverride = decl.isOverride,
isTransient = decl.isTransient,
typeDecl = decl.typeDecl,
type = ObjRecord.Type.Field,
fieldId = decl.fieldId
)

View File

@ -826,6 +826,7 @@ open class ObjClass(
type: ObjRecord.Type = ObjRecord.Type.Field,
fieldId: Int? = null,
methodId: Int? = null,
typeDecl: net.sergeych.lyng.TypeDecl? = null,
): ObjRecord {
// Validation of override rules: only for non-system declarations
var existing: ObjRecord? = null
@ -921,6 +922,7 @@ open class ObjClass(
isOverride = isOverride,
isTransient = isTransient,
type = type,
typeDecl = typeDecl,
memberName = name,
fieldId = effectiveFieldId,
methodId = effectiveMethodId

View File

@ -20,6 +20,7 @@ import net.sergeych.lyng.Script
import net.sergeych.lyng.ScriptError
import net.sergeych.lyng.eval
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertTrue
@ -433,35 +434,80 @@ class TypesTest {
""".trimIndent())
}
// @Test
// fun testAliasesInGenerics1() = runTest {
// val scope = Script.newScope()
// scope.eval("""
// type IntList<T: Int> = List<T>
// type IntMap<K,V> = Map<K,V>
// type IntSet<T: Int> = Set<T>
// type IntPair<T: Int> = Pair<T,T>
// type IntTriple<T: Int> = Triple<T,T,T>
// type IntQuad<T: Int> = Quad<T,T,T,T>
//
// import lyng.buffer
// type Tag = String | Buffer
//
// class X {
// var tags: Set<Tag> = Set()
// }
// val x = X()
// x.tags += "tag1"
// assertEquals(Set("tag1"), x.tags)
// x.tags += "tag2"
// assertEquals(Set("tag1", "tag2"), x.tags)
// x.tags += Buffer("tag3")
// assertEquals(Set("tag1", "tag2", Buffer("tag3")), x.tags)
// x.tags += Buffer("tag4")
// assertEquals(Set("tag1", "tag2", Buffer("tag3"), Buffer("tag4")), x.tags)
// x.tags += "tag3"
// x.tags += "tag4"
// assertEquals(Set("tag1", "tag2", Buffer("tag3"), Buffer("tag4")), x.tags)
// """)
// }
@Test
fun testAliasesInGenerics1() = runTest {
val scope = Script.newScope()
scope.eval("""
type IntList<T: Int> = List<T>
type IntMap<K,V> = Map<K,V>
type IntSet<T: Int> = Set<T>
type IntPair<T: Int> = Pair<T,T>
type IntTriple<T: Int> = Triple<T,T,T>
type IntQuad<T: Int> = Quad<T,T,T,T>
import lyng.buffer
type Tag = String | Buffer
class X {
var tags: Set<Tag> = Set()
}
val x = X()
x.tags += "tag1"
assertEquals(Set("tag1"), x.tags)
x.tags += "tag2"
assertEquals(Set("tag1", "tag2"), x.tags)
x.tags += Buffer("tag3")
assertEquals(Set("tag1", "tag2", Buffer("tag3")), x.tags)
x.tags += Buffer("tag4")
assertEquals(Set("tag1", "tag2", Buffer("tag3"), Buffer("tag4")), x.tags)
""")
scope.eval("""
assert(x is X)
x.tags += "42"
assertEquals(Set("tag1", "tag2", Buffer("tag3"), Buffer("tag4"), "42"), x.tags)
""".trimIndent())
// now this must fail becaise element type does not match the declared:
assertFailsWith<ScriptError> {
scope.eval(
"""
x.tags += 42
""".trimIndent()
)
}
}
@Test
fun testAliasesInGenericsList1() = runTest {
val scope = Script.newScope()
scope.eval("""
import lyng.buffer
type Tag = String | Buffer
class X {
var tags: List<Tag> = List()
}
val x = X()
x.tags += "tag1"
assertEquals(List("tag1"), x.tags)
x.tags += "tag2"
assertEquals(List("tag1", "tag2"), x.tags)
x.tags += Buffer("tag3")
assertEquals(List("tag1", "tag2", Buffer("tag3")), x.tags)
x.tags += ["tag4", Buffer("tag5")]
assertEquals(List("tag1", "tag2", Buffer("tag3"), "tag4", Buffer("tag5")), x.tags)
""")
scope.eval("""
assert(x is X)
x.tags += "42"
assertEquals(List("tag1", "tag2", Buffer("tag3"), "tag4", Buffer("tag5"), "42"), x.tags)
""".trimIndent())
assertFailsWith<ScriptError> {
scope.eval(
"""
x.tags += 42
""".trimIndent()
)
}
}
}