diff --git a/conn.go b/conn.go index e050d535..7d83f672 100644 --- a/conn.go +++ b/conn.go @@ -31,8 +31,10 @@ var ( ErrNotSupported = errors.New("pq: Unsupported command") ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction") ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server") - ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less") - ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly") + ErrSSLKeyUnknownOwnership = errors.New("pq: Could not get owner information for private key, may not be properly protected") + ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key has world access. Permissions should be u=rw,g=r (0640) if owned by root, or u=rw (0600), or less") + + ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly") errUnexpectedReady = errors.New("unexpected ReadyForQuery") errNoRowsAffected = errors.New("no RowsAffected available after the empty statement") @@ -322,7 +324,7 @@ func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) { if err != nil { return nil, err } - c.dialer = d + c.Dialer(d) return c.open(context.Background()) } diff --git a/connector.go b/connector.go index d7d47261..1145e122 100644 --- a/connector.go +++ b/connector.go @@ -27,6 +27,11 @@ func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) { return c.open(ctx) } +// Dialer allows change the dialer used to open connections. +func (c *Connector) Dialer(dialer Dialer) { + c.dialer = dialer +} + // Driver returns the underlying driver of this Connector. func (c *Connector) Driver() driver.Driver { return &Driver{} diff --git a/copy.go b/copy.go index c072bc3b..2f5c1ec8 100644 --- a/copy.go +++ b/copy.go @@ -1,6 +1,7 @@ package pq import ( + "context" "database/sql/driver" "encoding/binary" "errors" @@ -273,6 +274,43 @@ func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) { return driver.RowsAffected(0), nil } +// CopyData inserts a raw string into the COPY stream. The insert is +// asynchronous and CopyData can return errors from previous CopyData calls to +// the same COPY stmt. +// +// You need to call Exec(nil) to sync the COPY stream and to get any +// errors from pending data, since Stmt.Close() doesn't return errors +// to the user. +func (ci *copyin) CopyData(ctx context.Context, line string) (r driver.Result, err error) { + if ci.closed { + return nil, errCopyInClosed + } + + if finish := ci.cn.watchCancel(ctx); finish != nil { + defer finish() + } + + if err := ci.getBad(); err != nil { + return nil, err + } + defer ci.cn.errRecover(&err) + + if err := ci.err(); err != nil { + return nil, err + } + + ci.buffer = append(ci.buffer, []byte(line)...) + ci.buffer = append(ci.buffer, '\n') + + if len(ci.buffer) > ciBufferFlushSize { + ci.flush(ci.buffer) + // reset buffer, keep bytes for message identifier and length + ci.buffer = ci.buffer[:5] + } + + return driver.RowsAffected(0), nil +} + func (ci *copyin) Close() (err error) { if ci.closed { // Don't do anything, we're already closed return nil diff --git a/encode.go b/encode.go index 210b1ec3..bffe6096 100644 --- a/encode.go +++ b/encode.go @@ -422,7 +422,7 @@ func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, erro if remainderIdx < len(str) && str[remainderIdx] == '.' { fracStart := remainderIdx + 1 - fracOff := strings.IndexAny(str[fracStart:], "-+ ") + fracOff := strings.IndexAny(str[fracStart:], "-+Z ") if fracOff < 0 { fracOff = len(str) - fracStart } @@ -432,7 +432,7 @@ func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, erro remainderIdx += fracOff + 1 } if tzStart := remainderIdx; tzStart < len(str) && (str[tzStart] == '-' || str[tzStart] == '+') { - // time zone separator is always '-' or '+' (UTC is +00) + // time zone separator is always '-' or '+' or 'Z' (UTC is +00) var tzSign int switch c := str[tzStart]; c { case '-': @@ -454,7 +454,11 @@ func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, erro remainderIdx += 3 } tzOff = tzSign * ((tzHours * 60 * 60) + (tzMin * 60) + tzSec) + } else if tzStart < len(str) && str[tzStart] == 'Z' { + // time zone Z separator indicates UTC is +00 + remainderIdx += 1 } + var isoYear int if isBC { diff --git a/encode_test.go b/encode_test.go index bffa8386..69f9ebb1 100644 --- a/encode_test.go +++ b/encode_test.go @@ -828,6 +828,43 @@ func TestAppendEscapedTextExistingBuffer(t *testing.T) { } } +var formatAndParseTimestamp = []struct { + time time.Time + expected string +}{ + {time.Time{}, "0001-01-01 00:00:00Z"}, + {time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "2001-02-03 04:05:06.123456789Z"}, + {time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "2001-02-03 04:05:06.123456789+02:00"}, + {time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "2001-02-03 04:05:06.123456789-06:00"}, + {time.Date(2001, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "2001-02-03 04:05:06-07:30:09"}, + + {time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "0001-02-03 04:05:06.123456789Z"}, + {time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "0001-02-03 04:05:06.123456789+02:00"}, + {time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "0001-02-03 04:05:06.123456789-06:00"}, + + {time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "0001-02-03 04:05:06.123456789Z BC"}, + {time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "0001-02-03 04:05:06.123456789+02:00 BC"}, + {time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "0001-02-03 04:05:06.123456789-06:00 BC"}, + + {time.Date(1, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "0001-02-03 04:05:06-07:30:09"}, + {time.Date(0, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "0001-02-03 04:05:06-07:30:09 BC"}, +} + +func TestFormatAndParseTimestamp(t *testing.T) { + for _, val := range formatAndParseTimestamp { + formattedTime := FormatTimestamp(val.time) + parsedTime, err := ParseTimestamp(nil, string(formattedTime)) + + if err != nil { + t.Errorf("invalid parsing, err: %v", err.Error()) + } + + if val.time.UTC() != parsedTime.UTC() { + t.Errorf("invalid parsing from formatted timestamp, got %v; expected %v", parsedTime.String(), val.time.String()) + } + } +} + func BenchmarkAppendEscapedText(b *testing.B) { longString := "" for i := 0; i < 100; i++ { diff --git a/error.go b/error.go index 5cfe9c6e..21b3d933 100644 --- a/error.go +++ b/error.go @@ -402,6 +402,11 @@ func (err *Error) Fatal() bool { return err.Severity == Efatal } +// SQLState returns the SQLState of the error. +func (err *Error) SQLState() string { + return string(err.Code) +} + // Get implements the legacy PGError interface. New code should use the fields // of the Error struct directly. func (err *Error) Get(k byte) (v string) { diff --git a/go18_test.go b/go18_test.go index bcc02006..6166db27 100644 --- a/go18_test.go +++ b/go18_test.go @@ -334,3 +334,19 @@ func TestTxOptions(t *testing.T) { t.Errorf("Expected error to mention isolation level, got %q", err) } } + +func TestErrorSQLState(t *testing.T) { + r := readBuf([]byte{67, 52, 48, 48, 48, 49, 0, 0}) // 40001 + err := parseError(&r) + var sqlErr errWithSQLState + if !errors.As(err, &sqlErr) { + t.Fatal("SQLState interface not satisfied") + } + if state := err.SQLState(); state != "40001" { + t.Fatalf("unexpected SQL state %v", state) + } +} + +type errWithSQLState interface { + SQLState() string +} diff --git a/ssl_permissions.go b/ssl_permissions.go index 014af6a1..d587f102 100644 --- a/ssl_permissions.go +++ b/ssl_permissions.go @@ -3,7 +3,28 @@ package pq -import "os" +import ( + "errors" + "os" + "syscall" +) + +const ( + rootUserID = uint32(0) + + // The maximum permissions that a private key file owned by a regular user + // is allowed to have. This translates to u=rw. + maxUserOwnedKeyPermissions os.FileMode = 0600 + + // The maximum permissions that a private key file owned by root is allowed + // to have. This translates to u=rw,g=r. + maxRootOwnedKeyPermissions os.FileMode = 0640 +) + +var ( + errSSLKeyHasUnacceptableUserPermissions = errors.New("permissions for files not owned by root should be u=rw (0600) or less") + errSSLKeyHasUnacceptableRootPermissions = errors.New("permissions for root owned files should be u=rw,g=r (0640) or less") +) // sslKeyPermissions checks the permissions on user-supplied ssl key files. // The key file should have very little access. @@ -14,8 +35,59 @@ func sslKeyPermissions(sslkey string) error { if err != nil { return err } - if info.Mode().Perm()&0077 != 0 { - return ErrSSLKeyHasWorldPermissions + + err = hasCorrectPermissions(info) + + // return ErrSSLKeyHasWorldPermissions for backwards compatability with + // existing code. + if err == errSSLKeyHasUnacceptableUserPermissions || err == errSSLKeyHasUnacceptableRootPermissions { + err = ErrSSLKeyHasWorldPermissions } - return nil + return err +} + +// hasCorrectPermissions checks the file info (and the unix-specific stat_t +// output) to verify that the permissions on the file are correct. +// +// If the file is owned by the same user the process is running as, +// the file should only have 0600 (u=rw). If the file is owned by root, +// and the group matches the group that the process is running in, the +// permissions cannot be more than 0640 (u=rw,g=r). The file should +// never have world permissions. +// +// Returns an error when the permission check fails. +func hasCorrectPermissions(info os.FileInfo) error { + // if file's permission matches 0600, allow access. + userPermissionMask := (os.FileMode(0777) ^ maxUserOwnedKeyPermissions) + + // regardless of if we're running as root or not, 0600 is acceptable, + // so we return if we match the regular user permission mask. + if info.Mode().Perm()&userPermissionMask == 0 { + return nil + } + + // We need to pull the Unix file information to get the file's owner. + // If we can't access it, there's some sort of operating system level error + // and we should fail rather than attempting to use faulty information. + sysInfo := info.Sys() + if sysInfo == nil { + return ErrSSLKeyUnknownOwnership + } + + unixStat, ok := sysInfo.(*syscall.Stat_t) + if !ok { + return ErrSSLKeyUnknownOwnership + } + + // if the file is owned by root, we allow 0640 (u=rw,g=r) to match what + // Postgres does. + if unixStat.Uid == rootUserID { + rootPermissionMask := (os.FileMode(0777) ^ maxRootOwnedKeyPermissions) + if info.Mode().Perm()&rootPermissionMask != 0 { + return errSSLKeyHasUnacceptableRootPermissions + } + return nil + } + + return errSSLKeyHasUnacceptableUserPermissions } diff --git a/ssl_permissions_test.go b/ssl_permissions_test.go new file mode 100644 index 00000000..b0bdca10 --- /dev/null +++ b/ssl_permissions_test.go @@ -0,0 +1,109 @@ +//go:build !windows +// +build !windows + +package pq + +import ( + "os" + "syscall" + "testing" + "time" +) + +type stat_t_wrapper struct { + stat syscall.Stat_t +} + +func (stat_t *stat_t_wrapper) Name() string { + return "pem.key" +} + +func (stat_t *stat_t_wrapper) Size() int64 { + return int64(100) +} + +func (stat_t *stat_t_wrapper) Mode() os.FileMode { + return os.FileMode(stat_t.stat.Mode) +} + +func (stat_t *stat_t_wrapper) ModTime() time.Time { + return time.Now() +} + +func (stat_t *stat_t_wrapper) IsDir() bool { + return true +} + +func (stat_t *stat_t_wrapper) Sys() interface{} { + return &stat_t.stat +} + +func TestHasCorrectRootGroupPermissions(t *testing.T) { + currentUID := uint32(os.Getuid()) + currentGID := uint32(os.Getgid()) + + testData := []struct { + expectedError error + stat syscall.Stat_t + }{ + { + expectedError: nil, + stat: syscall.Stat_t{ + Mode: 0600, + Uid: currentUID, + Gid: currentGID, + }, + }, + { + expectedError: nil, + stat: syscall.Stat_t{ + Mode: 0640, + Uid: 0, + Gid: currentGID, + }, + }, + { + expectedError: errSSLKeyHasUnacceptableUserPermissions, + stat: syscall.Stat_t{ + Mode: 0666, + Uid: currentUID, + Gid: currentGID, + }, + }, + { + expectedError: errSSLKeyHasUnacceptableRootPermissions, + stat: syscall.Stat_t{ + Mode: 0666, + Uid: 0, + Gid: currentGID, + }, + }, + } + + for _, test := range testData { + wrapper := &stat_t_wrapper{ + stat: test.stat, + } + + if test.expectedError != hasCorrectPermissions(wrapper) { + if test.expectedError == nil { + t.Errorf( + "file owned by %d:%d with %s should not have failed check with error \"%s\"", + test.stat.Uid, + test.stat.Gid, + wrapper.Mode(), + hasCorrectPermissions(wrapper), + ) + continue + } + t.Errorf( + "file owned by %d:%d with %s, expected \"%s\", got \"%s\"", + test.stat.Uid, + test.stat.Gid, + wrapper.Mode(), + test.expectedError, + hasCorrectPermissions(wrapper), + ) + } + } +}