another import fix

This commit is contained in:
Sergey Chernov 2026-04-13 20:43:11 +03:00
parent ab39110834
commit 3b6bdda0a4
3 changed files with 219 additions and 36 deletions

View File

@ -1,15 +1,74 @@
/*
* Copyright 2026 Sergey S. Chernov real.sergeych@gmail.com
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package net.sergeych package net.sergeych
import kotlinx.coroutines.runBlocking
import net.sergeych.lyng.EvalSession import net.sergeych.lyng.EvalSession
import net.sergeych.lyng.Source import net.sergeych.lyng.Source
import net.sergeych.lyng.obj.ObjString import net.sergeych.lyng.obj.ObjString
import kotlinx.coroutines.runBlocking import org.junit.After
import org.junit.Before
import java.io.ByteArrayOutputStream
import java.io.PrintStream
import java.nio.file.Files import java.nio.file.Files
import kotlin.io.path.writeText import kotlin.io.path.writeText
import kotlin.test.Test import kotlin.test.Test
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertFalse
import kotlin.test.assertTrue
class CliLocalModuleImportRegressionJvmTest { class CliLocalModuleImportRegressionJvmTest {
private val originalOut: PrintStream = System.out
private val originalErr: PrintStream = System.err
private class TestExit(val code: Int) : RuntimeException()
@Before
fun setUp() {
jvmExitImpl = { code -> throw TestExit(code) }
}
@After
fun tearDown() {
System.setOut(originalOut)
System.setErr(originalErr)
jvmExitImpl = { code -> kotlin.system.exitProcess(code) }
}
private data class CliResult(val out: String, val err: String, val exitCode: Int?)
private fun runCli(vararg args: String): CliResult {
val outBuf = ByteArrayOutputStream()
val errBuf = ByteArrayOutputStream()
System.setOut(PrintStream(outBuf, true, Charsets.UTF_8))
System.setErr(PrintStream(errBuf, true, Charsets.UTF_8))
var exitCode: Int? = null
try {
runMain(arrayOf(*args))
} catch (e: TestExit) {
exitCode = e.code
} finally {
System.out.flush()
System.err.flush()
}
return CliResult(outBuf.toString("UTF-8"), errBuf.toString("UTF-8"), exitCode)
}
private fun writeTransitiveImportTree(root: java.nio.file.Path) { private fun writeTransitiveImportTree(root: java.nio.file.Path) {
val packageDir = Files.createDirectories(root.resolve("package1")) val packageDir = Files.createDirectories(root.resolve("package1"))
@ -74,6 +133,49 @@ class CliLocalModuleImportRegressionJvmTest {
) )
} }
private fun writeNestedLaunchImportBugTree(root: java.nio.file.Path) {
val packageDir = Files.createDirectories(root.resolve("package1"))
packageDir.resolve("alpha.lyng").writeText(
"""
import lyng.io.net
import package1.bravo
class Alpha {
val tcpServer: TcpServer
val headers = Map<String, String>()
fn startListen(port, host) {
tcpServer = Net.tcpListen(port, host)
// println("tcpServer.isOpen: " + tcpServer.isOpen()) // historical workaround; should not be needed
launch {
try {
while (true) {
val tcpSocket = tcpServer.accept()
var bravo = Bravo()
bravo.doSomething()
tcpSocket.close()
break
}
} finally {
tcpServer.close()
}
}
}
}
""".trimIndent()
)
packageDir.resolve("bravo.lyng").writeText(
"""
class Bravo {
fn doSomething() {
println("Bravo.doSomething")
}
}
""".trimIndent()
)
}
@Test @Test
fun localModuleUsingLaunchAndNetImportsWithoutStdlibRedefinition() = runBlocking { fun localModuleUsingLaunchAndNetImportsWithoutStdlibRedefinition() = runBlocking {
val root = Files.createTempDirectory("lyng-cli-import-regression") val root = Files.createTempDirectory("lyng-cli-import-regression")
@ -134,4 +236,43 @@ class CliLocalModuleImportRegressionJvmTest {
root.toFile().deleteRecursively() root.toFile().deleteRecursively()
} }
} }
@Test
fun localModuleImportUsedOnlyInsideMethodLaunchClosureRemainsPrepared() = runBlocking {
val root = Files.createTempDirectory("lyng-cli-import-regression-launch")
try {
val mainFile = root.resolve("main.lyng")
val port = java.net.ServerSocket(0).let {
val selected = it.localPort
it.close()
selected
}
writeNestedLaunchImportBugTree(root)
mainFile.writeText(
"""
import lyng.io.net
import package1.alpha
val alpha = Alpha()
alpha.startListen($port, "127.0.0.1")
delay(50)
val socket = Net.tcpConnect("127.0.0.1", $port)
socket.writeUtf8("ping")
socket.flush()
socket.close()
delay(50)
""".trimIndent()
)
val result = runCli(mainFile.toString())
assertTrue(result.err.isBlank(), result.err)
assertFalse(result.out.contains("module capture 'Bravo'"), result.out)
assertTrue(result.out.contains("Bravo.doSomething"), result.out)
} finally {
root.toFile().deleteRecursively()
}
}
} }

View File

@ -220,6 +220,38 @@ class Compiler(
return result return result
} }
private fun captureNamesForBytecodeFunction(
bytecodeFn: CmdFunction,
declaredCaptureNames: List<String> = emptyList()
): List<String> {
val ordered = LinkedHashSet<String>()
ordered.addAll(declaredCaptureNames)
val names = bytecodeFn.localSlotNames
val captures = bytecodeFn.localSlotCaptures
for (i in names.indices) {
if (captures.getOrNull(i) != true) continue
val name = names[i] ?: continue
ordered.add(name)
}
collectNestedModuleCaptureNames(bytecodeFn, ordered)
return ordered.toList()
}
private fun collectNestedModuleCaptureNames(bytecodeFn: CmdFunction, out: LinkedHashSet<String>) {
for (constant in bytecodeFn.constants) {
val lambda = constant as? BytecodeConst.LambdaFn ?: continue
val table = lambda.captureTableId?.let { bytecodeFn.constants.getOrNull(it) as? BytecodeConst.CaptureTable }
if (table != null) {
for ((index, entry) in table.entries.withIndex()) {
if (entry.ownerKind != CaptureOwnerFrameKind.MODULE) continue
val name = lambda.captureNames.getOrNull(index) ?: continue
out.add(name)
}
}
collectNestedModuleCaptureNames(lambda.fn, out)
}
}
private fun seedSlotPlanFromScope(scope: Scope, includeParents: Boolean = false) { private fun seedSlotPlanFromScope(scope: Scope, includeParents: Boolean = false) {
val plan = moduleSlotPlan() ?: return val plan = moduleSlotPlan() ?: return
seedingSlotPlan = true seedingSlotPlan = true
@ -9166,20 +9198,10 @@ class Compiler(
val declaredNames = bytecodeFn.constants val declaredNames = bytecodeFn.constants
.mapNotNull { it as? BytecodeConst.LocalDecl } .mapNotNull { it as? BytecodeConst.LocalDecl }
.mapTo(mutableSetOf()) { it.name } .mapTo(mutableSetOf()) { it.name }
val captureNames = if (captureSlots.isNotEmpty()) { val captureNames = captureNamesForBytecodeFunction(
bytecodeFn,
captureSlots.map { it.name } captureSlots.map { it.name }
} else { )
val fn = bytecodeFn
val names = fn.localSlotNames
val captures = fn.localSlotCaptures
val ordered = LinkedHashSet<String>()
for (i in names.indices) {
if (captures.getOrNull(i) != true) continue
val name = names[i] ?: continue
ordered.add(name)
}
ordered.toList()
}
val prebuiltCaptures = closureBox.captureRecords val prebuiltCaptures = closureBox.captureRecords
if (prebuiltCaptures != null && captureNames.isNotEmpty()) { if (prebuiltCaptures != null && captureNames.isNotEmpty()) {
context.captureRecords = prebuiltCaptures context.captureRecords = prebuiltCaptures

View File

@ -2536,10 +2536,8 @@ class CmdDeclFunction(internal val constId: Int, internal val slot: Int) : Cmd()
} }
private fun captureNamesForFunctionDecl(spec: net.sergeych.lyng.FunctionDeclSpec): List<String> { private fun captureNamesForFunctionDecl(spec: net.sergeych.lyng.FunctionDeclSpec): List<String> {
if (spec.captureSlots.isNotEmpty()) { val declaredCaptures = spec.captureSlots.map { it.name }
return spec.captureSlots.map { it.name } return mergeCaptureNames(declaredCaptures, captureNamesForStatement(spec.fnBody))
}
return captureNamesForStatement(spec.fnBody)
} }
private fun captureNamesForStatement(stmt: Statement?): List<String> { private fun captureNamesForStatement(stmt: Statement?): List<String> {
@ -2549,6 +2547,10 @@ private fun captureNamesForStatement(stmt: Statement?): List<String> {
is BytecodeBodyProvider -> stmt.bytecodeBody()?.bytecodeFunction() is BytecodeBodyProvider -> stmt.bytecodeBody()?.bytecodeFunction()
else -> null else -> null
} ?: return emptyList() } ?: return emptyList()
return captureNamesForBytecode(bytecode)
}
private fun captureNamesForBytecode(bytecode: CmdFunction): List<String> {
val names = bytecode.localSlotNames val names = bytecode.localSlotNames
val captures = bytecode.localSlotCaptures val captures = bytecode.localSlotCaptures
val ordered = LinkedHashSet<String>() val ordered = LinkedHashSet<String>()
@ -2557,9 +2559,33 @@ private fun captureNamesForStatement(stmt: Statement?): List<String> {
val name = names[i] ?: continue val name = names[i] ?: continue
ordered.add(name) ordered.add(name)
} }
collectNestedModuleCaptureNames(bytecode, ordered)
return ordered.toList() return ordered.toList()
} }
private fun collectNestedModuleCaptureNames(bytecode: CmdFunction, out: LinkedHashSet<String>) {
for (constant in bytecode.constants) {
val lambda = constant as? BytecodeConst.LambdaFn ?: continue
val table = lambda.captureTableId?.let { bytecode.constants.getOrNull(it) as? BytecodeConst.CaptureTable }
if (table != null) {
for ((index, entry) in table.entries.withIndex()) {
if (entry.ownerKind != CaptureOwnerFrameKind.MODULE) continue
val name = lambda.captureNames.getOrNull(index) ?: continue
out.add(name)
}
}
collectNestedModuleCaptureNames(lambda.fn, out)
}
}
private fun findInheritedCaptureRecord(scope: Scope, name: String): ObjRecord? {
val inheritedNames = scope.captureNames ?: return null
val inheritedRecords = scope.captureRecords ?: return null
val inheritedIndex = inheritedNames.indexOf(name)
if (inheritedIndex < 0) return null
return inheritedRecords.getOrNull(inheritedIndex)
}
private fun freezeImmutableCaptureRecord(record: ObjRecord): ObjRecord { private fun freezeImmutableCaptureRecord(record: ObjRecord): ObjRecord {
val value = record.value as Obj? val value = record.value as Obj?
if (record.isMutable || record.type == ObjRecord.Type.Delegated || record.type == ObjRecord.Type.Property || value is ObjProperty) { if (record.isMutable || record.type == ObjRecord.Type.Delegated || record.type == ObjRecord.Type.Property || value is ObjProperty) {
@ -4230,24 +4256,17 @@ class CmdFrame(
} }
val name = captureNames?.getOrNull(index) val name = captureNames?.getOrNull(index)
if (name != null) { if (name != null) {
val inheritedNames = scope.captureNames val inherited = findInheritedCaptureRecord(scope, name)
val inheritedRecords = scope.captureRecords if (inherited != null) {
if (inheritedNames != null && inheritedRecords != null) { val copied = ObjRecord(
val inheritedIndex = inheritedNames.indexOf(name) value = inherited.value,
if (inheritedIndex >= 0) { isMutable = inherited.isMutable,
val inherited = inheritedRecords.getOrNull(inheritedIndex) visibility = inherited.visibility,
if (inherited != null) { isTransient = inherited.isTransient,
val copied = ObjRecord( type = inherited.type
value = inherited.value, )
isMutable = inherited.isMutable, copied.delegate = inherited.delegate
visibility = inherited.visibility, return@mapIndexed copied
isTransient = inherited.isTransient,
type = inherited.type
)
copied.delegate = inherited.delegate
return@mapIndexed copied
}
}
} }
} }
val isMutable = fn.localSlotMutables.getOrNull(localIndex) ?: false val isMutable = fn.localSlotMutables.getOrNull(localIndex) ?: false
@ -4293,6 +4312,7 @@ class CmdFrame(
// Fallback to current scope in case the module scope isn't in the parent chain // Fallback to current scope in case the module scope isn't in the parent chain
// or doesn't carry the imported symbol yet. // or doesn't carry the imported symbol yet.
findNamedExistingRecord(scope, name)?.let { return@mapIndexed it } findNamedExistingRecord(scope, name)?.let { return@mapIndexed it }
findInheritedCaptureRecord(scope, name)?.let { return@mapIndexed it }
} }
if (slotId < target.slotCount) { if (slotId < target.slotCount) {
val existing = target.getSlotRecord(slotId) val existing = target.getSlotRecord(slotId)