cleaned up main module
This commit is contained in:
@@ -10,7 +10,6 @@ func setupFlags() map[string]string {
|
|||||||
host := flag.String("host", "", "Override host to listen on")
|
host := flag.String("host", "", "Override host to listen on")
|
||||||
port := flag.String("port", "", "Override port to listen on")
|
port := flag.String("port", "", "Override port to listen on")
|
||||||
test := flag.Bool("test", false, "Run server in test mode")
|
test := flag.Bool("test", false, "Run server in test mode")
|
||||||
tester := flag.Bool("tester", false, "Run tester function instead of main program")
|
|
||||||
dbver := flag.Bool("dbver", false, "Get the version of the database required")
|
dbver := flag.Bool("dbver", false, "Get the version of the database required")
|
||||||
loglevel := flag.String("loglevel", "", "Set log level")
|
loglevel := flag.String("loglevel", "", "Set log level")
|
||||||
logoutput := flag.String("logoutput", "", "Set log destination (file, console or both)")
|
logoutput := flag.String("logoutput", "", "Set log destination (file, console or both)")
|
||||||
@@ -21,7 +20,6 @@ func setupFlags() map[string]string {
|
|||||||
"host": *host,
|
"host": *host,
|
||||||
"port": *port,
|
"port": *port,
|
||||||
"test": strconv.FormatBool(*test),
|
"test": strconv.FormatBool(*test),
|
||||||
"tester": strconv.FormatBool(*tester),
|
|
||||||
"dbver": strconv.FormatBool(*dbver),
|
"dbver": strconv.FormatBool(*dbver),
|
||||||
"loglevel": *loglevel,
|
"loglevel": *loglevel,
|
||||||
"logoutput": *logoutput,
|
"logoutput": *logoutput,
|
||||||
|
|||||||
46
cmd/projectreshoot/logger.go
Normal file
46
cmd/projectreshoot/logger.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"git.haelnorr.com/h/golib/hlog"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Take in the desired logOutput and a console writer to use
|
||||||
|
func setupLogger(logLevel hlog.Level, logOutput string, w *io.Writer, logDirectory string) (*hlog.Logger, error) {
|
||||||
|
// Setup the logfile
|
||||||
|
var logfile *os.File = nil
|
||||||
|
if logOutput == "both" || logOutput == "file" {
|
||||||
|
logfile, err := hlog.NewLogFile(logDirectory)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "hlog")
|
||||||
|
}
|
||||||
|
defer logfile.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup the console writer
|
||||||
|
var consoleWriter io.Writer
|
||||||
|
if logOutput == "both" || logOutput == "console" {
|
||||||
|
if w != nil {
|
||||||
|
consoleWriter = *w
|
||||||
|
} else {
|
||||||
|
if logOutput == "console" {
|
||||||
|
return nil, errors.New("Console logging specified as sole method but no writer provided")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup the logger
|
||||||
|
logger, err := hlog.NewLogger(
|
||||||
|
logLevel,
|
||||||
|
consoleWriter,
|
||||||
|
logfile,
|
||||||
|
logDirectory,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "hlog")
|
||||||
|
}
|
||||||
|
return logger, nil
|
||||||
|
}
|
||||||
@@ -4,12 +4,22 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"projectreshoot/internal/config"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
args := setupFlags()
|
args := setupFlags()
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
if err := run(ctx, os.Stdout, args); err != nil {
|
|
||||||
|
config, err := config.GetConfig(args)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s\n", errors.Wrap(err, "Failed to load config"))
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := run(ctx, os.Stdout, args, config); err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "%s\n", err)
|
fmt.Fprintf(os.Stderr, "%s\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,55 +13,25 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hlog"
|
|
||||||
"git.haelnorr.com/h/golib/jwt"
|
"git.haelnorr.com/h/golib/jwt"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
var maint uint32 // atomic: 1 if in maintenance mode
|
|
||||||
|
|
||||||
// Initializes and runs the server
|
// Initializes and runs the server
|
||||||
func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
func run(ctx context.Context, w io.Writer, args map[string]string, config *config.Config) error {
|
||||||
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
|
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
config, err := config.GetConfig(args)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "server.GetConfig")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the version of the database required
|
// Return the version of the database required
|
||||||
if args["dbver"] == "true" {
|
if args["dbver"] == "true" {
|
||||||
fmt.Fprintf(w, "Database version: %s\n", config.DBName)
|
fmt.Fprintf(w, "Database version: %s\n", config.DBName)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Setup the logfile
|
logger, err := setupLogger(config.LogLevel, config.LogOutput, &w, config.LogDir)
|
||||||
var logfile *os.File = nil
|
|
||||||
if config.LogOutput == "both" || config.LogOutput == "file" {
|
|
||||||
logfile, err = hlog.NewLogFile(config.LogDir)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "logging.GetLogFile")
|
|
||||||
}
|
|
||||||
defer logfile.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setup the console writer
|
|
||||||
var consoleWriter io.Writer
|
|
||||||
if config.LogOutput == "both" || config.LogOutput == "console" {
|
|
||||||
consoleWriter = w
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setup the logger
|
|
||||||
logger, err := hlog.NewLogger(
|
|
||||||
config.LogLevel,
|
|
||||||
consoleWriter,
|
|
||||||
logfile,
|
|
||||||
config.LogDir,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "logging.GetLogger")
|
return errors.Wrap(err, "setupLogger")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Setup the database connection
|
// Setup the database connection
|
||||||
@@ -92,10 +62,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger.Debug().Msg("Setting up HTTP server")
|
logger.Debug().Msg("Setting up HTTP server")
|
||||||
httpServer := httpserver.NewServer(config, logger, conn, tokenGen, &staticFS, &maint)
|
httpServer := httpserver.NewServer(config, logger, conn, tokenGen, &staticFS)
|
||||||
|
|
||||||
// Setups a channel to listen for os.Signal
|
|
||||||
handleMaintSignals(httpServer, logger)
|
|
||||||
|
|
||||||
// Runs the http server
|
// Runs the http server
|
||||||
logger.Debug().Msg("Starting up the HTTP server")
|
logger.Debug().Msg("Starting up the HTTP server")
|
||||||
|
|||||||
@@ -1,41 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"sync/atomic"
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"git.haelnorr.com/h/golib/hlog"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Handle SIGUSR1 and SIGUSR2 syscalls to toggle maintenance mode
|
|
||||||
func handleMaintSignals(
|
|
||||||
srv *http.Server,
|
|
||||||
logger *hlog.Logger,
|
|
||||||
) {
|
|
||||||
logger.Debug().Msg("Starting signal listener")
|
|
||||||
ch := make(chan os.Signal, 1)
|
|
||||||
srv.RegisterOnShutdown(func() {
|
|
||||||
logger.Debug().Msg("Shutting down signal listener")
|
|
||||||
close(ch)
|
|
||||||
})
|
|
||||||
go func() {
|
|
||||||
for sig := range ch {
|
|
||||||
switch sig {
|
|
||||||
case syscall.SIGUSR1:
|
|
||||||
if atomic.LoadUint32(&maint) != 1 {
|
|
||||||
atomic.StoreUint32(&maint, 1)
|
|
||||||
logger.Info().Msg("Signal received: Starting maintenance")
|
|
||||||
}
|
|
||||||
case syscall.SIGUSR2:
|
|
||||||
if atomic.LoadUint32(&maint) != 0 {
|
|
||||||
logger.Info().Msg("Signal received: Maintenance over")
|
|
||||||
atomic.StoreUint32(&maint, 0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
signal.Notify(ch, syscall.SIGUSR1, syscall.SIGUSR2)
|
|
||||||
}
|
|
||||||
@@ -20,10 +20,9 @@ func NewServer(
|
|||||||
conn *sql.DB,
|
conn *sql.DB,
|
||||||
tokenGen *jwt.TokenGenerator,
|
tokenGen *jwt.TokenGenerator,
|
||||||
staticFS *fs.FS,
|
staticFS *fs.FS,
|
||||||
maint *uint32,
|
|
||||||
) *http.Server {
|
) *http.Server {
|
||||||
fs := http.FS(*staticFS)
|
fs := http.FS(*staticFS)
|
||||||
srv := createServer(config, logger, conn, tokenGen, &fs, maint)
|
srv := createServer(config, logger, conn, tokenGen, &fs)
|
||||||
httpServer := &http.Server{
|
httpServer := &http.Server{
|
||||||
Addr: net.JoinHostPort(config.Host, config.Port),
|
Addr: net.JoinHostPort(config.Host, config.Port),
|
||||||
Handler: srv,
|
Handler: srv,
|
||||||
@@ -41,7 +40,6 @@ func createServer(
|
|||||||
conn *sql.DB,
|
conn *sql.DB,
|
||||||
tokenGen *jwt.TokenGenerator,
|
tokenGen *jwt.TokenGenerator,
|
||||||
staticFS *http.FileSystem,
|
staticFS *http.FileSystem,
|
||||||
maint *uint32,
|
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
addRoutes(
|
addRoutes(
|
||||||
@@ -56,7 +54,7 @@ func createServer(
|
|||||||
// Add middleware here, must be added in reverse order of execution
|
// Add middleware here, must be added in reverse order of execution
|
||||||
// i.e. First in list will get executed last during the request handling
|
// i.e. First in list will get executed last during the request handling
|
||||||
handler = middleware.Logging(logger, handler)
|
handler = middleware.Logging(logger, handler)
|
||||||
handler = middleware.Authentication(logger, config, conn, tokenGen, handler, maint)
|
handler = middleware.Authentication(logger, config, conn, tokenGen, handler)
|
||||||
|
|
||||||
// Gzip
|
// Gzip
|
||||||
handler = middleware.Gzip(handler, config.GZIP)
|
handler = middleware.Gzip(handler, config.GZIP)
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"projectreshoot/internal/config"
|
"projectreshoot/internal/config"
|
||||||
@@ -104,7 +103,6 @@ func Authentication(
|
|||||||
conn *sql.DB,
|
conn *sql.DB,
|
||||||
tokenGen *jwt.TokenGenerator,
|
tokenGen *jwt.TokenGenerator,
|
||||||
next http.Handler,
|
next http.Handler,
|
||||||
maint *uint32,
|
|
||||||
) http.Handler {
|
) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.URL.Path == "/static/css/output.css" ||
|
if r.URL.Path == "/static/css/output.css" ||
|
||||||
@@ -114,9 +112,6 @@ func Authentication(
|
|||||||
}
|
}
|
||||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if atomic.LoadUint32(maint) == 1 {
|
|
||||||
cancel()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start the transaction
|
// Start the transaction
|
||||||
tx, err := conn.BeginTx(ctx, nil)
|
tx, err := conn.BeginTx(ctx, nil)
|
||||||
|
|||||||
Reference in New Issue
Block a user