Skip to content

Commit 298093d

Browse files
authored
Pass context all the way from API controller through pyramid (treeverse#1259)
* Plumb context through tier_fs, this time for real Get context when reading from surrounding call, not from setup time. * Use context from dependencies in API controller Fixes treeverse#1248.
1 parent ab82a01 commit 298093d

File tree

6 files changed

+68
-58
lines changed

6 files changed

+68
-58
lines changed

api/api_controller.go

Lines changed: 29 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,6 @@ func NewController(cataloger catalog.Cataloger, auth auth.Service, blockAdapter
120120
return c
121121
}
122122

123-
func (c *Controller) Context() context.Context {
124-
if c.deps.ctx != nil {
125-
return c.deps.ctx
126-
}
127-
return context.Background()
128-
}
129-
130123
// Configure attaches our API operations to a generated swagger API stub
131124
// Adding new handlers requires also adding them here so that the generated server will use them
132125
func (c *Controller) Configure(api *operations.LakefsAPI) {
@@ -312,7 +305,7 @@ func (c *Controller) ListRepositoriesHandler() repositories.ListRepositoriesHand
312305

313306
after, amount := getPaginationParams(params.After, params.Amount)
314307

315-
repos, hasMore, err := deps.Cataloger.ListRepositories(c.Context(), amount, after)
308+
repos, hasMore, err := deps.Cataloger.ListRepositories(deps.ctx, amount, after)
316309
if err != nil {
317310
return repositories.NewListRepositoriesDefault(http.StatusInternalServerError).
318311
WithPayload(responseError("error listing repositories: %s", err))
@@ -372,7 +365,7 @@ func (c *Controller) GetRepoHandler() repositories.GetRepositoryHandler {
372365
return repositories.NewGetRepositoryUnauthorized().WithPayload(responseErrorFrom(err))
373366
}
374367
deps.LogAction("get_repo")
375-
repo, err := deps.Cataloger.GetRepository(c.Context(), params.Repository)
368+
repo, err := deps.Cataloger.GetRepository(deps.ctx, params.Repository)
376369
if errors.Is(err, db.ErrNotFound) {
377370
return repositories.NewGetRepositoryNotFound().
378371
WithPayload(responseError("repository not found"))
@@ -404,7 +397,7 @@ func (c *Controller) GetCommitHandler() commits.GetCommitHandler {
404397
return commits.NewGetCommitUnauthorized().WithPayload(responseErrorFrom(err))
405398
}
406399
deps.LogAction("get_commit")
407-
commit, err := deps.Cataloger.GetCommit(c.Context(), params.Repository, params.CommitID)
400+
commit, err := deps.Cataloger.GetCommit(deps.ctx, params.Repository, params.CommitID)
408401
if errors.Is(err, db.ErrNotFound) {
409402
return commits.NewGetCommitNotFound().WithPayload(responseError("commit not found"))
410403
}
@@ -440,7 +433,7 @@ func (c *Controller) CommitHandler() commits.CommitHandler {
440433
}
441434
committer := userModel.Username
442435
commitMessage := swag.StringValue(params.Commit.Message)
443-
commit, err := deps.Cataloger.Commit(c.Context(), params.Repository,
436+
commit, err := deps.Cataloger.Commit(deps.ctx, params.Repository,
444437
params.Branch, commitMessage, committer, params.Commit.Metadata)
445438
if err != nil {
446439
return commits.NewCommitDefault(http.StatusInternalServerError).WithPayload(responseErrorFrom(err))
@@ -472,7 +465,7 @@ func (c *Controller) CommitsGetBranchCommitLogHandler() commits.GetBranchCommitL
472465

473466
after, amount := getPaginationParams(params.After, params.Amount)
474467
// get commit log
475-
commitLog, hasMore, err := cataloger.ListCommits(c.Context(), params.Repository, params.Branch, after, amount)
468+
commitLog, hasMore, err := cataloger.ListCommits(deps.ctx, params.Repository, params.Branch, after, amount)
476469
switch {
477470
case errors.Is(err, catalog.ErrBranchNotFound) || errors.Is(err, graveler.ErrBranchNotFound):
478471
return commits.NewGetBranchCommitLogNotFound().WithPayload(responseError("branch '%s' not found.", params.Branch))
@@ -552,7 +545,7 @@ func (c *Controller) CreateRepositoryHandler() repositories.CreateRepositoryHand
552545
return repositories.NewCreateRepositoryBadRequest().
553546
WithPayload(responseError("error creating repository: could not access storage namespace"))
554547
}
555-
repo, err := deps.Cataloger.CreateRepository(c.Context(),
548+
repo, err := deps.Cataloger.CreateRepository(deps.ctx,
556549
swag.StringValue(params.Repository.Name),
557550
swag.StringValue(params.Repository.StorageNamespace),
558551
params.Repository.DefaultBranch)
@@ -582,7 +575,7 @@ func (c *Controller) DeleteRepositoryHandler() repositories.DeleteRepositoryHand
582575
return repositories.NewDeleteRepositoryUnauthorized().WithPayload(responseErrorFrom(err))
583576
}
584577
deps.LogAction("delete_repo")
585-
err = deps.Cataloger.DeleteRepository(c.Context(), params.Repository)
578+
err = deps.Cataloger.DeleteRepository(deps.ctx, params.Repository)
586579
if errors.Is(err, db.ErrNotFound) {
587580
return repositories.NewDeleteRepositoryNotFound().WithPayload(responseError("repository not found"))
588581
}
@@ -610,7 +603,7 @@ func (c *Controller) ListBranchesHandler() branches.ListBranchesHandler {
610603

611604
after, amount := getPaginationParams(params.After, params.Amount)
612605

613-
res, hasMore, err := cataloger.ListBranches(c.Context(), params.Repository, "", amount, after)
606+
res, hasMore, err := cataloger.ListBranches(deps.ctx, params.Repository, "", amount, after)
614607
if err != nil {
615608
return branches.NewListBranchesDefault(http.StatusInternalServerError).
616609
WithPayload(responseError("could not list branches: %s", err))
@@ -654,7 +647,7 @@ func (c *Controller) GetBranchHandler() branches.GetBranchHandler {
654647
return branches.NewGetBranchUnauthorized().WithPayload(responseErrorFrom(err))
655648
}
656649
deps.LogAction("get_branch")
657-
reference, err := deps.Cataloger.GetBranchReference(c.Context(), params.Repository, params.Branch)
650+
reference, err := deps.Cataloger.GetBranchReference(deps.ctx, params.Repository, params.Branch)
658651

659652
switch {
660653
case errors.Is(err, catalog.ErrBranchNotFound) || errors.Is(err, graveler.ErrBranchNotFound):
@@ -685,7 +678,7 @@ func (c *Controller) CreateBranchHandler() branches.CreateBranchHandler {
685678
deps.LogAction("create_branch")
686679
cataloger := deps.Cataloger
687680
sourceRef := swag.StringValue(params.Branch.Source)
688-
commitLog, err := cataloger.CreateBranch(c.Context(), repository, branch, sourceRef)
681+
commitLog, err := cataloger.CreateBranch(deps.ctx, repository, branch, sourceRef)
689682
if err != nil {
690683
return branches.NewCreateBranchDefault(http.StatusInternalServerError).WithPayload(responseErrorFrom(err))
691684
}
@@ -706,7 +699,7 @@ func (c *Controller) DeleteBranchHandler() branches.DeleteBranchHandler {
706699
}
707700
deps.LogAction("delete_branch")
708701
cataloger := deps.Cataloger
709-
err = cataloger.DeleteBranch(c.Context(), params.Repository, params.Branch)
702+
err = cataloger.DeleteBranch(deps.ctx, params.Repository, params.Branch)
710703
switch {
711704
case errors.Is(err, catalog.ErrBranchNotFound) || errors.Is(err, graveler.ErrBranchNotFound):
712705
return branches.NewDeleteBranchNotFound().WithPayload(responseError("branch '%s' not found.", params.Branch))
@@ -742,7 +735,7 @@ func (c *Controller) MergeMergeIntoBranchHandler() refs.MergeIntoBranchHandler {
742735
message = params.Merge.Message
743736
metadata = params.Merge.Metadata
744737
}
745-
res, err := deps.Cataloger.Merge(c.Context(),
738+
res, err := deps.Cataloger.Merge(deps.ctx,
746739
params.Repository, params.SourceRef, params.DestinationRef,
747740
userModel.Username,
748741
message,
@@ -808,7 +801,7 @@ func (c *Controller) BranchesDiffBranchHandler() branches.DiffBranchHandler {
808801
cataloger := deps.Cataloger
809802
limit := int(swag.Int64Value(params.Amount))
810803
after := swag.StringValue(params.After)
811-
diff, hasMore, err := cataloger.DiffUncommitted(c.Context(), params.Repository, params.Branch, limit, after)
804+
diff, hasMore, err := cataloger.DiffUncommitted(deps.ctx, params.Repository, params.Branch, limit, after)
812805
if err != nil {
813806
return branches.NewDiffBranchDefault(http.StatusInternalServerError).
814807
WithPayload(responseError("could not diff branch: %s", err))
@@ -849,7 +842,7 @@ func (c *Controller) RefsDiffRefsHandler() refs.DiffRefsHandler {
849842
cataloger := deps.Cataloger
850843
limit := int(swag.Int64Value(params.Amount))
851844
after := swag.StringValue(params.After)
852-
diff, hasMore, err := cataloger.Diff(c.Context(), params.Repository, params.LeftRef, params.RightRef, catalog.DiffParams{
845+
diff, hasMore, err := cataloger.Diff(deps.ctx, params.Repository, params.LeftRef, params.RightRef, catalog.DiffParams{
853846
Limit: limit,
854847
After: after,
855848
})
@@ -895,15 +888,15 @@ func (c *Controller) ObjectsStatObjectHandler() objects.StatObjectHandler {
895888
deps.LogAction("stat_object")
896889
cataloger := deps.Cataloger
897890

898-
entry, err := cataloger.GetEntry(c.Context(), params.Repository, params.Ref, params.Path, catalog.GetEntryParams{ReturnExpired: true})
891+
entry, err := cataloger.GetEntry(deps.ctx, params.Repository, params.Ref, params.Path, catalog.GetEntryParams{ReturnExpired: true})
899892
if errors.Is(err, db.ErrNotFound) {
900893
return objects.NewStatObjectNotFound().WithPayload(responseError("resource not found"))
901894
}
902895
if err != nil {
903896
return objects.NewStatObjectDefault(http.StatusInternalServerError).WithPayload(responseErrorFrom(err))
904897
}
905898

906-
repo, err := cataloger.GetRepository(c.Context(), params.Repository)
899+
repo, err := cataloger.GetRepository(deps.ctx, params.Repository)
907900
if err != nil {
908901
return objects.NewStatObjectDefault(http.StatusInternalServerError).WithPayload(responseErrorFrom(err))
909902
}
@@ -945,15 +938,15 @@ func (c *Controller) ObjectsGetUnderlyingPropertiesHandler() objects.GetUnderlyi
945938
cataloger := deps.Cataloger
946939

947940
// read repo
948-
repo, err := cataloger.GetRepository(c.Context(), params.Repository)
941+
repo, err := cataloger.GetRepository(deps.ctx, params.Repository)
949942
if errors.Is(err, db.ErrNotFound) {
950943
return objects.NewGetUnderlyingPropertiesNotFound().WithPayload(responseError("resource not found"))
951944
}
952945
if err != nil {
953946
return objects.NewGetUnderlyingPropertiesDefault(http.StatusInternalServerError).WithPayload(responseErrorFrom(err))
954947
}
955948

956-
entry, err := cataloger.GetEntry(c.Context(), params.Repository, params.Ref, params.Path, catalog.GetEntryParams{})
949+
entry, err := cataloger.GetEntry(deps.ctx, params.Repository, params.Ref, params.Path, catalog.GetEntryParams{})
957950
if errors.Is(err, db.ErrNotFound) {
958951
return objects.NewGetUnderlyingPropertiesNotFound().WithPayload(responseError("resource not found"))
959952
}
@@ -989,7 +982,7 @@ func (c *Controller) ObjectsGetObjectHandler() objects.GetObjectHandler {
989982
cataloger := deps.Cataloger
990983

991984
// read repo
992-
repo, err := cataloger.GetRepository(c.Context(), params.Repository)
985+
repo, err := cataloger.GetRepository(deps.ctx, params.Repository)
993986
if errors.Is(err, db.ErrNotFound) {
994987
return objects.NewGetObjectNotFound().WithPayload(responseError("resource not found"))
995988
}
@@ -998,7 +991,7 @@ func (c *Controller) ObjectsGetObjectHandler() objects.GetObjectHandler {
998991
}
999992

1000993
// read the FS entry
1001-
entry, err := cataloger.GetEntry(c.Context(), params.Repository, params.Ref, params.Path, catalog.GetEntryParams{ReturnExpired: true})
994+
entry, err := cataloger.GetEntry(deps.ctx, params.Repository, params.Ref, params.Path, catalog.GetEntryParams{ReturnExpired: true})
1002995
if errors.Is(err, db.ErrNotFound) {
1003996
return objects.NewGetObjectNotFound().WithPayload(responseError("resource not found"))
1004997
}
@@ -1042,7 +1035,7 @@ func (c *Controller) MetadataCreateSymlinkHandler() metadataop.CreateSymlinkHand
10421035
cataloger := deps.Cataloger
10431036

10441037
// read repo
1045-
repo, err := cataloger.GetRepository(c.Context(), params.Repository)
1038+
repo, err := cataloger.GetRepository(deps.ctx, params.Repository)
10461039
if errors.Is(err, db.ErrNotFound) {
10471040
return metadataop.NewCreateSymlinkNotFound().WithPayload(responseError("resource not found"))
10481041
}
@@ -1057,7 +1050,7 @@ func (c *Controller) MetadataCreateSymlinkHandler() metadataop.CreateSymlinkHand
10571050
hasMore := true
10581051
for hasMore {
10591052
entries, hasMore, err = cataloger.ListEntries(
1060-
c.Context(),
1053+
deps.ctx,
10611054
params.Repository,
10621055
params.Branch,
10631056
swag.StringValue(params.Location),
@@ -1136,7 +1129,7 @@ func (c *Controller) ObjectsListObjectsHandler() objects.ListObjectsHandler {
11361129

11371130
delimiter := catalog.DefaultPathDelimiter
11381131
res, hasMore, err := cataloger.ListEntries(
1139-
c.Context(),
1132+
deps.ctx,
11401133
params.Repository,
11411134
params.Ref,
11421135
swag.StringValue(params.Prefix),
@@ -1151,7 +1144,7 @@ func (c *Controller) ObjectsListObjectsHandler() objects.ListObjectsHandler {
11511144
WithPayload(responseError("error while listing objects: %s", err))
11521145
}
11531146

1154-
repo, err := cataloger.GetRepository(c.Context(), params.Repository)
1147+
repo, err := cataloger.GetRepository(deps.ctx, params.Repository)
11551148
if err != nil {
11561149
return objects.NewStatObjectDefault(http.StatusInternalServerError).WithPayload(responseErrorFrom(err))
11571150
}
@@ -1215,15 +1208,15 @@ func (c *Controller) ObjectsUploadObjectHandler() objects.UploadObjectHandler {
12151208
deps.LogAction("put_object")
12161209
cataloger := deps.Cataloger
12171210

1218-
repo, err := cataloger.GetRepository(c.Context(), params.Repository)
1211+
repo, err := cataloger.GetRepository(deps.ctx, params.Repository)
12191212
if errors.Is(err, db.ErrNotFound) {
12201213
return objects.NewUploadObjectNotFound().WithPayload(responseError("repository not found"))
12211214
}
12221215
if err != nil {
12231216
return objects.NewUploadObjectDefault(http.StatusInternalServerError).WithPayload(responseErrorFrom(err))
12241217
}
12251218
// check if branch exists - it is still a possibility, but we don't want to upload large object when the branch was not there in the first place
1226-
branchExists, err := cataloger.BranchExists(c.Context(), params.Repository, params.Branch)
1219+
branchExists, err := cataloger.BranchExists(deps.ctx, params.Repository, params.Branch)
12271220
if err != nil {
12281221
return objects.NewUploadObjectDefault(http.StatusInternalServerError).WithPayload(responseErrorFrom(err))
12291222
}
@@ -1252,7 +1245,7 @@ func (c *Controller) ObjectsUploadObjectHandler() objects.UploadObjectHandler {
12521245
Size: blob.Size,
12531246
Checksum: blob.Checksum,
12541247
}
1255-
err = cataloger.CreateEntry(c.Context(), repo.Name, params.Branch, entry,
1248+
err = cataloger.CreateEntry(deps.ctx, repo.Name, params.Branch, entry,
12561249
catalog.CreateEntryParams{
12571250
Dedup: catalog.DedupParams{
12581251
ID: blob.DedupID,
@@ -1296,7 +1289,7 @@ func (c *Controller) ObjectsDeleteObjectHandler() objects.DeleteObjectHandler {
12961289
deps.LogAction("delete_object")
12971290
cataloger := deps.Cataloger
12981291

1299-
err = cataloger.DeleteEntry(c.Context(), params.Repository, params.Branch, params.Path)
1292+
err = cataloger.DeleteEntry(deps.ctx, params.Repository, params.Branch, params.Path)
13001293
if errors.Is(err, db.ErrNotFound) {
13011294
return objects.NewDeleteObjectNotFound().WithPayload(responseError("resource not found"))
13021295
}
@@ -1322,7 +1315,7 @@ func (c *Controller) RevertBranchHandler() branches.RevertBranchHandler {
13221315
deps.LogAction("revert_branch")
13231316
cataloger := deps.Cataloger
13241317

1325-
ctx := c.Context()
1318+
ctx := deps.ctx
13261319
switch swag.StringValue(params.Revert.Type) {
13271320
case models.RevertCreationTypeCommit:
13281321
err = cataloger.RollbackCommit(ctx, params.Repository, params.Branch, params.Revert.Commit)

cache/reference_counted.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package cache
22

33
import (
4+
"context"
45
"errors"
56
"fmt"
67
"hash/maphash"
@@ -20,7 +21,7 @@ type Derefer func() error
2021
// not a Cache.
2122
type CacheWithDisposal interface {
2223
Name() string
23-
GetOrSet(k interface{}, setFn SetFn) (v interface{}, disposer Derefer, err error)
24+
GetOrSet(ctx context.Context, k interface{}, setFn SetFn) (v interface{}, disposer Derefer, err error)
2425
}
2526

2627
type ParamsWithDisposal struct {
@@ -57,15 +58,15 @@ func (s *shardedCacheWithDisposal) Name() string {
5758
return s.name
5859
}
5960

60-
func (s *shardedCacheWithDisposal) GetOrSet(k interface{}, setFn SetFn) (interface{}, Derefer, error) {
61+
func (s *shardedCacheWithDisposal) GetOrSet(ctx context.Context, k interface{}, setFn SetFn) (interface{}, Derefer, error) {
6162
hash := maphash.Hash{}
6263
hash.SetSeed(s.Seed)
6364
// (Explicitly ignore return value from hash.WriteString: its godoc says "it always
6465
// writes all of s and never fails; the count and error result are for implementing
6566
// io.StringWriter."
6667
_, _ = hash.WriteString(fmt.Sprintf("%+#v", k))
6768
hashVal := hash.Sum64() % uint64(len(s.Shards))
68-
return s.Shards[hashVal].GetOrSet(k, setFn)
69+
return s.Shards[hashVal].GetOrSet(ctx, k, setFn)
6970
}
7071

7172
// SingleThreadedCacheWithDisposal is a CacheWithDisposal that uses a single critical section
@@ -127,7 +128,7 @@ func NewSingleThreadedCacheWithDisposal(p ParamsWithDisposal) *SingleThreadedCac
127128
return ret
128129
}
129130

130-
func (c *SingleThreadedCacheWithDisposal) GetOrSet(k interface{}, setFn SetFn) (interface{}, Derefer, error) {
131+
func (c *SingleThreadedCacheWithDisposal) GetOrSet(ctx context.Context, k interface{}, setFn SetFn) (interface{}, Derefer, error) {
131132
c.mu.Lock()
132133
defer c.mu.Unlock()
133134
var entry *cacheEntry

cache/reference_counted_test.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package cache_test
22

33
import (
4+
"context"
45
"fmt"
56
"sync"
67
"sync/atomic"
@@ -22,6 +23,7 @@ func TestCacheWithDisposal(t *testing.T) {
2223
disposed int32
2324
live int32
2425
}
26+
ctx := context.Background()
2527
var elements [size]*record
2628
ch := make(chan error, parallelism*2)
2729
allErrs := make(chan []error)
@@ -60,6 +62,7 @@ func TestCacheWithDisposal(t *testing.T) {
6062
for j := 0; j < repeats; j++ {
6163
k := j % size
6264
v, release, err := c.GetOrSet(
65+
ctx,
6366
k, func() (interface{}, error) {
6467
e := &record{
6568
key: k,

graveler/sstable/cache.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ type namespaceID struct {
111111
}
112112

113113
func (c *lruCache) GetOrOpen(ctx context.Context, namespace string, id committed.ID) (*sstable.Reader, Derefer, error) {
114-
e, derefer, err := c.c.GetOrSet(namespaceID{namespace, id}, func() (interface{}, error) {
114+
e, derefer, err := c.c.GetOrSet(ctx, namespaceID{namespace, id}, func() (interface{}, error) {
115115
r, err := c.open(ctx, namespace, string(id))
116116
if err != nil {
117117
return nil, fmt.Errorf("open SSTable %s after fetch from next tier: %w", id, err)

logging/logger.go

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ type logrusEntryWrapper struct {
5252
}
5353

5454
func (l *logrusEntryWrapper) WithContext(ctx context.Context) Logger {
55-
return &logrusEntryWrapper{l.e.WithContext(ctx)}
55+
return addFromContext(
56+
&logrusEntryWrapper{l.e.WithContext(ctx)},
57+
ctx,
58+
)
5659
}
5760

5861
func (l *logrusEntryWrapper) WithField(key string, value interface{}) Logger {
@@ -151,14 +154,17 @@ func Default() Logger {
151154
}
152155
}
153156

154-
func FromContext(ctx context.Context) Logger {
155-
log := Default()
157+
func addFromContext(log Logger, ctx context.Context) Logger {
156158
fields := ctx.Value(LogFieldsContextKey)
157-
if fields != nil {
158-
loggerFields := fields.(Fields)
159-
return log.WithFields(loggerFields)
159+
if fields == nil {
160+
return log
160161
}
161-
return log
162+
loggerFields := fields.(Fields)
163+
return log.WithFields(loggerFields)
164+
}
165+
166+
func FromContext(ctx context.Context) Logger {
167+
return addFromContext(Default(), ctx)
162168
}
163169

164170
func AddFields(ctx context.Context, fields Fields) context.Context {

0 commit comments

Comments
 (0)