diff --git a/staging/src/k8s.io/apimachinery/pkg/util/net/http.go b/staging/src/k8s.io/apimachinery/pkg/util/net/http.go index 8cc1810af1330..8cd833781b878 100644 --- a/staging/src/k8s.io/apimachinery/pkg/util/net/http.go +++ b/staging/src/k8s.io/apimachinery/pkg/util/net/http.go @@ -92,6 +92,8 @@ func IsProbableEOF(err error) bool { return true case msg == "http: can't write HTTP request on broken connection": return true + case msg == "http: read on closed response body": + return true case strings.Contains(msg, "http2: server sent GOAWAY and closed the connection"): return true case strings.Contains(msg, "connection reset by peer"): diff --git a/staging/src/k8s.io/apiserver/pkg/endpoints/handlers/watch_test.go b/staging/src/k8s.io/apiserver/pkg/endpoints/handlers/watch_test.go index 77410bfa5e30f..7d2577c0b5566 100644 --- a/staging/src/k8s.io/apiserver/pkg/endpoints/handlers/watch_test.go +++ b/staging/src/k8s.io/apiserver/pkg/endpoints/handlers/watch_test.go @@ -17,8 +17,8 @@ limitations under the License. package handlers import ( - "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -27,8 +27,9 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "k8s.io/apimachinery/pkg/api/errors" + apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" @@ -38,6 +39,7 @@ import ( apitesting "k8s.io/apiserver/pkg/endpoints/testing" "k8s.io/client-go/dynamic" restclient "k8s.io/client-go/rest" + utiltesting "k8s.io/client-go/util/testing" ) // Fake API versions, similar to api/latest.go @@ -61,6 +63,7 @@ func init() { } func TestWatchHTTPErrors(t *testing.T) { + ctx := t.Context() watcher := watch.NewFake() timeoutCh := make(chan time.Time) doneCh := make(chan struct{}) @@ -88,19 +91,24 @@ func TestWatchHTTPErrors(t *testing.T) { defer s.Close() // Setup a client - dest, _ := url.Parse(s.URL) + dest, err := url.Parse(s.URL) + require.NoError(t, err) dest.Path = "/" + namedGroupPrefix + "/" + testGroupV2.Group + "/" + testGroupV2.Version + "/simple" dest.RawQuery = "watch=true" - req, _ := http.NewRequest(http.MethodGet, dest.String(), nil) + // Start watch request + req, err := http.NewRequestWithContext(ctx, http.MethodGet, dest.String(), nil) + require.NoError(t, err) client := http.Client{} resp, err := client.Do(req) require.NoError(t, err) - errStatus := errors.NewInternalError(fmt.Errorf("we got an error")).Status() + defer assertClosed(t, resp.Body) + + // Send error to server from storage + errStatus := apierrors.NewInternalError(fmt.Errorf("we got an error")).Status() watcher.Error(&errStatus) - watcher.Stop() - // Make sure we can actually watch an endpoint + // Decode error from the response decoder := json.NewDecoder(resp.Body) var got watchJSON err = decoder.Decode(&got) @@ -121,12 +129,27 @@ func TestWatchHTTPErrors(t *testing.T) { Details: errStatus.Details, } require.Equal(t, expectedStatus, status) + + // Close the response body to signal the server to stop serving. + require.NoError(t, resp.Body.Close()) + + // Wait for the server to call the CancelFunc returned by + // TimeoutFactory.TimeoutCh, closing the done channel. + err = utiltesting.WaitForChannelToCloseWithTimeout(ctx, wait.ForeverTestTimeout, doneCh) + require.NoError(t, err) + + // Wait for the server to call watcher.Stop, closing the result channel. + err = utiltesting.WaitForChannelToCloseWithTimeout(ctx, wait.ForeverTestTimeout, watcher.ResultChan()) + require.NoError(t, err) + + // Confirm watcher.Stop was called by the server. + require.Truef(t, watcher.IsStopped(), + "Leaked watcher goroutine after request done") } func TestWatchHTTPErrorsBeforeServe(t *testing.T) { + ctx := t.Context() watcher := watch.NewFake() - timeoutCh := make(chan time.Time) - doneCh := make(chan struct{}) info, ok := runtime.SerializerInfoForMediaType(codecs.SupportedMediaTypes(), runtime.ContentTypeJSON) if !ok || info.StreamSerializer == nil { @@ -147,24 +170,29 @@ func TestWatchHTTPErrorsBeforeServe(t *testing.T) { Encoder: testCodecV2, EmbeddedEncoder: testCodecV2, - TimeoutFactory: &fakeTimeoutFactory{timeoutCh, doneCh}, + // TimeoutFactory should not be needed, because the server should error + // before calling TimeoutFactory.TimeoutCh. } - statusErr := errors.NewInternalError(fmt.Errorf("we got an error")) + statusErr := apierrors.NewInternalError(fmt.Errorf("we got an error")) errStatus := statusErr.Status() s := httptest.NewServer(serveWatch(watcher, watchServer, statusErr)) defer s.Close() // Setup a client - dest, _ := url.Parse(s.URL) + dest, err := url.Parse(s.URL) + require.NoError(t, err) dest.Path = "/" + namedGroupPrefix + "/" + testGroupV2.Group + "/" + testGroupV2.Version + "/simple" dest.RawQuery = "watch=true" - req, _ := http.NewRequest(http.MethodGet, dest.String(), nil) + // Start watch request + req, err := http.NewRequestWithContext(ctx, http.MethodGet, dest.String(), nil) + require.NoError(t, err) client := http.Client{} resp, err := client.Do(req) require.NoError(t, err) + defer assertClosed(t, resp.Body) // We had already got an error before watch serve started decoder := json.NewDecoder(resp.Body) @@ -184,15 +212,25 @@ func TestWatchHTTPErrorsBeforeServe(t *testing.T) { } require.Equal(t, expectedStatus, status) - // check for leaks + // Close the response body to signal the server to stop serving. + // This isn't strictly necessary, since the test serveWatch doesn't block, + // but it would be if this were the real watch server. + require.NoError(t, resp.Body.Close()) + + // Wait for the server to call watcher.Stop, closing the result channel. + err = utiltesting.WaitForChannelToCloseWithTimeout(ctx, wait.ForeverTestTimeout, watcher.ResultChan()) + require.NoError(t, err) + + // Confirm watcher.Stop was called by the server. require.Truef(t, watcher.IsStopped(), - "Leaked watcher goruntine after request done") + "Leaked watcher goroutine after request done") } func TestWatchHTTPDynamicClientErrors(t *testing.T) { + ctx := t.Context() watcher := watch.NewFake() timeoutCh := make(chan time.Time) - done := make(chan struct{}) + doneCh := make(chan struct{}) info, ok := runtime.SerializerInfoForMediaType(codecs.SupportedMediaTypes(), runtime.ContentTypeJSON) if !ok || info.StreamSerializer == nil { @@ -210,7 +248,7 @@ func TestWatchHTTPDynamicClientErrors(t *testing.T) { Encoder: testCodecV2, EmbeddedEncoder: testCodecV2, - TimeoutFactory: &fakeTimeoutFactory{timeoutCh, done}, + TimeoutFactory: &fakeTimeoutFactory{timeoutCh: timeoutCh, done: doneCh}, } s := httptest.NewServer(serveWatch(watcher, watchServer, nil)) @@ -222,14 +260,30 @@ func TestWatchHTTPDynamicClientErrors(t *testing.T) { APIPath: "/" + namedGroupPrefix, }).Resource(testGroupV2.WithResource("simple")) - _, err := client.Watch(context.TODO(), metav1.ListOptions{}) + _, err := client.Watch(ctx, metav1.ListOptions{}) require.Equal(t, runtime.NegotiateError{Stream: true, ContentType: "testcase/json"}, err) + + // The client should automatically close the connection on error. + + // Wait for the server to call the CancelFunc returned by + // TimeoutFactory.TimeoutCh, closing the done channel. + err = utiltesting.WaitForChannelToCloseWithTimeout(ctx, wait.ForeverTestTimeout, doneCh) + require.NoError(t, err) + + // Wait for the server to call watcher.Stop, closing the result channel. + err = utiltesting.WaitForChannelToCloseWithTimeout(ctx, wait.ForeverTestTimeout, watcher.ResultChan()) + require.NoError(t, err) + + // Confirm watcher.Stop was called by the server. + require.Truef(t, watcher.IsStopped(), + "Leaked watcher goroutine after request done") } func TestWatchHTTPTimeout(t *testing.T) { + ctx := t.Context() watcher := watch.NewFake() timeoutCh := make(chan time.Time) - done := make(chan struct{}) + doneCh := make(chan struct{}) info, ok := runtime.SerializerInfoForMediaType(codecs.SupportedMediaTypes(), runtime.ContentTypeJSON) if !ok || info.StreamSerializer == nil { @@ -247,21 +301,27 @@ func TestWatchHTTPTimeout(t *testing.T) { Encoder: testCodecV2, EmbeddedEncoder: testCodecV2, - TimeoutFactory: &fakeTimeoutFactory{timeoutCh, done}, + TimeoutFactory: &fakeTimeoutFactory{timeoutCh: timeoutCh, done: doneCh}, } s := httptest.NewServer(serveWatch(watcher, watchServer, nil)) defer s.Close() // Setup a client - dest, _ := url.Parse(s.URL) + dest, err := url.Parse(s.URL) + require.NoError(t, err) dest.Path = "/" + namedGroupPrefix + "/" + testGroupV2.Group + "/" + testGroupV2.Version + "/simple" dest.RawQuery = "watch=true" - req, _ := http.NewRequest(http.MethodGet, dest.String(), nil) + // Start watch request + req, err := http.NewRequestWithContext(ctx, http.MethodGet, dest.String(), nil) + require.NoError(t, err) client := http.Client{} resp, err := client.Do(req) require.NoError(t, err) + defer assertClosed(t, resp.Body) + + // Send object added event to server from storage watcher.Add(&apitesting.Simple{TypeMeta: metav1.TypeMeta{APIVersion: testGroupV2.String()}}) // Make sure we can actually watch an endpoint @@ -270,29 +330,28 @@ func TestWatchHTTPTimeout(t *testing.T) { err = decoder.Decode(&got) require.NoError(t, err) - // Timeout and check for leaks + // Trigger server-side timeout. close(timeoutCh) - select { - case <-done: - eventCh := watcher.ResultChan() - select { - case _, opened := <-eventCh: - if opened { - t.Errorf("Watcher received unexpected event") - } - if !watcher.IsStopped() { - t.Errorf("Watcher is not stopped") - } - case <-time.After(wait.ForeverTestTimeout): - t.Errorf("Leaked watch on timeout") - } - case <-time.After(wait.ForeverTestTimeout): - t.Errorf("Failed to stop watcher after %s of timeout signal", wait.ForeverTestTimeout.String()) - } - // Make sure we can't receive any more events through the timeout watch + // Wait for the server to call the CancelFunc returned by + // TimeoutFactory.TimeoutCh, closing the done channel. + err = utiltesting.WaitForChannelToCloseWithTimeout(ctx, wait.ForeverTestTimeout, doneCh) + require.NoError(t, err) + + // Wait for the server to call watcher.Stop, closing the result channel. + err = utiltesting.WaitForChannelToCloseWithTimeout(ctx, wait.ForeverTestTimeout, watcher.ResultChan()) + require.NoError(t, err) + + // Confirm watcher.Stop was called by the server. + require.Truef(t, watcher.IsStopped(), + "Leaked watcher goroutine after request done") + + // Make sure we can't receive any more events after the watch timeout err = decoder.Decode(&got) require.Equal(t, io.EOF, err) + + // Close the response body to clean up watch client resources. + require.NoError(t, resp.Body.Close()) } // watchJSON defines the expected JSON wire equivalent of watch.Event. @@ -330,3 +389,49 @@ func serveWatch(watcher watch.Interface, watchServer *WatchServer, preServeErr e watchServer.HandleHTTP(w, req) } } + +// From https://github.com/golang/go/blob/go1.20/src/net/http/transport.go#L2779 +var errReadOnClosedResBody = errors.New("http: read on closed response body") + +// assertClosed fails the test if the ReadCloser is NOT already closed. +// If not already closed, the ReadCloser will be drained and closed. +// Defer when your test is expected to close the ReadCloser before ending. +func assertClosed(t *testing.T, rc io.ReadCloser) { + assert.Equal(t, errReadOnClosedResBody, drainAndClose(rc)) +} + +// drainAndClose reads from the ReadCloser until EOF, discarding the content, +// and closes the ReadCloser when finished or on error. +// Returns an error when either Read or Close error. If both error, the errors +// are joined and returned. +// +// In a defer from a test, use with t.Error or assert.NoError, NOT t.Fatal or +// require.NoError. +func drainAndClose(rc io.ReadCloser) error { + errCh := make(chan error) + go func() { + // Close after done reading + defer func() { + defer close(errCh) + if err := rc.Close(); err != nil { + errCh <- err + } + }() + // Read until EOF and discard + if _, err := io.Copy(io.Discard, rc); err != nil { + errCh <- err + } + }() + + // Wait until Read and Close are both done. + // Combine errors, if multiple. + var multiErr error + for err := range errCh { + if multiErr != nil { + multiErr = errors.Join(multiErr, err) + } else { + multiErr = err + } + } + return multiErr +} diff --git a/staging/src/k8s.io/client-go/rest/request.go b/staging/src/k8s.io/client-go/rest/request.go index 1eb2f9b42a0e2..ae988a8b48f81 100644 --- a/staging/src/k8s.io/client-go/rest/request.go +++ b/staging/src/k8s.io/client-go/rest/request.go @@ -786,49 +786,79 @@ func (r *Request) watchInternal(ctx context.Context) (watch.Interface, runtime.D } retry := r.retryFn(r.maxRetries) url := r.URL().String() + var done bool + var w watch.Interface + var d runtime.Decoder + var err error for { - if err := retry.Before(ctx, r); err != nil { - return nil, nil, retry.WrapPreviousError(err) - } - - req, err := r.newHTTPRequest(ctx) - if err != nil { - return nil, nil, err - } - - resp, err := client.Do(req) - retry.After(ctx, r, resp, err) - if err == nil && resp.StatusCode == http.StatusOK { - return r.newStreamWatcher(ctx, resp) - } - - done, transformErr := func() (bool, error) { - defer readAndCloseResponseBody(resp) + // TODO(karlkfi): extract this out to a Request method for readability + done, w, d, err = func(ctx context.Context) (bool, watch.Interface, runtime.Decoder, error) { + // Cleanup after each failed attempt + ctx, cancel := context.WithCancel(ctx) + defer func() { cancel() }() + + if err := retry.Before(ctx, r); err != nil { + return true, nil, nil, retry.WrapPreviousError(err) + } - if retry.IsNextRetry(ctx, r, req, resp, err, isErrRetryableFunc) { - return false, nil + req, err := r.newHTTPRequest(ctx) + if err != nil { + return true, nil, nil, err } - if resp == nil { - // the server must have sent us an error in 'err' - return true, nil + resp, err := client.Do(req) + retry.After(ctx, r, resp, err) + if err == nil && resp.StatusCode == http.StatusOK { + w, d, streamErr := r.newStreamWatcher(ctx, resp) + if streamErr == nil { + // Invalidate cancel() to defer until watcher is stopped + cancel = func() {} + return true, w, d, nil + } + // Cancel the request immediately + cancel() + // Handle stream error like a request error + err = streamErr } - result := r.transformResponse(ctx, resp, req) - if err := result.Error(); err != nil { - return true, err + + done, transformErr := func() (bool, error) { + defer readAndCloseResponseBody(resp) + + if retry.IsNextRetry(ctx, r, req, resp, err, isErrRetryableFunc) { + return false, nil + } + if err != nil { + // Read the response body until closed. + // Skip decoding and ignore the content. + return true, nil + } + if resp != nil { + // Read the response body until closed. + // Decode the content and return any error. + result := r.transformResponse(ctx, resp, req) + if respErr := result.Error(); respErr != nil { + return true, respErr + } + } + // No error from client or server, but we're done retrying. + // Return a minimal error, to be wrapped with previous errors. + return true, fmt.Errorf("for request %s, got status: %v", url, resp.StatusCode) + }() + if done { + if isErrRetryableFunc(req, err) { + return true, watch.NewEmptyWatch(), nil, nil + } + if err == nil { + // if the server sent us an HTTP Response object, + // we need to return the error object from that. + err = transformErr + } + return true, nil, nil, retry.WrapPreviousError(err) } - return true, fmt.Errorf("for request %s, got status: %v", url, resp.StatusCode) - }() + return false, nil, nil, nil + }(ctx) if done { - if isErrRetryableFunc(req, err) { - return watch.NewEmptyWatch(), nil, nil - } - if err == nil { - // if the server sent us an HTTP Response object, - // we need to return the error object from that. - err = transformErr - } - return nil, nil, retry.WrapPreviousError(err) + return w, d, err } } } diff --git a/staging/src/k8s.io/client-go/rest/request_test.go b/staging/src/k8s.io/client-go/rest/request_test.go index fd64dcb028187..6754378520db5 100644 --- a/staging/src/k8s.io/client-go/rest/request_test.go +++ b/staging/src/k8s.io/client-go/rest/request_test.go @@ -39,11 +39,9 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "github.com/google/go-cmp/cmp" - v1 "k8s.io/api/core/v1" apiequality "k8s.io/apimachinery/pkg/api/equality" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -54,6 +52,7 @@ import ( "k8s.io/apimachinery/pkg/runtime/serializer/streaming" "k8s.io/apimachinery/pkg/util/intstr" utilnet "k8s.io/apimachinery/pkg/util/net" + "k8s.io/apimachinery/pkg/util/wait" "k8s.io/apimachinery/pkg/watch" "k8s.io/client-go/kubernetes/scheme" restclientwatch "k8s.io/client-go/rest/watch" @@ -2069,6 +2068,7 @@ func TestWatch(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + ctx := t.Context() var table = []struct { t watch.EventType obj runtime.Object @@ -2110,29 +2110,28 @@ func TestWatch(t *testing.T) { defer testServer.Close() s := testRESTClient(t, testServer) - watching, err := s.Get().Prefix("path/to/watch/thing"). - MaxRetries(test.maxRetries).Watch(context.Background()) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } + watcher, err := s.Get().Prefix("path/to/watch/thing"). + MaxRetries(test.maxRetries).Watch(ctx) + require.NoError(t, err) + defer watcher.Stop() + // Read the events from the result channel + resultCh := watcher.ResultChan() for _, item := range table { - got, ok := <-watching.ResultChan() + got, ok := <-resultCh if !ok { - t.Fatalf("Unexpected early close") - } - if e, a := item.t, got.Type; e != a { - t.Errorf("Expected %v, got %v", e, a) - } - if e, a := item.obj, got.Object; !apiequality.Semantic.DeepDerivative(e, a) { - t.Errorf("Expected %v, got %v", e, a) + t.Fatal("Unexpected early close") } + assert.Equal(t, item.t, got.Type) + assert.Equal(t, item.obj, got.Object) } - _, ok := <-watching.ResultChan() - if ok { - t.Fatal("Unexpected non-close") - } + // Stop watcher when done reading watch events + watcher.Stop() + + // Wait for the result channel to close + err = utiltesting.WaitForChannelToCloseWithTimeout(ctx, wait.ForeverTestTimeout, resultCh) + require.NoError(t, err) }) } } @@ -2173,28 +2172,26 @@ func TestWatchNonDefaultContentType(t *testing.T) { contentConfig := defaultContentConfig() contentConfig.ContentType = "application/vnd.kubernetes.protobuf" s := testRESTClientWithConfig(t, testServer, contentConfig) - watching, err := s.Get().Prefix("path/to/watch/thing").Watch(context.Background()) - if err != nil { - t.Fatalf("Unexpected error") - } + watcher, err := s.Get().Prefix("path/to/watch/thing").Watch(t.Context()) + require.NoError(t, err) + defer watcher.Stop() + resultCh := watcher.ResultChan() for _, item := range table { - got, ok := <-watching.ResultChan() + got, ok := <-resultCh if !ok { t.Fatalf("Unexpected early close") } - if e, a := item.t, got.Type; e != a { - t.Errorf("Expected %v, got %v", e, a) - } - if e, a := item.obj, got.Object; !apiequality.Semantic.DeepDerivative(e, a) { - t.Errorf("Expected %v, got %v", e, a) - } + assert.Equal(t, item.t, got.Type) + assert.Equal(t, item.obj, got.Object) } - _, ok := <-watching.ResultChan() - if ok { - t.Fatal("Unexpected non-close") - } + // Stop watcher when done reading watch events + watcher.Stop() + + // Wait for the result channel to close + err = utiltesting.WaitForChannelToCloseWithTimeout(t.Context(), wait.ForeverTestTimeout, watcher.ResultChan()) + require.NoError(t, err) } func TestWatchUnknownContentType(t *testing.T) { @@ -2230,10 +2227,8 @@ func TestWatchUnknownContentType(t *testing.T) { defer testServer.Close() s := testRESTClient(t, testServer) - _, err := s.Get().Prefix("path/to/watch/thing").Watch(context.Background()) - if err == nil { - t.Fatalf("Expected to fail due to lack of known stream serialization for content type") - } + _, err := s.Get().Prefix("path/to/watch/thing").Watch(t.Context()) + require.Equal(t, runtime.NegotiateError{ContentType: "foobar", Stream: true}, err) } func TestStream(t *testing.T) { diff --git a/staging/src/k8s.io/client-go/rest/watch/decoder_test.go b/staging/src/k8s.io/client-go/rest/watch/decoder_test.go index a6a95b884222b..6bc18d0abb871 100644 --- a/staging/src/k8s.io/client-go/rest/watch/decoder_test.go +++ b/staging/src/k8s.io/client-go/rest/watch/decoder_test.go @@ -21,10 +21,10 @@ import ( "fmt" "io" "testing" - "time" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" v1 "k8s.io/api/core/v1" - apiequality "k8s.io/apimachinery/pkg/api/equality" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" runtimejson "k8s.io/apimachinery/pkg/runtime/serializer/json" @@ -32,6 +32,7 @@ import ( "k8s.io/apimachinery/pkg/util/wait" "k8s.io/apimachinery/pkg/watch" "k8s.io/client-go/kubernetes/scheme" + utiltesting "k8s.io/client-go/util/testing" ) // getDecoder mimics how k8s.io/client-go/rest.createSerializers creates a decoder @@ -45,86 +46,103 @@ func TestDecoder(t *testing.T) { table := []watch.EventType{watch.Added, watch.Deleted, watch.Modified, watch.Error, watch.Bookmark} for _, eventType := range table { - out, in := io.Pipe() - - decoder := NewDecoder(streaming.NewDecoder(out, getDecoder()), getDecoder()) - expect := &v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "foo"}} - encoder := json.NewEncoder(in) - eType := eventType - errc := make(chan error) - - go func() { - data, err := runtime.Encode(scheme.Codecs.LegacyCodec(v1.SchemeGroupVersion), expect) - if err != nil { - errc <- fmt.Errorf("Unexpected error %v", err) - return - } - event := metav1.WatchEvent{ - Type: string(eType), - Object: runtime.RawExtension{Raw: json.RawMessage(data)}, - } - if err := encoder.Encode(&event); err != nil { - t.Errorf("Unexpected error %v", err) - } - in.Close() - }() - - done := make(chan struct{}) - go func() { - action, got, err := decoder.Decode() - if err != nil { - errc <- fmt.Errorf("Unexpected error %v", err) - return - } - if e, a := eType, action; e != a { - t.Errorf("Expected %v, got %v", e, a) - } - if e, a := expect, got; !apiequality.Semantic.DeepDerivative(e, a) { - t.Errorf("Expected %v, got %v", e, a) - } - t.Logf("Exited read") - close(done) - }() - select { - case err := <-errc: - t.Fatal(err) - case <-done: - } - - done = make(chan struct{}) - go func() { - _, _, err := decoder.Decode() - if err == nil { - t.Errorf("Unexpected nil error") - } - close(done) - }() - <-done - - decoder.Close() + t.Run(string(eventType), func(t *testing.T) { + ctx := t.Context() + out, in := io.Pipe() + defer assertNoCloseError(t, in) + decoder := NewDecoder(streaming.NewDecoder(out, getDecoder()), getDecoder()) + defer decoder.Close() + expect := &v1.Pod{ObjectMeta: metav1.ObjectMeta{Name: "foo"}} + encoder := json.NewEncoder(in) + eType := eventType + + encodeErrCh := make(chan error) + go func() { + defer close(encodeErrCh) + data, err := runtime.Encode(scheme.Codecs.LegacyCodec(v1.SchemeGroupVersion), expect) + if err != nil { + encodeErrCh <- fmt.Errorf("encode error: %w", err) + return + } + event := metav1.WatchEvent{ + Type: string(eType), + Object: runtime.RawExtension{Raw: json.RawMessage(data)}, + } + if err := encoder.Encode(&event); err != nil { + encodeErrCh <- fmt.Errorf("encode error: %w", err) + return + } + }() + + decodeErrCh := make(chan error) + go func() { + defer close(decodeErrCh) + action, got, err := decoder.Decode() + if err != nil { + decodeErrCh <- fmt.Errorf("decode error: %w", err) + return + } + assert.Equal(t, eType, action) + assert.Equal(t, expect, got) + }() + + // Wait for encoder and decoder to return without error + err := utiltesting.WaitForAllChannelsToCloseWithTimeout(ctx, + wait.ForeverTestTimeout, encodeErrCh, decodeErrCh) + require.NoError(t, err) + + // Close the input pipe, which should cause the decoder to error + require.NoError(t, in.Close()) + + // Wait for decoder EOF error + decodeErrCh = make(chan error) + go func() { + defer close(decodeErrCh) + _, _, err := decoder.Decode() + if err != nil { + decodeErrCh <- err + } + }() + + // Wait for decoder EOF error + decodeErr, err := utiltesting.WaitForChannelEventWithTimeout(ctx, wait.ForeverTestTimeout, decodeErrCh) + require.NoError(t, err) + require.Equal(t, io.EOF, decodeErr) + }) } } func TestDecoder_SourceClose(t *testing.T) { + ctx := t.Context() out, in := io.Pipe() + defer assertNoCloseError(t, in) decoder := NewDecoder(streaming.NewDecoder(out, getDecoder()), getDecoder()) + defer decoder.Close() - done := make(chan struct{}) - + errCh := make(chan error) go func() { + defer close(errCh) _, _, err := decoder.Decode() - if err == nil { - t.Errorf("Unexpected nil error") + if err != nil { + errCh <- err } - close(done) }() - in.Close() + // Close the input pipe, which should cause the decoder to error + require.NoError(t, in.Close()) - select { - case <-done: - break - case <-time.After(wait.ForeverTestTimeout): - t.Error("Timeout") - } + // Wait for decoder EOF error + decodeErr, err := utiltesting.WaitForChannelEventWithTimeout(ctx, wait.ForeverTestTimeout, errCh) + require.NoError(t, err) + require.Equal(t, io.EOF, decodeErr) + + // Wait for errCh to close + err = utiltesting.WaitForChannelToCloseWithTimeout(ctx, wait.ForeverTestTimeout, errCh) + require.NoError(t, err) +} + +// assertNoCloseError asserts that closing the Closer doesn't error. +// Safe to call in a defer to ensure Closer.Close is called. +func assertNoCloseError(t *testing.T, c io.Closer) { + assert.NoError(t, c.Close()) } diff --git a/staging/src/k8s.io/client-go/util/testing/channels.go b/staging/src/k8s.io/client-go/util/testing/channels.go new file mode 100644 index 0000000000000..2611161b2f2f2 --- /dev/null +++ b/staging/src/k8s.io/client-go/util/testing/channels.go @@ -0,0 +1,154 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package testing + +import ( + "context" + "errors" + "fmt" + "reflect" + "time" +) + +// WaitForChannelEvent blocks until the channel receives an event. +// Returns an error if the channel is closed or the context is done. +func WaitForChannelEvent[T any](ctx context.Context, ch <-chan T) (T, error) { + var t T // zero value + for { + select { + case <-ctx.Done(): + err := ctx.Err() + switch err { + case context.DeadlineExceeded: + return t, fmt.Errorf("timed out waiting for channel to close: %w", err) + default: + return t, fmt.Errorf("context cancelled before channel closed: %w", err) + } + case event, ok := <-ch: + if !ok { + return t, errors.New("channel closed before receiving event") + } + return event, nil + } + } +} + +// WaitForChannelToClose blocks until the channel is closed. +// Returns an error if any events are received or the context is done. +func WaitForChannelToClose[T any](ctx context.Context, ch <-chan T) error { + for { + select { + case <-ctx.Done(): + err := ctx.Err() + switch err { + case context.DeadlineExceeded: + return fmt.Errorf("timed out waiting for channel to close: %w", err) + default: + return fmt.Errorf("context cancelled before channel closed: %w", err) + } + case event, ok := <-ch: + if !ok { + return nil + } + return &UnexpectedEventError[T]{Event: event} + } + } +} + +// WaitForAllChannelsToClose blocks until all the channels are closed. +// Returns an error if any events are received or the context is done. +func WaitForAllChannelsToClose[T any](ctx context.Context, channels ...<-chan T) error { + // Build a list of cases to select from + cases := make([]reflect.SelectCase, len(channels)+1) + for i, ch := range channels { + cases[i] = reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(ch), + } + } + // Add the context done channel as the last case + contextCaseIndex := len(channels) + cases[contextCaseIndex] = reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(ctx.Done()), + } + // Select from the cases until all channels are closed, an event is received, + // or the context is done. + channelsRemaining := len(cases) + for channelsRemaining > 1 { + // Block until one of the channels receives an event or closes + chosenIndex, value, ok := reflect.Select(cases) + if !ok { + // Return error immediately if the context is done + if chosenIndex == contextCaseIndex { + err := ctx.Err() + switch err { + case context.DeadlineExceeded: + return fmt.Errorf("timed out waiting for channel to close: %w", err) + default: + return fmt.Errorf("context cancelled before channel closed: %w", err) + } + } + // Remove closed channel from case to ignore it going forward + cases[chosenIndex].Chan = reflect.ValueOf(nil) + channelsRemaining-- + continue + } + // All events received are treated as errors + return fmt.Errorf("channel %d: %w", chosenIndex, + &UnexpectedEventError[T]{Event: value.Interface().(T)}) + } + // All channels closed before the context was done + return nil +} + +// WaitForChannelEventWithTimeout blocks until the channel receives an event. +// Returns an error if the channel is closed, the context is done, or the +// timeout is reached +func WaitForChannelEventWithTimeout[T any](ctx context.Context, timeout time.Duration, ch <-chan T) (T, error) { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + return WaitForChannelEvent(ctx, ch) +} + +// WaitForChannelToClose blocks until the channel is closed. +// Returns an error if any events are received, the context is done, or the +// timeout is reached +func WaitForChannelToCloseWithTimeout[T any](ctx context.Context, timeout time.Duration, ch <-chan T) error { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + return WaitForChannelToClose(ctx, ch) +} + +// WaitForAllChannelsToCloseWithTimeout blocks until all the channels are closed. +// Returns an error if any events are received, the context is done, or the +// timeout is reached +func WaitForAllChannelsToCloseWithTimeout[T any](ctx context.Context, timeout time.Duration, channels ...<-chan T) error { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + return WaitForAllChannelsToClose(ctx, channels...) +} + +// UnexpectedEventError wraps an event unexpectedly received from a channel. +type UnexpectedEventError[T any] struct { + Event T +} + +// Error implements the error interface +func (ue *UnexpectedEventError[T]) Error() string { + return fmt.Sprintf("channel received unexpected event: %#v", ue.Event) +}