refactor: changed file structure
This commit is contained in:
69
cmd/migrate/migrate.go
Normal file
69
cmd/migrate/migrate.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"embed"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/pressly/goose/v3"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
//go:embed migrations
|
||||
var migrationsFS embed.FS
|
||||
|
||||
func main() {
|
||||
if len(os.Args) != 4 {
|
||||
fmt.Println("Usage: migrate <file_path> up-to|down-to <version>")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
filePath := os.Args[1]
|
||||
direction := os.Args[2]
|
||||
versionStr := os.Args[3]
|
||||
|
||||
version, err := strconv.Atoi(versionStr)
|
||||
if err != nil {
|
||||
log.Fatalf("Invalid version number: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||
log.Fatalf("Database file does not exist: %v", filePath)
|
||||
}
|
||||
db, err := sql.Open("sqlite", filePath)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
migrations, err := fs.Sub(migrationsFS, "migrations")
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to get migrations from embedded filesystem")
|
||||
}
|
||||
provider, err := goose.NewProvider(goose.DialectSQLite3, db, migrations)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create migration provider: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
switch direction {
|
||||
case "up-to":
|
||||
_, err = provider.UpTo(ctx, int64(version))
|
||||
case "down-to":
|
||||
_, err = provider.DownTo(ctx, int64(version))
|
||||
default:
|
||||
log.Fatalf("Invalid direction: use 'up-to' or 'down-to'")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Fatalf("Migration failed: %v", err)
|
||||
}
|
||||
|
||||
fmt.Println("Migration successful!")
|
||||
}
|
||||
27
cmd/migrate/migrations/00001_init.sql
Normal file
27
cmd/migrate/migrations/00001_init.sql
Normal file
@@ -0,0 +1,27 @@
|
||||
-- +goose Up
|
||||
-- +goose StatementBegin
|
||||
PRAGMA foreign_keys=ON;
|
||||
CREATE TABLE IF NOT EXISTS jwtblacklist (
|
||||
jti TEXT PRIMARY KEY CHECK(jti GLOB '[0-9a-fA-F-]*'),
|
||||
exp INTEGER NOT NULL
|
||||
) STRICT;
|
||||
CREATE TABLE IF NOT EXISTS "users" (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT NOT NULL UNIQUE,
|
||||
password_hash TEXT DEFAULT "",
|
||||
created_at INTEGER DEFAULT (unixepoch()),
|
||||
bio TEXT DEFAULT ""
|
||||
) STRICT;
|
||||
CREATE TRIGGER IF NOT EXISTS cleanup_expired_tokens
|
||||
AFTER INSERT ON jwtblacklist
|
||||
BEGIN
|
||||
DELETE FROM jwtblacklist WHERE exp < strftime('%s', 'now');
|
||||
END;
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose Down
|
||||
-- +goose StatementBegin
|
||||
DROP TRIGGER IF EXISTS cleanup_expired_tokens;
|
||||
DROP TABLE IF EXISTS jwtblacklist;
|
||||
DROP TABLE IF EXISTS users;
|
||||
-- +goose StatementEnd
|
||||
37
cmd/projectreshoot/dbconn.go
Normal file
37
cmd/projectreshoot/dbconn.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"projectreshoot/pkg/config"
|
||||
"projectreshoot/pkg/db"
|
||||
"projectreshoot/pkg/tests"
|
||||
"strconv"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func setupDBConn(
|
||||
args map[string]string,
|
||||
logger *zerolog.Logger,
|
||||
config *config.Config,
|
||||
) (*db.SafeConn, error) {
|
||||
if args["test"] == "true" {
|
||||
logger.Debug().Msg("Server in test mode, using test database")
|
||||
ver, err := strconv.ParseInt(config.DBName, 10, 0)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "strconv.ParseInt")
|
||||
}
|
||||
wconn, rconn, err := tests.SetupTestDB(ver)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "tests.SetupTestDB")
|
||||
}
|
||||
conn := db.MakeSafe(wconn, rconn, logger)
|
||||
return conn, nil
|
||||
} else {
|
||||
conn, err := db.ConnectToDatabase(config.DBName, logger)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "db.ConnectToDatabase")
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
30
cmd/projectreshoot/flags.go
Normal file
30
cmd/projectreshoot/flags.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func setupFlags() map[string]string {
|
||||
// Parse commandline args
|
||||
host := flag.String("host", "", "Override host to listen on")
|
||||
port := flag.String("port", "", "Override port to listen on")
|
||||
test := flag.Bool("test", false, "Run server in test mode")
|
||||
tester := flag.Bool("tester", false, "Run tester function instead of main program")
|
||||
dbver := flag.Bool("dbver", false, "Get the version of the database required")
|
||||
loglevel := flag.String("loglevel", "", "Set log level")
|
||||
logoutput := flag.String("logoutput", "", "Set log destination (file, console or both)")
|
||||
flag.Parse()
|
||||
|
||||
// Map the args for easy access
|
||||
args := map[string]string{
|
||||
"host": *host,
|
||||
"port": *port,
|
||||
"test": strconv.FormatBool(*test),
|
||||
"tester": strconv.FormatBool(*tester),
|
||||
"dbver": strconv.FormatBool(*dbver),
|
||||
"loglevel": *loglevel,
|
||||
"logoutput": *logoutput,
|
||||
}
|
||||
return args
|
||||
}
|
||||
16
cmd/projectreshoot/main.go
Normal file
16
cmd/projectreshoot/main.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
func main() {
|
||||
args := setupFlags()
|
||||
ctx := context.Background()
|
||||
if err := run(ctx, os.Stdout, args); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
145
cmd/projectreshoot/main_test.go
Normal file
145
cmd/projectreshoot/main_test.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_main(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
t.Cleanup(cancel)
|
||||
args := map[string]string{"test": "true"}
|
||||
var stdout bytes.Buffer
|
||||
os.Setenv("SECRET_KEY", ".")
|
||||
os.Setenv("TMDB_API_TOKEN", ".")
|
||||
os.Setenv("HOST", "127.0.0.1")
|
||||
os.Setenv("PORT", "3232")
|
||||
runSrvErr := make(chan error)
|
||||
go func() {
|
||||
if err := run(ctx, &stdout, args); err != nil {
|
||||
runSrvErr <- err
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
err := waitForReady(ctx, 10*time.Second, "http://127.0.0.1:3232/healthz")
|
||||
if err != nil {
|
||||
runSrvErr <- err
|
||||
return
|
||||
}
|
||||
runSrvErr <- nil
|
||||
}()
|
||||
select {
|
||||
case err := <-runSrvErr:
|
||||
if err != nil {
|
||||
t.Fatalf("Error starting test server: %s", err)
|
||||
return
|
||||
}
|
||||
t.Log("Test server started")
|
||||
}
|
||||
|
||||
t.Run("SIGUSR1 puts database into global lock", func(t *testing.T) {
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
expected := "Global database lock acquired"
|
||||
for {
|
||||
if strings.Contains(stdout.String(), expected) {
|
||||
done <- true
|
||||
return
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
}()
|
||||
|
||||
proc, err := os.FindProcess(os.Getpid())
|
||||
require.NoError(t, err)
|
||||
proc.Signal(syscall.SIGUSR1)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("found")
|
||||
case <-time.After(250 * time.Millisecond):
|
||||
t.Errorf("Not found")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SIGUSR2 releases database global lock", func(t *testing.T) {
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
expected := "Global database lock released"
|
||||
for {
|
||||
if strings.Contains(stdout.String(), expected) {
|
||||
done <- true
|
||||
return
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
}()
|
||||
|
||||
proc, err := os.FindProcess(os.Getpid())
|
||||
require.NoError(t, err)
|
||||
proc.Signal(syscall.SIGUSR2)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("found")
|
||||
case <-time.After(250 * time.Millisecond):
|
||||
t.Errorf("Not found")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func waitForReady(
|
||||
ctx context.Context,
|
||||
timeout time.Duration,
|
||||
endpoint string,
|
||||
) error {
|
||||
client := http.Client{}
|
||||
startTime := time.Now()
|
||||
for {
|
||||
req, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
http.MethodGet,
|
||||
endpoint,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
fmt.Printf("Error making request: %s\n", err.Error())
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
fmt.Println("Endpoint is ready!")
|
||||
resp.Body.Close()
|
||||
return nil
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
if time.Since(startTime) >= timeout {
|
||||
return fmt.Errorf("timeout reached while waiting for endpoint")
|
||||
}
|
||||
// wait a little while between checks
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
}
|
||||
119
cmd/projectreshoot/run.go
Normal file
119
cmd/projectreshoot/run.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"projectreshoot/internal/httpserver"
|
||||
"projectreshoot/pkg/config"
|
||||
"projectreshoot/pkg/embedfs"
|
||||
"projectreshoot/pkg/logging"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var maint uint32 // atomic: 1 if in maintenance mode
|
||||
|
||||
// Initializes and runs the server
|
||||
func run(ctx context.Context, w io.Writer, args map[string]string) error {
|
||||
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
|
||||
defer cancel()
|
||||
|
||||
config, err := config.GetConfig(args)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "server.GetConfig")
|
||||
}
|
||||
|
||||
// Return the version of the database required
|
||||
if args["dbver"] == "true" {
|
||||
fmt.Fprintf(w, "Database version: %s\n", config.DBName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Setup the logfile
|
||||
var logfile *os.File = nil
|
||||
if config.LogOutput == "both" || config.LogOutput == "file" {
|
||||
logfile, err = logging.GetLogFile(config.LogDir)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "logging.GetLogFile")
|
||||
}
|
||||
defer logfile.Close()
|
||||
}
|
||||
|
||||
// Setup the console writer
|
||||
var consoleWriter io.Writer
|
||||
if config.LogOutput == "both" || config.LogOutput == "console" {
|
||||
consoleWriter = w
|
||||
}
|
||||
|
||||
// Setup the logger
|
||||
logger, err := logging.GetLogger(
|
||||
config.LogLevel,
|
||||
consoleWriter,
|
||||
logfile,
|
||||
config.LogDir,
|
||||
)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "logging.GetLogger")
|
||||
}
|
||||
|
||||
// Setup the database connection
|
||||
logger.Debug().Msg("Config loaded and logger started")
|
||||
logger.Debug().Msg("Connecting to database")
|
||||
conn, err := setupDBConn(args, logger, config)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "setupDBConn")
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Setup embedded files
|
||||
logger.Debug().Msg("Getting embedded files")
|
||||
staticFS, err := embedfs.GetEmbeddedFS()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "getStaticFiles")
|
||||
}
|
||||
|
||||
logger.Debug().Msg("Setting up HTTP server")
|
||||
httpServer := httpserver.NewServer(config, logger, conn, &staticFS, &maint)
|
||||
|
||||
// Runs function for testing in dev if --tester flag true
|
||||
if args["tester"] == "true" {
|
||||
logger.Debug().Msg("Running tester function")
|
||||
test(config, logger, conn, httpServer)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Setups a channel to listen for os.Signal
|
||||
handleMaintSignals(conn, config, httpServer, logger)
|
||||
|
||||
// Runs the http server
|
||||
logger.Debug().Msg("Starting up the HTTP server")
|
||||
go func() {
|
||||
logger.Info().Str("address", httpServer.Addr).Msg("Listening for requests")
|
||||
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logger.Error().Err(err).Msg("Error listening and serving")
|
||||
}
|
||||
}()
|
||||
|
||||
// Handles graceful shutdown
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-ctx.Done()
|
||||
shutdownCtx := context.Background()
|
||||
shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second)
|
||||
defer cancel()
|
||||
if err := httpServer.Shutdown(shutdownCtx); err != nil {
|
||||
logger.Error().Err(err).Msg("Error shutting down server")
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
logger.Info().Msg("Shutting down")
|
||||
return nil
|
||||
}
|
||||
50
cmd/projectreshoot/signals.go
Normal file
50
cmd/projectreshoot/signals.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"projectreshoot/pkg/config"
|
||||
"projectreshoot/pkg/db"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// Handle SIGUSR1 and SIGUSR2 syscalls to toggle maintenance mode
|
||||
func handleMaintSignals(
|
||||
conn *db.SafeConn,
|
||||
config *config.Config,
|
||||
srv *http.Server,
|
||||
logger *zerolog.Logger,
|
||||
) {
|
||||
logger.Debug().Msg("Starting signal listener")
|
||||
ch := make(chan os.Signal, 1)
|
||||
srv.RegisterOnShutdown(func() {
|
||||
logger.Debug().Msg("Shutting down signal listener")
|
||||
close(ch)
|
||||
})
|
||||
go func() {
|
||||
for sig := range ch {
|
||||
switch sig {
|
||||
case syscall.SIGUSR1:
|
||||
if atomic.LoadUint32(&maint) != 1 {
|
||||
atomic.StoreUint32(&maint, 1)
|
||||
logger.Info().Msg("Signal received: Starting maintenance")
|
||||
logger.Info().Msg("Attempting to acquire database lock")
|
||||
conn.Pause(config.DBLockTimeout * time.Second)
|
||||
}
|
||||
case syscall.SIGUSR2:
|
||||
if atomic.LoadUint32(&maint) != 0 {
|
||||
logger.Info().Msg("Signal received: Maintenance over")
|
||||
logger.Info().Msg("Releasing database lock")
|
||||
conn.Resume()
|
||||
atomic.StoreUint32(&maint, 0)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
signal.Notify(ch, syscall.SIGUSR1, syscall.SIGUSR2)
|
||||
}
|
||||
32
cmd/projectreshoot/tester.go
Normal file
32
cmd/projectreshoot/tester.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"projectreshoot/pkg/config"
|
||||
"projectreshoot/pkg/db"
|
||||
"projectreshoot/pkg/tmdb"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// This function will only be called if the --test commandline flag is set.
|
||||
// After the function finishes the application will close.
|
||||
// Running command `make tester` will run the test using port 3232 to avoid
|
||||
// conflicts on the default 3333. Useful for testing things out during dev.
|
||||
// If you add code here, remember to run:
|
||||
// `git update-index --assume-unchanged tester.go` to avoid tracking changes
|
||||
func test(
|
||||
config *config.Config,
|
||||
logger *zerolog.Logger,
|
||||
conn *db.SafeConn,
|
||||
srv *http.Server,
|
||||
) {
|
||||
query := "a few good men"
|
||||
search, err := tmdb.SearchMovies(config.TMDBToken, query, false, 1)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("error occured")
|
||||
return
|
||||
}
|
||||
logger.Info().Interface("results", search).Msg("search results received")
|
||||
}
|
||||
Reference in New Issue
Block a user