diff --git a/buildmode/buildmode.go b/buildmode/buildmode.go new file mode 100644 index 0000000000000..1636d589cb790 --- /dev/null +++ b/buildmode/buildmode.go @@ -0,0 +1,21 @@ +package buildmode + +import ( + "flag" + "strings" +) + +// BuildMode is injected at build time. +var ( + BuildMode string +) + +// Dev returns true when built to run in a dev deployment. +func Dev() bool { + return strings.HasPrefix(BuildMode, "dev") +} + +// Test returns true when running inside a unit test. +func Test() bool { + return flag.Lookup("test.v") != nil +} diff --git a/coderd/api.go b/coderd/api.go new file mode 100644 index 0000000000000..1fe4d7140d8e2 --- /dev/null +++ b/coderd/api.go @@ -0,0 +1,17 @@ +package coderd + +// API offers an HTTP API. Routes are located in routes.go. +type API struct { + // Services. + projectService *projectService + workspaceService *workspaceService +} + +// New returns an instantiated API. +func NewAPI() *API { + api := &API{ + projectService: newProjectService(), + workspaceService: newWorkspaceService(), + } + return api +} diff --git a/coderd/coderd.go b/coderd/coderd.go index e3c7c2b546b9f..c57ef3c1372a8 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -16,8 +16,14 @@ type Options struct { Database database.Store } +const ( + provisionerTerraform = "provisioner:terraform" + provisionerBasic = "provisioner:basic" +) + // New constructs the Coder API into an HTTP handler. func New(options *Options) http.Handler { + api := NewAPI() r := chi.NewRouter() r.Route("/api/v2", func(r chi.Router) { r.Get("/", func(w http.ResponseWriter, r *http.Request) { @@ -27,6 +33,29 @@ func New(options *Options) http.Handler { Message: "👋", }) }) + + // Projects endpoint + r.Route("/projects", func(r chi.Router) { + r.Route("/{organization}", func(r chi.Router) { + // TODO: Authentication + // TODO: User extraction + // TODO: Extract organization and add to context + r.Get("/", api.projectService.getProjects) + r.Post("/", api.projectService.createProject) + + r.Get("/{projectId}", api.projectService.getProjectById) + // TODO: Get project by id + }) + }) + + // Workspaces endpoint + r.Route("/workspaces", func(r chi.Router) { + r.Route("/{organization}", func(r chi.Router) { + r.Get("/", api.workspaceService.getWorkspaces) + r.Get("/{projectId}", api.workspaceService.getWorkspaceById) + }) + }) + }) r.NotFound(site.Handler().ServeHTTP) return r diff --git a/coderd/projects.go b/coderd/projects.go new file mode 100644 index 0000000000000..81230b42eed3d --- /dev/null +++ b/coderd/projects.go @@ -0,0 +1,94 @@ +package coderd + +import ( + "net/http" + + "github.com/coder/coder/xjson" +) + +type ProjectParameter struct { + Id string `json:"id" validate:"required"` + Name string `json:"name" validate:"required"` + Description string `json:"description"` + + // Validation Parameters + ValueType string `json:"validation_value_type"` +} + +// Project is a Go representation of the workspaces v2 project, +// defined here: https://www.notion.so/coderhq/Workspaces-v2-e908a8cd54804ddd910367abf03c8d0a#befa328add894231979e6cf8a378d2ec +type Project struct { + Id string `json:"id" validate:"required"` + Name string `json:"name" validate:"required"` + Description string `json:"description" validate:"required"` + ProvisionerType string `json:"provisioner_type" validate:"required"` + + Parameters []ProjectParameter `json:"parameters" validate:"required"` +} + +// Placeholder type of projectService +type projectService struct { +} + +func newProjectService() *projectService { + projectService := &projectService{} + return projectService +} + +func (ps *projectService) getProjects(w http.ResponseWriter, r *http.Request) { + // Construct a couple hard-coded projects to return the UI + terraformProject := Project{ + Id: "test_terraform_project_id", + Name: "Terraform", + Description: "Kubernetes on Terraform", + Parameters: []ProjectParameter{ + { + Id: "parameter_cluster_namespace", + Name: "Namespace", + Description: "Kubernetes namespace to host workspace pod", + ValueType: "string", + }, + { + Id: "parameter_cpu", + Name: "CPU", + Description: "CPU Cores to Allocate", + ValueType: "number", + }, + }, + } + + echoProject := Project{ + Id: "test_echo_project_id", + Name: "Echo Project", + Description: "A simple echo provider", + Parameters: []ProjectParameter{ + { + Id: "parameter_echo_string", + Name: "Echo String", + Description: "String that should be echo'd out in build log", + ValueType: "string", + }, + }, + } + + projects := []Project{ + terraformProject, + echoProject, + } + + xjson.Write(w, http.StatusOK, projects) +} + +func (ps *projectService) getProjectById(w http.ResponseWriter, r *http.Request) { + // TODO: Get a project by id + xjson.Write(w, http.StatusNotFound, nil) +} + +func (ps *projectService) createProject(w http.ResponseWriter, r *http.Request) { + // TODO: Validate arguments + // Organization context + // User + // Parameter values + // Submit to provisioner + xjson.Write(w, http.StatusOK, nil) +} diff --git a/coderd/workspaces.go b/coderd/workspaces.go new file mode 100644 index 0000000000000..be236ef6b5936 --- /dev/null +++ b/coderd/workspaces.go @@ -0,0 +1,48 @@ +package coderd + +import ( + "net/http" + + "github.com/coder/coder/xjson" +) + +type Workspace struct { + Id string `json:"id" validate:"required"` + Name string `json:"name" validate:"required"` + ProjectId string `json:"project_id" validate:"required"` +} + +// Placeholder type of workspaceService +type workspaceService struct { +} + +func newWorkspaceService() *workspaceService { + workspaceService := &workspaceService{} + return workspaceService +} + +func (ws *workspaceService) getWorkspaces(w http.ResponseWriter, r *http.Request) { + // Dummy workspace to return + workspace := Workspace{ + Id: "test-workspace", + Name: "Test Workspace", + ProjectId: "test-project-id", + } + + workspaces := []Workspace{ + workspace, + } + + xjson.Write(w, http.StatusOK, workspaces) +} + +func (ws *workspaceService) getWorkspaceById(w http.ResponseWriter, r *http.Request) { + // TODO: Read workspace off context + // Dummy workspace to return + workspace := Workspace{ + Id: "test-workspace", + Name: "Test Workspace", + ProjectId: "test-project-id", + } + xjson.Write(w, http.StatusOK, workspace) +} \ No newline at end of file diff --git a/go.mod b/go.mod index 577e7278aa438..214e497aa8dd3 100644 --- a/go.mod +++ b/go.mod @@ -42,6 +42,7 @@ require ( github.com/alecthomas/chroma v0.9.1 // indirect github.com/apparentlymart/go-textseg v1.0.0 // indirect github.com/apparentlymart/go-textseg/v13 v13.0.0 // indirect + github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d // indirect github.com/cenkalti/backoff/v4 v4.1.2 // indirect github.com/containerd/continuity v0.1.0 // indirect github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964 // indirect @@ -53,6 +54,9 @@ require ( github.com/docker/go-connections v0.4.0 // indirect github.com/docker/go-units v0.4.0 // indirect github.com/fatih/color v1.13.0 // indirect + github.com/go-playground/locales v0.14.0 // indirect + github.com/go-playground/universal-translator v0.18.0 // indirect + github.com/go-playground/validator/v10 v10.10.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/google/go-cmp v0.5.6 // indirect @@ -66,6 +70,7 @@ require ( github.com/imdario/mergo v0.3.12 // indirect github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/justinas/nosurf v1.1.1 // indirect + github.com/leodido/go-urn v1.2.1 // indirect github.com/mattn/go-colorable v0.1.12 // indirect github.com/mattn/go-isatty v0.0.14 // indirect github.com/mitchellh/go-wordwrap v1.0.0 // indirect @@ -107,4 +112,5 @@ require ( google.golang.org/grpc v1.43.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect + k8s.io/utils v0.0.0-20211208161948-7d6a63dca704 // indirect ) diff --git a/go.sum b/go.sum index 009f6283c4c39..3803458b6f3ac 100644 --- a/go.sum +++ b/go.sum @@ -151,6 +151,8 @@ github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= +github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d h1:Byv0BzEl3/e6D5CLfI0j/7hiIEtvGVFPCZ7Ei2oq8iQ= +github.com/asaskevich/govalidator v0.0.0-20210307081110-f21760c49a8d/go.mod h1:WaHUgvxTVq04UNunO+XhnAqY/wQc+bxr74GqbsZ/Jqw= github.com/aws/aws-sdk-go v1.15.11/go.mod h1:mFuSZ37Z9YOHbQEwBWztmVzqXrEkub65tZoCYDt7FT0= github.com/aws/aws-sdk-go v1.15.78/go.mod h1:E3/ieXAlvM0XWO57iftYVDLLvQ824smPP3ATZkfNZeM= github.com/aws/aws-sdk-go v1.17.7 h1:/4+rDPe0W95KBmNGYCG+NUvdL8ssPYBMxL+aSCg6nIA= @@ -468,6 +470,13 @@ github.com/go-openapi/jsonreference v0.19.3/go.mod h1:rjx6GuL8TTa9VaixXglHmQmIL9 github.com/go-openapi/spec v0.19.3/go.mod h1:FpwSN1ksY1eteniUU7X0N/BgJ7a4WvBFVA8Lj9mJglo= github.com/go-openapi/swag v0.19.2/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= +github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU= +github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= +github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/jYrnRPArHwAcmLoJZxyho= +github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA= +github.com/go-playground/validator/v10 v10.10.0 h1:I7mrTYv78z8k8VXa/qJlOlEXn/nBh+BF8dHX5nt/dr0= +github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos= github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= @@ -804,6 +813,8 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.5/go.mod h1:9r2w37qlBe7rQ6e1fg1S/9xpWHSnaqNdHD3WcMdbPDA= github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= @@ -818,6 +829,8 @@ github.com/kylecarbs/terraform-config-inspect v0.0.0-20211215004401-bbc517866b88 github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348/go.mod h1:B69LEHPfb2qLo0BaaOLcbitczOKLWTsrBG9LczfCD4k= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= +github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= github.com/lib/pq v0.0.0-20180327071824-d34b9ff171c2/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= @@ -1020,6 +1033,7 @@ github.com/pion/webrtc/v3 v3.1.13 h1:2XxgGstOqt03ba8QD5+m9S8DCA3Ez53mULT4If8onOg github.com/pion/webrtc/v3 v3.1.13/go.mod h1:RACpyE1EDYlzonfbdPvXkIGDaqD8+NsHqZJN0yEbRbA= github.com/pkg/browser v0.0.0-20210706143420-7d21f8c997e2/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1-0.20171018195549-f15c970de5b7/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -1069,6 +1083,8 @@ github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6L github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= +github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= @@ -1265,6 +1281,7 @@ golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= @@ -1889,6 +1906,8 @@ k8s.io/klog/v2 v2.4.0/go.mod h1:Od+F08eJP+W3HUb4pSrPpgp9DGU4GzlpG/TmITuYh/Y= k8s.io/kube-openapi v0.0.0-20201113171705-d219536bb9fd/go.mod h1:WOJ3KddDSol4tAGcJo0Tvi+dK12EcqSLqcWsryKMpfM= k8s.io/kubernetes v1.13.0/go.mod h1:ocZa8+6APFNC2tX1DZASIbocyYT5jHzqFVsY5aoB7Jk= k8s.io/utils v0.0.0-20201110183641-67b214c5f920/go.mod h1:jPW/WVKK9YHAvNhRxK0md/EJ228hCsBRufyofKtW8HA= +k8s.io/utils v0.0.0-20211208161948-7d6a63dca704 h1:ZKMMxTvduyf5WUtREOqg5LiXaN1KO/+0oOQPRFrClpo= +k8s.io/utils v0.0.0-20211208161948-7d6a63dca704/go.mod h1:jPW/WVKK9YHAvNhRxK0md/EJ228hCsBRufyofKtW8HA= modernc.org/b v1.0.0/go.mod h1:uZWcZfRj1BpYzfN9JTerzlNUnnPsV9O2ZA8JsRcubNg= modernc.org/cc/v3 v3.32.4/go.mod h1:0R6jl1aZlIl2avnYfbfHBS1QB6/f+16mihBObaBC878= modernc.org/ccgo/v3 v3.9.2/go.mod h1:gnJpy6NIVqkETT+L5zPsQFj7L2kkhfPMzOghRNv/CFo= diff --git a/longid/id.go b/longid/id.go new file mode 100644 index 0000000000000..db520409b8924 --- /dev/null +++ b/longid/id.go @@ -0,0 +1,214 @@ +package longid + +import ( + "bytes" + "database/sql" + "database/sql/driver" + "encoding/binary" + "encoding/hex" + "fmt" + "hash/fnv" + "math/rand" + "os" + "sync/atomic" + "time" + + "golang.org/x/xerrors" +) + +// bit counts. +const ( + TimeBits = 32 + IncrementorBits = 10 + + RandomBits = 70 + HostIDBits = 8 + + // amount of random bits in each uint64 + RandomBits1 = 64 - (TimeBits + IncrementorBits) + RandomBits2 = RandomBits - RandomBits1 + + HostMask = 0x00000000000000FF +) + +var ( + inc uint32 + hostID int64 +) + +func init() { + rand.Seed(time.Now().UnixNano()) + + hostname, err := os.Hostname() + if err != nil { + panic(err) + } + hash := fnv.New64a() + _, _ = hash.Write([]byte(hostname)) + hostID = (int64(hash.Sum64())) & HostMask + inc = rand.Uint32() +} + +// HostID returns the host ID for the current machine. +func HostID() int64 { + return hostID +} + +// ID describes a 128 bit ID +type ID [16]byte + +// parse errors +var ( + ErrWrongSize = xerrors.New("id in string form should be exactly 33 bytes") +) + +// FromSlice converts a slice into an ID. +func FromSlice(b []byte) ID { + var l ID + copy(l[:], b) + return l +} + +func part1() int64 { + seconds := time.Now().Unix() + + // place time portion properly + time := seconds << (64 - TimeBits) + + i := atomic.AddUint32(&inc, 1) + + // reset incrementor if it's too big + atomic.CompareAndSwapUint32(&inc, ((1 << IncrementorBits) - 1), 0) + + i <<= (RandomBits1) + + var randBuf [4]byte + _, _ = rand.Read(randBuf[:]) + + rand := (binary.BigEndian.Uint32(randBuf[:]) >> (32 - RandomBits1)) + + return time + int64(i) + int64(rand) +} + +func part2() int64 { + var randBuf [8]byte + _, _ = rand.Read(randBuf[:]) + rand := binary.BigEndian.Uint64(randBuf[:]) << HostIDBits + // fmt.Printf("%x\n", rand) + + return int64(rand) + hostID +} + +// New generates a long ID. +func New() ID { + var id ID + binary.BigEndian.PutUint64(id[:8], uint64(part1())) + binary.BigEndian.PutUint64(id[8:], uint64(part2())) + return id +} + +// Bytes returns a byte slice from l. +func (l ID) Bytes() []byte { + return l[:] +} + +// CreatedAt returns the time the ID was created at. +func (l ID) CreatedAt() time.Time { + epoch := (time.Now().Unix() >> (TimeBits)) << (TimeBits) + + ts := binary.BigEndian.Uint64(l[:8]) >> (64 - TimeBits) + + // fmt.Printf("%064b\n", epoch) + // fmt.Printf("%064b\n", ts) + + return time.Unix(epoch+int64(ts), 0) +} + +// String returns the text representation of l +func (l ID) String() string { + return fmt.Sprintf("%08x-%024x", l[:4], l[4:]) +} + +// MarshalText marshals l +func (l ID) MarshalText() ([]byte, error) { + return []byte(l.String()), nil +} + +// UnmarshalText parses b +func (l *ID) UnmarshalText(b []byte) error { + ll, err := Parse(string(b)) + if err != nil { + return err + } + copy(l[:], ll[:]) + return nil +} + +// MarshalJSON marshals l +func (l ID) MarshalJSON() ([]byte, error) { + return []byte("\"" + l.String() + "\""), nil +} + +// UnmarshalJSON parses b +func (l *ID) UnmarshalJSON(b []byte) error { + return l.UnmarshalText(bytes.Trim(b, "\"")) +} + +var _ = driver.Valuer(New()) +var _ = sql.Scanner(&ID{}) + +func (l ID) Value() (driver.Value, error) { + return l.Bytes(), nil +} + +func (l *ID) Scan(v interface{}) error { + b, ok := v.([]byte) + if !ok { + return xerrors.New("can only scan binary types") + } + if len(b) != 16 { + return xerrors.New("must be 16 bytes") + } + copy(l[:], b) + return nil +} + +// Parse parses the String() representation of a Long +func Parse(l string) (ID, error) { + var ( + id ID + err error + ) + if len(l) != 33 { + return id, ErrWrongSize + } + + p1, err := hex.DecodeString(l[:8]) + if err != nil { + return id, xerrors.Errorf("failed to decode short portion: %w", err) + } + + p2, err := hex.DecodeString(l[9:]) + if err != nil { + return id, xerrors.Errorf("failed to decode rand portion: %w", err) + } + + copy(id[:4], p1) + copy(id[4:], p2) + + return id, nil +} + +// TimeReset the current bounds of +// validity for timestamps extracted from longs +func TimeReset() (last time.Time, next time.Time) { + const lastStr = "00000000-00680e087d8fff20a11d24e6" + const nextStr = "ffffffff-00680e087d8fff20a11d24e6" + l, _ := Parse(lastStr) + last = l.CreatedAt() + + l, _ = Parse(nextStr) + next = l.CreatedAt() + + return +} diff --git a/longid/id_test.go b/longid/id_test.go new file mode 100644 index 0000000000000..ce8df492b2c82 --- /dev/null +++ b/longid/id_test.go @@ -0,0 +1,80 @@ +package longid + +import ( + "fmt" + "math/rand" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestID(t *testing.T) { + last, next := TimeReset() + t.Logf("Long Reset: Last: %v, Next: %v (🔼 %v)\n", last, next, next.Sub(last)) + t.Run("New()", func(t *testing.T) { + for i := 0; i < 5; i++ { + l := New() + fmt.Printf("Long: %v\n", l) + assert.WithinDuration(t, time.Now(), l.CreatedAt(), time.Second) + } + }) + + t.Run("Parse()", func(t *testing.T) { + t.Run("Good", func(t *testing.T) { + want := New() + got, err := Parse(want.String()) + require.Nil(t, err) + require.Equal(t, want, got) + }) + + t.Run("Bad Size", func(t *testing.T) { + _, err := Parse(New().String() + "ab") + require.NotNil(t, err) + }) + + t.Run("Bad Hex", func(t *testing.T) { + str := New().String() + str = "O" + str[1:] + _, err := Parse(str) + require.NotNil(t, err) + }) + }) + + t.Run("FromSlice", func(t *testing.T) { + l := New() + assert.Equal(t, l, FromSlice(l[:])) + }) + + t.Run("Scan", func(t *testing.T) { + var l ID + b := make([]byte, 16) + _, err := rand.Read(b) + require.NoError(t, err) + + require.NoError(t, l.Scan(b)) + assert.Equal(t, b, l.Bytes()) + }) +} + +func TestLongRaces(_ *testing.T) { + var wg sync.WaitGroup + for i := 0; i < 16; i++ { + go func() { + for i := 0; i < 1000; i++ { + New() + } + }() + } + wg.Wait() +} + +func BenchmarkLong(b *testing.B) { + b.Run("New()", func(b *testing.B) { + for i := 0; i < b.N; i++ { + New() + } + }) +} diff --git a/srverr/error.go b/srverr/error.go new file mode 100644 index 0000000000000..022387ca2f28c --- /dev/null +++ b/srverr/error.go @@ -0,0 +1,10 @@ +package srverr + +// Error is an interface for specifying how specific errors should be +// dispatched by the API. The underlying struct is sent under the `details` +// field. +type Error interface { + Status() int + PublicMessage() string + Code() Code +} diff --git a/srverr/error_test.go b/srverr/error_test.go new file mode 100644 index 0000000000000..908df65d56e7d --- /dev/null +++ b/srverr/error_test.go @@ -0,0 +1,19 @@ +package srverr + +import ( + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" +) + +func TestErrorChain(t *testing.T) { + t.Run("wrapping", func(t *testing.T) { + err := xerrors.Errorf("im an error") + err = Upgrade(err, ResourceNotFoundError{}) + err = xerrors.Errorf("wrapped http error: %w", err) + + var herr Error + require.ErrorAs(t, err, &herr, "should find http error details") + }) +} diff --git a/srverr/errors.go b/srverr/errors.go new file mode 100644 index 0000000000000..411ea05e09219 --- /dev/null +++ b/srverr/errors.go @@ -0,0 +1,65 @@ +package srverr + +import ( + "net/http" +) + +// SettableError describes a structured error that can accept an error. This is +// useful to prevent handlers from needing to insert the error into Upgrade +// twice. xjson.HandleError uses this interface set the final error string +// before marshaling. +type VerboseError interface { + SetVerbose(err error) +} + +// Verbose is for reusing the `verbose` field between error types. It +// implements VerboseError so it's not necessary to prefill the struct with the +// verbose error. +type Verbose struct { + Verbose string `json:"verbose"` +} + +func (e *Verbose) SetVerbose(err error) { e.Verbose = err.Error() } + +// Code is a string enum indicating the structure of the details field in an +// error response. Each error type should correspond to a unique Code. +type Code string + +const ( + CodeServerError Code = "server_error" + CodeDatabaseError Code = "database_error" + CodeResourceNotFound Code = "resource_not_found" +) + +var _ VerboseError = &ServerError{} + +// ServerError describes an error of unknown origins. +type ServerError struct { + Verbose +} + +func (*ServerError) Status() int { return http.StatusInternalServerError } +func (*ServerError) PublicMessage() string { return "An internal server error occurred." } +func (*ServerError) Code() Code { return CodeServerError } +func (*ServerError) Error() string { return "internal server error" } + +// DatabaseError describes an unknown error from the database. +type DatabaseError struct { + Verbose +} + +func (*DatabaseError) Status() int { return http.StatusInternalServerError } +func (*DatabaseError) PublicMessage() string { return "A database error occurred." } +func (*DatabaseError) Code() Code { return CodeDatabaseError } +func (*DatabaseError) Error() string { return "database error" } + +// ResourceNotFoundError describes an error when a provided resource ID was not +// found within the database or the user does not have the proper permission to +// view it. +type ResourceNotFoundError struct { +} + +func (ResourceNotFoundError) Status() int { return http.StatusNotFound } +func (e ResourceNotFoundError) PublicMessage() string { return "Resource not found." } +func (ResourceNotFoundError) Code() Code { return CodeResourceNotFound } +func (ResourceNotFoundError) Error() string { return "resource not found" } diff --git a/srverr/wrap.go b/srverr/wrap.go new file mode 100644 index 0000000000000..d87229d0bbb27 --- /dev/null +++ b/srverr/wrap.go @@ -0,0 +1,48 @@ +package srverr + +import ( + "encoding/json" +) + +// Upgrade transparently upgrades any error chain by adding information on how +// the error should be converted into an HTTP response. Since this adds it to +// the chain transparently, there is no indication from the error string that +// it is an upgraded error. You must use xerrors.As to check if an error chain +// contains an upgraded error. +// An error may be upgraded multiple times. The last call to Upgrade will +// always be used. +func Upgrade(err error, herr Error) error { + return wrapped{ + err: err, + herr: herr, + } +} + +var _ VerboseError = wrapped{} + +type wrapped struct { + err error + herr Error +} + +// Make sure the wrapped error still behaves as if it was a regular call to +// xerrors.Errorf. +func (w wrapped) Error() string { return w.err.Error() } +func (w wrapped) Unwrap() error { return w.err } + +// Pass through srverr.Error interface functions from the underlying +// srverr.Error. +func (w wrapped) Status() int { return w.herr.Status() } +func (w wrapped) PublicMessage() string { return w.herr.PublicMessage() } +func (w wrapped) Code() Code { return w.herr.Code() } + +// When a wrapped error is marshaled, we want to make sure it marshals the +// underlying srverr.Error, not the wrapped structure. +func (w wrapped) MarshalJSON() ([]byte, error) { return json.Marshal(w.herr) } + +// If the underlying srverr.Error implements VerboseError, pass through. +func (w wrapped) SetVerbose(err error) { + if v, ok := w.herr.(VerboseError); ok { + v.SetVerbose(err) + } +} diff --git a/validate/devurl.go b/validate/devurl.go new file mode 100644 index 0000000000000..4577f14d9defd --- /dev/null +++ b/validate/devurl.go @@ -0,0 +1,38 @@ +package validate + +import ( + "regexp" + "strings" + + "golang.org/x/xerrors" +) + +// NOTE: disallowing leading and trailing hyphens to avoid semantic confusion with hyphen used as separator. +// Disallowing leading and trailing underscores to avoid potential clashes with mDNS-related stuff. +var devURLValidNameRx = regexp.MustCompile("^[a-zA-Z]([a-zA-Z0-9_-]{0,41}[a-zA-Z0-9])?$") +var devURLInvalidLenMsg = "invalid devurl name %q: names may not be more than 64 characters in length." +var devURLInvalidNameMsg = "invalid devurl name %q: names must begin with a letter, followed by zero or more letters," + + " digits, hyphens, or underscores, and end with a letter or digit." + +const ( + // DevURLDelimiter is the separator for parts of a DevURL. + // eg. kyle--test--name.cdr.co + DevURLDelimiter = "--" +) + +// DevURLName only validates the name of the devurl, not the fully resolved subdomain. +func DevURLName(name string) error { + if len(name) == 0 { + return nil + } + if len(name) > 43 { + return xerrors.Errorf(devURLInvalidLenMsg, name) + } + if name != "" && !devURLValidNameRx.MatchString(name) { + return xerrors.Errorf(devURLInvalidNameMsg, name) + } + if strings.Contains(name, DevURLDelimiter) { + return xerrors.Errorf(devURLInvalidNameMsg, name) + } + return nil +} diff --git a/validate/devurl_test.go b/validate/devurl_test.go new file mode 100644 index 0000000000000..0df6ed88b6622 --- /dev/null +++ b/validate/devurl_test.go @@ -0,0 +1,47 @@ +package validate + +import ( + "testing" + + "github.com/stretchr/testify/require" + "k8s.io/utils/pointer" +) + +func Test_DevURLName(t *testing.T) { + testCases := []struct { + S string + Err *string + }{ + {"", nil}, + {"a", nil}, + {"a1", nil}, + {"a1a", nil}, + {"a-b", nil}, + {"a_b", nil}, + {"a_-b", nil}, + {"a_bc", nil}, + {"a-b-c", nil}, + {"a-b_c", nil}, + {"a_b-c", nil}, + {"a_b_c", nil}, + {"1", pointer.String("names must begin with a letter")}, + {"1a", pointer.String("names must begin with a letter")}, + {"1a1", pointer.String("names must begin with a letter")}, + {"1234", pointer.String("names must begin with a letter")}, + {"-a", pointer.String("names must begin with a letter")}, + {"a-", pointer.String("names must begin with a letter")}, + {"_a", pointer.String("names must begin with a letter")}, + {"a_", pointer.String("names must begin with a letter")}, + {"a--b", pointer.String("names must begin with a letter")}, + {"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", pointer.String("names may not be more than 64 characters in length")}, + } + for _, tc := range testCases { + err := DevURLName(tc.S) + if tc.Err != nil { + require.Errorf(t, err, "expected error for test case %q", tc.S) + require.Containsf(t, err.Error(), *tc.Err, "expected error for test case %q", tc.S) + } else { + require.NoError(t, err, tc.S) + } + } +} diff --git a/validate/numeric.go b/validate/numeric.go new file mode 100644 index 0000000000000..cdabeb2ecbd77 --- /dev/null +++ b/validate/numeric.go @@ -0,0 +1,15 @@ +package validate + +// Numeric returns true if s contains only digits. +// Returns false otherwise. +func Numeric(s string) bool { + if s == "" { + return false + } + for _, r := range s { + if r < '0' || r > '9' { + return false + } + } + return true +} diff --git a/validate/numeric_test.go b/validate/numeric_test.go new file mode 100644 index 0000000000000..b86e13be569c1 --- /dev/null +++ b/validate/numeric_test.go @@ -0,0 +1,24 @@ +package validate + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_Numeric(t *testing.T) { + testCases := []struct { + S string + Expected bool + }{ + {"", false}, + {"a1", false}, + {"1a", false}, + {"1a1", false}, + {"1234", true}, + } + for _, tc := range testCases { + actual := Numeric(tc.S) + require.Equal(t, tc.Expected, actual, tc.S) + } +} diff --git a/validate/struct.go b/validate/struct.go new file mode 100644 index 0000000000000..0449f018eb50c --- /dev/null +++ b/validate/struct.go @@ -0,0 +1,285 @@ +package validate + +import ( + "reflect" + "strings" + + "github.com/go-playground/validator/v10" + "golang.org/x/xerrors" +) + +const ( + validateTag = "validate" // Used by go-playground/validator + jsonTag = "json" // Stdlib json tag +) + +var ErrNotAStruct = xerrors.Errorf("value not a struct") + +// FieldsMissingValidation returns a list of fields that are missing appropriate +// validation tags. Any field that is exported, does not have a "-" json tag +// value, and is not a bool or pointer to a bool will be included in the +// returned list of fields if it lacks a "validate" tag. +// +// This will recursively check nested structs, stopping at fields that are +// unexported, or fields that do not unmarshal from json. `v` should be a +// struct. +// +// Nested struct fields that do not have a "validate" tag will be included in +// the returned list. +func FieldsMissingValidation(v interface{}) ([]reflect.StructField, error) { + _, ok := isStruct(v) + if !ok { + return nil, ErrNotAStruct + } + + fields, err := SelectFields(v, + SelectAll{ + FieldSelectorFunc(IsExported), + NegateSelector(FieldSelectorFunc(IsBool)), + NegateSelector(FieldSelectorFunc(HasSkipJSON)), + NegateSelector(ValidateTagKeyFieldSelector), + }, + SelectAny{ + NegateSelector(FieldSelectorFunc(IsExported)), + FieldSelectorFunc(HasSkipJSON), + FieldSelectorFunc(HasSkipValidate), + }, + ) + if err != nil { + return nil, xerrors.Errorf("select fields: %w", err) + } + + return fields, nil +} + +// FieldsWithValidation returns a list of fields with a "validate" tag. +// +// This will recursively check nested structs, stopping at fields that are +// unexported, or fields that do not unmarshal from json. `v` should be a +// struct. +// +// Nested struct fields that do have a "validate" tag will be included in the +// returned list. +func FieldsWithValidation(v interface{}) ([]reflect.StructField, error) { + _, ok := isStruct(v) + if !ok { + return nil, ErrNotAStruct + } + + fields, err := SelectFields(v, + SelectAll{ + FieldSelectorFunc(IsExported), + NegateSelector(FieldSelectorFunc(IsBool)), + NegateSelector(FieldSelectorFunc(HasSkipJSON)), + ValidateTagKeyFieldSelector, + }, + SelectAny{ + NegateSelector(FieldSelectorFunc(IsExported)), + FieldSelectorFunc(HasSkipJSON), + FieldSelectorFunc(HasSkipValidate), + }, + ) + if err != nil { + return nil, xerrors.Errorf("select fields: %w", err) + } + + return fields, nil +} + +type FieldSelector interface { + Matches(field reflect.StructField) bool +} + +type SelectAny []FieldSelector + +func (fs SelectAny) Matches(field reflect.StructField) bool { + for _, f := range fs { + if f.Matches(field) { + return true + } + } + return false +} + +type SelectAll []FieldSelector + +func (fs SelectAll) Matches(field reflect.StructField) bool { + for _, f := range fs { + if !f.Matches(field) { + return false + } + } + return true +} + +// JSONTagValueFieldSelector selects all fields that has a given json tag value. +type JSONTagValueFieldSelector string + +func (fs JSONTagValueFieldSelector) Matches(field reflect.StructField) bool { + tagVal, ok := field.Tag.Lookup(jsonTag) + if !ok { + return false + } + for _, s := range strings.Split(tagVal, ",") { + if s == string(fs) { + return true + } + } + return false +} + +// TagKeyFieldSelector selects all fields with the given tag key. +type TagKeyFieldSelector string + +const ( + ValidateTagKeyFieldSelector TagKeyFieldSelector = validateTag +) + +func (fs TagKeyFieldSelector) Matches(field reflect.StructField) bool { + _, ok := field.Tag.Lookup(string(fs)) + return ok +} + +type FieldSelectorFunc func(field reflect.StructField) bool + +func (fs FieldSelectorFunc) Matches(field reflect.StructField) bool { + return fs(field) +} + +// IsExported checks if the field is exported. +func IsExported(field reflect.StructField) bool { + // PkgPath is empty for exported fields (noted in doc for PkgPath). + return field.PkgPath == "" +} + +// IsBool checks if the field is a bool, or a *bool. +func IsBool(field reflect.StructField) bool { + // Field is a bool. + if field.Type.Kind() == reflect.Bool { + return true + } + // Field is a *bool. + if field.Type.Kind() == reflect.Ptr && field.Type.Elem().Kind() == reflect.Bool { + return true + } + return false +} + +func HasSkipJSON(field reflect.StructField) bool { + jsonVal, jsonFound := field.Tag.Lookup(jsonTag) + skipJSON := jsonFound && jsonVal == "-" + return skipJSON +} + +func HasSkipValidate(field reflect.StructField) bool { + jsonVal, jsonFound := field.Tag.Lookup(validateTag) + skipJSON := jsonFound && jsonVal == "-" + return skipJSON +} + +func NegateSelector(fs FieldSelector) FieldSelector { + return FieldSelectorFunc(func(field reflect.StructField) bool { + return !fs.Matches(field) + }) +} + +// Field validates struct `v`, returning just the validation error for a +// field that Matches FieldSelector. If the selector matches more than one +// field, only the first will be checked. +func Field(v interface{}, fs FieldSelector) error { + _, ok := isStruct(v) + if !ok { + return ErrNotAStruct + } + + err := Validator().Struct(v) + if err == nil { + return nil + } + + var vErrs validator.ValidationErrors + if xerrors.As(err, &vErrs) { + fields, _ := SelectFields(v, fs, nil) // Can only error if `v` isn't a struct. + for _, field := range fields { + for _, vErr := range vErrs { + if field.Name == vErr.StructField() { + return vErr + } + } + } + // Field selector either matched no fields that failed validation, or + // all matched fields passed validation. + return nil + } + + return xerrors.Errorf("non-validation error when validating: %w", err) +} + +// SelectFields selects all fields from struct `v` that match the field +// selector. +// +// This will recurse through nested structs, stopping at fields that are +// selected with `skipFields`. A value of nil for `skipFields` will continue to +// recurse indiscriminately. Infinite recursion is avoided by detecting if a +// field of a struct has the same type as the struct itself. +func SelectFields(v interface{}, fs FieldSelector, skipFields FieldSelector) ([]reflect.StructField, error) { + return selectFieldsWithVisited(v, fs, skipFields, nil) +} + +type fieldType struct { + pkg string + name string +} + +func selectFieldsWithVisited(v interface{}, fs FieldSelector, skipFields FieldSelector, visited []*fieldType) ([]reflect.StructField, error) { + st, ok := isStruct(v) + if !ok { + return nil, ErrNotAStruct + } + + var fields []reflect.StructField + + // Check to make sure we haven't visited this type yet. If we have, there's + // no need to continue. + for _, ft := range visited { + if st.Name() == ft.name && st.PkgPath() == ft.pkg { + return fields, nil + } + } + visited = append(visited, &fieldType{pkg: st.PkgPath(), name: st.Name()}) + + for i := 0; i < st.NumField(); i++ { + field := st.Field(i) + if fs.Matches(field) { + fields = append(fields, field) + } + + if skipFields != nil && skipFields.Matches(field) { + continue + } + + fv := reflect.Zero(field.Type) + if field.Type.Kind() == reflect.Ptr { + fv = reflect.Zero(field.Type.Elem()) + } + if fv.Kind() == reflect.Struct { + nestedFields, err := selectFieldsWithVisited(fv.Interface(), fs, skipFields, visited) + if err != nil { + return nil, xerrors.Errorf("select fields, field: %s: %w", field.Name, err) + } + fields = append(fields, nestedFields...) + } + } + + return fields, nil +} + +// isStruct checks to make sure `v` is either a struct, or a pointer to a +// struct. +func isStruct(v interface{}) (reflect.Type, bool) { + rv := reflect.ValueOf(v) + if rv.Kind() == reflect.Ptr { + rv = rv.Elem() + } + return rv.Type(), rv.Kind() == reflect.Struct +} diff --git a/validate/struct_test.go b/validate/struct_test.go new file mode 100644 index 0000000000000..606a5213d43a4 --- /dev/null +++ b/validate/struct_test.go @@ -0,0 +1,206 @@ +package validate + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestFieldsValidation(t *testing.T) { + t.Parallel() + + t.Run("AllFieldsValidated", func(t *testing.T) { + type s struct { + Field string `validate:"min=3"` + } + var v s + fs, err := FieldsMissingValidation(v) + require.NoError(t, err, "err") + require.Equal(t, 0, len(fs), "num fields") + fs, err = FieldsWithValidation(v) + require.NoError(t, err, "err") + require.Equal(t, 1, len(fs), "num fields") + }) + + t.Run("Pointer", func(t *testing.T) { + type s struct { + Field string `validate:"min=3"` + } + var v s + fs, err := FieldsMissingValidation(&v) + require.NoError(t, err, "err") + require.Equal(t, 0, len(fs), "num fields") + fs, err = FieldsWithValidation(&v) + require.NoError(t, err, "err") + require.Equal(t, 1, len(fs), "num fields") + }) + + t.Run("MissingValidations", func(t *testing.T) { + type s struct { + Field1 string + Field2 string + } + var v s + fs, err := FieldsMissingValidation(v) + require.NoError(t, err, "err") + require.Equal(t, 2, len(fs), "num fields") + fs, err = FieldsWithValidation(v) + require.NoError(t, err, "err") + require.Equal(t, 0, len(fs), "num fields") + }) + + t.Run("UnexportedFields", func(t *testing.T) { + type s struct { + field string + } + var v = s{field: "string"} + fs, err := FieldsMissingValidation(v) + require.NoError(t, err, "err") + require.Equal(t, 0, len(fs), "num fields") + fs, err = FieldsWithValidation(v) + require.NoError(t, err, "err") + require.Equal(t, 0, len(fs), "num fields") + }) + + t.Run("Bools", func(t *testing.T) { + type s struct { + Field1 *bool + Field2 bool + } + var v s + fs, err := FieldsMissingValidation(v) + require.NoError(t, err, "err") + require.Equal(t, 0, len(fs), "num fields") + fs, err = FieldsWithValidation(v) + require.NoError(t, err, "err") + require.Equal(t, 0, len(fs), "num fields") + }) + + t.Run("Nested", func(t *testing.T) { + type nested struct { + Field string `validate:"min=3"` + } + type s struct { + Nested *nested `validate:"required"` + } + var v s + fs, err := FieldsMissingValidation(v) + require.NoError(t, err, "err") + require.Equal(t, 0, len(fs), "num fields") + fs, err = FieldsWithValidation(v) + require.NoError(t, err, "err") + require.Equal(t, 2, len(fs), "num fields") + }) + + t.Run("NestedUnexported", func(t *testing.T) { + // Specifically using time since it has known unexported fields, and is + // from a different package. + type s struct { + Time time.Time `validate:"required"` + } + var v s + fs, err := FieldsMissingValidation(v) + require.NoError(t, err, "err") + require.Equal(t, 0, len(fs), "num fields") + fs, err = FieldsWithValidation(v) + require.NoError(t, err, "err") + require.Equal(t, 1, len(fs), "num fields") + }) +} + +func TestValidateField(t *testing.T) { + t.Parallel() + + t.Run("NoMatch", func(t *testing.T) { + type s struct { + Field string `validate:"min=3"` + } + err := Field(s{}, JSONTagValueFieldSelector("hello")) + require.NoError(t, err, "validate") + }) + + t.Run("MatchValid", func(t *testing.T) { + type s struct { + Field string `json:"hello" validate:"min=3"` + } + err := Field(s{Field: "world"}, JSONTagValueFieldSelector("hello")) + require.NoError(t, err, "validate") + }) + + t.Run("MatchInvalid", func(t *testing.T) { + type s struct { + Field string `json:"hello" validate:"min=3"` + } + err := Field(s{Field: "hi"}, JSONTagValueFieldSelector("hello")) + require.Error(t, err, "validate") + }) +} + +func TestSelectFields(t *testing.T) { + t.Parallel() + + t.Run("Some", func(t *testing.T) { + type s struct { + Field1 string `bogus:"bogus"` + Field2 string + } + fs := TagKeyFieldSelector("bogus") + fields, err := SelectFields(s{}, fs, nil) + require.NoError(t, err, "select") + require.Equal(t, 1, len(fields), "num fields") + }) + + t.Run("Nested", func(t *testing.T) { + type nested struct { + Field1 string `bogus:"bogus"` + } + type s struct { + Field1 string `bogus:"bogus"` + Nested *nested + } + fs := TagKeyFieldSelector("bogus") + fields, err := SelectFields(s{}, fs, nil) + require.NoError(t, err, "select") + require.Equal(t, 2, len(fields), "num fields") + }) + + t.Run("Embedded", func(t *testing.T) { + type embedded struct { + Field1 string `bogus:"bogus"` + } + type s struct { + embedded + } + fs := TagKeyFieldSelector("bogus") + fields, err := SelectFields(s{}, fs, nil) + require.NoError(t, err, "select") + require.Equal(t, 1, len(fields), "num fields") + }) + + t.Run("InfiniteRecursion", func(t *testing.T) { + type s struct { + Field *s `bogus:"bogus"` + } + fs := TagKeyFieldSelector("bogus") + fields, err := SelectFields(s{}, fs, nil) + require.NoError(t, err, "select") + require.Equal(t, 1, len(fields), "num fields") + }) +} + +func TestValidateStruct(t *testing.T) { + t.Run("Anonymous", func(t *testing.T) { + // Test validate on anonymous fields + type s struct { + json.RawMessage `validate:"min=4"` + } + + // Not enough bytes + v := s{RawMessage: []byte("{}}")} + + err := Validator().Struct(v) + require.Error(t, err, "validate") + }) +} diff --git a/validate/user.go b/validate/user.go new file mode 100644 index 0000000000000..47560756089db --- /dev/null +++ b/validate/user.go @@ -0,0 +1,37 @@ +package validate + +import ( + "regexp" + + "golang.org/x/xerrors" +) + +const ( + // UsernameMaxLength is the maximum length a username can be. + UsernameMaxLength = 32 +) + +// Matches alphanumeric usernames with `-`, but not consecutively. +var usernameRx = regexp.MustCompile("^[a-zA-Z0-9]+(?:-[a-zA-Z0-9]+)*$") + +var ErrInvalidUsernameRegex = xerrors.Errorf("username must conform to regex %v", usernameRx.String()) +var ErrUsernameTooLong = xerrors.Errorf("usernames must be a maximum length of %d", UsernameMaxLength) + +// Username validates the string provided to be a valid Coder username. +// Coder usernames follow GitHub's username rules. Here are the rules: +// 1. Must be alphanumeric. +// 2. Minimum length of 1, maximum of 32. +// 3. Cannot start with a hyphen. +// 4. Cannot include consecutive hyphens. +func Username(s string) error { + if len(s) > UsernameMaxLength { + return ErrUsernameTooLong + } + if len(s) < 1 { + return ErrInvalidUsernameRegex + } + if !usernameRx.MatchString(s) { + return ErrInvalidUsernameRegex + } + return nil +} diff --git a/validate/user_test.go b/validate/user_test.go new file mode 100644 index 0000000000000..f94e2e4c7eb76 --- /dev/null +++ b/validate/user_test.go @@ -0,0 +1,65 @@ +package validate + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_Username(t *testing.T) { + t.Parallel() + testCases := []struct { + Input string + Err error + }{ + {"1", nil}, + {"12", nil}, + {"123", nil}, + {"12345678901234567890", nil}, + {"123456789012345678901", nil}, + {"a", nil}, + {"a1", nil}, + {"a1b2", nil}, + {"a1b2c3d4e5f6g7h8i9j0", nil}, + {"a1b2c3d4e5f6g7h8i9j0k", nil}, + {"aa", nil}, + {"abc", nil}, + {"abcdefghijklmnopqrst", nil}, + {"abcdefghijklmnopqrstu", nil}, + {"wow-test", nil}, + + {"", ErrInvalidUsernameRegex}, + {" ", ErrInvalidUsernameRegex}, + {" a", ErrInvalidUsernameRegex}, + {" a ", ErrInvalidUsernameRegex}, + {" 1", ErrInvalidUsernameRegex}, + {"1 ", ErrInvalidUsernameRegex}, + {" aa", ErrInvalidUsernameRegex}, + {"aa ", ErrInvalidUsernameRegex}, + {" 12", ErrInvalidUsernameRegex}, + {"12 ", ErrInvalidUsernameRegex}, + {" a1", ErrInvalidUsernameRegex}, + {"a1 ", ErrInvalidUsernameRegex}, + {" abcdefghijklmnopqrstu", ErrInvalidUsernameRegex}, + {"abcdefghijklmnopqrstu ", ErrInvalidUsernameRegex}, + {" 123456789012345678901", ErrInvalidUsernameRegex}, + {" a1b2c3d4e5f6g7h8i9j0k", ErrInvalidUsernameRegex}, + {"a1b2c3d4e5f6g7h8i9j0k ", ErrInvalidUsernameRegex}, + {"bananas_wow", ErrInvalidUsernameRegex}, + {"test--now", ErrInvalidUsernameRegex}, + + {"123456789012345678901234567890123", ErrUsernameTooLong}, + {"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", ErrUsernameTooLong}, + {"123456789012345678901234567890123123456789012345678901234567890123", ErrUsernameTooLong}, + } + for _, testCase := range testCases { + t.Run(testCase.Input, func(t *testing.T) { + if testCase.Err == nil { + require.NoError(t, Username(testCase.Input), fmt.Sprintf("username %q should be valid", testCase.Input)) + } else { + require.Equal(t, Username(testCase.Input), testCase.Err, fmt.Sprintf("username %q should not be valid", testCase.Input)) + } + }) + } +} diff --git a/validate/validator.go b/validate/validator.go new file mode 100644 index 0000000000000..df2034512237a --- /dev/null +++ b/validate/validator.go @@ -0,0 +1,46 @@ +package validate + +import ( + "github.com/go-playground/validator/v10" + "golang.org/x/xerrors" +) + +func init() { + validate = validator.New() + mustRegisterValidation(validate, "longid", validateLongID) +} + +func mustRegisterValidation(v *validator.Validate, tag string, fn validator.Func) { + if err := v.RegisterValidation(tag, fn); err != nil { + panic(xerrors.Errorf("register validation: %w", err)) + } +} + +// Global validation struct +// +// Custom validators should be added to this struct if needed (see +// https://github.com/go-playground/validator/blob/master/_examples/custom-validation/main.go +// for an example). +var validate *validator.Validate + +// Validator returns a copy of the global validator. +func Validator() *validator.Validate { + v := *validate + return &v +} + +// validateLongID validates that a field is a string, and that the string does +// not exceed the max length of a long ID. +// +// Additional formatting checks are omitted as there's a lot of tests that don't +// use actual long IDs when generating test requests, and the system admin's ID +// also does not follow the long ID format. +func validateLongID(fl validator.FieldLevel) bool { + f := fl.Field().Interface() + s, ok := f.(string) + if !ok { + return false + } + const longIDLen = 33 // The format string for a long id is "%08x-%024x" + return len(s) <= longIDLen +} diff --git a/validate/validator_test.go b/validate/validator_test.go new file mode 100644 index 0000000000000..7afb9245be09d --- /dev/null +++ b/validate/validator_test.go @@ -0,0 +1,46 @@ +package validate + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/longid" +) + +func TestValidateLongID(t *testing.T) { + t.Parallel() + + t.Run("OK", func(t *testing.T) { + type myStruct struct { + ID string `validate:"longid"` + } + v := &myStruct{ + ID: longid.New().String(), + } + err := Validator().Struct(v) + require.NoError(t, err, "validate") + }) + + t.Run("Invalid", func(t *testing.T) { + type myStruct struct { + ID string `validate:"longid"` + } + v := &myStruct{ + ID: longid.New().String() + "hello", + } + err := Validator().Struct(v) + require.Error(t, err, "unexpectedly validated") + }) + + t.Run("WrongType", func(t *testing.T) { + type myStruct struct { + ID int `validate:"longid"` + } + v := &myStruct{ + ID: 123, + } + err := Validator().Struct(v) + require.Error(t, err, "unexpectedly validated") + }) +} diff --git a/validate/vassert/assert.go b/validate/vassert/assert.go new file mode 100644 index 0000000000000..974d9b8c6c063 --- /dev/null +++ b/validate/vassert/assert.go @@ -0,0 +1,68 @@ +package vassert + +import ( + "testing" + + "github.com/coder/coder/validate" +) + +// Tags asserts that most fields on a struct with a "json" tag also have a +// "validate" tag, unless the json tag value is "-". +// +// This will recursively check nested structs. +// +// Boolean values and boolean pointers do not require a validate tag. +// +// `v` should be a struct. +func Tags(t *testing.T, v interface{}) { + t.Helper() + fields, err := validate.FieldsMissingValidation(v) + if err != nil { + t.Fatalf("failed to get missing field validations: %s", err) + } + if len(fields) > 0 { + names := make([]string, len(fields)) + for i, f := range fields { + names[i] = f.Name + } + t.Fatalf("the following fields are missing validations: %v", names) + } +} + +// FieldValid asserts that the field with the correspting `jsonField` +// value validates. +// +// `v` should be a struct. +func FieldValid(t *testing.T, v interface{}, jsonField string) { + t.Helper() + ensureHasJSONField(t, v, jsonField) + err := validate.Field(v, validate.JSONTagValueFieldSelector(jsonField)) + if err != nil { + t.Fatalf("expected field %q to validate: %v", jsonField, err) + } +} + +// FieldInvalid asserts that the field with the correspting `jsonField` +// value does not validate. +// +// `v` should be a struct. +func FieldInvalid(t *testing.T, v interface{}, jsonField string) { + t.Helper() + ensureHasJSONField(t, v, jsonField) + err := validate.Field(v, validate.JSONTagValueFieldSelector(jsonField)) + if err == nil { + t.Fatalf("expected field %q to be invalid", jsonField) + } +} + +func ensureHasJSONField(t *testing.T, v interface{}, jsonField string) { + t.Helper() + fs := validate.JSONTagValueFieldSelector(jsonField) + fields, err := validate.SelectFields(v, fs, nil) + if err != nil { + t.Fatalf("failed to select fields: %v", err) + } + if len(fields) == 0 { + t.Fatalf("%q matches no fields on the struct", jsonField) + } +} diff --git a/validate/vassert/doc.go b/validate/vassert/doc.go new file mode 100644 index 0000000000000..413342a68cc13 --- /dev/null +++ b/validate/vassert/doc.go @@ -0,0 +1,6 @@ +// Package vassert provides testing utilities for asserting validations for +// request bodies. +// +// All functions in this package assume the use of +// https://github.com/go-playground/validator. +package vassert diff --git a/xjson/duration.go b/xjson/duration.go new file mode 100644 index 0000000000000..32f296314d196 --- /dev/null +++ b/xjson/duration.go @@ -0,0 +1,32 @@ +package xjson + +import ( + "encoding/json" + "strconv" + "time" +) + +// Duration is a time.Duration that marshals to millisecond precision. +// Most javascript applications expect durations to be in milliseconds. +// Although this would typically be a time.Duration it was changed to +// an int64 to avoid errors in the swaggo/swag tool we use to auto -generate +// documentation. +type Duration int64 + +// MarshalJSON marshals the duration to millisecond precision. +func (d Duration) MarshalJSON() ([]byte, error) { + du := time.Duration(d) + return json.Marshal(du.Milliseconds()) +} + +// UnmarshalJSON unmarshals a millisecond-precision integer to +// a time.Duration. +func (d *Duration) UnmarshalJSON(b []byte) error { + i, err := strconv.ParseInt(string(b), 10, 64) + if err != nil { + return err + } + + *d = Duration(time.Duration(i) * time.Millisecond) + return nil +} diff --git a/xjson/duration_test.go b/xjson/duration_test.go new file mode 100644 index 0000000000000..e7aff2990eddf --- /dev/null +++ b/xjson/duration_test.go @@ -0,0 +1,22 @@ +package xjson + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestDuration(t *testing.T) { + t.Run("MarshalUnmarshalJSON", func(t *testing.T) { + var dur = Duration(time.Hour) + b, err := json.Marshal(dur) + require.NoError(t, err, "marshal duration") + + var unmarshalDur Duration + err = json.Unmarshal(b, &unmarshalDur) + require.NoError(t, err, "unmarshal duration") + require.Equal(t, dur, unmarshalDur, "Did not parse to milliseconds") + }) +} diff --git a/xjson/error.go b/xjson/error.go new file mode 100644 index 0000000000000..84711cb672b50 --- /dev/null +++ b/xjson/error.go @@ -0,0 +1,207 @@ +package xjson + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/coder/coder/srverr" +) + +const ( + // codeVerbose indicates a details object with a 'verbose' field + // exists in the error response. + codeVerbose srverr.Code = "verbose" + // codeEmpty indicates that no details object exists. + codeEmpty srverr.Code = "empty" + // codeSolution indicates the details field has a payload for the + // error and has a solution to resolve the error. + codeSolution srverr.Code = "solution" +) + +// WriteBadRequestWithCode writes a 400 to the response using a custom code, msg, and json marshaled details +func WriteBadRequestWithCode(w http.ResponseWriter, code srverr.Code, humanMsg string, details interface{}) { + Write(w, http.StatusBadRequest, errorResponse{ + Error: errorPayload{ + Msg: humanMsg, + Code: code, + Details: details, + }, + }) +} + +// WriteBadRequest writes a 400 to the response. +func WriteBadRequest(w http.ResponseWriter, humanMsg string) { + WriteError(w, http.StatusBadRequest, humanMsg, nil) +} + +// WriteUnauthorized writes a 401 to the response. +func WriteUnauthorized(w http.ResponseWriter, humanMsg string) { + WriteError(w, http.StatusUnauthorized, humanMsg, nil) +} + +// WriteForbidden writes a 403 to the response. +func WriteForbidden(w http.ResponseWriter, humanMsg string) { + WriteError(w, http.StatusForbidden, humanMsg, nil) +} + +// WriteConflict writes a 409 to the response. +func WriteConflict(w http.ResponseWriter, humanMsg string) { + WriteError(w, http.StatusConflict, humanMsg, nil) +} + +// WritePreconditionFailed writes a 412 to the response. If the err is non-nil +// a verbose field is written with the contents of the error. +func WritePreconditionFailed(w http.ResponseWriter, humanMsg string, err error) { + WriteError(w, http.StatusPreconditionFailed, humanMsg, err) +} + +func WriteErrorWithSolution(w http.ResponseWriter, statusCode int, humanMsg string, solution string, err error) { + Write(w, statusCode, errorResponse{ + Error: errorPayload{ + Msg: humanMsg, + Code: codeSolution, + Details: detailsPrecondition{ + Message: humanMsg, + Error: err.Error(), + Solution: solution, + Verbose: err.Error(), //nolint:deprecated + }, + }, + }) +} + +// WriteFieldedPreconditionFailed writes a 412 to the response and the +// proper json fielded payload for decoding the error + solution +func WriteFieldedPreconditionFailed(w http.ResponseWriter, humanMsg string, solution string, err error) { + WriteErrorWithSolution(w, http.StatusPreconditionFailed, humanMsg, solution, err) +} + +// WriteNotFound writes a 404 to the response. It returns a generic public +// message such as "Environment not found." using the provided resource. +func WriteNotFound(w http.ResponseWriter, resource string) { + WriteError(w, http.StatusNotFound, fmt.Sprintf("%s not found.", resource), nil) +} + +// WriteCustomNotFound writes a 400 to the response. +func WriteCustomNotFound(w http.ResponseWriter, humanMsg string) { + WriteError(w, http.StatusNotFound, humanMsg, nil) +} + +// WriteInternalServerError writes a 500 to the response. It uses a generic +// message as the public message and writes the error the 'verbose' field +// in 'details' if it is non-nil. +func WriteInternalServerError(w http.ResponseWriter, err error) { + WriteCustomInternalServerError(w, "An internal server error occurred.", err) +} + +// WriteCustomInternalServerError writes a 500 to the response. Instead of the +// generic "An internal server error" occurred, the provided humanMsg is used. +func WriteCustomInternalServerError(w http.ResponseWriter, humanMsg string, err error) { + WriteError(w, http.StatusInternalServerError, humanMsg, err) +} + +// WriteError is a generic endpoint for writing error responses. If err is non-nil +// a 'verbose' field is written to the 'details' object. +func WriteError(w http.ResponseWriter, status int, humanMsg string, err error) { + Write(w, status, defaultErrorParams{ + msg: humanMsg, + verbose: err, + }) +} + +// defaultErrorParams contains common parameters across most error responses. +// Since the nature of the error payload is nested this type exists to allow +// assigning the values to a friendly, flat type. +type defaultErrorParams struct { + msg string + verbose error +} + +// MarshalJSON marshals the default error parameters into our structured error +// response. +func (d defaultErrorParams) MarshalJSON() ([]byte, error) { + payload := errorResponse{ + Error: errorPayload{ + Msg: d.msg, + Code: codeEmpty, + }, + } + if d.verbose != nil { + payload.Error.Code = codeVerbose + payload.Error.Details = detailsVerbose{ + Verbose: d.verbose.Error(), + } + } + + return json.Marshal(payload) +} + +// detailsVerbose is a simple object that can be assigned to the 'details' +// field of an erro response. It contains a more verbose explanation of the +// error. It tends to be the raw output of err.Error(). +type detailsVerbose struct { + Verbose string `json:"verbose,omitempty"` +} + +// detailsPrecondition is a details object that should be paired with 412 status +// codes. It contains the Go error, a human message, and a solution note. +type detailsPrecondition struct { + // Error is err.Error() and from Go + Error string `json:"error"` + // Message is the human readable error message + Message string `json:"message"` + // Solution is a helpful hint on how to solve the error + Solution string `json:"solution"` + + // Verbose is a copy of Error. + // Deprecated: Should remove this field, but the ui expects 'verbose' messages + // still and have not been moved to use the new fields for this error type. + Verbose string `json:"verbose,omitempty"` +} + +// errorResponse is the root of the error payload we send for status codes 400 +// and above. +type errorResponse struct { + Error errorPayload `json:"error"` +} + +// errorPayload contains the contents of an error response. +type errorPayload struct { + // Msg is a human-readable message. + Msg string `json:"msg"` + // Code dictates the structure of the details field. + Code srverr.Code `json:"code"` + // Details is an arbitrary object containing extra information + // on a particular error. Its structure is dictated by Code. + Details interface{} `json:"details,omitempty"` +} + +// HTTPError represents an error from the Coder API. +type HTTPError struct { + *http.Response + // we can't read the body lazily when Error is invoked + // so this must be populated at construction + Body []byte +} + +var _ error = &HTTPError{} + +// Error implements error. +func (e *HTTPError) Error() string { + var msg errorResponse + // Try to decode the payload as an error, if it fails or if there is no error message, + // return the response URL with the status. + if err := json.Unmarshal(e.Body, &msg); err != nil || msg.Error.Msg == "" { + return fmt.Sprintf("%s: %d %s", e.Request.URL, e.StatusCode, e.Status) + } + + // If the payload was a in the expected error format with a message, include it. + return msg.Error.Msg +} + +func BodyError(resp *http.Response) *HTTPError { + body, _ := io.ReadAll(resp.Body) + return &HTTPError{Response: resp, Body: body} +} diff --git a/xjson/error_page.html b/xjson/error_page.html new file mode 100644 index 0000000000000..5a0dc52b3d638 --- /dev/null +++ b/xjson/error_page.html @@ -0,0 +1,117 @@ + + + + + + + + + + +
+ +
+
+ +

