diff --git a/provider/agent.go b/provider/agent.go index 4d628927..0ca6305d 100644 --- a/provider/agent.go +++ b/provider/agent.go @@ -6,6 +6,7 @@ import ( "os" "reflect" "strings" + "time" "github.com/google/uuid" "github.com/hashicorp/terraform-plugin-sdk/v2/diag" @@ -137,6 +138,13 @@ func agentResource() *schema.Resource { Optional: true, Description: "This option defines whether or not the user can (by default) login to the workspace before it is ready. Ready means that e.g. the startup_script is done and has exited. When enabled, users may see an incomplete workspace when logging in.", }, + "ssh_max_timeout": { + Type: schema.TypeString, + Default: "0", + ForceNew: true, + Optional: true, + Description: "The max timeout for SSH connections. By default there is no timeout.", + }, }, } } @@ -197,10 +205,19 @@ func updateInitScript(resourceData *schema.ResourceData, i interface{}) diag.Dia if err != nil { return diag.Errorf("parse access url: %s", err) } + sshMaxTimeout, valid := resourceData.Get("ssh_max_timeout").(string) + if !valid { + return diag.Errorf("ssh_max_timeout was unexpected type %q", reflect.TypeOf(resourceData.Get("ssh_max_timeout"))) + } + _, err = time.ParseDuration(sshMaxTimeout) + if err != nil { + return diag.Errorf("expected '%s' to be a duration string e.g. '2h45m10s'. Full error: %s", sshMaxTimeout, err.Error()) + } script := os.Getenv(fmt.Sprintf("CODER_AGENT_SCRIPT_%s_%s", operatingSystem, arch)) if script != "" { script = strings.ReplaceAll(script, "${ACCESS_URL}", accessURL.String()) script = strings.ReplaceAll(script, "${AUTH_TYPE}", auth) + script = strings.ReplaceAll(script, "${SSH_MAX_TIMEOUT}", sshMaxTimeout) } err = resourceData.Set("init_script", script) if err != nil { diff --git a/provider/agent_test.go b/provider/agent_test.go index 445ed75b..cd608ca1 100644 --- a/provider/agent_test.go +++ b/provider/agent_test.go @@ -37,6 +37,7 @@ func TestAgent(t *testing.T) { shutdown_script = "echo bye bye" shutdown_script_timeout = 120 login_before_ready = false + ssh_max_timeout = "1m" } `, Check: func(state *terraform.State) error {