Skip to content

Commit 416b5ff

Browse files
committed
review
1 parent ef7f40a commit 416b5ff

File tree

5 files changed

+100
-77
lines changed

5 files changed

+100
-77
lines changed

cli/cliui/agent.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -403,10 +403,10 @@ func ConnDiagnostics(w io.Writer, d ConnDiags) {
403403
}
404404

405405
if d.ClientIPIsAWS {
406-
_, _ = fmt.Fprint(w, "❗ Client IP address is within an AWS range, and is therefore behind a hard NAT\n")
406+
_, _ = fmt.Fprint(w, "❗ Client IP address is within an AWS range, which is known to cause problems with forming direct connections (AWS uses hard NAT)\n")
407407
}
408408

409409
if d.AgentIPIsAWS {
410-
_, _ = fmt.Fprint(w, "❗ Agent IP address is within an AWS range, and is therefore behind a hard NAT\n")
410+
_, _ = fmt.Fprint(w, "❗ Agent IP address is within an AWS range, which is known to cause problems with forming direct connections (AWS uses hard NAT)\n")
411411
}
412412
}

cli/cliui/agent_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,7 @@ func TestConnDiagnostics(t *testing.T) {
790790
},
791791
want: []string{
792792
`❗ You are connected via a DERP relay, not directly (p2p)`,
793-
`❗ Client IP address is within an AWS range, and is therefore behind a hard NAT`,
793+
`❗ Client IP address is within an AWS range, which is known to cause problems with forming direct connections (AWS uses hard NAT)`,
794794
},
795795
},
796796
{
@@ -801,7 +801,7 @@ func TestConnDiagnostics(t *testing.T) {
801801
},
802802
want: []string{
803803
`❗ You are connected via a DERP relay, not directly (p2p)`,
804-
`❗ Agent IP address is within an AWS range, and is therefore behind a hard NAT`,
804+
`❗ Agent IP address is within an AWS range, which is known to cause problems with forming direct connections (AWS uses hard NAT)`,
805805
},
806806
},
807807
}

cli/cliutil/awscheck.go

Lines changed: 50 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,33 +11,39 @@ import (
1111
"golang.org/x/xerrors"
1212
)
1313

14-
const awsIPRangesURL = "https://ip-ranges.amazonaws.com/ip-ranges.json"
14+
const AWSIPRangesURL = "https://ip-ranges.amazonaws.com/ip-ranges.json"
1515

