updated to use bun and updated hws modules.
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"context"
|
||||
"projectreshoot/internal/config"
|
||||
"projectreshoot/internal/handler"
|
||||
"projectreshoot/internal/models"
|
||||
@@ -11,19 +11,25 @@ import (
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/golib/hwsauth"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
func setupAuth(
|
||||
config *config.Config,
|
||||
logger *hlog.Logger,
|
||||
conn *sql.DB,
|
||||
db *bun.DB,
|
||||
server *hws.Server,
|
||||
ignoredPaths []string,
|
||||
) (*hwsauth.Authenticator[*models.User], error) {
|
||||
) (*hwsauth.Authenticator[*models.UserBun, bun.Tx], error) {
|
||||
beginTx := func(ctx context.Context) (hwsauth.DBTransaction, error) {
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
return tx, err
|
||||
}
|
||||
auth, err := hwsauth.NewAuthenticator(
|
||||
models.GetUserFromID,
|
||||
config.HWSAuth,
|
||||
models.GetUserByID,
|
||||
server,
|
||||
conn,
|
||||
beginTx,
|
||||
logger,
|
||||
handler.ErrorPage,
|
||||
)
|
||||
@@ -31,21 +37,8 @@ func setupAuth(
|
||||
return nil, errors.Wrap(err, "hwsauth.NewAuthenticator")
|
||||
}
|
||||
|
||||
auth.SSL = config.SSL
|
||||
auth.AccessTokenExpiry = config.AccessTokenExpiry
|
||||
auth.RefreshTokenExpiry = config.RefreshTokenExpiry
|
||||
auth.TokenFreshTime = config.TokenFreshTime
|
||||
auth.TrustedHost = config.TrustedHost
|
||||
auth.SecretKey = config.SecretKey
|
||||
auth.LandingPage = "/profile"
|
||||
|
||||
auth.IgnorePaths(ignoredPaths...)
|
||||
|
||||
err = auth.Initialise()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "auth.Initialise")
|
||||
}
|
||||
|
||||
contexts.CurrentUser = auth.CurrentModel
|
||||
|
||||
return auth, nil
|
||||
|
||||
52
cmd/projectreshoot/db.go
Normal file
52
cmd/projectreshoot/db.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"projectreshoot/internal/config"
|
||||
"projectreshoot/internal/models"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
"github.com/uptrace/bun/dialect/pgdialect"
|
||||
"github.com/uptrace/bun/driver/pgdriver"
|
||||
)
|
||||
|
||||
func setupBun(ctx context.Context, cfg *config.DBConfig, resetDB bool) (db *bun.DB, close func() error, err error) {
|
||||
dsn := fmt.Sprintf("postgres://%s:%s@%s:%v/%s?sslmode=%s", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DB, cfg.SSL)
|
||||
sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(dsn)))
|
||||
db = bun.NewDB(sqldb, pgdialect.New())
|
||||
close = sqldb.Close
|
||||
|
||||
err = loadModels(ctx, db, resetDB)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "loadModels")
|
||||
}
|
||||
|
||||
return db, close, nil
|
||||
}
|
||||
|
||||
func loadModels(ctx context.Context, db *bun.DB, resetDB bool) error {
|
||||
models := []any{
|
||||
(*models.UserBun)(nil),
|
||||
}
|
||||
|
||||
for _, model := range models {
|
||||
_, err := db.NewCreateTable().
|
||||
Model(model).
|
||||
IfNotExists().
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "db.NewCreateTable")
|
||||
}
|
||||
if resetDB {
|
||||
err = db.ResetModel(ctx, model)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "db.ResetModel")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,53 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func setupDBConn(dbName string) (*sql.DB, error) {
|
||||
opts := "_journal_mode=WAL&_synchronous=NORMAL&_txlock=IMMEDIATE"
|
||||
file := fmt.Sprintf("file:%s.db?%s", dbName, opts)
|
||||
conn, err := sql.Open("sqlite3", file)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "sql.Open")
|
||||
}
|
||||
err = checkDBVersion(conn, dbName)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "checkDBVersion")
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Check the database version
|
||||
func checkDBVersion(db *sql.DB, dbName string) error {
|
||||
expectVer, err := strconv.Atoi(dbName)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "strconv.Atoi")
|
||||
}
|
||||
query := `SELECT version_id FROM goose_db_version WHERE is_applied = 1
|
||||
ORDER BY version_id DESC LIMIT 1`
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "db.Query")
|
||||
}
|
||||
defer rows.Close()
|
||||
if rows.Next() {
|
||||
var version int
|
||||
err = rows.Scan(&version)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "rows.Scan")
|
||||
}
|
||||
if version != expectVer {
|
||||
return errors.New("Version mismatch")
|
||||
}
|
||||
} else {
|
||||
return errors.New("No version found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -7,22 +7,18 @@ import (
|
||||
|
||||
func setupFlags() map[string]string {
|
||||
// Parse commandline args
|
||||
host := flag.String("host", "", "Override host to listen on")
|
||||
port := flag.String("port", "", "Override port to listen on")
|
||||
test := flag.Bool("test", false, "Run server in test mode")
|
||||
dbver := flag.Bool("dbver", false, "Get the version of the database required")
|
||||
loglevel := flag.String("loglevel", "", "Set log level")
|
||||
logoutput := flag.String("logoutput", "", "Set log destination (file, console or both)")
|
||||
resetDB := flag.Bool("resetdb", false, "Reset all the database tables with the updated models")
|
||||
printEnv := flag.Bool("printenv", false, "Print all environment variables and their documentation")
|
||||
genEnv := flag.String("genenv", "", "Generate a .env file with all environment variables (specify filename)")
|
||||
envfile := flag.String("envfile", ".env", "Specify a .env file to use for the configuration")
|
||||
flag.Parse()
|
||||
|
||||
// Map the args for easy access
|
||||
args := map[string]string{
|
||||
"host": *host,
|
||||
"port": *port,
|
||||
"test": strconv.FormatBool(*test),
|
||||
"dbver": strconv.FormatBool(*dbver),
|
||||
"loglevel": *loglevel,
|
||||
"logoutput": *logoutput,
|
||||
"resetdb": strconv.FormatBool(*resetDB),
|
||||
"printenv": strconv.FormatBool(*printEnv),
|
||||
"genenv": *genEnv,
|
||||
"envfile": *envfile,
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"projectreshoot/internal/config"
|
||||
@@ -10,39 +9,31 @@ import (
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
|
||||
"git.haelnorr.com/h/golib/hlog"
|
||||
"git.haelnorr.com/h/golib/jwt"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
func setupHttpServer(
|
||||
staticFS *fs.FS,
|
||||
config *config.Config,
|
||||
logger *hlog.Logger,
|
||||
conn *sql.DB,
|
||||
tokenGen *jwt.TokenGenerator,
|
||||
bun *bun.DB,
|
||||
) (server *hws.Server, err error) {
|
||||
if staticFS == nil {
|
||||
return nil, errors.New("No filesystem provided")
|
||||
}
|
||||
fs := http.FS(*staticFS)
|
||||
httpServer, err := hws.NewServer(
|
||||
config.Host,
|
||||
config.Port,
|
||||
)
|
||||
httpServer, err := hws.NewServer(config.HWS)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "hws.NewServer")
|
||||
}
|
||||
httpServer.ReadHeaderTimeout(config.ReadHeaderTimeout)
|
||||
httpServer.WriteTimeout(config.WriteTimeout)
|
||||
httpServer.IdleTimeout(config.IdleTimeout)
|
||||
httpServer.GZIP = config.GZIP
|
||||
|
||||
ignoredPaths := []string{
|
||||
"/static/css/output.css",
|
||||
"/static/favicon.ico",
|
||||
}
|
||||
|
||||
auth, err := setupAuth(config, logger, conn, httpServer, ignoredPaths)
|
||||
auth, err := setupAuth(config, logger, bun, httpServer, ignoredPaths)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "setupAuth")
|
||||
}
|
||||
@@ -62,7 +53,7 @@ func setupHttpServer(
|
||||
return nil, errors.Wrap(err, "httpServer.LoggerIgnorePaths")
|
||||
}
|
||||
|
||||
err = addRoutes(httpServer, &fs, config, logger, conn, tokenGen, auth)
|
||||
err = addRoutes(httpServer, &fs, config, logger, bun, auth)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "addRoutes")
|
||||
}
|
||||
|
||||
@@ -3,17 +3,18 @@ package main
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"projectreshoot/internal/config"
|
||||
|
||||
"git.haelnorr.com/h/golib/hlog"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Take in the desired logOutput and a console writer to use
|
||||
func setupLogger(logLevel hlog.Level, logOutput string, w *io.Writer, logDirectory string) (*hlog.Logger, error) {
|
||||
func setupLogger(cfg *config.HLOGConfig, w *io.Writer) (*hlog.Logger, error) {
|
||||
// Setup the logfile
|
||||
var logfile *os.File = nil
|
||||
if logOutput == "both" || logOutput == "file" {
|
||||
logfile, err := hlog.NewLogFile(logDirectory)
|
||||
if cfg.LogOutput == "both" || cfg.LogOutput == "file" {
|
||||
logfile, err := hlog.NewLogFile(cfg.LogDir)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "hlog")
|
||||
}
|
||||
@@ -22,11 +23,11 @@ func setupLogger(logLevel hlog.Level, logOutput string, w *io.Writer, logDirecto
|
||||
|
||||
// Setup the console writer
|
||||
var consoleWriter io.Writer
|
||||
if logOutput == "both" || logOutput == "console" {
|
||||
if cfg.LogOutput == "both" || cfg.LogOutput == "console" {
|
||||
if w != nil {
|
||||
consoleWriter = *w
|
||||
} else {
|
||||
if logOutput == "console" {
|
||||
if cfg.LogOutput == "console" {
|
||||
return nil, errors.New("Console logging specified as sole method but no writer provided")
|
||||
}
|
||||
}
|
||||
@@ -34,10 +35,10 @@ func setupLogger(logLevel hlog.Level, logOutput string, w *io.Writer, logDirecto
|
||||
|
||||
// Setup the logger
|
||||
logger, err := hlog.NewLogger(
|
||||
logLevel,
|
||||
cfg.LogLevel,
|
||||
consoleWriter,
|
||||
logfile,
|
||||
logDirectory,
|
||||
cfg.LogDir,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "hlog")
|
||||
|
||||
@@ -13,13 +13,32 @@ func main() {
|
||||
args := setupFlags()
|
||||
ctx := context.Background()
|
||||
|
||||
config, err := config.GetConfig(args)
|
||||
// Handle printenv flag
|
||||
if args["printenv"] == "true" {
|
||||
if err := config.PrintEnvVars(os.Stdout); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to print environment variables: %s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Handle genenv flag
|
||||
if args["genenv"] != "" {
|
||||
if err := config.GenerateDotEnv(args["genenv"]); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Failed to generate .env file: %s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Printf("Successfully generated .env file: %s\n", args["genenv"])
|
||||
return
|
||||
}
|
||||
|
||||
cfg, err := config.GetConfig(args["envfile"])
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s\n", errors.Wrap(err, "Failed to load config"))
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := run(ctx, os.Stdout, args, config); err != nil {
|
||||
if err := run(ctx, os.Stdout, args, cfg); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/golib/hwsauth"
|
||||
"projectreshoot/internal/models"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/golib/hwsauth"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
func addMiddleware(
|
||||
server *hws.Server,
|
||||
auth *hwsauth.Authenticator[*models.User],
|
||||
auth *hwsauth.Authenticator[*models.UserBun, bun.Tx],
|
||||
) error {
|
||||
|
||||
err := server.AddMiddleware(
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/golib/hwsauth"
|
||||
"net/http"
|
||||
"projectreshoot/internal/config"
|
||||
"projectreshoot/internal/handler"
|
||||
"projectreshoot/internal/models"
|
||||
"projectreshoot/internal/view/page"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/golib/hwsauth"
|
||||
|
||||
"git.haelnorr.com/h/golib/hlog"
|
||||
"git.haelnorr.com/h/golib/jwt"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
func addRoutes(
|
||||
@@ -20,9 +20,8 @@ func addRoutes(
|
||||
staticFS *http.FileSystem,
|
||||
config *config.Config,
|
||||
logger *hlog.Logger,
|
||||
conn *sql.DB,
|
||||
tokenGen *jwt.TokenGenerator,
|
||||
auth *hwsauth.Authenticator[*models.User],
|
||||
db *bun.DB,
|
||||
auth *hwsauth.Authenticator[*models.UserBun, bun.Tx],
|
||||
) error {
|
||||
// Create the routes
|
||||
routes := []hws.Route{
|
||||
@@ -44,32 +43,32 @@ func addRoutes(
|
||||
{
|
||||
Path: "/login",
|
||||
Method: hws.MethodGET,
|
||||
Handler: auth.LogoutReq(handler.LoginPage(config.TrustedHost)),
|
||||
Handler: auth.LogoutReq(handler.LoginPage(config.HWSAuth.TrustedHost)),
|
||||
},
|
||||
{
|
||||
Path: "/login",
|
||||
Method: hws.MethodPOST,
|
||||
Handler: auth.LogoutReq(handler.LoginRequest(server, auth, conn)),
|
||||
Handler: auth.LogoutReq(handler.LoginRequest(server, auth, db)),
|
||||
},
|
||||
{
|
||||
Path: "/register",
|
||||
Method: hws.MethodGET,
|
||||
Handler: auth.LogoutReq(handler.RegisterPage(config.TrustedHost)),
|
||||
Handler: auth.LogoutReq(handler.RegisterPage(config.HWSAuth.TrustedHost)),
|
||||
},
|
||||
{
|
||||
Path: "/register",
|
||||
Method: hws.MethodPOST,
|
||||
Handler: auth.LogoutReq(handler.RegisterRequest(config, logger, conn, tokenGen)),
|
||||
Handler: auth.LogoutReq(handler.RegisterRequest(server, auth, db)),
|
||||
},
|
||||
{
|
||||
Path: "/logout",
|
||||
Method: hws.MethodPOST,
|
||||
Handler: handler.Logout(server, auth, conn),
|
||||
Handler: handler.Logout(server, auth, db),
|
||||
},
|
||||
{
|
||||
Path: "/reauthenticate",
|
||||
Method: hws.MethodPOST,
|
||||
Handler: auth.LoginReq(handler.Reauthenticate(server, auth, conn)),
|
||||
Handler: auth.LoginReq(handler.Reauthenticate(server, auth, db)),
|
||||
},
|
||||
{
|
||||
Path: "/profile",
|
||||
@@ -89,17 +88,17 @@ func addRoutes(
|
||||
{
|
||||
Path: "/change-username",
|
||||
Method: hws.MethodPOST,
|
||||
Handler: auth.LoginReq(auth.FreshReq(handler.ChangeUsername(server, auth, conn))),
|
||||
Handler: auth.LoginReq(auth.FreshReq(handler.ChangeUsername(server, auth, db))),
|
||||
},
|
||||
{
|
||||
Path: "/change-password",
|
||||
Method: hws.MethodPOST,
|
||||
Handler: auth.LoginReq(auth.FreshReq(handler.ChangePassword(server, auth, conn))),
|
||||
Handler: auth.LoginReq(auth.FreshReq(handler.ChangePassword(server, auth, db))),
|
||||
},
|
||||
{
|
||||
Path: "/change-bio",
|
||||
Method: hws.MethodPOST,
|
||||
Handler: auth.LoginReq(handler.ChangeBio(server, auth, conn)),
|
||||
Handler: auth.LoginReq(handler.ChangeBio(server, auth, db)),
|
||||
},
|
||||
{
|
||||
Path: "/movies",
|
||||
@@ -109,12 +108,12 @@ func addRoutes(
|
||||
{
|
||||
Path: "/search-movies",
|
||||
Method: hws.MethodPOST,
|
||||
Handler: handler.SearchMovies(config, logger),
|
||||
Handler: handler.SearchMovies(config.TMDB, logger),
|
||||
},
|
||||
{
|
||||
Path: "/movie/{movie_id}",
|
||||
Method: hws.MethodGET,
|
||||
Handler: handler.Movie(server, config),
|
||||
Handler: handler.Movie(server, config.TMDB),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -2,17 +2,15 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/signal"
|
||||
"projectreshoot/internal/config"
|
||||
"projectreshoot/pkg/embedfs"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/jwt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
@@ -21,13 +19,7 @@ func run(ctx context.Context, w io.Writer, args map[string]string, config *confi
|
||||
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
|
||||
defer cancel()
|
||||
|
||||
// Return the version of the database required
|
||||
if args["dbver"] == "true" {
|
||||
fmt.Fprintf(w, "Database version: %s\n", config.DBName)
|
||||
return nil
|
||||
}
|
||||
|
||||
logger, err := setupLogger(config.LogLevel, config.LogOutput, &w, config.LogDir)
|
||||
logger, err := setupLogger(config.HLOG, &w)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "setupLogger")
|
||||
}
|
||||
@@ -35,11 +27,15 @@ func run(ctx context.Context, w io.Writer, args map[string]string, config *confi
|
||||
// Setup the database connection
|
||||
logger.Debug().Msg("Config loaded and logger started")
|
||||
logger.Debug().Msg("Connecting to database")
|
||||
conn, err := setupDBConn(config.DBName)
|
||||
resetdb, err := strconv.ParseBool(args["resetdb"])
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "strconv.ParseBool")
|
||||
}
|
||||
bun, closedb, err := setupBun(ctx, config.DB, resetdb)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "setupDBConn")
|
||||
}
|
||||
defer conn.Close()
|
||||
defer closedb()
|
||||
|
||||
// Setup embedded files
|
||||
logger.Debug().Msg("Getting embedded files")
|
||||
@@ -48,19 +44,8 @@ func run(ctx context.Context, w io.Writer, args map[string]string, config *confi
|
||||
return errors.Wrap(err, "getStaticFiles")
|
||||
}
|
||||
|
||||
// Setup TokenGenerator
|
||||
logger.Debug().Msg("Creating TokenGenerator")
|
||||
tokenGen, err := jwt.CreateGenerator(
|
||||
config.AccessTokenExpiry,
|
||||
config.RefreshTokenExpiry,
|
||||
config.TokenFreshTime,
|
||||
config.TrustedHost,
|
||||
config.SecretKey,
|
||||
conn,
|
||||
)
|
||||
|
||||
logger.Debug().Msg("Setting up HTTP server")
|
||||
httpServer, err := setupHttpServer(&staticFS, config, logger, conn, tokenGen)
|
||||
httpServer, err := setupHttpServer(&staticFS, config, logger, bun)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "setupHttpServer")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user