diff --git a/cryptobyte/asn1.go b/cryptobyte/asn1.go index 6fc2838a3f..2492f796af 100644 --- a/cryptobyte/asn1.go +++ b/cryptobyte/asn1.go @@ -733,13 +733,14 @@ func (s *String) ReadOptionalASN1OctetString(out *[]byte, outPresent *bool, tag return true } -// ReadOptionalASN1Boolean sets *out to the value of the next ASN.1 BOOLEAN or, -// if the next bytes are not an ASN.1 BOOLEAN, to the value of defaultValue. -// It reports whether the operation was successful. -func (s *String) ReadOptionalASN1Boolean(out *bool, defaultValue bool) bool { +// ReadOptionalASN1Boolean attempts to read an optional ASN.1 BOOLEAN +// explicitly tagged with tag into out and advances. If no element with a +// matching tag is present, it sets "out" to defaultValue instead. It reports +// whether the read was successful. +func (s *String) ReadOptionalASN1Boolean(out *bool, tag asn1.Tag, defaultValue bool) bool { var present bool var child String - if !s.ReadOptionalASN1(&child, &present, asn1.BOOLEAN) { + if !s.ReadOptionalASN1(&child, &present, tag) { return false } @@ -748,7 +749,7 @@ func (s *String) ReadOptionalASN1Boolean(out *bool, defaultValue bool) bool { return true } - return s.ReadASN1Boolean(out) + return child.ReadASN1Boolean(out) } func (s *String) readASN1(out *String, outTag *asn1.Tag, skipHeader bool) bool { diff --git a/cryptobyte/asn1_test.go b/cryptobyte/asn1_test.go index e3f53a932e..93760b06e9 100644 --- a/cryptobyte/asn1_test.go +++ b/cryptobyte/asn1_test.go @@ -115,6 +115,28 @@ func TestReadASN1OptionalInteger(t *testing.T) { } } +const defaultBool = false + +var optionalBoolTestData = []readASN1Test{ + {"empty", []byte{}, 0xa0, true, false}, + {"invalid", []byte{0xa1, 0x3, 0x1, 0x2, 0x7f}, 0xa1, false, defaultBool}, + {"missing", []byte{0xa1, 0x3, 0x1, 0x1, 0x7f}, 0xa0, true, defaultBool}, + {"present", []byte{0xa1, 0x3, 0x1, 0x1, 0xff}, 0xa1, true, true}, +} + +func TestReadASN1OptionalBoolean(t *testing.T) { + for _, test := range optionalBoolTestData { + t.Run(test.name, func(t *testing.T) { + in := String(test.in) + var out bool + ok := in.ReadOptionalASN1Boolean(&out, test.tag, defaultBool) + if ok != test.ok || ok && out != test.out.(bool) { + t.Errorf("in.ReadOptionalASN1Boolean() = %v, want %v; out = %v, want %v", ok, test.ok, out, test.out) + } + }) + } +} + func TestReadASN1IntegerSigned(t *testing.T) { testData64 := []struct { in []byte diff --git a/curve25519/internal/field/_asm/go.mod b/curve25519/internal/field/_asm/go.mod index bf6dbc73cd..f6902f4a64 100644 --- a/curve25519/internal/field/_asm/go.mod +++ b/curve25519/internal/field/_asm/go.mod @@ -9,7 +9,7 @@ require ( require ( golang.org/x/mod v0.8.0 // indirect - golang.org/x/sys v0.13.0 // indirect + golang.org/x/sys v0.14.0 // indirect golang.org/x/tools v0.6.0 // indirect ) diff --git a/curve25519/internal/field/_asm/go.sum b/curve25519/internal/field/_asm/go.sum index 3f57ad913c..96b7915d4f 100644 --- a/curve25519/internal/field/_asm/go.sum +++ b/curve25519/internal/field/_asm/go.sum @@ -26,21 +26,21 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= -golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= +golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= -golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= +golang.org/x/term v0.14.0/go.mod h1:TySc+nGkYR6qt8km8wUhuFRTVSMIX3XPR58y2lC8vww= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= diff --git a/go.mod b/go.mod index 90eec5425c..d676a454ad 100644 --- a/go.mod +++ b/go.mod @@ -4,8 +4,8 @@ go 1.18 require ( golang.org/x/net v0.10.0 // tagx:ignore - golang.org/x/sys v0.14.0 - golang.org/x/term v0.14.0 + golang.org/x/sys v0.15.0 + golang.org/x/term v0.15.0 ) require golang.org/x/text v0.14.0 // indirect diff --git a/go.sum b/go.sum index 49ce5c4aee..f9352ee97e 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,8 @@ golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= -golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.14.0 h1:LGK9IlZ8T9jvdy6cTdfKUCltatMFOehAQo9SRC46UQ8= -golang.org/x/term v0.14.0/go.mod h1:TySc+nGkYR6qt8km8wUhuFRTVSMIX3XPR58y2lC8vww= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4= +golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= diff --git a/ssh/client_auth.go b/ssh/client_auth.go index 5c3bc25723..34bf089d0b 100644 --- a/ssh/client_auth.go +++ b/ssh/client_auth.go @@ -307,7 +307,10 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand } var methods []string var errSigAlgo error - for _, signer := range signers { + + origSignersLen := len(signers) + for idx := 0; idx < len(signers); idx++ { + signer := signers[idx] pub := signer.PublicKey() as, algo, err := pickSignatureAlgorithm(signer, extensions) if err != nil && errSigAlgo == nil { @@ -321,6 +324,21 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand if err != nil { return authFailure, nil, err } + // OpenSSH 7.2-7.7 advertises support for rsa-sha2-256 and rsa-sha2-512 + // in the "server-sig-algs" extension but doesn't support these + // algorithms for certificate authentication, so if the server rejects + // the key try to use the obtained algorithm as if "server-sig-algs" had + // not been implemented if supported from the algorithm signer. + if !ok && idx < origSignersLen && isRSACert(algo) && algo != CertAlgoRSAv01 { + if contains(as.Algorithms(), KeyAlgoRSA) { + // We retry using the compat algorithm after all signers have + // been tried normally. + signers = append(signers, &multiAlgorithmSigner{ + AlgorithmSigner: as, + supportedAlgorithms: []string{KeyAlgoRSA}, + }) + } + } if !ok { continue } diff --git a/ssh/common.go b/ssh/common.go index dd2ab0d69a..7e9c2cbc64 100644 --- a/ssh/common.go +++ b/ssh/common.go @@ -127,6 +127,14 @@ func isRSA(algo string) bool { return contains(algos, underlyingAlgo(algo)) } +func isRSACert(algo string) bool { + _, ok := certKeyAlgoNames[algo] + if !ok { + return false + } + return isRSA(algo) +} + // supportedPubKeyAuthAlgos specifies the supported client public key // authentication algorithms. Note that this doesn't include certificate types // since those use the underlying algorithm. This list is sent to the client if diff --git a/ssh/example_test.go b/ssh/example_test.go index 0a6b0767c9..3920832c1a 100644 --- a/ssh/example_test.go +++ b/ssh/example_test.go @@ -16,6 +16,7 @@ import ( "os" "path/filepath" "strings" + "sync" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/terminal" @@ -98,8 +99,15 @@ func ExampleNewServerConn() { } log.Printf("logged in with key %s", conn.Permissions.Extensions["pubkey-fp"]) + var wg sync.WaitGroup + defer wg.Wait() + // The incoming Request channel must be serviced. - go ssh.DiscardRequests(reqs) + wg.Add(1) + go func() { + ssh.DiscardRequests(reqs) + wg.Done() + }() // Service the incoming Channel channel. for newChannel := range chans { @@ -119,16 +127,22 @@ func ExampleNewServerConn() { // Sessions have out-of-band requests such as "shell", // "pty-req" and "env". Here we handle only the // "shell" request. + wg.Add(1) go func(in <-chan *ssh.Request) { for req := range in { req.Reply(req.Type == "shell", nil) } + wg.Done() }(requests) term := terminal.NewTerminal(channel, "> ") + wg.Add(1) go func() { - defer channel.Close() + defer func() { + channel.Close() + wg.Done() + }() for { line, err := term.ReadLine() if err != nil { diff --git a/ssh/mux_test.go b/ssh/mux_test.go index 1db3be54a0..eae637d5e2 100644 --- a/ssh/mux_test.go +++ b/ssh/mux_test.go @@ -10,7 +10,6 @@ import ( "io" "sync" "testing" - "time" ) func muxPair() (*mux, *mux) { @@ -112,7 +111,11 @@ func TestMuxReadWrite(t *testing.T) { magic := "hello world" magicExt := "hello stderr" + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) go func() { + defer wg.Done() _, err := s.Write([]byte(magic)) if err != nil { t.Errorf("Write: %v", err) @@ -152,13 +155,15 @@ func TestMuxChannelOverflow(t *testing.T) { defer writer.Close() defer mux.Close() - wDone := make(chan int, 1) + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) go func() { + defer wg.Done() if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { t.Errorf("could not fill window: %v", err) } writer.Write(make([]byte, 1)) - wDone <- 1 }() writer.remoteWin.waitWriterBlocked() @@ -175,7 +180,6 @@ func TestMuxChannelOverflow(t *testing.T) { if _, err := reader.SendRequest("hello", true, nil); err == nil { t.Errorf("SendRequest succeeded.") } - <-wDone } func TestMuxChannelCloseWriteUnblock(t *testing.T) { @@ -184,20 +188,21 @@ func TestMuxChannelCloseWriteUnblock(t *testing.T) { defer writer.Close() defer mux.Close() - wDone := make(chan int, 1) + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) go func() { + defer wg.Done() if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { t.Errorf("could not fill window: %v", err) } if _, err := writer.Write(make([]byte, 1)); err != io.EOF { t.Errorf("got %v, want EOF for unblock write", err) } - wDone <- 1 }() writer.remoteWin.waitWriterBlocked() reader.Close() - <-wDone } func TestMuxConnectionCloseWriteUnblock(t *testing.T) { @@ -206,20 +211,21 @@ func TestMuxConnectionCloseWriteUnblock(t *testing.T) { defer writer.Close() defer mux.Close() - wDone := make(chan int, 1) + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) go func() { + defer wg.Done() if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { t.Errorf("could not fill window: %v", err) } if _, err := writer.Write(make([]byte, 1)); err != io.EOF { t.Errorf("got %v, want EOF for unblock write", err) } - wDone <- 1 }() writer.remoteWin.waitWriterBlocked() mux.Close() - <-wDone } func TestMuxReject(t *testing.T) { @@ -227,7 +233,12 @@ func TestMuxReject(t *testing.T) { defer server.Close() defer client.Close() + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) go func() { + defer wg.Done() + ch, ok := <-server.incomingChannels if !ok { t.Error("cannot accept channel") @@ -267,6 +278,7 @@ func TestMuxChannelRequest(t *testing.T) { var received int var wg sync.WaitGroup + t.Cleanup(wg.Wait) wg.Add(1) go func() { for r := range server.incomingRequests { @@ -295,7 +307,6 @@ func TestMuxChannelRequest(t *testing.T) { } if ok { t.Errorf("SendRequest(no): %v", ok) - } client.Close() @@ -389,13 +400,8 @@ func TestMuxUnknownChannelRequests(t *testing.T) { // Wait for the server to send the keepalive message and receive back a // response. - select { - case err := <-kDone: - if err != nil { - t.Fatal(err) - } - case <-time.After(10 * time.Second): - t.Fatalf("server never received ack") + if err := <-kDone; err != nil { + t.Fatal(err) } // Confirm client hasn't closed. @@ -403,13 +409,9 @@ func TestMuxUnknownChannelRequests(t *testing.T) { t.Fatalf("failed to send keepalive: %v", err) } - select { - case err := <-kDone: - if err != nil { - t.Fatal(err) - } - case <-time.After(10 * time.Second): - t.Fatalf("server never shut down") + // Wait for the server to shut down. + if err := <-kDone; err != nil { + t.Fatal(err) } } @@ -525,11 +527,7 @@ func TestMuxClosedChannel(t *testing.T) { defer ch.Close() // Wait for the server to close the channel and send the keepalive. - select { - case <-kDone: - case <-time.After(10 * time.Second): - t.Fatalf("server never received ack") - } + <-kDone // Make sure the channel closed. if _, ok := <-ch.incomingRequests; ok { @@ -541,22 +539,29 @@ func TestMuxClosedChannel(t *testing.T) { t.Fatalf("failed to send keepalive: %v", err) } - select { - case <-kDone: - case <-time.After(10 * time.Second): - t.Fatalf("server never shut down") - } + // Wait for the server to shut down. + <-kDone } func TestMuxGlobalRequest(t *testing.T) { + var sawPeek bool + var wg sync.WaitGroup + defer func() { + wg.Wait() + if !sawPeek { + t.Errorf("never saw 'peek' request") + } + }() + clientMux, serverMux := muxPair() defer serverMux.Close() defer clientMux.Close() - var seen bool + wg.Add(1) go func() { + defer wg.Done() for r := range serverMux.incomingRequests { - seen = seen || r.Type == "peek" + sawPeek = sawPeek || r.Type == "peek" if r.WantReply { err := r.Reply(r.Type == "yes", append([]byte(r.Type), r.Payload...)) @@ -586,10 +591,6 @@ func TestMuxGlobalRequest(t *testing.T) { t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v", ok, data, err) } - - if !seen { - t.Errorf("never saw 'peek' request") - } } func TestMuxGlobalRequestUnblock(t *testing.T) { @@ -739,7 +740,13 @@ func TestMuxMaxPacketSize(t *testing.T) { t.Errorf("could not send packet") } - go a.SendRequest("hello", false, nil) + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + a.SendRequest("hello", false, nil) + wg.Done() + }() _, ok := <-b.incomingRequests if ok { diff --git a/ssh/server.go b/ssh/server.go index 8f1505af94..7f0c236a9a 100644 --- a/ssh/server.go +++ b/ssh/server.go @@ -337,7 +337,7 @@ func checkSourceAddress(addr net.Addr, sourceAddrs string) error { return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr) } -func gssExchangeToken(gssapiConfig *GSSAPIWithMICConfig, firstToken []byte, s *connection, +func gssExchangeToken(gssapiConfig *GSSAPIWithMICConfig, token []byte, s *connection, sessionID []byte, userAuthReq userAuthRequestMsg) (authErr error, perms *Permissions, err error) { gssAPIServer := gssapiConfig.Server defer gssAPIServer.DeleteSecContext() @@ -347,7 +347,7 @@ func gssExchangeToken(gssapiConfig *GSSAPIWithMICConfig, firstToken []byte, s *c outToken []byte needContinue bool ) - outToken, srcName, needContinue, err = gssAPIServer.AcceptSecContext(firstToken) + outToken, srcName, needContinue, err = gssAPIServer.AcceptSecContext(token) if err != nil { return err, nil, nil } @@ -369,6 +369,7 @@ func gssExchangeToken(gssapiConfig *GSSAPIWithMICConfig, firstToken []byte, s *c if err := Unmarshal(packet, userAuthGSSAPITokenReq); err != nil { return nil, nil, err } + token = userAuthGSSAPITokenReq.Token } packet, err := s.transport.readPacket() if err != nil { diff --git a/ssh/session_test.go b/ssh/session_test.go index 521677f9b1..807a913e5a 100644 --- a/ssh/session_test.go +++ b/ssh/session_test.go @@ -13,6 +13,7 @@ import ( "io" "math/rand" "net" + "sync" "testing" "golang.org/x/crypto/ssh/terminal" @@ -27,8 +28,14 @@ func dial(handler serverType, t *testing.T) *Client { t.Fatalf("netPipe: %v", err) } + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) go func() { - defer c1.Close() + defer func() { + c1.Close() + wg.Done() + }() conf := ServerConfig{ NoClientAuth: true, } @@ -39,7 +46,11 @@ func dial(handler serverType, t *testing.T) *Client { t.Errorf("Unable to handshake: %v", err) return } - go DiscardRequests(reqs) + wg.Add(1) + go func() { + DiscardRequests(reqs) + wg.Done() + }() for newCh := range chans { if newCh.ChannelType() != "session" { @@ -52,8 +63,10 @@ func dial(handler serverType, t *testing.T) *Client { t.Errorf("Accept: %v", err) continue } + wg.Add(1) go func() { handler(ch, inReqs, t) + wg.Done() }() } if err := conn.Wait(); err != io.EOF { @@ -338,8 +351,13 @@ func TestServerWindow(t *testing.T) { t.Fatal(err) } defer session.Close() - result := make(chan []byte) + serverStdin, err := session.StdinPipe() + if err != nil { + t.Fatalf("StdinPipe failed: %v", err) + } + + result := make(chan []byte) go func() { defer close(result) echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes)) @@ -355,10 +373,6 @@ func TestServerWindow(t *testing.T) { result <- echoedBuf.Bytes() }() - serverStdin, err := session.StdinPipe() - if err != nil { - t.Fatalf("StdinPipe failed: %v", err) - } written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestBytes) if err != nil { t.Errorf("failed to copy origBuf to serverStdin: %v", err) @@ -648,29 +662,44 @@ func TestSessionID(t *testing.T) { User: "user", } + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + srvErrCh := make(chan error, 1) + wg.Add(1) go func() { + defer wg.Done() conn, chans, reqs, err := NewServerConn(c1, serverConf) srvErrCh <- err if err != nil { return } serverID <- conn.SessionID() - go DiscardRequests(reqs) + wg.Add(1) + go func() { + DiscardRequests(reqs) + wg.Done() + }() for ch := range chans { ch.Reject(Prohibited, "") } }() cliErrCh := make(chan error, 1) + wg.Add(1) go func() { + defer wg.Done() conn, chans, reqs, err := NewClientConn(c2, "", clientConf) cliErrCh <- err if err != nil { return } clientID <- conn.SessionID() - go DiscardRequests(reqs) + wg.Add(1) + go func() { + DiscardRequests(reqs) + wg.Done() + }() for ch := range chans { ch.Reject(Prohibited, "") } @@ -738,6 +767,8 @@ func TestHostKeyAlgorithms(t *testing.T) { serverConf.AddHostKey(testSigners["rsa"]) serverConf.AddHostKey(testSigners["ecdsa"]) + var wg sync.WaitGroup + t.Cleanup(wg.Wait) connect := func(clientConf *ClientConfig, want string) { var alg string clientConf.HostKeyCallback = func(h string, a net.Addr, key PublicKey) error { @@ -751,7 +782,11 @@ func TestHostKeyAlgorithms(t *testing.T) { defer c1.Close() defer c2.Close() - go NewServerConn(c1, serverConf) + wg.Add(1) + go func() { + NewServerConn(c1, serverConf) + wg.Done() + }() _, _, _, err = NewClientConn(c2, "", clientConf) if err != nil { t.Fatalf("NewClientConn: %v", err) @@ -785,7 +820,11 @@ func TestHostKeyAlgorithms(t *testing.T) { defer c1.Close() defer c2.Close() - go NewServerConn(c1, serverConf) + wg.Add(1) + go func() { + NewServerConn(c1, serverConf) + wg.Done() + }() clientConf.HostKeyAlgorithms = []string{"nonexistent-hostkey-algo"} _, _, _, err = NewClientConn(c2, "", clientConf) if err == nil { @@ -818,14 +857,22 @@ func TestServerClientAuthCallback(t *testing.T) { User: someUsername, } + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) go func() { + defer wg.Done() _, chans, reqs, err := NewServerConn(c1, serverConf) if err != nil { t.Errorf("server handshake: %v", err) userCh <- "error" return } - go DiscardRequests(reqs) + wg.Add(1) + go func() { + DiscardRequests(reqs) + wg.Done() + }() for ch := range chans { ch.Reject(Prohibited, "") } diff --git a/ssh/tcpip.go b/ssh/tcpip.go index 80d35f5ec1..ef5059a11d 100644 --- a/ssh/tcpip.go +++ b/ssh/tcpip.go @@ -5,6 +5,7 @@ package ssh import ( + "context" "errors" "fmt" "io" @@ -332,6 +333,40 @@ func (l *tcpListener) Addr() net.Addr { return l.laddr } +// DialContext initiates a connection to the addr from the remote host. +// +// The provided Context must be non-nil. If the context expires before the +// connection is complete, an error is returned. Once successfully connected, +// any expiration of the context will not affect the connection. +// +// See func Dial for additional information. +func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + type connErr struct { + conn net.Conn + err error + } + ch := make(chan connErr) + go func() { + conn, err := c.Dial(n, addr) + select { + case ch <- connErr{conn, err}: + case <-ctx.Done(): + if conn != nil { + conn.Close() + } + } + }() + select { + case res := <-ch: + return res.conn, res.err + case <-ctx.Done(): + return nil, ctx.Err() + } +} + // Dial initiates a connection to the addr from the remote host. // The resulting connection has a zero LocalAddr() and RemoteAddr(). func (c *Client) Dial(n, addr string) (net.Conn, error) { diff --git a/ssh/tcpip_test.go b/ssh/tcpip_test.go index f1265cb496..4d85114727 100644 --- a/ssh/tcpip_test.go +++ b/ssh/tcpip_test.go @@ -5,7 +5,10 @@ package ssh import ( + "context" + "net" "testing" + "time" ) func TestAutoPortListenBroken(t *testing.T) { @@ -18,3 +21,33 @@ func TestAutoPortListenBroken(t *testing.T) { t.Errorf("version %q marked as broken", works) } } + +func TestClientImplementsDialContext(t *testing.T) { + type ContextDialer interface { + DialContext(context.Context, string, string) (net.Conn, error) + } + // Belt and suspenders assertion, since package net does not + // declare a ContextDialer type. + var _ ContextDialer = &net.Dialer{} + var _ ContextDialer = &Client{} +} + +func TestClientDialContextWithCancel(t *testing.T) { + c := &Client{} + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := c.DialContext(ctx, "tcp", "localhost:1000") + if err != context.Canceled { + t.Errorf("DialContext: got nil error, expected %v", context.Canceled) + } +} + +func TestClientDialContextWithDeadline(t *testing.T) { + c := &Client{} + ctx, cancel := context.WithDeadline(context.Background(), time.Now()) + defer cancel() + _, err := c.DialContext(ctx, "tcp", "localhost:1000") + if err != context.DeadlineExceeded { + t.Errorf("DialContext: got nil error, expected %v", context.DeadlineExceeded) + } +} diff --git a/ssh/test/dial_unix_test.go b/ssh/test/dial_unix_test.go index 0a5f5e395f..8ec8d50a50 100644 --- a/ssh/test/dial_unix_test.go +++ b/ssh/test/dial_unix_test.go @@ -9,6 +9,7 @@ package test // direct-tcpip and direct-streamlocal functional tests import ( + "context" "fmt" "io" "net" @@ -46,7 +47,11 @@ func testDial(t *testing.T, n, listenAddr string, x dialTester) { } }() - conn, err := sshConn.Dial(n, l.Addr().String()) + ctx, cancel := context.WithCancel(context.Background()) + conn, err := sshConn.DialContext(ctx, n, l.Addr().String()) + // Canceling the context after dial should have no effect + // on the opened connection. + cancel() if err != nil { t.Fatalf("Dial: %v", err) } diff --git a/ssh/test/sshcli_test.go b/ssh/test/sshcli_test.go index d3b85d77e2..ac2f7c10a9 100644 --- a/ssh/test/sshcli_test.go +++ b/ssh/test/sshcli_test.go @@ -10,6 +10,7 @@ import ( "os" "os/exec" "path/filepath" + "runtime" "testing" "golang.org/x/crypto/internal/testenv" @@ -34,6 +35,9 @@ func sshClient(t *testing.T) string { } func TestSSHCLIAuth(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skipf("always fails on Windows, see #64403") + } sshCLI := sshClient(t) dir := t.TempDir() keyPrivPath := filepath.Join(dir, "rsa")