refactor: changed file structure

This commit is contained in:
2025-03-05 20:18:28 +11:00
parent 5c1089e0ce
commit 1d9af44d0a
137 changed files with 4986 additions and 581 deletions

69
cmd/migrate/migrate.go Normal file
View File

@@ -0,0 +1,69 @@
package main
import (
"context"
"database/sql"
"embed"
"fmt"
"io/fs"
"log"
"os"
"strconv"
"github.com/pressly/goose/v3"
_ "modernc.org/sqlite"
)
//go:embed migrations
var migrationsFS embed.FS
func main() {
if len(os.Args) != 4 {
fmt.Println("Usage: migrate <file_path> up-to|down-to <version>")
os.Exit(1)
}
filePath := os.Args[1]
direction := os.Args[2]
versionStr := os.Args[3]
version, err := strconv.Atoi(versionStr)
if err != nil {
log.Fatalf("Invalid version number: %v", err)
}
if _, err := os.Stat(filePath); os.IsNotExist(err) {
log.Fatalf("Database file does not exist: %v", filePath)
}
db, err := sql.Open("sqlite", filePath)
if err != nil {
log.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
migrations, err := fs.Sub(migrationsFS, "migrations")
if err != nil {
log.Fatalf("Failed to get migrations from embedded filesystem")
}
provider, err := goose.NewProvider(goose.DialectSQLite3, db, migrations)
if err != nil {
log.Fatalf("Failed to create migration provider: %v", err)
}
ctx := context.Background()
switch direction {
case "up-to":
_, err = provider.UpTo(ctx, int64(version))
case "down-to":
_, err = provider.DownTo(ctx, int64(version))
default:
log.Fatalf("Invalid direction: use 'up-to' or 'down-to'")
}
if err != nil {
log.Fatalf("Migration failed: %v", err)
}
fmt.Println("Migration successful!")
}

View File

@@ -0,0 +1,27 @@
-- +goose Up
-- +goose StatementBegin
PRAGMA foreign_keys=ON;
CREATE TABLE IF NOT EXISTS jwtblacklist (
jti TEXT PRIMARY KEY CHECK(jti GLOB '[0-9a-fA-F-]*'),
exp INTEGER NOT NULL
) STRICT;
CREATE TABLE IF NOT EXISTS "users" (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL UNIQUE,
password_hash TEXT DEFAULT "",
created_at INTEGER DEFAULT (unixepoch()),
bio TEXT DEFAULT ""
) STRICT;
CREATE TRIGGER IF NOT EXISTS cleanup_expired_tokens
AFTER INSERT ON jwtblacklist
BEGIN
DELETE FROM jwtblacklist WHERE exp < strftime('%s', 'now');
END;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
DROP TRIGGER IF EXISTS cleanup_expired_tokens;
DROP TABLE IF EXISTS jwtblacklist;
DROP TABLE IF EXISTS users;
-- +goose StatementEnd

View File

@@ -0,0 +1,37 @@
package main
import (
"projectreshoot/pkg/config"
"projectreshoot/pkg/db"
"projectreshoot/pkg/tests"
"strconv"
"github.com/pkg/errors"
"github.com/rs/zerolog"
)
func setupDBConn(
args map[string]string,
logger *zerolog.Logger,
config *config.Config,
) (*db.SafeConn, error) {
if args["test"] == "true" {
logger.Debug().Msg("Server in test mode, using test database")
ver, err := strconv.ParseInt(config.DBName, 10, 0)
if err != nil {
return nil, errors.Wrap(err, "strconv.ParseInt")
}
wconn, rconn, err := tests.SetupTestDB(ver)
if err != nil {
return nil, errors.Wrap(err, "tests.SetupTestDB")
}
conn := db.MakeSafe(wconn, rconn, logger)
return conn, nil
} else {
conn, err := db.ConnectToDatabase(config.DBName, logger)
if err != nil {
return nil, errors.Wrap(err, "db.ConnectToDatabase")
}
return conn, nil
}
}

View File

@@ -0,0 +1,30 @@
package main
import (
"flag"
"strconv"
)
func setupFlags() map[string]string {
// Parse commandline args
host := flag.String("host", "", "Override host to listen on")
port := flag.String("port", "", "Override port to listen on")
test := flag.Bool("test", false, "Run server in test mode")
tester := flag.Bool("tester", false, "Run tester function instead of main program")
dbver := flag.Bool("dbver", false, "Get the version of the database required")
loglevel := flag.String("loglevel", "", "Set log level")
logoutput := flag.String("logoutput", "", "Set log destination (file, console or both)")
flag.Parse()
// Map the args for easy access
args := map[string]string{
"host": *host,
"port": *port,
"test": strconv.FormatBool(*test),
"tester": strconv.FormatBool(*tester),
"dbver": strconv.FormatBool(*dbver),
"loglevel": *loglevel,
"logoutput": *logoutput,
}
return args
}

View File

@@ -0,0 +1,16 @@
package main
import (
"context"
"fmt"
"os"
)
func main() {
args := setupFlags()
ctx := context.Background()
if err := run(ctx, os.Stdout, args); err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1)
}
}

