Fixed issue where tests will never fail if server fails to launch

This commit is contained in:
2025-02-22 12:42:49 +11:00
parent 9ee5aa7b8f
commit ef48091906
3 changed files with 53 additions and 10 deletions

25
main.go
View File

@@ -21,6 +21,7 @@ import (
"projectreshoot/db" "projectreshoot/db"
"projectreshoot/logging" "projectreshoot/logging"
"projectreshoot/server" "projectreshoot/server"
"projectreshoot/tests"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@@ -96,7 +97,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
// Return the version of the database required // Return the version of the database required
if args["dbver"] == "true" { if args["dbver"] == "true" {
fmt.Printf("Database version: %s\n", config.DBName) fmt.Fprintf(w, "Database version: %s\n", config.DBName)
return nil return nil
} }
@@ -126,10 +127,24 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
logger.Debug().Msg("Config loaded and logger started") logger.Debug().Msg("Config loaded and logger started")
logger.Debug().Msg("Connecting to database") logger.Debug().Msg("Connecting to database")
conn, err := db.ConnectToDatabase(config.DBName, logger) var conn *db.SafeConn
if args["test"] == "true" {
logger.Debug().Msg("Server in test mode, using test database")
ver, err := strconv.ParseInt(config.DBName, 10, 0)
if err != nil {
return errors.Wrap(err, "strconv.ParseInt")
}
testconn, err := tests.SetupTestDB(ver)
if err != nil {
return errors.Wrap(err, "tests.SetupTestDB")
}
conn = db.MakeSafe(testconn, logger)
} else {
conn, err = db.ConnectToDatabase(config.DBName, logger)
if err != nil { if err != nil {
return errors.Wrap(err, "db.ConnectToDatabase") return errors.Wrap(err, "db.ConnectToDatabase")
} }
}
defer conn.Close() defer conn.Close()
logger.Debug().Msg("Getting static files") logger.Debug().Msg("Getting static files")
@@ -149,7 +164,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
} }
// Runs function for testing in dev if --test flag true // Runs function for testing in dev if --test flag true
if args["test"] == "true" { if args["tester"] == "true" {
logger.Debug().Msg("Running tester function") logger.Debug().Msg("Running tester function")
test(config, logger, conn, httpServer) test(config, logger, conn, httpServer)
return nil return nil
@@ -191,7 +206,8 @@ func main() {
// Parse commandline args // Parse commandline args
host := flag.String("host", "", "Override host to listen on") host := flag.String("host", "", "Override host to listen on")
port := flag.String("port", "", "Override port to listen on") port := flag.String("port", "", "Override port to listen on")
test := flag.Bool("test", false, "Run test function instead of main program") test := flag.Bool("test", false, "Run server in test mode")
tester := flag.Bool("test", false, "Run tester function instead of main program")
dbver := flag.Bool("dbver", false, "Get the version of the database required") dbver := flag.Bool("dbver", false, "Get the version of the database required")
loglevel := flag.String("loglevel", "", "Set log level") loglevel := flag.String("loglevel", "", "Set log level")
logoutput := flag.String("logoutput", "", "Set log destination (file, console or both)") logoutput := flag.String("logoutput", "", "Set log destination (file, console or both)")
@@ -202,6 +218,7 @@ func main() {
"host": *host, "host": *host,
"port": *port, "port": *port,
"test": strconv.FormatBool(*test), "test": strconv.FormatBool(*test),
"tester": strconv.FormatBool(*tester),
"dbver": strconv.FormatBool(*dbver), "dbver": strconv.FormatBool(*dbver),
"loglevel": *loglevel, "loglevel": *loglevel,
"logoutput": *logoutput, "logoutput": *logoutput,

View File

@@ -18,14 +18,35 @@ func Test_main(t *testing.T) {
ctx := context.Background() ctx := context.Background()
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
t.Cleanup(cancel) t.Cleanup(cancel)
args := map[string]string{} args := map[string]string{"test": "true"}
var stdout bytes.Buffer var stdout bytes.Buffer
os.Setenv("SECRET_KEY", ".") os.Setenv("SECRET_KEY", ".")
os.Setenv("HOST", "127.0.0.1") os.Setenv("HOST", "127.0.0.1")
os.Setenv("PORT", "3232") os.Setenv("PORT", "3232")
go run(ctx, &stdout, args) runSrvErr := make(chan error)
go func() {
if err := run(ctx, &stdout, args); err != nil {
runSrvErr <- err
return
}
}()
waitForReady(ctx, 10*time.Second, "http://127.0.0.1:3232/healthz") go func() {
err := waitForReady(ctx, 10*time.Second, "http://127.0.0.1:3232/healthz")
if err != nil {
runSrvErr <- err
return
}
runSrvErr <- nil
}()
select {
case err := <-runSrvErr:
if err != nil {
t.Fatalf("Error starting test server: %s", err)
return
}
t.Log("Test server started")
}
t.Run("SIGUSR1 puts database into global lock", func(t *testing.T) { t.Run("SIGUSR1 puts database into global lock", func(t *testing.T) {
done := make(chan bool) done := make(chan bool)
@@ -99,6 +120,7 @@ func waitForReady(
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
fmt.Printf("Error making request: %s\n", err.Error()) fmt.Printf("Error making request: %s\n", err.Error())
time.Sleep(250 * time.Millisecond)
continue continue
} }
if resp.StatusCode == http.StatusOK { if resp.StatusCode == http.StatusOK {

View File

@@ -24,6 +24,10 @@ func NilLogger() *zerolog.Logger {
// Return a logger that makes use of the T.Log method to enable debugging tests // Return a logger that makes use of the T.Log method to enable debugging tests
func DebugLogger(t *testing.T) *zerolog.Logger { func DebugLogger(t *testing.T) *zerolog.Logger {
logger := zerolog.New(&TLogWriter{t: t}) logger := zerolog.New(GetTLogWriter(t))
return &logger return &logger
} }
func GetTLogWriter(t *testing.T) *TLogWriter {
return &TLogWriter{t: t}
}