Add generic bounds checks and union/intersection types
This commit is contained in:
parent
c5bf4e5039
commit
54c6fca0e8
@ -64,12 +64,38 @@ class Compiler(
|
||||
)
|
||||
private val slotPlanStack = mutableListOf<SlotPlan>()
|
||||
private var nextScopeId = 0
|
||||
private val genericFunctionDeclsStack = mutableListOf<MutableMap<String, GenericFunctionDecl>>(mutableMapOf())
|
||||
|
||||
// Track declared local variables count per function for precise capacity hints
|
||||
private val localDeclCountStack = mutableListOf<Int>()
|
||||
private val currentLocalDeclCount: Int
|
||||
get() = localDeclCountStack.lastOrNull() ?: 0
|
||||
|
||||
private data class GenericFunctionDecl(
|
||||
val typeParams: List<TypeDecl.TypeParam>,
|
||||
val params: List<ArgsDeclaration.Item>,
|
||||
val pos: Pos
|
||||
)
|
||||
|
||||
private fun pushGenericFunctionScope() {
|
||||
genericFunctionDeclsStack.add(mutableMapOf())
|
||||
}
|
||||
|
||||
private fun popGenericFunctionScope() {
|
||||
genericFunctionDeclsStack.removeLast()
|
||||
}
|
||||
|
||||
private fun currentGenericFunctionDecls(): MutableMap<String, GenericFunctionDecl> {
|
||||
return genericFunctionDeclsStack.last()
|
||||
}
|
||||
|
||||
private fun lookupGenericFunctionDecl(name: String): GenericFunctionDecl? {
|
||||
for (i in genericFunctionDeclsStack.indices.reversed()) {
|
||||
genericFunctionDeclsStack[i][name]?.let { return it }
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
private inline fun <T> withLocalNames(names: Set<String>, block: () -> T): T {
|
||||
localNamesStack.add(names.toMutableSet())
|
||||
return try {
|
||||
@ -440,6 +466,7 @@ class Compiler(
|
||||
|
||||
private fun currentTypeParams(): Set<String> {
|
||||
val result = mutableSetOf<String>()
|
||||
pendingTypeParamStack.lastOrNull()?.let { result.addAll(it) }
|
||||
for (ctx in codeContexts.asReversed()) {
|
||||
when (ctx) {
|
||||
is CodeContext.Function -> result.addAll(ctx.typeParams)
|
||||
@ -450,6 +477,8 @@ class Compiler(
|
||||
return result
|
||||
}
|
||||
|
||||
private val pendingTypeParamStack = mutableListOf<Set<String>>()
|
||||
|
||||
private fun parseTypeParamList(): List<TypeDecl.TypeParam> {
|
||||
if (cc.peekNextNonWhitespace().type != Token.Type.LT) return emptyList()
|
||||
val typeParams = mutableListOf<TypeDecl.TypeParam>()
|
||||
@ -774,6 +803,7 @@ class Compiler(
|
||||
|
||||
private suspend fun <T> inCodeContext(context: CodeContext, f: suspend () -> T): T {
|
||||
codeContexts.add(context)
|
||||
pushGenericFunctionScope()
|
||||
try {
|
||||
val res = f()
|
||||
if (context is CodeContext.ClassBody) {
|
||||
@ -784,6 +814,7 @@ class Compiler(
|
||||
}
|
||||
return res
|
||||
} finally {
|
||||
popGenericFunctionScope()
|
||||
codeContexts.removeLast()
|
||||
}
|
||||
}
|
||||
@ -2401,6 +2432,36 @@ class Compiler(
|
||||
}
|
||||
|
||||
private fun parseTypeExpressionWithMini(): Pair<TypeDecl, MiniTypeRef> {
|
||||
return parseTypeUnionWithMini()
|
||||
}
|
||||
|
||||
private fun parseTypeUnionWithMini(): Pair<TypeDecl, MiniTypeRef> {
|
||||
var left = parseTypeIntersectionWithMini()
|
||||
val options = mutableListOf(left)
|
||||
while (cc.skipTokenOfType(Token.Type.BITOR, isOptional = true)) {
|
||||
options += parseTypeIntersectionWithMini()
|
||||
}
|
||||
if (options.size == 1) return left
|
||||
val rangeStart = options.first().second.range.start
|
||||
val rangeEnd = cc.currentPos()
|
||||
val mini = MiniTypeUnion(MiniRange(rangeStart, rangeEnd), options.map { it.second }, nullable = false)
|
||||
return TypeDecl.Union(options.map { it.first }, nullable = false) to mini
|
||||
}
|
||||
|
||||
private fun parseTypeIntersectionWithMini(): Pair<TypeDecl, MiniTypeRef> {
|
||||
var left = parseTypePrimaryWithMini()
|
||||
val options = mutableListOf(left)
|
||||
while (cc.skipTokenOfType(Token.Type.BITAND, isOptional = true)) {
|
||||
options += parseTypePrimaryWithMini()
|
||||
}
|
||||
if (options.size == 1) return left
|
||||
val rangeStart = options.first().second.range.start
|
||||
val rangeEnd = cc.currentPos()
|
||||
val mini = MiniTypeIntersection(MiniRange(rangeStart, rangeEnd), options.map { it.second }, nullable = false)
|
||||
return TypeDecl.Intersection(options.map { it.first }, nullable = false) to mini
|
||||
}
|
||||
|
||||
private fun parseTypePrimaryWithMini(): Pair<TypeDecl, MiniTypeRef> {
|
||||
parseFunctionTypeWithMini()?.let { return it }
|
||||
return parseSimpleTypeExpressionWithMini()
|
||||
}
|
||||
@ -2595,8 +2656,8 @@ class Compiler(
|
||||
private fun typeDeclToTypeRef(typeDecl: TypeDecl, pos: Pos): ObjRef {
|
||||
return when (typeDecl) {
|
||||
TypeDecl.TypeAny,
|
||||
TypeDecl.TypeNullableAny,
|
||||
is TypeDecl.TypeVar -> ConstRef(Obj.rootObjectType.asReadonly)
|
||||
TypeDecl.TypeNullableAny -> ConstRef(Obj.rootObjectType.asReadonly)
|
||||
is TypeDecl.TypeVar -> resolveLocalTypeRef(typeDecl.name, pos) ?: ConstRef(Obj.rootObjectType.asReadonly)
|
||||
else -> {
|
||||
val cls = resolveTypeDeclObjClass(typeDecl)
|
||||
if (cls != null) return ConstRef(cls.asReadonly)
|
||||
@ -2612,10 +2673,92 @@ class Compiler(
|
||||
is TypeDecl.Generic -> typeDecl.name
|
||||
is TypeDecl.Function -> "Callable"
|
||||
is TypeDecl.TypeVar -> typeDecl.name
|
||||
is TypeDecl.Union -> typeDecl.options.joinToString(" | ") { typeDeclName(it) }
|
||||
is TypeDecl.Intersection -> typeDecl.options.joinToString(" & ") { typeDeclName(it) }
|
||||
TypeDecl.TypeAny -> "Object"
|
||||
TypeDecl.TypeNullableAny -> "Object?"
|
||||
}
|
||||
|
||||
private fun inferObjClassFromRef(ref: ObjRef): ObjClass? = when (ref) {
|
||||
is ConstRef -> ref.constValue as? ObjClass ?: (ref.constValue as? Obj)?.objClass
|
||||
is LocalVarRef -> nameObjClass[ref.name]
|
||||
is LocalSlotRef -> nameObjClass[ref.name]
|
||||
is ListLiteralRef -> ObjList.type
|
||||
is MapLiteralRef -> ObjMap.type
|
||||
is RangeRef -> ObjRange.type
|
||||
is CastRef -> resolveTypeRefClass(ref.castTypeRef())
|
||||
else -> null
|
||||
}
|
||||
|
||||
private fun resolveTypeRefClass(ref: ObjRef): ObjClass? = when (ref) {
|
||||
is ConstRef -> ref.constValue as? ObjClass
|
||||
is LocalSlotRef -> resolveTypeDeclObjClass(TypeDecl.Simple(ref.name, false)) ?: nameObjClass[ref.name]
|
||||
is LocalVarRef -> resolveTypeDeclObjClass(TypeDecl.Simple(ref.name, false)) ?: nameObjClass[ref.name]
|
||||
else -> null
|
||||
}
|
||||
|
||||
private fun typeParamBoundSatisfied(argClass: ObjClass, bound: TypeDecl): Boolean = when (bound) {
|
||||
is TypeDecl.Union -> bound.options.any { typeParamBoundSatisfied(argClass, it) }
|
||||
is TypeDecl.Intersection -> bound.options.all { typeParamBoundSatisfied(argClass, it) }
|
||||
is TypeDecl.Simple, is TypeDecl.Generic -> {
|
||||
val boundClass = resolveTypeDeclObjClass(bound) ?: return false
|
||||
argClass == boundClass || argClass.allParentsSet.contains(boundClass)
|
||||
}
|
||||
else -> true
|
||||
}
|
||||
|
||||
private fun checkGenericBoundsAtCall(
|
||||
name: String,
|
||||
args: List<ParsedArgument>,
|
||||
pos: Pos
|
||||
) {
|
||||
val decl = lookupGenericFunctionDecl(name) ?: return
|
||||
val inferred = mutableMapOf<String, ObjClass>()
|
||||
val limit = minOf(args.size, decl.params.size)
|
||||
for (i in 0 until limit) {
|
||||
val paramType = decl.params[i].type
|
||||
val argRef = (args[i].value as? ExpressionStatement)?.ref ?: continue
|
||||
val argClass = inferObjClassFromRef(argRef) ?: continue
|
||||
if (paramType is TypeDecl.TypeVar) {
|
||||
inferred[paramType.name] = argClass
|
||||
}
|
||||
}
|
||||
for (tp in decl.typeParams) {
|
||||
val argClass = inferred[tp.name] ?: continue
|
||||
val bound = tp.bound ?: continue
|
||||
if (!typeParamBoundSatisfied(argClass, bound)) {
|
||||
throw ScriptError(pos, "type argument ${argClass.className} does not satisfy bound ${typeDeclName(bound)}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun bindTypeParamsAtRuntime(
|
||||
context: Scope,
|
||||
argsDeclaration: ArgsDeclaration,
|
||||
typeParams: List<TypeDecl.TypeParam>
|
||||
) {
|
||||
if (typeParams.isEmpty()) return
|
||||
val inferred = mutableMapOf<String, ObjClass>()
|
||||
for (param in argsDeclaration.params) {
|
||||
val paramType = param.type
|
||||
if (paramType is TypeDecl.TypeVar) {
|
||||
val rec = context.getLocalRecordDirect(param.name) ?: continue
|
||||
val value = rec.value
|
||||
if (value is Obj) inferred[paramType.name] = value.objClass
|
||||
}
|
||||
}
|
||||
for (tp in typeParams) {
|
||||
val cls = inferred[tp.name]
|
||||
?: tp.defaultType?.let { resolveTypeDeclObjClass(it) }
|
||||
?: Obj.rootObjectType
|
||||
context.addConst(tp.name, cls)
|
||||
val bound = tp.bound ?: continue
|
||||
if (!typeParamBoundSatisfied(cls, bound)) {
|
||||
context.raiseError("type argument ${cls.className} does not satisfy bound ${typeDeclName(bound)}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun resolveLocalTypeRef(name: String, pos: Pos): ObjRef? {
|
||||
val slotLoc = lookupSlotLocation(name, includeModule = true) ?: return null
|
||||
captureLocalRef(name, slotLoc, pos)?.let { return it }
|
||||
@ -2828,6 +2971,7 @@ class Compiler(
|
||||
implicitThisTypeName
|
||||
)
|
||||
} else {
|
||||
checkGenericBoundsAtCall(left.name, args, left.pos())
|
||||
CallRef(left, args, detectedBlockArgument, isOptional)
|
||||
}
|
||||
}
|
||||
@ -2848,6 +2992,7 @@ class Compiler(
|
||||
implicitThisTypeName
|
||||
)
|
||||
} else {
|
||||
checkGenericBoundsAtCall(left.name, args, left.pos())
|
||||
CallRef(left, args, detectedBlockArgument, isOptional)
|
||||
}
|
||||
}
|
||||
@ -3749,11 +3894,18 @@ class Compiler(
|
||||
val classCtx = codeContexts.lastOrNull() as? CodeContext.ClassBody
|
||||
val typeParamDecls = parseTypeParamList()
|
||||
classCtx?.typeParamDecls = typeParamDecls
|
||||
classCtx?.typeParams = typeParamDecls.map { it.name }.toSet()
|
||||
val constructorArgsDeclaration =
|
||||
val classTypeParams = typeParamDecls.map { it.name }.toSet()
|
||||
classCtx?.typeParams = classTypeParams
|
||||
pendingTypeParamStack.add(classTypeParams)
|
||||
val constructorArgsDeclaration: ArgsDeclaration?
|
||||
try {
|
||||
constructorArgsDeclaration =
|
||||
if (cc.skipTokenOfType(Token.Type.LPAREN, isOptional = true))
|
||||
parseArgsDeclaration(isClassDeclaration = true)
|
||||
else ArgsDeclaration(emptyList(), Token.Type.RPAREN)
|
||||
} finally {
|
||||
pendingTypeParamStack.removeLast()
|
||||
}
|
||||
|
||||
if (constructorArgsDeclaration != null && constructorArgsDeclaration.endTokenType != Token.Type.RPAREN)
|
||||
throw ScriptError(
|
||||
@ -3777,6 +3929,8 @@ class Compiler(
|
||||
data class BaseSpec(val name: String, val args: List<ParsedArgument>?)
|
||||
|
||||
val baseSpecs = mutableListOf<BaseSpec>()
|
||||
pendingTypeParamStack.add(classTypeParams)
|
||||
try {
|
||||
if (cc.skipTokenOfType(Token.Type.COLON, isOptional = true)) {
|
||||
do {
|
||||
val (baseDecl, _) = parseSimpleTypeExpressionWithMini()
|
||||
@ -3794,6 +3948,9 @@ class Compiler(
|
||||
baseSpecs += BaseSpec(baseName, argsList)
|
||||
} while (cc.skipTokenOfType(Token.Type.COMMA, isOptional = true))
|
||||
}
|
||||
} finally {
|
||||
pendingTypeParamStack.removeLast()
|
||||
}
|
||||
|
||||
cc.skipTokenOfType(Token.Type.NEWLINE, isOptional = true)
|
||||
|
||||
@ -4414,17 +4571,27 @@ class Compiler(
|
||||
|
||||
val typeParamDecls = parseTypeParamList()
|
||||
val typeParams = typeParamDecls.map { it.name }.toSet()
|
||||
|
||||
val argsDeclaration: ArgsDeclaration =
|
||||
pendingTypeParamStack.add(typeParams)
|
||||
val argsDeclaration: ArgsDeclaration
|
||||
val returnTypeMini: MiniTypeRef?
|
||||
try {
|
||||
argsDeclaration =
|
||||
if (cc.peekNextNonWhitespace().type == Token.Type.LPAREN) {
|
||||
cc.nextNonWhitespace() // consume (
|
||||
parseArgsDeclaration() ?: ArgsDeclaration(emptyList(), Token.Type.RPAREN)
|
||||
} else ArgsDeclaration(emptyList(), Token.Type.RPAREN)
|
||||
|
||||
if (typeParamDecls.isNotEmpty() && declKind != SymbolKind.MEMBER) {
|
||||
currentGenericFunctionDecls()[name] = GenericFunctionDecl(typeParamDecls, argsDeclaration.params, nameStartPos)
|
||||
}
|
||||
|
||||
// Optional return type
|
||||
val returnTypeMini: MiniTypeRef? = if (cc.peekNextNonWhitespace().type == Token.Type.COLON) {
|
||||
returnTypeMini = if (cc.peekNextNonWhitespace().type == Token.Type.COLON) {
|
||||
parseTypeDeclarationWithMini().second
|
||||
} else null
|
||||
} finally {
|
||||
pendingTypeParamStack.removeLast()
|
||||
}
|
||||
|
||||
var isDelegated = false
|
||||
var delegateExpression: Statement? = null
|
||||
@ -4485,8 +4652,9 @@ class Compiler(
|
||||
outerLabel?.let { cc.labels.add(it) }
|
||||
|
||||
val paramNamesList = argsDeclaration.params.map { it.name }
|
||||
val typeParamNames = typeParamDecls.map { it.name }
|
||||
val paramNames: Set<String> = paramNamesList.toSet()
|
||||
val paramSlotPlan = buildParamSlotPlan(paramNamesList)
|
||||
val paramSlotPlan = buildParamSlotPlan(paramNamesList + typeParamNames)
|
||||
val capturePlan = CapturePlan(paramSlotPlan)
|
||||
val rangeParamNames = argsDeclaration.params
|
||||
.filter { isRangeType(it.type) }
|
||||
@ -4588,6 +4756,7 @@ class Compiler(
|
||||
|
||||
// load params from caller context
|
||||
argsDeclaration.assignToContext(context, callerContext.args, defaultAccessType = AccessType.Val)
|
||||
bindTypeParamsAtRuntime(context, argsDeclaration, typeParamDecls)
|
||||
if (extTypeName != null) {
|
||||
context.thisObj = callerContext.thisObj
|
||||
}
|
||||
@ -4897,6 +5066,8 @@ class Compiler(
|
||||
is TypeDecl.Generic -> type.name
|
||||
is TypeDecl.Function -> "Callable"
|
||||
is TypeDecl.TypeVar -> return null
|
||||
is TypeDecl.Union -> return null
|
||||
is TypeDecl.Intersection -> return null
|
||||
else -> return null
|
||||
}
|
||||
val name = rawName.substringAfterLast('.')
|
||||
|
||||
@ -31,6 +31,8 @@ sealed class TypeDecl(val isNullable:Boolean = false) {
|
||||
val nullable: Boolean = false
|
||||
) : TypeDecl(nullable)
|
||||
data class TypeVar(val name: String, val nullable: Boolean = false) : TypeDecl(nullable)
|
||||
data class Union(val options: List<TypeDecl>, val nullable: Boolean = false) : TypeDecl(nullable)
|
||||
data class Intersection(val options: List<TypeDecl>, val nullable: Boolean = false) : TypeDecl(nullable)
|
||||
data class TypeParam(
|
||||
val name: String,
|
||||
val variance: Variance = Variance.Invariant,
|
||||
|
||||
@ -1025,6 +1025,8 @@ object DocLookupUtils {
|
||||
is MiniGenericType -> simpleClassNameOf(t.base)
|
||||
is MiniFunctionType -> null
|
||||
is MiniTypeVar -> null
|
||||
is MiniTypeUnion -> null
|
||||
is MiniTypeIntersection -> null
|
||||
}
|
||||
|
||||
fun typeOf(t: MiniTypeRef?): String = when (t) {
|
||||
@ -1035,6 +1037,8 @@ object DocLookupUtils {
|
||||
r + "(" + t.params.joinToString(", ") { typeOf(it) } + ") -> " + typeOf(t.returnType) + (if (t.nullable) "?" else "")
|
||||
}
|
||||
is MiniTypeVar -> t.name + (if (t.nullable) "?" else "")
|
||||
is MiniTypeUnion -> t.options.joinToString(" | ") { typeOf(it) } + (if (t.nullable) "?" else "")
|
||||
is MiniTypeIntersection -> t.options.joinToString(" & ") { typeOf(it) } + (if (t.nullable) "?" else "")
|
||||
null -> ""
|
||||
}
|
||||
|
||||
|
||||
@ -150,6 +150,18 @@ data class MiniTypeVar(
|
||||
val nullable: Boolean
|
||||
) : MiniTypeRef
|
||||
|
||||
data class MiniTypeUnion(
|
||||
override val range: MiniRange,
|
||||
val options: List<MiniTypeRef>,
|
||||
val nullable: Boolean
|
||||
) : MiniTypeRef
|
||||
|
||||
data class MiniTypeIntersection(
|
||||
override val range: MiniRange,
|
||||
val options: List<MiniTypeRef>,
|
||||
val nullable: Boolean
|
||||
) : MiniTypeRef
|
||||
|
||||
// Script and declarations (lean subset; can be extended later)
|
||||
sealed interface MiniNamedDecl : MiniNode {
|
||||
val name: String
|
||||
|
||||
@ -5273,6 +5273,40 @@ class ScriptTest {
|
||||
assertEquals(ObjFalse, scope.eval("isInt(\"42\")"))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testGenericBoundsAndReifiedTypeParams() = runTest {
|
||||
val resInt = eval(
|
||||
"""
|
||||
fun square<T: Int | Real>(x: T) = x * x
|
||||
square(2)
|
||||
""".trimIndent()
|
||||
)
|
||||
assertEquals(4L, (resInt as ObjInt).value)
|
||||
val resReal = eval(
|
||||
"""
|
||||
fun square<T: Int | Real>(x: T) = x * x
|
||||
square(1.5)
|
||||
""".trimIndent()
|
||||
)
|
||||
assertEquals(2.25, (resReal as ObjReal).value, 0.00001)
|
||||
assertFailsWith<ScriptError> {
|
||||
eval(
|
||||
"""
|
||||
fun square<T: Int | Real>(x: T) = x * x
|
||||
square("x")
|
||||
""".trimIndent()
|
||||
)
|
||||
}
|
||||
|
||||
val reified = eval(
|
||||
"""
|
||||
fun sameType<T>(x: T, y: Object) = y is T
|
||||
sameType(1, "a")
|
||||
""".trimIndent()
|
||||
)
|
||||
assertEquals(false, (reified as ObjBool).value)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun testFilterBug() = runTest {
|
||||
eval(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user