cleaned up main module

This commit is contained in:
2026-01-02 19:50:10 +11:00
parent 1bcdf0e813
commit 4a21ba3821
7 changed files with 63 additions and 90 deletions

View File

@@ -10,7 +10,6 @@ func setupFlags() map[string]string {
host := flag.String("host", "", "Override host to listen on")
port := flag.String("port", "", "Override port to listen on")
test := flag.Bool("test", false, "Run server in test mode")
tester := flag.Bool("tester", false, "Run tester function instead of main program")
dbver := flag.Bool("dbver", false, "Get the version of the database required")
loglevel := flag.String("loglevel", "", "Set log level")
logoutput := flag.String("logoutput", "", "Set log destination (file, console or both)")
@@ -21,7 +20,6 @@ func setupFlags() map[string]string {
"host": *host,
"port": *port,
"test": strconv.FormatBool(*test),
"tester": strconv.FormatBool(*tester),
"dbver": strconv.FormatBool(*dbver),
"loglevel": *loglevel,
"logoutput": *logoutput,

View File

@@ -0,0 +1,46 @@
package main
import (
"io"
"os"
"git.haelnorr.com/h/golib/hlog"
"github.com/pkg/errors"
)
// Take in the desired logOutput and a console writer to use
func setupLogger(logLevel hlog.Level, logOutput string, w *io.Writer, logDirectory string) (*hlog.Logger, error) {
// Setup the logfile
var logfile *os.File = nil
if logOutput == "both" || logOutput == "file" {
logfile, err := hlog.NewLogFile(logDirectory)
if err != nil {
return nil, errors.Wrap(err, "hlog")
}
defer logfile.Close()
}
// Setup the console writer
var consoleWriter io.Writer
if logOutput == "both" || logOutput == "console" {
if w != nil {
consoleWriter = *w
} else {
if logOutput == "console" {
return nil, errors.New("Console logging specified as sole method but no writer provided")
}
}
}
// Setup the logger
logger, err := hlog.NewLogger(
logLevel,
consoleWriter,
logfile,
logDirectory,
)
if err != nil {
return nil, errors.Wrap(err, "hlog")
}
return logger, nil
}

View File

@@ -4,12 +4,22 @@ import (
"context"
"fmt"
"os"
"projectreshoot/internal/config"
"github.com/pkg/errors"
)
func main() {
args := setupFlags()
ctx := context.Background()
if err := run(ctx, os.Stdout, args); err != nil {
config, err := config.GetConfig(args)
if err != nil {
fmt.Fprintf(os.Stderr, "%s\n", errors.Wrap(err, "Failed to load config"))
os.Exit(1)
}
if err := run(ctx, os.Stdout, args, config); err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1)
}

View File

@@ -13,55 +13,25 @@ import (
"sync"
"time"
"git.haelnorr.com/h/golib/hlog"
"git.haelnorr.com/h/golib/jwt"
"github.com/pkg/errors"
)
var maint uint32 // atomic: 1 if in maintenance mode
// Initializes and runs the server
func run(ctx context.Context, w io.Writer, args map[string]string) error {
func run(ctx context.Context, w io.Writer, args map[string]string, config *config.Config) error {
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
defer cancel()
config, err := config.GetConfig(args)
if err != nil {
return errors.Wrap(err, "server.GetConfig")
}
// Return the version of the database required
if args["dbver"] == "true" {
fmt.Fprintf(w, "Database version: %s\n", config.DBName)
return nil
}
// Setup the logfile
var logfile *os.File = nil
if config.LogOutput == "both" || config.LogOutput == "file" {
logfile, err = hlog.NewLogFile(config.LogDir)
if err != nil {
return errors.Wrap(err, "logging.GetLogFile")
}
defer logfile.Close()
}
// Setup the console writer
var consoleWriter io.Writer
if config.LogOutput == "both" || config.LogOutput == "console" {
consoleWriter = w
}
// Setup the logger
logger, err := hlog.NewLogger(
config.LogLevel,
consoleWriter,
logfile,
config.LogDir,
)
logger, err := setupLogger(config.LogLevel, config.LogOutput, &w, config.LogDir)
if err != nil {
return errors.Wrap(err, "logging.GetLogger")
return errors.Wrap(err, "setupLogger")
}
// Setup the database connection
@@ -92,10 +62,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
)
logger.Debug().Msg("Setting up HTTP server")
httpServer := httpserver.NewServer(config, logger, conn, tokenGen, &staticFS, &maint)
// Setups a channel to listen for os.Signal
handleMaintSignals(httpServer, logger)
httpServer := httpserver.NewServer(config, logger, conn, tokenGen, &staticFS)
// Runs the http server
logger.Debug().Msg("Starting up the HTTP server")

View File

@@ -1,41 +0,0 @@
package main
import (
"net/http"
"os"
"os/signal"
"sync/atomic"
"syscall"
"git.haelnorr.com/h/golib/hlog"
)
// Handle SIGUSR1 and SIGUSR2 syscalls to toggle maintenance mode
func handleMaintSignals(
srv *http.Server,
logger *hlog.Logger,
) {
logger.Debug().Msg("Starting signal listener")
ch := make(chan os.Signal, 1)
srv.RegisterOnShutdown(func() {
logger.Debug().Msg("Shutting down signal listener")
close(ch)
})
go func() {
for sig := range ch {
switch sig {
case syscall.SIGUSR1:
if atomic.LoadUint32(&maint) != 1 {
atomic.StoreUint32(&maint, 1)
logger.Info().Msg("Signal received: Starting maintenance")
}
case syscall.SIGUSR2:
if atomic.LoadUint32(&maint) != 0 {
logger.Info().Msg("Signal received: Maintenance over")
atomic.StoreUint32(&maint, 0)
}
}
}
}()
signal.Notify(ch, syscall.SIGUSR1, syscall.SIGUSR2)
}

View File

@@ -20,10 +20,9 @@ func NewServer(
conn *sql.DB,
tokenGen *jwt.TokenGenerator,
staticFS *fs.FS,
maint *uint32,
) *http.Server {
fs := http.FS(*staticFS)
srv := createServer(config, logger, conn, tokenGen, &fs, maint)
srv := createServer(config, logger, conn, tokenGen, &fs)
httpServer := &http.Server{
Addr: net.JoinHostPort(config.Host, config.Port),
Handler: srv,
@@ -41,7 +40,6 @@ func createServer(
conn *sql.DB,
tokenGen *jwt.TokenGenerator,
staticFS *http.FileSystem,
maint *uint32,
) http.Handler {
mux := http.NewServeMux()
addRoutes(
@@ -56,7 +54,7 @@ func createServer(
// Add middleware here, must be added in reverse order of execution
// i.e. First in list will get executed last during the request handling
handler = middleware.Logging(logger, handler)
handler = middleware.Authentication(logger, config, conn, tokenGen, handler, maint)
handler = middleware.Authentication(logger, config, conn, tokenGen, handler)
// Gzip
handler = middleware.Gzip(handler, config.GZIP)

View File

@@ -4,7 +4,6 @@ import (
"context"
"database/sql"
"net/http"
"sync/atomic"
"time"
"projectreshoot/internal/config"
@@ -104,7 +103,6 @@ func Authentication(
conn *sql.DB,
tokenGen *jwt.TokenGenerator,
next http.Handler,
maint *uint32,
) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/static/css/output.css" ||
@@ -114,9 +112,6 @@ func Authentication(
}
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
if atomic.LoadUint32(maint) == 1 {
cancel()
}
// Start the transaction
tx, err := conn.BeginTx(ctx, nil)