From 40de53f6884468ca50eeb598454a59b11adfded0 Mon Sep 17 00:00:00 2001 From: sergeych Date: Thu, 5 Feb 2026 20:14:09 +0300 Subject: [PATCH] Define type expression checks for unions --- AGENTS.md | 1 + .../kotlin/net/sergeych/lyng/Compiler.kt | 77 +++++--- .../net/sergeych/lyng/bytecode/CmdRuntime.kt | 28 ++- .../kotlin/net/sergeych/lyng/obj/ObjRef.kt | 20 +- .../net/sergeych/lyng/obj/ObjTypeExpr.kt | 175 +++++++++++++++++- lynglib/src/commonTest/kotlin/TypesTest.kt | 24 ++- notes/new_lyng_type_system_spec.md | 12 ++ 7 files changed, 300 insertions(+), 37 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index cdff595..2c375fb 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -11,6 +11,7 @@ - Nullability is Kotlin-style: `T` non-null, `T?` nullable, `!!` asserts non-null. - `void` is a singleton of class `Void` (syntax sugar for return type). - Object members are always allowed even on unknown types; non-Object members require explicit casts. Remove `inspect` from Object and use `toInspectString()` instead. +- Type expression checks: `x is T` is value instance check; `T1 is T2` is type-subset; `A in T` means `A` is subset of `T`; `==` is structural type equality. - Do not reintroduce bytecode fallback opcodes (e.g., `GET_NAME`, `EVAL_*`, `CALL_FALLBACK`) or runtime name-resolution fallbacks; all symbol resolution must stay compile-time only. ## Bytecode frame-first migration plan diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt index b73d9dc..b3f1e1e 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt @@ -3706,7 +3706,7 @@ class Compiler( typeParams: List ) { if (typeParams.isEmpty()) return - val inferred = mutableMapOf() + val inferred = mutableMapOf() for (param in argsDeclaration.params) { val rec = context.getLocalRecordDirect(param.name) ?: continue val value = rec.value @@ -3715,13 +3715,17 @@ class Compiler( } } for (tp in typeParams) { - val cls = inferred[tp.name] - ?: tp.defaultType?.let { resolveTypeDeclObjClass(it) } - ?: Obj.rootObjectType - context.addConst(tp.name, cls) + val inferredType = inferred[tp.name] ?: tp.defaultType ?: TypeDecl.TypeAny + val normalized = normalizeRuntimeTypeDecl(inferredType) + val cls = resolveTypeDeclObjClass(normalized) + if (cls != null && !normalized.isNullable && normalized !is TypeDecl.Union && normalized !is TypeDecl.Intersection) { + context.addConst(tp.name, cls) + } else { + context.addConst(tp.name, net.sergeych.lyng.obj.ObjTypeExpr(normalized)) + } val bound = tp.bound ?: continue - if (!typeParamBoundSatisfied(cls, bound)) { - context.raiseError("type argument ${cls.className} does not satisfy bound ${typeDeclName(bound)}") + if (!typeDeclSatisfiesBound(normalized, bound)) { + context.raiseError("type argument ${typeDeclName(normalized)} does not satisfy bound ${typeDeclName(bound)}") } } } @@ -3729,42 +3733,69 @@ class Compiler( private fun collectRuntimeTypeVarBindings( paramType: TypeDecl, value: Obj, - inferred: MutableMap + inferred: MutableMap ) { when (paramType) { is TypeDecl.TypeVar -> { if (value !== ObjNull) { - inferred[paramType.name] = value.objClass + inferred[paramType.name] = inferRuntimeTypeDecl(value) } } is TypeDecl.Generic -> { val base = paramType.name.substringAfterLast('.') val arg = paramType.args.firstOrNull() if (base == "List" && arg is TypeDecl.TypeVar && value is ObjList) { - val elementClass = inferListElementClass(value) - inferred[arg.name] = elementClass + val elementType = inferListElementTypeDecl(value) + inferred[arg.name] = elementType } } else -> {} } } - private fun inferListElementClass(list: ObjList): ObjClass { - var elemClass: ObjClass? = null + private fun inferRuntimeTypeDecl(value: Obj): TypeDecl { + return when (value) { + is ObjInt -> TypeDecl.Simple("Int", false) + is ObjReal -> TypeDecl.Simple("Real", false) + is ObjString -> TypeDecl.Simple("String", false) + is ObjBool -> TypeDecl.Simple("Bool", false) + is ObjChar -> TypeDecl.Simple("Char", false) + is ObjNull -> TypeDecl.TypeNullableAny + is ObjList -> TypeDecl.Generic("List", listOf(inferListElementTypeDecl(value)), false) + is ObjMap -> TypeDecl.Generic("Map", listOf(TypeDecl.TypeAny, TypeDecl.TypeAny), false) + is ObjClass -> TypeDecl.Simple(value.className, false) + else -> TypeDecl.Simple(value.objClass.className, false) + } + } + + private fun inferListElementTypeDecl(list: ObjList): TypeDecl { + var nullable = false + val options = mutableListOf() + val seen = mutableSetOf() for (elem in list.list) { if (elem === ObjNull) { - elemClass = Obj.rootObjectType - break - } - val cls = elem.objClass - if (elemClass == null) { - elemClass = cls - } else if (elemClass != cls) { - elemClass = Obj.rootObjectType - break + nullable = true + continue } + val elemType = inferRuntimeTypeDecl(elem) + val base = stripNullable(elemType).first + val key = typeDeclKey(base) + if (seen.add(key)) options += base + } + val base = when { + options.isEmpty() -> TypeDecl.TypeAny + options.size == 1 -> options[0] + else -> TypeDecl.Union(options, nullable = false) + } + return if (nullable) makeTypeDeclNullable(base) else base + } + + private fun normalizeRuntimeTypeDecl(type: TypeDecl): TypeDecl { + return when (type) { + is TypeDecl.Union -> TypeDecl.Union(type.options.distinctBy { typeDeclKey(it) }, type.isNullable) + is TypeDecl.Intersection -> TypeDecl.Intersection(type.options.distinctBy { typeDeclKey(it) }, type.isNullable) + else -> type } - return elemClass ?: Obj.rootObjectType } private fun resolveLocalTypeRef(name: String, pos: Pos): ObjRef? { diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/CmdRuntime.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/CmdRuntime.kt index d1cd994..1db8b39 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/CmdRuntime.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/CmdRuntime.kt @@ -207,9 +207,14 @@ class CmdCheckIs(internal val objSlot: Int, internal val typeSlot: Int, internal override suspend fun perform(frame: CmdFrame) { val obj = frame.slotToObj(objSlot) val typeObj = frame.slotToObj(typeSlot) - val result = when (typeObj) { - is ObjTypeExpr -> matchesTypeDecl(frame.ensureScope(), obj, typeObj.typeDecl) - is ObjClass -> obj.isInstanceOf(typeObj) + val result = when { + (obj is ObjTypeExpr || obj is ObjClass) && (typeObj is ObjTypeExpr || typeObj is ObjClass) -> { + val leftDecl = typeDeclFromObj(frame.ensureScope(), obj) ?: return frame.setBool(dst, false) + val rightDecl = typeDeclFromObj(frame.ensureScope(), typeObj) ?: return frame.setBool(dst, false) + typeDeclIsSubtype(frame.ensureScope(), leftDecl, rightDecl) + } + typeObj is ObjTypeExpr -> matchesTypeDecl(frame.ensureScope(), obj, typeObj.typeDecl) + typeObj is ObjClass -> obj.isInstanceOf(typeObj) else -> false } frame.setBool(dst, result) @@ -1020,7 +1025,22 @@ class CmdModObj(internal val a: Int, internal val b: Int, internal val dst: Int) class CmdContainsObj(internal val target: Int, internal val value: Int, internal val dst: Int) : Cmd() { override suspend fun perform(frame: CmdFrame) { - frame.setBool(dst, frame.slotToObj(target).contains(frame.ensureScope(), frame.slotToObj(value))) + val targetObj = frame.slotToObj(target) + val valueObj = frame.slotToObj(value) + val result = if ((targetObj is ObjTypeExpr || targetObj is ObjClass) && + (valueObj is ObjTypeExpr || valueObj is ObjClass) + ) { + val leftDecl = typeDeclFromObj(frame.ensureScope(), valueObj) + val rightDecl = typeDeclFromObj(frame.ensureScope(), targetObj) + if (leftDecl != null && rightDecl != null) { + typeDeclIsSubtype(frame.ensureScope(), leftDecl, rightDecl) + } else { + false + } + } else { + targetObj.contains(frame.ensureScope(), valueObj) + } + frame.setBool(dst, result) return } } diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/ObjRef.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/ObjRef.kt index 401517c..3f17917 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/ObjRef.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/ObjRef.kt @@ -152,9 +152,23 @@ class BinaryOpRef(internal val op: BinOp, internal val left: ObjRef, internal va val a = left.evalValue(scope) val b = right.evalValue(scope) if (op == BinOp.IS || op == BinOp.NOTIS) { - if (b is ObjTypeExpr) { - val result = matchesTypeDecl(scope, a, b.typeDecl) - return if (op == BinOp.NOTIS) ObjBool(!result) else ObjBool(result) + val result = when { + (a is ObjTypeExpr || a is ObjClass) && (b is ObjTypeExpr || b is ObjClass) -> { + val leftDecl = typeDeclFromObj(scope, a) ?: return ObjBool(false) + val rightDecl = typeDeclFromObj(scope, b) ?: return ObjBool(false) + typeDeclIsSubtype(scope, leftDecl, rightDecl) + } + b is ObjTypeExpr -> matchesTypeDecl(scope, a, b.typeDecl) + else -> a.isInstanceOf(b) + } + return if (op == BinOp.NOTIS) ObjBool(!result) else ObjBool(result) + } + if (op == BinOp.IN || op == BinOp.NOTIN) { + if ((b is ObjTypeExpr || b is ObjClass) && (a is ObjTypeExpr || a is ObjClass)) { + val leftDecl = typeDeclFromObj(scope, a) ?: return ObjBool(op == BinOp.NOTIN) + val rightDecl = typeDeclFromObj(scope, b) ?: return ObjBool(op == BinOp.NOTIN) + val result = typeDeclIsSubtype(scope, leftDecl, rightDecl) + return if (op == BinOp.NOTIN) ObjBool(!result) else ObjBool(result) } } diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/ObjTypeExpr.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/ObjTypeExpr.kt index fb0afe1..61b7e7d 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/ObjTypeExpr.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/ObjTypeExpr.kt @@ -22,7 +22,20 @@ import net.sergeych.lyng.TypeDecl /** * Runtime wrapper for a type expression (including unions/intersections) used by `is` checks. */ -class ObjTypeExpr(val typeDecl: TypeDecl) : Obj() +class ObjTypeExpr(val typeDecl: TypeDecl) : Obj() { + override suspend fun equals(scope: Scope, other: Obj): Boolean { + val otherDecl = typeDeclFromObj(scope, other) ?: return false + val leftKey = typeDeclKey(normalizeTypeDecl(scope, typeDecl)) + val rightKey = typeDeclKey(normalizeTypeDecl(scope, otherDecl)) + return leftKey == rightKey + } + + override suspend fun contains(scope: Scope, other: Obj): Boolean { + val leftDecl = typeDeclFromObj(scope, other) ?: return false + val rightDecl = normalizeTypeDecl(scope, typeDecl) + return typeDeclIsSubtype(scope, leftDecl, rightDecl) + } +} internal fun matchesTypeDecl(scope: Scope, value: Obj, typeDecl: TypeDecl): Boolean { if (value === ObjNull) { @@ -54,3 +67,163 @@ internal fun matchesTypeDecl(scope: Scope, value: Obj, typeDecl: TypeDecl): Bool is TypeDecl.Intersection -> typeDecl.options.all { matchesTypeDecl(scope, value, it) } } } + +internal fun typeDeclFromObj(scope: Scope, value: Obj): TypeDecl? { + return when (value) { + is ObjTypeExpr -> normalizeTypeDecl(scope, value.typeDecl) + is ObjClass -> TypeDecl.Simple(value.className, false) + else -> null + } +} + +internal fun typeDeclIsSubtype(scope: Scope, left: TypeDecl, right: TypeDecl): Boolean { + val lNorm = normalizeTypeDecl(scope, left) + val rNorm = normalizeTypeDecl(scope, right) + val lNullable = lNorm.isNullable || lNorm is TypeDecl.TypeNullableAny + val rNullable = rNorm.isNullable || rNorm is TypeDecl.TypeNullableAny + if (lNullable && !rNullable) return false + val l = stripNullable(lNorm) + val r = stripNullable(rNorm) + if (r == TypeDecl.TypeAny || r == TypeDecl.TypeNullableAny) return true + if (l == TypeDecl.TypeAny) return r == TypeDecl.TypeAny || r == TypeDecl.TypeNullableAny + if (l == TypeDecl.TypeNullableAny) return r == TypeDecl.TypeNullableAny + return when (l) { + is TypeDecl.Union -> l.options.all { typeDeclIsSubtype(scope, it, r) } + is TypeDecl.Intersection -> l.options.any { typeDeclIsSubtype(scope, it, r) } + else -> when (r) { + is TypeDecl.Union -> r.options.any { typeDeclIsSubtype(scope, l, it) } + is TypeDecl.Intersection -> r.options.all { typeDeclIsSubtype(scope, l, it) } + is TypeDecl.Simple, is TypeDecl.Generic, is TypeDecl.Function -> { + val leftClass = resolveTypeDeclClass(scope, l) ?: return false + val rightClass = resolveTypeDeclClass(scope, r) ?: return false + leftClass == rightClass || leftClass.allParentsSet.contains(rightClass) + } + else -> false + } + } +} + +private fun normalizeTypeDecl(scope: Scope, decl: TypeDecl): TypeDecl { + val resolved = if (decl is TypeDecl.TypeVar) { + val bound = scope[decl.name]?.value + when (bound) { + is ObjTypeExpr -> bound.typeDecl + is ObjClass -> TypeDecl.Simple(bound.className, decl.isNullable) + else -> decl + } + } else decl + return when (resolved) { + is TypeDecl.Union -> normalizeUnion(scope, resolved) + is TypeDecl.Intersection -> normalizeIntersection(scope, resolved) + else -> resolved + } +} + +private fun normalizeUnion(scope: Scope, decl: TypeDecl.Union): TypeDecl { + val options = mutableListOf() + var nullable = decl.isNullable + for (opt in decl.options) { + val norm = normalizeTypeDecl(scope, opt) + if (norm is TypeDecl.TypeNullableAny) nullable = true + val base = stripNullable(norm) + if (base == TypeDecl.TypeAny) return if (nullable) TypeDecl.TypeNullableAny else TypeDecl.TypeAny + if (base is TypeDecl.Union) { + options.addAll(base.options) + } else { + options += base + } + nullable = nullable || norm.isNullable + } + val unique = options.distinctBy { typeDeclKey(it) }.sortedBy { typeDeclKey(it) } + val base = if (unique.size == 1) unique[0] else TypeDecl.Union(unique, nullable = false) + return if (nullable) makeNullable(base) else base +} + +private fun normalizeIntersection(scope: Scope, decl: TypeDecl.Intersection): TypeDecl { + val options = mutableListOf() + var nullable = decl.isNullable + for (opt in decl.options) { + val norm = normalizeTypeDecl(scope, opt) + val base = stripNullable(norm) + if (base == TypeDecl.TypeAny) { + nullable = nullable || norm.isNullable + continue + } + if (base is TypeDecl.Intersection) { + options.addAll(base.options) + } else { + options += base + } + nullable = nullable || norm.isNullable + } + val unique = options.distinctBy { typeDeclKey(it) }.sortedBy { typeDeclKey(it) } + val base = when { + unique.isEmpty() -> TypeDecl.TypeAny + unique.size == 1 -> unique[0] + else -> TypeDecl.Intersection(unique, nullable = false) + } + return if (nullable) makeNullable(base) else base +} + +private fun stripNullable(type: TypeDecl): TypeDecl { + return if (!type.isNullable && type !is TypeDecl.TypeNullableAny) { + type + } else { + when (type) { + is TypeDecl.Function -> type.copy(nullable = false) + is TypeDecl.TypeVar -> type.copy(nullable = false) + is TypeDecl.Union -> type.copy(nullable = false) + is TypeDecl.Intersection -> type.copy(nullable = false) + is TypeDecl.Simple -> TypeDecl.Simple(type.name, false) + is TypeDecl.Generic -> TypeDecl.Generic(type.name, type.args, false) + else -> TypeDecl.TypeAny + } + } +} + +private fun makeNullable(type: TypeDecl): TypeDecl { + return when (type) { + TypeDecl.TypeAny -> TypeDecl.TypeNullableAny + TypeDecl.TypeNullableAny -> type + is TypeDecl.Function -> type.copy(nullable = true) + is TypeDecl.TypeVar -> type.copy(nullable = true) + is TypeDecl.Union -> type.copy(nullable = true) + is TypeDecl.Intersection -> type.copy(nullable = true) + is TypeDecl.Simple -> TypeDecl.Simple(type.name, true) + is TypeDecl.Generic -> TypeDecl.Generic(type.name, type.args, true) + } +} + +private fun typeDeclKey(type: TypeDecl): String = when (type) { + TypeDecl.TypeAny -> "Any" + TypeDecl.TypeNullableAny -> "Any?" + is TypeDecl.Simple -> "S:${type.name}" + is TypeDecl.Generic -> "G:${type.name}<${type.args.joinToString(",") { typeDeclKey(it) }}>" + is TypeDecl.Function -> "F:(${type.params.joinToString(",") { typeDeclKey(it) }})->${typeDeclKey(type.returnType)}" + is TypeDecl.TypeVar -> "V:${type.name}" + is TypeDecl.Union -> "U:${type.options.joinToString("|") { typeDeclKey(it) }}" + is TypeDecl.Intersection -> "I:${type.options.joinToString("&") { typeDeclKey(it) }}" +} + +private fun resolveTypeDeclClass(scope: Scope, type: TypeDecl): ObjClass? { + return when (type) { + is TypeDecl.Simple -> { + val direct = scope[type.name]?.value as? ObjClass + direct ?: scope[type.name.substringAfterLast('.')]?.value as? ObjClass + } + is TypeDecl.Generic -> { + val direct = scope[type.name]?.value as? ObjClass + direct ?: scope[type.name.substringAfterLast('.')]?.value as? ObjClass + } + is TypeDecl.Function -> scope["Callable"]?.value as? ObjClass + is TypeDecl.TypeVar -> { + val bound = scope[type.name]?.value + when (bound) { + is ObjClass -> bound + is ObjTypeExpr -> resolveTypeDeclClass(scope, bound.typeDecl) + else -> null + } + } + else -> null + } +} diff --git a/lynglib/src/commonTest/kotlin/TypesTest.kt b/lynglib/src/commonTest/kotlin/TypesTest.kt index 02f5923..3c4dea0 100644 --- a/lynglib/src/commonTest/kotlin/TypesTest.kt +++ b/lynglib/src/commonTest/kotlin/TypesTest.kt @@ -202,15 +202,24 @@ class TypesTest { } @Test - fun testUnioTypeLists() = runTest { + fun testUnionTypeLists() = runTest { eval(""" - - fun f(list: List) { + + fun fMixed(list: List) { println(list) println(T) + assert( T is Int | String | Bool ) + assert( !(T is Int) ) + assert( Int in T ) + assert( String in T ) } - f([1, "two", true]) - f([1,2,3]) + fun fInts(list: List) { + assert( T is Int ) + assert( Int in T ) + assert( !(String in T) ) + } + fMixed([1, "two", true]) + fInts([1,2,3]) """) } @@ -226,8 +235,11 @@ class TypesTest { R2("t").apply { assertEquals("r2", r2) assertEquals("t", shared) - assertEquals("r1", this@R1.r1) + assertEquals("s", this@R1.shared) // actually we have now this of union type R1 & R2! +// println(this::class) + assert( this@R2 is R2 ) + assert( this@R1 is R1 ) } } """) diff --git a/notes/new_lyng_type_system_spec.md b/notes/new_lyng_type_system_spec.md index a50e6ea..10ac78f 100644 --- a/notes/new_lyng_type_system_spec.md +++ b/notes/new_lyng_type_system_spec.md @@ -234,6 +234,18 @@ Object methods: - keep `toString()` as Object method - if we need extra metadata later, use explicit helpers like `Object.getHashCode(obj)` +- Type expression checks (unions/intersections): + - Value check: `x is T` is runtime instance check (as usual). + - Type check: `T1 is T2` means type-subset (all values of `T1` fit in `T2`). + - Exact equality uses `==` and is structural (normalized unions/intersections). + - Includes uses `in`: `A in T` means `A` is a subset of `T`. + - Examples (T = A | B): + - `T == A` is false + - `T is A` is false + - `A in T` is true + - `B in T` is true + - `T is A | B` is true + - Builtin classes inheritance: Are Int/String final? If so, is "class T: String, Int" forbidden (and thus Int & String is unsatisfiable but still allowed)? What keyword we did used for final vals/vars/funs? "closed"? Anyway I am uncertain whether to make Int or String closed, it is a discussion subject. But if we have some closed independent classes A, B, is a compile time error.