Add generic bounds checks and union/intersection types

This commit is contained in:
Sergey Chernov 2026-02-03 15:36:11 +03:00
parent c5bf4e5039
commit 54c6fca0e8
5 changed files with 256 additions and 33 deletions

View File

@ -64,12 +64,38 @@ class Compiler(
) )
private val slotPlanStack = mutableListOf<SlotPlan>() private val slotPlanStack = mutableListOf<SlotPlan>()
private var nextScopeId = 0 private var nextScopeId = 0
private val genericFunctionDeclsStack = mutableListOf<MutableMap<String, GenericFunctionDecl>>(mutableMapOf())
// Track declared local variables count per function for precise capacity hints // Track declared local variables count per function for precise capacity hints
private val localDeclCountStack = mutableListOf<Int>() private val localDeclCountStack = mutableListOf<Int>()
private val currentLocalDeclCount: Int private val currentLocalDeclCount: Int
get() = localDeclCountStack.lastOrNull() ?: 0 get() = localDeclCountStack.lastOrNull() ?: 0
private data class GenericFunctionDecl(
val typeParams: List<TypeDecl.TypeParam>,
val params: List<ArgsDeclaration.Item>,
val pos: Pos
)
private fun pushGenericFunctionScope() {
genericFunctionDeclsStack.add(mutableMapOf())
}
private fun popGenericFunctionScope() {
genericFunctionDeclsStack.removeLast()
}
private fun currentGenericFunctionDecls(): MutableMap<String, GenericFunctionDecl> {
return genericFunctionDeclsStack.last()
}
private fun lookupGenericFunctionDecl(name: String): GenericFunctionDecl? {
for (i in genericFunctionDeclsStack.indices.reversed()) {
genericFunctionDeclsStack[i][name]?.let { return it }
}
return null
}
private inline fun <T> withLocalNames(names: Set<String>, block: () -> T): T { private inline fun <T> withLocalNames(names: Set<String>, block: () -> T): T {
localNamesStack.add(names.toMutableSet()) localNamesStack.add(names.toMutableSet())
return try { return try {
@ -440,6 +466,7 @@ class Compiler(
private fun currentTypeParams(): Set<String> { private fun currentTypeParams(): Set<String> {
val result = mutableSetOf<String>() val result = mutableSetOf<String>()
pendingTypeParamStack.lastOrNull()?.let { result.addAll(it) }
for (ctx in codeContexts.asReversed()) { for (ctx in codeContexts.asReversed()) {
when (ctx) { when (ctx) {
is CodeContext.Function -> result.addAll(ctx.typeParams) is CodeContext.Function -> result.addAll(ctx.typeParams)
@ -450,6 +477,8 @@ class Compiler(
return result return result
} }
private val pendingTypeParamStack = mutableListOf<Set<String>>()
private fun parseTypeParamList(): List<TypeDecl.TypeParam> { private fun parseTypeParamList(): List<TypeDecl.TypeParam> {
if (cc.peekNextNonWhitespace().type != Token.Type.LT) return emptyList() if (cc.peekNextNonWhitespace().type != Token.Type.LT) return emptyList()
val typeParams = mutableListOf<TypeDecl.TypeParam>() val typeParams = mutableListOf<TypeDecl.TypeParam>()
@ -774,6 +803,7 @@ class Compiler(
private suspend fun <T> inCodeContext(context: CodeContext, f: suspend () -> T): T { private suspend fun <T> inCodeContext(context: CodeContext, f: suspend () -> T): T {
codeContexts.add(context) codeContexts.add(context)
pushGenericFunctionScope()
try { try {
val res = f() val res = f()
if (context is CodeContext.ClassBody) { if (context is CodeContext.ClassBody) {
@ -784,6 +814,7 @@ class Compiler(
} }
return res return res
} finally { } finally {
popGenericFunctionScope()
codeContexts.removeLast() codeContexts.removeLast()
} }
} }
@ -2401,6 +2432,36 @@ class Compiler(
} }
private fun parseTypeExpressionWithMini(): Pair<TypeDecl, MiniTypeRef> { private fun parseTypeExpressionWithMini(): Pair<TypeDecl, MiniTypeRef> {
return parseTypeUnionWithMini()
}
private fun parseTypeUnionWithMini(): Pair<TypeDecl, MiniTypeRef> {
var left = parseTypeIntersectionWithMini()
val options = mutableListOf(left)
while (cc.skipTokenOfType(Token.Type.BITOR, isOptional = true)) {
options += parseTypeIntersectionWithMini()
}
if (options.size == 1) return left
val rangeStart = options.first().second.range.start
val rangeEnd = cc.currentPos()
val mini = MiniTypeUnion(MiniRange(rangeStart, rangeEnd), options.map { it.second }, nullable = false)
return TypeDecl.Union(options.map { it.first }, nullable = false) to mini
}
private fun parseTypeIntersectionWithMini(): Pair<TypeDecl, MiniTypeRef> {
var left = parseTypePrimaryWithMini()
val options = mutableListOf(left)
while (cc.skipTokenOfType(Token.Type.BITAND, isOptional = true)) {
options += parseTypePrimaryWithMini()
}
if (options.size == 1) return left
val rangeStart = options.first().second.range.start
val rangeEnd = cc.currentPos()
val mini = MiniTypeIntersection(MiniRange(rangeStart, rangeEnd), options.map { it.second }, nullable = false)
return TypeDecl.Intersection(options.map { it.first }, nullable = false) to mini
}
private fun parseTypePrimaryWithMini(): Pair<TypeDecl, MiniTypeRef> {
parseFunctionTypeWithMini()?.let { return it } parseFunctionTypeWithMini()?.let { return it }
return parseSimpleTypeExpressionWithMini() return parseSimpleTypeExpressionWithMini()
} }
@ -2595,8 +2656,8 @@ class Compiler(
private fun typeDeclToTypeRef(typeDecl: TypeDecl, pos: Pos): ObjRef { private fun typeDeclToTypeRef(typeDecl: TypeDecl, pos: Pos): ObjRef {
return when (typeDecl) { return when (typeDecl) {
TypeDecl.TypeAny, TypeDecl.TypeAny,
TypeDecl.TypeNullableAny, TypeDecl.TypeNullableAny -> ConstRef(Obj.rootObjectType.asReadonly)
is TypeDecl.TypeVar -> ConstRef(Obj.rootObjectType.asReadonly) is TypeDecl.TypeVar -> resolveLocalTypeRef(typeDecl.name, pos) ?: ConstRef(Obj.rootObjectType.asReadonly)
else -> { else -> {
val cls = resolveTypeDeclObjClass(typeDecl) val cls = resolveTypeDeclObjClass(typeDecl)
if (cls != null) return ConstRef(cls.asReadonly) if (cls != null) return ConstRef(cls.asReadonly)
@ -2612,10 +2673,92 @@ class Compiler(
is TypeDecl.Generic -> typeDecl.name is TypeDecl.Generic -> typeDecl.name
is TypeDecl.Function -> "Callable" is TypeDecl.Function -> "Callable"
is TypeDecl.TypeVar -> typeDecl.name is TypeDecl.TypeVar -> typeDecl.name
is TypeDecl.Union -> typeDecl.options.joinToString(" | ") { typeDeclName(it) }
is TypeDecl.Intersection -> typeDecl.options.joinToString(" & ") { typeDeclName(it) }
TypeDecl.TypeAny -> "Object" TypeDecl.TypeAny -> "Object"
TypeDecl.TypeNullableAny -> "Object?" TypeDecl.TypeNullableAny -> "Object?"
} }
private fun inferObjClassFromRef(ref: ObjRef): ObjClass? = when (ref) {
is ConstRef -> ref.constValue as? ObjClass ?: (ref.constValue as? Obj)?.objClass
is LocalVarRef -> nameObjClass[ref.name]
is LocalSlotRef -> nameObjClass[ref.name]
is ListLiteralRef -> ObjList.type
is MapLiteralRef -> ObjMap.type
is RangeRef -> ObjRange.type
is CastRef -> resolveTypeRefClass(ref.castTypeRef())
else -> null
}
private fun resolveTypeRefClass(ref: ObjRef): ObjClass? = when (ref) {
is ConstRef -> ref.constValue as? ObjClass
is LocalSlotRef -> resolveTypeDeclObjClass(TypeDecl.Simple(ref.name, false)) ?: nameObjClass[ref.name]
is LocalVarRef -> resolveTypeDeclObjClass(TypeDecl.Simple(ref.name, false)) ?: nameObjClass[ref.name]
else -> null
}
private fun typeParamBoundSatisfied(argClass: ObjClass, bound: TypeDecl): Boolean = when (bound) {
is TypeDecl.Union -> bound.options.any { typeParamBoundSatisfied(argClass, it) }
is TypeDecl.Intersection -> bound.options.all { typeParamBoundSatisfied(argClass, it) }
is TypeDecl.Simple, is TypeDecl.Generic -> {
val boundClass = resolveTypeDeclObjClass(bound) ?: return false
argClass == boundClass || argClass.allParentsSet.contains(boundClass)
}
else -> true
}
private fun checkGenericBoundsAtCall(
name: String,
args: List<ParsedArgument>,
pos: Pos
) {
val decl = lookupGenericFunctionDecl(name) ?: return
val inferred = mutableMapOf<String, ObjClass>()
val limit = minOf(args.size, decl.params.size)
for (i in 0 until limit) {
val paramType = decl.params[i].type
val argRef = (args[i].value as? ExpressionStatement)?.ref ?: continue
val argClass = inferObjClassFromRef(argRef) ?: continue
if (paramType is TypeDecl.TypeVar) {
inferred[paramType.name] = argClass
}
}
for (tp in decl.typeParams) {
val argClass = inferred[tp.name] ?: continue
val bound = tp.bound ?: continue
if (!typeParamBoundSatisfied(argClass, bound)) {
throw ScriptError(pos, "type argument ${argClass.className} does not satisfy bound ${typeDeclName(bound)}")
}
}
}
private fun bindTypeParamsAtRuntime(
context: Scope,
argsDeclaration: ArgsDeclaration,
typeParams: List<TypeDecl.TypeParam>
) {
if (typeParams.isEmpty()) return
val inferred = mutableMapOf<String, ObjClass>()
for (param in argsDeclaration.params) {
val paramType = param.type
if (paramType is TypeDecl.TypeVar) {
val rec = context.getLocalRecordDirect(param.name) ?: continue
val value = rec.value
if (value is Obj) inferred[paramType.name] = value.objClass
}
}
for (tp in typeParams) {
val cls = inferred[tp.name]
?: tp.defaultType?.let { resolveTypeDeclObjClass(it) }
?: Obj.rootObjectType
context.addConst(tp.name, cls)
val bound = tp.bound ?: continue
if (!typeParamBoundSatisfied(cls, bound)) {
context.raiseError("type argument ${cls.className} does not satisfy bound ${typeDeclName(bound)}")
}
}
}
private fun resolveLocalTypeRef(name: String, pos: Pos): ObjRef? { private fun resolveLocalTypeRef(name: String, pos: Pos): ObjRef? {
val slotLoc = lookupSlotLocation(name, includeModule = true) ?: return null val slotLoc = lookupSlotLocation(name, includeModule = true) ?: return null
captureLocalRef(name, slotLoc, pos)?.let { return it } captureLocalRef(name, slotLoc, pos)?.let { return it }
@ -2828,6 +2971,7 @@ class Compiler(
implicitThisTypeName implicitThisTypeName
) )
} else { } else {
checkGenericBoundsAtCall(left.name, args, left.pos())
CallRef(left, args, detectedBlockArgument, isOptional) CallRef(left, args, detectedBlockArgument, isOptional)
} }
} }
@ -2848,6 +2992,7 @@ class Compiler(
implicitThisTypeName implicitThisTypeName
) )
} else { } else {
checkGenericBoundsAtCall(left.name, args, left.pos())
CallRef(left, args, detectedBlockArgument, isOptional) CallRef(left, args, detectedBlockArgument, isOptional)
} }
} }
@ -3749,11 +3894,18 @@ class Compiler(
val classCtx = codeContexts.lastOrNull() as? CodeContext.ClassBody val classCtx = codeContexts.lastOrNull() as? CodeContext.ClassBody
val typeParamDecls = parseTypeParamList() val typeParamDecls = parseTypeParamList()
classCtx?.typeParamDecls = typeParamDecls classCtx?.typeParamDecls = typeParamDecls
classCtx?.typeParams = typeParamDecls.map { it.name }.toSet() val classTypeParams = typeParamDecls.map { it.name }.toSet()
val constructorArgsDeclaration = classCtx?.typeParams = classTypeParams
pendingTypeParamStack.add(classTypeParams)
val constructorArgsDeclaration: ArgsDeclaration?
try {
constructorArgsDeclaration =
if (cc.skipTokenOfType(Token.Type.LPAREN, isOptional = true)) if (cc.skipTokenOfType(Token.Type.LPAREN, isOptional = true))
parseArgsDeclaration(isClassDeclaration = true) parseArgsDeclaration(isClassDeclaration = true)
else ArgsDeclaration(emptyList(), Token.Type.RPAREN) else ArgsDeclaration(emptyList(), Token.Type.RPAREN)
} finally {
pendingTypeParamStack.removeLast()
}
if (constructorArgsDeclaration != null && constructorArgsDeclaration.endTokenType != Token.Type.RPAREN) if (constructorArgsDeclaration != null && constructorArgsDeclaration.endTokenType != Token.Type.RPAREN)
throw ScriptError( throw ScriptError(
@ -3777,6 +3929,8 @@ class Compiler(
data class BaseSpec(val name: String, val args: List<ParsedArgument>?) data class BaseSpec(val name: String, val args: List<ParsedArgument>?)
val baseSpecs = mutableListOf<BaseSpec>() val baseSpecs = mutableListOf<BaseSpec>()
pendingTypeParamStack.add(classTypeParams)
try {
if (cc.skipTokenOfType(Token.Type.COLON, isOptional = true)) { if (cc.skipTokenOfType(Token.Type.COLON, isOptional = true)) {
do { do {
val (baseDecl, _) = parseSimpleTypeExpressionWithMini() val (baseDecl, _) = parseSimpleTypeExpressionWithMini()
@ -3794,6 +3948,9 @@ class Compiler(
baseSpecs += BaseSpec(baseName, argsList) baseSpecs += BaseSpec(baseName, argsList)
} while (cc.skipTokenOfType(Token.Type.COMMA, isOptional = true)) } while (cc.skipTokenOfType(Token.Type.COMMA, isOptional = true))
} }
} finally {
pendingTypeParamStack.removeLast()
}
cc.skipTokenOfType(Token.Type.NEWLINE, isOptional = true) cc.skipTokenOfType(Token.Type.NEWLINE, isOptional = true)
@ -4414,17 +4571,27 @@ class Compiler(
val typeParamDecls = parseTypeParamList() val typeParamDecls = parseTypeParamList()
val typeParams = typeParamDecls.map { it.name }.toSet() val typeParams = typeParamDecls.map { it.name }.toSet()
pendingTypeParamStack.add(typeParams)
val argsDeclaration: ArgsDeclaration = val argsDeclaration: ArgsDeclaration
val returnTypeMini: MiniTypeRef?
try {
argsDeclaration =
if (cc.peekNextNonWhitespace().type == Token.Type.LPAREN) { if (cc.peekNextNonWhitespace().type == Token.Type.LPAREN) {
cc.nextNonWhitespace() // consume ( cc.nextNonWhitespace() // consume (
parseArgsDeclaration() ?: ArgsDeclaration(emptyList(), Token.Type.RPAREN) parseArgsDeclaration() ?: ArgsDeclaration(emptyList(), Token.Type.RPAREN)
} else ArgsDeclaration(emptyList(), Token.Type.RPAREN) } else ArgsDeclaration(emptyList(), Token.Type.RPAREN)
if (typeParamDecls.isNotEmpty() && declKind != SymbolKind.MEMBER) {
currentGenericFunctionDecls()[name] = GenericFunctionDecl(typeParamDecls, argsDeclaration.params, nameStartPos)
}
// Optional return type // Optional return type
val returnTypeMini: MiniTypeRef? = if (cc.peekNextNonWhitespace().type == Token.Type.COLON) { returnTypeMini = if (cc.peekNextNonWhitespace().type == Token.Type.COLON) {
parseTypeDeclarationWithMini().second parseTypeDeclarationWithMini().second
} else null } else null
} finally {
pendingTypeParamStack.removeLast()
}
var isDelegated = false var isDelegated = false
var delegateExpression: Statement? = null var delegateExpression: Statement? = null
@ -4485,8 +4652,9 @@ class Compiler(
outerLabel?.let { cc.labels.add(it) } outerLabel?.let { cc.labels.add(it) }
val paramNamesList = argsDeclaration.params.map { it.name } val paramNamesList = argsDeclaration.params.map { it.name }
val typeParamNames = typeParamDecls.map { it.name }
val paramNames: Set<String> = paramNamesList.toSet() val paramNames: Set<String> = paramNamesList.toSet()
val paramSlotPlan = buildParamSlotPlan(paramNamesList) val paramSlotPlan = buildParamSlotPlan(paramNamesList + typeParamNames)
val capturePlan = CapturePlan(paramSlotPlan) val capturePlan = CapturePlan(paramSlotPlan)
val rangeParamNames = argsDeclaration.params val rangeParamNames = argsDeclaration.params
.filter { isRangeType(it.type) } .filter { isRangeType(it.type) }
@ -4588,6 +4756,7 @@ class Compiler(
// load params from caller context // load params from caller context
argsDeclaration.assignToContext(context, callerContext.args, defaultAccessType = AccessType.Val) argsDeclaration.assignToContext(context, callerContext.args, defaultAccessType = AccessType.Val)
bindTypeParamsAtRuntime(context, argsDeclaration, typeParamDecls)
if (extTypeName != null) { if (extTypeName != null) {
context.thisObj = callerContext.thisObj context.thisObj = callerContext.thisObj
} }
@ -4897,6 +5066,8 @@ class Compiler(
is TypeDecl.Generic -> type.name is TypeDecl.Generic -> type.name
is TypeDecl.Function -> "Callable" is TypeDecl.Function -> "Callable"
is TypeDecl.TypeVar -> return null is TypeDecl.TypeVar -> return null
is TypeDecl.Union -> return null
is TypeDecl.Intersection -> return null
else -> return null else -> return null
} }
val name = rawName.substringAfterLast('.') val name = rawName.substringAfterLast('.')

View File

@ -31,6 +31,8 @@ sealed class TypeDecl(val isNullable:Boolean = false) {
val nullable: Boolean = false val nullable: Boolean = false
) : TypeDecl(nullable) ) : TypeDecl(nullable)
data class TypeVar(val name: String, val nullable: Boolean = false) : TypeDecl(nullable) data class TypeVar(val name: String, val nullable: Boolean = false) : TypeDecl(nullable)
data class Union(val options: List<TypeDecl>, val nullable: Boolean = false) : TypeDecl(nullable)
data class Intersection(val options: List<TypeDecl>, val nullable: Boolean = false) : TypeDecl(nullable)
data class TypeParam( data class TypeParam(
val name: String, val name: String,
val variance: Variance = Variance.Invariant, val variance: Variance = Variance.Invariant,

View File

@ -1025,6 +1025,8 @@ object DocLookupUtils {
is MiniGenericType -> simpleClassNameOf(t.base) is MiniGenericType -> simpleClassNameOf(t.base)
is MiniFunctionType -> null is MiniFunctionType -> null
is MiniTypeVar -> null is MiniTypeVar -> null
is MiniTypeUnion -> null
is MiniTypeIntersection -> null
} }
fun typeOf(t: MiniTypeRef?): String = when (t) { fun typeOf(t: MiniTypeRef?): String = when (t) {
@ -1035,6 +1037,8 @@ object DocLookupUtils {
r + "(" + t.params.joinToString(", ") { typeOf(it) } + ") -> " + typeOf(t.returnType) + (if (t.nullable) "?" else "") r + "(" + t.params.joinToString(", ") { typeOf(it) } + ") -> " + typeOf(t.returnType) + (if (t.nullable) "?" else "")
} }
is MiniTypeVar -> t.name + (if (t.nullable) "?" else "") is MiniTypeVar -> t.name + (if (t.nullable) "?" else "")
is MiniTypeUnion -> t.options.joinToString(" | ") { typeOf(it) } + (if (t.nullable) "?" else "")
is MiniTypeIntersection -> t.options.joinToString(" & ") { typeOf(it) } + (if (t.nullable) "?" else "")
null -> "" null -> ""
} }

View File

@ -150,6 +150,18 @@ data class MiniTypeVar(
val nullable: Boolean val nullable: Boolean
) : MiniTypeRef ) : MiniTypeRef
data class MiniTypeUnion(
override val range: MiniRange,
val options: List<MiniTypeRef>,
val nullable: Boolean
) : MiniTypeRef
data class MiniTypeIntersection(
override val range: MiniRange,
val options: List<MiniTypeRef>,
val nullable: Boolean
) : MiniTypeRef
// Script and declarations (lean subset; can be extended later) // Script and declarations (lean subset; can be extended later)
sealed interface MiniNamedDecl : MiniNode { sealed interface MiniNamedDecl : MiniNode {
val name: String val name: String

View File

@ -5273,6 +5273,40 @@ class ScriptTest {
assertEquals(ObjFalse, scope.eval("isInt(\"42\")")) assertEquals(ObjFalse, scope.eval("isInt(\"42\")"))
} }
@Test
fun testGenericBoundsAndReifiedTypeParams() = runTest {
val resInt = eval(
"""
fun square<T: Int | Real>(x: T) = x * x
square(2)
""".trimIndent()
)
assertEquals(4L, (resInt as ObjInt).value)
val resReal = eval(
"""
fun square<T: Int | Real>(x: T) = x * x
square(1.5)
""".trimIndent()
)
assertEquals(2.25, (resReal as ObjReal).value, 0.00001)
assertFailsWith<ScriptError> {
eval(
"""
fun square<T: Int | Real>(x: T) = x * x
square("x")
""".trimIndent()
)
}
val reified = eval(
"""
fun sameType<T>(x: T, y: Object) = y is T
sameType(1, "a")
""".trimIndent()
)
assertEquals(false, (reified as ObjBool).value)
}
@Test @Test
fun testFilterBug() = runTest { fun testFilterBug() = runTest {
eval( eval(