From 4a21ba3821b5db399c4e908215efa0f4fe5f7575 Mon Sep 17 00:00:00 2001 From: Haelnorr Date: Fri, 2 Jan 2026 19:50:10 +1100 Subject: [PATCH] cleaned up main module --- cmd/projectreshoot/flags.go | 2 -- cmd/projectreshoot/logger.go | 46 +++++++++++++++++++++++++++ cmd/projectreshoot/main.go | 12 ++++++- cmd/projectreshoot/run.go | 41 +++--------------------- cmd/projectreshoot/signals.go | 41 ------------------------ internal/httpserver/server.go | 6 ++-- internal/middleware/authentication.go | 5 --- 7 files changed, 63 insertions(+), 90 deletions(-) create mode 100644 cmd/projectreshoot/logger.go delete mode 100644 cmd/projectreshoot/signals.go diff --git a/cmd/projectreshoot/flags.go b/cmd/projectreshoot/flags.go index 3501397..4ad83cf 100644 --- a/cmd/projectreshoot/flags.go +++ b/cmd/projectreshoot/flags.go @@ -10,7 +10,6 @@ func setupFlags() map[string]string { host := flag.String("host", "", "Override host to listen on") port := flag.String("port", "", "Override port to listen on") test := flag.Bool("test", false, "Run server in test mode") - tester := flag.Bool("tester", false, "Run tester function instead of main program") dbver := flag.Bool("dbver", false, "Get the version of the database required") loglevel := flag.String("loglevel", "", "Set log level") logoutput := flag.String("logoutput", "", "Set log destination (file, console or both)") @@ -21,7 +20,6 @@ func setupFlags() map[string]string { "host": *host, "port": *port, "test": strconv.FormatBool(*test), - "tester": strconv.FormatBool(*tester), "dbver": strconv.FormatBool(*dbver), "loglevel": *loglevel, "logoutput": *logoutput, diff --git a/cmd/projectreshoot/logger.go b/cmd/projectreshoot/logger.go new file mode 100644 index 0000000..9876890 --- /dev/null +++ b/cmd/projectreshoot/logger.go @@ -0,0 +1,46 @@ +package main + +import ( + "io" + "os" + + "git.haelnorr.com/h/golib/hlog" + "github.com/pkg/errors" +) + +// Take in the desired logOutput and a console writer to use +func setupLogger(logLevel hlog.Level, logOutput string, w *io.Writer, logDirectory string) (*hlog.Logger, error) { + // Setup the logfile + var logfile *os.File = nil + if logOutput == "both" || logOutput == "file" { + logfile, err := hlog.NewLogFile(logDirectory) + if err != nil { + return nil, errors.Wrap(err, "hlog") + } + defer logfile.Close() + } + + // Setup the console writer + var consoleWriter io.Writer + if logOutput == "both" || logOutput == "console" { + if w != nil { + consoleWriter = *w + } else { + if logOutput == "console" { + return nil, errors.New("Console logging specified as sole method but no writer provided") + } + } + } + + // Setup the logger + logger, err := hlog.NewLogger( + logLevel, + consoleWriter, + logfile, + logDirectory, + ) + if err != nil { + return nil, errors.Wrap(err, "hlog") + } + return logger, nil +} diff --git a/cmd/projectreshoot/main.go b/cmd/projectreshoot/main.go index 65f289b..504ebb3 100644 --- a/cmd/projectreshoot/main.go +++ b/cmd/projectreshoot/main.go @@ -4,12 +4,22 @@ import ( "context" "fmt" "os" + "projectreshoot/internal/config" + + "github.com/pkg/errors" ) func main() { args := setupFlags() ctx := context.Background() - if err := run(ctx, os.Stdout, args); err != nil { + + config, err := config.GetConfig(args) + if err != nil { + fmt.Fprintf(os.Stderr, "%s\n", errors.Wrap(err, "Failed to load config")) + os.Exit(1) + } + + if err := run(ctx, os.Stdout, args, config); err != nil { fmt.Fprintf(os.Stderr, "%s\n", err) os.Exit(1) } diff --git a/cmd/projectreshoot/run.go b/cmd/projectreshoot/run.go index 7ae00f4..d63c47a 100644 --- a/cmd/projectreshoot/run.go +++ b/cmd/projectreshoot/run.go @@ -13,55 +13,25 @@ import ( "sync" "time" - "git.haelnorr.com/h/golib/hlog" "git.haelnorr.com/h/golib/jwt" "github.com/pkg/errors" ) -var maint uint32 // atomic: 1 if in maintenance mode - // Initializes and runs the server -func run(ctx context.Context, w io.Writer, args map[string]string) error { +func run(ctx context.Context, w io.Writer, args map[string]string, config *config.Config) error { ctx, cancel := signal.NotifyContext(ctx, os.Interrupt) defer cancel() - config, err := config.GetConfig(args) - if err != nil { - return errors.Wrap(err, "server.GetConfig") - } - // Return the version of the database required if args["dbver"] == "true" { fmt.Fprintf(w, "Database version: %s\n", config.DBName) return nil } - // Setup the logfile - var logfile *os.File = nil - if config.LogOutput == "both" || config.LogOutput == "file" { - logfile, err = hlog.NewLogFile(config.LogDir) - if err != nil { - return errors.Wrap(err, "logging.GetLogFile") - } - defer logfile.Close() - } - - // Setup the console writer - var consoleWriter io.Writer - if config.LogOutput == "both" || config.LogOutput == "console" { - consoleWriter = w - } - - // Setup the logger - logger, err := hlog.NewLogger( - config.LogLevel, - consoleWriter, - logfile, - config.LogDir, - ) + logger, err := setupLogger(config.LogLevel, config.LogOutput, &w, config.LogDir) if err != nil { - return errors.Wrap(err, "logging.GetLogger") + return errors.Wrap(err, "setupLogger") } // Setup the database connection @@ -92,10 +62,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { ) logger.Debug().Msg("Setting up HTTP server") - httpServer := httpserver.NewServer(config, logger, conn, tokenGen, &staticFS, &maint) - - // Setups a channel to listen for os.Signal - handleMaintSignals(httpServer, logger) + httpServer := httpserver.NewServer(config, logger, conn, tokenGen, &staticFS) // Runs the http server logger.Debug().Msg("Starting up the HTTP server") diff --git a/cmd/projectreshoot/signals.go b/cmd/projectreshoot/signals.go deleted file mode 100644 index 10566b0..0000000 --- a/cmd/projectreshoot/signals.go +++ /dev/null @@ -1,41 +0,0 @@ -package main - -import ( - "net/http" - "os" - "os/signal" - "sync/atomic" - "syscall" - - "git.haelnorr.com/h/golib/hlog" -) - -// Handle SIGUSR1 and SIGUSR2 syscalls to toggle maintenance mode -func handleMaintSignals( - srv *http.Server, - logger *hlog.Logger, -) { - logger.Debug().Msg("Starting signal listener") - ch := make(chan os.Signal, 1) - srv.RegisterOnShutdown(func() { - logger.Debug().Msg("Shutting down signal listener") - close(ch) - }) - go func() { - for sig := range ch { - switch sig { - case syscall.SIGUSR1: - if atomic.LoadUint32(&maint) != 1 { - atomic.StoreUint32(&maint, 1) - logger.Info().Msg("Signal received: Starting maintenance") - } - case syscall.SIGUSR2: - if atomic.LoadUint32(&maint) != 0 { - logger.Info().Msg("Signal received: Maintenance over") - atomic.StoreUint32(&maint, 0) - } - } - } - }() - signal.Notify(ch, syscall.SIGUSR1, syscall.SIGUSR2) -} diff --git a/internal/httpserver/server.go b/internal/httpserver/server.go index c407027..c9fdc7d 100644 --- a/internal/httpserver/server.go +++ b/internal/httpserver/server.go @@ -20,10 +20,9 @@ func NewServer( conn *sql.DB, tokenGen *jwt.TokenGenerator, staticFS *fs.FS, - maint *uint32, ) *http.Server { fs := http.FS(*staticFS) - srv := createServer(config, logger, conn, tokenGen, &fs, maint) + srv := createServer(config, logger, conn, tokenGen, &fs) httpServer := &http.Server{ Addr: net.JoinHostPort(config.Host, config.Port), Handler: srv, @@ -41,7 +40,6 @@ func createServer( conn *sql.DB, tokenGen *jwt.TokenGenerator, staticFS *http.FileSystem, - maint *uint32, ) http.Handler { mux := http.NewServeMux() addRoutes( @@ -56,7 +54,7 @@ func createServer( // Add middleware here, must be added in reverse order of execution // i.e. First in list will get executed last during the request handling handler = middleware.Logging(logger, handler) - handler = middleware.Authentication(logger, config, conn, tokenGen, handler, maint) + handler = middleware.Authentication(logger, config, conn, tokenGen, handler) // Gzip handler = middleware.Gzip(handler, config.GZIP) diff --git a/internal/middleware/authentication.go b/internal/middleware/authentication.go index 71a9a42..b2e4830 100644 --- a/internal/middleware/authentication.go +++ b/internal/middleware/authentication.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "net/http" - "sync/atomic" "time" "projectreshoot/internal/config" @@ -104,7 +103,6 @@ func Authentication( conn *sql.DB, tokenGen *jwt.TokenGenerator, next http.Handler, - maint *uint32, ) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/static/css/output.css" || @@ -114,9 +112,6 @@ func Authentication( } ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) defer cancel() - if atomic.LoadUint32(maint) == 1 { - cancel() - } // Start the transaction tx, err := conn.BeginTx(ctx, nil)