Skip to content

Commit 8da3601

Browse files
committed
fixups
1 parent a305e70 commit 8da3601

File tree

5 files changed

+136
-17
lines changed

5 files changed

+136
-17
lines changed

coderd/database/dbmem/dbmem.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1187,7 +1187,7 @@ func (q *FakeQuerier) CustomRoles(_ context.Context, arg database.CustomRolesPar
11871187
role := role
11881188
if len(arg.LookupRoles) > 0 {
11891189
if !slices.ContainsFunc(arg.LookupRoles, func(pair database.NameOrganizationPair) bool {
1190-
if !strings.EqualFold(pair.Name, role.Name) {
1190+
if pair.Name != role.Name {
11911191
return false
11921192
}
11931193

coderd/database/querier_test.go

Lines changed: 118 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,7 @@ func TestReadCustomRoles(t *testing.T) {
527527
sqlDB := testSQLDB(t)
528528
err := migrations.Up(sqlDB)
529529
require.NoError(t, err)
530+
530531
db := database.New(sqlDB)
531532
ctx := testutil.Context(t, testutil.WaitLong)
532533

@@ -536,28 +537,44 @@ func TestReadCustomRoles(t *testing.T) {
536537
orgIDs[i] = uuid.New()
537538
}
538539

539-
roles := make([]database.CustomRole, 0)
540+
allRoles := make([]database.CustomRole, 0)
541+
siteRoles := make([]database.CustomRole, 0)
542+
orgRoles := make([]database.CustomRole, 0)
540543
for i := 0; i < 15; i++ {
541-
orgID := uuid.NullUUID{}
542-
543-
orgID = uuid.NullUUID{
544+
orgID := uuid.NullUUID{
544545
UUID: orgIDs[i%len(orgIDs)],
545546
Valid: true,
546547
}
548+
if i%4 == 0 {
549+
// Some should be site wide
550+
orgID = uuid.NullUUID{}
551+
}
547552

548553
role, err := db.UpsertCustomRole(ctx, database.UpsertCustomRoleParams{
549554
Name: fmt.Sprintf("role-%d", i),
550555
OrganizationID: orgID,
551556
})
552557
require.NoError(t, err)
553-
roles = append(roles, role)
558+
allRoles = append(allRoles, role)
559+
if orgID.Valid {
560+
orgRoles = append(orgRoles, role)
561+
} else {
562+
siteRoles = append(siteRoles, role)
563+
}
554564
}
555565

556566
// normalizedRoleName allows for the simple ElementsMatch to work properly.
557567
normalizedRoleName := func(role database.CustomRole) string {
558568
return role.Name + ":" + role.OrganizationID.UUID.String()
559569
}
560570

571+
roleToLookup := func(role database.CustomRole) database.NameOrganizationPair {
572+
return database.NameOrganizationPair{
573+
Name: role.Name,
574+
OrganizationID: role.OrganizationID.UUID,
575+
}
576+
}
577+
561578
testCases := []struct {
562579
Name string
563580
Params database.CustomRolesParams
@@ -598,17 +615,108 @@ func TestReadCustomRoles(t *testing.T) {
598615
},
599616
},
600617
{
601-
Name: "SpecificRole",
618+
Name: "SpecificOrgRole",
619+
Params: database.CustomRolesParams{
620+
LookupRoles: []database.NameOrganizationPair{
621+
{
622+
Name: orgRoles[0].Name,
623+
OrganizationID: orgRoles[0].OrganizationID.UUID,
624+
},
625+
},
626+
},
627+
Match: func(role database.CustomRole) bool {
628+
return role.Name == orgRoles[0].Name && role.OrganizationID.UUID == orgRoles[0].OrganizationID.UUID
629+
},
630+
},
631+
{
632+
Name: "SpecificSiteRole",
633+
Params: database.CustomRolesParams{
634+
LookupRoles: []database.NameOrganizationPair{
635+
{
636+
Name: siteRoles[0].Name,
637+
OrganizationID: siteRoles[0].OrganizationID.UUID,
638+
},
639+
},
640+
},
641+
Match: func(role database.CustomRole) bool {
642+
return role.Name == siteRoles[0].Name && role.OrganizationID.UUID == siteRoles[0].OrganizationID.UUID
643+
},
644+
},
645+
{
646+
Name: "FewSpecificRoles",
647+
Params: database.CustomRolesParams{
648+
LookupRoles: []database.NameOrganizationPair{
649+
{
650+
Name: orgRoles[0].Name,
651+
OrganizationID: orgRoles[0].OrganizationID.UUID,
652+
},
653+
{
654+
Name: orgRoles[1].Name,
655+
OrganizationID: orgRoles[1].OrganizationID.UUID,
656+
},
657+
{
658+
Name: siteRoles[0].Name,
659+
OrganizationID: siteRoles[0].OrganizationID.UUID,
660+
},
661+
},
662+
},
663+
Match: func(role database.CustomRole) bool {
664+
return (role.Name == orgRoles[0].Name && role.OrganizationID.UUID == orgRoles[0].OrganizationID.UUID) ||
665+
(role.Name == orgRoles[1].Name && role.OrganizationID.UUID == orgRoles[1].OrganizationID.UUID) ||
666+
(role.Name == siteRoles[0].Name && role.OrganizationID.UUID == siteRoles[0].OrganizationID.UUID)
667+
},
668+
},
669+
{
670+
Name: "AllRolesByLookup",
671+
Params: database.CustomRolesParams{
672+
LookupRoles: db2sdk.List(allRoles, roleToLookup),
673+
},
674+
Match: func(role database.CustomRole) bool {
675+
return true
676+
},
677+
},
678+
{
679+
Name: "NotExists",
602680
Params: database.CustomRolesParams{
603681
LookupRoles: []database.NameOrganizationPair{
604682
{
605-
Name: roles[0].Name,
606-
OrganizationID: roles[0].OrganizationID.UUID,
683+
Name: "not-exists",
684+
OrganizationID: uuid.New(),
685+
},
686+
{
687+
Name: "not-exists",
688+
OrganizationID: uuid.Nil,
689+
},
690+
},
691+
},
692+
Match: func(role database.CustomRole) bool {
693+
return false
694+
},
695+
},
696+
{
697+
Name: "Mixed",
698+
Params: database.CustomRolesParams{
699+
LookupRoles: []database.NameOrganizationPair{
700+
{
701+
Name: "not-exists",
702+
OrganizationID: uuid.New(),
703+
},
704+
{
705+
Name: "not-exists",
706+
OrganizationID: uuid.Nil,
707+
},
708+
{
709+
Name: orgRoles[0].Name,
710+
OrganizationID: orgRoles[0].OrganizationID.UUID,
711+
},
712+
{
713+
Name: siteRoles[0].Name,
607714
},
608715
},
609716
},
610717
Match: func(role database.CustomRole) bool {
611-
return role.Name == roles[0].Name && role.OrganizationID.UUID == roles[0].OrganizationID.UUID
718+
return (role.Name == orgRoles[0].Name && role.OrganizationID.UUID == orgRoles[0].OrganizationID.UUID) ||
719+
(role.Name == siteRoles[0].Name && role.OrganizationID.UUID == siteRoles[0].OrganizationID.UUID)
612720
},
613721
},
614722
}
@@ -619,12 +727,11 @@ func TestReadCustomRoles(t *testing.T) {
619727
t.Run(tc.Name, func(t *testing.T) {
620728
t.Parallel()
621729

622-
t.Log(tc.Params)
623730
ctx := testutil.Context(t, testutil.WaitLong)
624731
found, err := db.CustomRoles(ctx, tc.Params)
625732
require.NoError(t, err)
626733
filtered := make([]database.CustomRole, 0)
627-
for _, role := range roles {
734+
for _, role := range allRoles {
628735
if tc.Match(role) {
629736
filtered = append(filtered, role)
630737
}

coderd/database/queries.sql.go

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries/roles.sql

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ WHERE
1414
END
1515
-- This allows fetching all roles, or just site wide roles
1616
AND CASE WHEN @exclude_org_roles :: boolean THEN
17-
organization_id IS null OR true
17+
organization_id IS null
1818
ELSE true
1919
END
2020
-- Allows fetching all roles to a particular organization
2121
AND CASE WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
22-
organization_id = @organization_id OR true
22+
organization_id = @organization_id
2323
ELSE true
2424
END
2525
;

coderd/database/types.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,21 @@ func (*NameOrganizationPair) Scan(_ interface{}) error {
155155
return xerrors.Errorf("this should never happen, type 'NameOrganizationPair' should only be used as a parameter")
156156
}
157157

158+
// Value returns the tuple **literal**
159+
// To get the literal value to return, you can use the expression syntax in a psql
160+
// shell.
161+
//
162+
// SELECT ('customrole'::text,'ece79dac-926e-44ca-9790-2ff7c5eb6e0c'::uuid);
163+
// To see 'null' option
164+
// SELECT ('customrole',null);
165+
//
166+
// This value is usually used as an array, NameOrganizationPair[]. You can see
167+
// what that literal is as well, with proper quoting.
168+
//
169+
// SELECT ARRAY[('customrole'::text,'ece79dac-926e-44ca-9790-2ff7c5eb6e0c'::uuid)];
158170
func (a NameOrganizationPair) Value() (driver.Value, error) {
159171
if a.OrganizationID == uuid.Nil {
160-
return fmt.Sprintf(`('%s', NULL)`, a.Name), nil
172+
return fmt.Sprintf(`('%s',)`, a.Name), nil
161173
}
162174

163175
return fmt.Sprintf(`(%s,%s)`, a.Name, a.OrganizationID.String()), nil

0 commit comments

Comments
 (0)