updated to use bun and updated hws modules.

This commit is contained in:
2026-01-11 23:39:10 +11:00
parent 6e03c98ae8
commit 1eedbc5220
33 changed files with 984 additions and 375 deletions

View File

@@ -1,7 +1,7 @@
package main
import (
"database/sql"
"context"
"projectreshoot/internal/config"
"projectreshoot/internal/handler"
"projectreshoot/internal/models"
@@ -11,19 +11,25 @@ import (
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/golib/hwsauth"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
func setupAuth(
config *config.Config,
logger *hlog.Logger,
conn *sql.DB,
db *bun.DB,
server *hws.Server,
ignoredPaths []string,
) (*hwsauth.Authenticator[*models.User], error) {
) (*hwsauth.Authenticator[*models.UserBun, bun.Tx], error) {
beginTx := func(ctx context.Context) (hwsauth.DBTransaction, error) {
tx, err := db.BeginTx(ctx, nil)
return tx, err
}
auth, err := hwsauth.NewAuthenticator(
models.GetUserFromID,
config.HWSAuth,
models.GetUserByID,
server,
conn,
beginTx,
logger,
handler.ErrorPage,
)
@@ -31,21 +37,8 @@ func setupAuth(
return nil, errors.Wrap(err, "hwsauth.NewAuthenticator")
}
auth.SSL = config.SSL
auth.AccessTokenExpiry = config.AccessTokenExpiry
auth.RefreshTokenExpiry = config.RefreshTokenExpiry
auth.TokenFreshTime = config.TokenFreshTime
auth.TrustedHost = config.TrustedHost
auth.SecretKey = config.SecretKey
auth.LandingPage = "/profile"
auth.IgnorePaths(ignoredPaths...)
err = auth.Initialise()
if err != nil {
return nil, errors.Wrap(err, "auth.Initialise")
}
contexts.CurrentUser = auth.CurrentModel
return auth, nil

52
cmd/projectreshoot/db.go Normal file
View File

@@ -0,0 +1,52 @@
package main
import (
"context"
"database/sql"
"fmt"
"projectreshoot/internal/config"
"projectreshoot/internal/models"
"github.com/pkg/errors"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/pgdialect"
"github.com/uptrace/bun/driver/pgdriver"
)
func setupBun(ctx context.Context, cfg *config.DBConfig, resetDB bool) (db *bun.DB, close func() error, err error) {
dsn := fmt.Sprintf("postgres://%s:%s@%s:%v/%s?sslmode=%s", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DB, cfg.SSL)
sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn)))
db = bun.NewDB(sqldb, pgdialect.New())
close = sqldb.Close
err = loadModels(ctx, db, resetDB)
if err != nil {
return nil, nil, errors.Wrap(err, "loadModels")
}
return db, close, nil
}
func loadModels(ctx context.Context, db *bun.DB, resetDB bool) error {
models := []any{
(*models.UserBun)(nil),
}
for _, model := range models {
_, err := db.NewCreateTable().
Model(model).
IfNotExists().
Exec(ctx)
if err != nil {
return errors.Wrap(err, "db.NewCreateTable")
}
if resetDB {
err = db.ResetModel(ctx, model)
if err != nil {
return errors.Wrap(err, "db.ResetModel")
}
}
}
return nil
}

View File

@@ -1,53 +0,0 @@
package main
import (
"database/sql"
"fmt"
"strconv"
"github.com/pkg/errors"
_ "github.com/mattn/go-sqlite3"
)
func setupDBConn(dbName string) (*sql.DB, error) {
opts := "_journal_mode=WAL&_synchronous=NORMAL&_txlock=IMMEDIATE"
file := fmt.Sprintf("file:%s.db?%s", dbName, opts)
conn, err := sql.Open("sqlite3", file)
if err != nil {
return nil, errors.Wrap(err, "sql.Open")
}
err = checkDBVersion(conn, dbName)
if err != nil {
return nil, errors.Wrap(err, "checkDBVersion")
}
return conn, nil
}
// Check the database version
func checkDBVersion(db *sql.DB, dbName string) error {
expectVer, err := strconv.Atoi(dbName)
if err != nil {
return errors.Wrap(err, "strconv.Atoi")
}
query := `SELECT version_id FROM goose_db_version WHERE is_applied = 1
ORDER BY version_id DESC LIMIT 1`
rows, err := db.Query(query)
if err != nil {
return errors.Wrap(err, "db.Query")
}
defer rows.Close()
if rows.Next() {
var version int
err = rows.Scan(&version)
if err != nil {
return errors.Wrap(err, "rows.Scan")
}
if version != expectVer {
return errors.New("Version mismatch")
}
} else {
return errors.New("No version found")
}
return nil
}

