From 346a5b8de15ecd0b4107f5a9776dfef00aaf0c21 Mon Sep 17 00:00:00 2001 From: Karl Isenberg Date: Sun, 27 Apr 2025 20:23:09 -0700 Subject: [PATCH 1/2] test: validate that watchers close without error - Add channel test helpers to k8s.io/client-go/util/testing - Use the new channel test helpers in the client and server watch tests to validate that the result channels closes without error when the client stops the watcher. - Use the new channel test helpers in the client watcher decoder tests to validate that encoded watch events can be decoded and that the decoder errors with EOF when stopped asynchronously. These new tests uncovered existing errors (added TODOs): 1. The watch client doesn't close the response body when it encounters a NegotiateError. 2. The watch server has a race condition that sometimes sends a watch error on the result channel after the client watcher has been stopped. --- .../pkg/endpoints/handlers/watch_test.go | 189 ++++++++++++++---- .../src/k8s.io/client-go/rest/request_test.go | 113 +++++++---- .../client-go/rest/watch/decoder_test.go | 160 ++++++++------- .../k8s.io/client-go/util/testing/channels.go | 154 ++++++++++++++ 4 files changed, 467 insertions(+), 149 deletions(-) create mode 100644 staging/src/k8s.io/client-go/util/testing/channels.go diff --git a/staging/src/k8s.io/apiserver/pkg/endpoints/handlers/watch_test.go b/staging/src/k8s.io/apiserver/pkg/endpoints/handlers/watch_test.go index 77410bfa5e30f..f9089589ca9e8 100644 --- a/staging/src/k8s.io/apiserver/pkg/endpoints/handlers/watch_test.go +++ b/staging/src/k8s.io/apiserver/pkg/endpoints/handlers/watch_test.go @@ -17,8 +17,8 @@ limitations under the License. package handlers import ( - "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -27,8 +27,9 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "k8s.io/apimachinery/pkg/api/errors" + apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" @@ -38,6 +39,7 @@ import ( apitesting "k8s.io/apiserver/pkg/endpoints/testing" "k8s.io/client-go/dynamic" restclient "k8s.io/client-go/rest" + utiltesting "k8s.io/client-go/util/testing" ) // Fake API versions, similar to api/latest.go @@ -61,6 +63,7 @@ func init() { } func TestWatchHTTPErrors(t *testing.T) { + ctx := t.Context() watcher := watch.NewFake() timeoutCh := make(chan time.Time) doneCh := make(chan struct{}) @@ -88,19 +91,24 @@ func TestWatchHTTPErrors(t *testing.T) { defer s.Close() // Setup a client - dest, _ := url.Parse(s.URL) + dest, err := url.Parse(s.URL) + require.NoError(t, err) dest.Path = "/" + namedGroupPrefix + "/" + testGroupV2.Group + "/" + testGroupV2.Version + "/simple" dest.RawQuery = "watch=true" - req, _ := http.NewRequest(http.MethodGet, dest.String(), nil) + // Start watch request + req, err := http.NewRequestWithContext(ctx, http.MethodGet, dest.String(), nil) + require.NoError(t, err) client := http.Client{} resp, err := client.Do(req) require.NoError(t, err) - errStatus := errors.NewInternalError(fmt.Errorf("we got an error")).Status() + defer assertClosed(t, resp.Body) + + // Send error to server from storage + errStatus := apierrors.NewInternalError(fmt.Errorf("we got an error")).Status() watcher.Error(&errStatus) - watcher.Stop() - // Make sure we can actually watch an endpoint + // Decode error from the response decoder := json.NewDecoder(resp.Body) var got watchJSON err = decoder.Decode(&got) @@ -121,12 +129,27 @@ func TestWatchHTTPErrors(t *testing.T) { Details: errStatus.Details, } require.Equal(t, expectedStatus, status) + + // Close the response body to signal the server to stop serving. + require.NoError(t, resp.Body.Close()) + + // Wait for the server to call the CancelFunc returned by + // TimeoutFactory.TimeoutCh, closing the done channel. + err = utiltesting.WaitForChannelToCloseWithTimeout(ctx, wait.ForeverTestTimeout, doneCh) + require.NoError(t, err) + + // Wait for the server to call watcher.Stop, closing the result channel. + err = utiltesting.WaitForChannelToCloseWithTimeout(ctx, wait.ForeverTestTimeout, watcher.ResultChan()) + require.NoError(t, err) + + // Confirm watcher.Stop was called by the server. + require.Truef(t, watcher.IsStopped(), + "Leaked watcher goroutine after request done") } func TestWatchHTTPErrorsBeforeServe(t *testing.T) { + ctx := t.Context() watcher := watch.NewFake() - timeoutCh := make(chan time.Time) - doneCh := make(chan struct{}) info, ok := runtime.SerializerInfoForMediaType(codecs.SupportedMediaTypes(), runtime.ContentTypeJSON) if !ok || info.StreamSerializer == nil { @@ -147,24 +170,29 @@ func TestWatchHTTPErrorsBeforeServe(t *testing.T) { Encoder: testCodecV2, EmbeddedEncoder: testCodecV2, - TimeoutFactory: &fakeTimeoutFactory{timeoutCh, doneCh}, + // TimeoutFactory should not be needed, because the server should error + // before calling TimeoutFactory.TimeoutCh. } - statusErr := errors.NewInternalError(fmt.Errorf("we got an error")) + statusErr := apierrors.NewInternalError(fmt.Errorf("we got an error")) errStatus := statusErr.Status() s := httptest.NewServer(serveWatch(watcher, watchServer, statusErr)) defer s.Close() // Setup a client - dest, _ := url.Parse(s.URL) + dest, err := url.Parse(s.URL) + require.NoError(t, err) dest.Path = "/" + namedGroupPrefix + "/" + testGroupV2.Group + "/" + testGroupV2.Version + "/simple" dest.RawQuery = "watch=true" - req, _ := http.NewRequest(http.MethodGet, dest.String(), nil) + // Start watch request + req, err := http.NewRequestWithContext(ctx, http.MethodGet, dest.String(), nil) + require.NoError(t, err) client := http.Client{} resp, err := client.Do(req) require.NoError(t, err) + defer assertClosed(t, resp.Body) // We had already got an error before watch serve started decoder := json.NewDecoder(resp.Body) @@ -184,15 +212,25 @@ func TestWatchHTTPErrorsBeforeServe(t *testing.T) { } require.Equal(t, expectedStatus, status) - // check for leaks + // Close the response body to signal the server to stop serving. + // This isn't strictly necessary, since the test serveWatch doesn't block, + // but it would be if this were the real watch server. + require.NoError(t, resp.Body.Close()) + + // Wait for the server to call watcher.Stop, closing the result channel. + err = utiltesting.WaitForChannelToCloseWithTimeout(ctx, wait.ForeverTestTimeout, watcher.ResultChan()) + require.NoError(t, err) + + // Confirm watcher.Stop was called by the server. require.Truef(t, watcher.IsStopped(), - "Leaked watcher goruntine after request done") + "Leaked watcher goroutine after request done") } func TestWatchHTTPDynamicClientErrors(t *testing.T) { + ctx := t.Context() watcher := watch.NewFake() timeoutCh := make(chan time.Time) - done := make(chan struct{}) + doneCh := make(chan struct{}) info, ok := runtime.SerializerInfoForMediaType(codecs.SupportedMediaTypes(), runtime.ContentTypeJSON) if !ok || info.StreamSerializer == nil { @@ -210,7 +248,7 @@ func TestWatchHTTPDynamicClientErrors(t *testing.T) { Encoder: testCodecV2, EmbeddedEncoder: testCodecV2, - TimeoutFactory: &fakeTimeoutFactory{timeoutCh, done}, + TimeoutFactory: &fakeTimeoutFactory{timeoutCh: timeoutCh, done: doneCh}, } s := httptest.NewServer(serveWatch(watcher, watchServer, nil)) @@ -222,14 +260,32 @@ func TestWatchHTTPDynamicClientErrors(t *testing.T) { APIPath: "/" + namedGroupPrefix, }).Resource(testGroupV2.WithResource("simple")) - _, err := client.Watch(context.TODO(), metav1.ListOptions{}) + _, err := client.Watch(ctx, metav1.ListOptions{}) require.Equal(t, runtime.NegotiateError{Stream: true, ContentType: "testcase/json"}, err) + + // The client should automatically close the connection on error. + //TODO: Fix watch client to close the response body and stop the storage watcher after a NegotiateError + require.False(t, watcher.IsStopped()) + + // Wait for the server to call the CancelFunc returned by + // TimeoutFactory.TimeoutCh, closing the done channel. + // err = utiltesting.WaitForChannelToCloseWithTimeout(ctx, wait.ForeverTestTimeout, doneCh) + // require.NoError(t, err) + + // Wait for the server to call watcher.Stop, closing the result channel. + // err = utiltesting.WaitForChannelToCloseWithTimeout(ctx, wait.ForeverTestTimeout, watcher.ResultChan()) + // require.NoError(t, err) + + // Confirm watcher.Stop was called by the server. + // require.Truef(t, watcher.IsStopped(), + // "Leaked watcher goroutine after request done") } func TestWatchHTTPTimeout(t *testing.T) { + ctx := t.Context() watcher := watch.NewFake() timeoutCh := make(chan time.Time) - done := make(chan struct{}) + doneCh := make(chan struct{}) info, ok := runtime.SerializerInfoForMediaType(codecs.SupportedMediaTypes(), runtime.ContentTypeJSON) if !ok || info.StreamSerializer == nil { @@ -247,21 +303,27 @@ func TestWatchHTTPTimeout(t *testing.T) { Encoder: testCodecV2, EmbeddedEncoder: testCodecV2, - TimeoutFactory: &fakeTimeoutFactory{timeoutCh, done}, + TimeoutFactory: &fakeTimeoutFactory{timeoutCh: timeoutCh, done: doneCh}, } s := httptest.NewServer(serveWatch(watcher, watchServer, nil)) defer s.Close() // Setup a client - dest, _ := url.Parse(s.URL) + dest, err := url.Parse(s.URL) + require.NoError(t, err) dest.Path = "/" + namedGroupPrefix + "/" + testGroupV2.Group + "/" + testGroupV2.Version + "/simple" dest.RawQuery = "watch=true" - req, _ := http.NewRequest(http.MethodGet, dest.String(), nil) + // Start watch request + req, err := http.NewRequestWithContext(ctx, http.MethodGet, dest.String(), nil) + require.NoError(t, err) client := http.Client{} resp, err := client.Do(req) require.NoError(t, err) + defer assertClosed(t, resp.Body) + + // Send object added event to server from storage watcher.Add(&apitesting.Simple{TypeMeta: metav1.TypeMeta{APIVersion: testGroupV2.String()}}) // Make sure we can actually watch an endpoint @@ -270,29 +332,28 @@ func TestWatchHTTPTimeout(t *testing.T) { err = decoder.Decode(&got) require.NoError(t, err) - // Timeout and check for leaks + // Trigger server-side timeout. close(timeoutCh) - select { - case <-done: - eventCh := watcher.ResultChan() - select { - case _, opened := <-eventCh: - if opened { - t.Errorf("Watcher received unexpected event") - } - if !watcher.IsStopped() { - t.Errorf("Watcher is not stopped") - } - case <-time.After(wait.ForeverTestTimeout): - t.Errorf("Leaked watch on timeout") - } - case <-time.After(wait.ForeverTestTimeout): - t.Errorf("Failed to stop watcher after %s of timeout signal", wait.ForeverTestTimeout.String()) - } - // Make sure we can't receive any more events through the timeout watch + // Wait for the server to call the CancelFunc returned by + // TimeoutFactory.TimeoutCh, closing the done channel. + err = utiltesting.WaitForChannelToCloseWithTimeout(ctx, wait.ForeverTestTimeout, doneCh) + require.NoError(t, err) + + // Wait for the server to call watcher.Stop, closing the result channel. + err = utiltesting.WaitForChannelToCloseWithTimeout(ctx, wait.ForeverTestTimeout, watcher.ResultChan()) + require.NoError(t, err) + + // Confirm watcher.Stop was called by the server. + require.Truef(t, watcher.IsStopped(), + "Leaked watcher goroutine after request done") + + // Make sure we can't receive any more events after the watch timeout err = decoder.Decode(&got) require.Equal(t, io.EOF, err) + + // Close the response body to clean up watch client resources. + require.NoError(t, resp.Body.Close()) } // watchJSON defines the expected JSON wire equivalent of watch.Event. @@ -330,3 +391,49 @@ func serveWatch(watcher watch.Interface, watchServer *WatchServer, preServeErr e watchServer.HandleHTTP(w, req) } } + +// From https://github.com/golang/go/blob/go1.20/src/net/http/transport.go#L2779 +var errReadOnClosedResBody = errors.New("http: read on closed response body") + +// assertClosed fails the test if the ReadCloser is NOT already closed. +// If not already closed, the ReadCloser will be drained and closed. +// Defer when your test is expected to close the ReadCloser before ending. +func assertClosed(t *testing.T, rc io.ReadCloser) { + assert.Equal(t, errReadOnClosedResBody, drainAndClose(rc)) +} + +// drainAndClose reads from the ReadCloser until EOF, discarding the content, +// and closes the ReadCloser when finished or on error. +// Returns an error when either Read or Close error. If both error, the errors +// are joined and returned. +// +// In a defer from a test, use with t.Error or assert.NoError, NOT t.Fatal or +// require.NoError. +func drainAndClose(rc io.ReadCloser) error { + errCh := make(chan error) + go func() { + // Close after done reading + defer func() { + defer close(errCh) + if err := rc.Close(); err != nil { + errCh <- err + } + }() + // Read until EOF and discard + if _, err := io.Copy(io.Discard, rc); err != nil { + errCh <- err + } + }() + + // Wait until Read and Close are both done. + // Combine errors, if multiple. + var multiErr error + for err := range errCh { + if multiErr != nil { + multiErr = errors.Join(multiErr, err) + } else { + multiErr = err + } + } + return multiErr +} diff --git a/staging/src/k8s.io/client-go/rest/request_test.go b/staging/src/k8s.io/client-go/rest/request_test.go index fd64dcb028187..6e710062f11ca 100644 --- a/staging/src/k8s.io/client-go/rest/request_test.go +++ b/staging/src/k8s.io/client-go/rest/request_test.go @@ -39,11 +39,9 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "github.com/google/go-cmp/cmp" - v1 "k8s.io/api/core/v1" apiequality "k8s.io/apimachinery/pkg/api/equality" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -54,6 +52,7 @@ import ( "k8s.io/apimachinery/pkg/runtime/serializer/streaming" "k8s.io/apimachinery/pkg/util/intstr" utilnet "k8s.io/apimachinery/pkg/util/net" + "k8s.io/apimachinery/pkg/util/wait" "k8s.io/apimachinery/pkg/watch" "k8s.io/client-go/kubernetes/scheme" restclientwatch "k8s.io/client-go/rest/watch" @@ -2069,6 +2068,7 @@ func TestWatch(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + ctx := t.Context() var table = []struct { t watch.EventType obj runtime.Object @@ -2110,28 +2110,34 @@ func TestWatch(t *testing.T) { defer testServer.Close() s := testRESTClient(t, testServer) - watching, err := s.Get().Prefix("path/to/watch/thing"). - MaxRetries(test.maxRetries).Watch(context.Background()) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + watcher, err := s.Get().Prefix("path/to/watch/thing"). + MaxRetries(test.maxRetries).Watch(ctx) + require.NoError(t, err) + defer watcher.Stop() + // Read the events from the result channel + resultCh := watcher.ResultChan() for _, item := range table { - got, ok := <-watching.ResultChan() + got, ok := <-resultCh if !ok { - t.Fatalf("Unexpected early close") - } - if e, a := item.t, got.Type; e != a { - t.Errorf("Expected %v, got %v", e, a) - } - if e, a := item.obj, got.Object; !apiequality.Semantic.DeepDerivative(e, a) { - t.Errorf("Expected %v, got %v", e, a) + t.Fatal("Unexpected early close") } + assert.Equal(t, item.t, got.Type) + assert.Equal(t, item.obj, got.Object) } - _, ok := <-watching.ResultChan() - if ok { - t.Fatal("Unexpected non-close") + // Stop watcher when done reading watch events + watcher.Stop() + + // Wait for the result channel to close + err = utiltesting.WaitForChannelToCloseWithTimeout(ctx, wait.ForeverTestTimeout, resultCh) + // TODO(karlkfi): Fix race condition that causes the watch server to sometimes send an error when stopped + if err != nil { + // Extract & compare the Event so we can see diff when it fails + event, unwrapErr := unwrapUnexectedWatchEvent(err) + require.NoError(t, unwrapErr) + expectedEvent := newClientWatchDecodingEvent(errReadOnClosedResBody) + require.Equal(t, expectedEvent, event) } }) } @@ -2173,27 +2179,32 @@ func TestWatchNonDefaultContentType(t *testing.T) { contentConfig := defaultContentConfig() contentConfig.ContentType = "application/vnd.kubernetes.protobuf" s := testRESTClientWithConfig(t, testServer, contentConfig) - watching, err := s.Get().Prefix("path/to/watch/thing").Watch(context.Background()) - if err != nil { - t.Fatalf("Unexpected error") - } + watcher, err := s.Get().Prefix("path/to/watch/thing").Watch(t.Context()) + require.NoError(t, err) + defer watcher.Stop() + resultCh := watcher.ResultChan() for _, item := range table { - got, ok := <-watching.ResultChan() + got, ok := <-resultCh if !ok { t.Fatalf("Unexpected early close") } - if e, a := item.t, got.Type; e != a { - t.Errorf("Expected %v, got %v", e, a) - } - if e, a := item.obj, got.Object; !apiequality.Semantic.DeepDerivative(e, a) { - t.Errorf("Expected %v, got %v", e, a) - } + assert.Equal(t, item.t, got.Type) + assert.Equal(t, item.obj, got.Object) } - _, ok := <-watching.ResultChan() - if ok { - t.Fatal("Unexpected non-close") + // Stop watcher when done reading watch events + watcher.Stop() + + // Wait for the result channel to close + err = utiltesting.WaitForChannelToCloseWithTimeout(t.Context(), wait.ForeverTestTimeout, resultCh) + // TODO(karlkfi): Fix race condition that causes the watch server to sometimes send an error when stopped + if err != nil { + // Extract & compare the Event so we can see diff when it fails + event, unwrapErr := unwrapUnexectedWatchEvent(err) + require.NoError(t, unwrapErr) + expectedEvent := newClientWatchDecodingEvent(errReadOnClosedResBody) + require.Equal(t, expectedEvent, event) } } @@ -2230,10 +2241,8 @@ func TestWatchUnknownContentType(t *testing.T) { defer testServer.Close() s := testRESTClient(t, testServer) - _, err := s.Get().Prefix("path/to/watch/thing").Watch(context.Background()) - if err == nil { - t.Fatalf("Expected to fail due to lack of known stream serialization for content type") - } + _, err := s.Get().Prefix("path/to/watch/thing").Watch(t.Context()) + require.Equal(t, runtime.NegotiateError{ContentType: "foobar", Stream: true}, err) } func TestStream(t *testing.T) { @@ -4253,3 +4262,33 @@ func TestRequestWarningHandler(t *testing.T) { assert.Nil(t, request.warningHandler) }) } + +// From https://github.com/golang/go/blob/go1.20/src/net/http/transport.go#L2779 +var errReadOnClosedResBody = errors.New("http: read on closed response body") + +// newClientWatchDecodingEvent simulates a InternalServerError of type +// "ClientWatchDecoding", wrapped as a watch.Event. These errors come from the +// watch client StreamWatcher. +func newClientWatchDecodingEvent(decodeErr error) watch.Event { + // From Request.newStreamWatcher + clientErrorReporter := apierrors.NewClientErrorReporter(http.StatusInternalServerError, "GET", "ClientWatchDecoding") + // From StreamWatcher.receive + // TODO(karlkfi): Use `%w` to format errors (errorlint) in StreamWatcher.receive + //nolint:errorlint // StreamWatcher.receive uses `%v` + serverErr := fmt.Errorf("unable to decode an event from the watch stream: %v", decodeErr) + return watch.Event{ + Type: watch.Error, + Object: clientErrorReporter.AsObject(serverErr), + } +} + +// unwrapUnexectedWatchEvent unwraps a utiltesting.UnexpectedEventError to +// extract the watch.Error. +func unwrapUnexectedWatchEvent(err error) (watch.Event, error) { + var expectedErr *utiltesting.UnexpectedEventError[watch.Event] + if !errors.As(err, &expectedErr) { + return watch.Event{}, fmt.Errorf("Error expected to be of type %v, but was %v", + reflect.TypeOf(expectedErr), reflect.TypeOf(err)) + } + return expectedErr.Event, nil +} diff --git a/staging/src/k8s.io/client-go/rest/watch/decoder_test.go b/staging/src/k8s.io/client-go/rest/watch/decoder_test.go index a6a95b884222b..6bc18d0abb871 100644 --- a/staging/src/k8s.io/client-go/rest/watch/decoder_test.go +++ b/staging/src/k8s.io/client-go/rest/watch/decoder_test.go @@ -21,10 +21,10 @@ import ( "fmt" "io" "testing" - "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" v1 "k8s.io/api/core/v1" - apiequality "k8s.io/apimachinery/pkg/api/equality" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" runtimejson "k8s.io/apimachinery/pkg/runtime/serializer/json" @@ -32,6 +32,7 @@ import ( "k8s.io/apimachinery/pkg/util/wait" "k8s.io/apimachinery/pkg/watch" "k8s.io/client-go/kubernetes/scheme" + utiltesting "k8s.io/client-go/util/testing" ) // getDecoder mimics how k8s.io/client-go/rest.createSerializers creates a decoder @@ -45,86 +46,103 @@ func TestDecoder(t *testing.T) { table := []watch.EventType{watch.Added, watch.Deleted, watch.Modified, watch.Error, watch.Bookmark} for _, eventType := range table { - out, in := io.Pipe() - - decoder := NewDecoder(streaming.NewDecoder(out, getDecoder()), getDecoder()) - expect := &v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "foo"}} - encoder := json.NewEncoder(in) - eType := eventType - errc := make(chan error) - - go func() { - data, err := runtime.Encode(scheme.Codecs.LegacyCodec(v1.SchemeGroupVersion), expect) - if err != nil { - errc <- fmt.Errorf("Unexpected error %v", err) - return - } - event := metav1.WatchEvent{ - Type: string(eType), - Object: runtime.RawExtension{Raw: json.RawMessage(data)}, - } - if err := encoder.Encode(&event); err != nil { - t.Errorf("Unexpected error %v", err) - } - in.Close() - }() - - done := make(chan struct{}) - go func() { - action, got, err := decoder.Decode() - if err != nil { - errc <- fmt.Errorf("Unexpected error %v", err) - return - } - if e, a := eType, action; e != a { - t.Errorf("Expected %v, got %v", e, a) - } - if e, a := expect, got; !apiequality.Semantic.DeepDerivative(e, a) { - t.Errorf("Expected %v, got %v", e, a) - } - t.Logf("Exited read") - close(done) - }() - select { - case err := <-errc: - t.Fatal(err) - case <-done: - } - - done = make(chan struct{}) - go func() { - _, _, err := decoder.Decode() - if err == nil { - t.Errorf("Unexpected nil error") - } - close(done) - }() - <-done - - decoder.Close() + t.Run(string(eventType), func(t *testing.T) { + ctx := t.Context() + out, in := io.Pipe() + defer assertNoCloseError(t, in) + decoder := NewDecoder(streaming.NewDecoder(out, getDecoder()), getDecoder()) + defer decoder.Close() + expect := &v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "foo"}} + encoder := json.NewEncoder(in) + eType := eventType + + encodeErrCh := make(chan error) + go func() { + defer close(encodeErrCh) + data, err := runtime.Encode(scheme.Codecs.LegacyCodec(v1.SchemeGroupVersion), expect) + if err != nil { + encodeErrCh <- fmt.Errorf("encode error: %w", err) + return + } + event := metav1.WatchEvent{ + Type: string(eType), + Object: runtime.RawExtension{Raw: json.RawMessage(data)}, + } + if err := encoder.Encode(&event); err != nil { + encodeErrCh <- fmt.Errorf("encode error: %w", err) + return + } + }() + + decodeErrCh := make(chan error) + go func() { + defer close(decodeErrCh) + action, got, err := decoder.Decode() + if err != nil { + decodeErrCh <- fmt.Errorf("decode error: %w", err) + return + } + assert.Equal(t, eType, action) + assert.Equal(t, expect, got) + }() + + // Wait for encoder and decoder to return without error + err := utiltesting.WaitForAllChannelsToCloseWithTimeout(ctx, + wait.ForeverTestTimeout, encodeErrCh, decodeErrCh) + require.NoError(t, err) + + // Close the input pipe, which should cause the decoder to error + require.NoError(t, in.Close()) + + // Wait for decoder EOF error + decodeErrCh = make(chan error) + go func() { + defer close(decodeErrCh) + _, _, err := decoder.Decode() + if err != nil { + decodeErrCh <- err + } + }() + + // Wait for decoder EOF error + decodeErr, err := utiltesting.WaitForChannelEventWithTimeout(ctx, wait.ForeverTestTimeout, decodeErrCh) + require.NoError(t, err) + require.Equal(t, io.EOF, decodeErr) + }) } } func TestDecoder_SourceClose(t *testing.T) { + ctx := t.Context() out, in := io.Pipe() + defer assertNoCloseError(t, in) decoder := NewDecoder(streaming.NewDecoder(out, getDecoder()), getDecoder()) + defer decoder.Close() - done := make(chan struct{}) - + errCh := make(chan error) go func() { + defer close(errCh) _, _, err := decoder.Decode() - if err == nil { - t.Errorf("Unexpected nil error") + if err != nil { + errCh <- err } - close(done) }() - in.Close() + // Close the input pipe, which should cause the decoder to error + require.NoError(t, in.Close()) - select { - case <-done: - break - case <-time.After(wait.ForeverTestTimeout): - t.Error("Timeout") - } + // Wait for decoder EOF error + decodeErr, err := utiltesting.WaitForChannelEventWithTimeout(ctx, wait.ForeverTestTimeout, errCh) + require.NoError(t, err) + require.Equal(t, io.EOF, decodeErr) + + // Wait for errCh to close + err = utiltesting.WaitForChannelToCloseWithTimeout(ctx, wait.ForeverTestTimeout, errCh) + require.NoError(t, err) +} + +// assertNoCloseError asserts that closing the Closer doesn't error. +// Safe to call in a defer to ensure Closer.Close is called. +func assertNoCloseError(t *testing.T, c io.Closer) { + assert.NoError(t, c.Close()) } diff --git a/staging/src/k8s.io/client-go/util/testing/channels.go b/staging/src/k8s.io/client-go/util/testing/channels.go new file mode 100644 index 0000000000000..2611161b2f2f2 --- /dev/null +++ b/staging/src/k8s.io/client-go/util/testing/channels.go @@ -0,0 +1,154 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package testing + +import ( + "context" + "errors" + "fmt" + "reflect" + "time" +) + +// WaitForChannelEvent blocks until the channel receives an event. +// Returns an error if the channel is closed or the context is done. +func WaitForChannelEvent[T any](ctx context.Context, ch <-chan T) (T, error) { + var t T // zero value + for { + select { + case <-ctx.Done(): + err := ctx.Err() + switch err { + case context.DeadlineExceeded: + return t, fmt.Errorf("timed out waiting for channel to close: %w", err) + default: + return t, fmt.Errorf("context cancelled before channel closed: %w", err) + } + case event, ok := <-ch: + if !ok { + return t, errors.New("channel closed before receiving event") + } + return event, nil + } + } +} + +// WaitForChannelToClose blocks until the channel is closed. +// Returns an error if any events are received or the context is done. +func WaitForChannelToClose[T any](ctx context.Context, ch <-chan T) error { + for { + select { + case <-ctx.Done(): + err := ctx.Err() + switch err { + case context.DeadlineExceeded: + return fmt.Errorf("timed out waiting for channel to close: %w", err) + default: + return fmt.Errorf("context cancelled before channel closed: %w", err) + } + case event, ok := <-ch: + if !ok { + return nil + } + return &UnexpectedEventError[T]{Event: event} + } + } +} + +// WaitForAllChannelsToClose blocks until all the channels are closed. +// Returns an error if any events are received or the context is done. +func WaitForAllChannelsToClose[T any](ctx context.Context, channels ...<-chan T) error { + // Build a list of cases to select from + cases := make([]reflect.SelectCase, len(channels)+1) + for i, ch := range channels { + cases[i] = reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(ch), + } + } + // Add the context done channel as the last case + contextCaseIndex := len(channels) + cases[contextCaseIndex] = reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(ctx.Done()), + } + // Select from the cases until all channels are closed, an event is received, + // or the context is done. + channelsRemaining := len(cases) + for channelsRemaining > 1 { + // Block until one of the channels receives an event or closes + chosenIndex, value, ok := reflect.Select(cases) + if !ok { + // Return error immediately if the context is done + if chosenIndex == contextCaseIndex { + err := ctx.Err() + switch err { + case context.DeadlineExceeded: + return fmt.Errorf("timed out waiting for channel to close: %w", err) + default: + return fmt.Errorf("context cancelled before channel closed: %w", err) + } + } + // Remove closed channel from case to ignore it going forward + cases[chosenIndex].Chan = reflect.ValueOf(nil) + channelsRemaining-- + continue + } + // All events received are treated as errors + return fmt.Errorf("channel %d: %w", chosenIndex, + &UnexpectedEventError[T]{Event: value.Interface().(T)}) + } + // All channels closed before the context was done + return nil +} + +// WaitForChannelEventWithTimeout blocks until the channel receives an event. +// Returns an error if the channel is closed, the context is done, or the +// timeout is reached +func WaitForChannelEventWithTimeout[T any](ctx context.Context, timeout time.Duration, ch <-chan T) (T, error) { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + return WaitForChannelEvent(ctx, ch) +} + +// WaitForChannelToClose blocks until the channel is closed. +// Returns an error if any events are received, the context is done, or the +// timeout is reached +func WaitForChannelToCloseWithTimeout[T any](ctx context.Context, timeout time.Duration, ch <-chan T) error { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + return WaitForChannelToClose(ctx, ch) +} + +// WaitForAllChannelsToCloseWithTimeout blocks until all the channels are closed. +// Returns an error if any events are received, the context is done, or the +// timeout is reached +func WaitForAllChannelsToCloseWithTimeout[T any](ctx context.Context, timeout time.Duration, channels ...<-chan T) error { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + return WaitForAllChannelsToClose(ctx, channels...) +} + +// UnexpectedEventError wraps an event unexpectedly received from a channel. +type UnexpectedEventError[T any] struct { + Event T +} + +// Error implements the error interface +func (ue *UnexpectedEventError[T]) Error() string { + return fmt.Sprintf("channel received unexpected event: %#v", ue.Event) +} From e3cb60000136781943fb23c2f1875311a7a31251 Mon Sep 17 00:00:00 2001 From: Karl Isenberg Date: Wed, 16 Apr 2025 18:22:17 -0700 Subject: [PATCH 2/2] fix: watch client errors - Fix a bug in client-go Watch and WatchList that was keeping the response body open when a NegotiateError was encountered by Request.newStreamWatcher. This was causing the server to keep the storage watcher and timeout channel open until server-side timeout. - Fix a bug in client-go Watch and WatchList that was sometimes sending a "http: read on closed response body" error from the decoder to the result channel after the client watcher had been closed, which closes the http response body. The watcher is suppossed to be closed by the client when done reading from the result channel, so the impact was minimal, but this helps avoid needing to drain the result channel before closing it. --- .../k8s.io/apimachinery/pkg/util/net/http.go | 2 + .../pkg/endpoints/handlers/watch_test.go | 14 ++- staging/src/k8s.io/client-go/rest/request.go | 102 +++++++++++------- .../src/k8s.io/client-go/rest/request_test.go | 50 +-------- 4 files changed, 77 insertions(+), 91 deletions(-) diff --git a/staging/src/k8s.io/apimachinery/pkg/util/net/http.go b/staging/src/k8s.io/apimachinery/pkg/util/net/http.go index 8cc1810af1330..8cd833781b878 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/net/http.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/net/http.go @@ -92,6 +92,8 @@ func IsProbableEOF(err error) bool { return true case msg == "http: can't write HTTP request on broken connection": return true + case msg == "http: read on closed response body": + return true case strings.Contains(msg, "http2: server sent GOAWAY and closed the connection"): return true case strings.Contains(msg, "connection reset by peer"): diff --git a/staging/src/k8s.io/apiserver/pkg/endpoints/handlers/watch_test.go b/staging/src/k8s.io/apiserver/pkg/endpoints/handlers/watch_test.go index f9089589ca9e8..7d2577c0b5566 100644 --- a/staging/src/k8s.io/apiserver/pkg/endpoints/handlers/watch_test.go +++ b/staging/src/k8s.io/apiserver/pkg/endpoints/handlers/watch_test.go @@ -264,21 +264,19 @@ func TestWatchHTTPDynamicClientErrors(t *testing.T) { require.Equal(t, runtime.NegotiateError{Stream: true, ContentType: "testcase/json"}, err) // The client should automatically close the connection on error. - //TODO: Fix watch client to close the response body and stop the storage watcher after a NegotiateError - require.False(t, watcher.IsStopped()) // Wait for the server to call the CancelFunc returned by // TimeoutFactory.TimeoutCh, closing the done channel. - // err = utiltesting.WaitForChannelToCloseWithTimeout(ctx, wait.ForeverTestTimeout, doneCh) - // require.NoError(t, err) + err = utiltesting.WaitForChannelToCloseWithTimeout(ctx, wait.ForeverTestTimeout, doneCh) + require.NoError(t, err) // Wait for the server to call watcher.Stop, closing the result channel. - // err = utiltesting.WaitForChannelToCloseWithTimeout(ctx, wait.ForeverTestTimeout, watcher.ResultChan()) - // require.NoError(t, err) + err = utiltesting.WaitForChannelToCloseWithTimeout(ctx, wait.ForeverTestTimeout, watcher.ResultChan()) + require.NoError(t, err) // Confirm watcher.Stop was called by the server. - // require.Truef(t, watcher.IsStopped(), - // "Leaked watcher goroutine after request done") + require.Truef(t, watcher.IsStopped(), + "Leaked watcher goroutine after request done") } func TestWatchHTTPTimeout(t *testing.T) { diff --git a/staging/src/k8s.io/client-go/rest/request.go b/staging/src/k8s.io/client-go/rest/request.go index 1eb2f9b42a0e2..ae988a8b48f81 100644 --- a/staging/src/k8s.io/client-go/rest/request.go +++ b/staging/src/k8s.io/client-go/rest/request.go @@ -786,49 +786,79 @@ func (r *Request) watchInternal(ctx context.Context) (watch.Interface, runtime.D } retry := r.retryFn(r.maxRetries) url := r.URL().String() + var done bool + var w watch.Interface + var d runtime.Decoder + var err error for { - if err := retry.Before(ctx, r); err != nil { - return nil, nil, retry.WrapPreviousError(err) - } - - req, err := r.newHTTPRequest(ctx) - if err != nil { - return nil, nil, err - } - - resp, err := client.Do(req) - retry.After(ctx, r, resp, err) - if err == nil && resp.StatusCode == http.StatusOK { - return r.newStreamWatcher(ctx, resp) - } - - done, transformErr := func() (bool, error) { - defer readAndCloseResponseBody(resp) + // TODO(karlkfi): extract this out to a Request method for readability + done, w, d, err = func(ctx context.Context) (bool, watch.Interface, runtime.Decoder, error) { + // Cleanup after each failed attempt + ctx, cancel := context.WithCancel(ctx) + defer func() { cancel() }() + + if err := retry.Before(ctx, r); err != nil { + return true, nil, nil, retry.WrapPreviousError(err) + } - if retry.IsNextRetry(ctx, r, req, resp, err, isErrRetryableFunc) { - return false, nil + req, err := r.newHTTPRequest(ctx) + if err != nil { + return true, nil, nil, err } - if resp == nil { - // the server must have sent us an error in 'err' - return true, nil + resp, err := client.Do(req) + retry.After(ctx, r, resp, err) + if err == nil && resp.StatusCode == http.StatusOK { + w, d, streamErr := r.newStreamWatcher(ctx, resp) + if streamErr == nil { + // Invalidate cancel() to defer until watcher is stopped + cancel = func() {} + return true, w, d, nil + } + // Cancel the request immediately + cancel() + // Handle stream error like a request error + err = streamErr } - result := r.transformResponse(ctx, resp, req) - if err := result.Error(); err != nil { - return true, err + + done, transformErr := func() (bool, error) { + defer readAndCloseResponseBody(resp) + + if retry.IsNextRetry(ctx, r, req, resp, err, isErrRetryableFunc) { + return false, nil + } + if err != nil { + // Read the response body until closed. + // Skip decoding and ignore the content. + return true, nil + } + if resp != nil { + // Read the response body until closed. + // Decode the content and return any error. + result := r.transformResponse(ctx, resp, req) + if respErr := result.Error(); respErr != nil { + return true, respErr + } + } + // No error from client or server, but we're done retrying. + // Return a minimal error, to be wrapped with previous errors. + return true, fmt.Errorf("for request %s, got status: %v", url, resp.StatusCode) + }() + if done { + if isErrRetryableFunc(req, err) { + return true, watch.NewEmptyWatch(), nil, nil + } + if err == nil { + // if the server sent us an HTTP Response object, + // we need to return the error object from that. + err = transformErr + } + return true, nil, nil, retry.WrapPreviousError(err) } - return true, fmt.Errorf("for request %s, got status: %v", url, resp.StatusCode) - }() + return false, nil, nil, nil + }(ctx) if done { - if isErrRetryableFunc(req, err) { - return watch.NewEmptyWatch(), nil, nil - } - if err == nil { - // if the server sent us an HTTP Response object, - // we need to return the error object from that. - err = transformErr - } - return nil, nil, retry.WrapPreviousError(err) + return w, d, err } } } diff --git a/staging/src/k8s.io/client-go/rest/request_test.go b/staging/src/k8s.io/client-go/rest/request_test.go index 6e710062f11ca..6754378520db5 100644 --- a/staging/src/k8s.io/client-go/rest/request_test.go +++ b/staging/src/k8s.io/client-go/rest/request_test.go @@ -2131,14 +2131,7 @@ func TestWatch(t *testing.T) { // Wait for the result channel to close err = utiltesting.WaitForChannelToCloseWithTimeout(ctx, wait.ForeverTestTimeout, resultCh) - // TODO(karlkfi): Fix race condition that causes the watch server to sometimes send an error when stopped - if err != nil { - // Extract & compare the Event so we can see diff when it fails - event, unwrapErr := unwrapUnexectedWatchEvent(err) - require.NoError(t, unwrapErr) - expectedEvent := newClientWatchDecodingEvent(errReadOnClosedResBody) - require.Equal(t, expectedEvent, event) - } + require.NoError(t, err) }) } } @@ -2197,15 +2190,8 @@ func TestWatchNonDefaultContentType(t *testing.T) { watcher.Stop() // Wait for the result channel to close - err = utiltesting.WaitForChannelToCloseWithTimeout(t.Context(), wait.ForeverTestTimeout, resultCh) - // TODO(karlkfi): Fix race condition that causes the watch server to sometimes send an error when stopped - if err != nil { - // Extract & compare the Event so we can see diff when it fails - event, unwrapErr := unwrapUnexectedWatchEvent(err) - require.NoError(t, unwrapErr) - expectedEvent := newClientWatchDecodingEvent(errReadOnClosedResBody) - require.Equal(t, expectedEvent, event) - } + err = utiltesting.WaitForChannelToCloseWithTimeout(t.Context(), wait.ForeverTestTimeout, watcher.ResultChan()) + require.NoError(t, err) } func TestWatchUnknownContentType(t *testing.T) { @@ -4262,33 +4248,3 @@ func TestRequestWarningHandler(t *testing.T) { assert.Nil(t, request.warningHandler) }) } - -// From https://github.com/golang/go/blob/go1.20/src/net/http/transport.go#L2779 -var errReadOnClosedResBody = errors.New("http: read on closed response body") - -// newClientWatchDecodingEvent simulates a InternalServerError of type -// "ClientWatchDecoding", wrapped as a watch.Event. These errors come from the -// watch client StreamWatcher. -func newClientWatchDecodingEvent(decodeErr error) watch.Event { - // From Request.newStreamWatcher - clientErrorReporter := apierrors.NewClientErrorReporter(http.StatusInternalServerError, "GET", "ClientWatchDecoding") - // From StreamWatcher.receive - // TODO(karlkfi): Use `%w` to format errors (errorlint) in StreamWatcher.receive - //nolint:errorlint // StreamWatcher.receive uses `%v` - serverErr := fmt.Errorf("unable to decode an event from the watch stream: %v", decodeErr) - return watch.Event{ - Type: watch.Error, - Object: clientErrorReporter.AsObject(serverErr), - } -} - -// unwrapUnexectedWatchEvent unwraps a utiltesting.UnexpectedEventError to -// extract the watch.Error. -func unwrapUnexectedWatchEvent(err error) (watch.Event, error) { - var expectedErr *utiltesting.UnexpectedEventError[watch.Event] - if !errors.As(err, &expectedErr) { - return watch.Event{}, fmt.Errorf("Error expected to be of type %v, but was %v", - reflect.TypeOf(expectedErr), reflect.TypeOf(err)) - } - return expectedErr.Event, nil -}