migrated out more modules and refactored db system

This commit is contained in:
2026-01-01 21:56:21 +11:00
parent 03095448d6
commit 8f6b4b0026
81 changed files with 462 additions and 5016 deletions

View File

@@ -1,37 +1,53 @@
package main
import (
"projectreshoot/pkg/config"
"projectreshoot/pkg/db"
"projectreshoot/pkg/tests"
"database/sql"
"fmt"
"strconv"
"github.com/pkg/errors"
"github.com/rs/zerolog"
_ "github.com/mattn/go-sqlite3"
)
func setupDBConn(
args map[string]string,
logger *zerolog.Logger,
config *config.Config,
) (*db.SafeConn, error) {
if args["test"] == "true" {
logger.Debug().Msg("Server in test mode, using test database")
ver, err := strconv.ParseInt(config.DBName, 10, 0)
if err != nil {
return nil, errors.Wrap(err, "strconv.ParseInt")
}
wconn, rconn, err := tests.SetupTestDB(ver)
if err != nil {
return nil, errors.Wrap(err, "tests.SetupTestDB")
}
conn := db.MakeSafe(wconn, rconn, logger)
return conn, nil
} else {
conn, err := db.ConnectToDatabase(config.DBName, logger)
if err != nil {
return nil, errors.Wrap(err, "db.ConnectToDatabase")
}
return conn, nil
func setupDBConn(dbName string) (*sql.DB, error) {
opts := "_journal_mode=WAL&_synchronous=NORMAL&_txlock=IMMEDIATE"
file := fmt.Sprintf("file:%s.db?%s", dbName, opts)
conn, err := sql.Open("sqlite3", file)
if err != nil {
return nil, errors.Wrap(err, "sql.Open")
}
err = checkDBVersion(conn, dbName)
if err != nil {
return nil, errors.Wrap(err, "checkDBVersion")
}
return conn, nil
}
// Check the database version
func checkDBVersion(db *sql.DB, dbName string) error {
expectVer, err := strconv.Atoi(dbName)
if err != nil {
return errors.Wrap(err, "strconv.Atoi")
}
query := `SELECT version_id FROM goose_db_version WHERE is_applied = 1
ORDER BY version_id DESC LIMIT 1`
rows, err := db.Query(query)
if err != nil {
return errors.Wrap(err, "db.Query")
}
defer rows.Close()
if rows.Next() {
var version int
err = rows.Scan(&version)
if err != nil {
return errors.Wrap(err, "rows.Scan")
}
if version != expectVer {
return errors.New("Version mismatch")
}
} else {
return errors.New("No version found")
}
return nil
}

View File

@@ -1,145 +0,0 @@
package main
import (
"bytes"
"context"
"fmt"
"net/http"
"os"
"strings"
"syscall"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func Test_main(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
t.Cleanup(cancel)
args := map[string]string{"test": "true"}
var stdout bytes.Buffer
os.Setenv("SECRET_KEY", ".")
os.Setenv("TMDB_API_TOKEN", ".")
os.Setenv("HOST", "127.0.0.1")
os.Setenv("PORT", "3232")
runSrvErr := make(chan error)
go func() {
if err := run(ctx, &stdout, args); err != nil {
runSrvErr <- err
return
}
}()
go func() {
err := waitForReady(ctx, 10*time.Second, "http://127.0.0.1:3232/healthz")
if err != nil {
runSrvErr <- err
return
}
runSrvErr <- nil
}()
select {
case err := <-runSrvErr:
if err != nil {
t.Fatalf("Error starting test server: %s", err)
return
}
t.Log("Test server started")
}
t.Run("SIGUSR1 puts database into global lock", func(t *testing.T) {
done := make(chan bool)
go func() {
expected := "Global database lock acquired"
for {
if strings.Contains(stdout.String(), expected) {
done <- true
return
}
time.Sleep(100 * time.Millisecond)
}
}()
proc, err := os.FindProcess(os.Getpid())
require.NoError(t, err)
proc.Signal(syscall.SIGUSR1)
select {
case <-done:
t.Log("found")
case <-time.After(250 * time.Millisecond):
t.Errorf("Not found")
}
})
t.Run("SIGUSR2 releases database global lock", func(t *testing.T) {
done := make(chan bool)
go func() {
expected := "Global database lock released"
for {
if strings.Contains(stdout.String(), expected) {
done <- true
return
}
time.Sleep(100 * time.Millisecond)
}
}()
proc, err := os.FindProcess(os.Getpid())
require.NoError(t, err)
proc.Signal(syscall.SIGUSR2)
select {
case <-done:
t.Log("found")
case <-time.After(250 * time.Millisecond):
t.Errorf("Not found")
}
})
}
func waitForReady(
ctx context.Context,
timeout time.Duration,
endpoint string,
) error {
client := http.Client{}
startTime := time.Now()
for {
req, err := http.NewRequestWithContext(
ctx,
http.MethodGet,
endpoint,
nil,
)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
fmt.Printf("Error making request: %s\n", err.Error())
time.Sleep(250 * time.Millisecond)
continue
}
if resp.StatusCode == http.StatusOK {
fmt.Println("Endpoint is ready!")
resp.Body.Close()
return nil
}
resp.Body.Close()
select {
case <-ctx.Done():
return ctx.Err()
default:
if time.Since(startTime) >= timeout {
return fmt.Errorf("timeout reached while waiting for endpoint")
}
// wait a little while between checks
time.Sleep(250 * time.Millisecond)
}
}
}

View File

@@ -3,7 +3,6 @@ package main
import (
"context"
"fmt"
"git.haelnorr.com/haelnorr/golibh/logger"
"io"
"net/http"
"os"
@@ -14,6 +13,9 @@ import (
"sync"
"time"
"git.haelnorr.com/h/golib/hlog"
"git.haelnorr.com/h/golib/jwt"
"github.com/pkg/errors"
)
@@ -38,7 +40,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
// Setup the logfile
var logfile *os.File = nil
if config.LogOutput == "both" || config.LogOutput == "file" {
logfile, err = logging.GetLogFile(config.LogDir)
logfile, err = hlog.NewLogFile(config.LogDir)
if err != nil {
return errors.Wrap(err, "logging.GetLogFile")
}
@@ -52,7 +54,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
}
// Setup the logger
logger, err := logging.GetLogger(
logger, err := hlog.NewLogger(
config.LogLevel,
consoleWriter,
logfile,
@@ -65,7 +67,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
// Setup the database connection
logger.Debug().Msg("Config loaded and logger started")
logger.Debug().Msg("Connecting to database")
conn, err := setupDBConn(args, logger, config)
conn, err := setupDBConn(config.DBName)
if err != nil {
return errors.Wrap(err, "setupDBConn")
}
@@ -78,18 +80,22 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
return errors.Wrap(err, "getStaticFiles")
}
logger.Debug().Msg("Setting up HTTP server")
httpServer := httpserver.NewServer(config, logger, conn, &staticFS, &maint)
// Setup TokenGenerator
logger.Debug().Msg("Creating TokenGenerator")
tokenGen, err := jwt.CreateGenerator(
config.AccessTokenExpiry,
config.RefreshTokenExpiry,
config.TokenFreshTime,
config.TrustedHost,
config.SecretKey,
conn,
)
// Runs function for testing in dev if --tester flag true
if args["tester"] == "true" {
logger.Debug().Msg("Running tester function")
test(config, logger, conn, httpServer)
return nil
}
logger.Debug().Msg("Setting up HTTP server")
httpServer := httpserver.NewServer(config, logger, conn, tokenGen, &staticFS, &maint)
// Setups a channel to listen for os.Signal
handleMaintSignals(conn, config, httpServer, logger)
handleMaintSignals(httpServer, logger)
// Runs the http server
logger.Debug().Msg("Starting up the HTTP server")
@@ -102,9 +108,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
// Handles graceful shutdown
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
wg.Go(func() {
<-ctx.Done()
shutdownCtx := context.Background()
shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second)
@@ -112,7 +116,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string) error {
if err := httpServer.Shutdown(shutdownCtx); err != nil {
logger.Error().Err(err).Msg("Error shutting down server")
}
}()
})
wg.Wait()
logger.Info().Msg("Shutting down")
return nil

View File

@@ -4,21 +4,16 @@ import (
"net/http"
"os"
"os/signal"
"projectreshoot/pkg/config"
"projectreshoot/pkg/db"
"sync/atomic"
"syscall"
"time"
"github.com/rs/zerolog"
"git.haelnorr.com/h/golib/hlog"
)
// Handle SIGUSR1 and SIGUSR2 syscalls to toggle maintenance mode
func handleMaintSignals(
conn *db.SafeConn,
config *config.Config,
srv *http.Server,
logger *zerolog.Logger,
logger *hlog.Logger,
) {
logger.Debug().Msg("Starting signal listener")
ch := make(chan os.Signal, 1)
@@ -33,14 +28,10 @@ func handleMaintSignals(
if atomic.LoadUint32(&maint) != 1 {
atomic.StoreUint32(&maint, 1)
logger.Info().Msg("Signal received: Starting maintenance")
logger.Info().Msg("Attempting to acquire database lock")
conn.Pause(config.DBLockTimeout * time.Second)
}
case syscall.SIGUSR2:
if atomic.LoadUint32(&maint) != 0 {
logger.Info().Msg("Signal received: Maintenance over")
logger.Info().Msg("Releasing database lock")
conn.Resume()
atomic.StoreUint32(&maint, 0)
}
}

View File

@@ -1,32 +0,0 @@
package main
import (
"net/http"
"projectreshoot/pkg/config"
"projectreshoot/pkg/db"
"projectreshoot/pkg/tmdb"
"github.com/rs/zerolog"
)
// This function will only be called if the --test commandline flag is set.
// After the function finishes the application will close.
// Running command `make tester` will run the test using port 3232 to avoid
// conflicts on the default 3333. Useful for testing things out during dev.
// If you add code here, remember to run:
// `git update-index --assume-unchanged tester.go` to avoid tracking changes
func test(
config *config.Config,
logger *zerolog.Logger,
conn *db.SafeConn,
srv *http.Server,
) {
query := "a few good men"
search, err := tmdb.SearchMovies(config.TMDBToken, query, false, 1)
if err != nil {
logger.Error().Err(err).Msg("error occured")
return
}
logger.Info().Interface("results", search).Msg("search results received")
}