another import fix

This commit is contained in:
Sergey Chernov 2026-04-13 20:43:11 +03:00
parent ab39110834
commit 3b6bdda0a4
3 changed files with 219 additions and 36 deletions

View File

@ -1,15 +1,74 @@
/*
* Copyright 2026 Sergey S. Chernov real.sergeych@gmail.com
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package net.sergeych
import kotlinx.coroutines.runBlocking
import net.sergeych.lyng.EvalSession
import net.sergeych.lyng.Source
import net.sergeych.lyng.obj.ObjString
import kotlinx.coroutines.runBlocking
import org.junit.After
import org.junit.Before
import java.io.ByteArrayOutputStream
import java.io.PrintStream
import java.nio.file.Files
import kotlin.io.path.writeText
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFalse
import kotlin.test.assertTrue
class CliLocalModuleImportRegressionJvmTest {
private val originalOut: PrintStream = System.out
private val originalErr: PrintStream = System.err
private class TestExit(val code: Int) : RuntimeException()
@Before
fun setUp() {
jvmExitImpl = { code -> throw TestExit(code) }
}
@After
fun tearDown() {
System.setOut(originalOut)
System.setErr(originalErr)
jvmExitImpl = { code -> kotlin.system.exitProcess(code) }
}
private data class CliResult(val out: String, val err: String, val exitCode: Int?)
private fun runCli(vararg args: String): CliResult {
val outBuf = ByteArrayOutputStream()
val errBuf = ByteArrayOutputStream()
System.setOut(PrintStream(outBuf, true, Charsets.UTF_8))
System.setErr(PrintStream(errBuf, true, Charsets.UTF_8))
var exitCode: Int? = null
try {
runMain(arrayOf(*args))
} catch (e: TestExit) {
exitCode = e.code
} finally {
System.out.flush()
System.err.flush()
}
return CliResult(outBuf.toString("UTF-8"), errBuf.toString("UTF-8"), exitCode)
}
private fun writeTransitiveImportTree(root: java.nio.file.Path) {
val packageDir = Files.createDirectories(root.resolve("package1"))
@ -74,6 +133,49 @@ class CliLocalModuleImportRegressionJvmTest {
)
}
private fun writeNestedLaunchImportBugTree(root: java.nio.file.Path) {
val packageDir = Files.createDirectories(root.resolve("package1"))
packageDir.resolve("alpha.lyng").writeText(
"""
import lyng.io.net
import package1.bravo
class Alpha {
val tcpServer: TcpServer
val headers = Map<String, String>()
fn startListen(port, host) {
tcpServer = Net.tcpListen(port, host)
// println("tcpServer.isOpen: " + tcpServer.isOpen()) // historical workaround; should not be needed
launch {
try {
while (true) {
val tcpSocket = tcpServer.accept()
var bravo = Bravo()
bravo.doSomething()
tcpSocket.close()
break
}
} finally {
tcpServer.close()
}
}
}
}
""".trimIndent()
)
packageDir.resolve("bravo.lyng").writeText(
"""
class Bravo {
fn doSomething() {
println("Bravo.doSomething")
}
}
""".trimIndent()
)
}
@Test
fun localModuleUsingLaunchAndNetImportsWithoutStdlibRedefinition() = runBlocking {
val root = Files.createTempDirectory("lyng-cli-import-regression")
@ -134,4 +236,43 @@ class CliLocalModuleImportRegressionJvmTest {
root.toFile().deleteRecursively()
}
}
@Test
fun localModuleImportUsedOnlyInsideMethodLaunchClosureRemainsPrepared() = runBlocking {
val root = Files.createTempDirectory("lyng-cli-import-regression-launch")
try {
val mainFile = root.resolve("main.lyng")
val port = java.net.ServerSocket(0).let {
val selected = it.localPort
it.close()
selected
}
writeNestedLaunchImportBugTree(root)
mainFile.writeText(
"""
import lyng.io.net
import package1.alpha
val alpha = Alpha()
alpha.startListen($port, "127.0.0.1")
delay(50)
val socket = Net.tcpConnect("127.0.0.1", $port)
socket.writeUtf8("ping")
socket.flush()
socket.close()
delay(50)
""".trimIndent()
)
val result = runCli(mainFile.toString())
assertTrue(result.err.isBlank(), result.err)
assertFalse(result.out.contains("module capture 'Bravo'"), result.out)
assertTrue(result.out.contains("Bravo.doSomething"), result.out)
} finally {
root.toFile().deleteRecursively()
}
}
}

View File

@ -220,6 +220,38 @@ class Compiler(
return result
}
private fun captureNamesForBytecodeFunction(
bytecodeFn: CmdFunction,
declaredCaptureNames: List<String> = emptyList()
): List<String> {
val ordered = LinkedHashSet<String>()
ordered.addAll(declaredCaptureNames)
val names = bytecodeFn.localSlotNames
val captures = bytecodeFn.localSlotCaptures
for (i in names.indices) {
if (captures.getOrNull(i) != true) continue
val name = names[i] ?: continue
ordered.add(name)
}
collectNestedModuleCaptureNames(bytecodeFn, ordered)
return ordered.toList()
}
private fun collectNestedModuleCaptureNames(bytecodeFn: CmdFunction, out: LinkedHashSet<String>) {
for (constant in bytecodeFn.constants) {
val lambda = constant as? BytecodeConst.LambdaFn ?: continue
val table = lambda.captureTableId?.let { bytecodeFn.constants.getOrNull(it) as? BytecodeConst.CaptureTable }
if (table != null) {
for ((index, entry) in table.entries.withIndex()) {
if (entry.ownerKind != CaptureOwnerFrameKind.MODULE) continue
val name = lambda.captureNames.getOrNull(index) ?: continue
out.add(name)
}
}
collectNestedModuleCaptureNames(lambda.fn, out)
}
}
private fun seedSlotPlanFromScope(scope: Scope, includeParents: Boolean = false) {
val plan = moduleSlotPlan() ?: return
seedingSlotPlan = true
@ -9166,20 +9198,10 @@ class Compiler(
val declaredNames = bytecodeFn.constants
.mapNotNull { it as? BytecodeConst.LocalDecl }
.mapTo(mutableSetOf()) { it.name }
val captureNames = if (captureSlots.isNotEmpty()) {
val captureNames = captureNamesForBytecodeFunction(
bytecodeFn,
captureSlots.map { it.name }
} else {
val fn = bytecodeFn
val names = fn.localSlotNames
val captures = fn.localSlotCaptures
val ordered = LinkedHashSet<String>()
for (i in names.indices) {
if (captures.getOrNull(i) != true) continue
val name = names[i] ?: continue
ordered.add(name)
}
ordered.toList()
}
)
val prebuiltCaptures = closureBox.captureRecords
if (prebuiltCaptures != null && captureNames.isNotEmpty()) {
context.captureRecords = prebuiltCaptures

View File

@ -2536,10 +2536,8 @@ class CmdDeclFunction(internal val constId: Int, internal val slot: Int) : Cmd()
}
private fun captureNamesForFunctionDecl(spec: net.sergeych.lyng.FunctionDeclSpec): List<String> {
if (spec.captureSlots.isNotEmpty()) {
return spec.captureSlots.map { it.name }
}
return captureNamesForStatement(spec.fnBody)
val declaredCaptures = spec.captureSlots.map { it.name }
return mergeCaptureNames(declaredCaptures, captureNamesForStatement(spec.fnBody))
}
private fun captureNamesForStatement(stmt: Statement?): List<String> {
@ -2549,6 +2547,10 @@ private fun captureNamesForStatement(stmt: Statement?): List<String> {
is BytecodeBodyProvider -> stmt.bytecodeBody()?.bytecodeFunction()
else -> null
} ?: return emptyList()
return captureNamesForBytecode(bytecode)
}
private fun captureNamesForBytecode(bytecode: CmdFunction): List<String> {
val names = bytecode.localSlotNames
val captures = bytecode.localSlotCaptures
val ordered = LinkedHashSet<String>()
@ -2557,9 +2559,33 @@ private fun captureNamesForStatement(stmt: Statement?): List<String> {
val name = names[i] ?: continue
ordered.add(name)
}
collectNestedModuleCaptureNames(bytecode, ordered)
return ordered.toList()
}
private fun collectNestedModuleCaptureNames(bytecode: CmdFunction, out: LinkedHashSet<String>) {
for (constant in bytecode.constants) {
val lambda = constant as? BytecodeConst.LambdaFn ?: continue
val table = lambda.captureTableId?.let { bytecode.constants.getOrNull(it) as? BytecodeConst.CaptureTable }
if (table != null) {
for ((index, entry) in table.entries.withIndex()) {
if (entry.ownerKind != CaptureOwnerFrameKind.MODULE) continue
val name = lambda.captureNames.getOrNull(index) ?: continue
out.add(name)
}
}
collectNestedModuleCaptureNames(lambda.fn, out)
}
}
private fun findInheritedCaptureRecord(scope: Scope, name: String): ObjRecord? {
val inheritedNames = scope.captureNames ?: return null
val inheritedRecords = scope.captureRecords ?: return null
val inheritedIndex = inheritedNames.indexOf(name)
if (inheritedIndex < 0) return null
return inheritedRecords.getOrNull(inheritedIndex)
}
private fun freezeImmutableCaptureRecord(record: ObjRecord): ObjRecord {
val value = record.value as Obj?
if (record.isMutable || record.type == ObjRecord.Type.Delegated || record.type == ObjRecord.Type.Property || value is ObjProperty) {
@ -4230,12 +4256,7 @@ class CmdFrame(
}
val name = captureNames?.getOrNull(index)
if (name != null) {
val inheritedNames = scope.captureNames
val inheritedRecords = scope.captureRecords
if (inheritedNames != null && inheritedRecords != null) {
val inheritedIndex = inheritedNames.indexOf(name)
if (inheritedIndex >= 0) {
val inherited = inheritedRecords.getOrNull(inheritedIndex)
val inherited = findInheritedCaptureRecord(scope, name)
if (inherited != null) {
val copied = ObjRecord(
value = inherited.value,
@ -4248,8 +4269,6 @@ class CmdFrame(
return@mapIndexed copied
}
}
}
}
val isMutable = fn.localSlotMutables.getOrNull(localIndex) ?: false
val isDelegated = fn.localSlotDelegated.getOrNull(localIndex) ?: false
if (isDelegated) {
@ -4293,6 +4312,7 @@ class CmdFrame(
// Fallback to current scope in case the module scope isn't in the parent chain
// or doesn't carry the imported symbol yet.
findNamedExistingRecord(scope, name)?.let { return@mapIndexed it }
findInheritedCaptureRecord(scope, name)?.let { return@mapIndexed it }
}
if (slotId < target.slotCount) {
val existing = target.getSlotRecord(slotId)