diff --git a/main.go b/main.go index a68fb5e..a1fc9a2 100644 --- a/main.go +++ b/main.go @@ -21,6 +21,7 @@ import ( "projectreshoot/db" "projectreshoot/logging" "projectreshoot/server" + "projectreshoot/tests" "github.com/pkg/errors" "github.com/rs/zerolog" @@ -96,7 +97,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { // Return the version of the database required if args["dbver"] == "true" { - fmt.Printf("Database version: %s\n", config.DBName) + fmt.Fprintf(w, "Database version: %s\n", config.DBName) return nil } @@ -126,9 +127,23 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { logger.Debug().Msg("Config loaded and logger started") logger.Debug().Msg("Connecting to database") - conn, err := db.ConnectToDatabase(config.DBName, logger) - if err != nil { - return errors.Wrap(err, "db.ConnectToDatabase") + var conn *db.SafeConn + if args["test"] == "true" { + logger.Debug().Msg("Server in test mode, using test database") + ver, err := strconv.ParseInt(config.DBName, 10, 0) + if err != nil { + return errors.Wrap(err, "strconv.ParseInt") + } + testconn, err := tests.SetupTestDB(ver) + if err != nil { + return errors.Wrap(err, "tests.SetupTestDB") + } + conn = db.MakeSafe(testconn, logger) + } else { + conn, err = db.ConnectToDatabase(config.DBName, logger) + if err != nil { + return errors.Wrap(err, "db.ConnectToDatabase") + } } defer conn.Close() @@ -149,7 +164,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error { } // Runs function for testing in dev if --test flag true - if args["test"] == "true" { + if args["tester"] == "true" { logger.Debug().Msg("Running tester function") test(config, logger, conn, httpServer) return nil @@ -191,7 +206,8 @@ func main() { // Parse commandline args host := flag.String("host", "", "Override host to listen on") port := flag.String("port", "", "Override port to listen on") - test := flag.Bool("test", false, "Run test function instead of main program") + test := flag.Bool("test", false, "Run server in test mode") + tester := flag.Bool("test", false, "Run tester function instead of main program") dbver := flag.Bool("dbver", false, "Get the version of the database required") loglevel := flag.String("loglevel", "", "Set log level") logoutput := flag.String("logoutput", "", "Set log destination (file, console or both)") @@ -202,6 +218,7 @@ func main() { "host": *host, "port": *port, "test": strconv.FormatBool(*test), + "tester": strconv.FormatBool(*tester), "dbver": strconv.FormatBool(*dbver), "loglevel": *loglevel, "logoutput": *logoutput, diff --git a/main_test.go b/main_test.go index 60faa95..bc71d15 100644 --- a/main_test.go +++ b/main_test.go @@ -18,14 +18,35 @@ func Test_main(t *testing.T) { ctx := context.Background() ctx, cancel := context.WithCancel(ctx) t.Cleanup(cancel) - args := map[string]string{} + args := map[string]string{"test": "true"} var stdout bytes.Buffer os.Setenv("SECRET_KEY", ".") os.Setenv("HOST", "127.0.0.1") os.Setenv("PORT", "3232") - go run(ctx, &stdout, args) + runSrvErr := make(chan error) + go func() { + if err := run(ctx, &stdout, args); err != nil { + runSrvErr <- err + return + } + }() - waitForReady(ctx, 10*time.Second, "http://127.0.0.1:3232/healthz") + go func() { + err := waitForReady(ctx, 10*time.Second, "http://127.0.0.1:3232/healthz") + if err != nil { + runSrvErr <- err + return + } + runSrvErr <- nil + }() + select { + case err := <-runSrvErr: + if err != nil { + t.Fatalf("Error starting test server: %s", err) + return + } + t.Log("Test server started") + } t.Run("SIGUSR1 puts database into global lock", func(t *testing.T) { done := make(chan bool) @@ -99,6 +120,7 @@ func waitForReady( resp, err := client.Do(req) if err != nil { fmt.Printf("Error making request: %s\n", err.Error()) + time.Sleep(250 * time.Millisecond) continue } if resp.StatusCode == http.StatusOK { diff --git a/tests/logger.go b/tests/logger.go index d8a0dd9..c3a9118 100644 --- a/tests/logger.go +++ b/tests/logger.go @@ -24,6 +24,10 @@ func NilLogger() *zerolog.Logger { // Return a logger that makes use of the T.Log method to enable debugging tests func DebugLogger(t *testing.T) *zerolog.Logger { - logger := zerolog.New(&TLogWriter{t: t}) + logger := zerolog.New(GetTLogWriter(t)) return &logger } + +func GetTLogWriter(t *testing.T) *TLogWriter { + return &TLogWriter{t: t} +}