Skip to content

Commit f37d3bb

Browse files
committed
Merge branch 'network-gateway' into 'master'
fix: allow access to clones from an internal Docker network See merge request postgres-ai/database-lab!802
2 parents 08590ae + 2a76647 commit f37d3bb

File tree

4 files changed

+115
-7
lines changed

4 files changed

+115
-7
lines changed

engine/cmd/database-lab/main.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"syscall"
1919
"time"
2020

21+
"github.com/docker/docker/api/types"
2122
"github.com/docker/docker/client"
2223
"github.com/pkg/errors"
2324

@@ -124,7 +125,9 @@ func main() {
124125
}
125126

126127
// Create a cloning service to provision new clones.
127-
provisioner, err := provision.New(ctx, &cfg.Provision, dbCfg, docker, pm, engProps.InstanceID, internalNetworkID)
128+
networkGateway := getNetworkGateway(docker, internalNetworkID)
129+
130+
provisioner, err := provision.New(ctx, &cfg.Provision, dbCfg, docker, pm, engProps.InstanceID, internalNetworkID, networkGateway)
128131
if err != nil {
129132
log.Errf(errors.WithMessage(err, `error in the "provision" section of the config`).Error())
130133
}
@@ -253,6 +256,22 @@ func main() {
253256
tm.SendEvent(ctxBackground, telemetry.EngineStoppedEvent, telemetry.EngineStopped{Uptime: server.Uptime()})
254257
}
255258

259+
func getNetworkGateway(docker *client.Client, internalNetworkID string) string {
260+
gateway := ""
261+
262+
networkResource, err := docker.NetworkInspect(context.Background(), internalNetworkID, types.NetworkInspectOptions{})
263+
if err != nil {
264+
log.Err(err.Error())
265+
return gateway
266+
}
267+
268+
if len(networkResource.IPAM.Config) > 0 {
269+
gateway = networkResource.IPAM.Config[0].Gateway
270+
}
271+
272+
return gateway
273+
}
274+
256275
func getEngineProperties(ctx context.Context, docker *client.Client, cfg *config.Config) (global.EngineProps, error) {
257276
hostname := os.Getenv("HOSTNAME")
258277
if hostname == "" {

engine/internal/cloning/storage_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ func newProvisioner() (*provision.Provisioner, error) {
8383
From: 1,
8484
To: 5,
8585
},
86-
}, nil, nil, nil, "instID", "nwID")
86+
}, nil, nil, nil, "instID", "nwID", "")
8787
}
8888

