Skip to content

Commit c7f52b7

Browse files
authored
feat(coderd): add prometheus metrics to servertailnet (#11988)
1 parent c84a637 commit c7f52b7

File tree

5 files changed

+165
-62
lines changed

5 files changed

+165
-62
lines changed

coderd/coderd.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ func New(options *Options) *API {
472472

473473
api.Auditor.Store(&options.Auditor)
474474
api.TailnetCoordinator.Store(&options.TailnetCoordinator)
475-
api.agentProvider, err = NewServerTailnet(api.ctx,
475+
stn, err := NewServerTailnet(api.ctx,
476476
options.Logger,
477477
options.DERPServer,
478478
api.DERPMap,
@@ -485,6 +485,10 @@ func New(options *Options) *API {
485485
if err != nil {
486486
panic("failed to setup server tailnet: " + err.Error())
487487
}
488+
api.agentProvider = stn
489+
if options.DeploymentValues.Prometheus.Enable {
490+
options.PrometheusRegistry.MustRegister(stn)
491+
}
488492
api.TailnetClientService, err = tailnet.NewClientService(
489493
api.Logger.Named("tailnetclient"),
490494
&api.TailnetCoordinator,

coderd/database/pubsub/pubsub_test.go

+19-60
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"testing"
77

88
"github.com/prometheus/client_golang/prometheus"
9-
dto "github.com/prometheus/client_model/go"
109
"github.com/stretchr/testify/assert"
1110
"github.com/stretchr/testify/require"
1211

@@ -43,8 +42,8 @@ func TestPGPubsub_Metrics(t *testing.T) {
4342

4443
metrics, err := registry.Gather()
4544
require.NoError(t, err)
46-
require.True(t, gaugeHasValue(t, metrics, 0, "coder_pubsub_current_events"))
47-
require.True(t, gaugeHasValue(t, metrics, 0, "coder_pubsub_current_subscribers"))
45+
require.True(t, testutil.PromGaugeHasValue(t, metrics, 0, "coder_pubsub_current_events"))
46+
require.True(t, testutil.PromGaugeHasValue(t, metrics, 0, "coder_pubsub_current_subscribers"))
4847

4948
event := "test"
5049
data := "testing"
@@ -63,14 +62,14 @@ func TestPGPubsub_Metrics(t *testing.T) {
6362
require.Eventually(t, func() bool {
6463
metrics, err = registry.Gather()
6564
assert.NoError(t, err)
66-
return gaugeHasValue(t, metrics, 1, "coder_pubsub_current_events") &&
67-
gaugeHasValue(t, metrics, 1, "coder_pubsub_current_subscribers") &&
68-
gaugeHasValue(t, metrics, 1, "coder_pubsub_connected") &&
69-
counterHasValue(t, metrics, 1, "coder_pubsub_publishes_total", "true") &&
70-
counterHasValue(t, metrics, 1, "coder_pubsub_subscribes_total", "true") &&
71-
counterHasValue(t, metrics, 1, "coder_pubsub_messages_total", "normal") &&
72-
counterHasValue(t, metrics, 7, "coder_pubsub_received_bytes_total") &&
73-
counterHasValue(t, metrics, 7, "coder_pubsub_published_bytes_total")
65+
return testutil.PromGaugeHasValue(t, metrics, 1, "coder_pubsub_current_events") &&
66+
testutil.PromGaugeHasValue(t, metrics, 1, "coder_pubsub_current_subscribers") &&
67+
testutil.PromGaugeHasValue(t, metrics, 1, "coder_pubsub_connected") &&
68+
testutil.PromCounterHasValue(t, metrics, 1, "coder_pubsub_publishes_total", "true") &&
69+
testutil.PromCounterHasValue(t, metrics, 1, "coder_pubsub_subscribes_total", "true") &&
70+
testutil.PromCounterHasValue(t, metrics, 1, "coder_pubsub_messages_total", "normal") &&
71+
testutil.PromCounterHasValue(t, metrics, 7, "coder_pubsub_received_bytes_total") &&
72+
testutil.PromCounterHasValue(t, metrics, 7, "coder_pubsub_published_bytes_total")
7473
}, testutil.WaitShort, testutil.IntervalFast)
7574

7675
colossalData := make([]byte, 7600)
@@ -93,54 +92,14 @@ func TestPGPubsub_Metrics(t *testing.T) {
9392
require.Eventually(t, func() bool {
9493
metrics, err = registry.Gather()
9594
assert.NoError(t, err)
96-
return gaugeHasValue(t, metrics, 1, "coder_pubsub_current_events") &&
97-
gaugeHasValue(t, metrics, 2, "coder_pubsub_current_subscribers") &&
98-
gaugeHasValue(t, metrics, 1, "coder_pubsub_connected") &&
99-
counterHasValue(t, metrics, 2, "coder_pubsub_publishes_total", "true") &&
100-
counterHasValue(t, metrics, 2, "coder_pubsub_subscribes_total", "true") &&
101-
counterHasValue(t, metrics, 1, "coder_pubsub_messages_total", "normal") &&
102-
counterHasValue(t, metrics, 1, "coder_pubsub_messages_total", "colossal") &&
103-
counterHasValue(t, metrics, 7607, "coder_pubsub_received_bytes_total") &&
104-
counterHasValue(t, metrics, 7607, "coder_pubsub_published_bytes_total")
95+
return testutil.PromGaugeHasValue(t, metrics, 1, "coder_pubsub_current_events") &&
96+
testutil.PromGaugeHasValue(t, metrics, 2, "coder_pubsub_current_subscribers") &&
97+
testutil.PromGaugeHasValue(t, metrics, 1, "coder_pubsub_connected") &&
98+
testutil.PromCounterHasValue(t, metrics, 2, "coder_pubsub_publishes_total", "true") &&
99+
testutil.PromCounterHasValue(t, metrics, 2, "coder_pubsub_subscribes_total", "true") &&
100+
testutil.PromCounterHasValue(t, metrics, 1, "coder_pubsub_messages_total", "normal") &&
101+
testutil.PromCounterHasValue(t, metrics, 1, "coder_pubsub_messages_total", "colossal") &&
102+
testutil.PromCounterHasValue(t, metrics, 7607, "coder_pubsub_received_bytes_total") &&
103+
testutil.PromCounterHasValue(t, metrics, 7607, "coder_pubsub_published_bytes_total")
105104
}, testutil.WaitShort, testutil.IntervalFast)
106105
}
107-
108-
func gaugeHasValue(t testing.TB, metrics []*dto.MetricFamily, value float64, name string, label ...string) bool {
109-
t.Helper()
110-
for _, family := range metrics {
111-
if family.GetName() != name {
112-
continue
113-
}
114-
ms := family.GetMetric()
115-
for _, m := range ms {
116-
require.Equal(t, len(label), len(m.GetLabel()))
117-
for i, lv := range label {
118-
if lv != m.GetLabel()[i].GetValue() {
119-
continue
120-
}
121-
}
122-
return value == m.GetGauge().GetValue()
123-
}
124-
}
125-
return false
126-
}
127-
128-
func counterHasValue(t testing.TB, metrics []*dto.MetricFamily, value float64, name string, label ...string) bool {
129-
t.Helper()
130-
for _, family := range metrics {
131-
if family.GetName() != name {
132-
continue
133-
}
134-
ms := family.GetMetric()
135-
for _, m := range ms {
136-
require.Equal(t, len(label), len(m.GetLabel()))
137-
for i, lv := range label {
138-
if lv != m.GetLabel()[i].GetValue() {
139-
continue
140-
}
141-
}
142-
return value == m.GetCounter().GetValue()
143-
}
144-
}
145-
return false
146-
}

coderd/tailnet.go

+53-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"time"
1515

1616
"github.com/google/uuid"
17+
"github.com/prometheus/client_golang/prometheus"
1718
"go.opentelemetry.io/otel/trace"
1819
"golang.org/x/xerrors"
1920
"tailscale.com/derp"
@@ -97,6 +98,18 @@ func NewServerTailnet(
9798
agentConnectionTimes: map[uuid.UUID]time.Time{},
9899
agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{},
99100
transport: tailnetTransport.Clone(),
101+
connsPerAgent: prometheus.NewGaugeVec(prometheus.GaugeOpts{
102+
Namespace: "coder",
103+
Subsystem: "servertailnet",
104+
Name: "open_connections",
105+
Help: "Total number of TCP connections currently open to workspace agents.",
106+
}, []string{"network"}),
107+
totalConns: prometheus.NewCounterVec(prometheus.CounterOpts{
108+
Namespace: "coder",
109+
Subsystem: "servertailnet",
110+
Name: "connections_total",
111+
Help: "Total number of TCP connections made to workspace agents.",
112+
}, []string{"network"}),
100113
}
101114
tn.transport.DialContext = tn.dialContext
102115
// These options are mostly just picked at random, and they can likely be
@@ -170,6 +183,16 @@ func NewServerTailnet(
170183
return tn, nil
171184
}
172185

186+
func (s *ServerTailnet) Describe(descs chan<- *prometheus.Desc) {
187+
s.connsPerAgent.Describe(descs)
188+
s.totalConns.Describe(descs)
189+
}
190+
191+
func (s *ServerTailnet) Collect(metrics chan<- prometheus.Metric) {
192+
s.connsPerAgent.Collect(metrics)
193+
s.totalConns.Collect(metrics)
194+
}
195+
173196
func (s *ServerTailnet) expireOldAgents() {
174197
const (
175198
tick = 5 * time.Minute
@@ -304,6 +327,9 @@ type ServerTailnet struct {
304327
agentTickets map[uuid.UUID]map[uuid.UUID]struct{}
305328

306329
transport *http.Transport
330+
331+
connsPerAgent *prometheus.GaugeVec
332+
totalConns *prometheus.CounterVec
307333
}
308334

309335
func (s *ServerTailnet) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID) *httputil.ReverseProxy {
@@ -349,7 +375,18 @@ func (s *ServerTailnet) dialContext(ctx context.Context, network, addr string) (
349375
return nil, xerrors.Errorf("no agent id attached")
350376
}
351377

352-
return s.DialAgentNetConn(ctx, agentID, network, addr)
378+
nc, err := s.DialAgentNetConn(ctx, agentID, network, addr)
379+
if err != nil {
380+
return nil, err
381+
}
382+
383+
s.connsPerAgent.WithLabelValues("tcp").Inc()
384+
s.totalConns.WithLabelValues("tcp").Inc()
385+
return &instrumentedConn{
386+
Conn: nc,
387+
agentID: agentID,
388+
connsPerAgent: s.connsPerAgent,
389+
}, nil
353390
}
354391

355392
func (s *ServerTailnet) ensureAgent(agentID uuid.UUID) error {
@@ -455,3 +492,18 @@ func (s *ServerTailnet) Close() error {
455492
<-s.derpMapUpdaterClosed
456493
return nil
457494
}
495+
496+
type instrumentedConn struct {
497+
net.Conn
498+
499+
agentID uuid.UUID
500+
closeOnce sync.Once
501+
connsPerAgent *prometheus.GaugeVec
502+
}
503+
504+
func (c *instrumentedConn) Close() error {
505+
c.closeOnce.Do(func() {
506+
c.connsPerAgent.WithLabelValues("tcp").Dec()
507+
})
508+
return c.Conn.Close()
509+
}

coderd/tailnet_test.go

+38
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"testing"
1414

1515
"github.com/google/uuid"
16+
"github.com/prometheus/client_golang/prometheus"
1617
"github.com/spf13/afero"
1718
"github.com/stretchr/testify/assert"
1819
"github.com/stretchr/testify/require"
@@ -79,6 +80,43 @@ func TestServerTailnet_ReverseProxy(t *testing.T) {
7980
assert.Equal(t, http.StatusOK, res.StatusCode)
8081
})
8182

83+
t.Run("Metrics", func(t *testing.T) {
84+
t.Parallel()
85+
86+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
87+
defer cancel()
88+
89+
agents, serverTailnet := setupServerTailnetAgent(t, 1)
90+
a := agents[0]
91+
92+
registry := prometheus.NewRegistry()
93+
require.NoError(t, registry.Register(serverTailnet))
94+
95+
u, err := url.Parse(fmt.Sprintf("http://127.0.0.1:%d", codersdk.WorkspaceAgentHTTPAPIServerPort))
96+
require.NoError(t, err)
97+
98+
rp := serverTailnet.ReverseProxy(u, u, a.id)
99+
100+
rw := httptest.NewRecorder()
101+
req := httptest.NewRequest(
102+
http.MethodGet,
103+
u.String(),
104+
nil,
105+
).WithContext(ctx)
106+
107+
rp.ServeHTTP(rw, req)
108+
res := rw.Result()
109+
defer res.Body.Close()
110+
111+
assert.Equal(t, http.StatusOK, res.StatusCode)
112+
require.Eventually(t, func() bool {
113+
metrics, err := registry.Gather()
114+
assert.NoError(t, err)
115+
return testutil.PromCounterHasValue(t, metrics, 1, "coder_servertailnet_connections_total", "tcp") &&
116+
testutil.PromGaugeHasValue(t, metrics, 1, "coder_servertailnet_open_connections", "tcp")
117+
}, testutil.WaitShort, testutil.IntervalFast)
118+
})
119+
82120
t.Run("HostRewrite", func(t *testing.T) {
83121
t.Parallel()
84122

testutil/prometheus.go

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package testutil
2+
3+
import (
4+
"testing"
5+
6+
dto "github.com/prometheus/client_model/go"
7+
"github.com/stretchr/testify/require"
8+
)
9+
10+
func PromGaugeHasValue(t testing.TB, metrics []*dto.MetricFamily, value float64, name string, label ...string) bool {
11+
t.Helper()
12+
for _, family := range metrics {
13+
if family.GetName() != name {
14+
continue
15+
}
16+
ms := family.GetMetric()
17+
metricsLoop:
18+
for _, m := range ms {
19+
require.Equal(t, len(label), len(m.GetLabel()))
20+
for i, lv := range label {
21+
if lv != m.GetLabel()[i].GetValue() {
22+
continue metricsLoop
23+
}
24+
}
25+
return value == m.GetGauge().GetValue()
26+
}
27+
}
28+
return false
29+
}
30+
31+
func PromCounterHasValue(t testing.TB, metrics []*dto.MetricFamily, value float64, name string, label ...string) bool {
32+
t.Helper()
33+
for _, family := range metrics {
34+
if family.GetName() != name {
35+
continue
36+
}
37+
ms := family.GetMetric()
38+
metricsLoop:
39+
for _, m := range ms {
40+
require.Equal(t, len(label), len(m.GetLabel()))
41+
for i, lv := range label {
42+
if lv != m.GetLabel()[i].GetValue() {
43+
continue metricsLoop
44+
}
45+
}
46+
return value == m.GetCounter().GetValue()
47+
}
48+
}
49+
return false
50+
}

0 commit comments

Comments
 (0)