{{status .}}

+

{{.Msg}}

+ {{if .DevURL}} + Retry + {{end}} + + {{if .Err}} +

{{.Err}}

+ {{end}} Back to Site +
+
+
+
+ diff --git a/xjson/error_test.go b/xjson/error_test.go new file mode 100644 index 0000000000000..017dd60084655 --- /dev/null +++ b/xjson/error_test.go @@ -0,0 +1,238 @@ +package xjson + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/net/html" + "golang.org/x/xerrors" + + "github.com/coder/coder/srverr" +) + +func TestDefaultErrorParams(t *testing.T) { + t.Parallel() + + t.Run("VerboseDetails", func(t *testing.T) { + var ( + testMsg = "Testing." + testErr = xerrors.Errorf("testing verbose") + ) + p := defaultErrorParams{ + msg: testMsg, + verbose: testErr, + } + + actualResponse, err := p.MarshalJSON() + require.NoError(t, err, "marshal error params") + + expectedResponse, err := json.Marshal(errorResponse{ + Error: errorPayload{ + Msg: testMsg, + Code: codeVerbose, + Details: detailsVerbose{ + Verbose: testErr.Error(), + }, + }, + }) + require.NoError(t, err, "marshal expected response") + require.Equal(t, expectedResponse, actualResponse, "responses differ") + }) + + t.Run("EmptyDetails", func(t *testing.T) { + var ( + testMsg = "Testing." + ) + p := defaultErrorParams{ + msg: testMsg, + } + + actualResponse, err := p.MarshalJSON() + require.NoError(t, err, "marshal error params") + + expectedResponse, err := json.Marshal(errorResponse{ + Error: errorPayload{ + Msg: testMsg, + Code: codeEmpty, + }, + }) + require.NoError(t, err, "marshal expected response") + require.Equal(t, string(expectedResponse), string(actualResponse), "responses differ") + }) +} + +// Checking Xjson errors write the correct code +func Test_JsonErrors(t *testing.T) { + t.Parallel() + const testMessage = "test message" + + vs := []struct { + Name string + Write func(w http.ResponseWriter) + ExpectedStatusCode int + ErrorCode srverr.Code + RespContains string + }{ + { + Name: "CustomBadRequest", + Write: func(w http.ResponseWriter) { + WriteBadRequestWithCode(w, "test", testMessage, nil) + }, + ExpectedStatusCode: http.StatusBadRequest, + RespContains: testMessage, + ErrorCode: "test", + }, + { + Name: "BadRequest", + Write: func(w http.ResponseWriter) { + WriteBadRequest(w, testMessage) + }, + ExpectedStatusCode: http.StatusBadRequest, + RespContains: testMessage, + }, + { + Name: "Unauthorized", + Write: func(w http.ResponseWriter) { + WriteUnauthorized(w, testMessage) + }, + ExpectedStatusCode: http.StatusUnauthorized, + RespContains: testMessage, + }, + { + Name: "Forbidden", + Write: func(w http.ResponseWriter) { + WriteForbidden(w, testMessage) + }, + ExpectedStatusCode: http.StatusForbidden, + RespContains: testMessage, + }, + { + Name: "Conflict", + Write: func(w http.ResponseWriter) { + WriteConflict(w, testMessage) + }, + ExpectedStatusCode: http.StatusConflict, + RespContains: testMessage, + }, + { + Name: "PreconditionFailed", + Write: func(w http.ResponseWriter) { + WritePreconditionFailed(w, testMessage, xerrors.New("random")) + }, + ExpectedStatusCode: http.StatusPreconditionFailed, + RespContains: testMessage, + ErrorCode: codeVerbose, + }, + { + Name: "FieldedPreconditionFailed", + Write: func(w http.ResponseWriter) { + WriteFieldedPreconditionFailed(w, testMessage, "this is a solution", xerrors.New("random")) + }, + ExpectedStatusCode: http.StatusPreconditionFailed, + RespContains: testMessage, + ErrorCode: codeSolution, + }, + { + Name: "NotFound", + Write: func(w http.ResponseWriter) { + WriteNotFound(w, testMessage) + }, + ExpectedStatusCode: http.StatusNotFound, + RespContains: testMessage, + }, + { + Name: "CustomNotFound", + Write: func(w http.ResponseWriter) { + WriteCustomNotFound(w, testMessage) + }, + ExpectedStatusCode: http.StatusNotFound, + RespContains: testMessage, + }, + { + Name: "CustomServerError", + Write: func(w http.ResponseWriter) { + WriteCustomInternalServerError(w, testMessage, xerrors.New("random")) + }, + ExpectedStatusCode: http.StatusInternalServerError, + RespContains: testMessage, + ErrorCode: codeVerbose, + }, + { + Name: "ServerError", + Write: func(w http.ResponseWriter) { + WriteInternalServerError(w, xerrors.New(testMessage)) + }, + ExpectedStatusCode: http.StatusInternalServerError, + RespContains: "server error", + ErrorCode: codeVerbose, + }, + } + + for _, v := range vs { + if v.ErrorCode == "" { + v.ErrorCode = codeEmpty + } + t.Run(v.Name, func(t *testing.T) { + w := httptest.NewRecorder() + v.Write(w) + + resp := w.Result() + require.Equal(t, v.ExpectedStatusCode, resp.StatusCode, "BadRequest") + + respErr := BodyError(resp) + _ = resp.Body.Close() + + // Assure the body is a full json payload + var eResp errorResponse + err := json.Unmarshal(respErr.Body, &eResp) + + require.True(t, strings.Contains(respErr.Error(), v.RespContains), "contains") + + require.NoError(t, err, "body decode") + require.Equal(t, v.ErrorCode, eResp.Error.Code, "correct code") + }) + } +} + +// Test_EmptyHTTPError checks to ensure the error works on empty responses. +// Athought the response is not an error code, the BodyError should still wrap the response data +// corectly. +func Test_EmptyHTTPError(t *testing.T) { + t.Parallel() + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + t.Cleanup(srv.Close) + + r, err := http.NewRequestWithContext(context.Background(), "GET", srv.URL, nil) + require.NoError(t, err, "new request") + resp, err := http.DefaultClient.Do(r) + require.NoError(t, err, "GET request") + + e := BodyError(resp) + require.Contains(t, e.Error(), strconv.Itoa(resp.StatusCode), "has status code") +} + +// TestHTMLErrorPage tests that the embedded page is valid HTML. +func TestHTMLErrorPage(t *testing.T) { + t.Parallel() + + recorder := httptest.NewRecorder() + + WriteErrPage(recorder, ErrPage{ + DevURL: "https://*.master.cdr.dev", + AccessURL: "https://master.cdr.dev", + }) + + node, err := html.Parse(recorder.Body) + require.NoError(t, err, "HTML error page does not appear valid") + require.NotNil(t, node, "require the node to be non-nil") + require.Nil(t, node.Parent, "parent should be nil (root element)") + require.Equal(t, html.DocumentNode, node.Type, "node is document") +} diff --git a/xjson/json.go b/xjson/json.go new file mode 100644 index 0000000000000..e3505067a3cf9 --- /dev/null +++ b/xjson/json.go @@ -0,0 +1,372 @@ +package xjson + +import ( + "context" + _ "embed" + "encoding/json" + "fmt" + "html/template" + "io" + "net/http" + "reflect" + "runtime" + "strings" + + "github.com/asaskevich/govalidator" + "github.com/go-playground/validator/v10" + "golang.org/x/xerrors" + + "github.com/coder/coder/buildmode" + "github.com/coder/coder/validate" + + "cdr.dev/slog" +) + +// This contains the raw mark-up for our dynamic server-side error page. +// Reasons for using a string literal: +// +// 1. It's a small/simple amount of markup +// 2. We avoid possible file-path errors from files moving around +// 3. More performant than doing file i/o +// 4. Development will be easier because it becomes hot-swappable +//go:embed error_page.html +var errPageMarkup string + +// m is a a helper struct for marshaling arbitrary json. +type m map[string]interface{} + +// SuccessMsg contains a single 'msg' field +// with an arbitrary value indicating success. +type SuccessMsg struct { + Msg string `json:"msg"` +} // @name SuccessMsg + +// Write writes a json response. +func Write(w http.ResponseWriter, status int, body interface{}) { + if body == nil { + w.WriteHeader(status) + return + } + + if strBody, ok := body.(string); ok { + body = SuccessMsg{ + Msg: strBody, + } + } + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(status) + + err := encodeBody(w, body) + if err != nil { + // We can't write to hijacked connections. Don't panic in that + // case. + if xerrors.Is(err, http.ErrHijacked) { + return + } + + panic(err) + } +} + +func encodeBody(w io.Writer, body interface{}) error { + enc := json.NewEncoder(w) + // Format the response nicely. + enc.SetIndent("", "\t") + enc.SetEscapeHTML(false) + return enc.Encode(body) +} + +// ErrPage is used for writing error templates +// to the response writer using WriteErrPage. +type ErrPage struct { + DevURL string + AccessURL string + Msg string + Code int + Err error +} + +// WriteErrPage writes error templates to w after dynamically constructing it +// based on the contents of p. +// +// If p.Code == 0: +// The status code will default to http.StatusInternalServerError. +// +// If p.Msg == "": +// The error message will default to the status text of the status code. + +// If p.Err == nil: +// The error will not render. It is optional to provide a value for p.Err since +// p.Msg is rendered as the public-facing error that the user will see. p.Err +// can be used for development debugging purposes. +// +// If p.AccessURL == "": +// The Back to Site button on the page won't work. AccessURL should have it's +// value set to a database.ConfigGeneral.AccessURL.URL. +// +// If p.DevURL == "": +// The retry button linking back to the dev url will not appear on the rendered page. +func WriteErrPage(w http.ResponseWriter, p ErrPage) { + if p.Code == 0 { + p.Code = http.StatusInternalServerError + } + + if p.Msg == "" { + p.Msg = http.StatusText(p.Code) + } + + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(p.Code) + + t, err := template.New("").Funcs( + template.FuncMap{ + "status": func(p ErrPage) string { + return fmt.Sprintf("%d - %s", p.Code, http.StatusText(p.Code)) + }, + }, + ).Parse(errPageMarkup) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Load the contents of p into the template then write the template to w. + if err := t.ExecuteTemplate(w, "", p); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } +} + +// ErrUnauthorized is an error returned when a user tries to access a resource without +// having sufficient permissions. +var ErrUnauthorized = xerrors.New("Insufficient permissions to access resource") + +// WriteUnauthorizedError writes out an error formatted in JSON +// about the user not having sufficient permissions to access a resource. +func WriteUnauthorizedError(w http.ResponseWriter) { + WriteUnauthorized(w, ErrUnauthorized.Error()) +} + +// DatabaseError writes 500 with a database error message to the response writer, +// and logs details about the error. +func DatabaseError(ctx context.Context, log slog.Logger, w http.ResponseWriter, err error) { + // To maintain backwards-compatible behavior we do not write the database error to + // the details. + WriteCustomInternalServerError(w, "A database error occurred.", nil) + slog.Helper() + log.Error(ctx, "A database error occurred.", slog.Error(err)) +} + +// ServerError writes a 500 with the message to the response writer, and logs details +// about the error. +func ServerError(ctx context.Context, log slog.Logger, w http.ResponseWriter, err error, msg string) { + WriteInternalServerError(w, nil) + slog.Helper() + log.Error(ctx, "server error", + slog.F("msg", msg), + slog.Error(err), + ) +} + +// validatorErrorMessage constructs a human readable message from a validation error. +func validatorErrorMessage(err govalidator.Error) string { + switch { + case err.Validator == "required": + return fmt.Sprintf("Field %q is required.", err.Name) + default: + return fmt.Sprintf("Field %q is invalid (%v).", err.Name, err.Err.Error()) + } +} + +// convertValidationErrors converts govalidator errors into structured JSON. +// Each entry of the returned []m has at least `msg` set. +func convertValidationErrors(errs govalidator.Errors) []m { + var r []m + + for _, err := range errs { + switch e := err.(type) { // nolint: errorlint + case govalidator.Errors: + // For some reason govalidator nests another Errors sometimes. + // Let's just flatten and append it. + r = append(r, convertValidationErrors(e)...) + case govalidator.Error: + r = append(r, m{ + "msg": validatorErrorMessage(e), + "field": e.Name, + "error": e.Err.Error(), + "validator": e.Validator, + }) + default: + r = append(r, m{ + "msg": err, + // type is provided to aid in debugging. It offers no contract. + "type": reflect.TypeOf(err).String(), + }) + } + } + + return r +} + +// ReadBody reads a json object from the request body. If the read fails, a 400 +// is sent back to the client, and this will return false. +// +// To ensure proper validation during development, this function will fatal if +// the current build mode is "dev", if there's at least one field with the +// "validate" tag, and there's additional fields on the struct that are not +// validated (in accordance to `validate.FieldsMissingValidation`). +func ReadBody(log slog.Logger, w http.ResponseWriter, r *http.Request, v interface{}) bool { + err := json.NewDecoder(r.Body).Decode(v) + if err != nil { + log.Warn(r.Context(), "failed to read body", slog.Error(err)) + WriteError(w, http.StatusBadRequest, "Failed to read body.", err) + return false + } + + if buildmode.Dev() { + mustConsistentlyValidate(r.Context(), log, v) + } + + return Validate(w, v) +} + +func mustConsistentlyValidate(ctx context.Context, log slog.Logger, v interface{}) { + // Only make the logger if we need it. + logger := func() slog.Logger { + // Add caller and object type to know where to find struct that is failing. + _, file, line, _ := runtime.Caller(3) + return log.With( + slog.F("v", v), + slog.F("type", reflect.TypeOf(v).String()), + slog.F("caller", fmt.Sprintf("%s:%d", file, line)), + ) + } + + // Get explicitly tagged fields. + explicit, err := validate.FieldsWithValidation(v) + // Errors only if `v` isn't a struct. + if err != nil { + logger().Debug(ctx, "failed to check for fields with validation", slog.Error(err)) + return + } + if len(explicit) > 0 { + notValidated, err := validate.FieldsMissingValidation(v) + if err != nil { + logger().Debug(ctx, "failed to check for fields missing validation", slog.Error(err)) + return + } + if len(notValidated) > 0 { + logger().Fatal(ctx, "some fields missing validation", + slog.F("explicitly_validated", explicit), + slog.F("not_validated", notValidated), + ) + } + } +} + +func summarizeValidationErrors(subErrors []m) string { + var sb strings.Builder + _, _ = fmt.Fprint(&sb, "Input validation failed.") + for _, v := range subErrors { + _, _ = fmt.Fprint(&sb, "\n", "• ", v["msg"]) + } + return sb.String() +} + +// summarizeFieldErrors formats an error string suitable for displaying directly +// to the user from field validation errors. +func summarizeFieldErrors(errs []validator.FieldError) string { + var sb strings.Builder + for i, err := range errs { + // Error is decent enough. Will produce a string in the form of: + // "Key: '%s' Error:Field validation for '%s' failed on the '%s' tag" + _, _ = sb.WriteString(err.Error()) + if i != len(errs)-1 { + _, _ = sb.WriteString(", ") + } + } + + return sb.String() +} + +// convertFieldErrors converts field errors into structured JSON. Each entry of +// the returned []m has at least `msg` set. +func convertFieldErrors(errs []validator.FieldError) []m { + ms := make([]m, len(errs)) + for i, err := range errs { + ms[i] = m{ + "msg": err.Error(), + } + } + + return ms +} + +// Validate will call `Check` on the provided value, and write an appropriate +// response if validation fails. +func Validate(w http.ResponseWriter, v interface{}) bool { + if err := Check(v); err != nil { + WriteError(w, http.StatusBadRequest, "Request failed to validate.", err) + return false + } + return true +} + +type checkError struct { + Message string `json:"msg"` + Errors []m `json:"validation_errors,omitempty"` +} + +func (e *checkError) Error() string { + return fmt.Sprintf("%s: [%s]", e.Message, summarizeValidationErrors(e.Errors)) +} + +// Check runs validation on fields of a struct. If the type passed in is not a +// struct, no validation will be done. +// +// If a struct field has a "valid" tag, asaskevich/govalidator will be used for +// validation. If a struct field has a "validate" tag, go-playground/validator +// will be used for validation. +func Check(v interface{}) error { + // govalidator returns an error if the type isn't struct or *struct. + rv := reflect.ValueOf(v) + for rv.Kind() == reflect.Ptr || rv.Kind() == reflect.Interface { + rv = rv.Elem() + } + if rv.Kind() != reflect.Struct { + return nil + } + + // Validate using go-playground/validator first. + if err := validate.Validator().Struct(v); err != nil { + var vErrs validator.ValidationErrors + if xerrors.As(err, &vErrs) { + return &checkError{ + Message: summarizeFieldErrors(vErrs), + Errors: convertFieldErrors(vErrs), + } + } + return &checkError{Message: fmt.Sprintf("input validation: %s", err.Error())} + } + + // Validate using asaskevich/govalidator after the above. Eventually this + // should be removed when all validation is switched over. + if ok, err := govalidator.ValidateStruct(v); err != nil { + var gve govalidator.Errors + if xerrors.As(err, &gve) { + verrs := convertValidationErrors(gve) + + return &checkError{ + Message: summarizeValidationErrors(verrs), + Errors: verrs, + } + } + + return &checkError{Message: fmt.Sprintf("input validation: %s", err.Error())} + } else if !ok { + return &checkError{Message: "input validation failed"} + } + + return nil +} diff --git a/xjson/json_test.go b/xjson/json_test.go new file mode 100644 index 0000000000000..3550881931067 --- /dev/null +++ b/xjson/json_test.go @@ -0,0 +1,241 @@ +package xjson + +import ( + "bytes" + "fmt" + "html/template" + "io" + "net/http" + "net/http/httptest" + "reflect" + "testing" + + "github.com/asaskevich/govalidator" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" +) + +func Test_convertValidationErrors(t *testing.T) { + type args struct { + errs govalidator.Errors + } + tests := []struct { + name string + args args + want []m + }{ + {"none", args{govalidator.Errors{}}, nil}, + // TODO (AB): Add more tests. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := convertValidationErrors(tt.args.errs); !reflect.DeepEqual(got, tt.want) { + t.Errorf("convertValidationErrors() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestWriteErrPage(t *testing.T) { + for _, test := range []struct { + name string + p ErrPage + }{ + { + name: "OK", + p: ErrPage{ + DevURL: "ok-expected-dev-url", + AccessURL: "ok-expected-access-url", + Msg: "ok-expected-msg", + Code: http.StatusBadGateway, + Err: xerrors.New("ok-expected-err"), + }, + }, + { + name: "Code Unset", + p: ErrPage{ + DevURL: "code-unset-expected-dev-url", + AccessURL: "code-unset-expected-access-url", + Msg: "code-unset-expected-msg", + Code: 0, + Err: xerrors.New("code-unset-expected-err"), + }, + }, + { + name: "Msg Unset", + p: ErrPage{ + DevURL: "msg-unset-expected-dev-url", + AccessURL: "msg-unset-expected-access-url", + Msg: "", + Code: http.StatusInternalServerError, + Err: xerrors.New("msg-unset-expected-err"), + }, + }, + { + name: "AccessURL Unset", + p: ErrPage{ + DevURL: "access-url-unset-expected-dev-url", + AccessURL: "", + Msg: "access-url-unset-expected-msg", + Code: http.StatusInternalServerError, + Err: xerrors.New("access-url-unset-expected-err"), + }, + }, + { + name: "DevURL Unset", + p: ErrPage{ + DevURL: "", + AccessURL: "dev-url-unset-expected-access-url", + Msg: "dev-url-unset-expected-msg", + Code: http.StatusInternalServerError, + Err: xerrors.New("dev-url-unset-expected-err"), + }, + }, + { + name: "Nil Err", + p: ErrPage{ + DevURL: "nil-err-expected-dev-url", + AccessURL: "nil-err-expected-access-url", + Msg: "nil-err-expected-msg", + Code: http.StatusInternalServerError, + Err: nil, + }, + }, + } { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + handler := http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + WriteErrPage(w, test.p) + }, + ) + + s := httptest.NewServer(handler) + defer s.Close() + + req := httptest.NewRequest(http.MethodGet, s.URL, nil) + respRecorder := httptest.NewRecorder() + handler(respRecorder, req) + + resp := respRecorder.Result() + require.NotNil(t, resp, "expected non-nil response") + defer resp.Body.Close() + require.Equal(t, http.StatusInternalServerError, resp.StatusCode, "status codes differ") + + got, err := io.ReadAll(resp.Body) + require.Equal(t, nil, err, "read body") + + switch test.name { + case "OK": + require.Equal(t, expectedOKErrPage(t), got, "OK response data does not match") + case "Code Unset": + require.Equal(t, expectedCodeUnsetErrPage(t), got, "code-unset response data does not match") + case "Msg Unset": + require.Equal(t, expectedMsgUnsetErrPage(t), got, "msg-unset response data does not match") + case "DevURL Unset": + require.Equal(t, expectedDevURLUnsetErrPage(t), got, "dev-url-unset response data does not match") + case "AccessURL Unset": + require.Equal(t, expectedAccessURLUnsetErrPage(t), got, "access-url-unset response data does not match") + case "Nil Err": + require.Equal(t, expectedNilErrPage(t), got, "nil-err response data does not match") + default: + t.Fail() + } + }, + ) + } +} + +func toTemplateData(t *testing.T, ep ErrPage) []byte { + view, err := template.New("").Funcs( + template.FuncMap{ + "status": func(p ErrPage) string { + return fmt.Sprintf("%d - %s", p.Code, http.StatusText(p.Code)) + }, + }, + ).Parse(errPageMarkup) + + require.NoError(t, err) + b := bytes.NewBuffer(nil) + // We can use a bytes.Buffer as replacement for the http.ResponseWriter because + // it implements the io.Writer interface. + // Write the ErrPage to the template then writes the template to the buffer. + require.NoError(t, view.ExecuteTemplate(b, "", ep)) + return b.Bytes() +} + +func expectedOKErrPage(t *testing.T) []byte { + // when this is turned into template data, + // the retry button should be rendered by the handler + // and match accordingly. + return toTemplateData(t, ErrPage{ + DevURL: "code-unset-expected-dev-url", + AccessURL: "code-unset-expected-access-url", + Msg: "code-unset-expected-msg", + Code: http.StatusBadGateway, + Err: xerrors.New("code-unset-expected-err"), + }) +} + +func expectedCodeUnsetErrPage(t *testing.T) []byte { + return toTemplateData(t, ErrPage{ + DevURL: "code-unset-expected-dev-url", + AccessURL: "code-unset-expected-access-url", + Msg: "code-unset-expected-msg", + // If Code is unset, it should default to a 500. + Code: http.StatusInternalServerError, + Err: xerrors.New("code-unset-expected-err"), + }) +} + +func expectedMsgUnsetErrPage(t *testing.T) []byte { + return toTemplateData(t, ErrPage{ + DevURL: "msg-unset-expected-dev-url", + AccessURL: "msg-unset-expected-access-url", + // If Msg was unset, it should default to the status text + // of its status code.. + Msg: http.StatusText(http.StatusInternalServerError), + Code: http.StatusInternalServerError, + Err: xerrors.New("msg-unset-expected-err"), + }) +} + +func expectedAccessURLUnsetErrPage(t *testing.T) []byte { + return toTemplateData(t, ErrPage{ + DevURL: "access-url-unset-expected-dev-url", + // AccessURL's do not get auto-corrected. + // 'Back to Site' button will not work when + // writing an ErrPage with an unset AccessURL + // to a template. + AccessURL: "", + Msg: "access-url-unset-expected-msg", + Code: http.StatusInternalServerError, + Err: xerrors.New("access-url-unset-expected-err"), + }) +} + +func expectedNilErrPage(t *testing.T) []byte { + return toTemplateData(t, ErrPage{ + DevURL: "nil-err-expected-dev-url", + AccessURL: "nil-err-expected-access-url", + Msg: "nil-err-expected-msg", + Code: http.StatusInternalServerError, + Err: nil, + }) +} + +func expectedDevURLUnsetErrPage(t *testing.T) []byte { + return toTemplateData(t, ErrPage{ + // The DevURL remains unadjusted because no logic + // will adjust it. When the err page is turned into + // template data the retry button won't be present. + // The test server handler should do the same and provide + // the expected result. + DevURL: "", + AccessURL: "nil-err-expected-access-url", + Msg: "nil-err-expected-msg", + Code: http.StatusInternalServerError, + Err: nil, + }) +}