Add union type checks and list literal inference

This commit is contained in:
Sergey Chernov 2026-02-05 19:45:07 +03:00
parent 9db5b12c31
commit c31343a040
7 changed files with 421 additions and 18 deletions

View File

@ -1505,6 +1505,7 @@ class Compiler(
is BinaryOpRef -> containsUnsupportedRef(ref.left) || containsUnsupportedRef(ref.right)
is UnaryOpRef -> containsUnsupportedRef(ref.a)
is CastRef -> containsUnsupportedRef(ref.castValueRef()) || containsUnsupportedRef(ref.castTypeRef())
is net.sergeych.lyng.obj.TypeDeclRef -> false
is AssignRef -> {
val target = ref.target as? LocalSlotRef
(target?.isDelegated == true) || containsUnsupportedRef(ref.value)
@ -1601,6 +1602,7 @@ class Compiler(
is BinaryOpRef -> containsDelegatedRefs(ref.left) || containsDelegatedRefs(ref.right)
is UnaryOpRef -> containsDelegatedRefs(ref.a)
is CastRef -> containsDelegatedRefs(ref.castValueRef()) || containsDelegatedRefs(ref.castTypeRef())
is net.sergeych.lyng.obj.TypeDeclRef -> false
is AssignRef -> {
val target = ref.target as? LocalSlotRef
(target?.isDelegated == true) || containsDelegatedRefs(ref.value)
@ -1830,6 +1832,14 @@ class Compiler(
} else {
CastRef(lvalue!!, typeRef, true, opToken.pos)
}
} else if (opToken.type == Token.Type.IS || opToken.type == Token.Type.NOTIS) {
val (typeDecl, _) = parseTypeExpressionWithMini()
val typeRef = net.sergeych.lyng.obj.TypeDeclRef(typeDecl, opToken.pos)
if (opToken.type == Token.Type.IS) {
BinaryOpRef(BinOp.IS, lvalue!!, typeRef)
} else {
BinaryOpRef(BinOp.NOTIS, lvalue!!, typeRef)
}
} else {
val rvalue = parseExpressionLevel(level + 1)
?: throw ScriptError(opToken.pos, "Expecting expression")
@ -3127,6 +3137,114 @@ class Compiler(
TypeDecl.TypeNullableAny -> "Object?"
}
private fun inferTypeDeclFromInitializer(stmt: Statement): TypeDecl? {
val directRef = unwrapDirectRef(stmt) ?: return null
return inferTypeDeclFromRef(directRef)
}
private fun inferTypeDeclFromRef(ref: ObjRef): TypeDecl? {
resolveReceiverTypeDecl(ref)?.let { return it }
return when (ref) {
is ListLiteralRef -> inferListLiteralTypeDecl(ref)
is ConstRef -> inferTypeDeclFromConst(ref.constValue)
else -> null
}
}
private fun inferTypeDeclFromConst(value: Obj): TypeDecl? = when (value) {
is ObjInt -> TypeDecl.Simple("Int", false)
is ObjReal -> TypeDecl.Simple("Real", false)
is ObjString -> TypeDecl.Simple("String", false)
is ObjBool -> TypeDecl.Simple("Bool", false)
is ObjChar -> TypeDecl.Simple("Char", false)
is ObjNull -> TypeDecl.TypeNullableAny
is ObjList -> TypeDecl.Generic("List", listOf(TypeDecl.TypeAny), false)
is ObjMap -> TypeDecl.Generic("Map", listOf(TypeDecl.TypeAny, TypeDecl.TypeAny), false)
else -> null
}
private fun inferListLiteralTypeDecl(ref: ListLiteralRef): TypeDecl {
val elementType = inferListLiteralElementType(ref.entries())
return TypeDecl.Generic("List", listOf(elementType), false)
}
private fun inferListLiteralElementType(entries: List<ListEntry>): TypeDecl {
var nullable = false
val collected = mutableListOf<TypeDecl>()
val seen = mutableSetOf<String>()
fun addType(type: TypeDecl) {
val (base, isNullable) = stripNullable(type)
nullable = nullable || isNullable
if (base == TypeDecl.TypeAny) {
collected.clear()
collected += base
seen.clear()
seen += typeDeclKey(base)
return
}
val key = typeDeclKey(base)
if (seen.add(key)) {
collected += base
}
}
for (entry in entries) {
val type = when (entry) {
is ListEntry.Element -> inferTypeDeclFromRef(entry.ref)
is ListEntry.Spread -> inferElementTypeFromSpread(entry.ref)
} ?: return if (nullable) TypeDecl.TypeNullableAny else TypeDecl.TypeAny
addType(type)
if (collected.size == 1 && collected[0] == TypeDecl.TypeAny) break
}
if (collected.isEmpty()) return TypeDecl.TypeAny
val base = if (collected.size == 1) {
collected[0]
} else {
TypeDecl.Union(collected.toList(), nullable = false)
}
return if (nullable) makeTypeDeclNullable(base) else base
}
private fun inferElementTypeFromSpread(ref: ObjRef): TypeDecl? {
val listType = inferTypeDeclFromRef(ref) ?: return null
if (listType == TypeDecl.TypeAny || listType == TypeDecl.TypeNullableAny) return listType
if (listType is TypeDecl.Generic) {
val base = listType.name.substringAfterLast('.')
if (base == "List" || base == "Array" || base == "Iterable") {
return listType.args.firstOrNull() ?: TypeDecl.TypeAny
}
}
return TypeDecl.TypeAny
}
private fun stripNullable(type: TypeDecl): Pair<TypeDecl, Boolean> {
if (type is TypeDecl.TypeNullableAny) return TypeDecl.TypeAny to true
val nullable = type.isNullable
val base = if (!nullable) type else when (type) {
is TypeDecl.Function -> type.copy(nullable = false)
is TypeDecl.TypeVar -> type.copy(nullable = false)
is TypeDecl.Union -> type.copy(nullable = false)
is TypeDecl.Intersection -> type.copy(nullable = false)
is TypeDecl.Simple -> TypeDecl.Simple(type.name, false)
is TypeDecl.Generic -> TypeDecl.Generic(type.name, type.args, false)
else -> type
}
return base to nullable
}
private fun typeDeclKey(type: TypeDecl): String = when (type) {
TypeDecl.TypeAny -> "Any"
TypeDecl.TypeNullableAny -> "Any?"
is TypeDecl.Simple -> "S:${type.name}"
is TypeDecl.Generic -> "G:${type.name}<${type.args.joinToString(",") { typeDeclKey(it) }}>"
is TypeDecl.Function -> "F:(${type.params.joinToString(",") { typeDeclKey(it) }})->${typeDeclKey(type.returnType)}"
is TypeDecl.TypeVar -> "V:${type.name}"
is TypeDecl.Union -> "U:${type.options.joinToString("|") { typeDeclKey(it) }}"
is TypeDecl.Intersection -> "I:${type.options.joinToString("&") { typeDeclKey(it) }}"
}
private fun inferObjClassFromRef(ref: ObjRef): ObjClass? = when (ref) {
is ConstRef -> ref.constValue as? ObjClass ?: (ref.constValue as? Obj)?.objClass
is LocalVarRef -> nameObjClass[ref.name] ?: resolveClassByName(ref.name)
@ -3482,21 +3600,102 @@ class Compiler(
pos: Pos
) {
val decl = lookupGenericFunctionDecl(name) ?: return
val inferred = mutableMapOf<String, ObjClass>()
val inferred = mutableMapOf<String, TypeDecl>()
val limit = minOf(args.size, decl.params.size)
for (i in 0 until limit) {
val paramType = decl.params[i].type
val argRef = (args[i].value as? ExpressionStatement)?.ref ?: continue
val argClass = inferObjClassFromRef(argRef) ?: continue
if (paramType is TypeDecl.TypeVar) {
inferred[paramType.name] = argClass
}
val argTypeDecl = inferTypeDeclFromRef(argRef)
?: inferObjClassFromRef(argRef)?.let { TypeDecl.Simple(it.className, false) }
?: continue
collectTypeVarBindings(paramType, argTypeDecl, inferred)
}
for (tp in decl.typeParams) {
val argClass = inferred[tp.name] ?: continue
val argType = inferred[tp.name] ?: continue
val bound = tp.bound ?: continue
if (!typeParamBoundSatisfied(argClass, bound)) {
throw ScriptError(pos, "type argument ${argClass.className} does not satisfy bound ${typeDeclName(bound)}")
if (!typeDeclSatisfiesBound(argType, bound)) {
throw ScriptError(pos, "type argument ${typeDeclName(argType)} does not satisfy bound ${typeDeclName(bound)}")
}
}
}
private fun collectTypeVarBindings(
paramType: TypeDecl,
argType: TypeDecl,
out: MutableMap<String, TypeDecl>
) {
when (paramType) {
is TypeDecl.TypeVar -> {
val current = out[paramType.name]
out[paramType.name] = mergeTypeDecls(current, argType)
}
is TypeDecl.Generic -> {
if (argType is TypeDecl.Generic && argType.name == paramType.name &&
argType.args.size == paramType.args.size
) {
for (i in paramType.args.indices) {
collectTypeVarBindings(paramType.args[i], argType.args[i], out)
}
}
}
is TypeDecl.Union -> {
if (argType is TypeDecl.Union) {
val limit = minOf(paramType.options.size, argType.options.size)
for (i in 0 until limit) {
collectTypeVarBindings(paramType.options[i], argType.options[i], out)
}
}
}
is TypeDecl.Intersection -> {
if (argType is TypeDecl.Intersection) {
val limit = minOf(paramType.options.size, argType.options.size)
for (i in 0 until limit) {
collectTypeVarBindings(paramType.options[i], argType.options[i], out)
}
}
}
else -> {}
}
}
private fun mergeTypeDecls(a: TypeDecl?, b: TypeDecl): TypeDecl {
if (a == null) return b
if (a == TypeDecl.TypeAny || b == TypeDecl.TypeAny) {
return if (a.isNullable || b.isNullable) TypeDecl.TypeNullableAny else TypeDecl.TypeAny
}
if (a == TypeDecl.TypeNullableAny || b == TypeDecl.TypeNullableAny) return TypeDecl.TypeNullableAny
val (aBase, aNullable) = stripNullable(a)
val (bBase, bNullable) = stripNullable(b)
if (typeDeclKey(aBase) == typeDeclKey(bBase)) {
return if (aNullable || bNullable) makeTypeDeclNullable(aBase) else aBase
}
val options = mutableListOf<TypeDecl>()
val seen = mutableSetOf<String>()
fun addOpt(t: TypeDecl) {
val key = typeDeclKey(t)
if (seen.add(key)) options += t
}
val nullable = aNullable || bNullable
if (aBase is TypeDecl.Union) aBase.options.forEach { addOpt(it) } else addOpt(aBase)
if (bBase is TypeDecl.Union) bBase.options.forEach { addOpt(it) } else addOpt(bBase)
val merged = TypeDecl.Union(options, nullable = false)
return if (nullable) makeTypeDeclNullable(merged) else merged
}
private fun typeDeclSatisfiesBound(argType: TypeDecl, bound: TypeDecl): Boolean {
return when (argType) {
TypeDecl.TypeAny, TypeDecl.TypeNullableAny -> true
is TypeDecl.Union -> argType.options.all { typeDeclSatisfiesBound(it, bound) }
is TypeDecl.Intersection -> argType.options.all { typeDeclSatisfiesBound(it, bound) }
else -> when (bound) {
is TypeDecl.Union -> bound.options.any { typeDeclSatisfiesBound(argType, it) }
is TypeDecl.Intersection -> bound.options.all { typeDeclSatisfiesBound(argType, it) }
is TypeDecl.Simple, is TypeDecl.Generic, is TypeDecl.Function -> {
val argClass = resolveTypeDeclObjClass(argType) ?: return false
val boundClass = resolveTypeDeclObjClass(bound) ?: return false
argClass == boundClass || argClass.allParentsSet.contains(boundClass)
}
else -> true
}
}
}
@ -3509,11 +3708,10 @@ class Compiler(
if (typeParams.isEmpty()) return
val inferred = mutableMapOf<String, ObjClass>()
for (param in argsDeclaration.params) {
val paramType = param.type
if (paramType is TypeDecl.TypeVar) {
val rec = context.getLocalRecordDirect(param.name) ?: continue
val value = rec.value
if (value is Obj) inferred[paramType.name] = value.objClass
val rec = context.getLocalRecordDirect(param.name) ?: continue
val value = rec.value
if (value is Obj) {
collectRuntimeTypeVarBindings(param.type, value, inferred)
}
}
for (tp in typeParams) {
@ -3528,6 +3726,47 @@ class Compiler(
}
}
private fun collectRuntimeTypeVarBindings(
paramType: TypeDecl,
value: Obj,
inferred: MutableMap<String, ObjClass>
) {
when (paramType) {
is TypeDecl.TypeVar -> {
if (value !== ObjNull) {
inferred[paramType.name] = value.objClass
}
}
is TypeDecl.Generic -> {
val base = paramType.name.substringAfterLast('.')
val arg = paramType.args.firstOrNull()
if (base == "List" && arg is TypeDecl.TypeVar && value is ObjList) {
val elementClass = inferListElementClass(value)
inferred[arg.name] = elementClass
}
}
else -> {}
}
}
private fun inferListElementClass(list: ObjList): ObjClass {
var elemClass: ObjClass? = null
for (elem in list.list) {
if (elem === ObjNull) {
elemClass = Obj.rootObjectType
break
}
val cls = elem.objClass
if (elemClass == null) {
elemClass = cls
} else if (elemClass != cls) {
elemClass = Obj.rootObjectType
break
}
}
return elemClass ?: Obj.rootObjectType
}
private fun resolveLocalTypeRef(name: String, pos: Pos): ObjRef? {
val slotLoc = lookupSlotLocation(name, includeModule = true) ?: return null
captureLocalRef(name, slotLoc, pos)?.let { return it }
@ -4248,7 +4487,9 @@ class Compiler(
Token.Type.IS,
Token.Type.NOTIS -> {
val negated = t.type == Token.Type.NOTIS
val caseType = parseExpression() ?: throw ScriptError(cc.currentPos(), "type expected")
val (typeDecl, _) = parseTypeExpressionWithMini()
val typeRef = net.sergeych.lyng.obj.TypeDeclRef(typeDecl, t.pos)
val caseType = ExpressionStatement(typeRef, t.pos)
currentConditions += WhenIsCondition(caseType, negated, t.pos)
}
@ -6476,7 +6717,7 @@ class Compiler(
// Optional explicit type annotation
cc.skipWsTokens()
val (varTypeDecl, varTypeMini) = if (cc.peekNextNonWhitespace().type == Token.Type.COLON) {
var (varTypeDecl, varTypeMini) = if (cc.peekNextNonWhitespace().type == Token.Type.COLON) {
parseTypeDeclarationWithMini()
} else {
TypeDecl.TypeAny to null
@ -6624,6 +6865,13 @@ class Compiler(
else parseStatement(true)
?: throw ScriptError(effectiveEqToken!!.pos, "Expected initializer expression")
if (varTypeDecl == TypeDecl.TypeAny && initialExpression != null) {
val inferred = inferTypeDeclFromInitializer(initialExpression)
if (inferred != null) {
varTypeDecl = inferred
}
}
if (isDelegate && initialExpression != null) {
ensureDelegateType(initialExpression)
if (isMutable && resolveInitializerObjClass(initialExpression) == ObjLazyDelegate.type) {

View File

@ -49,7 +49,11 @@ class WhenIsCondition(
override val pos: Pos,
) : WhenCondition(expr, pos) {
override suspend fun matches(scope: Scope, value: Obj): Boolean {
val result = value.isInstanceOf(expr.execute(scope))
val typeExpr = expr.execute(scope)
val result = when (typeExpr) {
is net.sergeych.lyng.obj.ObjTypeExpr -> net.sergeych.lyng.obj.matchesTypeDecl(scope, value, typeExpr.typeDecl)
else -> value.isInstanceOf(typeExpr)
}
return if (negated) !result else result
}
}

View File

@ -198,6 +198,7 @@ class BytecodeCompiler(
private fun compileRef(ref: ObjRef): CompiledValue? {
return when (ref) {
is ConstRef -> compileConst(ref.constValue)
is TypeDeclRef -> compileConst(ObjTypeExpr(ref.decl()))
is IncDecRef -> compileIncDec(ref, true)
is CastRef -> compileCast(ref)
is LocalSlotRef -> {

View File

@ -207,8 +207,12 @@ class CmdCheckIs(internal val objSlot: Int, internal val typeSlot: Int, internal
override suspend fun perform(frame: CmdFrame) {
val obj = frame.slotToObj(objSlot)
val typeObj = frame.slotToObj(typeSlot)
val clazz = typeObj as? ObjClass
frame.setBool(dst, clazz != null && obj.isInstanceOf(clazz))
val result = when (typeObj) {
is ObjTypeExpr -> matchesTypeDecl(frame.ensureScope(), obj, typeObj.typeDecl)
is ObjClass -> obj.isInstanceOf(typeObj)
else -> false
}
frame.setBool(dst, result)
return
}
}

View File

@ -151,6 +151,12 @@ class BinaryOpRef(internal val op: BinOp, internal val left: ObjRef, internal va
override suspend fun evalValue(scope: Scope): Obj {
val a = left.evalValue(scope)
val b = right.evalValue(scope)
if (op == BinOp.IS || op == BinOp.NOTIS) {
if (b is ObjTypeExpr) {
val result = matchesTypeDecl(scope, a, b.typeDecl)
return if (op == BinOp.NOTIS) ObjBool(!result) else ObjBool(result)
}
}
// Primitive fast paths for common cases (guarded by PerfFlags.PRIMITIVE_FASTOPS)
if (PerfFlags.PRIMITIVE_FASTOPS) {
@ -461,6 +467,24 @@ class CastRef(
}
}
/** Type expression reference used for `is` checks (including unions/intersections). */
class TypeDeclRef(private val typeDecl: TypeDecl, private val atPos: Pos) : ObjRef {
internal fun decl(): TypeDecl = typeDecl
internal fun pos(): Pos = atPos
override fun forEachVariable(block: (String) -> Unit) {}
override fun forEachVariableWithPos(block: (String, Pos) -> Unit) {}
override suspend fun get(scope: Scope): ObjRecord {
return evalValue(scope).asReadonly
}
override suspend fun evalValue(scope: Scope): Obj {
return ObjTypeExpr(typeDecl)
}
}
/** Qualified `this@Type`: resolves to a view of current `this` starting dispatch from the ancestor Type. */
class QualifiedThisRef(val typeName: String, private val atPos: Pos) : ObjRef {
internal fun pos(): Pos = atPos

View File

@ -0,0 +1,56 @@
/*
* Copyright 2026 Sergey S. Chernov
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.sergeych.lyng.obj
import net.sergeych.lyng.Scope
import net.sergeych.lyng.TypeDecl
/**
* Runtime wrapper for a type expression (including unions/intersections) used by `is` checks.
*/
class ObjTypeExpr(val typeDecl: TypeDecl) : Obj()
internal fun matchesTypeDecl(scope: Scope, value: Obj, typeDecl: TypeDecl): Boolean {
if (value === ObjNull) {
return typeDecl.isNullable || typeDecl is TypeDecl.TypeNullableAny
}
fun resolveClassFromScope(typeName: String): ObjClass? {
val direct = scope[typeName]?.value as? ObjClass
if (direct != null) return direct
val simple = typeName.substringAfterLast('.')
return scope[simple]?.value as? ObjClass
}
return when (typeDecl) {
TypeDecl.TypeAny -> true
TypeDecl.TypeNullableAny -> true
is TypeDecl.TypeVar -> {
val cls = (scope[typeDecl.name]?.value as? ObjClass)
if (cls != null) value.isInstanceOf(cls) else value.isInstanceOf(typeDecl.name)
}
is TypeDecl.Simple -> {
val cls = resolveClassFromScope(typeDecl.name)
if (cls != null) value.isInstanceOf(cls) else value.isInstanceOf(typeDecl.name.substringAfterLast('.'))
}
is TypeDecl.Generic -> {
val cls = resolveClassFromScope(typeDecl.name)
if (cls != null) value.isInstanceOf(cls) else value.isInstanceOf(typeDecl.name.substringAfterLast('.'))
}
is TypeDecl.Function -> value.isInstanceOf("Callable")
is TypeDecl.Union -> typeDecl.options.any { matchesTypeDecl(scope, value, it) }
is TypeDecl.Intersection -> typeDecl.options.all { matchesTypeDecl(scope, value, it) }
}
}

View File

@ -18,6 +18,7 @@
import kotlinx.coroutines.test.runTest
import net.sergeych.lyng.eval
import kotlin.test.Test
import kotlin.test.assertFailsWith
class TypesTest {
@ -166,4 +167,69 @@ class TypesTest {
""".trimIndent()
)
}
@Test
fun testIsUnionIntersection() = runTest {
eval("""
class A
class B
class C: A, B
val c = C()
assert( c is A | B )
assert( c is A & B )
assert( !(c is A & String) )
val v = 1
assert( v is Int | String | Real )
assert( !(v is String | Bool) )
""".trimIndent())
}
@Test
fun testListLiteralInferenceForBounds() = runTest {
eval("""
fun acceptInts<T: Int>(xs: List<T>) { }
acceptInts([1, 2, 3])
val base = [1, 2]
acceptInts([...base, 3])
""".trimIndent())
assertFailsWith<net.sergeych.lyng.ScriptError> {
eval("""
fun acceptInts<T: Int>(xs: List<T>) { }
acceptInts([1, "a"])
""".trimIndent())
}
}
@Test
fun testUnioTypeLists() = runTest {
eval("""
fun f<T>(list: List<T>) {
println(list)
println(T)
}
f([1, "two", true])
f([1,2,3])
""")
}
@Test
fun multipleReceivers() = runTest {
eval("""
class R1(shared,r1="r1")
class R2(shared,r2="r2")
R1("s").apply {
assertEquals("r1", r1)
assertEquals("s", shared)
R2("t").apply {
assertEquals("r2", r2)
assertEquals("t", shared)
assertEquals("r1", this@R1.r1)
// actually we have now this of union type R1 & R2!
}
}
""")
}
}