diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/ClassInstanceDeclStatements.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/ClassInstanceDeclStatements.kt index fac98d0..96c590f 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/ClassInstanceDeclStatements.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/ClassInstanceDeclStatements.kt @@ -33,6 +33,7 @@ class ClassInstanceFieldDeclStatement( val isMutable: Boolean, val visibility: Visibility, val writeVisibility: Visibility?, + val typeDecl: TypeDecl?, val isAbstract: Boolean, val isClosed: Boolean, val isOverride: Boolean, diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt index d6cf578..0fcb75c 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt @@ -212,6 +212,9 @@ class Compiler( scopeSeedNames.add(name) if (record.typeDecl != null && nameTypeDecl[name] == null) { nameTypeDecl[name] = record.typeDecl + if (nameObjClass[name] == null) { + resolveTypeDeclObjClass(record.typeDecl)?.let { nameObjClass[name] = it } + } } val instance = record.value as? ObjInstance if (instance != null && nameObjClass[name] == null) { @@ -291,6 +294,9 @@ class Compiler( scopeSeedNames.add(name) if (record.typeDecl != null && nameTypeDecl[name] == null) { nameTypeDecl[name] = record.typeDecl + if (nameObjClass[name] == null) { + resolveTypeDeclObjClass(record.typeDecl)?.let { nameObjClass[name] = it } + } } if (record.typeDecl != null) { slotTypeDeclByScopeId.getOrPut(plan.id) { mutableMapOf() }[slotIndex] = record.typeDecl @@ -1208,6 +1214,11 @@ class Compiler( for ((name, record) in current.objects) { if (!record.visibility.isPublic) continue if (nameObjClass.containsKey(name)) continue + val declaredClass = record.typeDecl?.let { resolveTypeDeclObjClass(it) } + if (declaredClass != null) { + nameObjClass[name] = declaredClass + continue + } val resolved = when (val raw = record.value) { is FrameSlotRef -> raw.peekValue() ?: raw.read() is RecordSlotRef -> raw.peekValue() ?: raw.read() @@ -2519,6 +2530,9 @@ class Compiler( } else { val rvalue = parseExpressionLevel(level + 1) ?: throw ScriptError(opToken.pos, "Expecting expression") + if (opToken.type == Token.Type.PLUSASSIGN) { + checkCollectionPlusAssignTypes(lvalue!!, rvalue, opToken.pos) + } op.generate(opToken.pos, lvalue!!, rvalue) } if (opToken.type == Token.Type.ASSIGN) { @@ -4197,6 +4211,22 @@ class Compiler( is ListLiteralRef -> inferListLiteralTypeDecl(ref) is MapLiteralRef -> inferMapLiteralTypeDecl(ref) is ConstRef -> inferTypeDeclFromConst(ref.constValue) + is CallRef -> { + inferCallReturnClass(ref)?.let { TypeDecl.Simple(it.className, false) } + ?: run { + val targetName = when (val target = ref.target) { + is LocalVarRef -> target.name + is FastLocalVarRef -> target.name + is LocalSlotRef -> target.name + else -> null + } + if (targetName != null && targetName.firstOrNull()?.isUpperCase() == true) { + TypeDecl.Simple(targetName, false) + } else { + null + } + } + } else -> null } } @@ -4351,6 +4381,66 @@ class Compiler( return TypeDecl.TypeAny } + private fun inferCollectionElementType(typeDecl: TypeDecl): TypeDecl? { + val generic = typeDecl as? TypeDecl.Generic ?: return null + val base = generic.name.substringAfterLast('.') + return when (base) { + "Set", "List", "Iterable", "Collection", "Array" -> generic.args.firstOrNull() + else -> null + } + } + + private fun typeDeclSubtypeOf(arg: TypeDecl, param: TypeDecl): Boolean { + if (param == TypeDecl.TypeAny || param == TypeDecl.TypeNullableAny) return true + val (argBase, argNullable) = stripNullable(arg) + val (paramBase, paramNullable) = stripNullable(param) + if (argNullable && !paramNullable) return false + if (paramBase == TypeDecl.TypeAny) return true + if (paramBase is TypeDecl.TypeVar) return true + if (argBase is TypeDecl.TypeVar) return true + if (paramBase is TypeDecl.Simple && (paramBase.name == "Object" || paramBase.name == "Obj")) return true + if (argBase is TypeDecl.Ellipsis) return typeDeclSubtypeOf(argBase.elementType, paramBase) + if (paramBase is TypeDecl.Ellipsis) return typeDeclSubtypeOf(argBase, paramBase.elementType) + return when (argBase) { + is TypeDecl.Union -> argBase.options.all { typeDeclSubtypeOf(it, paramBase) } + is TypeDecl.Intersection -> argBase.options.any { typeDeclSubtypeOf(it, paramBase) } + else -> when (paramBase) { + is TypeDecl.Union -> paramBase.options.any { typeDeclSubtypeOf(argBase, it) } + is TypeDecl.Intersection -> paramBase.options.all { typeDeclSubtypeOf(argBase, it) } + else -> { + val argClass = resolveTypeDeclObjClass(argBase) ?: return false + val paramClass = resolveTypeDeclObjClass(paramBase) ?: return false + argClass == paramClass || argClass.allParentsSet.contains(paramClass) + } + } + } + } + + private fun checkCollectionPlusAssignTypes(targetRef: ObjRef, valueRef: ObjRef, pos: Pos) { + // Enforce strict compile-time element checks for declared members. + // Local vars can be inferred from literals and are allowed to widen dynamically. + if (targetRef !is FieldRef) return + val targetDeclRaw = resolveReceiverTypeDecl(targetRef) ?: return + val targetDecl = expandTypeAliases(targetDeclRaw, pos) + val targetGeneric = targetDecl as? TypeDecl.Generic ?: return + val targetBase = targetGeneric.name.substringAfterLast('.') + if (targetBase != "Set" && targetBase != "List") return + val elementRaw = targetGeneric.args.firstOrNull() ?: return + val elementDecl = expandTypeAliases(elementRaw, pos) + val valueDeclRaw = inferTypeDeclFromRef(valueRef) ?: return + val valueDecl = expandTypeAliases(valueDeclRaw, pos) + + if (typeDeclSubtypeOf(valueDecl, elementDecl)) return + + val sourceElementDecl = inferCollectionElementType(valueDecl)?.let { expandTypeAliases(it, pos) } + if (sourceElementDecl != null && typeDeclSubtypeOf(sourceElementDecl, elementDecl)) return + + throw ScriptError( + pos, + "argument type ${typeDeclName(valueDecl)} does not match ${typeDeclName(elementDecl)} for '+='" + ) + } + private fun stripNullable(type: TypeDecl): Pair { if (type is TypeDecl.TypeNullableAny) return TypeDecl.TypeAny to true val nullable = type.isNullable @@ -4425,7 +4515,7 @@ class Compiler( is FastLocalVarRef -> nameTypeDecl[ref.name] ?: seedTypeDeclByName(ref.name) is FieldRef -> { val targetDecl = resolveReceiverTypeDecl(ref.target) ?: return null - val targetClass = resolveTypeDeclObjClass(targetDecl) + val targetClass = resolveTypeDeclObjClass(targetDecl) ?: resolveReceiverClassForMember(ref.target) targetClass?.getInstanceMemberOrNull(ref.name, includeAbstract = true)?.typeDecl?.let { return it } classFieldTypesByName[targetClass?.className]?.get(ref.name) ?.let { return TypeDecl.Simple(it.className, false) } @@ -9039,6 +9129,7 @@ class Compiler( isMutable = isMutable, visibility = visibility, writeVisibility = setterVisibility, + typeDecl = if (varTypeDecl == TypeDecl.TypeAny || varTypeDecl == TypeDecl.TypeNullableAny) null else varTypeDecl, isAbstract = isAbstract, isClosed = isClosed, isOverride = isOverride, diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt index 72a042b..d5134de 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt @@ -5174,6 +5174,7 @@ class BytecodeCompiler( isMutable = stmt.isMutable, visibility = stmt.visibility, writeVisibility = stmt.writeVisibility, + typeDecl = stmt.typeDecl, isTransient = stmt.isTransient, isAbstract = stmt.isAbstract, isClosed = stmt.isClosed, @@ -6998,7 +6999,9 @@ class BytecodeCompiler( val slot = resolveSlot(ref) val fromSlot = slot?.let { slotObjClass[it] } fromSlot + ?: slot?.let { typeDeclForSlot(it) }?.let { resolveClassFromTypeDecl(it) } ?: slotTypeByScopeId[ownerScopeId]?.get(ownerSlot) + ?: slotTypeDeclByScopeId[ownerScopeId]?.get(ownerSlot)?.let { resolveClassFromTypeDecl(it) } ?: nameObjClass[ref.name] ?: resolveTypeNameClass(ref.name) ?: slotInitClassByKey[ScopeSlotKey(ownerScopeId, ownerSlot)] @@ -7016,9 +7019,14 @@ class BytecodeCompiler( } val fromSlot = resolveDirectNameSlot(ref.name)?.let { slotObjClass[it.slot] } if (fromSlot != null) return fromSlot + val fromDirectTypeDecl = resolveDirectNameSlot(ref.name) + ?.let { typeDeclForSlot(it.slot) } + ?.let { resolveClassFromTypeDecl(it) } + if (fromDirectTypeDecl != null) return fromDirectTypeDecl val key = localSlotInfoMap.entries.firstOrNull { it.value.name == ref.name }?.key key?.let { slotTypeByScopeId[it.scopeId]?.get(it.slot) + ?: slotTypeDeclByScopeId[it.scopeId]?.get(it.slot)?.let { decl -> resolveClassFromTypeDecl(decl) } ?: slotInitClassByKey[it] } ?: nameObjClass[ref.name] ?: resolveTypeNameClass(ref.name) @@ -7029,9 +7037,14 @@ class BytecodeCompiler( } val fromSlot = resolveDirectNameSlot(ref.name)?.let { slotObjClass[it.slot] } if (fromSlot != null) return fromSlot + val fromDirectTypeDecl = resolveDirectNameSlot(ref.name) + ?.let { typeDeclForSlot(it.slot) } + ?.let { resolveClassFromTypeDecl(it) } + if (fromDirectTypeDecl != null) return fromDirectTypeDecl val key = localSlotInfoMap.entries.firstOrNull { it.value.name == ref.name }?.key key?.let { slotTypeByScopeId[it.scopeId]?.get(it.slot) + ?: slotTypeDeclByScopeId[it.scopeId]?.get(it.slot)?.let { decl -> resolveClassFromTypeDecl(decl) } ?: slotInitClassByKey[it] } ?: nameObjClass[ref.name] ?: resolveTypeNameClass(ref.name) @@ -7073,6 +7086,23 @@ class BytecodeCompiler( } } + private fun resolveClassFromTypeDecl(typeDecl: TypeDecl): ObjClass? { + return when (typeDecl) { + is TypeDecl.Simple -> { + resolveTypeNameClass(typeDecl.name) ?: nameObjClass[typeDecl.name]?.let { cls -> + if (cls == ObjClassType) ObjDynamic.type else cls + } + } + is TypeDecl.Generic -> { + resolveTypeNameClass(typeDecl.name) ?: nameObjClass[typeDecl.name]?.let { cls -> + if (cls == ObjClassType) ObjDynamic.type else cls + } + } + is TypeDecl.Ellipsis -> resolveClassFromTypeDecl(typeDecl.elementType) + else -> null + } + } + private fun isKnownClassReceiver(ref: ObjRef): Boolean { return when (ref) { is LocalVarRef -> { diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeConst.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeConst.kt index 90b8fce..4a546b0 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeConst.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeConst.kt @@ -100,6 +100,7 @@ sealed class BytecodeConst { val isMutable: Boolean, val visibility: Visibility, val writeVisibility: Visibility?, + val typeDecl: TypeDecl?, val isTransient: Boolean, val isAbstract: Boolean, val isClosed: Boolean, diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeStatement.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeStatement.kt index 241e5cc..982a879 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeStatement.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeStatement.kt @@ -348,6 +348,7 @@ class BytecodeStatement private constructor( stmt.isMutable, stmt.visibility, stmt.writeVisibility, + stmt.typeDecl, stmt.isAbstract, stmt.isClosed, stmt.isOverride, diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/CmdRuntime.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/CmdRuntime.kt index 8f071bb..d9a496c 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/CmdRuntime.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/CmdRuntime.kt @@ -2750,6 +2750,7 @@ class CmdDeclClassInstanceField(internal val constId: Int, internal val slot: In isClosed = decl.isClosed, isOverride = decl.isOverride, isTransient = decl.isTransient, + typeDecl = decl.typeDecl, type = ObjRecord.Type.Field, fieldId = decl.fieldId ) diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/ObjClass.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/ObjClass.kt index dee2282..a09a22d 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/ObjClass.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/ObjClass.kt @@ -826,6 +826,7 @@ open class ObjClass( type: ObjRecord.Type = ObjRecord.Type.Field, fieldId: Int? = null, methodId: Int? = null, + typeDecl: net.sergeych.lyng.TypeDecl? = null, ): ObjRecord { // Validation of override rules: only for non-system declarations var existing: ObjRecord? = null @@ -921,6 +922,7 @@ open class ObjClass( isOverride = isOverride, isTransient = isTransient, type = type, + typeDecl = typeDecl, memberName = name, fieldId = effectiveFieldId, methodId = effectiveMethodId diff --git a/lynglib/src/commonTest/kotlin/TypesTest.kt b/lynglib/src/commonTest/kotlin/TypesTest.kt index d0275ab..2318feb 100644 --- a/lynglib/src/commonTest/kotlin/TypesTest.kt +++ b/lynglib/src/commonTest/kotlin/TypesTest.kt @@ -20,6 +20,7 @@ import net.sergeych.lyng.Script import net.sergeych.lyng.ScriptError import net.sergeych.lyng.eval import kotlin.test.Test +import kotlin.test.assertEquals import kotlin.test.assertFailsWith import kotlin.test.assertTrue @@ -433,35 +434,80 @@ class TypesTest { """.trimIndent()) } -// @Test -// fun testAliasesInGenerics1() = runTest { -// val scope = Script.newScope() -// scope.eval(""" -// type IntList = List -// type IntMap = Map -// type IntSet = Set -// type IntPair = Pair -// type IntTriple = Triple -// type IntQuad = Quad -// -// import lyng.buffer -// type Tag = String | Buffer -// -// class X { -// var tags: Set = Set() -// } -// val x = X() -// x.tags += "tag1" -// assertEquals(Set("tag1"), x.tags) -// x.tags += "tag2" -// assertEquals(Set("tag1", "tag2"), x.tags) -// x.tags += Buffer("tag3") -// assertEquals(Set("tag1", "tag2", Buffer("tag3")), x.tags) -// x.tags += Buffer("tag4") -// assertEquals(Set("tag1", "tag2", Buffer("tag3"), Buffer("tag4")), x.tags) -// x.tags += "tag3" -// x.tags += "tag4" -// assertEquals(Set("tag1", "tag2", Buffer("tag3"), Buffer("tag4")), x.tags) -// """) -// } + @Test + fun testAliasesInGenerics1() = runTest { + val scope = Script.newScope() + scope.eval(""" + type IntList = List + type IntMap = Map + type IntSet = Set + type IntPair = Pair + type IntTriple = Triple + type IntQuad = Quad + + import lyng.buffer + type Tag = String | Buffer + + class X { + var tags: Set = Set() + } + val x = X() + x.tags += "tag1" + assertEquals(Set("tag1"), x.tags) + x.tags += "tag2" + assertEquals(Set("tag1", "tag2"), x.tags) + x.tags += Buffer("tag3") + assertEquals(Set("tag1", "tag2", Buffer("tag3")), x.tags) + x.tags += Buffer("tag4") + assertEquals(Set("tag1", "tag2", Buffer("tag3"), Buffer("tag4")), x.tags) + """) + scope.eval(""" + assert(x is X) + x.tags += "42" + assertEquals(Set("tag1", "tag2", Buffer("tag3"), Buffer("tag4"), "42"), x.tags) + + """.trimIndent()) + // now this must fail becaise element type does not match the declared: + assertFailsWith { + scope.eval( + """ + x.tags += 42 + """.trimIndent() + ) + } + } + + @Test + fun testAliasesInGenericsList1() = runTest { + val scope = Script.newScope() + scope.eval(""" + import lyng.buffer + type Tag = String | Buffer + + class X { + var tags: List = List() + } + val x = X() + x.tags += "tag1" + assertEquals(List("tag1"), x.tags) + x.tags += "tag2" + assertEquals(List("tag1", "tag2"), x.tags) + x.tags += Buffer("tag3") + assertEquals(List("tag1", "tag2", Buffer("tag3")), x.tags) + x.tags += ["tag4", Buffer("tag5")] + assertEquals(List("tag1", "tag2", Buffer("tag3"), "tag4", Buffer("tag5")), x.tags) + """) + scope.eval(""" + assert(x is X) + x.tags += "42" + assertEquals(List("tag1", "tag2", Buffer("tag3"), "tag4", Buffer("tag5"), "42"), x.tags) + """.trimIndent()) + assertFailsWith { + scope.eval( + """ + x.tags += 42 + """.trimIndent() + ) + } + } }