diff --git a/handlers/static.go b/handlers/static.go index 768e6e1..bc198dd 100644 --- a/handlers/static.go +++ b/handlers/static.go @@ -42,10 +42,10 @@ func (f neuteredReaddirFile) Readdir(count int) ([]os.FileInfo, error) { // Handles requests for static files, without allowing access to the // directory viewer and returning 404 if an exact file is not found -func HandleStatic() http.Handler { +func HandleStatic(staticFS *http.FileSystem) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - nfs := justFilesFilesystem{http.Dir("static")} + nfs := justFilesFilesystem{*staticFS} fs := http.FileServer(nfs) fs.ServeHTTP(w, r) }, diff --git a/main.go b/main.go index 1bcf240..84ed2fd 100644 --- a/main.go +++ b/main.go @@ -6,6 +6,7 @@ import ( "flag" "fmt" "io" + "io/fs" "net" "net/http" "os" @@ -22,6 +23,26 @@ import ( "github.com/pkg/errors" ) +//go:embed static/* +var embeddedStatic embed.FS + +// Gets the static files +func getStaticFiles() (http.FileSystem, error) { + if _, err := os.Stat("static"); err == nil { + // Use actual filesystem in development + fmt.Println("Using filesystem for static files") + return http.Dir("static"), nil + } else { + // Use embedded filesystem in production + fmt.Println("Using embedded static files") + subFS, err := fs.Sub(embeddedStatic, "static") + if err != nil { + return nil, errors.Wrap(err, "fs.Sub") + } + return http.FS(subFS), nil + } +} + // Initializes and runs the server func run(ctx context.Context, w io.Writer, args map[string]string) error { ctx, cancel := signal.NotifyContext(ctx, os.Interrupt) @@ -62,7 +83,12 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { } defer conn.Close() - srv := server.NewServer(config, logger, conn) + staticFS, err := getStaticFiles() + if err != nil { + return errors.Wrap(err, "getStaticFiles") + } + + srv := server.NewServer(config, logger, conn, &staticFS) httpServer := &http.Server{ Addr: net.JoinHostPort(config.Host, config.Port), Handler: srv, @@ -101,9 +127,6 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { return nil } -//go:embed static/* -var static embed.FS - // Start of runtime. Parse commandline arguments & flags, Initializes context // and starts the server func main() { diff --git a/middleware/excluded.go b/middleware/excluded.go deleted file mode 100644 index cc31749..0000000 --- a/middleware/excluded.go +++ /dev/null @@ -1,25 +0,0 @@ -package middleware - -import ( - "net/http" - "strings" -) - -var excludedFiles = map[string]bool{ - "/static/css/output.css": true, -} - -// Checks is path requested if for an excluded file and returns the file -// instead of passing the request onto the next middleware -func ExcludedFiles(next http.Handler) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - if excludedFiles[r.URL.Path] { - filePath := strings.TrimPrefix(r.URL.Path, "/") - http.ServeFile(w, r, filePath) - } else { - next.ServeHTTP(w, r) - } - }, - ) -} diff --git a/middleware/favicon.go b/middleware/favicon.go deleted file mode 100644 index 41385fa..0000000 --- a/middleware/favicon.go +++ /dev/null @@ -1,17 +0,0 @@ -package middleware - -import ( - "net/http" -) - -func Favicon(next http.Handler) http.Handler { - return http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/favicon.ico" { - http.ServeFile(w, r, "static/favicon.ico") - } else { - next.ServeHTTP(w, r) - } - }, - ) -} diff --git a/server/routes.go b/server/routes.go index 606036e..a92885f 100644 --- a/server/routes.go +++ b/server/routes.go @@ -18,12 +18,13 @@ func addRoutes( logger *zerolog.Logger, config *config.Config, conn *sql.DB, + staticFS *http.FileSystem, ) { // Health check mux.HandleFunc("GET /healthz", func(http.ResponseWriter, *http.Request) {}) // Static files - mux.Handle("GET /static/", http.StripPrefix("/static/", handlers.HandleStatic())) + mux.Handle("GET /static/", http.StripPrefix("/static/", handlers.HandleStatic(staticFS))) // Index page and unhandled catchall (404) mux.Handle("GET /", handlers.HandleRoot()) diff --git a/server/server.go b/server/server.go index e460546..648082f 100644 --- a/server/server.go +++ b/server/server.go @@ -15,6 +15,7 @@ func NewServer( config *config.Config, logger *zerolog.Logger, conn *sql.DB, + staticFS *http.FileSystem, ) http.Handler { mux := http.NewServeMux() addRoutes( @@ -22,6 +23,7 @@ func NewServer( logger, config, conn, + staticFS, ) var handler http.Handler = mux // Add middleware here, must be added in reverse order of execution @@ -29,10 +31,6 @@ func NewServer( handler = middleware.Logging(logger, handler) handler = middleware.Authentication(logger, config, conn, handler) - // Serve the favicon and exluded files before any middleware is added - handler = middleware.ExcludedFiles(handler) - handler = middleware.Favicon(handler) - // Gzip handler = middleware.Gzip(handler, config.GZIP) diff --git a/view/layout/global.templ b/view/layout/global.templ index 0b5c155..17f58eb 100644 --- a/view/layout/global.templ +++ b/view/layout/global.templ @@ -34,6 +34,7 @@ templ Global() { Project Reshoot +