diff --git a/coderd/coderd.go b/coderd/coderd.go index 82486b98722ef..bb800ee954805 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -126,7 +126,7 @@ func New(options *Options) http.Handler { }) }) }) - r.NotFound(site.Handler().ServeHTTP) + r.NotFound(site.Handler(options.Logger).ServeHTTP) return r } diff --git a/go.mod b/go.mod index d026e00efbe84..14386bcf6927b 100644 --- a/go.mod +++ b/go.mod @@ -28,6 +28,7 @@ require ( github.com/pion/logging v0.2.2 github.com/pion/transport v0.13.0 github.com/pion/webrtc/v3 v3.1.23 + github.com/psanford/memfs v0.0.0-20210214183328-a001468d78ef github.com/quasilyte/go-ruleguard/dsl v0.3.16 github.com/spf13/cobra v1.3.0 github.com/stretchr/testify v1.7.0 diff --git a/go.sum b/go.sum index 0e9fb0192b6ed..8bf39c8e1bb37 100644 --- a/go.sum +++ b/go.sum @@ -1088,6 +1088,8 @@ github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4O github.com/prometheus/procfs v0.2.0/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= +github.com/psanford/memfs v0.0.0-20210214183328-a001468d78ef h1:NKxTG6GVGbfMXc2mIk+KphcH6hagbVXhcFkbTgYleTI= +github.com/psanford/memfs v0.0.0-20210214183328-a001468d78ef/go.mod h1:tcaRap0jS3eifrEEllL6ZMd9dg8IlDpi2S1oARrQ+NI= github.com/quasilyte/go-ruleguard/dsl v0.3.16 h1:yJtIpd4oyNS+/c/gKqxNwoGO9+lPOsy1A4BzKjJRcrI= github.com/quasilyte/go-ruleguard/dsl v0.3.16/go.mod h1:KeCP03KrjuSO0H1kTuZQCWlQPulDV6YMIXmpQss17rU= github.com/remyoudompheng/bigfft v0.0.0-20190728182440-6a916e37a237/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= diff --git a/site/embed.go b/site/embed.go index 69326b22403b7..eccbb26009d53 100644 --- a/site/embed.go +++ b/site/embed.go @@ -1,21 +1,18 @@ package site import ( - "bytes" "embed" "fmt" - "io" "io/fs" "net/http" - "path" - "path/filepath" "strings" - "text/template" // html/template escapes some nonces - "time" "github.com/justinas/nosurf" "github.com/unrolled/secure" - "golang.org/x/xerrors" + + "cdr.dev/slog" + + "github.com/coder/coder/site/nextrouter" ) // The `embed` package ignores recursively including directories @@ -27,53 +24,33 @@ import ( var site embed.FS // Handler returns an HTTP handler for serving the static site. -func Handler() http.Handler { +func Handler(logger slog.Logger) http.Handler { filesystem, err := fs.Sub(site, "out") if err != nil { // This can't happen... Go would throw a compilation error. panic(err) } - // html files are handled by a text/template. Non-html files - // are served by the default file server. - files, err := htmlFiles(filesystem) - if err != nil { - panic(xerrors.Errorf("Failed to return handler for static files. Html files failed to load: %w", err)) + // Render CSP and CSRF in the served pages + templateFunc := func(r *http.Request) interface{} { + return htmlState{ + // Nonce is the CSP nonce for the given request (if there is one present) + CSP: cspState{Nonce: secure.CSPNonce(r.Context())}, + // Token is the CSRF token for the given request + CSRF: csrfState{Token: nosurf.Token(r)}, + } } - return secureHeaders(&handler{ - fs: filesystem, - htmlFiles: files, - h: http.FileServer(http.FS(filesystem)), // All other non-html static files + nextRouterHandler, err := nextrouter.Handler(filesystem, &nextrouter.Options{ + Logger: logger, + TemplateDataFunc: templateFunc, }) -} - -type handler struct { - fs fs.FS - // htmlFiles is the text/template for all *.html files. - // This is needed to support Content Security Policy headers. - // Due to material UI, we are forced to use a nonce to allow inline - // scripts, and that nonce is passed through a template. - // We only do this for html files to reduce the amount of in memory caching - // of duplicate files as `fs`. - htmlFiles *htmlTemplates - h http.Handler -} - -// filePath returns the filepath of the requested file. -func (*handler) filePath(p string) string { - if !strings.HasPrefix(p, "/") { - p = "/" + p - } - return strings.TrimPrefix(path.Clean(p), "/") -} - -func (h *handler) exists(filePath string) bool { - f, err := h.fs.Open(filePath) - if err == nil { - _ = f.Close() + if err != nil { + // There was an error setting up our file system handler. + // This likely means a problem with our embedded file system. + panic(err) } - return err == nil + return secureHeaders(nextRouterHandler) } type htmlState struct { @@ -89,80 +66,6 @@ type csrfState struct { Token string } -func (h *handler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { - // reqFile is the static file requested - reqFile := h.filePath(r.URL.Path) - state := htmlState{ - // Nonce is the CSP nonce for the given request (if there is one present) - CSP: cspState{Nonce: secure.CSPNonce(r.Context())}, - // Token is the CSRF token for the given request - CSRF: csrfState{Token: nosurf.Token(r)}, - } - - // First check if it's a file we have in our templates - if h.serveHTML(rw, r, reqFile, state) { - return - } - - // If the original file path exists we serve it. - if h.exists(reqFile) { - h.h.ServeHTTP(rw, r) - return - } - - // Serve the file assuming it's an html file - // This matches paths like `/app/terminal.html` - r.URL.Path = strings.TrimSuffix(r.URL.Path, "/") - r.URL.Path += ".html" - - reqFile = h.filePath(r.URL.Path) - // All html files should be served by the htmlFile templates - if h.serveHTML(rw, r, reqFile, state) { - return - } - - // If we don't have the file... we should redirect to `/` - // for our single-page-app. - r.URL.Path = "/" - if h.serveHTML(rw, r, "", state) { - return - } - - // This will send a correct 404 - h.h.ServeHTTP(rw, r) -} - -func (h *handler) serveHTML(rw http.ResponseWriter, r *http.Request, reqPath string, state htmlState) bool { - if data, err := h.htmlFiles.renderWithState(reqPath, state); err == nil { - if reqPath == "" { - // Pass "index.html" to the ServeContent so the ServeContent sets the right content headers. - reqPath = "index.html" - } - http.ServeContent(rw, r, reqPath, time.Time{}, bytes.NewReader(data)) - return true - } - return false -} - -type htmlTemplates struct { - tpls *template.Template -} - -// renderWithState will render the file using the given nonce if the file exists -// as a template. If it does not, it will return an error. -func (t *htmlTemplates) renderWithState(filePath string, state htmlState) ([]byte, error) { - var buf bytes.Buffer - if filePath == "" { - filePath = "index.html" - } - err := t.tpls.ExecuteTemplate(&buf, filePath, state) - if err != nil { - return nil, err - } - - return buf.Bytes(), nil -} - // cspDirectives is a map of all csp fetch directives to their values. // Each directive is a set of values that is joined by a space (' '). // All directives are semi-colon separated as a single string for the csp header. @@ -264,52 +167,3 @@ func secureHeaders(next http.Handler) http.Handler { ReferrerPolicy: "no-referrer", }).Handler(next) } - -// htmlFiles recursively walks the file system passed finding all *.html files. -// The template returned has all html files parsed. -func htmlFiles(files fs.FS) (*htmlTemplates, error) { - // root is the collection of html templates. All templates are named by their pathing. - // So './404.html' is named '404.html'. './subdir/index.html' is 'subdir/index.html' - root := template.New("") - - rootPath := "." - err := fs.WalkDir(files, rootPath, func(path string, dirEntry fs.DirEntry, err error) error { - if err != nil { - return err - } - - if dirEntry.IsDir() { - return nil - } - - if filepath.Ext(dirEntry.Name()) != ".html" { - return nil - } - - file, err := files.Open(path) - if err != nil { - return err - } - - data, err := io.ReadAll(file) - if err != nil { - return err - } - - tPath := strings.TrimPrefix(path, rootPath+string(filepath.Separator)) - _, err = root.New(tPath).Parse(string(data)) - if err != nil { - return err - } - - return nil - }) - - if err != nil { - return nil, err - } - - return &htmlTemplates{ - tpls: root, - }, nil -} diff --git a/site/embed_test.go b/site/embed_test.go index 4e43f2a56bb37..9378c9cdda0bf 100644 --- a/site/embed_test.go +++ b/site/embed_test.go @@ -9,13 +9,15 @@ import ( "github.com/stretchr/testify/require" + "cdr.dev/slog" + "github.com/coder/coder/site" ) func TestIndexPageRenders(t *testing.T) { t.Parallel() - srv := httptest.NewServer(site.Handler()) + srv := httptest.NewServer(site.Handler(slog.Logger{})) req, err := http.NewRequestWithContext(context.Background(), "GET", srv.URL, nil) require.NoError(t, err) diff --git a/site/nextrouter/nextrouter.go b/site/nextrouter/nextrouter.go new file mode 100644 index 0000000000000..adf9a00cd23d4 --- /dev/null +++ b/site/nextrouter/nextrouter.go @@ -0,0 +1,238 @@ +package nextrouter + +import ( + "bytes" + "context" + "html/template" + "io/fs" + "net/http" + "path/filepath" + "strings" + "time" + + "github.com/go-chi/chi/v5" + + "cdr.dev/slog" +) + +// Options for configuring a nextrouter +type Options struct { + Logger slog.Logger + TemplateDataFunc HTMLTemplateHandler +} + +// HTMLTemplateHandler is a function that lets the consumer of `nextrouter` +// inject arbitrary template parameters, based on the request. This is useful +// if the Request object is carrying CSRF tokens, session tokens, etc - +// they can be emitted in the page. +type HTMLTemplateHandler func(*http.Request) interface{} + +// Handler returns an HTTP handler for serving a next-based static site +// This handler respects NextJS-based routing rules: +// https://nextjs.org/docs/routing/dynamic-routes +// +// 1) If a file is of the form `[org]`, it's a dynamic route for a single-parameter +// 2) If a file is of the form `[[...any]]`, it's a dynamic route for any parameters +func Handler(fileSystem fs.FS, options *Options) (http.Handler, error) { + if options == nil { + options = &Options{ + Logger: slog.Logger{}, + TemplateDataFunc: nil, + } + } + router := chi.NewRouter() + + // Build up a router that matches NextJS routing rules, for HTML files + err := registerRoutes(router, fileSystem, *options) + if err != nil { + return nil, err + } + + // Fallback to static file server for non-HTML files + // Non-HTML files don't have special routing rules, so we can just leverage + // the built-in http.FileServer for it. + fileHandler := http.FileServer(http.FS(fileSystem)) + router.NotFound(fileHandler.ServeHTTP) + + // Finally, if there is a 404.html available, serve that + err = register404(fileSystem, router, *options) + if err != nil { + // An error may be expected if a 404.html is not present + options.Logger.Warn(context.Background(), "Unable to find 404.html", slog.Error(err)) + } + + return router, nil +} + +// registerRoutes recursively traverses the file-system, building routes +// as appropriate for respecting NextJS dynamic rules. +func registerRoutes(rtr chi.Router, fileSystem fs.FS, options Options) error { + files, err := fs.ReadDir(fileSystem, ".") + if err != nil { + return err + } + + // Loop through everything in the current directory... + for _, file := range files { + name := file.Name() + + // If we're working with a file - just serve it up + if !file.IsDir() { + serveFile(rtr, fileSystem, name, options) + continue + } + + // ...otherwise, if it's a directory, create a sub-route by + // recursively calling `buildRouter` + sub, err := fs.Sub(fileSystem, name) + if err != nil { + return err + } + + // In the special case where the folder is dynamic, + // like `[org]`, we can convert to a chi-style dynamic route + // (which uses `{` instead of `[`) + routeName := name + if isDynamicRoute(name) { + routeName = "{dynamic}" + } + + options.Logger.Debug(context.Background(), "Registering route", slog.F("name", name), slog.F("routeName", routeName)) + rtr.Route("/"+routeName, func(r chi.Router) { + err := registerRoutes(r, sub, options) + if err != nil { + options.Logger.Error(context.Background(), "Error registering route", slog.F("name", routeName), slog.Error(err)) + } + }) + } + + return nil +} + +// serveFile is responsible for serving up HTML files in our next router +// It handles various special cases, like trailing-slashes or handling routes w/o the .html suffix. +func serveFile(router chi.Router, fileSystem fs.FS, fileName string, options Options) { + // We only handle .html files for now + ext := filepath.Ext(fileName) + if ext != ".html" { + return + } + + options.Logger.Debug(context.Background(), "Reading file", slog.F("fileName", fileName)) + + data, err := fs.ReadFile(fileSystem, fileName) + if err != nil { + options.Logger.Error(context.Background(), "Unable to read file", slog.F("fileName", fileName)) + return + } + + // Create a template from the data - we can inject custom parameters like CSRF here + tpls, err := template.New(fileName).Parse(string(data)) + if err != nil { + options.Logger.Error(context.Background(), "Unable to create template for file", slog.F("fileName", fileName)) + return + } + + handler := func(writer http.ResponseWriter, request *http.Request) { + var buf bytes.Buffer + + // See if there are any template parameters we need to inject! + // Things like CSRF tokens, etc... + //templateData := struct{}{} + var templateData interface{} + templateData = nil + if options.TemplateDataFunc != nil { + templateData = options.TemplateDataFunc(request) + } + + options.Logger.Debug(context.Background(), "Applying template parameters", slog.F("fileName", fileName), slog.F("templateData", templateData)) + err := tpls.ExecuteTemplate(&buf, fileName, templateData) + + if err != nil { + options.Logger.Error(request.Context(), "Error executing template", slog.F("template_parameters", templateData)) + http.Error(writer, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + + http.ServeContent(writer, request, fileName, time.Time{}, bytes.NewReader(buf.Bytes())) + } + + fileNameWithoutExtension := removeFileExtension(fileName) + + // Handle the `[[...any]]` catch-all case + if isCatchAllRoute(fileNameWithoutExtension) { + options.Logger.Info(context.Background(), "Registering catch-all route", slog.F("fileName", fileName)) + router.NotFound(handler) + return + } + + // Handle the `[org]` dynamic route case + if isDynamicRoute(fileNameWithoutExtension) { + options.Logger.Debug(context.Background(), "Registering dynamic route", slog.F("fileName", fileName)) + router.Get("/{dynamic}", handler) + return + } + + // Handle the basic file cases + // Directly accessing a file, ie `/providers.html` + router.Get("/"+fileName, handler) + // Accessing a file without an extension, ie `/providers` + router.Get("/"+fileNameWithoutExtension, handler) + + // Special case: '/' should serve index.html + if fileName == "index.html" { + router.Get("/", handler) + return + } + + // Otherwise, handling the trailing slash case - + // for examples, `providers.html` should serve `/providers/` + router.Get("/"+fileNameWithoutExtension+"/", handler) +} + +func register404(fileSystem fs.FS, router chi.Router, options Options) error { + // Get the file contents + fileBytes, err := fs.ReadFile(fileSystem, "404.html") + if err != nil { + // An error is expected if the file doesn't exist + return err + } + + router.NotFound(func(writer http.ResponseWriter, request *http.Request) { + writer.WriteHeader(http.StatusNotFound) + _, err = writer.Write(fileBytes) + if err != nil { + options.Logger.Error(request.Context(), "Unable to write bytes for 404") + return + } + }) + + return nil +} + +// isDynamicRoute returns true if the file is a NextJS dynamic route, like `[orgs]` +// Returns false if the file is not a dynamic route, or if it is a catch-all route (`[[...any]]`) +// NOTE: The extension should be removed from the file name +func isDynamicRoute(fileWithoutExtension string) bool { + // Assuming ASCII encoding - `len` in go works on bytes + byteLen := len(fileWithoutExtension) + if byteLen < 2 { + return false + } + + return fileWithoutExtension[0] == '[' && fileWithoutExtension[1] != '[' && fileWithoutExtension[byteLen-1] == ']' +} + +// isCatchAllRoute returns true if the file is a catch-all route, like `[[...any]]` +// Return false otherwise +// NOTE: The extension should be removed from the file name +func isCatchAllRoute(fileWithoutExtension string) bool { + ret := strings.HasPrefix(fileWithoutExtension, "[[.") + return ret +} + +// removeFileExtension removes the extension from a file +// For example, removeFileExtension("index.html") would return "index" +func removeFileExtension(fileName string) string { + return strings.TrimSuffix(fileName, filepath.Ext(fileName)) +} diff --git a/site/nextrouter/nextrouter_test.go b/site/nextrouter/nextrouter_test.go new file mode 100644 index 0000000000000..96ebf3d7ef40b --- /dev/null +++ b/site/nextrouter/nextrouter_test.go @@ -0,0 +1,374 @@ +package nextrouter_test + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/psanford/memfs" + "github.com/stretchr/testify/require" + + "cdr.dev/slog" + + "github.com/coder/coder/site/nextrouter" +) + +func TestNextRouter(t *testing.T) { + t.Parallel() + + t.Run("Serves file at root", func(t *testing.T) { + t.Parallel() + rootFS := memfs.New() + err := rootFS.WriteFile("test.html", []byte("test123"), 0755) + require.NoError(t, err) + + router, err := nextrouter.Handler(rootFS, nil) + require.NoError(t, err) + server := httptest.NewServer(router) + + res, err := request(server, "/test.html") + require.NoError(t, err) + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, string(body), "test123") + require.Equal(t, res.StatusCode, 200) + }) + + // This is a test case for the issue we hit in V1 w/ NextJS migration + t.Run("Prefer file over folder w/ trailing slash", func(t *testing.T) { + t.Parallel() + rootFS := memfs.New() + err := rootFS.MkdirAll("folder", 0777) + require.NoError(t, err) + err = rootFS.WriteFile("folder.html", []byte("folderFile"), 0755) + require.NoError(t, err) + + router, err := nextrouter.Handler(rootFS, nil) + require.NoError(t, err) + server := httptest.NewServer(router) + + res, err := request(server, "/folder/") + require.NoError(t, err) + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, string(body), "folderFile") + require.Equal(t, res.StatusCode, 200) + }) + + t.Run("Serves non-html files at root", func(t *testing.T) { + t.Parallel() + rootFS := memfs.New() + err := rootFS.WriteFile("test.png", []byte("png-bytes"), 0755) + require.NoError(t, err) + + router, err := nextrouter.Handler(rootFS, nil) + require.NoError(t, err) + server := httptest.NewServer(router) + + res, err := request(server, "/test.png") + require.NoError(t, err) + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, res.Header.Get("Content-Type"), "image/png") + require.Equal(t, string(body), "png-bytes") + require.Equal(t, res.StatusCode, 200) + }) + + t.Run("Serves html file without extension", func(t *testing.T) { + t.Parallel() + rootFS := memfs.New() + err := rootFS.WriteFile("test.html", []byte("test-no-extension"), 0755) + require.NoError(t, err) + + router, err := nextrouter.Handler(rootFS, nil) + require.NoError(t, err) + server := httptest.NewServer(router) + + res, err := request(server, "/test") + require.NoError(t, err) + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, string(body), "test-no-extension") + require.Equal(t, res.StatusCode, 200) + }) + + t.Run("Defaults to index.html at root", func(t *testing.T) { + t.Parallel() + rootFS := memfs.New() + err := rootFS.WriteFile("index.html", []byte("test-root-index"), 0755) + require.NoError(t, err) + + router, err := nextrouter.Handler(rootFS, nil) + require.NoError(t, err) + server := httptest.NewServer(router) + + res, err := request(server, "/") + require.NoError(t, err) + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, res.Header.Get("Content-Type"), "text/html; charset=utf-8") + require.Equal(t, string(body), "test-root-index") + require.Equal(t, res.StatusCode, 200) + }) + + t.Run("Serves nested file", func(t *testing.T) { + t.Parallel() + + rootFS := memfs.New() + err := rootFS.MkdirAll("test/a/b", 0777) + require.NoError(t, err) + + rootFS.WriteFile("test/a/b/c.html", []byte("test123"), 0755) + require.NoError(t, err) + + router, err := nextrouter.Handler(rootFS, nil) + require.NoError(t, err) + server := httptest.NewServer(router) + + res, err := request(server, "/test/a/b/c.html") + require.NoError(t, err) + defer res.Body.Close() + + res, err = request(server, "/test/a/b/c.html") + require.NoError(t, err) + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, string(body), "test123") + require.Equal(t, res.StatusCode, 200) + }) + + t.Run("Uses index.html in nested path", func(t *testing.T) { + t.Parallel() + + rootFS := memfs.New() + err := rootFS.MkdirAll("test/a/b/c", 0777) + require.NoError(t, err) + + rootFS.WriteFile("test/a/b/c/index.html", []byte("test-abc-index"), 0755) + require.NoError(t, err) + + router, err := nextrouter.Handler(rootFS, nil) + require.NoError(t, err) + server := httptest.NewServer(router) + + res, err := request(server, "/test/a/b/c") + require.NoError(t, err) + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, string(body), "test-abc-index") + require.Equal(t, res.StatusCode, 200) + }) + + t.Run("404 if file at root is not found", func(t *testing.T) { + t.Parallel() + + rootFS := memfs.New() + err := rootFS.WriteFile("test.html", []byte("test123"), 0755) + require.NoError(t, err) + + router, err := nextrouter.Handler(rootFS, nil) + require.NoError(t, err) + server := httptest.NewServer(router) + + res, err := request(server, "/test-non-existent.html") + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, res.StatusCode, 404) + }) + + t.Run("404 if file at root is not found", func(t *testing.T) { + t.Parallel() + + rootFS := memfs.New() + err := rootFS.WriteFile("test.html", []byte("test123"), 0755) + require.NoError(t, err) + + router, err := nextrouter.Handler(rootFS, nil) + require.NoError(t, err) + server := httptest.NewServer(router) + + res, err := request(server, "/test-non-existent.html") + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, res.StatusCode, 404) + }) + + t.Run("Serve custom 404.html if available", func(t *testing.T) { + t.Parallel() + + rootFS := memfs.New() + err := rootFS.WriteFile("404.html", []byte("404 custom content"), 0755) + require.NoError(t, err) + + router, err := nextrouter.Handler(rootFS, nil) + require.NoError(t, err) + server := httptest.NewServer(router) + + res, err := request(server, "/test-non-existent.html") + require.NoError(t, err) + defer res.Body.Close() + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, res.StatusCode, 404) + require.Equal(t, string(body), "404 custom content") + }) + + t.Run("Serves dynamic-routed file", func(t *testing.T) { + t.Parallel() + rootFS := memfs.New() + err := rootFS.MkdirAll("folder", 0777) + require.NoError(t, err) + err = rootFS.WriteFile("folder/[orgs].html", []byte("test-dynamic-path"), 0755) + require.NoError(t, err) + + router, err := nextrouter.Handler(rootFS, nil) + require.NoError(t, err) + server := httptest.NewServer(router) + + res, err := request(server, "/folder/org-1") + require.NoError(t, err) + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, string(body), "test-dynamic-path") + require.Equal(t, res.StatusCode, 200) + }) + + t.Run("Handles dynamic-routed folders", func(t *testing.T) { + t.Parallel() + rootFS := memfs.New() + err := rootFS.MkdirAll("folder/[org]/[project]", 0777) + require.NoError(t, err) + err = rootFS.WriteFile("folder/[org]/[project]/create.html", []byte("test-create"), 0755) + require.NoError(t, err) + + router, err := nextrouter.Handler(rootFS, nil) + require.NoError(t, err) + server := httptest.NewServer(router) + + res, err := request(server, "/folder/org-1/project-1/create") + require.NoError(t, err) + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, string(body), "test-create") + require.Equal(t, res.StatusCode, 200) + }) + + t.Run("Handles catch-all routes", func(t *testing.T) { + t.Parallel() + rootFS := memfs.New() + err := rootFS.MkdirAll("folder", 0777) + require.NoError(t, err) + err = rootFS.WriteFile("folder/[[...any]].html", []byte("test-catch-all"), 0755) + require.NoError(t, err) + + router, err := nextrouter.Handler(rootFS, nil) + require.NoError(t, err) + server := httptest.NewServer(router) + + res, err := request(server, "/folder/org-1/project-1/random") + require.NoError(t, err) + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, string(body), "test-catch-all") + require.Equal(t, res.StatusCode, 200) + }) + + t.Run("Static routes should be preferred to dynamic routes", func(t *testing.T) { + t.Parallel() + rootFS := memfs.New() + err := rootFS.MkdirAll("folder", 0777) + require.NoError(t, err) + err = rootFS.WriteFile("folder/[orgs].html", []byte("test-dynamic-path"), 0755) + require.NoError(t, err) + err = rootFS.WriteFile("folder/create.html", []byte("test-create"), 0755) + require.NoError(t, err) + + router, err := nextrouter.Handler(rootFS, nil) + require.NoError(t, err) + server := httptest.NewServer(router) + + res, err := request(server, "/folder/create") + require.NoError(t, err) + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, string(body), "test-create") + require.Equal(t, res.StatusCode, 200) + }) + + t.Run("Injects template parameters", func(t *testing.T) { + t.Parallel() + + rootFS := memfs.New() + err := rootFS.WriteFile("test.html", []byte("{{ .CSRF.Token }}"), 0755) + require.NoError(t, err) + + type csrfState struct { + Token string + } + + type template struct { + CSRF csrfState + } + + // Add custom template function + templateFunc := func(request *http.Request) interface{} { + return template{ + CSRF: csrfState{ + Token: "hello-csrf", + }, + } + } + + router, err := nextrouter.Handler(rootFS, &nextrouter.Options{ + Logger: slog.Logger{}, + TemplateDataFunc: templateFunc, + }) + require.NoError(t, err) + server := httptest.NewServer(router) + + res, err := request(server, "/test.html") + require.NoError(t, err) + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, string(body), "hello-csrf") + require.Equal(t, res.StatusCode, 200) + }) +} + +func request(server *httptest.Server, path string) (*http.Response, error) { + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL+path, nil) + if err != nil { + return nil, err + } + res, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + return res, err +}