refactor: changed file structure
This commit is contained in:
118
pkg/config/config.go
Normal file
118
pkg/config/config.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"projectreshoot/pkg/logging"
|
||||
"projectreshoot/pkg/tmdb"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Host string // Host to listen on
|
||||
Port string // Port to listen on
|
||||
TrustedHost string // Domain/Hostname to accept as trusted
|
||||
SSL bool // Flag for SSL Mode
|
||||
GZIP bool // Flag for GZIP compression on requests
|
||||
ReadHeaderTimeout time.Duration // Timeout for reading request headers in seconds
|
||||
WriteTimeout time.Duration // Timeout for writing requests in seconds
|
||||
IdleTimeout time.Duration // Timeout for idle connections in seconds
|
||||
DBName string // Filename of the db - hardcoded and doubles as DB version
|
||||
DBLockTimeout time.Duration // Timeout for acquiring database lock
|
||||
SecretKey string // Secret key for signing tokens
|
||||
AccessTokenExpiry int64 // Access token expiry in minutes
|
||||
RefreshTokenExpiry int64 // Refresh token expiry in minutes
|
||||
TokenFreshTime int64 // Time for tokens to stay fresh in minutes
|
||||
LogLevel zerolog.Level // Log level for global logging. Defaults to info
|
||||
LogOutput string // "file", "console", or "both". Defaults to console
|
||||
LogDir string // Path to create log files
|
||||
TMDBToken string // Read access token for TMDB API
|
||||
TMDBConfig *tmdb.Config // Config data for interfacing with TMDB
|
||||
}
|
||||
|
||||
// Load the application configuration and get a pointer to the Config object
|
||||
func GetConfig(args map[string]string) (*Config, error) {
|
||||
godotenv.Load(".env")
|
||||
var (
|
||||
host string
|
||||
port string
|
||||
logLevel zerolog.Level
|
||||
logOutput string
|
||||
valid bool
|
||||
)
|
||||
|
||||
if args["host"] != "" {
|
||||
host = args["host"]
|
||||
} else {
|
||||
host = GetEnvDefault("HOST", "127.0.0.1")
|
||||
}
|
||||
if args["port"] != "" {
|
||||
port = args["port"]
|
||||
} else {
|
||||
port = GetEnvDefault("PORT", "3010")
|
||||
}
|
||||
if args["loglevel"] != "" {
|
||||
logLevel = logging.GetLogLevel(args["loglevel"])
|
||||
} else {
|
||||
logLevel = logging.GetLogLevel(GetEnvDefault("LOG_LEVEL", "info"))
|
||||
}
|
||||
if args["logoutput"] != "" {
|
||||
opts := map[string]string{
|
||||
"both": "both",
|
||||
"file": "file",
|
||||
"console": "console",
|
||||
}
|
||||
logOutput, valid = opts[args["logoutput"]]
|
||||
if !valid {
|
||||
logOutput = "console"
|
||||
fmt.Println(
|
||||
"Log output type was not parsed correctly. Defaulting to console only",
|
||||
)
|
||||
}
|
||||
} else {
|
||||
logOutput = GetEnvDefault("LOG_OUTPUT", "console")
|
||||
}
|
||||
if logOutput != "both" && logOutput != "console" && logOutput != "file" {
|
||||
logOutput = "console"
|
||||
}
|
||||
tmdbcfg, err := tmdb.GetConfig(os.Getenv("TMDB_API_TOKEN"))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "tmdb.GetConfig")
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
Host: host,
|
||||
Port: port,
|
||||
TrustedHost: GetEnvDefault("TRUSTED_HOST", "127.0.0.1"),
|
||||
SSL: GetEnvBool("SSL_MODE", false),
|
||||
GZIP: GetEnvBool("GZIP", false),
|
||||
ReadHeaderTimeout: GetEnvDur("READ_HEADER_TIMEOUT", 2),
|
||||
WriteTimeout: GetEnvDur("WRITE_TIMEOUT", 10),
|
||||
IdleTimeout: GetEnvDur("IDLE_TIMEOUT", 120),
|
||||
DBName: "00001",
|
||||
DBLockTimeout: GetEnvDur("DB_LOCK_TIMEOUT", 60),
|
||||
SecretKey: os.Getenv("SECRET_KEY"),
|
||||
AccessTokenExpiry: GetEnvInt64("ACCESS_TOKEN_EXPIRY", 5),
|
||||
RefreshTokenExpiry: GetEnvInt64("REFRESH_TOKEN_EXPIRY", 1440), // defaults to 1 day
|
||||
TokenFreshTime: GetEnvInt64("TOKEN_FRESH_TIME", 5),
|
||||
LogLevel: logLevel,
|
||||
LogOutput: logOutput,
|
||||
LogDir: GetEnvDefault("LOG_DIR", ""),
|
||||
TMDBToken: os.Getenv("TMDB_API_TOKEN"),
|
||||
TMDBConfig: tmdbcfg,
|
||||
}
|
||||
|
||||
if config.SecretKey == "" && args["dbver"] != "true" {
|
||||
return nil, errors.New("Envar not set: SECRET_KEY")
|
||||
}
|
||||
if config.TMDBToken == "" && args["dbver"] != "true" {
|
||||
return nil, errors.New("Envar not set: TMDB_API_TOKEN")
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
94
pkg/config/environment.go
Normal file
94
pkg/config/environment.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Get an environment variable, specifying a default value if its not set
|
||||
func GetEnvDefault(key string, defaultValue string) string {
|
||||
val, exists := os.LookupEnv(key)
|
||||
if !exists {
|
||||
return defaultValue
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
// Get an environment variable as a time.Duration, specifying a default value if its
|
||||
// not set or can't be parsed properly
|
||||
func GetEnvDur(key string, defaultValue time.Duration) time.Duration {
|
||||
val, exists := os.LookupEnv(key)
|
||||
if !exists {
|
||||
return time.Duration(defaultValue)
|
||||
}
|
||||
|
||||
intVal, err := strconv.Atoi(val)
|
||||
if err != nil {
|
||||
return time.Duration(defaultValue)
|
||||
}
|
||||
return time.Duration(intVal)
|
||||
|
||||
}
|
||||
|
||||
// Get an environment variable as an int, specifying a default value if its
|
||||
// not set or can't be parsed properly into an int
|
||||
func GetEnvInt(key string, defaultValue int) int {
|
||||
val, exists := os.LookupEnv(key)
|
||||
if !exists {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
intVal, err := strconv.Atoi(val)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return intVal
|
||||
|
||||
}
|
||||
|
||||
// Get an environment variable as an int64, specifying a default value if its
|
||||
// not set or can't be parsed properly into an int64
|
||||
func GetEnvInt64(key string, defaultValue int64) int64 {
|
||||
val, exists := os.LookupEnv(key)
|
||||
if !exists {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
intVal, err := strconv.ParseInt(val, 10, 64)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return intVal
|
||||
|
||||
}
|
||||
|
||||
// Get an environment variable as a boolean, specifying a default value if its
|
||||
// not set or can't be parsed properly into a bool
|
||||
func GetEnvBool(key string, defaultValue bool) bool {
|
||||
val, exists := os.LookupEnv(key)
|
||||
if !exists {
|
||||
return defaultValue
|
||||
}
|
||||
truthy := map[string]bool{
|
||||
"true": true, "t": true, "yes": true, "y": true, "on": true, "1": true,
|
||||
"enable": true, "enabled": true, "active": true, "affirmative": true,
|
||||
}
|
||||
|
||||
falsy := map[string]bool{
|
||||
"false": false, "f": false, "no": false, "n": false, "off": false, "0": false,
|
||||
"disable": false, "disabled": false, "inactive": false, "negative": false,
|
||||
}
|
||||
|
||||
normalized := strings.TrimSpace(strings.ToLower(val))
|
||||
|
||||
if val, ok := truthy[normalized]; ok {
|
||||
return val
|
||||
}
|
||||
if val, ok := falsy[normalized]; ok {
|
||||
return val
|
||||
}
|
||||
|
||||
return defaultValue
|
||||
}
|
||||
12
pkg/contexts/keys.go
Normal file
12
pkg/contexts/keys.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package contexts
|
||||
|
||||
type contextKey string
|
||||
|
||||
func (c contextKey) String() string {
|
||||
return "projectreshoot context key " + string(c)
|
||||
}
|
||||
|
||||
var (
|
||||
contextKeyAuthorizedUser = contextKey("auth-user")
|
||||
contextKeyRequestTime = contextKey("req-time")
|
||||
)
|
||||
21
pkg/contexts/request_timer.go
Normal file
21
pkg/contexts/request_timer.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package contexts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Set the start time of the request
|
||||
func SetStart(ctx context.Context, time time.Time) context.Context {
|
||||
return context.WithValue(ctx, contextKeyRequestTime, time)
|
||||
}
|
||||
|
||||
// Get the start time of the request
|
||||
func GetStartTime(ctx context.Context) (time.Time, error) {
|
||||
start, ok := ctx.Value(contextKeyRequestTime).(time.Time)
|
||||
if !ok {
|
||||
return time.Time{}, errors.New("Failed to get start time of request")
|
||||
}
|
||||
return start, nil
|
||||
}
|
||||
25
pkg/contexts/user.go
Normal file
25
pkg/contexts/user.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package contexts
|
||||
|
||||
import (
|
||||
"context"
|
||||
"projectreshoot/internal/models"
|
||||
)
|
||||
|
||||
type AuthenticatedUser struct {
|
||||
*models.User
|
||||
Fresh int64
|
||||
}
|
||||
|
||||
// Return a new context with the user added in
|
||||
func SetUser(ctx context.Context, u *AuthenticatedUser) context.Context {
|
||||
return context.WithValue(ctx, contextKeyAuthorizedUser, u)
|
||||
}
|
||||
|
||||
// Retrieve a user from the given context. Returns nil if not set
|
||||
func GetUser(ctx context.Context) *AuthenticatedUser {
|
||||
user, ok := ctx.Value(contextKeyAuthorizedUser).(*AuthenticatedUser)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return user
|
||||
}
|
||||
37
pkg/cookies/functions.go
Normal file
37
pkg/cookies/functions.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package cookies
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Tell the browser to delete the cookie matching the name provided
|
||||
// Path must match the original set cookie for it to delete
|
||||
func DeleteCookie(w http.ResponseWriter, name string, path string) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: name,
|
||||
Value: "",
|
||||
Path: path,
|
||||
Expires: time.Unix(0, 0), // Expire in the past
|
||||
MaxAge: -1, // Immediately expire
|
||||
HttpOnly: true,
|
||||
})
|
||||
}
|
||||
|
||||
// Set a cookie with the given name, path and value. maxAge directly relates
|
||||
// to cookie MaxAge (0 for no max age, >0 for TTL in seconds)
|
||||
func SetCookie(
|
||||
w http.ResponseWriter,
|
||||
name string,
|
||||
path string,
|
||||
value string,
|
||||
maxAge int,
|
||||
) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: name,
|
||||
Value: value,
|
||||
Path: path,
|
||||
HttpOnly: true,
|
||||
MaxAge: maxAge,
|
||||
})
|
||||
}
|
||||
36
pkg/cookies/pagefrom.go
Normal file
36
pkg/cookies/pagefrom.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package cookies
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// Check the value of "pagefrom" cookie, delete the cookie, and return the value
|
||||
func CheckPageFrom(w http.ResponseWriter, r *http.Request) string {
|
||||
pageFromCookie, err := r.Cookie("pagefrom")
|
||||
if err != nil {
|
||||
return "/"
|
||||
}
|
||||
pageFrom := pageFromCookie.Value
|
||||
DeleteCookie(w, pageFromCookie.Name, pageFromCookie.Path)
|
||||
return pageFrom
|
||||
}
|
||||
|
||||
// Check the referer of the request, and if it matches the trustedHost, set
|
||||
// the "pagefrom" cookie as the Path of the referer
|
||||
func SetPageFrom(w http.ResponseWriter, r *http.Request, trustedHost string) {
|
||||
referer := r.Referer()
|
||||
parsedURL, err := url.Parse(referer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var pageFrom string
|
||||
if parsedURL.Path == "" || parsedURL.Host != trustedHost {
|
||||
pageFrom = "/"
|
||||
} else if parsedURL.Path == "/login" || parsedURL.Path == "/register" {
|
||||
return
|
||||
} else {
|
||||
pageFrom = parsedURL.Path
|
||||
}
|
||||
SetCookie(w, "pagefrom", "/", pageFrom, 0)
|
||||
}
|
||||
77
pkg/cookies/tokens.go
Normal file
77
pkg/cookies/tokens.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package cookies
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"projectreshoot/internal/models"
|
||||
"projectreshoot/pkg/config"
|
||||
"projectreshoot/pkg/jwt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Get the value of the access and refresh tokens
|
||||
func GetTokenStrings(
|
||||
r *http.Request,
|
||||
) (acc string, ref string) {
|
||||
accCookie, accErr := r.Cookie("access")
|
||||
refCookie, refErr := r.Cookie("refresh")
|
||||
var (
|
||||
accStr string = ""
|
||||
refStr string = ""
|
||||
)
|
||||
if accErr == nil {
|
||||
accStr = accCookie.Value
|
||||
}
|
||||
if refErr == nil {
|
||||
refStr = refCookie.Value
|
||||
}
|
||||
return accStr, refStr
|
||||
}
|
||||
|
||||
// Set a token with the provided details
|
||||
func setToken(
|
||||
w http.ResponseWriter,
|
||||
config *config.Config,
|
||||
token string,
|
||||
scope string,
|
||||
exp int64,
|
||||
rememberme bool,
|
||||
) {
|
||||
tokenCookie := &http.Cookie{
|
||||
Name: scope,
|
||||
Value: token,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: config.SSL,
|
||||
}
|
||||
if rememberme {
|
||||
tokenCookie.Expires = time.Unix(exp, 0)
|
||||
}
|
||||
http.SetCookie(w, tokenCookie)
|
||||
}
|
||||
|
||||
// Generate new tokens for the user and set them as cookies
|
||||
func SetTokenCookies(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
config *config.Config,
|
||||
user *models.User,
|
||||
fresh bool,
|
||||
rememberMe bool,
|
||||
) error {
|
||||
at, atexp, err := jwt.GenerateAccessToken(config, user, fresh, rememberMe)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "jwt.GenerateAccessToken")
|
||||
}
|
||||
rt, rtexp, err := jwt.GenerateRefreshToken(config, user, rememberMe)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "jwt.GenerateRefreshToken")
|
||||
}
|
||||
// Don't set the cookies until we know no errors occured
|
||||
setToken(w, config, at, "access", atexp, rememberMe)
|
||||
setToken(w, config, rt, "refresh", rtexp, rememberMe)
|
||||
return nil
|
||||
}
|
||||
68
pkg/db/connection.go
Normal file
68
pkg/db/connection.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// Returns a database connection handle for the DB
|
||||
func ConnectToDatabase(
|
||||
dbName string,
|
||||
logger *zerolog.Logger,
|
||||
) (*SafeConn, error) {
|
||||
opts := "_journal_mode=WAL&_synchronous=NORMAL&_txlock=IMMEDIATE"
|
||||
file := fmt.Sprintf("file:%s.db?%s", dbName, opts)
|
||||
wconn, err := sql.Open("sqlite3", file)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "sql.Open (rw)")
|
||||
}
|
||||
wconn.SetMaxOpenConns(1)
|
||||
opts = "_synchronous=NORMAL&mode=ro"
|
||||
file = fmt.Sprintf("file:%s.db?%s", dbName, opts)
|
||||
|
||||
rconn, err := sql.Open("sqlite3", file)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "sql.Open (ro)")
|
||||
}
|
||||
|
||||
version, err := strconv.Atoi(dbName)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "strconv.Atoi")
|
||||
}
|
||||
err = checkDBVersion(rconn, version)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "checkDBVersion")
|
||||
}
|
||||
conn := MakeSafe(wconn, rconn, logger)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Check the database version
|
||||
func checkDBVersion(db *sql.DB, expectVer int) error {
|
||||
query := `SELECT version_id FROM goose_db_version WHERE is_applied = 1
|
||||
ORDER BY version_id DESC LIMIT 1`
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "db.Query")
|
||||
}
|
||||
defer rows.Close()
|
||||
if rows.Next() {
|
||||
var version int
|
||||
err = rows.Scan(&version)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "rows.Scan")
|
||||
}
|
||||
if version != expectVer {
|
||||
return errors.New("Version mismatch")
|
||||
}
|
||||
} else {
|
||||
return errors.New("No version found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
162
pkg/db/safeconn.go
Normal file
162
pkg/db/safeconn.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
type SafeConn struct {
|
||||
wconn *sql.DB
|
||||
rconn *sql.DB
|
||||
readLockCount uint32
|
||||
globalLockStatus uint32
|
||||
globalLockRequested uint32
|
||||
logger *zerolog.Logger
|
||||
}
|
||||
|
||||
// Make the provided db handle safe and attach a logger to it
|
||||
func MakeSafe(wconn *sql.DB, rconn *sql.DB, logger *zerolog.Logger) *SafeConn {
|
||||
return &SafeConn{wconn: wconn, rconn: rconn, logger: logger}
|
||||
}
|
||||
|
||||
// Attempts to acquire a global lock on the database connection
|
||||
func (conn *SafeConn) acquireGlobalLock() bool {
|
||||
if conn.readLockCount > 0 || conn.globalLockStatus == 1 {
|
||||
return false
|
||||
}
|
||||
conn.globalLockStatus = 1
|
||||
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
|
||||
Msg("Global lock acquired")
|
||||
return true
|
||||
}
|
||||
|
||||
// Releases a global lock on the database connection
|
||||
func (conn *SafeConn) releaseGlobalLock() {
|
||||
conn.globalLockStatus = 0
|
||||
conn.logger.Debug().Uint32("global_lock_status", conn.globalLockStatus).
|
||||
Msg("Global lock released")
|
||||
}
|
||||
|
||||
// Acquire a read lock on the connection. Multiple read locks can be acquired
|
||||
// at the same time
|
||||
func (conn *SafeConn) acquireReadLock() bool {
|
||||
if conn.globalLockStatus == 1 || conn.globalLockRequested == 1 {
|
||||
return false
|
||||
}
|
||||
conn.readLockCount += 1
|
||||
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
|
||||
Msg("Read lock acquired")
|
||||
return true
|
||||
}
|
||||
|
||||
// Release a read lock. Decrements read lock count by 1
|
||||
func (conn *SafeConn) releaseReadLock() {
|
||||
conn.readLockCount -= 1
|
||||
conn.logger.Debug().Uint32("read_lock_count", conn.readLockCount).
|
||||
Msg("Read lock released")
|
||||
}
|
||||
|
||||
// Starts a new transaction based on the current context. Will cancel if
|
||||
// the context is closed/cancelled/done
|
||||
func (conn *SafeConn) Begin(ctx context.Context) (*SafeWTX, error) {
|
||||
lockAcquired := make(chan struct{})
|
||||
lockCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-lockCtx.Done():
|
||||
return
|
||||
default:
|
||||
if conn.acquireReadLock() {
|
||||
close(lockAcquired)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-lockAcquired:
|
||||
tx, err := conn.wconn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
conn.releaseReadLock()
|
||||
return nil, err
|
||||
}
|
||||
return &SafeWTX{tx: tx, sc: conn}, nil
|
||||
case <-ctx.Done():
|
||||
cancel()
|
||||
return nil, errors.New("Transaction time out due to database lock")
|
||||
}
|
||||
}
|
||||
|
||||
// Starts a new READONLY transaction based on the current context. Will cancel if
|
||||
// the context is closed/cancelled/done
|
||||
func (conn *SafeConn) RBegin(ctx context.Context) (*SafeRTX, error) {
|
||||
lockAcquired := make(chan struct{})
|
||||
lockCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-lockCtx.Done():
|
||||
return
|
||||
default:
|
||||
if conn.acquireReadLock() {
|
||||
close(lockAcquired)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-lockAcquired:
|
||||
tx, err := conn.rconn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
conn.releaseReadLock()
|
||||
return nil, err
|
||||
}
|
||||
return &SafeRTX{tx: tx, sc: conn}, nil
|
||||
case <-ctx.Done():
|
||||
cancel()
|
||||
return nil, errors.New("Transaction time out due to database lock")
|
||||
}
|
||||
}
|
||||
|
||||
// Acquire a global lock, preventing all transactions
|
||||
func (conn *SafeConn) Pause(timeoutAfter time.Duration) {
|
||||
conn.logger.Info().Msg("Attempting to acquire global database lock")
|
||||
conn.globalLockRequested = 1
|
||||
defer func() { conn.globalLockRequested = 0 }()
|
||||
timeout := time.After(timeoutAfter)
|
||||
attempt := 0
|
||||
for {
|
||||
if conn.acquireGlobalLock() {
|
||||
conn.logger.Info().Msg("Global database lock acquired")
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-timeout:
|
||||
conn.logger.Info().Msg("Timeout: Global database lock abandoned")
|
||||
return
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
attempt++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Release the global lock
|
||||
func (conn *SafeConn) Resume() {
|
||||
conn.releaseGlobalLock()
|
||||
conn.logger.Info().Msg("Global database lock released")
|
||||
}
|
||||
|
||||
// Close the database connection
|
||||
func (conn *SafeConn) Close() error {
|
||||
conn.logger.Debug().Msg("Acquiring global lock for connection close")
|
||||
conn.acquireGlobalLock()
|
||||
defer conn.releaseGlobalLock()
|
||||
conn.logger.Debug().Msg("Closing database connection")
|
||||
return conn.wconn.Close()
|
||||
}
|
||||
143
pkg/db/safeconntx_test.go
Normal file
143
pkg/db/safeconntx_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"projectreshoot/pkg/tests"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSafeConn(t *testing.T) {
|
||||
cfg, err := tests.TestConfig()
|
||||
require.NoError(t, err)
|
||||
logger := tests.NilLogger()
|
||||
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
|
||||
require.NoError(t, err)
|
||||
wconn, rconn, err := tests.SetupTestDB(ver)
|
||||
require.NoError(t, err)
|
||||
sconn := MakeSafe(wconn, rconn, logger)
|
||||
defer sconn.Close()
|
||||
|
||||
t.Run("Global lock waits for read locks to finish", func(t *testing.T) {
|
||||
tx, err := sconn.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
var requested sync.WaitGroup
|
||||
var engaged sync.WaitGroup
|
||||
requested.Add(1)
|
||||
engaged.Add(1)
|
||||
go func() {
|
||||
requested.Done()
|
||||
sconn.Pause(5 * time.Second)
|
||||
engaged.Done()
|
||||
}()
|
||||
requested.Wait()
|
||||
assert.Equal(t, uint32(0), sconn.globalLockStatus)
|
||||
assert.Equal(t, uint32(1), sconn.globalLockRequested)
|
||||
tx.Commit()
|
||||
engaged.Wait()
|
||||
assert.Equal(t, uint32(1), sconn.globalLockStatus)
|
||||
assert.Equal(t, uint32(0), sconn.globalLockRequested)
|
||||
sconn.Resume()
|
||||
})
|
||||
t.Run("Lock abandons after timeout", func(t *testing.T) {
|
||||
tx, err := sconn.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
sconn.Pause(250 * time.Millisecond)
|
||||
assert.Equal(t, uint32(0), sconn.globalLockStatus)
|
||||
assert.Equal(t, uint32(0), sconn.globalLockRequested)
|
||||
tx.Commit()
|
||||
})
|
||||
t.Run("Pause blocks transactions and resume allows", func(t *testing.T) {
|
||||
tx, err := sconn.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
var requested sync.WaitGroup
|
||||
var engaged sync.WaitGroup
|
||||
requested.Add(1)
|
||||
engaged.Add(1)
|
||||
go func() {
|
||||
requested.Done()
|
||||
sconn.Pause(5 * time.Second)
|
||||
engaged.Done()
|
||||
}()
|
||||
requested.Wait()
|
||||
assert.Equal(t, uint32(0), sconn.globalLockStatus)
|
||||
assert.Equal(t, uint32(1), sconn.globalLockRequested)
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
|
||||
defer cancel()
|
||||
_, err = sconn.Begin(ctx)
|
||||
require.Error(t, err)
|
||||
tx.Commit()
|
||||
engaged.Wait()
|
||||
_, err = sconn.Begin(ctx)
|
||||
require.Error(t, err)
|
||||
sconn.Resume()
|
||||
tx, err = sconn.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
tx.Commit()
|
||||
})
|
||||
}
|
||||
func TestSafeTX(t *testing.T) {
|
||||
cfg, err := tests.TestConfig()
|
||||
require.NoError(t, err)
|
||||
logger := tests.NilLogger()
|
||||
ver, err := strconv.ParseInt(cfg.DBName, 10, 0)
|
||||
require.NoError(t, err)
|
||||
wconn, rconn, err := tests.SetupTestDB(ver)
|
||||
require.NoError(t, err)
|
||||
sconn := MakeSafe(wconn, rconn, logger)
|
||||
defer sconn.Close()
|
||||
|
||||
t.Run("Commit releases lock", func(t *testing.T) {
|
||||
tx, err := sconn.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(1), sconn.readLockCount)
|
||||
tx.Commit()
|
||||
assert.Equal(t, uint32(0), sconn.readLockCount)
|
||||
})
|
||||
t.Run("Rollback releases lock", func(t *testing.T) {
|
||||
tx, err := sconn.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint32(1), sconn.readLockCount)
|
||||
tx.Rollback()
|
||||
assert.Equal(t, uint32(0), sconn.readLockCount)
|
||||
})
|
||||
t.Run("Multiple RTX can gain read lock", func(t *testing.T) {
|
||||
tx1, err := sconn.RBegin(t.Context())
|
||||
require.NoError(t, err)
|
||||
tx2, err := sconn.RBegin(t.Context())
|
||||
require.NoError(t, err)
|
||||
tx3, err := sconn.RBegin(t.Context())
|
||||
require.NoError(t, err)
|
||||
tx1.Commit()
|
||||
tx2.Commit()
|
||||
tx3.Commit()
|
||||
})
|
||||
t.Run("Lock acquiring times out after timeout", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
|
||||
defer cancel()
|
||||
sconn.acquireGlobalLock()
|
||||
defer sconn.releaseGlobalLock()
|
||||
_, err := sconn.Begin(ctx)
|
||||
require.Error(t, err)
|
||||
})
|
||||
t.Run("Lock acquires if lock released", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 250*time.Millisecond)
|
||||
defer cancel()
|
||||
sconn.acquireGlobalLock()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
tx, err := sconn.Begin(ctx)
|
||||
require.NoError(t, err)
|
||||
tx.Commit()
|
||||
wg.Done()
|
||||
}()
|
||||
sconn.releaseGlobalLock()
|
||||
wg.Wait()
|
||||
})
|
||||
}
|
||||
163
pkg/db/safetx.go
Normal file
163
pkg/db/safetx.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type SafeTX interface {
|
||||
Query(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRow(ctx context.Context, query string, args ...interface{}) (*sql.Row, error)
|
||||
Commit() error
|
||||
Rollback() error
|
||||
}
|
||||
|
||||
// Extends sql.Tx for use with SafeConn
|
||||
type SafeWTX struct {
|
||||
tx *sql.Tx
|
||||
sc *SafeConn
|
||||
}
|
||||
type SafeRTX struct {
|
||||
tx *sql.Tx
|
||||
sc *SafeConn
|
||||
}
|
||||
|
||||
func isWriteOperation(query string) bool {
|
||||
query = strings.TrimSpace(query)
|
||||
query = strings.ToUpper(query)
|
||||
writeOpsRegex := `^(INSERT|UPDATE|DELETE|REPLACE|MERGE|CREATE|DROP|ALTER|TRUNCATE)\s+`
|
||||
re := regexp.MustCompile(writeOpsRegex)
|
||||
return re.MatchString(query)
|
||||
}
|
||||
|
||||
// Query the database inside the transaction
|
||||
func (stx *SafeRTX) Query(
|
||||
ctx context.Context,
|
||||
query string,
|
||||
args ...interface{},
|
||||
) (*sql.Rows, error) {
|
||||
if stx.tx == nil {
|
||||
return nil, errors.New("Cannot query without a transaction")
|
||||
}
|
||||
if isWriteOperation(query) {
|
||||
return nil, errors.New("Cannot query with a write operation")
|
||||
}
|
||||
rows, err := stx.tx.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "tx.QueryContext")
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// Query the database inside the transaction
|
||||
func (stx *SafeWTX) Query(
|
||||
ctx context.Context,
|
||||
query string,
|
||||
args ...interface{},
|
||||
) (*sql.Rows, error) {
|
||||
if stx.tx == nil {
|
||||
return nil, errors.New("Cannot query without a transaction")
|
||||
}
|
||||
if isWriteOperation(query) {
|
||||
return nil, errors.New("Cannot query with a write operation")
|
||||
}
|
||||
rows, err := stx.tx.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "tx.QueryContext")
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// Query a row from the database inside the transaction
|
||||
func (stx *SafeRTX) QueryRow(
|
||||
ctx context.Context,
|
||||
query string,
|
||||
args ...interface{},
|
||||
) (*sql.Row, error) {
|
||||
if stx.tx == nil {
|
||||
return nil, errors.New("Cannot query without a transaction")
|
||||
}
|
||||
if isWriteOperation(query) {
|
||||
return nil, errors.New("Cannot query with a write operation")
|
||||
}
|
||||
return stx.tx.QueryRowContext(ctx, query, args...), nil
|
||||
}
|
||||
|
||||
// Query a row from the database inside the transaction
|
||||
func (stx *SafeWTX) QueryRow(
|
||||
ctx context.Context,
|
||||
query string,
|
||||
args ...interface{},
|
||||
) (*sql.Row, error) {
|
||||
if stx.tx == nil {
|
||||
return nil, errors.New("Cannot query without a transaction")
|
||||
}
|
||||
if isWriteOperation(query) {
|
||||
return nil, errors.New("Cannot query with a write operation")
|
||||
}
|
||||
return stx.tx.QueryRowContext(ctx, query, args...), nil
|
||||
}
|
||||
|
||||
// Exec a statement on the database inside the transaction
|
||||
func (stx *SafeWTX) Exec(
|
||||
ctx context.Context,
|
||||
query string,
|
||||
args ...interface{},
|
||||
) (sql.Result, error) {
|
||||
if stx.tx == nil {
|
||||
return nil, errors.New("Cannot exec without a transaction")
|
||||
}
|
||||
res, err := stx.tx.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "tx.ExecContext")
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// Commit the current transaction and release the read lock
|
||||
func (stx *SafeRTX) Commit() error {
|
||||
if stx.tx == nil {
|
||||
return errors.New("Cannot commit without a transaction")
|
||||
}
|
||||
err := stx.tx.Commit()
|
||||
stx.tx = nil
|
||||
stx.sc.releaseReadLock()
|
||||
return err
|
||||
}
|
||||
|
||||
// Commit the current transaction and release the read lock
|
||||
func (stx *SafeWTX) Commit() error {
|
||||
if stx.tx == nil {
|
||||
return errors.New("Cannot commit without a transaction")
|
||||
}
|
||||
err := stx.tx.Commit()
|
||||
stx.tx = nil
|
||||
stx.sc.releaseReadLock()
|
||||
return err
|
||||
}
|
||||
|
||||
// Abort the current transaction, releasing the read lock
|
||||
func (stx *SafeRTX) Rollback() error {
|
||||
if stx.tx == nil {
|
||||
return errors.New("Cannot rollback without a transaction")
|
||||
}
|
||||
err := stx.tx.Rollback()
|
||||
stx.tx = nil
|
||||
stx.sc.releaseReadLock()
|
||||
return err
|
||||
}
|
||||
|
||||
// Abort the current transaction, releasing the read lock
|
||||
func (stx *SafeWTX) Rollback() error {
|
||||
if stx.tx == nil {
|
||||
return errors.New("Cannot rollback without a transaction")
|
||||
}
|
||||
err := stx.tx.Rollback()
|
||||
stx.tx = nil
|
||||
stx.sc.releaseReadLock()
|
||||
return err
|
||||
}
|
||||
20
pkg/embedfs/embedfs.go
Normal file
20
pkg/embedfs/embedfs.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package embedfs
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"io/fs"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
//go:embed files/*
|
||||
var embeddedFiles embed.FS
|
||||
|
||||
// Gets the embedded files
|
||||
func GetEmbeddedFS() (fs.FS, error) {
|
||||
subFS, err := fs.Sub(embeddedFiles, "files")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "fs.Sub")
|
||||
}
|
||||
return subFS, nil
|
||||
}
|
||||
BIN
pkg/embedfs/files/assets/error.png
Normal file
BIN
pkg/embedfs/files/assets/error.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 21 KiB |
117
pkg/embedfs/files/css/input.css
Normal file
117
pkg/embedfs/files/css/input.css
Normal file
@@ -0,0 +1,117 @@
|
||||
@import url("https://fonts.googleapis.com/css2?family=Ubuntu+Mono:ital,wght@0,400;0,700;1,400;1,700&display=swap");
|
||||
@import "tailwindcss";
|
||||
|
||||
[x-cloak] {
|
||||
display: none !important;
|
||||
}
|
||||
@theme inline {
|
||||
--color-rosewater: var(--rosewater);
|
||||
--color-flamingo: var(--flamingo);
|
||||
--color-pink: var(--pink);
|
||||
--color-mauve: var(--mauve);
|
||||
--color-red: var(--red);
|
||||
--color-dark-red: var(--dark-red);
|
||||
--color-maroon: var(--maroon);
|
||||
--color-peach: var(--peach);
|
||||
--color-yellow: var(--yellow);
|
||||
--color-green: var(--green);
|
||||
--color-teal: var(--teal);
|
||||
--color-sky: var(--sky);
|
||||
--color-sapphire: var(--sapphire);
|
||||
--color-blue: var(--blue);
|
||||
--color-lavender: var(--lavender);
|
||||
--color-text: var(--text);
|
||||
--color-subtext1: var(--subtext1);
|
||||
--color-subtext0: var(--subtext0);
|
||||
--color-overlay2: var(--overlay2);
|
||||
--color-overlay1: var(--overlay1);
|
||||
--color-overlay0: var(--overlay0);
|
||||
--color-surface2: var(--surface2);
|
||||
--color-surface1: var(--surface1);
|
||||
--color-surface0: var(--surface0);
|
||||
--color-base: var(--base);
|
||||
--color-mantle: var(--mantle);
|
||||
--color-crust: var(--crust);
|
||||
}
|
||||
:root {
|
||||
--rosewater: hsl(11, 59%, 67%);
|
||||
--flamingo: hsl(0, 60%, 67%);
|
||||
--pink: hsl(316, 73%, 69%);
|
||||
--mauve: hsl(266, 85%, 58%);
|
||||
--red: hsl(347, 87%, 44%);
|
||||
--dark-red: hsl(343, 50%, 82%);
|
||||
--maroon: hsl(355, 76%, 59%);
|
||||
--peach: hsl(22, 99%, 52%);
|
||||
--yellow: hsl(35, 77%, 49%);
|
||||
--green: hsl(109, 58%, 40%);
|
||||
--teal: hsl(183, 74%, 35%);
|
||||
--sky: hsl(197, 97%, 46%);
|
||||
--sapphire: hsl(189, 70%, 42%);
|
||||
--blue: hsl(220, 91%, 54%);
|
||||
--lavender: hsl(231, 97%, 72%);
|
||||
--text: hsl(234, 16%, 35%);
|
||||
--subtext1: hsl(233, 13%, 41%);
|
||||
--subtext0: hsl(233, 10%, 47%);
|
||||
--overlay2: hsl(232, 10%, 53%);
|
||||
--overlay1: hsl(231, 10%, 59%);
|
||||
--overlay0: hsl(228, 11%, 65%);
|
||||
--surface2: hsl(227, 12%, 71%);
|
||||
--surface1: hsl(225, 14%, 77%);
|
||||
--surface0: hsl(223, 16%, 83%);
|
||||
--base: hsl(220, 23%, 95%);
|
||||
--mantle: hsl(220, 22%, 92%);
|
||||
--crust: hsl(220, 21%, 89%);
|
||||
}
|
||||
|
||||
.dark {
|
||||
--rosewater: hsl(10, 56%, 91%);
|
||||
--flamingo: hsl(0, 59%, 88%);
|
||||
--pink: hsl(316, 72%, 86%);
|
||||
--mauve: hsl(267, 84%, 81%);
|
||||
--red: hsl(343, 81%, 75%);
|
||||
--dark-red: hsl(316, 19%, 27%);
|
||||
--maroon: hsl(350, 65%, 77%);
|
||||
--peach: hsl(23, 92%, 75%);
|
||||
--yellow: hsl(41, 86%, 83%);
|
||||
--green: hsl(115, 54%, 76%);
|
||||
--teal: hsl(170, 57%, 73%);
|
||||
--sky: hsl(189, 71%, 73%);
|
||||
--sapphire: hsl(199, 76%, 69%);
|
||||
--blue: hsl(217, 92%, 76%);
|
||||
--lavender: hsl(232, 97%, 85%);
|
||||
--text: hsl(226, 64%, 88%);
|
||||
--subtext1: hsl(227, 35%, 80%);
|
||||
--subtext0: hsl(228, 24%, 72%);
|
||||
--overlay2: hsl(228, 17%, 64%);
|
||||
--overlay1: hsl(230, 13%, 55%);
|
||||
--overlay0: hsl(231, 11%, 47%);
|
||||
--surface2: hsl(233, 12%, 39%);
|
||||
--surface1: hsl(234, 13%, 31%);
|
||||
--surface0: hsl(237, 16%, 23%);
|
||||
--base: hsl(240, 21%, 15%);
|
||||
--mantle: hsl(240, 21%, 12%);
|
||||
--crust: hsl(240, 23%, 9%);
|
||||
}
|
||||
.ubuntu-mono-regular {
|
||||
font-family: "Ubuntu Mono", serif;
|
||||
font-weight: 400;
|
||||
font-style: normal;
|
||||
}
|
||||
|
||||
.ubuntu-mono-bold {
|
||||
font-family: "Ubuntu Mono", serif;
|
||||
font-weight: 700;
|
||||
font-style: normal;
|
||||
}
|
||||
|
||||
.ubuntu-mono-regular-italic {
|
||||
font-family: "Ubuntu Mono", serif;
|
||||
font-weight: 400;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.ubuntu-mono-bold-italic {
|
||||
font-family: "Ubuntu Mono", serif;
|
||||
font-weight: 700;
|
||||
font-style: italic;
|
||||
}
|
||||
1913
pkg/embedfs/files/css/output.css
Normal file
1913
pkg/embedfs/files/css/output.css
Normal file
File diff suppressed because it is too large
Load Diff
BIN
pkg/embedfs/files/favicon.ico
Normal file
BIN
pkg/embedfs/files/favicon.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 834 B |
84
pkg/jwt/create.go
Normal file
84
pkg/jwt/create.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"projectreshoot/internal/models"
|
||||
"projectreshoot/pkg/config"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Generates an access token for the provided user
|
||||
func GenerateAccessToken(
|
||||
config *config.Config,
|
||||
user *models.User,
|
||||
fresh bool,
|
||||
rememberMe bool,
|
||||
) (tokenStr string, exp int64, err error) {
|
||||
issuedAt := time.Now().Unix()
|
||||
expiresAt := issuedAt + (config.AccessTokenExpiry * 60)
|
||||
var freshExpiresAt int64
|
||||
if fresh {
|
||||
freshExpiresAt = issuedAt + (config.TokenFreshTime * 60)
|
||||
} else {
|
||||
freshExpiresAt = issuedAt
|
||||
}
|
||||
var ttl string
|
||||
if rememberMe {
|
||||
ttl = "exp"
|
||||
} else {
|
||||
ttl = "session"
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256,
|
||||
jwt.MapClaims{
|
||||
"iss": config.TrustedHost,
|
||||
"scope": "access",
|
||||
"ttl": ttl,
|
||||
"jti": uuid.New(),
|
||||
"iat": issuedAt,
|
||||
"exp": expiresAt,
|
||||
"fresh": freshExpiresAt,
|
||||
"sub": user.ID,
|
||||
})
|
||||
|
||||
signedToken, err := token.SignedString([]byte(config.SecretKey))
|
||||
if err != nil {
|
||||
return "", 0, errors.Wrap(err, "token.SignedString")
|
||||
}
|
||||
return signedToken, expiresAt, nil
|
||||
}
|
||||
|
||||
// Generates a refresh token for the provided user
|
||||
func GenerateRefreshToken(
|
||||
config *config.Config,
|
||||
user *models.User,
|
||||
rememberMe bool,
|
||||
) (tokenStr string, exp int64, err error) {
|
||||
issuedAt := time.Now().Unix()
|
||||
expiresAt := issuedAt + (config.RefreshTokenExpiry * 60)
|
||||
var ttl string
|
||||
if rememberMe {
|
||||
ttl = "exp"
|
||||
} else {
|
||||
ttl = "session"
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256,
|
||||
jwt.MapClaims{
|
||||
"iss": config.TrustedHost,
|
||||
"scope": "refresh",
|
||||
"ttl": ttl,
|
||||
"jti": uuid.New(),
|
||||
"iat": issuedAt,
|
||||
"exp": expiresAt,
|
||||
"sub": user.ID,
|
||||
})
|
||||
|
||||
signedToken, err := token.SignedString([]byte(config.SecretKey))
|
||||
if err != nil {
|
||||
return "", 0, errors.Wrap(err, "token.SignedString")
|
||||
}
|
||||
return signedToken, expiresAt, nil
|
||||
}
|
||||
268
pkg/jwt/parse.go
Normal file
268
pkg/jwt/parse.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"projectreshoot/pkg/config"
|
||||
"projectreshoot/pkg/db"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Parse an access token and return a struct with all the claims. Does validation on
|
||||
// all the claims, including checking if it is expired, has a valid issuer, and
|
||||
// has the correct scope.
|
||||
func ParseAccessToken(
|
||||
config *config.Config,
|
||||
ctx context.Context,
|
||||
tx db.SafeTX,
|
||||
tokenString string,
|
||||
) (*AccessToken, error) {
|
||||
if tokenString == "" {
|
||||
return nil, errors.New("Access token string not provided")
|
||||
}
|
||||
claims, err := parseToken(config.SecretKey, tokenString)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "parseToken")
|
||||
}
|
||||
expiry, err := checkTokenExpired(claims["exp"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "checkTokenExpired")
|
||||
}
|
||||
issuer, err := checkTokenIssuer(config.TrustedHost, claims["iss"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "checkTokenIssuer")
|
||||
}
|
||||
ttl, err := getTokenTTL(claims["ttl"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getTokenTTL")
|
||||
}
|
||||
scope, err := getTokenScope(claims["scope"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getTokenScope")
|
||||
}
|
||||
if scope != "access" {
|
||||
return nil, errors.New("Token is not an Access token")
|
||||
}
|
||||
issuedAt, err := getIssuedTime(claims["iat"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getIssuedTime")
|
||||
}
|
||||
subject, err := getTokenSubject(claims["sub"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getTokenSubject")
|
||||
}
|
||||
fresh, err := getFreshTime(claims["fresh"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getFreshTime")
|
||||
}
|
||||
jti, err := getTokenJTI(claims["jti"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getTokenJTI")
|
||||
}
|
||||
|
||||
token := &AccessToken{
|
||||
ISS: issuer,
|
||||
TTL: ttl,
|
||||
EXP: expiry,
|
||||
IAT: issuedAt,
|
||||
SUB: subject,
|
||||
Fresh: fresh,
|
||||
JTI: jti,
|
||||
Scope: scope,
|
||||
}
|
||||
|
||||
valid, err := CheckTokenNotRevoked(ctx, tx, token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "CheckTokenNotRevoked")
|
||||
}
|
||||
if !valid {
|
||||
return nil, errors.New("Token has been revoked")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// Parse a refresh token and return a struct with all the claims. Does validation on
|
||||
// all the claims, including checking if it is expired, has a valid issuer, and
|
||||
// has the correct scope.
|
||||
func ParseRefreshToken(
|
||||
config *config.Config,
|
||||
ctx context.Context,
|
||||
tx db.SafeTX,
|
||||
tokenString string,
|
||||
) (*RefreshToken, error) {
|
||||
if tokenString == "" {
|
||||
return nil, errors.New("Refresh token string not provided")
|
||||
}
|
||||
claims, err := parseToken(config.SecretKey, tokenString)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "parseToken")
|
||||
}
|
||||
expiry, err := checkTokenExpired(claims["exp"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "checkTokenExpired")
|
||||
}
|
||||
issuer, err := checkTokenIssuer(config.TrustedHost, claims["iss"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "checkTokenIssuer")
|
||||
}
|
||||
ttl, err := getTokenTTL(claims["ttl"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getTokenTTL")
|
||||
}
|
||||
scope, err := getTokenScope(claims["scope"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getTokenScope")
|
||||
}
|
||||
if scope != "refresh" {
|
||||
return nil, errors.New("Token is not an Refresh token")
|
||||
}
|
||||
issuedAt, err := getIssuedTime(claims["iat"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getIssuedTime")
|
||||
}
|
||||
subject, err := getTokenSubject(claims["sub"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getTokenSubject")
|
||||
}
|
||||
jti, err := getTokenJTI(claims["jti"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getTokenJTI")
|
||||
}
|
||||
|
||||
token := &RefreshToken{
|
||||
ISS: issuer,
|
||||
TTL: ttl,
|
||||
EXP: expiry,
|
||||
IAT: issuedAt,
|
||||
SUB: subject,
|
||||
JTI: jti,
|
||||
Scope: scope,
|
||||
}
|
||||
|
||||
valid, err := CheckTokenNotRevoked(ctx, tx, token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "CheckTokenNotRevoked")
|
||||
}
|
||||
if !valid {
|
||||
return nil, errors.New("Token has been revoked")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// Parse a token, validating its signing sigature and returning the claims
|
||||
func parseToken(secretKey string, tokenString string) (jwt.MapClaims, error) {
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
|
||||
return []byte(secretKey), nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "jwt.Parse")
|
||||
}
|
||||
// Token decoded, parse the claims
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return nil, errors.New("Failed to parse claims")
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// Check if a token is expired. Returns the expiry if not expired
|
||||
func checkTokenExpired(expiry interface{}) (int64, error) {
|
||||
// Coerce the expiry to a float64 to avoid scientific notation
|
||||
expFloat, ok := expiry.(float64)
|
||||
if !ok {
|
||||
return 0, errors.New("Missing or invalid 'exp' claim")
|
||||
}
|
||||
// Convert to the int64 time we expect :)
|
||||
expiryTime := int64(expFloat)
|
||||
|
||||
// Check if its expired
|
||||
isExpired := time.Now().After(time.Unix(expiryTime, 0))
|
||||
if isExpired {
|
||||
return 0, errors.New("Token has expired")
|
||||
}
|
||||
return expiryTime, nil
|
||||
}
|
||||
|
||||
// Check if a token has a valid issuer. Returns the issuer if valid
|
||||
func checkTokenIssuer(trustedHost string, issuer interface{}) (string, error) {
|
||||
issuerVal, ok := issuer.(string)
|
||||
if !ok {
|
||||
return "", errors.New("Missing or invalid 'iss' claim")
|
||||
}
|
||||
if issuer != trustedHost {
|
||||
return "", errors.New("Issuer does not matched trusted host")
|
||||
}
|
||||
return issuerVal, nil
|
||||
}
|
||||
|
||||
// Check the scope matches the expected scope. Returns scope if true
|
||||
func getTokenScope(scope interface{}) (string, error) {
|
||||
scopeStr, ok := scope.(string)
|
||||
if !ok {
|
||||
return "", errors.New("Missing or invalid 'scope' claim")
|
||||
}
|
||||
return scopeStr, nil
|
||||
}
|
||||
|
||||
// Get the TTL of the token, either "session" or "exp"
|
||||
func getTokenTTL(ttl interface{}) (string, error) {
|
||||
ttlStr, ok := ttl.(string)
|
||||
if !ok {
|
||||
return "", errors.New("Missing or invalid 'ttl' claim")
|
||||
}
|
||||
if ttlStr != "exp" && ttlStr != "session" {
|
||||
return "", errors.New("TTL value is not recognised")
|
||||
}
|
||||
return ttlStr, nil
|
||||
}
|
||||
|
||||
// Get the time the token was issued at
|
||||
func getIssuedTime(issued interface{}) (int64, error) {
|
||||
// Same float64 -> int64 trick as expiry
|
||||
issuedFloat, ok := issued.(float64)
|
||||
if !ok {
|
||||
return 0, errors.New("Missing or invalid 'iat' claim")
|
||||
}
|
||||
issuedAt := int64(issuedFloat)
|
||||
return issuedAt, nil
|
||||
}
|
||||
|
||||
// Get the freshness expiry timestamp
|
||||
func getFreshTime(fresh interface{}) (int64, error) {
|
||||
freshUntil, ok := fresh.(float64)
|
||||
if !ok {
|
||||
return 0, errors.New("Missing or invalid 'fresh' claim")
|
||||
}
|
||||
return int64(freshUntil), nil
|
||||
}
|
||||
|
||||
// Get the subject of the token
|
||||
func getTokenSubject(sub interface{}) (int, error) {
|
||||
subject, ok := sub.(float64)
|
||||
if !ok {
|
||||
return 0, errors.New("Missing or invalid 'sub' claim")
|
||||
}
|
||||
return int(subject), nil
|
||||
}
|
||||
|
||||
// Get the JTI of the token
|
||||
func getTokenJTI(jti interface{}) (uuid.UUID, error) {
|
||||
jtiStr, ok := jti.(string)
|
||||
if !ok {
|
||||
return uuid.UUID{}, errors.New("Missing or invalid 'jti' claim")
|
||||
}
|
||||
jtiUUID, err := uuid.Parse(jtiStr)
|
||||
if err != nil {
|
||||
return uuid.UUID{}, errors.New("JTI is not a valid UUID")
|
||||
}
|
||||
return jtiUUID, nil
|
||||
}
|
||||
33
pkg/jwt/revoke.go
Normal file
33
pkg/jwt/revoke.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"projectreshoot/pkg/db"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Revoke a token by adding it to the database
|
||||
func RevokeToken(ctx context.Context, tx *db.SafeWTX, t Token) error {
|
||||
jti := t.GetJTI()
|
||||
exp := t.GetEXP()
|
||||
query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)`
|
||||
_, err := tx.Exec(ctx, query, jti, exp)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "tx.Exec")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if a token has been revoked. Returns true if not revoked.
|
||||
func CheckTokenNotRevoked(ctx context.Context, tx db.SafeTX, t Token) (bool, error) {
|
||||
jti := t.GetJTI()
|
||||
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
|
||||
rows, err := tx.Query(ctx, query, jti)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "tx.Query")
|
||||
}
|
||||
defer rows.Close()
|
||||
revoked := rows.Next()
|
||||
return !revoked, nil
|
||||
}
|
||||
74
pkg/jwt/tokens.go
Normal file
74
pkg/jwt/tokens.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"projectreshoot/internal/models"
|
||||
"projectreshoot/pkg/db"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Token interface {
|
||||
GetJTI() uuid.UUID
|
||||
GetEXP() int64
|
||||
GetScope() string
|
||||
GetUser(ctx context.Context, tx db.SafeTX) (*models.User, error)
|
||||
}
|
||||
|
||||
// Access token
|
||||
type AccessToken struct {
|
||||
ISS string // Issuer, generally TrustedHost
|
||||
IAT int64 // Time issued at
|
||||
EXP int64 // Time expiring at
|
||||
TTL string // Time-to-live: "session" or "exp". Used with 'remember me'
|
||||
SUB int // Subject (user) ID
|
||||
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
|
||||
Fresh int64 // Time freshness expiring at
|
||||
Scope string // Should be "access"
|
||||
}
|
||||
|
||||
// Refresh token
|
||||
type RefreshToken struct {
|
||||
ISS string // Issuer, generally TrustedHost
|
||||
IAT int64 // Time issued at
|
||||
EXP int64 // Time expiring at
|
||||
TTL string // Time-to-live: "session" or "exp". Used with 'remember me'
|
||||
SUB int // Subject (user) ID
|
||||
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
|
||||
Scope string // Should be "refresh"
|
||||
}
|
||||
|
||||
func (a AccessToken) GetUser(ctx context.Context, tx db.SafeTX) (*models.User, error) {
|
||||
user, err := models.GetUserFromID(ctx, tx, a.SUB)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "db.GetUserFromID")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
func (r RefreshToken) GetUser(ctx context.Context, tx db.SafeTX) (*models.User, error) {
|
||||
user, err := models.GetUserFromID(ctx, tx, r.SUB)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "db.GetUserFromID")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (a AccessToken) GetJTI() uuid.UUID {
|
||||
return a.JTI
|
||||
}
|
||||
func (r RefreshToken) GetJTI() uuid.UUID {
|
||||
return r.JTI
|
||||
}
|
||||
func (a AccessToken) GetEXP() int64 {
|
||||
return a.EXP
|
||||
}
|
||||
func (r RefreshToken) GetEXP() int64 {
|
||||
return r.EXP
|
||||
}
|
||||
func (a AccessToken) GetScope() string {
|
||||
return a.Scope
|
||||
}
|
||||
func (r RefreshToken) GetScope() string {
|
||||
return r.Scope
|
||||
}
|
||||
84
pkg/logging/logger.go
Normal file
84
pkg/logging/logger.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/pkgerrors"
|
||||
)
|
||||
|
||||
// Takes a log level as string and converts it to a zerolog.Level interface.
|
||||
// If the string is not a valid input it will return zerolog.InfoLevel
|
||||
func GetLogLevel(level string) zerolog.Level {
|
||||
levels := map[string]zerolog.Level{
|
||||
"trace": zerolog.TraceLevel,
|
||||
"debug": zerolog.DebugLevel,
|
||||
"info": zerolog.InfoLevel,
|
||||
"warn": zerolog.WarnLevel,
|
||||
"error": zerolog.ErrorLevel,
|
||||
"fatal": zerolog.FatalLevel,
|
||||
"panic": zerolog.PanicLevel,
|
||||
}
|
||||
logLevel, valid := levels[level]
|
||||
if !valid {
|
||||
return zerolog.InfoLevel
|
||||
}
|
||||
return logLevel
|
||||
}
|
||||
|
||||
// Returns a pointer to a new log file with the specified path.
|
||||
// Remember to call file.Close() when finished writing to the log file
|
||||
func GetLogFile(path string) (*os.File, error) {
|
||||
logPath := filepath.Join(path, "server.log")
|
||||
file, err := os.OpenFile(
|
||||
logPath,
|
||||
os.O_APPEND|os.O_CREATE|os.O_WRONLY,
|
||||
0663,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "os.OpenFile")
|
||||
}
|
||||
return file, nil
|
||||
}
|
||||
|
||||
// Get a pointer to a new zerolog.Logger with the specified level and output
|
||||
// Can provide a file, writer or both. Must provide at least one of the two
|
||||
func GetLogger(
|
||||
logLevel zerolog.Level,
|
||||
w io.Writer,
|
||||
logFile *os.File,
|
||||
logDir string,
|
||||
) (*zerolog.Logger, error) {
|
||||
if w == nil && logFile == nil {
|
||||
return nil, errors.New("No Writer provided for log output.")
|
||||
}
|
||||
|
||||
zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
|
||||
zerolog.ErrorStackMarshaler = pkgerrors.MarshalStack
|
||||
|
||||
var consoleWriter zerolog.ConsoleWriter
|
||||
if w != nil {
|
||||
consoleWriter = zerolog.ConsoleWriter{Out: w}
|
||||
}
|
||||
|
||||
var output io.Writer
|
||||
if logFile != nil {
|
||||
if w != nil {
|
||||
output = zerolog.MultiLevelWriter(logFile, consoleWriter)
|
||||
} else {
|
||||
output = logFile
|
||||
}
|
||||
} else {
|
||||
output = consoleWriter
|
||||
}
|
||||
logger := zerolog.New(output).
|
||||
With().
|
||||
Timestamp().
|
||||
Logger().
|
||||
Level(logLevel)
|
||||
|
||||
return &logger, nil
|
||||
}
|
||||
18
pkg/tests/config.go
Normal file
18
pkg/tests/config.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"os"
|
||||
"projectreshoot/pkg/config"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func TestConfig() (*config.Config, error) {
|
||||
os.Setenv("SECRET_KEY", ".")
|
||||
os.Setenv("TMDB_API_TOKEN", ".")
|
||||
cfg, err := config.GetConfig(map[string]string{})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "config.GetConfig")
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
116
pkg/tests/database.go
Normal file
116
pkg/tests/database.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/pressly/goose/v3"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func findMigrations() (*fs.FS, error) {
|
||||
dir, err := os.Getwd()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for {
|
||||
if _, err := os.Stat(filepath.Join(dir, "Makefile")); err == nil {
|
||||
migrationsdir := os.DirFS(filepath.Join(dir, "cmd", "migrate", "migrations"))
|
||||
return &migrationsdir, nil
|
||||
}
|
||||
|
||||
parent := filepath.Dir(dir)
|
||||
if parent == dir { // Reached root
|
||||
return nil, errors.New("Unable to locate migrations directory")
|
||||
}
|
||||
dir = parent
|
||||
}
|
||||
}
|
||||
|
||||
func findTestData() (string, error) {
|
||||
dir, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for {
|
||||
if _, err := os.Stat(filepath.Join(dir, "Makefile")); err == nil {
|
||||
return filepath.Join(dir, "pkg", "tests", "testdata.sql"), nil
|
||||
}
|
||||
|
||||
parent := filepath.Dir(dir)
|
||||
if parent == dir { // Reached root
|
||||
return "", errors.New("Unable to locate test data")
|
||||
}
|
||||
dir = parent
|
||||
}
|
||||
}
|
||||
|
||||
func migrateTestDB(wconn *sql.DB, version int64) error {
|
||||
migrations, err := findMigrations()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "findMigrations")
|
||||
}
|
||||
provider, err := goose.NewProvider(goose.DialectSQLite3, wconn, *migrations)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "goose.NewProvider")
|
||||
}
|
||||
ctx := context.Background()
|
||||
if _, err := provider.UpTo(ctx, version); err != nil {
|
||||
return errors.Wrap(err, "provider.UpTo")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadTestData(wconn *sql.DB) error {
|
||||
dataPath, err := findTestData()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "findSchema")
|
||||
}
|
||||
sqlBytes, err := os.ReadFile(dataPath)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "os.ReadFile")
|
||||
}
|
||||
dataSQL := string(sqlBytes)
|
||||
|
||||
_, err = wconn.Exec(dataSQL)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "tx.Exec")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Returns two db connection handles. First is a readwrite connection, second
|
||||
// is a read only connection
|
||||
func SetupTestDB(version int64) (*sql.DB, *sql.DB, error) {
|
||||
opts := "_journal_mode=WAL&_synchronous=NORMAL&_txlock=IMMEDIATE"
|
||||
file := fmt.Sprintf("file::memory:?cache=shared&%s", opts)
|
||||
wconn, err := sql.Open("sqlite", file)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "sql.Open")
|
||||
}
|
||||
|
||||
err = migrateTestDB(wconn, version)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "migrateTestDB")
|
||||
}
|
||||
err = loadTestData(wconn)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "loadTestData")
|
||||
}
|
||||
|
||||
opts = "_synchronous=NORMAL&mode=ro"
|
||||
file = fmt.Sprintf("file::memory:?cache=shared&%s", opts)
|
||||
rconn, err := sql.Open("sqlite", file)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "sql.Open")
|
||||
}
|
||||
return wconn, rconn, nil
|
||||
}
|
||||
33
pkg/tests/logger.go
Normal file
33
pkg/tests/logger.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
type TLogWriter struct {
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
// Write implements the io.Writer interface for TLogWriter.
|
||||
func (w *TLogWriter) Write(p []byte) (n int, err error) {
|
||||
w.t.Logf("%s", p)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// Return a fake logger to satisfy functions that expect one
|
||||
func NilLogger() *zerolog.Logger {
|
||||
logger := zerolog.New(nil)
|
||||
return &logger
|
||||
}
|
||||
|
||||
// Return a logger that makes use of the T.Log method to enable debugging tests
|
||||
func DebugLogger(t *testing.T) *zerolog.Logger {
|
||||
logger := zerolog.New(GetTLogWriter(t))
|
||||
return &logger
|
||||
}
|
||||
|
||||
func GetTLogWriter(t *testing.T) *TLogWriter {
|
||||
return &TLogWriter{t: t}
|
||||
}
|
||||
3
pkg/tests/testdata.sql
Normal file
3
pkg/tests/testdata.sql
Normal file
@@ -0,0 +1,3 @@
|
||||
INSERT INTO users VALUES(1,'testuser','hashedpassword',1738995274, 'bio');
|
||||
INSERT INTO jwtblacklist VALUES('0a6b338e-930a-43fe-8f70-1a6daed256fa', 33299675344);
|
||||
INSERT INTO jwtblacklist VALUES('b7fa51dc-8532-42e1-8756-5d25bfb2003a', 33299675344);
|
||||
32
pkg/tmdb/config.go
Normal file
32
pkg/tmdb/config.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package tmdb
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Image Image `json:"images"`
|
||||
}
|
||||
|
||||
type Image struct {
|
||||
BaseURL string `json:"base_url"`
|
||||
SecureBaseURL string `json:"secure_base_url"`
|
||||
BackdropSizes []string `json:"backdrop_sizes"`
|
||||
LogoSizes []string `json:"logo_sizes"`
|
||||
PosterSizes []string `json:"poster_sizes"`
|
||||
ProfileSizes []string `json:"profile_sizes"`
|
||||
StillSizes []string `json:"still_sizes"`
|
||||
}
|
||||
|
||||
func GetConfig(token string) (*Config, error) {
|
||||
url := "https://api.themoviedb.org/3/configuration"
|
||||
data, err := tmdbGet(url, token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "tmdbGet")
|
||||
}
|
||||
config := Config{}
|
||||
json.Unmarshal(data, &config)
|
||||
return &config, nil
|
||||
}
|
||||
54
pkg/tmdb/credits.go
Normal file
54
pkg/tmdb/credits.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package tmdb
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Credits struct {
|
||||
ID int32 `json:"id"`
|
||||
Cast []Cast `json:"cast"`
|
||||
Crew []Crew `json:"crew"`
|
||||
}
|
||||
|
||||
type Cast struct {
|
||||
Adult bool `json:"adult"`
|
||||
Gender int `json:"gender"`
|
||||
ID int32 `json:"id"`
|
||||
KnownFor string `json:"known_for_department"`
|
||||
Name string `json:"name"`
|
||||
OriginalName string `json:"original_name"`
|
||||
Popularity int `json:"popularity"`
|
||||
Profile string `json:"profile_path"`
|
||||
CastID int32 `json:"cast_id"`
|
||||
Character string `json:"character"`
|
||||
CreditID string `json:"credit_id"`
|
||||
Order int `json:"order"`
|
||||
}
|
||||
|
||||
type Crew struct {
|
||||
Adult bool `json:"adult"`
|
||||
Gender int `json:"gender"`
|
||||
ID int32 `json:"id"`
|
||||
KnownFor string `json:"known_for_department"`
|
||||
Name string `json:"name"`
|
||||
OriginalName string `json:"original_name"`
|
||||
Popularity int `json:"popularity"`
|
||||
Profile string `json:"profile_path"`
|
||||
CreditID string `json:"credit_id"`
|
||||
Department string `json:"department"`
|
||||
Job string `json:"job"`
|
||||
}
|
||||
|
||||
func GetCredits(movieid int32, token string) (*Credits, error) {
|
||||
url := fmt.Sprintf("https://api.themoviedb.org/3/movie/%v/credits?language=en-US", movieid)
|
||||
data, err := tmdbGet(url, token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "tmdbGet")
|
||||
}
|
||||
credits := Credits{}
|
||||
json.Unmarshal(data, &credits)
|
||||
return &credits, nil
|
||||
}
|
||||
41
pkg/tmdb/crew_functions.go
Normal file
41
pkg/tmdb/crew_functions.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package tmdb
|
||||
|
||||
import "sort"
|
||||
|
||||
type BilledCrew struct {
|
||||
Name string
|
||||
Roles []string
|
||||
}
|
||||
|
||||
func (credits *Credits) BilledCrew() []BilledCrew {
|
||||
crewmap := make(map[string][]string)
|
||||
billedcrew := []BilledCrew{}
|
||||
for _, crew := range credits.Crew {
|
||||
if crew.Job == "Director" ||
|
||||
crew.Job == "Screenplay" ||
|
||||
crew.Job == "Writer" ||
|
||||
crew.Job == "Novel" ||
|
||||
crew.Job == "Story" {
|
||||
crewmap[crew.Name] = append(crewmap[crew.Name], crew.Job)
|
||||
}
|
||||
}
|
||||
|
||||
for name, jobs := range crewmap {
|
||||
billedcrew = append(billedcrew, BilledCrew{Name: name, Roles: jobs})
|
||||
}
|
||||
for i := range billedcrew {
|
||||
sort.Strings(billedcrew[i].Roles)
|
||||
}
|
||||
sort.Slice(billedcrew, func(i, j int) bool {
|
||||
return billedcrew[i].Roles[0] < billedcrew[j].Roles[0]
|
||||
})
|
||||
return billedcrew
|
||||
}
|
||||
|
||||
func (billedcrew *BilledCrew) FRoles() string {
|
||||
jobs := ""
|
||||
for _, job := range billedcrew.Roles {
|
||||
jobs += job + ", "
|
||||
}
|
||||
return jobs[:len(jobs)-2]
|
||||
}
|
||||
45
pkg/tmdb/movie.go
Normal file
45
pkg/tmdb/movie.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package tmdb
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Movie struct {
|
||||
Adult bool `json:"adult"`
|
||||
Backdrop string `json:"backdrop_path"`
|
||||
Collection string `json:"belongs_to_collection"`
|
||||
Budget int `json:"budget"`
|
||||
Genres []Genre `json:"genres"`
|
||||
Homepage string `json:"homepage"`
|
||||
ID int32 `json:"id"`
|
||||
IMDbID string `json:"imdb_id"`
|
||||
OriginalLanguage string `json:"original_language"`
|
||||
OriginalTitle string `json:"original_title"`
|
||||
Overview string `json:"overview"`
|
||||
Popularity float32 `json:"popularity"`
|
||||
Poster string `json:"poster_path"`
|
||||
ProductionCompanies []ProductionCompany `json:"production_companies"`
|
||||
ProductionCountries []ProductionCountry `json:"production_countries"`
|
||||
ReleaseDate string `json:"release_date"`
|
||||
Revenue int `json:"revenue"`
|
||||
Runtime int `json:"runtime"`
|
||||
SpokenLanguages []SpokenLanguage `json:"spoken_languages"`
|
||||
Status string `json:"status"`
|
||||
Tagline string `json:"tagline"`
|
||||
Title string `json:"title"`
|
||||
Video bool `json:"video"`
|
||||
}
|
||||
|
||||
func GetMovie(id int32, token string) (*Movie, error) {
|
||||
url := fmt.Sprintf("https://api.themoviedb.org/3/movie/%v?language=en-US", id)
|
||||
data, err := tmdbGet(url, token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "tmdbGet")
|
||||
}
|
||||
movie := Movie{}
|
||||
json.Unmarshal(data, &movie)
|
||||
return &movie, nil
|
||||
}
|
||||
42
pkg/tmdb/movie_functions.go
Normal file
42
pkg/tmdb/movie_functions.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package tmdb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"path"
|
||||
)
|
||||
|
||||
func (movie *Movie) FRuntime() string {
|
||||
hours := movie.Runtime / 60
|
||||
mins := movie.Runtime % 60
|
||||
return fmt.Sprintf("%dh %02dm", hours, mins)
|
||||
}
|
||||
|
||||
func (movie *Movie) GetPoster(image *Image, size string) string {
|
||||
base, err := url.Parse(image.SecureBaseURL)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
fullPath := path.Join(base.Path, size, movie.Poster)
|
||||
base.Path = fullPath
|
||||
return base.String()
|
||||
}
|
||||
|
||||
func (movie *Movie) ReleaseYear() string {
|
||||
if movie.ReleaseDate == "" {
|
||||
return ""
|
||||
} else {
|
||||
return "(" + movie.ReleaseDate[:4] + ")"
|
||||
}
|
||||
}
|
||||
|
||||
func (movie *Movie) FGenres() string {
|
||||
genres := ""
|
||||
for _, genre := range movie.Genres {
|
||||
genres += genre.Name + ", "
|
||||
}
|
||||
if len(genres) > 2 {
|
||||
return genres[:len(genres)-2]
|
||||
}
|
||||
return genres
|
||||
}
|
||||
28
pkg/tmdb/request.go
Normal file
28
pkg/tmdb/request.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package tmdb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func tmdbGet(url string, token string) ([]byte, error) {
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "http.NewRequest")
|
||||
}
|
||||
req.Header.Add("accept", "application/json")
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "http.DefaultClient.Do")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "io.ReadAll")
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
79
pkg/tmdb/search.go
Normal file
79
pkg/tmdb/search.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package tmdb
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"path"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Result struct {
|
||||
Page int `json:"page"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
TotalResults int `json:"total_results"`
|
||||
}
|
||||
|
||||
type ResultMovies struct {
|
||||
Result
|
||||
Results []ResultMovie `json:"results"`
|
||||
}
|
||||
type ResultMovie struct {
|
||||
Adult bool `json:"adult"`
|
||||
BackdropPath string `json:"backdrop_path"`
|
||||
GenreIDs []int `json:"genre_ids"`
|
||||
ID int32 `json:"id"`
|
||||
OriginalLanguage string `json:"original_language"`
|
||||
OriginalTitle string `json:"original_title"`
|
||||
Overview string `json:"overview"`
|
||||
Popularity int `json:"popularity"`
|
||||
PosterPath string `json:"poster_path"`
|
||||
ReleaseDate string `json:"release_date"`
|
||||
Title string `json:"title"`
|
||||
Video bool `json:"video"`
|
||||
VoteAverage int `json:"vote_average"`
|
||||
VoteCount int `json:"vote_count"`
|
||||
}
|
||||
|
||||
func (movie *ResultMovie) GetPoster(image *Image, size string) string {
|
||||
base, err := url.Parse(image.SecureBaseURL)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
fullPath := path.Join(base.Path, size, movie.PosterPath)
|
||||
base.Path = fullPath
|
||||
return base.String()
|
||||
}
|
||||
|
||||
func (movie *ResultMovie) ReleaseYear() string {
|
||||
if movie.ReleaseDate == "" {
|
||||
return ""
|
||||
} else {
|
||||
return "(" + movie.ReleaseDate[:4] + ")"
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: genres list https://developer.themoviedb.org/reference/genre-movie-list
|
||||
// func (movie *ResultMovie) FGenres() string {
|
||||
// genres := ""
|
||||
// for _, genre := range movie.Genres {
|
||||
// genres += genre.Name + ", "
|
||||
// }
|
||||
// return genres[:len(genres)-2]
|
||||
// }
|
||||
|
||||
func SearchMovies(token string, query string, adult bool, page int) (*ResultMovies, error) {
|
||||
url := "https://api.themoviedb.org/3/search/movie" +
|
||||
fmt.Sprintf("?query=%s", url.QueryEscape(query)) +
|
||||
fmt.Sprintf("&include_adult=%t", adult) +
|
||||
fmt.Sprintf("&page=%v", page) +
|
||||
"&language=en-US"
|
||||
response, err := tmdbGet(url, token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "tmdbGet")
|
||||
}
|
||||
var results ResultMovies
|
||||
json.Unmarshal(response, &results)
|
||||
return &results, nil
|
||||
}
|
||||
24
pkg/tmdb/structs.go
Normal file
24
pkg/tmdb/structs.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package tmdb
|
||||
|
||||
type Genre struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type ProductionCompany struct {
|
||||
ID int `json:"id"`
|
||||
Logo string `json:"logo_path"`
|
||||
Name string `json:"name"`
|
||||
OriginCountry string `json:"origin_country"`
|
||||
}
|
||||
|
||||
type ProductionCountry struct {
|
||||
ISO_3166_1 string `json:"iso_3166_1"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type SpokenLanguage struct {
|
||||
EnglishName string `json:"english_name"`
|
||||
ISO_639_1 string `json:"iso_639_1"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
Reference in New Issue
Block a user