Propagate exact callable refs across expressions

This commit is contained in:
Sergey Chernov 2026-04-21 14:33:23 +03:00
parent ffa64d691b
commit 029fe874fa
2 changed files with 364 additions and 37 deletions

View File

@ -91,6 +91,7 @@ class BytecodeCompiler(
private val intLoopVarNames = LinkedHashSet<String>() private val intLoopVarNames = LinkedHashSet<String>()
private val valueFnRefs = LinkedHashSet<ValueFnRef>() private val valueFnRefs = LinkedHashSet<ValueFnRef>()
private val exactLambdaRefBySlot = LinkedHashMap<Int, LambdaFnRef>() private val exactLambdaRefBySlot = LinkedHashMap<Int, LambdaFnRef>()
private val exactCallableObjBySlot = LinkedHashMap<Int, Obj>()
private val activeInlineLambdas = LinkedHashSet<LambdaFnRef>() private val activeInlineLambdas = LinkedHashSet<LambdaFnRef>()
private val inlineThisBindings = ArrayDeque<InlineThisBinding>() private val inlineThisBindings = ArrayDeque<InlineThisBinding>()
private val loopVarKeys = LinkedHashSet<ScopeSlotKey>() private val loopVarKeys = LinkedHashSet<ScopeSlotKey>()
@ -594,6 +595,7 @@ class BytecodeCompiler(
updateSlotType(local, resolved) updateSlotType(local, resolved)
if (resolved == SlotType.OBJ) { if (resolved == SlotType.OBJ) {
propagateObjClass(SlotType.OBJ, mapped, local) propagateObjClass(SlotType.OBJ, mapped, local)
seedExactCallableForNamedSlot(ref.name, local)
} }
return CompiledValue(local, resolved) return CompiledValue(local, resolved)
} }
@ -604,8 +606,12 @@ class BytecodeCompiler(
emitLoadFromAddr(addrSlot, local, SlotType.OBJ) emitLoadFromAddr(addrSlot, local, SlotType.OBJ)
updateSlotType(local, SlotType.OBJ) updateSlotType(local, SlotType.OBJ)
propagateObjClass(SlotType.OBJ, mapped, local) propagateObjClass(SlotType.OBJ, mapped, local)
seedExactCallableForNamedSlot(ref.name, local)
return CompiledValue(local, SlotType.OBJ) return CompiledValue(local, SlotType.OBJ)
} }
if (resolved == SlotType.OBJ) {
seedExactCallableForNamedSlot(ref.name, mapped)
}
CompiledValue(mapped, resolved) CompiledValue(mapped, resolved)
} }
is LocalVarRef -> { is LocalVarRef -> {
@ -614,12 +620,18 @@ class BytecodeCompiler(
} }
loopSlotOverrides[ref.name]?.let { slot -> loopSlotOverrides[ref.name]?.let { slot ->
val resolved = slotTypes[slot] ?: SlotType.UNKNOWN val resolved = slotTypes[slot] ?: SlotType.UNKNOWN
if (resolved == SlotType.OBJ) {
seedExactCallableForNamedSlot(ref.name, slot)
}
return CompiledValue(slot, resolved) return CompiledValue(slot, resolved)
} }
if (allowLocalSlots) { if (allowLocalSlots) {
scopeSlotIndexByName[ref.name]?.let { slot -> scopeSlotIndexByName[ref.name]?.let { slot ->
noteScopeSlotRef(slot, callSitePos()) noteScopeSlotRef(slot, callSitePos())
val resolved = slotTypes[slot] ?: SlotType.UNKNOWN val resolved = slotTypes[slot] ?: SlotType.UNKNOWN
if (resolved == SlotType.OBJ) {
seedExactCallableForNamedSlot(ref.name, slot)
}
return CompiledValue(slot, resolved) return CompiledValue(slot, resolved)
} }
} }
@ -631,6 +643,9 @@ class BytecodeCompiler(
} }
loopSlotOverrides[ref.name]?.let { slot -> loopSlotOverrides[ref.name]?.let { slot ->
val resolved = slotTypes[slot] ?: SlotType.UNKNOWN val resolved = slotTypes[slot] ?: SlotType.UNKNOWN
if (resolved == SlotType.OBJ) {
seedExactCallableForNamedSlot(ref.name, slot)
}
return CompiledValue(slot, resolved) return CompiledValue(slot, resolved)
} }
if (allowLocalSlots) { if (allowLocalSlots) {
@ -638,11 +653,17 @@ class BytecodeCompiler(
if (localIndex != null) { if (localIndex != null) {
val slot = scopeSlotCount + localIndex val slot = scopeSlotCount + localIndex
val resolved = slotTypes[slot] ?: SlotType.UNKNOWN val resolved = slotTypes[slot] ?: SlotType.UNKNOWN
if (resolved == SlotType.OBJ) {
seedExactCallableForNamedSlot(ref.name, slot)
}
return CompiledValue(slot, resolved) return CompiledValue(slot, resolved)
} }
scopeSlotIndexByName[ref.name]?.let { slot -> scopeSlotIndexByName[ref.name]?.let { slot ->
noteScopeSlotRef(slot, callSitePos()) noteScopeSlotRef(slot, callSitePos())
val resolved = slotTypes[slot] ?: SlotType.UNKNOWN val resolved = slotTypes[slot] ?: SlotType.UNKNOWN
if (resolved == SlotType.OBJ) {
seedExactCallableForNamedSlot(ref.name, slot)
}
return CompiledValue(slot, resolved) return CompiledValue(slot, resolved)
} }
} }
@ -660,9 +681,13 @@ class BytecodeCompiler(
updateSlotType(local, resolved) updateSlotType(local, resolved)
if (resolved == SlotType.OBJ) { if (resolved == SlotType.OBJ) {
propagateObjClass(SlotType.OBJ, slot, local) propagateObjClass(SlotType.OBJ, slot, local)
localSlotNames.getOrNull(slot - scopeSlotCount)?.let { seedExactCallableForNamedSlot(it, local) }
} }
return CompiledValue(local, resolved) return CompiledValue(local, resolved)
} }
if (resolved == SlotType.OBJ) {
localSlotNames.getOrNull(slot - scopeSlotCount)?.let { seedExactCallableForNamedSlot(it, slot) }
}
CompiledValue(slot, resolved) CompiledValue(slot, resolved)
} }
is ValueFnRef -> compileValueFnRef(ref) is ValueFnRef -> compileValueFnRef(ref)
@ -720,7 +745,7 @@ class BytecodeCompiler(
val calleeObj = ensureObjSlot(callee) val calleeObj = ensureObjSlot(callee)
val args = compileCallArgsWithReceiver(receiver, emptyList(), false) ?: return null val args = compileCallArgsWithReceiver(receiver, emptyList(), false) ?: return null
val encodedCount = encodeCallArgCount(args) ?: return null val encodedCount = encodeCallArgCount(args) ?: return null
builder.emit(Opcode.CALL_SLOT, calleeObj.slot, args.base, encodedCount, dst) emitCallCompiled(calleeObj, args.base, encodedCount, dst)
updateSlotType(dst, SlotType.OBJ) updateSlotType(dst, SlotType.OBJ)
annotateIndexedReceiverSlot(dst, ownerClass?.let { inferFieldReturnClass(it, ref.name) }) annotateIndexedReceiverSlot(dst, ownerClass?.let { inferFieldReturnClass(it, ref.name) })
return CompiledValue(dst, SlotType.OBJ) return CompiledValue(dst, SlotType.OBJ)
@ -788,7 +813,7 @@ class BytecodeCompiler(
val args = compileCallArgsWithReceiver(receiver, ref.arguments(), ref.hasTailBlock()) ?: return null val args = compileCallArgsWithReceiver(receiver, ref.arguments(), ref.hasTailBlock()) ?: return null
val encodedCount = encodeCallArgCount(args) ?: return null val encodedCount = encodeCallArgCount(args) ?: return null
setPos(callPos) setPos(callPos)
builder.emit(Opcode.CALL_SLOT, calleeObj.slot, args.base, encodedCount, dst) emitCallCompiled(calleeObj, args.base, encodedCount, dst)
return CompiledValue(dst, SlotType.OBJ) return CompiledValue(dst, SlotType.OBJ)
} }
val nullSlot = allocSlot() val nullSlot = allocSlot()
@ -804,7 +829,7 @@ class BytecodeCompiler(
val args = compileCallArgsWithReceiver(receiver, ref.arguments(), ref.hasTailBlock()) ?: return null val args = compileCallArgsWithReceiver(receiver, ref.arguments(), ref.hasTailBlock()) ?: return null
val encodedCount = encodeCallArgCount(args) ?: return null val encodedCount = encodeCallArgCount(args) ?: return null
setPos(callPos) setPos(callPos)
builder.emit(Opcode.CALL_SLOT, calleeObj.slot, args.base, encodedCount, dst) emitCallCompiled(calleeObj, args.base, encodedCount, dst)
builder.emit(Opcode.JMP, listOf(CmdBuilder.Operand.LabelRef(endLabel))) builder.emit(Opcode.JMP, listOf(CmdBuilder.Operand.LabelRef(endLabel)))
builder.mark(nullLabel) builder.mark(nullLabel)
builder.emit(Opcode.CONST_NULL, dst) builder.emit(Opcode.CONST_NULL, dst)
@ -860,15 +885,18 @@ class BytecodeCompiler(
val id = builder.addConst(BytecodeConst.StringVal(obj.value)) val id = builder.addConst(BytecodeConst.StringVal(obj.value))
builder.emit(Opcode.CONST_OBJ, id, slot) builder.emit(Opcode.CONST_OBJ, id, slot)
slotObjClass[slot] = ObjString.type slotObjClass[slot] = ObjString.type
trackExactCallableObjAtSlot(slot, obj)
return CompiledValue(slot, SlotType.OBJ) return CompiledValue(slot, SlotType.OBJ)
} }
ObjNull -> { ObjNull -> {
builder.emit(Opcode.CONST_NULL, slot) builder.emit(Opcode.CONST_NULL, slot)
trackExactCallableObjAtSlot(slot, null)
return CompiledValue(slot, SlotType.OBJ) return CompiledValue(slot, SlotType.OBJ)
} }
else -> { else -> {
val id = builder.addConst(BytecodeConst.ObjRef(obj)) val id = builder.addConst(BytecodeConst.ObjRef(obj))
builder.emit(Opcode.CONST_OBJ, id, slot) builder.emit(Opcode.CONST_OBJ, id, slot)
trackExactCallableObjAtSlot(slot, obj)
return CompiledValue(slot, SlotType.OBJ) return CompiledValue(slot, SlotType.OBJ)
} }
} }
@ -879,6 +907,34 @@ class BytecodeCompiler(
builder.emit(Opcode.CALL_DIRECT, calleeId, argBase, encodedCount, dst) builder.emit(Opcode.CALL_DIRECT, calleeId, argBase, encodedCount, dst)
} }
private fun seedExactCallableForNamedSlot(name: String, slot: Int) {
if (exactCallableObjBySlot[slot] != null) return
if (slotObjClass[slot] != ObjClassType) return
resolveTypeNameClass(name)?.let { trackExactCallableObjAtSlot(slot, it) }
}
private fun emitCallCompiled(
callee: CompiledValue,
argBase: Int,
encodedCount: Int,
dst: Int,
isExternCall: Boolean = false,
) {
if (!isExternCall) {
exactCallableObjBySlot[callee.slot]?.let {
emitCallDirect(it, argBase, encodedCount, dst)
return
}
}
builder.emit(
if (isExternCall) Opcode.CALL_BRIDGE_SLOT else Opcode.CALL_SLOT,
callee.slot,
argBase,
encodedCount,
dst
)
}
private fun compileValueFnRef(ref: ValueFnRef): CompiledValue? { private fun compileValueFnRef(ref: ValueFnRef): CompiledValue? {
if (ref is LambdaFnRef && ref.bytecodeFn != null) { if (ref is LambdaFnRef && ref.bytecodeFn != null) {
val captures = (lambdaCaptureEntriesByRef[ref] ?: ref.captureEntries).orEmpty() val captures = (lambdaCaptureEntriesByRef[ref] ?: ref.captureEntries).orEmpty()
@ -2529,6 +2585,7 @@ class BytecodeCompiler(
} }
updateSlotType(slot, value.type) updateSlotType(slot, value.type)
propagateObjClass(value.type, value.slot, slot) propagateObjClass(value.type, value.slot, slot)
trackExactCallableObjAtSlot(slot, null)
trackExactLambdaAtSlot(slot, null) trackExactLambdaAtSlot(slot, null)
updateNameObjClassFromSlot(localTarget.name, slot) updateNameObjClassFromSlot(localTarget.name, slot)
return value return value
@ -2573,6 +2630,7 @@ class BytecodeCompiler(
} }
updateSlotType(slot, value.type) updateSlotType(slot, value.type)
propagateObjClass(value.type, value.slot, slot) propagateObjClass(value.type, value.slot, slot)
trackExactCallableObjAtSlot(slot, null)
trackExactLambdaAtSlot(slot, null) trackExactLambdaAtSlot(slot, null)
updateNameObjClassFromSlot(nameTarget, slot) updateNameObjClassFromSlot(nameTarget, slot)
return value return value
@ -2674,7 +2732,7 @@ class BytecodeCompiler(
val encodedCount = encodeCallArgCount(callArgs) ?: return null val encodedCount = encodeCallArgCount(callArgs) ?: return null
val callDst = allocSlot() val callDst = allocSlot()
if (!target.isOptional) { if (!target.isOptional) {
builder.emit(Opcode.CALL_SLOT, callee.slot, callArgs.base, encodedCount, callDst) emitCallCompiled(callee, callArgs.base, encodedCount, callDst)
} else { } else {
val nullSlot = allocSlot() val nullSlot = allocSlot()
builder.emit(Opcode.CONST_NULL, nullSlot) builder.emit(Opcode.CONST_NULL, nullSlot)
@ -2685,7 +2743,7 @@ class BytecodeCompiler(
Opcode.JMP_IF_TRUE, Opcode.JMP_IF_TRUE,
listOf(CmdBuilder.Operand.IntVal(cmpSlot), CmdBuilder.Operand.LabelRef(endLabel)) listOf(CmdBuilder.Operand.IntVal(cmpSlot), CmdBuilder.Operand.LabelRef(endLabel))
) )
builder.emit(Opcode.CALL_SLOT, callee.slot, callArgs.base, encodedCount, callDst) emitCallCompiled(callee, callArgs.base, encodedCount, callDst)
builder.mark(endLabel) builder.mark(endLabel)
} }
return value return value
@ -2734,7 +2792,7 @@ class BytecodeCompiler(
val callArgs = CallArgs(base = argSlots[0], count = argSlots.size, planId = null) val callArgs = CallArgs(base = argSlots[0], count = argSlots.size, planId = null)
val encodedCount = encodeCallArgCount(callArgs) ?: return null val encodedCount = encodeCallArgCount(callArgs) ?: return null
val callDst = allocSlot() val callDst = allocSlot()
builder.emit(Opcode.CALL_SLOT, calleeObj.slot, callArgs.base, encodedCount, callDst) emitCallCompiled(calleeObj, callArgs.base, encodedCount, callDst)
return value return value
} }
builder.emit(Opcode.SET_MEMBER_SLOT, receiver.slot, fieldId, methodId, value.slot) builder.emit(Opcode.SET_MEMBER_SLOT, receiver.slot, fieldId, methodId, value.slot)
@ -3358,7 +3416,7 @@ class BytecodeCompiler(
val callArgs = CallArgs(base = argSlots[0], count = argSlots.size, planId = null) val callArgs = CallArgs(base = argSlots[0], count = argSlots.size, planId = null)
val encodedCount = encodeCallArgCount(callArgs) ?: return null val encodedCount = encodeCallArgCount(callArgs) ?: return null
if (!target.isOptional) { if (!target.isOptional) {
builder.emit(Opcode.CALL_SLOT, callee.slot, callArgs.base, encodedCount, resultSlot) emitCallCompiled(callee, callArgs.base, encodedCount, resultSlot)
} else { } else {
val recvNull = allocSlot() val recvNull = allocSlot()
builder.emit(Opcode.CONST_NULL, recvNull) builder.emit(Opcode.CONST_NULL, recvNull)
@ -3369,7 +3427,7 @@ class BytecodeCompiler(
Opcode.JMP_IF_TRUE, Opcode.JMP_IF_TRUE,
listOf(CmdBuilder.Operand.IntVal(recvCmp), CmdBuilder.Operand.LabelRef(skipLabel)) listOf(CmdBuilder.Operand.IntVal(recvCmp), CmdBuilder.Operand.LabelRef(skipLabel))
) )
builder.emit(Opcode.CALL_SLOT, callee.slot, callArgs.base, encodedCount, resultSlot) emitCallCompiled(callee, callArgs.base, encodedCount, resultSlot)
builder.mark(skipLabel) builder.mark(skipLabel)
} }
} }
@ -4405,12 +4463,12 @@ class BytecodeCompiler(
) )
val thenValue = compileRefWithFallback(ref.ifTrue, null, Pos.builtIn) ?: return null val thenValue = compileRefWithFallback(ref.ifTrue, null, Pos.builtIn) ?: return null
val thenObj = ensureObjSlot(thenValue) val thenObj = ensureObjSlot(thenValue)
builder.emit(Opcode.MOVE_OBJ, thenObj.slot, resultSlot) emitMove(thenObj, resultSlot)
builder.emit(Opcode.JMP, listOf(CmdBuilder.Operand.LabelRef(endLabel))) builder.emit(Opcode.JMP, listOf(CmdBuilder.Operand.LabelRef(endLabel)))
builder.mark(elseLabel) builder.mark(elseLabel)
val elseValue = compileRefWithFallback(ref.ifFalse, null, Pos.builtIn) ?: return null val elseValue = compileRefWithFallback(ref.ifFalse, null, Pos.builtIn) ?: return null
val elseObj = ensureObjSlot(elseValue) val elseObj = ensureObjSlot(elseValue)
builder.emit(Opcode.MOVE_OBJ, elseObj.slot, resultSlot) emitMove(elseObj, resultSlot)
builder.mark(endLabel) builder.mark(endLabel)
updateSlotType(resultSlot, SlotType.OBJ) updateSlotType(resultSlot, SlotType.OBJ)
return CompiledValue(resultSlot, SlotType.OBJ) return CompiledValue(resultSlot, SlotType.OBJ)
@ -4430,12 +4488,12 @@ class BytecodeCompiler(
Opcode.JMP_IF_TRUE, Opcode.JMP_IF_TRUE,
listOf(CmdBuilder.Operand.IntVal(cmpSlot), CmdBuilder.Operand.LabelRef(rightLabel)) listOf(CmdBuilder.Operand.IntVal(cmpSlot), CmdBuilder.Operand.LabelRef(rightLabel))
) )
builder.emit(Opcode.MOVE_OBJ, leftObj.slot, resultSlot) emitMove(leftObj, resultSlot)
builder.emit(Opcode.JMP, listOf(CmdBuilder.Operand.LabelRef(endLabel))) builder.emit(Opcode.JMP, listOf(CmdBuilder.Operand.LabelRef(endLabel)))
builder.mark(rightLabel) builder.mark(rightLabel)
val rightValue = compileRefWithFallback(ref.right, null, Pos.builtIn) ?: return null val rightValue = compileRefWithFallback(ref.right, null, Pos.builtIn) ?: return null
val rightObj = ensureObjSlot(rightValue) val rightObj = ensureObjSlot(rightValue)
builder.emit(Opcode.MOVE_OBJ, rightObj.slot, resultSlot) emitMove(rightObj, resultSlot)
builder.mark(endLabel) builder.mark(endLabel)
updateSlotType(resultSlot, SlotType.OBJ) updateSlotType(resultSlot, SlotType.OBJ)
return CompiledValue(resultSlot, SlotType.OBJ) return CompiledValue(resultSlot, SlotType.OBJ)
@ -4469,7 +4527,7 @@ class BytecodeCompiler(
val bodyValue = compileStatementValueOrFallback(case.block, wantResult) ?: return null val bodyValue = compileStatementValueOrFallback(case.block, wantResult) ?: return null
if (wantResult) { if (wantResult) {
val bodyObj = ensureObjSlot(bodyValue) val bodyObj = ensureObjSlot(bodyValue)
builder.emit(Opcode.MOVE_OBJ, bodyObj.slot, resultSlot) emitMove(bodyObj, resultSlot)
} }
restoreFlowTypeOverride(caseRestore) restoreFlowTypeOverride(caseRestore)
builder.emit(Opcode.JMP, listOf(CmdBuilder.Operand.LabelRef(endLabel))) builder.emit(Opcode.JMP, listOf(CmdBuilder.Operand.LabelRef(endLabel)))
@ -4479,7 +4537,7 @@ class BytecodeCompiler(
val elseValue = compileStatementValueOrFallback(it, wantResult) ?: return null val elseValue = compileStatementValueOrFallback(it, wantResult) ?: return null
if (wantResult) { if (wantResult) {
val elseObj = ensureObjSlot(elseValue) val elseObj = ensureObjSlot(elseValue)
builder.emit(Opcode.MOVE_OBJ, elseObj.slot, resultSlot) emitMove(elseObj, resultSlot)
} }
} }
builder.mark(endLabel) builder.mark(endLabel)
@ -4632,6 +4690,19 @@ class BytecodeCompiler(
} }
val localTarget = ref.target as? LocalVarRef val localTarget = ref.target as? LocalVarRef
val isExternCall = localTarget != null && externCallableNames.contains(localTarget.name) val isExternCall = localTarget != null && externCallableNames.contains(localTarget.name)
if (!isExternCall) {
val exactCallee = resolveExactCallableObj(ref.target)
if (exactCallee != null) {
val args = compileCallArgs(ref.args, ref.tailBlock, ref.explicitTypeArgs) ?: return null
val encodedCount = encodeCallArgCount(args) ?: return null
val dst = allocSlot()
setPos(callPos)
emitCallDirect(exactCallee, args.base, encodedCount, dst)
updateSlotType(dst, SlotType.OBJ)
(exactCallee as? ObjClass)?.let { slotObjClass[dst] = it }
return CompiledValue(dst, SlotType.OBJ)
}
}
if (localTarget != null) { if (localTarget != null) {
val direct = resolveDirectNameSlot(localTarget.name) val direct = resolveDirectNameSlot(localTarget.name)
if (direct == null) { if (direct == null) {
@ -4660,13 +4731,7 @@ class BytecodeCompiler(
val args = compileCallArgs(ref.args, ref.tailBlock, ref.explicitTypeArgs) ?: return null val args = compileCallArgs(ref.args, ref.tailBlock, ref.explicitTypeArgs) ?: return null
val encodedCount = encodeCallArgCount(args) ?: return null val encodedCount = encodeCallArgCount(args) ?: return null
setPos(callPos) setPos(callPos)
builder.emit( emitCallCompiled(callee, args.base, encodedCount, dst, isExternCall = isExternCall)
if (isExternCall) Opcode.CALL_BRIDGE_SLOT else Opcode.CALL_SLOT,
callee.slot,
args.base,
encodedCount,
dst
)
if (initClass != null) { if (initClass != null) {
slotObjClass[dst] = initClass slotObjClass[dst] = initClass
} }
@ -4685,13 +4750,7 @@ class BytecodeCompiler(
val args = compileCallArgs(ref.args, ref.tailBlock, ref.explicitTypeArgs) ?: return null val args = compileCallArgs(ref.args, ref.tailBlock, ref.explicitTypeArgs) ?: return null
val encodedCount = encodeCallArgCount(args) ?: return null val encodedCount = encodeCallArgCount(args) ?: return null
setPos(callPos) setPos(callPos)
builder.emit( emitCallCompiled(callee, args.base, encodedCount, dst, isExternCall = isExternCall)
if (isExternCall) Opcode.CALL_BRIDGE_SLOT else Opcode.CALL_SLOT,
callee.slot,
args.base,
encodedCount,
dst
)
if (initClass != null) { if (initClass != null) {
slotObjClass[dst] = initClass slotObjClass[dst] = initClass
} }
@ -4885,7 +4944,7 @@ class BytecodeCompiler(
val args = compileCallArgs(ref.args, ref.tailBlock, ref.explicitTypeArgs) ?: return null val args = compileCallArgs(ref.args, ref.tailBlock, ref.explicitTypeArgs) ?: return null
val encodedCount = encodeCallArgCount(args) ?: return null val encodedCount = encodeCallArgCount(args) ?: return null
setPos(callPos) setPos(callPos)
builder.emit(Opcode.CALL_SLOT, memberSlot, args.base, encodedCount, dst) emitCallCompiled(CompiledValue(memberSlot, SlotType.OBJ), args.base, encodedCount, dst)
return CompiledValue(dst, SlotType.OBJ) return CompiledValue(dst, SlotType.OBJ)
} }
val extSlot = resolveExtensionCallableSlot(receiverClass, ref.name) val extSlot = resolveExtensionCallableSlot(receiverClass, ref.name)
@ -4898,7 +4957,7 @@ class BytecodeCompiler(
val args = compileCallArgsWithReceiver(receiver, ref.args, ref.tailBlock, ref.explicitTypeArgs) ?: return null val args = compileCallArgsWithReceiver(receiver, ref.args, ref.tailBlock, ref.explicitTypeArgs) ?: return null
val encodedCount = encodeCallArgCount(args) ?: return null val encodedCount = encodeCallArgCount(args) ?: return null
setPos(callPos) setPos(callPos)
builder.emit(Opcode.CALL_SLOT, callee.slot, args.base, encodedCount, dst) emitCallCompiled(callee, args.base, encodedCount, dst)
return CompiledValue(dst, SlotType.OBJ) return CompiledValue(dst, SlotType.OBJ)
} }
val nullSlot = allocSlot() val nullSlot = allocSlot()
@ -4914,7 +4973,7 @@ class BytecodeCompiler(
val args = compileCallArgsWithReceiver(receiver, ref.args, ref.tailBlock, ref.explicitTypeArgs) ?: return null val args = compileCallArgsWithReceiver(receiver, ref.args, ref.tailBlock, ref.explicitTypeArgs) ?: return null
val encodedCount = encodeCallArgCount(args) ?: return null val encodedCount = encodeCallArgCount(args) ?: return null
setPos(callPos) setPos(callPos)
builder.emit(Opcode.CALL_SLOT, callee.slot, args.base, encodedCount, dst) emitCallCompiled(callee, args.base, encodedCount, dst)
builder.emit(Opcode.JMP, listOf(CmdBuilder.Operand.LabelRef(endLabel))) builder.emit(Opcode.JMP, listOf(CmdBuilder.Operand.LabelRef(endLabel)))
builder.mark(nullLabel) builder.mark(nullLabel)
builder.emit(Opcode.CONST_NULL, dst) builder.emit(Opcode.CONST_NULL, dst)
@ -5435,6 +5494,14 @@ class BytecodeCompiler(
} }
} }
private fun trackExactCallableObjAtSlot(slot: Int, obj: Obj?) {
if (obj == null || obj === ObjNull || obj === ObjUnset || obj is ObjExternCallable) {
exactCallableObjBySlot.remove(slot)
} else {
exactCallableObjBySlot[slot] = obj
}
}
private fun preloadExactLambdaRefs() { private fun preloadExactLambdaRefs() {
if (exactLambdaRefByScopeId.isEmpty()) return if (exactLambdaRefByScopeId.isEmpty()) return
for ((scopeId, slots) in exactLambdaRefByScopeId) { for ((scopeId, slots) in exactLambdaRefByScopeId) {
@ -5453,6 +5520,24 @@ class BytecodeCompiler(
} }
} }
private fun preloadExactCallableNames() {
if (scopeSlotCount > 0) {
for (index in 0 until scopeSlotCount) {
val name = scopeSlotNames.getOrNull(index) ?: continue
resolveTypeNameClass(name)?.let { trackExactCallableObjAtSlot(index, it) }
}
}
if (!allowLocalSlots || localSlotNames.isEmpty()) return
for (localIndex in localSlotNames.indices) {
val name = localSlotNames[localIndex] ?: continue
val key = localSlotKeyByIndex.getOrNull(localIndex)
val isModuleLocal = key != null && moduleScopeId != null && key.scopeId == moduleScopeId
val isCapture = localSlotCaptures.getOrNull(localIndex) == true
if (!isModuleLocal && !isCapture) continue
resolveTypeNameClass(name)?.let { trackExactCallableObjAtSlot(scopeSlotCount + localIndex, it) }
}
}
private fun collectExactLambdaModuleCaptures() { private fun collectExactLambdaModuleCaptures() {
if (exactLambdaRefByScopeId.isEmpty()) return if (exactLambdaRefByScopeId.isEmpty()) return
val seen = LinkedHashSet<LambdaFnRef>() val seen = LinkedHashSet<LambdaFnRef>()
@ -5500,6 +5585,107 @@ class BytecodeCompiler(
} }
} }
private fun extractExactCallableObj(value: Obj?): Obj? {
return when (value) {
is ExpressionStatement -> resolveExactCallableObj(value.ref)
is IfStatement -> {
val thenObj = extractExactCallableObj(value.ifBody)
val elseObj = value.elseBody?.let { extractExactCallableObj(it) }
if (thenObj != null && thenObj === elseObj) thenObj else null
}
is WhenStatement -> {
var candidate: Obj? = null
for (case in value.cases) {
val current = extractExactCallableObj(case.block) ?: return null
if (candidate == null) {
candidate = current
} else if (candidate !== current) {
return null
}
}
val elseObj = value.elseCase?.let { extractExactCallableObj(it) }
if (candidate == null) return elseObj
if (elseObj == null || candidate !== elseObj) return null
candidate
}
else -> null
}
}
private fun isDefinitelyNullRef(ref: ObjRef): Boolean {
return when (ref) {
is ConstRef -> ref.constValue === ObjNull
is StatementRef -> {
val statement = ref.statement
statement is ExpressionStatement && isDefinitelyNullRef(statement.ref)
}
else -> false
}
}
private fun resolveNamedExactCallableObj(name: String): Obj? {
val direct = resolveDirectNameSlot(name)
if (direct != null) {
exactCallableObjBySlot[direct.slot]?.let { return it }
if (slotObjClass[direct.slot] == ObjClassType) {
return resolveTypeNameClass(name)
}
if (!canFallbackToNamedExactCallable(direct.slot)) {
return null
}
}
return resolveTypeNameClass(name)
}
private fun canFallbackToNamedExactCallable(slot: Int): Boolean {
if (slot < scopeSlotCount) return true
val localIndex = slot - scopeSlotCount
if (localSlotCaptures.getOrNull(localIndex) == true) return true
val key = localSlotKeyByIndex.getOrNull(localIndex) ?: return false
return moduleScopeId != null && key.scopeId == moduleScopeId
}
private fun resolveExactCallableObj(target: ObjRef): Obj? {
return when (target) {
is ConstRef -> target.constValue.takeUnless { it === ObjNull || it === ObjUnset || it is ObjExternCallable }
is ElvisRef -> {
val left = resolveExactCallableObj(target.left)
if (left != null) return left
if (isDefinitelyNullRef(target.left)) resolveExactCallableObj(target.right) else null
}
is ConditionalRef -> {
val thenObj = resolveExactCallableObj(target.ifTrue)
val elseObj = resolveExactCallableObj(target.ifFalse)
if (thenObj != null && thenObj === elseObj) thenObj else null
}
is StatementRef -> {
when (val statement = target.statement) {
is ExpressionStatement -> resolveExactCallableObj(statement.ref)
is IfStatement -> {
val thenObj = extractExactCallableObj(statement.ifBody)
val elseObj = statement.elseBody?.let { extractExactCallableObj(it) }
if (thenObj != null && thenObj === elseObj) thenObj else null
}
else -> null
}
}
is LocalSlotRef -> {
val resolvedSlot = resolveLocalSlotByRefOrName(target)
if (resolvedSlot != null) {
exactCallableObjBySlot[resolvedSlot]?.let { return it }
if (slotObjClass[resolvedSlot] == ObjClassType || canFallbackToNamedExactCallable(resolvedSlot)) {
return resolveTypeNameClass(target.name)
}
}
null
}
is LocalVarRef -> resolveNamedExactCallableObj(target.name)
is FastLocalVarRef -> resolveNamedExactCallableObj(target.name)
is BoundLocalVarRef -> exactCallableObjBySlot[target.slotIndex()]
else -> null
}
}
private fun compileInlineListFillInt(size: CompiledValue, lambdaRef: LambdaFnRef, inlineRef: ObjRef): CompiledValue { private fun compileInlineListFillInt(size: CompiledValue, lambdaRef: LambdaFnRef, inlineRef: ObjRef): CompiledValue {
if (isImplicitItIdentityRef(inlineRef)) { if (isImplicitItIdentityRef(inlineRef)) {
val dst = allocSlot() val dst = allocSlot()
@ -6020,6 +6206,15 @@ class BytecodeCompiler(
} ?: allocSlot() } ?: allocSlot()
builder.emit(Opcode.DECL_FUNCTION, constId, dst) builder.emit(Opcode.DECL_FUNCTION, constId, dst)
updateSlotType(dst, SlotType.OBJ) updateSlotType(dst, SlotType.OBJ)
if (!stmt.spec.actualExtern &&
!stmt.spec.isDelegated &&
stmt.spec.annotation == null &&
stmt.spec.extTypeName == null
) {
trackExactCallableObjAtSlot(dst, stmt.spec.fnBody)
} else {
trackExactCallableObjAtSlot(dst, null)
}
return CompiledValue(dst, SlotType.OBJ) return CompiledValue(dst, SlotType.OBJ)
} }
@ -6644,7 +6839,14 @@ class BytecodeCompiler(
?: updateSlotObjClass(localSlot, stmt.initializer, stmt.initializerObjClass) ?: updateSlotObjClass(localSlot, stmt.initializer, stmt.initializerObjClass)
updateListElementClassFromDecl(localSlot, scopeId, stmt.slotIndex) updateListElementClassFromDecl(localSlot, scopeId, stmt.slotIndex)
updateListElementClassFromInitializer(localSlot, stmt.initializer) updateListElementClassFromInitializer(localSlot, stmt.initializer)
trackExactLambdaAtSlot(localSlot, if (!stmt.isMutable) extractExactLambdaRef(stmt.initializer) else null) trackExactCallableObjAtSlot(
localSlot,
if (!stmt.isMutable) extractExactCallableObj(stmt.initializer) ?: exactCallableObjBySlot[localSlot] else null
)
trackExactLambdaAtSlot(
localSlot,
if (!stmt.isMutable) extractExactLambdaRef(stmt.initializer) ?: exactLambdaRefBySlot[localSlot] else null
)
updateNameObjClassFromSlot(stmt.name, localSlot) updateNameObjClassFromSlot(stmt.name, localSlot)
val shadowedScopeSlot = scopeSlotIndexByName.containsKey(stmt.name) val shadowedScopeSlot = scopeSlotIndexByName.containsKey(stmt.name)
val isModuleScope = moduleScopeId != null && scopeId == moduleScopeId val isModuleScope = moduleScopeId != null && scopeId == moduleScopeId
@ -6678,7 +6880,14 @@ class BytecodeCompiler(
?: updateSlotObjClass(scopeSlot, stmt.initializer, stmt.initializerObjClass) ?: updateSlotObjClass(scopeSlot, stmt.initializer, stmt.initializerObjClass)
updateListElementClassFromDecl(scopeSlot, scopeId, stmt.slotIndex) updateListElementClassFromDecl(scopeSlot, scopeId, stmt.slotIndex)
updateListElementClassFromInitializer(scopeSlot, stmt.initializer) updateListElementClassFromInitializer(scopeSlot, stmt.initializer)
trackExactLambdaAtSlot(scopeSlot, if (!stmt.isMutable) extractExactLambdaRef(stmt.initializer) else null) trackExactCallableObjAtSlot(
scopeSlot,
if (!stmt.isMutable) extractExactCallableObj(stmt.initializer) ?: exactCallableObjBySlot[scopeSlot] else null
)
trackExactLambdaAtSlot(
scopeSlot,
if (!stmt.isMutable) extractExactLambdaRef(stmt.initializer) ?: exactLambdaRefBySlot[scopeSlot] else null
)
val declId = builder.addConst( val declId = builder.addConst(
BytecodeConst.LocalDecl( BytecodeConst.LocalDecl(
stmt.name, stmt.name,
@ -7400,7 +7609,7 @@ class BytecodeCompiler(
val thenValue = compileStatementValueOrFallback(stmt.ifBody) ?: return null val thenValue = compileStatementValueOrFallback(stmt.ifBody) ?: return null
restoreFlowTypeOverride(thenRestore) restoreFlowTypeOverride(thenRestore)
val thenObj = ensureObjSlot(thenValue) val thenObj = ensureObjSlot(thenValue)
builder.emit(Opcode.MOVE_OBJ, thenObj.slot, resultSlot) emitMove(thenObj, resultSlot)
builder.emit(Opcode.JMP, listOf(CmdBuilder.Operand.LabelRef(endLabel))) builder.emit(Opcode.JMP, listOf(CmdBuilder.Operand.LabelRef(endLabel)))
builder.mark(elseLabel) builder.mark(elseLabel)
if (stmt.elseBody != null) { if (stmt.elseBody != null) {
@ -7408,7 +7617,7 @@ class BytecodeCompiler(
val elseValue = compileStatementValueOrFallback(stmt.elseBody) ?: return null val elseValue = compileStatementValueOrFallback(stmt.elseBody) ?: return null
restoreFlowTypeOverride(elseRestore) restoreFlowTypeOverride(elseRestore)
val elseObj = ensureObjSlot(elseValue) val elseObj = ensureObjSlot(elseValue)
builder.emit(Opcode.MOVE_OBJ, elseObj.slot, resultSlot) emitMove(elseObj, resultSlot)
} else { } else {
val id = builder.addConst(BytecodeConst.ObjRef(ObjVoid)) val id = builder.addConst(BytecodeConst.ObjRef(ObjVoid))
builder.emit(Opcode.CONST_OBJ, id, resultSlot) builder.emit(Opcode.CONST_OBJ, id, resultSlot)
@ -7727,6 +7936,7 @@ class BytecodeCompiler(
val addrSlot = ensureScopeAddr(srcSlot) val addrSlot = ensureScopeAddr(srcSlot)
emitLoadFromAddr(addrSlot, dstSlot, value.type) emitLoadFromAddr(addrSlot, dstSlot, value.type)
propagateObjClass(value.type, srcSlot, dstSlot) propagateObjClass(value.type, srcSlot, dstSlot)
propagateExactCallableObj(value.type, srcSlot, dstSlot)
propagateExactLambdaRef(value.type, srcSlot, dstSlot) propagateExactLambdaRef(value.type, srcSlot, dstSlot)
return return
} }
@ -7734,6 +7944,7 @@ class BytecodeCompiler(
val addrSlot = ensureScopeAddr(dstSlot) val addrSlot = ensureScopeAddr(dstSlot)
emitStoreToAddr(srcSlot, addrSlot, value.type) emitStoreToAddr(srcSlot, addrSlot, value.type)
propagateObjClass(value.type, srcSlot, dstSlot) propagateObjClass(value.type, srcSlot, dstSlot)
propagateExactCallableObj(value.type, srcSlot, dstSlot)
propagateExactLambdaRef(value.type, srcSlot, dstSlot) propagateExactLambdaRef(value.type, srcSlot, dstSlot)
return return
} }
@ -7746,6 +7957,7 @@ class BytecodeCompiler(
else -> builder.emit(Opcode.BOX_OBJ, srcSlot, dstSlot) else -> builder.emit(Opcode.BOX_OBJ, srcSlot, dstSlot)
} }
propagateObjClass(value.type, srcSlot, dstSlot) propagateObjClass(value.type, srcSlot, dstSlot)
propagateExactCallableObj(value.type, srcSlot, dstSlot)
propagateExactLambdaRef(value.type, srcSlot, dstSlot) propagateExactLambdaRef(value.type, srcSlot, dstSlot)
} }
@ -7783,6 +7995,14 @@ class BytecodeCompiler(
} }
} }
private fun propagateExactCallableObj(type: SlotType, srcSlot: Int, dstSlot: Int) {
if (type == SlotType.OBJ || type == SlotType.UNKNOWN) {
trackExactCallableObjAtSlot(dstSlot, exactCallableObjBySlot[srcSlot])
} else {
trackExactCallableObjAtSlot(dstSlot, null)
}
}
private fun setPos(pos: Pos?) { private fun setPos(pos: Pos?) {
currentPos = pos currentPos = pos
builder.setPos(pos) builder.setPos(pos)
@ -8748,6 +8968,7 @@ class BytecodeCompiler(
loopVarSlots.clear() loopVarSlots.clear()
valueFnRefs.clear() valueFnRefs.clear()
exactLambdaRefBySlot.clear() exactLambdaRefBySlot.clear()
exactCallableObjBySlot.clear()
activeInlineLambdas.clear() activeInlineLambdas.clear()
addrSlotByScopeSlot.clear() addrSlotByScopeSlot.clear()
loopStack.clear() loopStack.clear()
@ -9131,6 +9352,7 @@ class BytecodeCompiler(
} }
} }
} }
preloadExactCallableNames()
} }
is DelegatedVarDeclStatement -> { is DelegatedVarDeclStatement -> {
val slotIndex = stmt.slotIndex val slotIndex = stmt.slotIndex

View File

@ -567,6 +567,111 @@ class BytecodeRecentOpsTest {
assertEquals(2, scope.eval("calc()").toInt()) assertEquals(2, scope.eval("calc()").toInt())
} }
@Test
fun constructorNameUsesDirectCall() = runTest {
val scope = Script.newScope()
scope.eval(
"""
fun calc() {
Map().size
}
""".trimIndent()
)
val disasm = scope.disassembleSymbol("calc")
assertTrue(disasm.contains("CALL_DIRECT"), disasm)
assertFalse(disasm.contains("CALL_SLOT"), disasm)
assertEquals(0, scope.eval("calc()").toInt())
}
@Test
fun constructorAliasUsesDirectCall() = runTest {
val scope = Script.newScope()
scope.eval(
"""
fun calc() {
val ctor = Map
val m = ctor() as Map
m.size
}
""".trimIndent()
)
val disasm = scope.disassembleSymbol("calc")
assertTrue(disasm.contains("CALL_DIRECT"), disasm)
assertFalse(disasm.contains("CALL_SLOT"), disasm)
assertEquals(0, scope.eval("calc()").toInt())
}
@Test
fun ifExpressionConstructorAliasUsesDirectCall() = runTest {
val scope = Script.newScope()
scope.eval(
"""
fun calc(flag: Bool) {
val ctor = if(flag) Map else Map
val m = ctor() as Map
m.size
}
""".trimIndent()
)
val disasm = scope.disassembleSymbol("calc")
assertTrue(disasm.contains("CALL_DIRECT"), disasm)
assertFalse(disasm.contains("CALL_SLOT"), disasm)
assertEquals(0, scope.eval("calc(true)").toInt())
}
@Test
fun elvisConstructorAliasUsesDirectCall() = runTest {
val scope = Script.newScope()
scope.eval(
"""
fun calc() {
val ctor = null ?: Map
val m = ctor() as Map
m.size
}
""".trimIndent()
)
val disasm = scope.disassembleSymbol("calc")
assertTrue(disasm.contains("CALL_DIRECT"), disasm)
assertFalse(disasm.contains("CALL_SLOT"), disasm)
assertEquals(0, scope.eval("calc()").toInt())
}
@Test
fun localNamedFunctionUsesDirectCall() = runTest {
val scope = Script.newScope()
scope.eval(
"""
fun calc() {
fun twice(x: Int) { x * 2 }
twice(3)
}
""".trimIndent()
)
val disasm = scope.disassembleSymbol("calc")
assertTrue(disasm.contains("CALL_DIRECT"), disasm)
assertFalse(disasm.contains("CALL_SLOT"), disasm)
assertEquals(6, scope.eval("calc()").toInt())
}
@Test
fun localNamedFunctionAliasUsesDirectCall() = runTest {
val scope = Script.newScope()
scope.eval(
"""
fun calc() {
fun twice(x: Int) { x * 2 }
val f = twice
f(3)
}
""".trimIndent()
)
val disasm = scope.disassembleSymbol("calc")
assertTrue(disasm.contains("CALL_DIRECT"), disasm)
assertFalse(disasm.contains("CALL_SLOT"), disasm)
assertEquals(6, scope.eval("calc()").toInt())
}
@Test @Test
fun optionalIndexPreIncSkipsOnNullReceiver() = runTest { fun optionalIndexPreIncSkipsOnNullReceiver() = runTest {
eval( eval(