migrated out more modules and refactored db system
This commit is contained in:
@@ -1,37 +1,53 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"projectreshoot/pkg/config"
|
||||
"projectreshoot/pkg/db"
|
||||
"projectreshoot/pkg/tests"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func setupDBConn(
|
||||
args map[string]string,
|
||||
logger *zerolog.Logger,
|
||||
config *config.Config,
|
||||
) (*db.SafeConn, error) {
|
||||
if args["test"] == "true" {
|
||||
logger.Debug().Msg("Server in test mode, using test database")
|
||||
ver, err := strconv.ParseInt(config.DBName, 10, 0)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "strconv.ParseInt")
|
||||
}
|
||||
wconn, rconn, err := tests.SetupTestDB(ver)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "tests.SetupTestDB")
|
||||
}
|
||||
conn := db.MakeSafe(wconn, rconn, logger)
|
||||
return conn, nil
|
||||
} else {
|
||||
conn, err := db.ConnectToDatabase(config.DBName, logger)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "db.ConnectToDatabase")
|
||||
}
|
||||
return conn, nil
|
||||
func setupDBConn(dbName string) (*sql.DB, error) {
|
||||
opts := "_journal_mode=WAL&_synchronous=NORMAL&_txlock=IMMEDIATE"
|
||||
file := fmt.Sprintf("file:%s.db?%s", dbName, opts)
|
||||
conn, err := sql.Open("sqlite3", file)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "sql.Open")
|
||||
}
|
||||
err = checkDBVersion(conn, dbName)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "checkDBVersion")
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Check the database version
|
||||
func checkDBVersion(db *sql.DB, dbName string) error {
|
||||
expectVer, err := strconv.Atoi(dbName)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "strconv.Atoi")
|
||||
}
|
||||
query := `SELECT version_id FROM goose_db_version WHERE is_applied = 1
|
||||
ORDER BY version_id DESC LIMIT 1`
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "db.Query")
|
||||
}
|
||||
defer rows.Close()
|
||||
if rows.Next() {
|
||||
var version int
|
||||
err = rows.Scan(&version)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "rows.Scan")
|
||||
}
|
||||
if version != expectVer {
|
||||
return errors.New("Version mismatch")
|
||||
}
|
||||
} else {
|
||||
return errors.New("No version found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,145 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_main(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
t.Cleanup(cancel)
|
||||
args := map[string]string{"test": "true"}
|
||||
var stdout bytes.Buffer
|
||||
os.Setenv("SECRET_KEY", ".")
|
||||
os.Setenv("TMDB_API_TOKEN", ".")
|
||||
os.Setenv("HOST", "127.0.0.1")
|
||||
os.Setenv("PORT", "3232")
|
||||
runSrvErr := make(chan error)
|
||||
go func() {
|
||||
if err := run(ctx, &stdout, args); err != nil {
|
||||
runSrvErr <- err
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
err := waitForReady(ctx, 10*time.Second, "http://127.0.0.1:3232/healthz")
|
||||
if err != nil {
|
||||
runSrvErr <- err
|
||||
return
|
||||
}
|
||||
runSrvErr <- nil
|
||||
}()
|
||||
select {
|
||||
case err := <-runSrvErr:
|
||||
if err != nil {
|
||||
t.Fatalf("Error starting test server: %s", err)
|
||||
return
|
||||
}
|
||||
t.Log("Test server started")
|
||||
}
|
||||
|
||||
t.Run("SIGUSR1 puts database into global lock", func(t *testing.T) {
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
expected := "Global database lock acquired"
|
||||
for {
|
||||
if strings.Contains(stdout.String(), expected) {
|
||||
done <- true
|
||||
return
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
}()
|
||||
|
||||
proc, err := os.FindProcess(os.Getpid())
|
||||
require.NoError(t, err)
|
||||
proc.Signal(syscall.SIGUSR1)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("found")
|
||||
case <-time.After(250 * time.Millisecond):
|
||||
t.Errorf("Not found")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SIGUSR2 releases database global lock", func(t *testing.T) {
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
expected := "Global database lock released"
|
||||
for {
|
||||
if strings.Contains(stdout.String(), expected) {
|
||||
done <- true
|
||||
return
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
}()
|
||||
|
||||
proc, err := os.FindProcess(os.Getpid())
|
||||
require.NoError(t, err)
|
||||
proc.Signal(syscall.SIGUSR2)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("found")
|
||||
case <-time.After(250 * time.Millisecond):
|
||||
t.Errorf("Not found")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func waitForReady(
|
||||
ctx context.Context,
|
||||
timeout time.Duration,
|
||||
endpoint string,
|
||||
) error {
|
||||
client := http.Client{}
|
||||
startTime := time.Now()
|
||||
for {
|
||||
req, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
http.MethodGet,
|
||||
endpoint,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
fmt.Printf("Error making request: %s\n", err.Error())
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
fmt.Println("Endpoint is ready!")
|
||||
resp.Body.Close()
|
||||
return nil
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
if time.Since(startTime) >= timeout {
|
||||
return fmt.Errorf("timeout reached while waiting for endpoint")
|
||||
}
|
||||
// wait a little while between checks
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"git.haelnorr.com/haelnorr/golibh/logger"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -14,6 +13,9 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hlog"
|
||||
"git.haelnorr.com/h/golib/jwt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
@@ -38,7 +40,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
||||
// Setup the logfile
|
||||
var logfile *os.File = nil
|
||||
if config.LogOutput == "both" || config.LogOutput == "file" {
|
||||
logfile, err = logging.GetLogFile(config.LogDir)
|
||||
logfile, err = hlog.NewLogFile(config.LogDir)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "logging.GetLogFile")
|
||||
}
|
||||
@@ -52,7 +54,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
||||
}
|
||||
|
||||
// Setup the logger
|
||||
logger, err := logging.GetLogger(
|
||||
logger, err := hlog.NewLogger(
|
||||
config.LogLevel,
|
||||
consoleWriter,
|
||||
logfile,
|
||||
@@ -65,7 +67,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
||||
// Setup the database connection
|
||||
logger.Debug().Msg("Config loaded and logger started")
|
||||
logger.Debug().Msg("Connecting to database")
|
||||
conn, err := setupDBConn(args, logger, config)
|
||||
conn, err := setupDBConn(config.DBName)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "setupDBConn")
|
||||
}
|
||||
@@ -78,18 +80,22 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
||||
return errors.Wrap(err, "getStaticFiles")
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Setting up HTTP server")
|
||||
httpServer := httpserver.NewServer(config, logger, conn, &staticFS, &maint)
|
||||
// Setup TokenGenerator
|
||||
logger.Debug().Msg("Creating TokenGenerator")
|
||||
tokenGen, err := jwt.CreateGenerator(
|
||||
config.AccessTokenExpiry,
|
||||
config.RefreshTokenExpiry,
|
||||
config.TokenFreshTime,
|
||||
config.TrustedHost,
|
||||
config.SecretKey,
|
||||
conn,
|
||||
)
|
||||
|
||||
// Runs function for testing in dev if --tester flag true
|
||||
if args["tester"] == "true" {
|
||||
logger.Debug().Msg("Running tester function")
|
||||
test(config, logger, conn, httpServer)
|
||||
return nil
|
||||
}
|
||||
logger.Debug().Msg("Setting up HTTP server")
|
||||
httpServer := httpserver.NewServer(config, logger, conn, tokenGen, &staticFS, &maint)
|
||||
|
||||
// Setups a channel to listen for os.Signal
|
||||
handleMaintSignals(conn, config, httpServer, logger)
|
||||
handleMaintSignals(httpServer, logger)
|
||||
|
||||
// Runs the http server
|
||||
logger.Debug().Msg("Starting up the HTTP server")
|
||||
@@ -102,9 +108,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
||||
|
||||
// Handles graceful shutdown
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
wg.Go(func() {
|
||||
<-ctx.Done()
|
||||
shutdownCtx := context.Background()
|
||||
shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second)
|
||||
@@ -112,7 +116,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
||||
if err := httpServer.Shutdown(shutdownCtx); err != nil {
|
||||
logger.Error().Err(err).Msg("Error shutting down server")
|
||||
}
|
||||
}()
|
||||
})
|
||||
wg.Wait()
|
||||
logger.Info().Msg("Shutting down")
|
||||
return nil
|
||||
|
||||
@@ -4,21 +4,16 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"projectreshoot/pkg/config"
|
||||
"projectreshoot/pkg/db"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"git.haelnorr.com/h/golib/hlog"
|
||||
)
|
||||
|
||||
// Handle SIGUSR1 and SIGUSR2 syscalls to toggle maintenance mode
|
||||
func handleMaintSignals(
|
||||
conn *db.SafeConn,
|
||||
config *config.Config,
|
||||
srv *http.Server,
|
||||
logger *zerolog.Logger,
|
||||
logger *hlog.Logger,
|
||||
) {
|
||||
logger.Debug().Msg("Starting signal listener")
|
||||
ch := make(chan os.Signal, 1)
|
||||
@@ -33,14 +28,10 @@ func handleMaintSignals(
|
||||
if atomic.LoadUint32(&maint) != 1 {
|
||||
atomic.StoreUint32(&maint, 1)
|
||||
logger.Info().Msg("Signal received: Starting maintenance")
|
||||
logger.Info().Msg("Attempting to acquire database lock")
|
||||
conn.Pause(config.DBLockTimeout * time.Second)
|
||||
}
|
||||
case syscall.SIGUSR2:
|
||||
if atomic.LoadUint32(&maint) != 0 {
|
||||
logger.Info().Msg("Signal received: Maintenance over")
|
||||
logger.Info().Msg("Releasing database lock")
|
||||
conn.Resume()
|
||||
atomic.StoreUint32(&maint, 0)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"projectreshoot/pkg/config"
|
||||
"projectreshoot/pkg/db"
|
||||
"projectreshoot/pkg/tmdb"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// This function will only be called if the --test commandline flag is set.
|
||||
// After the function finishes the application will close.
|
||||
// Running command `make tester` will run the test using port 3232 to avoid
|
||||
// conflicts on the default 3333. Useful for testing things out during dev.
|
||||
// If you add code here, remember to run:
|
||||
// `git update-index --assume-unchanged tester.go` to avoid tracking changes
|
||||
func test(
|
||||
config *config.Config,
|
||||
logger *zerolog.Logger,
|
||||
conn *db.SafeConn,
|
||||
srv *http.Server,
|
||||
) {
|
||||
query := "a few good men"
|
||||
search, err := tmdb.SearchMovies(config.TMDBToken, query, false, 1)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("error occured")
|
||||
return
|
||||
}
|
||||
logger.Info().Interface("results", search).Msg("search results received")
|
||||
}
|
||||
Reference in New Issue
Block a user