From 65edf9fe67159bd0f7be5cd3d30d7f2850b05c51 Mon Sep 17 00:00:00 2001 From: sergeych Date: Sat, 14 Mar 2026 19:18:09 +0300 Subject: [PATCH] Fix generic type checks and explicit type arg runtime binding --- .../kotlin/net/sergeych/lyng/Arguments.kt | 1 + .../kotlin/net/sergeych/lyng/Compiler.kt | 21 +++++++--- .../lyng/bytecode/BytecodeCompiler.kt | 31 ++++++++++---- .../sergeych/lyng/bytecode/BytecodeConst.kt | 6 ++- .../net/sergeych/lyng/bytecode/CmdRuntime.kt | 7 +++- .../kotlin/net/sergeych/lyng/obj/ObjRef.kt | 1 + .../net/sergeych/lyng/obj/ObjTypeExpr.kt | 4 +- lynglib/src/commonTest/kotlin/TypesTest.kt | 42 +++++++++++++++++++ 8 files changed, 95 insertions(+), 18 deletions(-) diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Arguments.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Arguments.kt index 61bb174..ea7ccf0 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Arguments.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Arguments.kt @@ -220,6 +220,7 @@ data class ParsedArgument( val list: List, val tailBlockMode: Boolean = false, val named: Map = emptyMap(), + val explicitTypeArgs: List = emptyList(), ) : List by list { constructor(vararg values: Obj) : this(values.toList()) diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt index 2fd2b63..ab84791 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/Compiler.kt @@ -4487,6 +4487,7 @@ class Compiler( is ListLiteralRef -> ObjList.type is MapLiteralRef -> ObjMap.type is RangeRef -> ObjRange.type + is ClassOperatorRef -> ObjClassType is CastRef -> resolveTypeRefClass(ref.castTypeRef()) else -> null } @@ -4580,6 +4581,7 @@ class Compiler( is ListLiteralRef -> ObjList.type is MapLiteralRef -> ObjMap.type is RangeRef -> ObjRange.type + is ClassOperatorRef -> ObjClassType is CastRef -> resolveTypeRefClass(ref.castTypeRef()) is QualifiedThisRef -> resolveClassByName(ref.typeName) is StatementRef -> (ref.statement as? ExpressionStatement)?.let { resolveReceiverClassForMember(it.ref) } @@ -5346,6 +5348,10 @@ class Compiler( typeParams: List ): Map { if (typeParams.isEmpty()) return emptyMap() + val explicitTypeArgs = context.args.explicitTypeArgs + if (explicitTypeArgs.size > typeParams.size) { + context.raiseError("too many type arguments: expected ${typeParams.size}, got ${explicitTypeArgs.size}") + } val inferred = mutableMapOf() val argValues = context.args.list for ((index, param) in argsDeclaration.params.withIndex()) { @@ -5358,8 +5364,11 @@ class Compiler( collectRuntimeTypeVarBindings(param.type, value, inferred) } val boundValues = LinkedHashMap(typeParams.size) - for (tp in typeParams) { - val inferredType = inferred[tp.name] ?: tp.defaultType ?: TypeDecl.TypeAny + for ((index, tp) in typeParams.withIndex()) { + val inferredType = explicitTypeArgs.getOrNull(index) + ?: inferred[tp.name] + ?: tp.defaultType + ?: TypeDecl.TypeAny val normalized = normalizeRuntimeTypeDecl(inferredType) val cls = resolveTypeDeclObjClass(normalized) val boundValue = if (cls != null && @@ -5681,7 +5690,7 @@ class Compiler( val result = when (left) { is ImplicitThisMemberRef -> if (left.methodId == null && left.fieldId != null) { - CallRef(left, args, detectedBlockArgument, isOptional).also { callRef -> + CallRef(left, args, detectedBlockArgument, isOptional, explicitTypeArgs).also { callRef -> applyExplicitCallTypeArgs(callRef, explicitTypeArgs) } } else { @@ -5716,7 +5725,7 @@ class Compiler( checkFunctionTypeCallArity(left, args, left.pos()) checkFunctionTypeCallTypes(left, args, left.pos()) checkGenericBoundsAtCall(left.name, args, left.pos()) - CallRef(left, args, detectedBlockArgument, isOptional).also { callRef -> + CallRef(left, args, detectedBlockArgument, isOptional, explicitTypeArgs).also { callRef -> applyExplicitCallTypeArgs(callRef, explicitTypeArgs) } } @@ -5742,12 +5751,12 @@ class Compiler( checkFunctionTypeCallArity(left, args, left.pos()) checkFunctionTypeCallTypes(left, args, left.pos()) checkGenericBoundsAtCall(left.name, args, left.pos()) - CallRef(left, args, detectedBlockArgument, isOptional).also { callRef -> + CallRef(left, args, detectedBlockArgument, isOptional, explicitTypeArgs).also { callRef -> applyExplicitCallTypeArgs(callRef, explicitTypeArgs) } } } - else -> CallRef(left, args, detectedBlockArgument, isOptional).also { callRef -> + else -> CallRef(left, args, detectedBlockArgument, isOptional, explicitTypeArgs).also { callRef -> applyExplicitCallTypeArgs(callRef, explicitTypeArgs) } } diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt index f18186b..7d3fdc9 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeCompiler.kt @@ -4395,7 +4395,7 @@ class BytecodeCompiler( val dst = allocSlot() val encodedMethodId = encodeMemberId(receiverClass, methodId) ?: methodId if (!ref.isOptionalInvoke) { - val args = compileCallArgs(ref.args, ref.tailBlock) ?: return null + val args = compileCallArgs(ref.args, ref.tailBlock, ref.explicitTypeArgs) ?: return null val encodedCount = encodeCallArgCount(args) ?: return null setPos(callPos) builder.emit(Opcode.CALL_MEMBER_SLOT, receiver.slot, encodedMethodId, args.base, encodedCount, dst) @@ -4410,7 +4410,7 @@ class BytecodeCompiler( Opcode.JMP_IF_TRUE, listOf(CmdBuilder.Operand.IntVal(cmpSlot), CmdBuilder.Operand.LabelRef(nullLabel)) ) - val args = compileCallArgs(ref.args, ref.tailBlock) ?: return null + val args = compileCallArgs(ref.args, ref.tailBlock, ref.explicitTypeArgs) ?: return null val encodedCount = encodeCallArgCount(args) ?: return null setPos(callPos) builder.emit(Opcode.CALL_MEMBER_SLOT, receiver.slot, encodedMethodId, args.base, encodedCount, dst) @@ -4450,7 +4450,7 @@ class BytecodeCompiler( val callee = compileRefWithFallback(ref.target, null, refPosOrCurrent(ref.target)) ?: return null val dst = allocSlot() if (!ref.isOptionalInvoke) { - val args = compileCallArgs(ref.args, ref.tailBlock) ?: return null + val args = compileCallArgs(ref.args, ref.tailBlock, ref.explicitTypeArgs) ?: return null val encodedCount = encodeCallArgCount(args) ?: return null setPos(callPos) builder.emit( @@ -4475,7 +4475,7 @@ class BytecodeCompiler( Opcode.JMP_IF_TRUE, listOf(CmdBuilder.Operand.IntVal(cmpSlot), CmdBuilder.Operand.LabelRef(nullLabel)) ) - val args = compileCallArgs(ref.args, ref.tailBlock) ?: return null + val args = compileCallArgs(ref.args, ref.tailBlock, ref.explicitTypeArgs) ?: return null val encodedCount = encodeCallArgCount(args) ?: return null setPos(callPos) builder.emit( @@ -4876,10 +4876,14 @@ class BytecodeCompiler( return CallArgs(base = argSlots[0], count = argSlots.size, planId = planId) } - private fun compileCallArgs(args: List, tailBlock: Boolean): CallArgs? { - if (args.isEmpty()) return CallArgs(base = 0, count = 0, planId = null) + private fun compileCallArgs( + args: List, + tailBlock: Boolean, + explicitTypeArgs: List? = null + ): CallArgs? { + if (args.isEmpty() && explicitTypeArgs.isNullOrEmpty()) return CallArgs(base = 0, count = 0, planId = null) val argSlots = IntArray(args.size) { allocSlot() } - val needPlan = tailBlock || args.any { it.isSplat || it.name != null } + val needPlan = tailBlock || args.any { it.isSplat || it.name != null } || !explicitTypeArgs.isNullOrEmpty() val specs = if (needPlan) ArrayList(args.size) else null for ((index, arg) in args.withIndex()) { val compiled = compileArgValue(arg.value) ?: return null @@ -4891,11 +4895,17 @@ class BytecodeCompiler( specs?.add(BytecodeConst.CallArgSpec(arg.name, arg.isSplat)) } val planId = if (needPlan) { - builder.addConst(BytecodeConst.CallArgsPlan(tailBlock, specs ?: emptyList())) + builder.addConst( + BytecodeConst.CallArgsPlan( + tailBlock = tailBlock, + specs = specs ?: emptyList(), + explicitTypeArgs = explicitTypeArgs ?: emptyList() + ) + ) } else { null } - return CallArgs(base = argSlots[0], count = argSlots.size, planId = planId) + return CallArgs(base = if (argSlots.isEmpty()) 0 else argSlots[0], count = argSlots.size, planId = planId) } private fun compileArgValue(value: Obj): CompiledValue? { @@ -8464,6 +8474,9 @@ class BytecodeCompiler( collectScopeSlotsRef(ref.targetRef) collectScopeSlotsRef(ref.indexRef) } + is ClassOperatorRef -> { + collectScopeSlotsRef(ref.target) + } is ListLiteralRef -> { for (entry in ref.entries()) { when (entry) { diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeConst.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeConst.kt index 4a546b0..7f19d16 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeConst.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/BytecodeConst.kt @@ -180,6 +180,10 @@ sealed class BytecodeConst { val pattern: ListLiteralRef, val pos: Pos, ) : BytecodeConst() - data class CallArgsPlan(val tailBlock: Boolean, val specs: List) : BytecodeConst() + data class CallArgsPlan( + val tailBlock: Boolean, + val specs: List, + val explicitTypeArgs: List = emptyList() + ) : BytecodeConst() data class CallArgSpec(val name: String?, val isSplat: Boolean) } diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/CmdRuntime.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/CmdRuntime.kt index d9a496c..437550c 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/CmdRuntime.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/bytecode/CmdRuntime.kt @@ -4699,7 +4699,12 @@ class CmdFrame( positional.add(value) } } - return Arguments(positional, plan.tailBlock, named ?: emptyMap()) + return Arguments( + list = positional, + tailBlockMode = plan.tailBlock, + named = named ?: emptyMap(), + explicitTypeArgs = plan.explicitTypeArgs + ) } private fun resolveLocalScope(localIndex: Int): Scope? { diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/ObjRef.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/ObjRef.kt index 7c1f863..8f73b60 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/ObjRef.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/ObjRef.kt @@ -469,6 +469,7 @@ class CallRef( internal val args: List, internal val tailBlock: Boolean, internal val isOptionalInvoke: Boolean, + internal val explicitTypeArgs: List? = null, ) : ObjRef { override suspend fun get(scope: Scope): ObjRecord = scope.raiseObjRefEvalDisabled() } diff --git a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/ObjTypeExpr.kt b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/ObjTypeExpr.kt index 00494cc..89d063e 100644 --- a/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/ObjTypeExpr.kt +++ b/lynglib/src/commonMain/kotlin/net/sergeych/lyng/obj/ObjTypeExpr.kt @@ -99,7 +99,9 @@ internal fun typeDeclIsSubtype(scope: Scope, left: TypeDecl, right: TypeDecl): B is TypeDecl.Simple, is TypeDecl.Generic, is TypeDecl.Function, is TypeDecl.Ellipsis -> { val leftClass = resolveTypeDeclClass(scope, l) ?: return false val rightClass = resolveTypeDeclClass(scope, r) ?: return false - leftClass == rightClass || leftClass.allParentsSet.contains(rightClass) + leftClass == rightClass || + rightClass == Obj.rootObjectType || + leftClass.allParentsSet.contains(rightClass) } else -> false } diff --git a/lynglib/src/commonTest/kotlin/TypesTest.kt b/lynglib/src/commonTest/kotlin/TypesTest.kt index 2318feb..c685f5f 100644 --- a/lynglib/src/commonTest/kotlin/TypesTest.kt +++ b/lynglib/src/commonTest/kotlin/TypesTest.kt @@ -510,4 +510,46 @@ class TypesTest { ) } } + + @Test + fun testClassName() = runTest { + eval(""" + class X { + var x = 1 + } + assert( X::class is Class) + assertEquals("Class", X::class.name) + """.trimIndent()) + } + + @Test + fun testGenericTypes() = runTest { + eval(""" + fun t(): String = + when(T) { + null -> "%s is Null"(T::class.name) + is Object -> "%s is Object"(T::class.name) + else -> throw "It should not happen" + } + assert( Int is Object) + assertEquals( t(), "Class is Object") + """.trimIndent()) + } + +// @Test fun nonTrivialOperatorsTest() = runTest { +// val s = Script.newScope() +// s.eval(""" +// class Matrix(val rows: Int, val cols: Int,initialValue:T?) { +// val data +// init { +// val v = initalValue? +// } +// data = List(rows*cols) { initialValue } } +// fun getAt(row: Int, col: Int) = data[row*cols+col] +// fun setAt(row: Int, col: Int, value: T) { data[row*cols+col] = value } +// } +// val m = Matrix(1,1) +// +// """.trimIndent()) +// } }