16-
type AWSIPv4Prefix struct {
16+
type awsIPv4Prefix struct {
1717
Prefix string `json:"ip_prefix"`
1818
Region string `json:"region"`
1919
Service string `json:"service"`
2020
NetworkBorderGroup string `json:"network_border_group"`
2121
}
2222

23-
type AWSIPv6Prefix struct {
24-
Prefix string `json:"ipv6_prefix"`
25-
Region string `json:"region"`
26-
Service string `json:"service"`
23+
type awsIPv6Prefix struct {
24+
Prefix string `json:"ipv6_prefix"`
25+
Region string `json:"region"`
26+
Service string `json:"service"`
27+
NetworkBorderGroup string `json:"network_border_group"`
2728
}
2829

2930
type AWSIPRanges struct {
31+
V4 []netip.Prefix
32+
V6 []netip.Prefix
33+
}
34+
35+
type awsIPRangesResponse struct {
3036
SyncToken string `json:"syncToken"`
3137
CreateDate string `json:"createDate"`
32-
IPV4Prefixes []AWSIPv4Prefix `json:"prefixes"`
33-
IPV6Prefixes []AWSIPv6Prefix `json:"ipv6_prefixes"`
38+
IPV4Prefixes []awsIPv4Prefix `json:"prefixes"`
39+
IPV6Prefixes []awsIPv6Prefix `json:"ipv6_prefixes"`
3440
}
3541

36-
func NewAWSIPRanges(ctx context.Context) (*AWSIPRanges, error) {
42+
func FetchAWSIPRanges(ctx context.Context, url string) (*AWSIPRanges, error) {
3743
client := &http.Client{}
3844
reqCtx, reqCancel := context.WithTimeout(ctx, 5*time.Second)
3945
defer reqCancel()
40-
req, _ := http.NewRequestWithContext(reqCtx, http.MethodGet, awsIPRangesURL, nil)
46+
req, _ := http.NewRequestWithContext(reqCtx, http.MethodGet, url, nil)
4147
resp, err := client.Do(req)
4248
if err != nil {
4349
return nil, err
@@ -49,40 +55,54 @@ func NewAWSIPRanges(ctx context.Context) (*AWSIPRanges, error) {
4955
return nil, xerrors.Errorf("unexpected status code %d: %s", resp.StatusCode, b)
5056
}
5157

52-
var out AWSIPRanges
53-
err = json.NewDecoder(resp.Body).Decode(&out)
58+
var body awsIPRangesResponse
59+
err = json.NewDecoder(resp.Body).Decode(&body)
5460
if err != nil {
5561
return nil, xerrors.Errorf("json decode: %w", err)
5662
}
57-
return &out, nil
63+
64+
out := &AWSIPRanges{
65+
V4: make([]netip.Prefix, 0, len(body.IPV4Prefixes)),
66+
V6: make([]netip.Prefix, 0, len(body.IPV6Prefixes)),
67+
}
68+
69+
for _, p := range body.IPV4Prefixes {
70+
prefix, err := netip.ParsePrefix(p.Prefix)
71+
if err != nil {
72+
return nil, xerrors.Errorf("parse ip prefix: %w", err)
73+
}
74+
out.V4 = append(out.V4, prefix)
75+
}
76+
77+
for _, p := range body.IPV6Prefixes {
78+
prefix, err := netip.ParsePrefix(p.Prefix)
79+
if err != nil {
80+
return nil, xerrors.Errorf("parse ip prefix: %w", err)
81+
}
82+
out.V6 = append(out.V6, prefix)
83+
}
84+
85+
return out, nil
5886
}
5987

6088
// CheckIP checks if the given IP address is an AWS IP.
61-
func (r *AWSIPRanges) CheckIP(ip netip.Addr) (bool, error) {
89+
func (r *AWSIPRanges) CheckIP(ip netip.Addr) bool {
6290
if ip.IsLoopback() || ip.IsLinkLocalMulticast() || ip.IsLinkLocalUnicast() || ip.IsPrivate() {
63-
return false, nil
91+
return false
6492
}
6593

6694
if ip.Is4() {
67-
for _, p := range r.IPV4Prefixes {
68-
prefix, err := netip.ParsePrefix(p.Prefix)
69-
if err != nil {
70-
return false, xerrors.Errorf("parse ip prefix: %w", err)
71-
}
72-
if prefix.Contains(ip) {
73-
return true, nil
95+
for _, p := range r.V4 {
96+
if p.Contains(ip) {
97+
return true
7498
}
7599
}
76100
} else {
77-
for _, p := range r.IPV6Prefixes {
78-
prefix, err := netip.ParsePrefix(p.Prefix)
79-
if err != nil {
80-
return false, xerrors.Errorf("parse ip prefix: %w", err)
81-
}
82-
if prefix.Contains(ip) {
83-
return true, nil
101+
for _, p := range r.V6 {
102+
if p.Contains(ip) {
103+
return true
84104
}
85105
}
86106
}
87-
return false, nil
107+
return false
88108
}
Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,95 @@
1-
package cliutil_test
1+
package cliutil
22

33
import (
4+
"context"
5+
"net/http"
6+
"net/http/httptest"
47
"net/netip"
58
"testing"
69

710
"github.com/stretchr/testify/require"
811

9-
"github.com/coder/coder/v2/cli/cliutil"
12+
"github.com/coder/coder/v2/coderd/httpapi"
1013
"github.com/coder/coder/v2/testutil"
1114
)
1215

1316
func TestIPV4Check(t *testing.T) {
1417
t.Parallel()
18+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
19+
httpapi.Write(context.Background(), w, http.StatusOK, awsIPRangesResponse{
20+
IPV4Prefixes: []awsIPv4Prefix{
21+
{
22+
Prefix: "3.24.0.0/14",
23+
},
24+
{
25+
Prefix: "15.230.15.29/32",
26+
},
27+
{
28+
Prefix: "47.128.82.100/31",
29+
},
30+
},
31+
IPV6Prefixes: []awsIPv6Prefix{
32+
{
33+
Prefix: "2600:9000:5206::/48",
34+
},
35+
{
36+
Prefix: "2406:da70:8800::/40",
37+
},
38+
{
39+
Prefix: "2600:1f68:5000::/40",
40+
},
41+
},
42+
})
43+
}))
1544
ctx := testutil.Context(t, testutil.WaitShort)
16-
ranges, err := cliutil.NewAWSIPRanges(ctx)
45+
ranges, err := FetchAWSIPRanges(ctx, srv.URL)
1746
require.NoError(t, err)
1847

1948
t.Run("Private/IPV4", func(t *testing.T) {
2049
t.Parallel()
2150
ip, err := netip.ParseAddr("192.168.0.1")
2251
require.NoError(t, err)
23-
isAws, err := ranges.CheckIP(ip)
24-
require.NoError(t, err)
52+
isAws := ranges.CheckIP(ip)
2553
require.False(t, isAws)
2654
})
2755

2856
t.Run("AWS/IPV4", func(t *testing.T) {
2957
t.Parallel()
3058
ip, err := netip.ParseAddr("3.25.61.113")
3159
require.NoError(t, err)
32-
isAws, err := ranges.CheckIP(ip)
33-
require.NoError(t, err)
60+
isAws := ranges.CheckIP(ip)
3461
require.True(t, isAws)
3562
})
3663

3764
t.Run("NonAWS/IPV4", func(t *testing.T) {
3865
t.Parallel()
3966
ip, err := netip.ParseAddr("159.196.123.40")
4067
require.NoError(t, err)
41-
isAws, err := ranges.CheckIP(ip)
42-
require.NoError(t, err)
68+
isAws := ranges.CheckIP(ip)
4369
require.False(t, isAws)
4470
})
4571

4672
t.Run("Private/IPV6", func(t *testing.T) {
4773
t.Parallel()
4874
ip, err := netip.ParseAddr("::1")
4975
require.NoError(t, err)
50-
isAws, err := ranges.CheckIP(ip)
51-
require.NoError(t, err)
76+
isAws := ranges.CheckIP(ip)
5277
require.False(t, isAws)
5378
})
5479

5580
t.Run("AWS/IPV6", func(t *testing.T) {
5681
t.Parallel()
5782
ip, err := netip.ParseAddr("2600:9000:5206:0001:0000:0000:0000:0001")
5883
require.NoError(t, err)
59-
isAws, err := ranges.CheckIP(ip)
60-
require.NoError(t, err)
84+
isAws := ranges.CheckIP(ip)
6185
require.True(t, isAws)
6286
})
6387

6488
t.Run("NonAWS/IPV6", func(t *testing.T) {
6589
t.Parallel()
6690
ip, err := netip.ParseAddr("2403:5807:885f:0:a544:49d4:58f8:aedf")
6791
require.NoError(t, err)
68-
isAws, err := ranges.CheckIP(ip)
69-
require.NoError(t, err)
92+
isAws := ranges.CheckIP(ip)
7093
require.False(t, isAws)
7194
})
7295
}

cli/ping.go

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -160,16 +160,12 @@ func (r *RootCmd) ping() *serpent.Command {
160160
LocalNetInfo: ni,
161161
}
162162

163-
awsRanges, err := cliutil.NewAWSIPRanges(ctx)
163+
awsRanges, err := cliutil.FetchAWSIPRanges(ctx, cliutil.AWSIPRangesURL)
164164
if err != nil {
165165
_, _ = fmt.Fprintf(inv.Stdout, "Failed to retrieve AWS IP ranges: %v\n", err)
166166
}
167167

168-
clientIPIsAWS, err := isAWSIP(awsRanges, ni)
169-
if err != nil {
170-
_, _ = fmt.Fprintf(inv.Stdout, "Failed to determine if client IP is AWS: %v\n", err)
171-
}
172-
connDiags.ClientIPIsAWS = clientIPIsAWS
168+
connDiags.ClientIPIsAWS = isAWSIP(awsRanges, ni)
173169

174170
connInfo, err := client.AgentConnectionInfoGeneric(ctx)
175171
if err == nil {
@@ -187,11 +183,7 @@ func (r *RootCmd) ping() *serpent.Command {
187183
agentNetcheck, err := conn.Netcheck(ctx)
188184
if err == nil {
189185
connDiags.AgentNetcheck = &agentNetcheck
190-
agentIPIsAws, err := isAWSIP(awsRanges, agentNetcheck.NetInfo)
191-
if err != nil {
192-
_, _ = fmt.Fprintf(inv.Stdout, "Failed to determine if agent IP is AWS: %v\n", err)
193-
}
194-
connDiags.AgentIPIsAWS = agentIPIsAws
186+
connDiags.AgentIPIsAWS = isAWSIP(awsRanges, agentNetcheck.NetInfo)
195187
} else {
196188
var sdkErr *codersdk.Error
197189
if errors.As(err, &sdkErr) && sdkErr.StatusCode() == http.StatusNotFound {
@@ -231,23 +223,11 @@ func (r *RootCmd) ping() *serpent.Command {
231223
return cmd
232224
}
233225

234-
func isAWSIP(awsRanges *cliutil.AWSIPRanges, ni *tailcfg.NetInfo) (bool, error) {
235-
var strIP string
236-
if ni.GlobalV4 != "" {
237-
strIP = ni.GlobalV4
238-
} else if ni.GlobalV6 != "" {
239-
strIP = ni.GlobalV6
240-
} else {
241-
return false, xerrors.Errorf("no public IP address found")
226+
func isAWSIP(awsRanges *cliutil.AWSIPRanges, ni *tailcfg.NetInfo) bool {
227+
checkIP := func(ipStr string) bool {
228+
ip, err := netip.ParseAddr(ipStr)
229+
return err == nil && awsRanges.CheckIP(ip)
242230
}
243231

244-
ip, err := netip.ParseAddr(strIP)
245-
if err != nil {
246-
return false, err
247-
}
248-
isAWS, err := awsRanges.CheckIP(ip)
249-
if err != nil {
250-
return false, err
251-
}
252-
return isAWS, nil
232+
return checkIP(ni.GlobalV4) || checkIP(ni.GlobalV6)
253233
}

0 commit comments

Comments
 (0)