View File

@@ -0,0 +1,145 @@
package main
import (
"bytes"
"context"
"fmt"
"net/http"
"os"
"strings"
"syscall"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func Test_main(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
t.Cleanup(cancel)
args := map[string]string{"test": "true"}
var stdout bytes.Buffer
os.Setenv("SECRET_KEY", ".")
os.Setenv("TMDB_API_TOKEN", ".")
os.Setenv("HOST", "127.0.0.1")
os.Setenv("PORT", "3232")
runSrvErr := make(chan error)
go func() {
if err := run(ctx, &stdout, args); err != nil {
runSrvErr <- err
return
}
}()
go func() {
err := waitForReady(ctx, 10*time.Second, "http://127.0.0.1:3232/healthz")
if err != nil {
runSrvErr <- err
return
}
runSrvErr <- nil
}()
select {
case err := <-runSrvErr:
if err != nil {
t.Fatalf("Error starting test server: %s", err)
return
}
t.Log("Test server started")
}
t.Run("SIGUSR1 puts database into global lock", func(t *testing.T) {
done := make(chan bool)
go func() {
expected := "Global database lock acquired"
for {
if strings.Contains(stdout.String(), expected) {
done <- true
return
}
time.Sleep(100 * time.Millisecond)
}
}()
proc, err := os.FindProcess(os.Getpid())
require.NoError(t, err)
proc.Signal(syscall.SIGUSR1)
select {
case <-done:
t.Log("found")
case <-time.After(250 * time.Millisecond):
t.Errorf("Not found")
}
})
t.Run("SIGUSR2 releases database global lock", func(t *testing.T) {
done := make(chan bool)
go func() {
expected := "Global database lock released"
for {
if strings.Contains(stdout.String(), expected) {
done <- true
return
}
time.Sleep(100 * time.Millisecond)
}
}()
proc, err := os.FindProcess(os.Getpid())
require.NoError(t, err)
proc.Signal(syscall.SIGUSR2)
select {
case <-done:
t.Log("found")
case <-time.After(250 * time.Millisecond):
t.Errorf("Not found")
}
})
}
func waitForReady(
ctx context.Context,
timeout time.Duration,
endpoint string,
) error {
client := http.Client{}
startTime := time.Now()
for {
req, err := http.NewRequestWithContext(
ctx,
http.MethodGet,
endpoint,
nil,
)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
fmt.Printf("Error making request: %s\n", err.Error())
time.Sleep(250 * time.Millisecond)
continue
}
if resp.StatusCode == http.StatusOK {
fmt.Println("Endpoint is ready!")
resp.Body.Close()
return nil
}
resp.Body.Close()
select {
case <-ctx.Done():
return ctx.Err()
default:
if time.Since(startTime) >= timeout {
return fmt.Errorf("timeout reached while waiting for endpoint")
}
// wait a little while between checks
time.Sleep(250 * time.Millisecond)
}
}
}

119
cmd/projectreshoot/run.go Normal file
View File