View File

@@ -7,22 +7,18 @@ import (
func setupFlags() map[string]string {
// Parse commandline args
host := flag.String("host", "", "Override host to listen on")
port := flag.String("port", "", "Override port to listen on")
test := flag.Bool("test", false, "Run server in test mode")
dbver := flag.Bool("dbver", false, "Get the version of the database required")
loglevel := flag.String("loglevel", "", "Set log level")
logoutput := flag.String("logoutput", "", "Set log destination (file, console or both)")
resetDB := flag.Bool("resetdb", false, "Reset all the database tables with the updated models")
printEnv := flag.Bool("printenv", false, "Print all environment variables and their documentation")
genEnv := flag.String("genenv", "", "Generate a .env file with all environment variables (specify filename)")
envfile := flag.String("envfile", ".env", "Specify a .env file to use for the configuration")
flag.Parse()
// Map the args for easy access
args := map[string]string{
"host": *host,
"port": *port,
"test": strconv.FormatBool(*test),
"dbver": strconv.FormatBool(*dbver),
"loglevel": *loglevel,
"logoutput": *logoutput,
"resetdb": strconv.FormatBool(*resetDB),
"printenv": strconv.FormatBool(*printEnv),
"genenv": *genEnv,
"envfile": *envfile,
}
return args
}

View File

@@ -1,7 +1,6 @@
package main
import (
"database/sql"
"io/fs"
"net/http"
"projectreshoot/internal/config"
@@ -10,39 +9,31 @@ import (
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/golib/hlog"
"git.haelnorr.com/h/golib/jwt"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
func setupHttpServer(
staticFS *fs.FS,
config *config.Config,
logger *hlog.Logger,
conn *sql.DB,
tokenGen *jwt.TokenGenerator,
bun *bun.DB,
) (server *hws.Server, err error) {
if staticFS == nil {
return nil, errors.New("No filesystem provided")
}
fs := http.FS(*staticFS)
httpServer, err := hws.NewServer(
config.Host,
config.Port,
)
httpServer, err := hws.NewServer(config.HWS)
if err != nil {
return nil, errors.Wrap(err, "hws.NewServer")
}
httpServer.ReadHeaderTimeout(config.ReadHeaderTimeout)
httpServer.WriteTimeout(config.WriteTimeout)
httpServer.IdleTimeout(config.IdleTimeout)
httpServer.GZIP = config.GZIP
ignoredPaths := []string{
"/static/css/output.css",
"/static/favicon.ico",
}
auth, err := setupAuth(config, logger, conn, httpServer, ignoredPaths)
auth, err := setupAuth(config, logger, bun, httpServer, ignoredPaths)
if err != nil {
return nil, errors.Wrap(err, "setupAuth")
}
@@ -62,7 +53,7 @@ func setupHttpServer(
return nil, errors.Wrap(err, "httpServer.LoggerIgnorePaths")
}
err = addRoutes(httpServer, &fs, config, logger, conn, tokenGen, auth)
err = addRoutes(httpServer, &fs, config, logger, bun, auth)
if err != nil {
return nil, errors.Wrap(err, "addRoutes")
}

View File

@@ -3,17 +3,18 @@ package main
import (
"io"
"os"
"projectreshoot/internal/config"
"git.haelnorr.com/h/golib/hlog"
"github.com/pkg/errors"
)
// Take in the desired logOutput and a console writer to use
func setupLogger(logLevel hlog.Level, logOutput string, w *io.Writer, logDirectory string) (*hlog.Logger, error) {
func setupLogger(cfg *config.HLOGConfig, w *io.Writer) (*hlog.Logger, error) {
// Setup the logfile
var logfile *os.File = nil
if logOutput == "both" || logOutput == "file" {
logfile, err := hlog.NewLogFile(logDirectory)
if cfg.LogOutput == "both" || cfg.LogOutput == "file" {
logfile, err := hlog.NewLogFile(cfg.LogDir)
if err != nil {
return nil, errors.Wrap(err, "hlog")
}
@@ -22,11 +23,11 @@ func setupLogger(logLevel hlog.Level, logOutput string, w *io.Writer, logDirecto
// Setup the console writer
var consoleWriter io.Writer
if logOutput == "both" || logOutput == "console" {
if cfg.LogOutput == "both" || cfg.LogOutput == "console" {
if w != nil {
consoleWriter = *w
} else {
if logOutput == "console" {
if cfg.LogOutput == "console" {
return nil, errors.New("Console logging specified as sole method but no writer provided")
}
}
@@ -34,10 +35,10 @@ func setupLogger(logLevel hlog.Level, logOutput string, w *io.Writer, logDirecto
// Setup the logger
logger, err := hlog.NewLogger(
logLevel,
cfg.LogLevel,
consoleWriter,
logfile,
logDirectory,
cfg.LogDir,
)
if err != nil {
return nil, errors.Wrap(err, "hlog")

View File

@@ -13,13 +13,32 @@ func main() {
args := setupFlags()
ctx := context.Background()
config, err := config.GetConfig(args)
// Handle printenv flag
if args["printenv"] == "true" {
if err := config.PrintEnvVars(os.Stdout); err != nil {
fmt.Fprintf(os.Stderr, "Failed to print environment variables: %s\n", err)
os.Exit(1)
}
return
}
// Handle genenv flag
if args["genenv"] != "" {
if err := config.GenerateDotEnv(args["genenv"]); err != nil {
fmt.Fprintf(os.Stderr, "Failed to generate .env file: %s\n", err)
os.Exit(1)
}
fmt.Printf("Successfully generated .env file: %s\n", args["genenv"])
return
}
cfg, err := config.GetConfig(args["envfile"])
if err != nil {
fmt.Fprintf(os.Stderr, "%s\n", errors.Wrap(err, "Failed to load config"))
os.Exit(1)
}
if err := run(ctx, os.Stdout, args, config); err != nil {
if err := run(ctx, os.Stdout, args, cfg); err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1)
}

View File

@@ -1,16 +1,18 @@
package main
import (
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/golib/hwsauth"
"projectreshoot/internal/models"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/golib/hwsauth"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
func addMiddleware(
server *hws.Server,
auth *hwsauth.Authenticator[*models.User],
auth *hwsauth.Authenticator[*models.UserBun, bun.Tx],
) error {
err := server.AddMiddleware(

View File

@@ -1,18 +1,18 @@
package main
import (
"database/sql"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/golib/hwsauth"
"net/http"
"projectreshoot/internal/config"
"projectreshoot/internal/handler"
"projectreshoot/internal/models"
"projectreshoot/internal/view/page"
"git.haelnorr.com/h/golib/hws"
"git.haelnorr.com/h/golib/hwsauth"
"git.haelnorr.com/h/golib/hlog"
"git.haelnorr.com/h/golib/jwt"
"github.com/pkg/errors"
"github.com/uptrace/bun"
)
func addRoutes(
@@ -20,9 +20,8 @@ func addRoutes(
staticFS *http.FileSystem,
config *config.Config,
logger *hlog.Logger,
conn *sql.DB,
tokenGen *jwt.TokenGenerator,
auth *hwsauth.Authenticator[*models.User],
db *bun.DB,
auth *hwsauth.Authenticator[*models.UserBun, bun.Tx],
) error {
// Create the routes
routes := []hws.Route{
@@ -44,32 +43,32 @@ func addRoutes(
{
Path: "/login",
Method: hws.MethodGET,
Handler: auth.LogoutReq(handler.LoginPage(config.TrustedHost)),
Handler: auth.LogoutReq(handler.LoginPage(config.HWSAuth.TrustedHost)),
},
{
Path: "/login",
Method: hws.MethodPOST,
Handler: auth.LogoutReq(handler.LoginRequest(server, auth, conn)),
Handler: auth.LogoutReq(handler.LoginRequest(server, auth, db)),
},
{
Path: "/register",
Method: hws.MethodGET,
Handler: auth.LogoutReq(handler.RegisterPage(config.TrustedHost)),
Handler: auth.LogoutReq(handler.RegisterPage(config.HWSAuth.TrustedHost)),
},
{
Path: "/register",
Method: hws.MethodPOST,
Handler: auth.LogoutReq(handler.RegisterRequest(config, logger, conn, tokenGen)),
Handler: auth.LogoutReq(handler.RegisterRequest(server, auth, db)),
},
{
Path: "/logout",
Method: hws.MethodPOST,
Handler: handler.Logout(server, auth, conn),
Handler: handler.Logout(server, auth, db),
},
{
Path: "/reauthenticate",
Method: hws.MethodPOST,
Handler: auth.LoginReq(handler.Reauthenticate(server, auth, conn)),
Handler: auth.LoginReq(handler.Reauthenticate(server, auth, db)),
},
{
Path: "/profile",
@@ -89,17 +88,17 @@ func addRoutes(
{
Path: "/change-username",
Method: hws.MethodPOST,
Handler: auth.LoginReq(auth.FreshReq(handler.ChangeUsername(server, auth, conn))),
Handler: auth.LoginReq(auth.FreshReq(handler.ChangeUsername(server, auth, db))),
},
{
Path: "/change-password",
Method: hws.MethodPOST,
Handler: auth.LoginReq(auth.FreshReq(handler.ChangePassword(server, auth, conn))),
Handler: auth.LoginReq(auth.FreshReq(handler.ChangePassword(server, auth, db))),
},
{
Path: "/change-bio",
Method: hws.MethodPOST,
Handler: auth.LoginReq(handler.ChangeBio(server, auth, conn)),
Handler: auth.LoginReq(handler.ChangeBio(server, auth, db)),
},
{
Path: "/movies",
@@ -109,12 +108,12 @@ func addRoutes(
{
Path: "/search-movies",
Method: hws.MethodPOST,
Handler: handler.SearchMovies(config, logger),
Handler: handler.SearchMovies(config.TMDB, logger),
},
{
Path: "/movie/{movie_id}",
Method: hws.MethodGET,
Handler: handler.Movie(server, config),
Handler: handler.Movie(server, config.TMDB),
},
}

View File

@@ -2,17 +2,15 @@ package main
import (
"context"
"fmt"
"io"
"os"
"os/signal"
"projectreshoot/internal/config"
"projectreshoot/pkg/embedfs"
"strconv"
"sync"
"time"
"git.haelnorr.com/h/golib/jwt"
"github.com/pkg/errors"
)
@@ -21,13 +19,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string, config *confi
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
defer cancel()
// Return the version of the database required
if args["dbver"] == "true" {
fmt.Fprintf(w, "Database version: %s\n", config.DBName)
return nil
}
logger, err := setupLogger(config.LogLevel, config.LogOutput, &w, config.LogDir)
logger, err := setupLogger(config.HLOG, &w)
if err != nil {
return errors.Wrap(err, "setupLogger")
}
@@ -35,11 +27,15 @@ func run(ctx context.Context, w io.Writer, args map[string]string, config *confi
// Setup the database connection
logger.Debug().Msg("Config loaded and logger started")
logger.Debug().Msg("Connecting to database")
conn, err := setupDBConn(config.DBName)
resetdb, err := strconv.ParseBool(args["resetdb"])
if err != nil {
return errors.Wrap(err, "strconv.ParseBool")
}
bun, closedb, err := setupBun(ctx, config.DB, resetdb)
if err != nil {
return errors.Wrap(err, "setupDBConn")
}
defer conn.Close()
defer closedb()
// Setup embedded files
logger.Debug().Msg("Getting embedded files")
@@ -48,19 +44,8 @@ func run(ctx context.Context, w io.Writer, args map[string]string, config *confi
return errors.Wrap(err, "getStaticFiles")
}
// Setup TokenGenerator
logger.Debug().Msg("Creating TokenGenerator")
tokenGen, err := jwt.CreateGenerator(
config.AccessTokenExpiry,
config.RefreshTokenExpiry,
config.TokenFreshTime,
config.TrustedHost,
config.SecretKey,
conn,
)
logger.Debug().Msg("Setting up HTTP server")
httpServer, err := setupHttpServer(&staticFS, config, logger, conn, tokenGen)
httpServer, err := setupHttpServer(&staticFS, config, logger, bun)
if err != nil {
return errors.Wrap(err, "setupHttpServer")
}