Fix generic type checks and explicit type arg runtime binding

This commit is contained in:
Sergey Chernov 2026-03-14 19:18:09 +03:00
parent 18394ce286
commit 65edf9fe67
8 changed files with 95 additions and 18 deletions

View File

@ -220,6 +220,7 @@ data class ParsedArgument(
val list: List<Obj>,
val tailBlockMode: Boolean = false,
val named: Map<String, Obj> = emptyMap(),
val explicitTypeArgs: List<TypeDecl> = emptyList(),
) : List<Obj> by list {
constructor(vararg values: Obj) : this(values.toList())

View File

@ -4487,6 +4487,7 @@ class Compiler(
is ListLiteralRef -> ObjList.type
is MapLiteralRef -> ObjMap.type
is RangeRef -> ObjRange.type
is ClassOperatorRef -> ObjClassType
is CastRef -> resolveTypeRefClass(ref.castTypeRef())
else -> null
}
@ -4580,6 +4581,7 @@ class Compiler(
is ListLiteralRef -> ObjList.type
is MapLiteralRef -> ObjMap.type
is RangeRef -> ObjRange.type
is ClassOperatorRef -> ObjClassType
is CastRef -> resolveTypeRefClass(ref.castTypeRef())
is QualifiedThisRef -> resolveClassByName(ref.typeName)
is StatementRef -> (ref.statement as? ExpressionStatement)?.let { resolveReceiverClassForMember(it.ref) }
@ -5346,6 +5348,10 @@ class Compiler(
typeParams: List<TypeDecl.TypeParam>
): Map<String, Obj> {
if (typeParams.isEmpty()) return emptyMap()
val explicitTypeArgs = context.args.explicitTypeArgs
if (explicitTypeArgs.size > typeParams.size) {
context.raiseError("too many type arguments: expected ${typeParams.size}, got ${explicitTypeArgs.size}")
}
val inferred = mutableMapOf<String, TypeDecl>()
val argValues = context.args.list
for ((index, param) in argsDeclaration.params.withIndex()) {
@ -5358,8 +5364,11 @@ class Compiler(
collectRuntimeTypeVarBindings(param.type, value, inferred)
}
val boundValues = LinkedHashMap<String, Obj>(typeParams.size)
for (tp in typeParams) {
val inferredType = inferred[tp.name] ?: tp.defaultType ?: TypeDecl.TypeAny
for ((index, tp) in typeParams.withIndex()) {
val inferredType = explicitTypeArgs.getOrNull(index)
?: inferred[tp.name]
?: tp.defaultType
?: TypeDecl.TypeAny
val normalized = normalizeRuntimeTypeDecl(inferredType)
val cls = resolveTypeDeclObjClass(normalized)
val boundValue = if (cls != null &&
@ -5681,7 +5690,7 @@ class Compiler(
val result = when (left) {
is ImplicitThisMemberRef ->
if (left.methodId == null && left.fieldId != null) {
CallRef(left, args, detectedBlockArgument, isOptional).also { callRef ->
CallRef(left, args, detectedBlockArgument, isOptional, explicitTypeArgs).also { callRef ->
applyExplicitCallTypeArgs(callRef, explicitTypeArgs)
}
} else {
@ -5716,7 +5725,7 @@ class Compiler(
checkFunctionTypeCallArity(left, args, left.pos())
checkFunctionTypeCallTypes(left, args, left.pos())
checkGenericBoundsAtCall(left.name, args, left.pos())
CallRef(left, args, detectedBlockArgument, isOptional).also { callRef ->
CallRef(left, args, detectedBlockArgument, isOptional, explicitTypeArgs).also { callRef ->
applyExplicitCallTypeArgs(callRef, explicitTypeArgs)
}
}
@ -5742,12 +5751,12 @@ class Compiler(
checkFunctionTypeCallArity(left, args, left.pos())
checkFunctionTypeCallTypes(left, args, left.pos())
checkGenericBoundsAtCall(left.name, args, left.pos())
CallRef(left, args, detectedBlockArgument, isOptional).also { callRef ->
CallRef(left, args, detectedBlockArgument, isOptional, explicitTypeArgs).also { callRef ->
applyExplicitCallTypeArgs(callRef, explicitTypeArgs)
}
}
}
else -> CallRef(left, args, detectedBlockArgument, isOptional).also { callRef ->
else -> CallRef(left, args, detectedBlockArgument, isOptional, explicitTypeArgs).also { callRef ->
applyExplicitCallTypeArgs(callRef, explicitTypeArgs)
}
}

View File

@ -4395,7 +4395,7 @@ class BytecodeCompiler(
val dst = allocSlot()
val encodedMethodId = encodeMemberId(receiverClass, methodId) ?: methodId
if (!ref.isOptionalInvoke) {
val args = compileCallArgs(ref.args, ref.tailBlock) ?: return null
val args = compileCallArgs(ref.args, ref.tailBlock, ref.explicitTypeArgs) ?: return null
val encodedCount = encodeCallArgCount(args) ?: return null
setPos(callPos)
builder.emit(Opcode.CALL_MEMBER_SLOT, receiver.slot, encodedMethodId, args.base, encodedCount, dst)
@ -4410,7 +4410,7 @@ class BytecodeCompiler(
Opcode.JMP_IF_TRUE,
listOf(CmdBuilder.Operand.IntVal(cmpSlot), CmdBuilder.Operand.LabelRef(nullLabel))
)
val args = compileCallArgs(ref.args, ref.tailBlock) ?: return null
val args = compileCallArgs(ref.args, ref.tailBlock, ref.explicitTypeArgs) ?: return null
val encodedCount = encodeCallArgCount(args) ?: return null
setPos(callPos)
builder.emit(Opcode.CALL_MEMBER_SLOT, receiver.slot, encodedMethodId, args.base, encodedCount, dst)
@ -4450,7 +4450,7 @@ class BytecodeCompiler(
val callee = compileRefWithFallback(ref.target, null, refPosOrCurrent(ref.target)) ?: return null
val dst = allocSlot()
if (!ref.isOptionalInvoke) {
val args = compileCallArgs(ref.args, ref.tailBlock) ?: return null
val args = compileCallArgs(ref.args, ref.tailBlock, ref.explicitTypeArgs) ?: return null
val encodedCount = encodeCallArgCount(args) ?: return null
setPos(callPos)
builder.emit(
@ -4475,7 +4475,7 @@ class BytecodeCompiler(
Opcode.JMP_IF_TRUE,
listOf(CmdBuilder.Operand.IntVal(cmpSlot), CmdBuilder.Operand.LabelRef(nullLabel))
)
val args = compileCallArgs(ref.args, ref.tailBlock) ?: return null
val args = compileCallArgs(ref.args, ref.tailBlock, ref.explicitTypeArgs) ?: return null
val encodedCount = encodeCallArgCount(args) ?: return null
setPos(callPos)
builder.emit(
@ -4876,10 +4876,14 @@ class BytecodeCompiler(
return CallArgs(base = argSlots[0], count = argSlots.size, planId = planId)
}
private fun compileCallArgs(args: List<ParsedArgument>, tailBlock: Boolean): CallArgs? {
if (args.isEmpty()) return CallArgs(base = 0, count = 0, planId = null)
private fun compileCallArgs(
args: List<ParsedArgument>,
tailBlock: Boolean,
explicitTypeArgs: List<TypeDecl>? = null
): CallArgs? {
if (args.isEmpty() && explicitTypeArgs.isNullOrEmpty()) return CallArgs(base = 0, count = 0, planId = null)
val argSlots = IntArray(args.size) { allocSlot() }
val needPlan = tailBlock || args.any { it.isSplat || it.name != null }
val needPlan = tailBlock || args.any { it.isSplat || it.name != null } || !explicitTypeArgs.isNullOrEmpty()
val specs = if (needPlan) ArrayList<BytecodeConst.CallArgSpec>(args.size) else null
for ((index, arg) in args.withIndex()) {
val compiled = compileArgValue(arg.value) ?: return null
@ -4891,11 +4895,17 @@ class BytecodeCompiler(
specs?.add(BytecodeConst.CallArgSpec(arg.name, arg.isSplat))
}
val planId = if (needPlan) {
builder.addConst(BytecodeConst.CallArgsPlan(tailBlock, specs ?: emptyList()))
builder.addConst(
BytecodeConst.CallArgsPlan(
tailBlock = tailBlock,
specs = specs ?: emptyList(),
explicitTypeArgs = explicitTypeArgs ?: emptyList()
)
)
} else {
null
}
return CallArgs(base = argSlots[0], count = argSlots.size, planId = planId)
return CallArgs(base = if (argSlots.isEmpty()) 0 else argSlots[0], count = argSlots.size, planId = planId)
}
private fun compileArgValue(value: Obj): CompiledValue? {
@ -8464,6 +8474,9 @@ class BytecodeCompiler(
collectScopeSlotsRef(ref.targetRef)
collectScopeSlotsRef(ref.indexRef)
}
is ClassOperatorRef -> {
collectScopeSlotsRef(ref.target)
}
is ListLiteralRef -> {
for (entry in ref.entries()) {
when (entry) {

View File

@ -180,6 +180,10 @@ sealed class BytecodeConst {
val pattern: ListLiteralRef,
val pos: Pos,
) : BytecodeConst()
data class CallArgsPlan(val tailBlock: Boolean, val specs: List<CallArgSpec>) : BytecodeConst()
data class CallArgsPlan(
val tailBlock: Boolean,
val specs: List<CallArgSpec>,
val explicitTypeArgs: List<TypeDecl> = emptyList()
) : BytecodeConst()
data class CallArgSpec(val name: String?, val isSplat: Boolean)
}

View File

@ -4699,7 +4699,12 @@ class CmdFrame(
positional.add(value)
}
}
return Arguments(positional, plan.tailBlock, named ?: emptyMap())
return Arguments(
list = positional,
tailBlockMode = plan.tailBlock,
named = named ?: emptyMap(),
explicitTypeArgs = plan.explicitTypeArgs
)
}
private fun resolveLocalScope(localIndex: Int): Scope? {

View File

@ -469,6 +469,7 @@ class CallRef(
internal val args: List<ParsedArgument>,
internal val tailBlock: Boolean,
internal val isOptionalInvoke: Boolean,
internal val explicitTypeArgs: List<TypeDecl>? = null,
) : ObjRef {
override suspend fun get(scope: Scope): ObjRecord = scope.raiseObjRefEvalDisabled()
}

View File

@ -99,7 +99,9 @@ internal fun typeDeclIsSubtype(scope: Scope, left: TypeDecl, right: TypeDecl): B
is TypeDecl.Simple, is TypeDecl.Generic, is TypeDecl.Function, is TypeDecl.Ellipsis -> {
val leftClass = resolveTypeDeclClass(scope, l) ?: return false
val rightClass = resolveTypeDeclClass(scope, r) ?: return false
leftClass == rightClass || leftClass.allParentsSet.contains(rightClass)
leftClass == rightClass ||
rightClass == Obj.rootObjectType ||
leftClass.allParentsSet.contains(rightClass)
}
else -> false
}

View File

@ -510,4 +510,46 @@ class TypesTest {
)
}
}
@Test
fun testClassName() = runTest {
eval("""
class X {
var x = 1
}
assert( X::class is Class)
assertEquals("Class", X::class.name)
""".trimIndent())
}
@Test
fun testGenericTypes() = runTest {
eval("""
fun t<T>(): String =
when(T) {
null -> "%s is Null"(T::class.name)
is Object -> "%s is Object"(T::class.name)
else -> throw "It should not happen"
}
assert( Int is Object)
assertEquals( t<Int>(), "Class is Object")
""".trimIndent())
}
// @Test fun nonTrivialOperatorsTest() = runTest {
// val s = Script.newScope()
// s.eval("""
// class Matrix<T>(val rows: Int, val cols: Int,initialValue:T?) {
// val data
// init {
// val v = initalValue?
// }
// data = List(rows*cols) { initialValue } }
// fun getAt(row: Int, col: Int) = data[row*cols+col]
// fun setAt(row: Int, col: Int, value: T) { data[row*cols+col] = value }
// }
// val m = Matrix(1,1)
//
// """.trimIndent())
// }
}