updated to use bun and updated hws modules.
This commit is contained in:
37
internal/config/auth.go
Normal file
37
internal/config/auth.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"git.haelnorr.com/h/golib/env"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type HWSAUTHConfig struct {
|
||||
SSL bool // ENV HWSAUTH_SSL: Flag for SSL Mode (default: false)
|
||||
TrustedHost string // ENV HWSAUTH_TRUSTED_HOST: Full server address to accept as trusted SSL host (required if SSL is true)
|
||||
SecretKey string // ENV HWSAUTH_SECRET_KEY: Secret key for signing tokens (required)
|
||||
AccessTokenExpiry int64 // ENV HWSAUTH_ACCESS_TOKEN_EXPIRY: Access token expiry in minutes (default: 5)
|
||||
RefreshTokenExpiry int64 // ENV HWSAUTH_REFRESH_TOKEN_EXPIRY: Refresh token expiry in minutes (default: 1440)
|
||||
TokenFreshTime int64 // ENV HWSAUTH_TOKEN_FRESH_TIME: Time for tokens to stay fresh in minutes (default: 5)
|
||||
}
|
||||
|
||||
func setupHWSAuth() (*HWSAUTHConfig, error) {
|
||||
ssl := env.Bool("HWSAUTH_SSL", false)
|
||||
trustedHost := env.String("HWS_TRUSTED_HOST", "")
|
||||
if ssl && trustedHost == "" {
|
||||
return nil, errors.New("SSL is enabled and no HWS_TRUSTED_HOST set")
|
||||
}
|
||||
cfg := &HWSAUTHConfig{
|
||||
SSL: ssl,
|
||||
TrustedHost: trustedHost,
|
||||
SecretKey: env.String("HWSAUTH_SECRET_KEY", ""),
|
||||
AccessTokenExpiry: env.Int64("HWSAUTH_ACCESS_TOKEN_EXPIRY", 5),
|
||||
RefreshTokenExpiry: env.Int64("HWSAUTH_REFRESH_TOKEN_EXPIRY", 1440),
|
||||
TokenFreshTime: env.Int64("HWSAUTH_TOKEN_FRESH_TIME", 5),
|
||||
}
|
||||
|
||||
if cfg.SecretKey == "" {
|
||||
return nil, errors.New("Envar not set: HWSAUTH_SECRET_KEY")
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
@@ -1,116 +1,55 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/env"
|
||||
"git.haelnorr.com/h/golib/hlog"
|
||||
"git.haelnorr.com/h/golib/tmdb"
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/golib/hwsauth"
|
||||
"github.com/joho/godotenv"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Host string // Host to listen on
|
||||
Port string // Port to listen on
|
||||
TrustedHost string // Domain/Hostname to accept as trusted
|
||||
SSL bool // Flag for SSL Mode
|
||||
GZIP bool // Flag for GZIP compression on requests
|
||||
ReadHeaderTimeout time.Duration // Timeout for reading request headers in seconds
|
||||
WriteTimeout time.Duration // Timeout for writing requests in seconds
|
||||
IdleTimeout time.Duration // Timeout for idle connections in seconds
|
||||
DBName string // Filename of the db - hardcoded and doubles as DB version
|
||||
DBLockTimeout time.Duration // Timeout for acquiring database lock
|
||||
SecretKey string // Secret key for signing tokens
|
||||
AccessTokenExpiry int64 // Access token expiry in minutes
|
||||
RefreshTokenExpiry int64 // Refresh token expiry in minutes
|
||||
TokenFreshTime int64 // Time for tokens to stay fresh in minutes
|
||||
LogLevel hlog.Level // Log level for global logging. Defaults to info
|
||||
LogOutput string // "file", "console", or "both". Defaults to console
|
||||
LogDir string // Path to create log files
|
||||
TMDBToken string // Read access token for TMDB API
|
||||
TMDBConfig *tmdb.Config // Config data for interfacing with TMDB
|
||||
DB *DBConfig
|
||||
HWS *hws.Config
|
||||
HWSAuth *hwsauth.Config
|
||||
TMDB *TMDBConfig
|
||||
HLOG *HLOGConfig
|
||||
}
|
||||
|
||||
// Load the application configuration and get a pointer to the Config object
|
||||
func GetConfig(args map[string]string) (*Config, error) {
|
||||
godotenv.Load(".env")
|
||||
var (
|
||||
host string
|
||||
port string
|
||||
logLevel hlog.Level
|
||||
logOutput string
|
||||
valid bool
|
||||
)
|
||||
func GetConfig(envfile string) (*Config, error) {
|
||||
godotenv.Load(envfile)
|
||||
|
||||
if args["host"] != "" {
|
||||
host = args["host"]
|
||||
} else {
|
||||
host = env.String("HOST", "127.0.0.1")
|
||||
}
|
||||
if args["port"] != "" {
|
||||
port = args["port"]
|
||||
} else {
|
||||
port = env.String("PORT", "3010")
|
||||
}
|
||||
if args["loglevel"] != "" {
|
||||
logLevel = hlog.LogLevel(args["loglevel"])
|
||||
} else {
|
||||
logLevel = hlog.LogLevel(env.String("LOG_LEVEL", "info"))
|
||||
}
|
||||
if args["logoutput"] != "" {
|
||||
opts := map[string]string{
|
||||
"both": "both",
|
||||
"file": "file",
|
||||
"console": "console",
|
||||
}
|
||||
logOutput, valid = opts[args["logoutput"]]
|
||||
if !valid {
|
||||
logOutput = "console"
|
||||
fmt.Println(
|
||||
"Log output type was not parsed correctly. Defaulting to console only",
|
||||
)
|
||||
}
|
||||
} else {
|
||||
logOutput = env.String("LOG_OUTPUT", "console")
|
||||
}
|
||||
if logOutput != "both" && logOutput != "console" && logOutput != "file" {
|
||||
logOutput = "console"
|
||||
}
|
||||
tmdbcfg, err := tmdb.GetConfig(os.Getenv("TMDB_API_TOKEN"))
|
||||
db, err := setupDB()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "tmdb.GetConfig")
|
||||
return nil, errors.Wrap(err, "setupDB")
|
||||
}
|
||||
|
||||
hws, err := hws.ConfigFromEnv()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "hws.ConfigFromEnv")
|
||||
}
|
||||
|
||||
hwsAuth, err := hwsauth.ConfigFromEnv()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "hwsauth.ConfigFromEnv")
|
||||
}
|
||||
|
||||
tmdb, err := setupTMDB()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "setupTMDB")
|
||||
}
|
||||
|
||||
hlog, err := setupHLOG()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "setupHLOG")
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
Host: host,
|
||||
Port: port,
|
||||
TrustedHost: env.String("TRUSTED_HOST", "127.0.0.1"),
|
||||
SSL: env.Bool("SSL_MODE", false),
|
||||
GZIP: env.Bool("GZIP", false),
|
||||
ReadHeaderTimeout: env.Duration("READ_HEADER_TIMEOUT", 2) * time.Second,
|
||||
WriteTimeout: env.Duration("WRITE_TIMEOUT", 10) * time.Second,
|
||||
IdleTimeout: env.Duration("IDLE_TIMEOUT", 120) * time.Second,
|
||||
DBName: "00001",
|
||||
DBLockTimeout: env.Duration("DB_LOCK_TIMEOUT", 60),
|
||||
SecretKey: env.String("SECRET_KEY", ""),
|
||||
AccessTokenExpiry: env.Int64("ACCESS_TOKEN_EXPIRY", 5),
|
||||
RefreshTokenExpiry: env.Int64("REFRESH_TOKEN_EXPIRY", 1440), // defaults to 1 day
|
||||
TokenFreshTime: env.Int64("TOKEN_FRESH_TIME", 5),
|
||||
LogLevel: logLevel,
|
||||
LogOutput: logOutput,
|
||||
LogDir: env.String("LOG_DIR", ""),
|
||||
TMDBToken: env.String("TMDB_API_TOKEN", ""),
|
||||
TMDBConfig: tmdbcfg,
|
||||
}
|
||||
|
||||
if config.SecretKey == "" && args["dbver"] != "true" {
|
||||
return nil, errors.New("Envar not set: SECRET_KEY")
|
||||
}
|
||||
if config.TMDBToken == "" && args["dbver"] != "true" {
|
||||
return nil, errors.New("Envar not set: TMDB_API_TOKEN")
|
||||
DB: db,
|
||||
HWS: hws,
|
||||
HWSAuth: hwsAuth,
|
||||
TMDB: tmdb,
|
||||
HLOG: hlog,
|
||||
}
|
||||
|
||||
return config, nil
|
||||
|
||||
55
internal/config/db.go
Normal file
55
internal/config/db.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"git.haelnorr.com/h/golib/env"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type DBConfig struct {
|
||||
User string // ENV DB_USER: Database user for authentication (required)
|
||||
Password string // ENV DB_PASSWORD: Database password for authentication (required)
|
||||
Host string // ENV DB_HOST: Database host address (required)
|
||||
Port uint16 // ENV DB_PORT: Database port (default: 5432)
|
||||
DB string // ENV DB_NAME: Database name to connect to (required)
|
||||
SSL string // ENV DB_SSL: SSL mode for connection (default: disable)
|
||||
}
|
||||
|
||||
func setupDB() (*DBConfig, error) {
|
||||
cfg := &DBConfig{
|
||||
User: env.String("DB_USER", ""),
|
||||
Password: env.String("DB_PASSWORD", ""),
|
||||
Host: env.String("DB_HOST", ""),
|
||||
Port: env.UInt16("DB_PORT", 5432),
|
||||
DB: env.String("DB_NAME", ""),
|
||||
SSL: env.String("DB_SSL", "disable"),
|
||||
}
|
||||
|
||||
// Validate SSL mode
|
||||
validSSLModes := map[string]bool{
|
||||
"disable": true,
|
||||
"require": true,
|
||||
"verify-ca": true,
|
||||
"verify-full": true,
|
||||
"allow": true,
|
||||
"prefer": true,
|
||||
}
|
||||
if !validSSLModes[cfg.SSL] {
|
||||
return nil, errors.Errorf("Invalid DB_SSL value: %s. Must be one of: disable, allow, prefer, require, verify-ca, verify-full", cfg.SSL)
|
||||
}
|
||||
|
||||
// Check required fields
|
||||
if cfg.User == "" {
|
||||
return nil, errors.New("Envar not set: DB_USER")
|
||||
}
|
||||
if cfg.Password == "" {
|
||||
return nil, errors.New("Envar not set: DB_PASSWORD")
|
||||
}
|
||||
if cfg.Host == "" {
|
||||
return nil, errors.New("Envar not set: DB_HOST")
|
||||
}
|
||||
if cfg.DB == "" {
|
||||
return nil, errors.New("Envar not set: DB_NAME")
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
114
internal/config/envdoc.go
Normal file
114
internal/config/envdoc.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// EnvVar represents an environment variable with its documentation
|
||||
type EnvVar struct {
|
||||
Name string
|
||||
Description string
|
||||
Default string
|
||||
HasDefault bool
|
||||
Required bool
|
||||
}
|
||||
|
||||
// extractEnvVars parses a struct's field comments to extract environment variable documentation
|
||||
func extractEnvVars(structType reflect.Type, fieldIndex int) *EnvVar {
|
||||
field := structType.Field(fieldIndex)
|
||||
tag := field.Tag.Get("comment")
|
||||
if tag == "" {
|
||||
// Try to get the comment from the struct field's tag or use reflection
|
||||
// For now, we'll parse it manually from the comment string
|
||||
return nil
|
||||
}
|
||||
|
||||
comment := tag
|
||||
if !strings.HasPrefix(comment, "ENV ") {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remove "ENV " prefix
|
||||
comment = strings.TrimPrefix(comment, "ENV ")
|
||||
|
||||
// Extract name and description
|
||||
parts := strings.SplitN(comment, ":", 2)
|
||||
if len(parts) != 2 {
|
||||
return nil
|
||||
}
|
||||
|
||||
name := strings.TrimSpace(parts[0])
|
||||
desc := strings.TrimSpace(parts[1])
|
||||
|
||||
// Check for default value in description
|
||||
defaultRegex := regexp.MustCompile(`\(default:\s*([^)]+)\)`)
|
||||
matches := defaultRegex.FindStringSubmatch(desc)
|
||||
|
||||
envVar := &EnvVar{
|
||||
Name: name,
|
||||
Description: desc,
|
||||
}
|
||||
|
||||
if len(matches) > 1 {
|
||||
envVar.Default = matches[1]
|
||||
envVar.HasDefault = true
|
||||
// Remove the default notation from description
|
||||
envVar.Description = strings.TrimSpace(defaultRegex.ReplaceAllString(desc, ""))
|
||||
}
|
||||
|
||||
return envVar
|
||||
}
|
||||
|
||||
// GetAllEnvVars returns a list of all environment variables used in the config
|
||||
func GetAllEnvVars() []EnvVar {
|
||||
var envVars []EnvVar
|
||||
|
||||
// Manually define all env vars based on the config structs
|
||||
// This is more reliable than reflection for extracting comments
|
||||
|
||||
// DBConfig
|
||||
envVars = append(envVars, []EnvVar{
|
||||
{Name: "DB_USER", Description: "Database user for authentication", HasDefault: false, Required: true},
|
||||
{Name: "DB_PASSWORD", Description: "Database password for authentication", HasDefault: false, Required: true},
|
||||
{Name: "DB_HOST", Description: "Database host address", HasDefault: false, Required: true},
|
||||
{Name: "DB_PORT", Description: "Database port", Default: "5432", HasDefault: true, Required: false},
|
||||
{Name: "DB_NAME", Description: "Database name to connect to", HasDefault: false, Required: true},
|
||||
{Name: "DB_SSL", Description: "SSL mode for connection", Default: "disable", HasDefault: true, Required: false},
|
||||
}...)
|
||||
|
||||
// HWSConfig
|
||||
envVars = append(envVars, []EnvVar{
|
||||
{Name: "HWS_HOST", Description: "Host to listen on", Default: "127.0.0.1", HasDefault: true, Required: false},
|
||||
{Name: "HWS_PORT", Description: "Port to listen on", Default: "3000", HasDefault: true, Required: false},
|
||||
{Name: "HWS_TRUSTED_HOST", Description: "Domain/Hostname to accept as trusted", Default: "same as Host", HasDefault: true, Required: false},
|
||||
{Name: "HWS_SSL", Description: "Flag for SSL Mode", Default: "false", HasDefault: true, Required: false},
|
||||
{Name: "HWS_GZIP", Description: "Flag for GZIP compression on requests", Default: "false", HasDefault: true, Required: false},
|
||||
{Name: "HWS_READ_HEADER_TIMEOUT", Description: "Timeout for reading request headers in seconds", Default: "2", HasDefault: true, Required: false},
|
||||
{Name: "HWS_WRITE_TIMEOUT", Description: "Timeout for writing requests in seconds", Default: "10", HasDefault: true, Required: false},
|
||||
{Name: "HWS_IDLE_TIMEOUT", Description: "Timeout for idle connections in seconds", Default: "120", HasDefault: true, Required: false},
|
||||
}...)
|
||||
|
||||
// HWSAUTHConfig
|
||||
envVars = append(envVars, []EnvVar{
|
||||
{Name: "HWSAUTH_SECRET_KEY", Description: "Secret key for signing tokens", HasDefault: false, Required: true},
|
||||
{Name: "HWSAUTH_ACCESS_TOKEN_EXPIRY", Description: "Access token expiry in minutes", Default: "5", HasDefault: true, Required: false},
|
||||
{Name: "HWSAUTH_REFRESH_TOKEN_EXPIRY", Description: "Refresh token expiry in minutes", Default: "1440", HasDefault: true, Required: false},
|
||||
{Name: "HWSAUTH_TOKEN_FRESH_TIME", Description: "Time for tokens to stay fresh in minutes", Default: "5", HasDefault: true, Required: false},
|
||||
}...)
|
||||
|
||||
// TMDBConfig
|
||||
envVars = append(envVars, []EnvVar{
|
||||
{Name: "TMDB_TOKEN", Description: "API token for TMDB", HasDefault: false, Required: true},
|
||||
}...)
|
||||
|
||||
// HLOGConfig
|
||||
envVars = append(envVars, []EnvVar{
|
||||
{Name: "LOG_LEVEL", Description: "Log level for global logging", Default: "info", HasDefault: true, Required: false},
|
||||
{Name: "LOG_OUTPUT", Description: "Output method for the logger (file, console, or both)", Default: "console", HasDefault: true, Required: false},
|
||||
{Name: "LOG_DIR", Description: "Path to create log files", HasDefault: false, Required: false},
|
||||
}...)
|
||||
|
||||
return envVars
|
||||
}
|
||||
95
internal/config/envgen.go
Normal file
95
internal/config/envgen.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// GenerateDotEnv creates a new .env file with all environment variables and their defaults
|
||||
func GenerateDotEnv(filename string) error {
|
||||
envVars := GetAllEnvVars()
|
||||
|
||||
file, err := os.Create(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Write header
|
||||
fmt.Fprintln(file, "# Environment Configuration")
|
||||
fmt.Fprintln(file, "# Generated by Project Reshoot")
|
||||
fmt.Fprintln(file, "#")
|
||||
fmt.Fprintln(file, "# Variables marked as (required) must be set")
|
||||
fmt.Fprintln(file, "# Variables with defaults can be left commented out to use the default value")
|
||||
fmt.Fprintln(file)
|
||||
|
||||
// Group by prefix
|
||||
groups := map[string][]EnvVar{
|
||||
"DB_": {},
|
||||
"HWS_": {},
|
||||
"HWSAUTH_": {},
|
||||
"TMDB_": {},
|
||||
"LOG_": {},
|
||||
}
|
||||
|
||||
for _, ev := range envVars {
|
||||
assigned := false
|
||||
for prefix := range groups {
|
||||
if strings.HasPrefix(ev.Name, prefix) {
|
||||
groups[prefix] = append(groups[prefix], ev)
|
||||
assigned = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !assigned {
|
||||
// Handle ungrouped vars
|
||||
if _, ok := groups["OTHER"]; !ok {
|
||||
groups["OTHER"] = []EnvVar{}
|
||||
}
|
||||
groups["OTHER"] = append(groups["OTHER"], ev)
|
||||
}
|
||||
}
|
||||
|
||||
// Print each group
|
||||
groupOrder := []string{"DB_", "HWS_", "HWSAUTH_", "TMDB_", "LOG_", "OTHER"}
|
||||
groupTitles := map[string]string{
|
||||
"DB_": "Database Configuration",
|
||||
"HWS_": "HTTP Web Server Configuration",
|
||||
"HWSAUTH_": "Authentication Configuration",
|
||||
"TMDB_": "TMDB API Configuration",
|
||||
"LOG_": "Logging Configuration",
|
||||
"OTHER": "Other Configuration",
|
||||
}
|
||||
|
||||
for _, prefix := range groupOrder {
|
||||
vars, ok := groups[prefix]
|
||||
if !ok || len(vars) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Fprintf(file, "# %s\n", groupTitles[prefix])
|
||||
fmt.Fprintln(file, strings.Repeat("#", len(groupTitles[prefix])+2))
|
||||
|
||||
for _, ev := range vars {
|
||||
// Write description as comment
|
||||
if ev.Required {
|
||||
fmt.Fprintf(file, "# %s (required)\n", ev.Description)
|
||||
// Leave required variables uncommented but empty
|
||||
fmt.Fprintf(file, "%s=\n", ev.Name)
|
||||
} else if ev.HasDefault {
|
||||
fmt.Fprintf(file, "# %s\n", ev.Description)
|
||||
// Comment out variables with defaults
|
||||
fmt.Fprintf(file, "# %s=%s\n", ev.Name, ev.Default)
|
||||
} else {
|
||||
fmt.Fprintf(file, "# %s\n", ev.Description)
|
||||
// Optional variables without defaults are commented out
|
||||
fmt.Fprintf(file, "# %s=\n", ev.Name)
|
||||
}
|
||||
fmt.Fprintln(file)
|
||||
}
|
||||
fmt.Fprintln(file)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
87
internal/config/envprint.go
Normal file
87
internal/config/envprint.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// PrintEnvVars writes all environment variables and their documentation to the provided writer
|
||||
func PrintEnvVars(w io.Writer) error {
|
||||
envVars := GetAllEnvVars()
|
||||
|
||||
// Find the longest name for alignment
|
||||
maxNameLen := 0
|
||||
for _, ev := range envVars {
|
||||
if len(ev.Name) > maxNameLen {
|
||||
maxNameLen = len(ev.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// Print header
|
||||
fmt.Fprintln(w, "Environment Variables")
|
||||
fmt.Fprintln(w, strings.Repeat("=", 80))
|
||||
fmt.Fprintln(w)
|
||||
|
||||
// Group by prefix
|
||||
groups := map[string][]EnvVar{
|
||||
"DB_": {},
|
||||
"HWS_": {},
|
||||
"HWSAUTH_": {},
|
||||
"TMDB_": {},
|
||||
"LOG_": {},
|
||||
}
|
||||
|
||||
for _, ev := range envVars {
|
||||
assigned := false
|
||||
for prefix := range groups {
|
||||
if strings.HasPrefix(ev.Name, prefix) {
|
||||
groups[prefix] = append(groups[prefix], ev)
|
||||
assigned = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !assigned {
|
||||
// Handle ungrouped vars
|
||||
if _, ok := groups["OTHER"]; !ok {
|
||||
groups["OTHER"] = []EnvVar{}
|
||||
}
|
||||
groups["OTHER"] = append(groups["OTHER"], ev)
|
||||
}
|
||||
}
|
||||
|
||||
// Print each group
|
||||
groupOrder := []string{"DB_", "HWS_", "HWSAUTH_", "TMDB_", "LOG_", "OTHER"}
|
||||
groupTitles := map[string]string{
|
||||
"DB_": "Database Configuration",
|
||||
"HWS_": "HTTP Web Server Configuration",
|
||||
"HWSAUTH_": "Authentication Configuration",
|
||||
"TMDB_": "TMDB API Configuration",
|
||||
"LOG_": "Logging Configuration",
|
||||
"OTHER": "Other Configuration",
|
||||
}
|
||||
|
||||
for _, prefix := range groupOrder {
|
||||
vars, ok := groups[prefix]
|
||||
if !ok || len(vars) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Fprintf(w, "%s\n", groupTitles[prefix])
|
||||
fmt.Fprintln(w, strings.Repeat("-", len(groupTitles[prefix])))
|
||||
|
||||
for _, ev := range vars {
|
||||
padding := strings.Repeat(" ", maxNameLen-len(ev.Name))
|
||||
if ev.Required {
|
||||
fmt.Fprintf(w, " %s%s : %s (required)\n", ev.Name, padding, ev.Description)
|
||||
} else if ev.HasDefault {
|
||||
fmt.Fprintf(w, " %s%s : %s (default: %s)\n", ev.Name, padding, ev.Description, ev.Default)
|
||||
} else {
|
||||
fmt.Fprintf(w, " %s%s : %s\n", ev.Name, padding, ev.Description)
|
||||
}
|
||||
}
|
||||
fmt.Fprintln(w)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
29
internal/config/httpserver.go
Normal file
29
internal/config/httpserver.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/env"
|
||||
)
|
||||
|
||||
type HWSConfig struct {
|
||||
Host string // ENV HWS_HOST: Host to listen on (default: 127.0.0.1)
|
||||
Port uint64 // ENV HWS_PORT: Port to listen on (default: 3000)
|
||||
GZIP bool // ENV HWS_GZIP: Flag for GZIP compression on requests (default: false)
|
||||
ReadHeaderTimeout time.Duration // ENV HWS_READ_HEADER_TIMEOUT: Timeout for reading request headers in seconds (default: 2)
|
||||
WriteTimeout time.Duration // ENV HWS_WRITE_TIMEOUT: Timeout for writing requests in seconds (default: 10)
|
||||
IdleTimeout time.Duration // ENV HWS_IDLE_TIMEOUT: Timeout for idle connections in seconds (default: 120)
|
||||
}
|
||||
|
||||
func setupHWS() (*HWSConfig, error) {
|
||||
cfg := &HWSConfig{
|
||||
Host: env.String("HWS_HOST", "127.0.0.1"),
|
||||
Port: env.UInt64("HWS_PORT", 3000),
|
||||
GZIP: env.Bool("HWS_GZIP", false),
|
||||
ReadHeaderTimeout: time.Duration(env.Int("HWS_READ_HEADER_TIMEOUT", 2)) * time.Second,
|
||||
WriteTimeout: time.Duration(env.Int("HWS_WRITE_TIMEOUT", 10)) * time.Second,
|
||||
IdleTimeout: time.Duration(env.Int("HWS_IDLE_TIMEOUT", 120)) * time.Second,
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
36
internal/config/logger.go
Normal file
36
internal/config/logger.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"git.haelnorr.com/h/golib/env"
|
||||
"git.haelnorr.com/h/golib/hlog"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type HLOGConfig struct {
|
||||
// ENV LOG_LEVEL: Log level for global logging. (default: info)
|
||||
LogLevel hlog.Level
|
||||
|
||||
// ENV LOG_OUTPUT: Output method for the logger. (default: console)
|
||||
// Valid options: "file", "console", "both"
|
||||
LogOutput string
|
||||
|
||||
// ENV LOG_DIR: Path to create log files
|
||||
LogDir string
|
||||
}
|
||||
|
||||
func setupHLOG() (*HLOGConfig, error) {
|
||||
logLevel, err := hlog.LogLevel(env.String("LOG_LEVEL", "info"))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "hlog.LogLevel")
|
||||
}
|
||||
logOutput := env.String("LOG_OUTPUT", "console")
|
||||
if logOutput != "both" && logOutput != "console" && logOutput != "file" {
|
||||
return nil, errors.Errorf("Invalid LOG_OUTPUT: %s", logOutput)
|
||||
}
|
||||
cfg := &HLOGConfig{
|
||||
LogLevel: logLevel,
|
||||
LogOutput: logOutput,
|
||||
LogDir: env.String("LOG_DIR", ""),
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
28
internal/config/tmdb.go
Normal file
28
internal/config/tmdb.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"git.haelnorr.com/h/golib/env"
|
||||
"git.haelnorr.com/h/golib/tmdb"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type TMDBConfig struct {
|
||||
Token string // ENV TMDB_TOKEN: API token for TMDB (required)
|
||||
Config *tmdb.Config // Config data for interfacing with TMDB
|
||||
}
|
||||
|
||||
func setupTMDB() (*TMDBConfig, error) {
|
||||
token := env.String("TMDB_TOKEN", "")
|
||||
if token == "" {
|
||||
return nil, errors.New("No TMDB API Token provided")
|
||||
}
|
||||
tmdbcfg, err := tmdb.GetConfig(token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "tmdb.GetConfig")
|
||||
}
|
||||
cfg := &TMDBConfig{
|
||||
Token: token,
|
||||
Config: tmdbcfg,
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
@@ -2,19 +2,20 @@ package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/golib/hwsauth"
|
||||
"projectreshoot/internal/models"
|
||||
"projectreshoot/internal/view/component/account"
|
||||
"projectreshoot/internal/view/page"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/golib/hwsauth"
|
||||
|
||||
"git.haelnorr.com/h/golib/cookies"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
// Renders the account page on the 'General' subpage
|
||||
@@ -46,8 +47,8 @@ func AccountSubpage() http.Handler {
|
||||
// Handles a request to change the users username
|
||||
func ChangeUsername(
|
||||
server *hws.Server,
|
||||
auth *hwsauth.Authenticator[*models.User],
|
||||
conn *sql.DB,
|
||||
auth *hwsauth.Authenticator[*models.UserBun, bun.Tx],
|
||||
db *bun.DB,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -55,7 +56,7 @@ func ChangeUsername(
|
||||
defer cancel()
|
||||
|
||||
// Start the transaction
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
err := server.ThrowError(w, r, hws.HWSError{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
@@ -69,7 +70,7 @@ func ChangeUsername(
|
||||
}
|
||||
r.ParseForm()
|
||||
newUsername := r.FormValue("username")
|
||||
unique, err := models.CheckUsernameUnique(tx, newUsername)
|
||||
unique, err := models.IsUsernameUnique(ctx, tx, newUsername)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
err := server.ThrowError(w, r, hws.HWSError{
|
||||
@@ -89,7 +90,7 @@ func ChangeUsername(
|
||||
return
|
||||
}
|
||||
user := auth.CurrentModel(r.Context())
|
||||
err = user.ChangeUsername(tx, newUsername)
|
||||
err = user.ChangeUsername(ctx, tx, newUsername)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
err := server.ThrowError(w, r, hws.HWSError{
|
||||
@@ -111,8 +112,8 @@ func ChangeUsername(
|
||||
// Handles a request to change the users bio
|
||||
func ChangeBio(
|
||||
server *hws.Server,
|
||||
auth *hwsauth.Authenticator[*models.User],
|
||||
conn *sql.DB,
|
||||
auth *hwsauth.Authenticator[*models.UserBun, bun.Tx],
|
||||
db *bun.DB,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -120,7 +121,7 @@ func ChangeBio(
|
||||
defer cancel()
|
||||
|
||||
// Start the transaction
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
err := server.ThrowError(w, r, hws.HWSError{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
@@ -142,7 +143,7 @@ func ChangeBio(
|
||||
return
|
||||
}
|
||||
user := auth.CurrentModel(r.Context())
|
||||
err = user.ChangeBio(tx, newBio)
|
||||
err = user.ChangeBio(ctx, tx, newBio)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
err := server.ThrowError(w, r, hws.HWSError{
|
||||
@@ -178,8 +179,8 @@ func validateChangePassword(
|
||||
// Handles a request to change the users password
|
||||
func ChangePassword(
|
||||
server *hws.Server,
|
||||
auth *hwsauth.Authenticator[*models.User],
|
||||
conn *sql.DB,
|
||||
auth *hwsauth.Authenticator[*models.UserBun, bun.Tx],
|
||||
db *bun.DB,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -187,7 +188,7 @@ func ChangePassword(
|
||||
defer cancel()
|
||||
|
||||
// Start the transaction
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
err := server.ThrowError(w, r, hws.HWSError{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
@@ -206,7 +207,7 @@ func ChangePassword(
|
||||
return
|
||||
}
|
||||
user := auth.CurrentModel(r.Context())
|
||||
err = user.SetPassword(tx, newPass)
|
||||
err = user.SetPassword(ctx, tx, newPass)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
err := server.ThrowError(w, r, hws.HWSError{
|
||||
|
||||
@@ -2,35 +2,41 @@ package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/golib/hwsauth"
|
||||
"projectreshoot/internal/models"
|
||||
"projectreshoot/internal/view/component/form"
|
||||
"projectreshoot/internal/view/page"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/golib/hwsauth"
|
||||
|
||||
"git.haelnorr.com/h/golib/cookies"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
// Validates the username matches a user in the database and the password
|
||||
// is correct. Returns the corresponding user
|
||||
func validateLogin(
|
||||
tx *sql.Tx,
|
||||
ctx context.Context,
|
||||
tx bun.Tx,
|
||||
r *http.Request,
|
||||
) (*models.User, error) {
|
||||
) (*models.UserBun, error) {
|
||||
formUsername := r.FormValue("username")
|
||||
formPassword := r.FormValue("password")
|
||||
user, err := models.GetUserFromUsername(tx, formUsername)
|
||||
user, err := models.GetUserByUsername(ctx, tx, formUsername)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "db.GetUserFromUsername")
|
||||
}
|
||||
|
||||
err = user.CheckPassword(tx, formPassword)
|
||||
if user == nil {
|
||||
return nil, errors.New("Username or password incorrect")
|
||||
}
|
||||
|
||||
err = user.CheckPassword(ctx, tx, formPassword)
|
||||
if err != nil {
|
||||
if !strings.Contains(err.Error(), "Username or password incorrect") {
|
||||
return nil, errors.Wrap(err, "user.CheckPassword")
|
||||
@@ -55,8 +61,8 @@ func checkRememberMe(r *http.Request) bool {
|
||||
// template for user feedback
|
||||
func LoginRequest(
|
||||
server *hws.Server,
|
||||
auth *hwsauth.Authenticator[*models.User],
|
||||
conn *sql.DB,
|
||||
auth *hwsauth.Authenticator[*models.UserBun, bun.Tx],
|
||||
db *bun.DB,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -64,7 +70,7 @@ func LoginRequest(
|
||||
defer cancel()
|
||||
|
||||
// Start the transaction
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
err := server.ThrowError(w, r, hws.HWSError{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
@@ -77,7 +83,7 @@ func LoginRequest(
|
||||
return
|
||||
}
|
||||
r.ParseForm()
|
||||
user, err := validateLogin(tx, r)
|
||||
user, err := validateLogin(ctx, tx, r)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
if err.Error() != "Username or password incorrect" {
|
||||
|
||||
@@ -2,26 +2,27 @@ package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/golib/hwsauth"
|
||||
"net/http"
|
||||
"projectreshoot/internal/models"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/golib/hwsauth"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
// Handle a logout request
|
||||
func Logout(
|
||||
server *hws.Server,
|
||||
auth *hwsauth.Authenticator[*models.User],
|
||||
conn *sql.DB,
|
||||
auth *hwsauth.Authenticator[*models.UserBun, bun.Tx],
|
||||
db *bun.DB,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
err := server.ThrowError(w, r, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
func Movie(
|
||||
server *hws.Server,
|
||||
config *config.Config,
|
||||
cfg *config.TMDBConfig,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -31,7 +31,7 @@ func Movie(
|
||||
}
|
||||
return
|
||||
}
|
||||
movie, err := tmdb.GetMovie(int32(movie_id), config.TMDBToken)
|
||||
movie, err := tmdb.GetMovie(int32(movie_id), cfg.Token)
|
||||
if err != nil {
|
||||
err := server.ThrowError(w, r, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
@@ -43,7 +43,7 @@ func Movie(
|
||||
}
|
||||
return
|
||||
}
|
||||
credits, err := tmdb.GetCredits(int32(movie_id), config.TMDBToken)
|
||||
credits, err := tmdb.GetCredits(int32(movie_id), cfg.Token)
|
||||
if err != nil {
|
||||
err := server.ThrowError(w, r, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
@@ -55,7 +55,7 @@ func Movie(
|
||||
}
|
||||
return
|
||||
}
|
||||
page.Movie(movie, credits, &config.TMDBConfig.Image).Render(r.Context(), w)
|
||||
page.Movie(movie, credits, &cfg.Config.Image).Render(r.Context(), w)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
)
|
||||
|
||||
func SearchMovies(
|
||||
config *config.Config,
|
||||
cfg *config.TMDBConfig,
|
||||
logger *hlog.Logger,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
@@ -22,12 +22,12 @@ func SearchMovies(
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
movies, err := tmdb.SearchMovies(config.TMDBToken, query, false, 1)
|
||||
movies, err := tmdb.SearchMovies(cfg.Token, query, false, 1)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
search.MovieResults(movies, &config.TMDBConfig.Image).Render(r.Context(), w)
|
||||
search.MovieResults(movies, &cfg.Config.Image).Render(r.Context(), w)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -2,28 +2,30 @@ package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/golib/hwsauth"
|
||||
"projectreshoot/internal/models"
|
||||
"projectreshoot/internal/view/component/form"
|
||||
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/golib/hwsauth"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
// Validate the provided password
|
||||
func validatePassword(
|
||||
auth *hwsauth.Authenticator[*models.User],
|
||||
tx *sql.Tx,
|
||||
ctx context.Context,
|
||||
auth *hwsauth.Authenticator[*models.UserBun, bun.Tx],
|
||||
tx bun.Tx,
|
||||
r *http.Request,
|
||||
) error {
|
||||
r.ParseForm()
|
||||
password := r.FormValue("password")
|
||||
user := auth.CurrentModel(r.Context())
|
||||
err := user.CheckPassword(tx, password)
|
||||
err := user.CheckPassword(ctx, tx, password)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "user.CheckPassword")
|
||||
}
|
||||
@@ -33,8 +35,8 @@ func validatePassword(
|
||||
// Handle request to reauthenticate (i.e. make token fresh again)
|
||||
func Reauthenticate(
|
||||
server *hws.Server,
|
||||
auth *hwsauth.Authenticator[*models.User],
|
||||
conn *sql.DB,
|
||||
auth *hwsauth.Authenticator[*models.UserBun, bun.Tx],
|
||||
db *bun.DB,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -42,7 +44,7 @@ func Reauthenticate(
|
||||
defer cancel()
|
||||
|
||||
// Start the transaction
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
err := server.ThrowError(w, r, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
@@ -55,7 +57,7 @@ func Reauthenticate(
|
||||
return
|
||||
}
|
||||
defer tx.Rollback()
|
||||
err = validatePassword(auth, tx, r)
|
||||
err = validatePassword(ctx, auth, tx, r)
|
||||
if err != nil {
|
||||
w.WriteHeader(445)
|
||||
form.ConfirmPassword("Incorrect password").Render(r.Context(), w)
|
||||
|
||||
@@ -2,30 +2,30 @@ package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"projectreshoot/internal/config"
|
||||
"projectreshoot/internal/models"
|
||||
"projectreshoot/internal/view/component/form"
|
||||
"projectreshoot/internal/view/page"
|
||||
|
||||
"git.haelnorr.com/h/golib/cookies"
|
||||
"git.haelnorr.com/h/golib/hlog"
|
||||
"git.haelnorr.com/h/golib/jwt"
|
||||
"git.haelnorr.com/h/golib/hws"
|
||||
"git.haelnorr.com/h/golib/hwsauth"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
func validateRegistration(
|
||||
tx *sql.Tx,
|
||||
ctx context.Context,
|
||||
tx bun.Tx,
|
||||
r *http.Request,
|
||||
) (*models.User, error) {
|
||||
) (*models.UserBun, error) {
|
||||
formUsername := r.FormValue("username")
|
||||
formPassword := r.FormValue("password")
|
||||
formConfirmPassword := r.FormValue("confirm-password")
|
||||
unique, err := models.CheckUsernameUnique(tx, formUsername)
|
||||
unique, err := models.IsUsernameUnique(ctx, tx, formUsername)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "models.CheckUsernameUnique")
|
||||
}
|
||||
@@ -38,7 +38,7 @@ func validateRegistration(
|
||||
if len(formPassword) > 72 {
|
||||
return nil, errors.New("Password exceeds maximum length of 72 bytes")
|
||||
}
|
||||
user, err := models.CreateNewUser(tx, formUsername, formPassword)
|
||||
user, err := models.CreateUser(ctx, tx, formUsername, formPassword)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "models.CreateNewUser")
|
||||
}
|
||||
@@ -47,10 +47,9 @@ func validateRegistration(
|
||||
}
|
||||
|
||||
func RegisterRequest(
|
||||
config *config.Config,
|
||||
logger *hlog.Logger,
|
||||
conn *sql.DB,
|
||||
tokenGen *jwt.TokenGenerator,
|
||||
server *hws.Server,
|
||||
auth *hwsauth.Authenticator[*models.UserBun, bun.Tx],
|
||||
db *bun.DB,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -58,21 +57,33 @@ func RegisterRequest(
|
||||
defer cancel()
|
||||
|
||||
// Start the transaction
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to set token cookies")
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
err := server.ThrowError(w, r, hws.HWSError{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Message: "Failed to start transaction",
|
||||
Error: err,
|
||||
})
|
||||
if err != nil {
|
||||
server.ThrowFatal(w, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
r.ParseForm()
|
||||
user, err := validateRegistration(tx, r)
|
||||
user, err := validateRegistration(ctx, tx, r)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
if err.Error() != "Username is taken" &&
|
||||
err.Error() != "Passwords do not match" &&
|
||||
err.Error() != "Password exceeds maximum length of 72 bytes" {
|
||||
logger.Warn().Caller().Err(err).Msg("Registration request failed")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
err := server.ThrowError(w, r, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Message: "Registration failed",
|
||||
Error: err,
|
||||
})
|
||||
if err != nil {
|
||||
server.ThrowFatal(w, err)
|
||||
}
|
||||
} else {
|
||||
form.RegisterForm(err.Error()).Render(r.Context(), w)
|
||||
}
|
||||
@@ -80,11 +91,17 @@ func RegisterRequest(
|
||||
}
|
||||
|
||||
rememberMe := checkRememberMe(r)
|
||||
err = jwt.SetTokenCookies(w, r, tokenGen, user.ID(), true, rememberMe, config.SSL)
|
||||
err = auth.Login(w, r, user, rememberMe)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
logger.Warn().Caller().Err(err).Msg("Failed to set token cookies")
|
||||
err := server.ThrowError(w, r, hws.HWSError{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Message: "Login failed",
|
||||
Error: err,
|
||||
})
|
||||
if err != nil {
|
||||
server.ThrowFatal(w, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
tx.Commit()
|
||||
|
||||
@@ -47,7 +47,7 @@ func (user *User) CheckPassword(tx *sql.Tx, password string) error {
|
||||
}
|
||||
err = bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "bcrypt.CompareHashAndPassword")
|
||||
return errors.Wrap(err, "Username or password incorrect")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
163
internal/models/user_bun.go
Normal file
163
internal/models/user_bun.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/uptrace/bun"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type UserBun struct {
|
||||
bun.BaseModel `bun:"table:users,alias:u"`
|
||||
|
||||
ID int `bun:"id,pk,autoincrement"` // Integer ID (index primary key)
|
||||
Username string `bun:"username,unique"` // Username (unique)
|
||||
PasswordHash string `bun:"password_hash,nullzero"` // Bcrypt hashed password (not exported in JSON)
|
||||
CreatedAt int64 `bun:"created_at"` // Epoch timestamp when the user was added to the database
|
||||
Bio string `bun:"bio"` // Short byline set by the user
|
||||
}
|
||||
|
||||
func (user *UserBun) GetID() int {
|
||||
return user.ID
|
||||
}
|
||||
|
||||
// Uses bcrypt to set the users password_hash from the given password
|
||||
func (user *UserBun) SetPassword(ctx context.Context, tx bun.Tx, password string) error {
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "bcrypt.GenerateFromPassword")
|
||||
}
|
||||
newPassword := string(hashedPassword)
|
||||
|
||||
_, err = tx.NewUpdate().
|
||||
Model(user).
|
||||
Set("password_hash = ?", newPassword).
|
||||
Where("id = ?", user.ID).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "tx.Update")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Uses bcrypt to check if the given password matches the users password_hash
|
||||
func (user *UserBun) CheckPassword(ctx context.Context, tx bun.Tx, password string) error {
|
||||
var hashedPassword string
|
||||
err := tx.NewSelect().
|
||||
Table("users").
|
||||
Column("password_hash").
|
||||
Where("id = ?", user.ID).
|
||||
Limit(1).
|
||||
Scan(ctx, &hashedPassword)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "tx.Select")
|
||||
}
|
||||
|
||||
err = bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Username or password incorrect")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Change the user's username
|
||||
func (user *UserBun) ChangeUsername(ctx context.Context, tx bun.Tx, newUsername string) error {
|
||||
_, err := tx.NewUpdate().
|
||||
Model(user).
|
||||
Set("username = ?", newUsername).
|
||||
Where("id = ?", user.ID).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "tx.Update")
|
||||
}
|
||||
user.Username = newUsername
|
||||
return nil
|
||||
}
|
||||
|
||||
// Change the user's bio
|
||||
func (user *UserBun) ChangeBio(ctx context.Context, tx bun.Tx, newBio string) error {
|
||||
_, err := tx.NewUpdate().
|
||||
Model(user).
|
||||
Set("bio = ?", newBio).
|
||||
Where("id = ?", user.ID).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "tx.Update")
|
||||
}
|
||||
user.Bio = newBio
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user with the given username and password
|
||||
func CreateUser(ctx context.Context, tx bun.Tx, username, password string) (*UserBun, error) {
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "bcrypt.GenerateFromPassword")
|
||||
}
|
||||
|
||||
user := &UserBun{
|
||||
Username: username,
|
||||
PasswordHash: string(hashedPassword),
|
||||
CreatedAt: 0, // You may want to set this to time.Now().Unix()
|
||||
Bio: "",
|
||||
}
|
||||
|
||||
_, err = tx.NewInsert().
|
||||
Model(user).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "tx.Insert")
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// GetUserByID queries the database for a user matching the given ID
|
||||
// Returns nil, nil if no user is found
|
||||
func GetUserByID(ctx context.Context, tx bun.Tx, id int) (*UserBun, error) {
|
||||
user := new(UserBun)
|
||||
err := tx.NewSelect().
|
||||
Model(user).
|
||||
Where("id = ?", id).
|
||||
Limit(1).
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
if err.Error() == "sql: no rows in result set" {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, errors.Wrap(err, "tx.Select")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// GetUserByUsername queries the database for a user matching the given username
|
||||
// Returns nil, nil if no user is found
|
||||
func GetUserByUsername(ctx context.Context, tx bun.Tx, username string) (*UserBun, error) {
|
||||
user := new(UserBun)
|
||||
err := tx.NewSelect().
|
||||
Model(user).
|
||||
Where("username = ?", username).
|
||||
Limit(1).
|
||||
Scan(ctx)
|
||||
if err != nil {
|
||||
if err.Error() == "sql: no rows in result set" {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, errors.Wrap(err, "tx.Select")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// IsUsernameUnique checks if the given username is unique (not already taken)
|
||||
// Returns true if the username is available, false if it's taken
|
||||
func IsUsernameUnique(ctx context.Context, tx bun.Tx, username string) (bool, error) {
|
||||
count, err := tx.NewSelect().
|
||||
Model((*UserBun)(nil)).
|
||||
Where("username = ?", username).
|
||||
Count(ctx)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "tx.Count")
|
||||
}
|
||||
return count == 0, nil
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"git.haelnorr.com/h/golib/hwsauth"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
@@ -31,7 +32,9 @@ func CreateNewUser(
|
||||
|
||||
// Fetches data from the users table using "WHERE column = 'value'"
|
||||
func fetchUserData(
|
||||
tx *sql.Tx,
|
||||
tx interface {
|
||||
Query(query string, args ...any) (*sql.Rows, error)
|
||||
},
|
||||
column string,
|
||||
value any,
|
||||
) (*sql.Rows, error) {
|
||||
@@ -87,7 +90,7 @@ func GetUserFromUsername(tx *sql.Tx, username string) (*User, error) {
|
||||
}
|
||||
|
||||
// Queries the database for a user matching the given ID.
|
||||
func GetUserFromID(tx *sql.Tx, id int) (*User, error) {
|
||||
func GetUserFromID(tx hwsauth.DBTransaction, id int) (*User, error) {
|
||||
rows, err := fetchUserData(tx, "id", id)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "fetchUserData")
|
||||
|
||||
Reference in New Issue
Block a user