Skip to content

Commit c1dd875

Browse files
committed
fix(cli): prevent sqlDB leaks in ConnectToPostgres
1 parent 3044091 commit c1dd875

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

cli/server.go

+19-15
Original file line numberDiff line numberDiff line change
@@ -1875,19 +1875,26 @@ func BuildLogger(inv *clibase.Invocation, cfg *codersdk.DeploymentValues) (slog.
18751875
}, nil
18761876
}
18771877

1878-
func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, dbURL string) (*sql.DB, error) {
1878+
func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, dbURL string) (sqlDB *sql.DB, err error) {
18791879
logger.Debug(ctx, "connecting to postgresql")
18801880

18811881
// Try to connect for 30 seconds.
18821882
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
18831883
defer cancel()
18841884

1885-
var (
1886-
sqlDB *sql.DB
1887-
err error
1888-
ok = false
1889-
tries int
1890-
)
1885+
var tries int
1886+
1887+
defer func() {
1888+
if err == nil {
1889+
return
1890+
}
1891+
if sqlDB != nil {
1892+
_ = sqlDB.Close()
1893+
sqlDB = nil
1894+
}
1895+
logger.Error(ctx, "connect to postgres; tries %d", slog.Error(err))
1896+
}()
1897+
18911898
for r := retry.New(time.Second, 3*time.Second); r.Wait(ctx); {
18921899
tries++
18931900

@@ -1900,18 +1907,16 @@ func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, d
19001907
err = pingPostgres(ctx, sqlDB)
19011908
if err != nil {
19021909
logger.Warn(ctx, "ping postgres; retrying", slog.Error(err), slog.F("try", tries))
1910+
_ = sqlDB.Close()
1911+
sqlDB = nil
19031912
continue
19041913
}
19051914

19061915
break
19071916
}
1908-
// Make sure we close the DB in case it opened but the ping failed for some
1909-
// reason.
1910-
defer func() {
1911-
if !ok && sqlDB != nil {
1912-
_ = sqlDB.Close()
1913-
}
1914-
}()
1917+
if err == nil {
1918+
err = ctx.Err()
1919+
}
19151920
if err != nil {
19161921
return nil, xerrors.Errorf("connect to postgres; tries %d; last error: %w", tries, err)
19171922
}
@@ -1958,7 +1963,6 @@ func ConnectToPostgres(ctx context.Context, logger slog.Logger, driver string, d
19581963
// of connection churn.
19591964
sqlDB.SetMaxIdleConns(3)
19601965

1961-
ok = true
19621966
return sqlDB, nil
19631967
}
19641968

0 commit comments

Comments
 (0)