Skip to content

Commit a631a4e

Browse files
committed
Break out read/write/modify SSH config
1 parent a605d79 commit a631a4e

File tree

2 files changed

+93
-57
lines changed

2 files changed

+93
-57
lines changed

src/main/kotlin/com/coder/gateway/sdk/CoderCLIManager.kt

Lines changed: 88 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@ import javax.xml.bind.annotation.adapters.HexBinaryAdapter
2424
/**
2525
* Manage the CLI for a single deployment.
2626
*/
27-
class CoderCLIManager @JvmOverloads constructor(private val deploymentURL: URL, destinationDir: Path = getDataDir()) {
27+
class CoderCLIManager @JvmOverloads constructor(
28+
private val deploymentURL: URL,
29+
destinationDir: Path = getDataDir(),
30+
private val sshConfigPath: Path = Path.of(System.getProperty("user.home")).resolve(".ssh/config"),
31+
) {
2832
private var remoteBinaryUrl: URL
2933
var localBinaryPath: Path
3034
private var coderConfigPath: Path
@@ -163,10 +167,27 @@ class CoderCLIManager @JvmOverloads constructor(private val deploymentURL: URL,
163167
/**
164168
* Configure SSH to use this binary.
165169
*/
166-
fun configSsh(
167-
workspaces: List<WorkspaceAgentModel>,
168-
sshConfigPath: Path = Path.of(System.getProperty("user.home")).resolve(".ssh/config"),
169-
) {
170+
fun configSsh(workspaces: List<WorkspaceAgentModel>) {
171+
writeSSHConfig(modifySSHConfig(readSSHConfig(), workspaces))
172+
}
173+
174+
/**
175+
* Return the contents of the SSH config or null if it does not exist.
176+
*/
177+
private fun readSSHConfig(): String? {
178+
return try {
179+
sshConfigPath.toFile().readText()
180+
} catch (e: FileNotFoundException) {
181+
null
182+
}
183+
}
184+
185+
/**
186+
* Given an existing SSH config modify it to add or remove the config for
187+
* this deployment and return the modified config or null if it does not
188+
* need to be modified.
189+
*/
190+
private fun modifySSHConfig(contents: String?, workspaces: List<WorkspaceAgentModel>): String? {
170191
val host = getSafeHost(deploymentURL)
171192
val startBlock = "# --- START CODER JETBRAINS $host"
172193
val endBlock = "# --- END CODER JETBRAINS $host"
@@ -187,53 +208,68 @@ class CoderCLIManager @JvmOverloads constructor(private val deploymentURL: URL,
187208
SetEnv CODER_SSH_SESSION_TYPE=JetBrains
188209
""".trimIndent().replace("\n", System.lineSeparator())
189210
})
190-
Files.createDirectories(sshConfigPath.parent)
191-
try {
192-
val contents = sshConfigPath.toFile().readText()
193-
val start = "(\\s*)$startBlock".toRegex().find(contents)
194-
val end = "$endBlock(\\s*)".toRegex().find(contents)
195-
if (start == null && end == null && isRemoving) {
196-
logger.info("Leaving $sshConfigPath alone since there are no workspaces and no config to remove")
197-
} else if (start == null && end == null) {
198-
logger.info("Appending config to $sshConfigPath")
199-
val toAppend = if (contents.isEmpty()) blockContent else listOf(
200-
contents,
201-
blockContent
202-
).joinToString(System.lineSeparator())
203-
sshConfigPath.toFile().writeText(toAppend + System.lineSeparator())
204-
} else if (start == null) {
205-
throw SSHConfigFormatException("End block exists but no start block")
206-
} else if (end == null) {
207-
throw SSHConfigFormatException("Start block exists but no end block")
208-
} else if (start.range.first > end.range.first) {
209-
throw SSHConfigFormatException("Start block found after end block")
210-
} else if (isRemoving) {
211-
logger.info("Removing config from $sshConfigPath")
212-
sshConfigPath.toFile().writeText(
213-
listOf(
214-
contents.substring(0, start.range.first),
215-
// Need to keep the trailing newline(s) if we are not at
216-
// the front of the file otherwise the before and after
217-
// lines would get joined.
218-
if (start.range.first > 0) end.groupValues[1] else "",
219-
contents.substring(end.range.last + 1)
220-
).joinToString("")
221-
)
222-
} else {
223-
logger.info("Replacing config in $sshConfigPath")
224-
sshConfigPath.toFile().writeText(
225-
listOf(
226-
contents.substring(0, start.range.first),
227-
start.groupValues[1], // Leading newline(s).
228-
blockContent,
229-
end.groupValues[1], // Trailing newline(s).
230-
contents.substring(end.range.last + 1)
231-
).joinToString("")
232-
)
233-
}
234-
} catch (e: FileNotFoundException) {
235-
logger.info("Writing config to $sshConfigPath")
236-
sshConfigPath.toFile().writeText(blockContent + System.lineSeparator())
211+
212+
if (contents == null) {
213+
logger.info("No existing SSH config to modify")
214+
return blockContent + System.lineSeparator()
215+
}
216+
217+
val start = "(\\s*)$startBlock".toRegex().find(contents)
218+
val end = "$endBlock(\\s*)".toRegex().find(contents)
219+
220+
if (start == null && end == null && isRemoving) {
221+
logger.info("No workspaces and no existing config blocks to remove")
222+
return null
223+
}
224+
225+
if (start == null && end == null) {
226+
logger.info("Appending config block")
227+
val toAppend = if (contents.isEmpty()) blockContent else listOf(
228+
contents,
229+
blockContent
230+
).joinToString(System.lineSeparator())
231+
return toAppend + System.lineSeparator()
232+
}
233+
234+
if (start == null) {
235+
throw SSHConfigFormatException("End block exists but no start block")
236+
}
237+
if (end == null) {
238+
throw SSHConfigFormatException("Start block exists but no end block")
239+
}
240+
if (start.range.first > end.range.first) {
241+
throw SSHConfigFormatException("Start block found after end block")
242+
}
243+
244+
if (isRemoving) {
245+
logger.info("No workspaces; removing config block")
246+
return listOf(
247+
contents.substring(0, start.range.first),
248+
// Need to keep the trailing newline(s) if we are not at the
249+
// front of the file otherwise the before and after lines would
250+
// get joined.
251+
if (start.range.first > 0) end.groupValues[1] else "",
252+
contents.substring(end.range.last + 1)
253+
).joinToString("")
254+
}
255+
256+
logger.info("Replacing existing config block")
257+
return listOf(
258+
contents.substring(0, start.range.first),
259+
start.groupValues[1], // Leading newline(s).
260+
blockContent,
261+
end.groupValues[1], // Trailing newline(s).
262+
contents.substring(end.range.last + 1)
263+
).joinToString("")
264+
}
265+
266+
/**
267+
* Write the provided SSH config or do nothing if null.
268+
*/
269+
private fun writeSSHConfig(contents: String?) {
270+
if (contents != null) {
271+
Files.createDirectories(sshConfigPath.parent)
272+
sshConfigPath.toFile().writeText(contents)
237273
}
238274
}
239275

src/test/groovy/CoderCLIManagerTest.groovy

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,8 @@ class CoderCLIManagerTest extends spock.lang.Specification {
322322

323323
def "configures an SSH file"() {
324324
given:
325-
def ccm = new CoderCLIManager(new URL("https://test.coder.invalid"), tmpdir)
326325
def sshConfigPath = tmpdir.resolve(input + "_to_" + output + ".conf")
326+
def ccm = new CoderCLIManager(new URL("https://test.coder.invalid"), tmpdir, sshConfigPath)
327327
if (input != null) {
328328
Files.createDirectories(sshConfigPath.getParent())
329329
def originalConf = Path.of("src/test/fixtures/inputs").resolve(input + ".conf").toFile().text
@@ -338,13 +338,13 @@ class CoderCLIManagerTest extends spock.lang.Specification {
338338
.replace("/tmp/coder-gateway/test.coder.invalid/coder-linux-amd64", ccm.localBinaryPath.toString())
339339

340340
when:
341-
ccm.configSsh(workspaces.collect { randWorkspace(it) }, sshConfigPath)
341+
ccm.configSsh(workspaces.collect { randWorkspace(it) })
342342

343343
then:
344344
sshConfigPath.toFile().text == expectedConf
345345

346346
when:
347-
ccm.configSsh(List.of(), sshConfigPath)
347+
ccm.configSsh(List.of())
348348

349349
then:
350350
sshConfigPath.toFile().text == Path.of("src/test/fixtures/inputs").resolve(remove + ".conf").toFile().text
@@ -367,8 +367,8 @@ class CoderCLIManagerTest extends spock.lang.Specification {
367367

368368
def "fails if config is malformed"() {
369369
given:
370-
def ccm = new CoderCLIManager(new URL("https://test.coder.invalid"), tmpdir)
371370
def sshConfigPath = tmpdir.resolve("configured" + input + ".conf")
371+
def ccm = new CoderCLIManager(new URL("https://test.coder.invalid"), tmpdir, sshConfigPath)
372372
Files.createDirectories(sshConfigPath.getParent())
373373
Files.copy(
374374
Path.of("src/test/fixtures/inputs").resolve(input + ".conf"),
@@ -377,7 +377,7 @@ class CoderCLIManagerTest extends spock.lang.Specification {
377377
)
378378

379379
when:
380-
ccm.configSsh(List.of(), sshConfigPath)
380+
ccm.configSsh(List.of())
381381

382382
then:
383383
thrown(SSHConfigFormatException)

0 commit comments

Comments
 (0)