diff --git a/sdk/azcore/CHANGELOG.md b/sdk/azcore/CHANGELOG.md index 7ecc8f2a9679..c4da5d78f1da 100644 --- a/sdk/azcore/CHANGELOG.md +++ b/sdk/azcore/CHANGELOG.md @@ -1,5 +1,12 @@ # Release History +## 1.6.1 (2023-06-06) + +### Bugs Fixed +* Retry policy always clones the underlying `*http.Request` before invoking the next policy. +* Added some non-standard error codes to the list of error codes for unregistered resource providers. +* Fixed an issue in `azcore.NewClient()` and `arm.NewClient()` that could cause an incorrect module name to be used in telemetry. + ## 1.6.0 (2023-05-04) ### Features Added diff --git a/sdk/azcore/arm/client.go b/sdk/azcore/arm/client.go index 94d018d43537..aa34575f66fb 100644 --- a/sdk/azcore/arm/client.go +++ b/sdk/azcore/arm/client.go @@ -28,12 +28,13 @@ type Client struct { // NewClient creates a new Client instance with the provided values. // This client is intended to be used with Azure Resource Manager endpoints. -// - clientName - the fully qualified name of the client ("package.Client"); this is used by the tracing provider when creating spans +// - clientName - the fully qualified name of the client ("module/package.Client"); this is used by the telemetry policy and tracing provider. +// if module and package are the same value, the "module/" prefix can be omitted. // - moduleVersion - the version of the containing module; used by the telemetry policy // - cred - the TokenCredential used to authenticate the request // - options - optional client configurations; pass nil to accept the default values func NewClient(clientName, moduleVersion string, cred azcore.TokenCredential, options *ClientOptions) (*Client, error) { - pkg, err := shared.ExtractPackageName(clientName) + mod, client, err := shared.ExtractModuleName(clientName) if err != nil { return nil, err } @@ -52,12 +53,12 @@ func NewClient(clientName, moduleVersion string, cred azcore.TokenCredential, op if c, ok := options.Cloud.Services[cloud.ResourceManager]; ok { ep = c.Endpoint } - pl, err := armruntime.NewPipeline(pkg, moduleVersion, cred, runtime.PipelineOptions{}, options) + pl, err := armruntime.NewPipeline(mod, moduleVersion, cred, runtime.PipelineOptions{}, options) if err != nil { return nil, err } - tr := options.TracingProvider.NewTracer(clientName, moduleVersion) + tr := options.TracingProvider.NewTracer(client, moduleVersion) return &Client{ep: ep, pl: pl, tr: tr}, nil } diff --git a/sdk/azcore/arm/runtime/pipeline_test.go b/sdk/azcore/arm/runtime/pipeline_test.go index 1fc6fef99aeb..d48a03f14fe8 100644 --- a/sdk/azcore/arm/runtime/pipeline_test.go +++ b/sdk/azcore/arm/runtime/pipeline_test.go @@ -104,7 +104,7 @@ func TestDisableAutoRPRegistration(t *testing.T) { srv, close := mock.NewServer() defer close() // initial response that RP is unregistered - srv.SetResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp))) + srv.SetResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp1))) opts := &armpolicy.ClientOptions{DisableRPRegistration: true, ClientOptions: policy.ClientOptions{Transport: srv}} req, err := azruntime.NewRequest(context.Background(), http.MethodGet, srv.URL()) if err != nil { diff --git a/sdk/azcore/arm/runtime/policy_register_rp.go b/sdk/azcore/arm/runtime/policy_register_rp.go index 49e6608070f2..c3f5eeafe020 100644 --- a/sdk/azcore/arm/runtime/policy_register_rp.go +++ b/sdk/azcore/arm/runtime/policy_register_rp.go @@ -80,7 +80,6 @@ func (r *rpRegistrationPolicy) Do(req *azpolicy.Request) (*http.Response, error) // policy is disabled return req.Next() } - const unregisteredRPCode = "MissingSubscriptionRegistration" const registeredState = "Registered" var rp string var resp *http.Response @@ -101,7 +100,7 @@ func (r *rpRegistrationPolicy) Do(req *azpolicy.Request) (*http.Response, error) // to the caller so its error unmarshalling will kick in return resp, err } - if !strings.EqualFold(reqErr.ServiceError.Code, unregisteredRPCode) { + if !isUnregisteredRPCode(reqErr.ServiceError.Code) { // not a 409 due to unregistered RP. just return the response // to the caller so its error unmarshalling will kick in return resp, err @@ -173,6 +172,21 @@ func (r *rpRegistrationPolicy) Do(req *azpolicy.Request) (*http.Response, error) return resp, fmt.Errorf("exceeded attempts to register %s", rp) } +var unregisteredRPCodes = []string{ + "MissingSubscriptionRegistration", + "MissingRegistrationForResourceProvider", + "Subscription Not Registered", +} + +func isUnregisteredRPCode(errorCode string) bool { + for _, code := range unregisteredRPCodes { + if strings.EqualFold(errorCode, code) { + return true + } + } + return false +} + func getSubscription(path string) (string, error) { parts := strings.Split(path, "/") for i, v := range parts { diff --git a/sdk/azcore/arm/runtime/policy_register_rp_test.go b/sdk/azcore/arm/runtime/policy_register_rp_test.go index 567e38fd4e87..0edb866bf34e 100644 --- a/sdk/azcore/arm/runtime/policy_register_rp_test.go +++ b/sdk/azcore/arm/runtime/policy_register_rp_test.go @@ -24,7 +24,7 @@ import ( "github.com/stretchr/testify/require" ) -const rpUnregisteredResp = `{ +const rpUnregisteredResp1 = `{ "error":{ "code":"MissingSubscriptionRegistration", "message":"The subscription registration is in 'Unregistered' state. The subscription must be registered to use namespace 'Microsoft.Storage'. See https://aka.ms/rps-not-found for how to register subscriptions.", @@ -37,6 +37,19 @@ const rpUnregisteredResp = `{ } }` +const rpUnregisteredResp2 = `{ + "error":{ + "code":"MissingRegistrationForResourceProvider", + "message":"The subscription registration is in 'Unregistered' state. The subscription must be registered to use namespace 'Microsoft.Storage'. See https://aka.ms/rps-not-found for how to register subscriptions.", + "details":[{ + "code":"MissingRegistrationForResourceProvider", + "target":"Microsoft.Storage", + "message":"The subscription registration is in 'Unregistered' state. The subscription must be registered to use namespace 'Microsoft.Storage'. See https://aka.ms/rps-not-found for how to register subscriptions." + } + ] + } +}` + // some content was omitted here as it's not relevant const rpRegisteringResp = `{ "id": "/subscriptions/00000000-0000-0000-0000-000000000000/providers/Microsoft.Storage", @@ -89,7 +102,7 @@ func TestRPRegistrationPolicySuccess(t *testing.T) { srv, close := mock.NewServer() defer close() // initial response that RP is unregistered - srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp))) + srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp1))) // polling responses to Register() and Get(), in progress srv.RepeatResponse(5, mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(rpRegisteringResp))) // polling response, successful registration @@ -180,7 +193,7 @@ func TestRPRegistrationPolicyTimesOut(t *testing.T) { srv, close := mock.NewServer() defer close() // initial response that RP is unregistered - srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp))) + srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp1))) // polling responses to Register() and Get(), in progress but slow // tests registration takes too long, times out srv.RepeatResponse(10, mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(rpRegisteringResp)), mock.WithSlowResponse(400*time.Millisecond)) @@ -212,7 +225,7 @@ func TestRPRegistrationPolicyExceedsAttempts(t *testing.T) { // add a cycle of unregistered->registered so that we keep retrying and hit the cap for i := 0; i < 4; i++ { // initial response that RP is unregistered - srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp))) + srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp1))) // polling responses to Register() and Get(), in progress srv.RepeatResponse(2, mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(rpRegisteringResp))) // polling response, successful registration @@ -246,7 +259,7 @@ func TestRPRegistrationPolicyCanCancel(t *testing.T) { srv, close := mock.NewServer() defer close() // initial response that RP is unregistered - srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp))) + srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp2))) // polling responses to Register() and Get(), in progress but slow so we have time to cancel srv.RepeatResponse(10, mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(rpRegisteringResp)), mock.WithSlowResponse(300*time.Millisecond)) // log only RP registration @@ -287,7 +300,7 @@ func TestRPRegistrationPolicyDisabled(t *testing.T) { srv, close := mock.NewServer() defer close() // initial response that RP is unregistered - srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp))) + srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp2))) ops := testRPRegistrationOptions(srv) ops.MaxAttempts = -1 client := newFakeClient(t, srv, ops) @@ -305,7 +318,7 @@ func TestRPRegistrationPolicyDisabled(t *testing.T) { require.Error(t, err) var respErr *exported.ResponseError require.ErrorAs(t, err, &respErr) - require.EqualValues(t, "MissingSubscriptionRegistration", respErr.ErrorCode) + require.EqualValues(t, "MissingRegistrationForResourceProvider", respErr.ErrorCode) require.Zero(t, resp) // shouldn't be any log entries require.Zero(t, logEntries) @@ -315,7 +328,7 @@ func TestRPRegistrationPolicyAudience(t *testing.T) { srv, close := mock.NewServer() defer close() // initial response that RP is unregistered - srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp))) + srv.AppendResponse(mock.WithStatusCode(http.StatusConflict), mock.WithBody([]byte(rpUnregisteredResp2))) // polling responses to Register() and Get(), in progress srv.AppendResponse(mock.WithStatusCode(http.StatusOK), mock.WithBody([]byte(rpRegisteringResp))) // polling response, successful registration @@ -399,6 +412,11 @@ func TestRPRegistrationPolicyEnvironmentsInSubExceeded(t *testing.T) { require.EqualValues(t, 0, logEntries) } +func TestIsUnregisteredRPCode(t *testing.T) { + require.True(t, isUnregisteredRPCode("Subscription Not Registered")) + require.False(t, isUnregisteredRPCode("Your subscription isn't registered")) +} + type fakeClient struct { ep string pl runtime.Pipeline diff --git a/sdk/azcore/core.go b/sdk/azcore/core.go index 72c2cf21eef3..27231ad920d5 100644 --- a/sdk/azcore/core.go +++ b/sdk/azcore/core.go @@ -76,12 +76,13 @@ type Client struct { } // NewClient creates a new Client instance with the provided values. -// - clientName - the fully qualified name of the client ("package.Client"); this is used by the tracing provider when creating spans +// - clientName - the fully qualified name of the client ("module/package.Client"); this is used by the telemetry policy and tracing provider. +// if module and package are the same value, the "module/" prefix can be omitted. // - moduleVersion - the semantic version of the containing module; used by the telemetry policy // - plOpts - pipeline configuration options; can be the zero-value // - options - optional client configurations; pass nil to accept the default values func NewClient(clientName, moduleVersion string, plOpts runtime.PipelineOptions, options *ClientOptions) (*Client, error) { - pkg, err := shared.ExtractPackageName(clientName) + mod, client, err := shared.ExtractModuleName(clientName) if err != nil { return nil, err } @@ -96,9 +97,9 @@ func NewClient(clientName, moduleVersion string, plOpts runtime.PipelineOptions, } } - pl := runtime.NewPipeline(pkg, moduleVersion, plOpts, options) + pl := runtime.NewPipeline(mod, moduleVersion, plOpts, options) - tr := options.TracingProvider.NewTracer(clientName, moduleVersion) + tr := options.TracingProvider.NewTracer(client, moduleVersion) return &Client{pl: pl, tr: tr}, nil } diff --git a/sdk/azcore/internal/shared/constants.go b/sdk/azcore/internal/shared/constants.go index 681167bcba57..269a831ed178 100644 --- a/sdk/azcore/internal/shared/constants.go +++ b/sdk/azcore/internal/shared/constants.go @@ -32,5 +32,5 @@ const ( Module = "azcore" // Version is the semantic version (see http://semver.org) of this module. - Version = "v1.6.0" + Version = "v1.6.1" ) diff --git a/sdk/azcore/internal/shared/shared.go b/sdk/azcore/internal/shared/shared.go index 930ab8c83999..db0aaa7cb956 100644 --- a/sdk/azcore/internal/shared/shared.go +++ b/sdk/azcore/internal/shared/shared.go @@ -13,7 +13,6 @@ import ( "reflect" "regexp" "strconv" - "strings" "time" ) @@ -79,14 +78,26 @@ func ValidateModVer(moduleVersion string) error { return nil } -// ExtractPackageName returns "package" from "package.Client". +// ExtractModuleName returns "module", "package.Client" from "module/package.Client" or +// "package", "package.Client" from "package.Client" when there's no "module/" prefix. // If clientName is malformed, an error is returned. -func ExtractPackageName(clientName string) (string, error) { - pkg, client, ok := strings.Cut(clientName, ".") - if !ok { - return "", fmt.Errorf("missing . in clientName %s", clientName) - } else if pkg == "" || client == "" { - return "", fmt.Errorf("malformed clientName %s", clientName) +func ExtractModuleName(clientName string) (string, string, error) { + // uses unnamed capturing for "module", "package.Client", and "package" + regex, err := regexp.Compile(`^(?:([a-z0-9]+)/)?(([a-z0-9]+)\.(?:[A-Za-z0-9]+))$`) + if err != nil { + return "", "", err } - return pkg, nil + + matches := regex.FindStringSubmatch(clientName) + if len(matches) < 4 { + return "", "", fmt.Errorf("malformed clientName %s", clientName) + } + + // the first match is the entire string, the second is "module", the third is + // "package.Client" and the fourth is "package". + // if there was no "module/" prefix, the second match will be the empty string + if matches[1] != "" { + return matches[1], matches[2], nil + } + return matches[3], matches[2], nil } diff --git a/sdk/azcore/internal/shared/shared_test.go b/sdk/azcore/internal/shared/shared_test.go index d868bc6e035d..f283d8921f31 100644 --- a/sdk/azcore/internal/shared/shared_test.go +++ b/sdk/azcore/internal/shared/shared_test.go @@ -85,24 +85,54 @@ func TestValidateModVer(t *testing.T) { require.Error(t, ValidateModVer("v1.2")) } -func TestExtractPackageName(t *testing.T) { - pkg, err := ExtractPackageName("package.Client") +func TestExtractModuleName(t *testing.T) { + mod, client, err := ExtractModuleName("module/package.Client") require.NoError(t, err) - require.Equal(t, "package", pkg) + require.Equal(t, "module", mod) + require.Equal(t, "package.Client", client) - pkg, err = ExtractPackageName("malformed") + mod, client, err = ExtractModuleName("malformed/") require.Error(t, err) - require.Empty(t, pkg) + require.Empty(t, mod) + require.Empty(t, client) - pkg, err = ExtractPackageName(".malformed") + mod, client, err = ExtractModuleName("malformed/malformed") require.Error(t, err) - require.Empty(t, pkg) + require.Empty(t, mod) + require.Empty(t, client) - pkg, err = ExtractPackageName("malformed.") + mod, client, err = ExtractModuleName("malformed/malformed.") require.Error(t, err) - require.Empty(t, pkg) + require.Empty(t, mod) + require.Empty(t, client) - pkg, err = ExtractPackageName("") + mod, client, err = ExtractModuleName("malformed/.malformed") require.Error(t, err) - require.Empty(t, pkg) + require.Empty(t, mod) + require.Empty(t, client) + + mod, client, err = ExtractModuleName("package.Client") + require.NoError(t, err) + require.Equal(t, "package", mod) + require.Equal(t, "package.Client", client) + + mod, client, err = ExtractModuleName("malformed") + require.Error(t, err) + require.Empty(t, mod) + require.Empty(t, client) + + mod, client, err = ExtractModuleName(".malformed") + require.Error(t, err) + require.Empty(t, mod) + require.Empty(t, client) + + mod, client, err = ExtractModuleName("malformed.") + require.Error(t, err) + require.Empty(t, mod) + require.Empty(t, client) + + mod, client, err = ExtractModuleName("") + require.Error(t, err) + require.Empty(t, mod) + require.Empty(t, client) } diff --git a/sdk/azcore/runtime/policy_retry.go b/sdk/azcore/runtime/policy_retry.go index 5f52ba75b459..e0c5929f3b70 100644 --- a/sdk/azcore/runtime/policy_retry.go +++ b/sdk/azcore/runtime/policy_retry.go @@ -125,7 +125,8 @@ func (p *retryPolicy) Do(req *policy.Request) (resp *http.Response, err error) { } if options.TryTimeout == 0 { - resp, err = req.Next() + clone := req.Clone(req.Raw().Context()) + resp, err = clone.Next() } else { // Set the per-try time for this particular retry operation and then Do the operation. tryCtx, tryCancel := context.WithTimeout(req.Raw().Context(), options.TryTimeout) diff --git a/sdk/azcore/runtime/policy_retry_test.go b/sdk/azcore/runtime/policy_retry_test.go index 61ce081b4204..228d0931a025 100644 --- a/sdk/azcore/runtime/policy_retry_test.go +++ b/sdk/azcore/runtime/policy_retry_test.go @@ -810,6 +810,39 @@ func TestRetryableRequestBodyWithCloser(t *testing.T) { require.True(t, tr.closeCalled) } +func TestRetryPolicySuccessWithRetryPreserveHeaders(t *testing.T) { + srv, close := mock.NewServer() + defer close() + srv.AppendResponse(mock.WithStatusCode(http.StatusRequestTimeout)) + srv.AppendResponse() + pl := exported.NewPipeline(srv, NewRetryPolicy(testRetryOptions()), exported.PolicyFunc(challengeLikePolicy)) + req, err := NewRequest(context.Background(), http.MethodGet, srv.URL()) + require.NoError(t, err) + body := newRewindTrackingBody("stuff") + require.NoError(t, req.SetBody(body, "text/plain")) + resp, err := pl.Do(req) + require.NoError(t, err) + require.EqualValues(t, http.StatusOK, resp.StatusCode) + require.EqualValues(t, 2, srv.Requests()) + require.EqualValues(t, 1, body.rcount) + require.True(t, body.closed) +} + +func challengeLikePolicy(req *policy.Request) (*http.Response, error) { + if req.Body() == nil { + return nil, errors.New("request body wasn't restored") + } + if req.Raw().Header.Get("content-type") != "text/plain" { + return nil, errors.New("content-type header wasn't restored") + } + + // remove the body and header. the retry policy should restore them + if err := req.SetBody(nil, ""); err != nil { + return nil, err + } + return req.Next() +} + func newRewindTrackingBody(s string) *rewindTrackingBody { // there are two rewinds that happen before rewinding for a retry // 1. to get the body's size in SetBody()