@@ -0,0 +1,119 @@
package main
import (
"context"
"fmt"
"io"
"net/http"
"os"
"os/signal"
"projectreshoot/internal/httpserver"
"projectreshoot/pkg/config"
"projectreshoot/pkg/embedfs"
"projectreshoot/pkg/logging"
"sync"
"time"
"github.com/pkg/errors"
)
var maint uint32 // atomic: 1 if in maintenance mode
// Initializes and runs the server
func run(ctx context.Context, w io.Writer, args map[string]string) error {
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
defer cancel()
config, err := config.GetConfig(args)
if err != nil {
return errors.Wrap(err, "server.GetConfig")
}
// Return the version of the database required
if args["dbver"] == "true" {
fmt.Fprintf(w, "Database version: %s\n", config.DBName)
return nil
}
// Setup the logfile
var logfile *os.File = nil
if config.LogOutput == "both" || config.LogOutput == "file" {
logfile, err = logging.GetLogFile(config.LogDir)
if err != nil {
return errors.Wrap(err, "logging.GetLogFile")
}
defer logfile.Close()
}
// Setup the console writer
var consoleWriter io.Writer
if config.LogOutput == "both" || config.LogOutput == "console" {
consoleWriter = w
}
// Setup the logger
logger, err := logging.GetLogger(
config.LogLevel,
consoleWriter,
logfile,
config.LogDir,
)
if err != nil {
return errors.Wrap(err, "logging.GetLogger")
}
// Setup the database connection
logger.Debug().Msg("Config loaded and logger started")
logger.Debug().Msg("Connecting to database")
conn, err := setupDBConn(args, logger, config)
if err != nil {
return errors.Wrap(err, "setupDBConn")
}
defer conn.Close()
// Setup embedded files
logger.Debug().Msg("Getting embedded files")
staticFS, err := embedfs.GetEmbeddedFS()
if err != nil {
return errors.Wrap(err, "getStaticFiles")
}
logger.Debug().Msg("Setting up HTTP server")
httpServer := httpserver.NewServer(config, logger, conn, &staticFS, &maint)
// Runs function for testing in dev if --tester flag true
if args["tester"] == "true" {
logger.Debug().Msg("Running tester function")
test(config, logger, conn, httpServer)
return nil
}
// Setups a channel to listen for os.Signal
handleMaintSignals(conn, config, httpServer, logger)
// Runs the http server
logger.Debug().Msg("Starting up the HTTP server")
go func() {
logger.Info().Str("address", httpServer.Addr).Msg("Listening for requests")
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.Error().Err(err).Msg("Error listening and serving")
}
}()
// Handles graceful shutdown
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
<-ctx.Done()
shutdownCtx := context.Background()
shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second)
defer cancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil {
logger.Error().Err(err).Msg("Error shutting down server")
}
}()
wg.Wait()
logger.Info().Msg("Shutting down")
return nil
}

View File

@@ -0,0 +1,50 @@
package main
import (
"net/http"
"os"
"os/signal"
"projectreshoot/pkg/config"
"projectreshoot/pkg/db"
"sync/atomic"
"syscall"
"time"
"github.com/rs/zerolog"
)
// Handle SIGUSR1 and SIGUSR2 syscalls to toggle maintenance mode
func handleMaintSignals(
conn *db.SafeConn,
config *config.Config,
srv *http.Server,
logger *zerolog.Logger,
) {
logger.Debug().Msg("Starting signal listener")
ch := make(chan os.Signal, 1)
srv.RegisterOnShutdown(func() {
logger.Debug().Msg("Shutting down signal listener")
close(ch)
})
go func() {
for sig := range ch {
switch sig {
case syscall.SIGUSR1:
if atomic.LoadUint32(&maint) != 1 {
atomic.StoreUint32(&maint, 1)
logger.Info().Msg("Signal received: Starting maintenance")
logger.Info().Msg("Attempting to acquire database lock")
conn.Pause(config.DBLockTimeout * time.Second)
}
case syscall.SIGUSR2:
if atomic.LoadUint32(&maint) != 0 {
logger.Info().Msg("Signal received: Maintenance over")
logger.Info().Msg("Releasing database lock")
conn.Resume()
atomic.StoreUint32(&maint, 0)
}
}
}
}()
signal.Notify(ch, syscall.SIGUSR1, syscall.SIGUSR2)
}

View File

@@ -0,0 +1,32 @@
package main
import (
"net/http"
"projectreshoot/pkg/config"
"projectreshoot/pkg/db"
"projectreshoot/pkg/tmdb"
"github.com/rs/zerolog"
)
// This function will only be called if the --test commandline flag is set.
// After the function finishes the application will close.
// Running command `make tester` will run the test using port 3232 to avoid
// conflicts on the default 3333. Useful for testing things out during dev.
// If you add code here, remember to run:
// `git update-index --assume-unchanged tester.go` to avoid tracking changes
func test(
config *config.Config,
logger *zerolog.Logger,
conn *db.SafeConn,
srv *http.Server,
) {
query := "a few good men"
search, err := tmdb.SearchMovies(config.TMDBToken, query, false, 1)
if err != nil {
logger.Error().Err(err).Msg("error occured")
return
}
logger.Info().Interface("results", search).Msg("search results received")
}