diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt index dda200f..f5e06c4 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt @@ -64,12 +64,38 @@ class Compiler( ) private val slotPlanStack = mutableListOf() private var nextScopeId = 0 + private val genericFunctionDeclsStack = mutableListOf>(mutableMapOf()) // Track declared local variables count per function for precise capacity hints private val localDeclCountStack = mutableListOf() private val currentLocalDeclCount: Int get() = localDeclCountStack.lastOrNull() ?: 0 + private data class GenericFunctionDecl( + val typeParams: List, + val params: List, + val pos: Pos + ) + + private fun pushGenericFunctionScope() { + genericFunctionDeclsStack.add(mutableMapOf()) + } + + private fun popGenericFunctionScope() { + genericFunctionDeclsStack.removeLast() + } + + private fun currentGenericFunctionDecls(): MutableMap { + return genericFunctionDeclsStack.last() + } + + private fun lookupGenericFunctionDecl(name: String): GenericFunctionDecl? { + for (i in genericFunctionDeclsStack.indices.reversed()) { + genericFunctionDeclsStack[i][name]?.let { return it } + } + return null + } + private inline fun withLocalNames(names: Set, block: () -> T): T { localNamesStack.add(names.toMutableSet()) return try { @@ -440,6 +466,7 @@ class Compiler( private fun currentTypeParams(): Set { val result = mutableSetOf() + pendingTypeParamStack.lastOrNull()?.let { result.addAll(it) } for (ctx in codeContexts.asReversed()) { when (ctx) { is CodeContext.Function -> result.addAll(ctx.typeParams) @@ -450,6 +477,8 @@ class Compiler( return result } + private val pendingTypeParamStack = mutableListOf>() + private fun parseTypeParamList(): List { if (cc.peekNextNonWhitespace().type != Token.Type.LT) return emptyList() val typeParams = mutableListOf() @@ -774,6 +803,7 @@ class Compiler( private suspend fun inCodeContext(context: CodeContext, f: suspend () -> T): T { codeContexts.add(context) + pushGenericFunctionScope() try { val res = f() if (context is CodeContext.ClassBody) { @@ -784,6 +814,7 @@ class Compiler( } return res } finally { + popGenericFunctionScope() codeContexts.removeLast() } } @@ -2401,6 +2432,36 @@ class Compiler( } private fun parseTypeExpressionWithMini(): Pair { + return parseTypeUnionWithMini() + } + + private fun parseTypeUnionWithMini(): Pair { + var left = parseTypeIntersectionWithMini() + val options = mutableListOf(left) + while (cc.skipTokenOfType(Token.Type.BITOR, isOptional = true)) { + options += parseTypeIntersectionWithMini() + } + if (options.size == 1) return left + val rangeStart = options.first().second.range.start + val rangeEnd = cc.currentPos() + val mini = MiniTypeUnion(MiniRange(rangeStart, rangeEnd), options.map { it.second }, nullable = false) + return TypeDecl.Union(options.map { it.first }, nullable = false) to mini + } + + private fun parseTypeIntersectionWithMini(): Pair { + var left = parseTypePrimaryWithMini() + val options = mutableListOf(left) + while (cc.skipTokenOfType(Token.Type.BITAND, isOptional = true)) { + options += parseTypePrimaryWithMini() + } + if (options.size == 1) return left + val rangeStart = options.first().second.range.start + val rangeEnd = cc.currentPos() + val mini = MiniTypeIntersection(MiniRange(rangeStart, rangeEnd), options.map { it.second }, nullable = false) + return TypeDecl.Intersection(options.map { it.first }, nullable = false) to mini + } + + private fun parseTypePrimaryWithMini(): Pair { parseFunctionTypeWithMini()?.let { return it } return parseSimpleTypeExpressionWithMini() } @@ -2595,8 +2656,8 @@ class Compiler( private fun typeDeclToTypeRef(typeDecl: TypeDecl, pos: Pos): ObjRef { return when (typeDecl) { TypeDecl.TypeAny, - TypeDecl.TypeNullableAny, - is TypeDecl.TypeVar -> ConstRef(Obj.rootObjectType.asReadonly) + TypeDecl.TypeNullableAny -> ConstRef(Obj.rootObjectType.asReadonly) + is TypeDecl.TypeVar -> resolveLocalTypeRef(typeDecl.name, pos) ?: ConstRef(Obj.rootObjectType.asReadonly) else -> { val cls = resolveTypeDeclObjClass(typeDecl) if (cls != null) return ConstRef(cls.asReadonly) @@ -2612,10 +2673,92 @@ class Compiler( is TypeDecl.Generic -> typeDecl.name is TypeDecl.Function -> "Callable" is TypeDecl.TypeVar -> typeDecl.name + is TypeDecl.Union -> typeDecl.options.joinToString(" | ") { typeDeclName(it) } + is TypeDecl.Intersection -> typeDecl.options.joinToString(" & ") { typeDeclName(it) } TypeDecl.TypeAny -> "Object" TypeDecl.TypeNullableAny -> "Object?" } + private fun inferObjClassFromRef(ref: ObjRef): ObjClass? = when (ref) { + is ConstRef -> ref.constValue as? ObjClass ?: (ref.constValue as? Obj)?.objClass + is LocalVarRef -> nameObjClass[ref.name] + is LocalSlotRef -> nameObjClass[ref.name] + is ListLiteralRef -> ObjList.type + is MapLiteralRef -> ObjMap.type + is RangeRef -> ObjRange.type + is CastRef -> resolveTypeRefClass(ref.castTypeRef()) + else -> null + } + + private fun resolveTypeRefClass(ref: ObjRef): ObjClass? = when (ref) { + is ConstRef -> ref.constValue as? ObjClass + is LocalSlotRef -> resolveTypeDeclObjClass(TypeDecl.Simple(ref.name, false)) ?: nameObjClass[ref.name] + is LocalVarRef -> resolveTypeDeclObjClass(TypeDecl.Simple(ref.name, false)) ?: nameObjClass[ref.name] + else -> null + } + + private fun typeParamBoundSatisfied(argClass: ObjClass, bound: TypeDecl): Boolean = when (bound) { + is TypeDecl.Union -> bound.options.any { typeParamBoundSatisfied(argClass, it) } + is TypeDecl.Intersection -> bound.options.all { typeParamBoundSatisfied(argClass, it) } + is TypeDecl.Simple, is TypeDecl.Generic -> { + val boundClass = resolveTypeDeclObjClass(bound) ?: return false + argClass == boundClass || argClass.allParentsSet.contains(boundClass) + } + else -> true + } + + private fun checkGenericBoundsAtCall( + name: String, + args: List, + pos: Pos + ) { + val decl = lookupGenericFunctionDecl(name) ?: return + val inferred = mutableMapOf() + val limit = minOf(args.size, decl.params.size) + for (i in 0 until limit) { + val paramType = decl.params[i].type + val argRef = (args[i].value as? ExpressionStatement)?.ref ?: continue + val argClass = inferObjClassFromRef(argRef) ?: continue + if (paramType is TypeDecl.TypeVar) { + inferred[paramType.name] = argClass + } + } + for (tp in decl.typeParams) { + val argClass = inferred[tp.name] ?: continue + val bound = tp.bound ?: continue + if (!typeParamBoundSatisfied(argClass, bound)) { + throw ScriptError(pos, "type argument ${argClass.className} does not satisfy bound ${typeDeclName(bound)}") + } + } + } + + private fun bindTypeParamsAtRuntime( + context: Scope, + argsDeclaration: ArgsDeclaration, + typeParams: List + ) { + if (typeParams.isEmpty()) return + val inferred = mutableMapOf() + for (param in argsDeclaration.params) { + val paramType = param.type + if (paramType is TypeDecl.TypeVar) { + val rec = context.getLocalRecordDirect(param.name) ?: continue + val value = rec.value + if (value is Obj) inferred[paramType.name] = value.objClass + } + } + for (tp in typeParams) { + val cls = inferred[tp.name] + ?: tp.defaultType?.let { resolveTypeDeclObjClass(it) } + ?: Obj.rootObjectType + context.addConst(tp.name, cls) + val bound = tp.bound ?: continue + if (!typeParamBoundSatisfied(cls, bound)) { + context.raiseError("type argument ${cls.className} does not satisfy bound ${typeDeclName(bound)}") + } + } + } + private fun resolveLocalTypeRef(name: String, pos: Pos): ObjRef? { val slotLoc = lookupSlotLocation(name, includeModule = true) ?: return null captureLocalRef(name, slotLoc, pos)?.let { return it } @@ -2828,6 +2971,7 @@ class Compiler( implicitThisTypeName ) } else { + checkGenericBoundsAtCall(left.name, args, left.pos()) CallRef(left, args, detectedBlockArgument, isOptional) } } @@ -2848,6 +2992,7 @@ class Compiler( implicitThisTypeName ) } else { + checkGenericBoundsAtCall(left.name, args, left.pos()) CallRef(left, args, detectedBlockArgument, isOptional) } } @@ -3749,11 +3894,18 @@ class Compiler( val classCtx = codeContexts.lastOrNull() as? CodeContext.ClassBody val typeParamDecls = parseTypeParamList() classCtx?.typeParamDecls = typeParamDecls - classCtx?.typeParams = typeParamDecls.map { it.name }.toSet() - val constructorArgsDeclaration = - if (cc.skipTokenOfType(Token.Type.LPAREN, isOptional = true)) - parseArgsDeclaration(isClassDeclaration = true) - else ArgsDeclaration(emptyList(), Token.Type.RPAREN) + val classTypeParams = typeParamDecls.map { it.name }.toSet() + classCtx?.typeParams = classTypeParams + pendingTypeParamStack.add(classTypeParams) + val constructorArgsDeclaration: ArgsDeclaration? + try { + constructorArgsDeclaration = + if (cc.skipTokenOfType(Token.Type.LPAREN, isOptional = true)) + parseArgsDeclaration(isClassDeclaration = true) + else ArgsDeclaration(emptyList(), Token.Type.RPAREN) + } finally { + pendingTypeParamStack.removeLast() + } if (constructorArgsDeclaration != null && constructorArgsDeclaration.endTokenType != Token.Type.RPAREN) throw ScriptError( @@ -3777,22 +3929,27 @@ class Compiler( data class BaseSpec(val name: String, val args: List?) val baseSpecs = mutableListOf() - if (cc.skipTokenOfType(Token.Type.COLON, isOptional = true)) { - do { - val (baseDecl, _) = parseSimpleTypeExpressionWithMini() - val baseName = when (baseDecl) { - is TypeDecl.Simple -> baseDecl.name - is TypeDecl.Generic -> baseDecl.name - else -> throw ScriptError(cc.currentPos(), "base class name expected") - } - var argsList: List? = null - // Optional constructor args of the base — parse and ignore for now (MVP), just to consume tokens - if (cc.skipTokenOfType(Token.Type.LPAREN, isOptional = true)) { - // Parse args without consuming any following block so that a class body can follow safely - argsList = parseArgsNoTailBlock() - } - baseSpecs += BaseSpec(baseName, argsList) - } while (cc.skipTokenOfType(Token.Type.COMMA, isOptional = true)) + pendingTypeParamStack.add(classTypeParams) + try { + if (cc.skipTokenOfType(Token.Type.COLON, isOptional = true)) { + do { + val (baseDecl, _) = parseSimpleTypeExpressionWithMini() + val baseName = when (baseDecl) { + is TypeDecl.Simple -> baseDecl.name + is TypeDecl.Generic -> baseDecl.name + else -> throw ScriptError(cc.currentPos(), "base class name expected") + } + var argsList: List? = null + // Optional constructor args of the base — parse and ignore for now (MVP), just to consume tokens + if (cc.skipTokenOfType(Token.Type.LPAREN, isOptional = true)) { + // Parse args without consuming any following block so that a class body can follow safely + argsList = parseArgsNoTailBlock() + } + baseSpecs += BaseSpec(baseName, argsList) + } while (cc.skipTokenOfType(Token.Type.COMMA, isOptional = true)) + } + } finally { + pendingTypeParamStack.removeLast() } cc.skipTokenOfType(Token.Type.NEWLINE, isOptional = true) @@ -4414,17 +4571,27 @@ class Compiler( val typeParamDecls = parseTypeParamList() val typeParams = typeParamDecls.map { it.name }.toSet() + pendingTypeParamStack.add(typeParams) + val argsDeclaration: ArgsDeclaration + val returnTypeMini: MiniTypeRef? + try { + argsDeclaration = + if (cc.peekNextNonWhitespace().type == Token.Type.LPAREN) { + cc.nextNonWhitespace() // consume ( + parseArgsDeclaration() ?: ArgsDeclaration(emptyList(), Token.Type.RPAREN) + } else ArgsDeclaration(emptyList(), Token.Type.RPAREN) - val argsDeclaration: ArgsDeclaration = - if (cc.peekNextNonWhitespace().type == Token.Type.LPAREN) { - cc.nextNonWhitespace() // consume ( - parseArgsDeclaration() ?: ArgsDeclaration(emptyList(), Token.Type.RPAREN) - } else ArgsDeclaration(emptyList(), Token.Type.RPAREN) + if (typeParamDecls.isNotEmpty() && declKind != SymbolKind.MEMBER) { + currentGenericFunctionDecls()[name] = GenericFunctionDecl(typeParamDecls, argsDeclaration.params, nameStartPos) + } - // Optional return type - val returnTypeMini: MiniTypeRef? = if (cc.peekNextNonWhitespace().type == Token.Type.COLON) { - parseTypeDeclarationWithMini().second - } else null + // Optional return type + returnTypeMini = if (cc.peekNextNonWhitespace().type == Token.Type.COLON) { + parseTypeDeclarationWithMini().second + } else null + } finally { + pendingTypeParamStack.removeLast() + } var isDelegated = false var delegateExpression: Statement? = null @@ -4485,8 +4652,9 @@ class Compiler( outerLabel?.let { cc.labels.add(it) } val paramNamesList = argsDeclaration.params.map { it.name } + val typeParamNames = typeParamDecls.map { it.name } val paramNames: Set = paramNamesList.toSet() - val paramSlotPlan = buildParamSlotPlan(paramNamesList) + val paramSlotPlan = buildParamSlotPlan(paramNamesList + typeParamNames) val capturePlan = CapturePlan(paramSlotPlan) val rangeParamNames = argsDeclaration.params .filter { isRangeType(it.type) } @@ -4588,6 +4756,7 @@ class Compiler( // load params from caller context argsDeclaration.assignToContext(context, callerContext.args, defaultAccessType = AccessType.Val) + bindTypeParamsAtRuntime(context, argsDeclaration, typeParamDecls) if (extTypeName != null) { context.thisObj = callerContext.thisObj } @@ -4897,6 +5066,8 @@ class Compiler( is TypeDecl.Generic -> type.name is TypeDecl.Function -> "Callable" is TypeDecl.TypeVar -> return null + is TypeDecl.Union -> return null + is TypeDecl.Intersection -> return null else -> return null } val name = rawName.substringAfterLast('.') diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/TypeDecl.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/TypeDecl.kt index 1294944..6c63233 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/TypeDecl.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/TypeDecl.kt @@ -31,6 +31,8 @@ sealed class TypeDecl(val isNullable:Boolean = false) { val nullable: Boolean = false ) : TypeDecl(nullable) data class TypeVar(val name: String, val nullable: Boolean = false) : TypeDecl(nullable) + data class Union(val options: List, val nullable: Boolean = false) : TypeDecl(nullable) + data class Intersection(val options: List, val nullable: Boolean = false) : TypeDecl(nullable) data class TypeParam( val name: String, val variance: Variance = Variance.Invariant, diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/miniast/DocLookupUtils.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/miniast/DocLookupUtils.kt index 8441aca..1d4c27e 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/miniast/DocLookupUtils.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/miniast/DocLookupUtils.kt @@ -1025,6 +1025,8 @@ object DocLookupUtils { is MiniGenericType -> simpleClassNameOf(t.base) is MiniFunctionType -> null is MiniTypeVar -> null + is MiniTypeUnion -> null + is MiniTypeIntersection -> null } fun typeOf(t: MiniTypeRef?): String = when (t) { @@ -1035,6 +1037,8 @@ object DocLookupUtils { r + "(" + t.params.joinToString(", ") { typeOf(it) } + ") -> " + typeOf(t.returnType) + (if (t.nullable) "?" else "") } is MiniTypeVar -> t.name + (if (t.nullable) "?" else "") + is MiniTypeUnion -> t.options.joinToString(" | ") { typeOf(it) } + (if (t.nullable) "?" else "") + is MiniTypeIntersection -> t.options.joinToString(" & ") { typeOf(it) } + (if (t.nullable) "?" else "") null -> "" } diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/miniast/MiniAst.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/miniast/MiniAst.kt index 4f67ab2..d1ced99 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/miniast/MiniAst.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/miniast/MiniAst.kt @@ -150,6 +150,18 @@ data class MiniTypeVar( val nullable: Boolean ) : MiniTypeRef +data class MiniTypeUnion( + override val range: MiniRange, + val options: List, + val nullable: Boolean +) : MiniTypeRef + +data class MiniTypeIntersection( + override val range: MiniRange, + val options: List, + val nullable: Boolean +) : MiniTypeRef + // Script and declarations (lean subset; can be extended later) sealed interface MiniNamedDecl : MiniNode { val name: String diff --git a/lynglib/src/commonTest/kotlin/ScriptTest.kt b/lynglib/src/commonTest/kotlin/ScriptTest.kt index 0104f50..fdb761f 100644 --- a/lynglib/src/commonTest/kotlin/ScriptTest.kt +++ b/lynglib/src/commonTest/kotlin/ScriptTest.kt @@ -5273,6 +5273,40 @@ class ScriptTest { assertEquals(ObjFalse, scope.eval("isInt(\"42\")")) } + @Test + fun testGenericBoundsAndReifiedTypeParams() = runTest { + val resInt = eval( + """ + fun square(x: T) = x * x + square(2) + """.trimIndent() + ) + assertEquals(4L, (resInt as ObjInt).value) + val resReal = eval( + """ + fun square(x: T) = x * x + square(1.5) + """.trimIndent() + ) + assertEquals(2.25, (resReal as ObjReal).value, 0.00001) + assertFailsWith { + eval( + """ + fun square(x: T) = x * x + square("x") + """.trimIndent() + ) + } + + val reified = eval( + """ + fun sameType(x: T, y: Object) = y is T + sameType(1, "a") + """.trimIndent() + ) + assertEquals(false, (reified as ObjBool).value) + } + @Test fun testFilterBug() = runTest { eval(