8989
func TestLoadingSessionState(t *testing.T) {

engine/internal/provision/mode_local.go

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"regexp"
1616
"sort"
1717
"strconv"
18+
"strings"
1819
"sync"
1920
"sync/atomic"
2021
"time"
@@ -41,6 +42,7 @@ const (
4142
maxNumberOfPortsToCheck = 5
4243
portCheckingTimeout = 3 * time.Second
4344
unknownVersion = "unknown"
45+
wildcardIP = "0.0.0.0"
4446
)
4547

4648
// PortPool describes an available port range for clones.
@@ -73,11 +75,12 @@ type Provisioner struct {
7375
pm *pool.Manager
7476
networkID string
7577
instanceID string
78+
gateway string
7679
}
7780

7881
// New creates a new Provisioner instance.
7982
func New(ctx context.Context, cfg *Config, dbCfg *resources.DB, docker *client.Client, pm *pool.Manager,
80-
instanceID, networkID string) (*Provisioner, error) {
83+
instanceID, networkID, gateway string) (*Provisioner, error) {
8184
if err := IsValidConfig(*cfg); err != nil {
8285
return nil, errors.Wrap(err, "configuration is not valid")
8386
}
@@ -93,6 +96,7 @@ func New(ctx context.Context, cfg *Config, dbCfg *resources.DB, docker *client.C
9396
pm: pm,
9497
networkID: networkID,
9598
instanceID: instanceID,
99+
gateway: gateway,
96100
ports: make([]bool, cfg.PortPool.To-cfg.PortPool.From+1),
97101
}
98102

@@ -435,7 +439,7 @@ func getLatestSnapshot(snapshots []resources.Snapshot) (*resources.Snapshot, err
435439
func (p *Provisioner) RevisePortPool() error {
436440
log.Msg(fmt.Sprintf("Revising availability of the port range [%d - %d]", p.config.PortPool.From, p.config.PortPool.To))
437441

438-
host, err := externalIP()
442+
host, err := hostIP(p.gateway)
439443
if err != nil {
440444
return err
441445
}
@@ -468,13 +472,21 @@ func (p *Provisioner) RevisePortPool() error {
468472
return nil
469473
}
470474

475+
func hostIP(gateway string) (string, error) {
476+
if gateway != "" {
477+
return gateway, nil
478+
}
479+
480+
return externalIP()
481+
}
482+
471483
// allocatePort tries to find a free port and occupy it.
472484
func (p *Provisioner) allocatePort() (uint, error) {
473485
portOpts := p.config.PortPool
474486

475487
attempts := 0
476488

477-
host, err := externalIP()
489+
host, err := hostIP(p.gateway)
478490
if err != nil {
479491
return 0, err
480492
}
@@ -598,6 +610,8 @@ func (p *Provisioner) stopPoolSessions(fsm pool.FSManager, exceptClones map[stri
598610
}
599611

600612
func (p *Provisioner) getAppConfig(pool *resources.Pool, name string, port uint) *resources.AppConfig {
613+
provisionHosts := p.getProvisionHosts()
614+
601615
appConfig := &resources.AppConfig{
602616
CloneName: name,
603617
DockerImage: p.config.DockerImage,
@@ -607,12 +621,33 @@ func (p *Provisioner) getAppConfig(pool *resources.Pool, name string, port uint)
607621
Pool: pool,
608622
ContainerConf: p.config.ContainerConfig,
609623
NetworkID: p.networkID,
610-
ProvisionHosts: p.config.CloneAccessAddresses,
624+
ProvisionHosts: provisionHosts,
611625
}
612626

613627
return appConfig
614628
}
615629

630+
// getProvisionHosts adds an internal Docker gateway to the hosts rule if the user restricts access to IP addresses.
631+
func (p *Provisioner) getProvisionHosts() string {
632+
provisionHosts := p.config.CloneAccessAddresses
633+
634+
if provisionHosts == "" || provisionHosts == wildcardIP {
635+
return provisionHosts
636+
}
637+
638+
hostSet := []string{p.gateway}
639+
640+
for _, hostIP := range strings.Split(provisionHosts, ",") {
641+
if hostIP != p.gateway {
642+
hostSet = append(hostSet, hostIP)
643+
}
644+
}
645+
646+
provisionHosts = strings.Join(hostSet, ",")
647+
648+
return provisionHosts
649+
}
650+
616651
// LastSessionActivity returns the time of the last session activity.
617652
func (p *Provisioner) LastSessionActivity(session *resources.Session, minimumTime time.Time) (*time.Time, error) {
618653
fsm, err := p.pm.GetFSManager(session.Pool)

engine/internal/provision/mode_local_test.go

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func TestPortAllocation(t *testing.T) {
2626
},
2727
}
2828

29-
p, err := New(context.Background(), cfg, &resources.DB{}, &client.Client{}, &pool.Manager{}, "instanceID", "networkID")
29+
p, err := New(context.Background(), cfg, &resources.DB{}, &client.Client{}, &pool.Manager{}, "instanceID", "networkID", "")
3030
require.NoError(t, err)
3131

3232
// Allocate a new port.
@@ -330,3 +330,57 @@ func createTempConfigFile(testCaseDir, fileName string, content string) error {
330330

331331
return os.WriteFile(fn, []byte(content), 0666)
332332
}
333+
334+
func TestProvisionHosts(t *testing.T) {
335+
tests := []struct {
336+
name string
337+
udAddresses string
338+
gateway string
339+
expectedHosts string
340+
}{
341+
{
342+
name: "Empty fields",
343+
udAddresses: "",
344+
gateway: "",
345+
expectedHosts: "",
346+
},
347+
{
348+
name: "Empty user-defined address",
349+
udAddresses: "",
350+
gateway: "172.20.0.1",
351+
expectedHosts: "",
352+
},
353+
{
354+
name: "Wildcard IP",
355+
udAddresses: "0.0.0.0",
356+
gateway: "172.20.0.1",
357+
expectedHosts: "0.0.0.0",
358+
},
359+
{
360+
name: "User-defined address",
361+
udAddresses: "192.168.1.1",
362+
gateway: "172.20.0.1",
363+
expectedHosts: "172.20.0.1,192.168.1.1",
364+
},
365+
{
366+
name: "Multiple user-defined addresses",
367+
udAddresses: "192.168.1.1,10.0.58.1",
368+
gateway: "172.20.0.1",
369+
expectedHosts: "172.20.0.1,192.168.1.1,10.0.58.1",
370+
},
371+
}
372+
373+
for _, tt := range tests {
374+
t.Run(tt.name, func(t *testing.T) {
375+
376+
p := Provisioner{
377+
config: &Config{
378+
CloneAccessAddresses: tt.udAddresses,
379+
},
380+
gateway: tt.gateway,
381+
}
382+
383+
assert.Equal(t, tt.expectedHosts, p.getProvisionHosts())
384+
})
385+
}
386+
}

0 commit comments

Comments
 (0)