Compare commits
7 Commits
hlog/v0.9.
...
hws/v0.1.0
| Author | SHA1 | Date | |
|---|---|---|---|
| 14eec74683 | |||
| ade3fa0454 | |||
| 516be905a9 | |||
| 6e632267ea | |||
| 05aad5f11b | |||
| c4574e32c7 | |||
| c466cd3163 |
19
cookies/delete.go
Normal file
19
cookies/delete.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package cookies
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Tell the browser to delete the cookie matching the name provided
|
||||
// Path must match the original set cookie for it to delete
|
||||
func DeleteCookie(w http.ResponseWriter, name string, path string) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: name,
|
||||
Value: "",
|
||||
Path: path,
|
||||
Expires: time.Unix(0, 0), // Expire in the past
|
||||
MaxAge: -1, // Immediately expire
|
||||
HttpOnly: true,
|
||||
})
|
||||
}
|
||||
3
cookies/go.mod
Normal file
3
cookies/go.mod
Normal file
@@ -0,0 +1,3 @@
|
||||
module git.haelnorr.com/h/golib/cookies
|
||||
|
||||
go 1.25.5
|
||||
36
cookies/pagefrom.go
Normal file
36
cookies/pagefrom.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package cookies
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// Check the value of "pagefrom" cookie, delete the cookie, and return the value
|
||||
func CheckPageFrom(w http.ResponseWriter, r *http.Request) string {
|
||||
pageFromCookie, err := r.Cookie("pagefrom")
|
||||
if err != nil {
|
||||
return "/"
|
||||
}
|
||||
pageFrom := pageFromCookie.Value
|
||||
DeleteCookie(w, pageFromCookie.Name, pageFromCookie.Path)
|
||||
return pageFrom
|
||||
}
|
||||
|
||||
// Check the referer of the request, and if it matches the trustedHost, set
|
||||
// the "pagefrom" cookie as the Path of the referer
|
||||
func SetPageFrom(w http.ResponseWriter, r *http.Request, trustedHost string) {
|
||||
referer := r.Referer()
|
||||
parsedURL, err := url.Parse(referer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var pageFrom string
|
||||
if parsedURL.Path == "" || parsedURL.Host != trustedHost {
|
||||
pageFrom = "/"
|
||||
} else if parsedURL.Path == "/login" || parsedURL.Path == "/register" {
|
||||
return
|
||||
} else {
|
||||
pageFrom = parsedURL.Path
|
||||
}
|
||||
SetCookie(w, "pagefrom", "/", pageFrom, 0)
|
||||
}
|
||||
23
cookies/set.go
Normal file
23
cookies/set.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package cookies
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Set a cookie with the given name, path and value. maxAge directly relates
|
||||
// to cookie MaxAge (0 for no max age, >0 for TTL in seconds)
|
||||
func SetCookie(
|
||||
w http.ResponseWriter,
|
||||
name string,
|
||||
path string,
|
||||
value string,
|
||||
maxAge int,
|
||||
) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: name,
|
||||
Value: value,
|
||||
Path: path,
|
||||
HttpOnly: true,
|
||||
MaxAge: maxAge,
|
||||
})
|
||||
}
|
||||
35
env/boolean.go
vendored
Normal file
35
env/boolean.go
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
package env
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Get an environment variable as a boolean, specifying a default value if its
|
||||
// not set or can't be parsed properly into a bool
|
||||
func Bool(key string, defaultValue bool) bool {
|
||||
val, exists := os.LookupEnv(key)
|
||||
if !exists {
|
||||
return defaultValue
|
||||
}
|
||||
truthy := map[string]bool{
|
||||
"true": true, "t": true, "yes": true, "y": true, "on": true, "1": true,
|
||||
"enable": true, "enabled": true, "active": true, "affirmative": true,
|
||||
}
|
||||
|
||||
falsy := map[string]bool{
|
||||
"false": false, "f": false, "no": false, "n": false, "off": false, "0": false,
|
||||
"disable": false, "disabled": false, "inactive": false, "negative": false,
|
||||
}
|
||||
|
||||
normalized := strings.TrimSpace(strings.ToLower(val))
|
||||
|
||||
if val, ok := truthy[normalized]; ok {
|
||||
return val
|
||||
}
|
||||
if val, ok := falsy[normalized]; ok {
|
||||
return val
|
||||
}
|
||||
|
||||
return defaultValue
|
||||
}
|
||||
23
env/duration.go
vendored
Normal file
23
env/duration.go
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
package env
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Get an environment variable as a time.Duration, specifying a default value if its
|
||||
// not set or can't be parsed properly
|
||||
func Duration(key string, defaultValue time.Duration) time.Duration {
|
||||
val, exists := os.LookupEnv(key)
|
||||
if !exists {
|
||||
return time.Duration(defaultValue)
|
||||
}
|
||||
|
||||
intVal, err := strconv.Atoi(val)
|
||||
if err != nil {
|
||||
return time.Duration(defaultValue)
|
||||
}
|
||||
return time.Duration(intVal)
|
||||
|
||||
}
|
||||
3
env/go.mod
vendored
Normal file
3
env/go.mod
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
module git.haelnorr.com/h/golib/env
|
||||
|
||||
go 1.25.5
|
||||
37
env/int.go
vendored
Normal file
37
env/int.go
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
package env
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// Get an environment variable as an int, specifying a default value if its
|
||||
// not set or can't be parsed properly into an int
|
||||
func Int(key string, defaultValue int) int {
|
||||
val, exists := os.LookupEnv(key)
|
||||
if !exists {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
intVal, err := strconv.Atoi(val)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return intVal
|
||||
}
|
||||
|
||||
// Get an environment variable as an int64, specifying a default value if its
|
||||
// not set or can't be parsed properly into an int64
|
||||
func Int64(key string, defaultValue int64) int64 {
|
||||
val, exists := os.LookupEnv(key)
|
||||
if !exists {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
intVal, err := strconv.ParseInt(val, 10, 64)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return intVal
|
||||
|
||||
}
|
||||
14
env/string.go
vendored
Normal file
14
env/string.go
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
package env
|
||||
|
||||
import (
|
||||
"os"
|
||||
)
|
||||
|
||||
// Get an environment variable, specifying a default value if its not set
|
||||
func String(key string, defaultValue string) string {
|
||||
val, exists := os.LookupEnv(key)
|
||||
if !exists {
|
||||
return defaultValue
|
||||
}
|
||||
return val
|
||||
}
|
||||
39
hws/errors.go
Normal file
39
hws/errors.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package hws
|
||||
|
||||
import "net/http"
|
||||
|
||||
type HWSError struct {
|
||||
statusCode int // HTTP Status code
|
||||
message string // Error message
|
||||
error error // Error
|
||||
}
|
||||
|
||||
type ErrorPage func(statusCode int, w http.ResponseWriter, r *http.Request) error
|
||||
|
||||
func NewError(statusCode int, msg string, err error) *HWSError {
|
||||
return &HWSError{
|
||||
statusCode: statusCode,
|
||||
message: msg,
|
||||
error: err,
|
||||
}
|
||||
}
|
||||
|
||||
func (server *Server) AddErrorPage(page ErrorPage) {
|
||||
server.errorPage = page
|
||||
}
|
||||
|
||||
func (server *Server) ThrowError(w http.ResponseWriter, r *http.Request, error *HWSError) {
|
||||
w.WriteHeader(error.statusCode)
|
||||
server.logger.logger.Error().Err(error.error).Msg(error.message)
|
||||
if server.errorPage != nil {
|
||||
err := server.errorPage(error.statusCode, w, r)
|
||||
if err != nil {
|
||||
server.logger.logger.Error().Err(err).Msg("Failed to render error page")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (server *Server) ThrowWarn(w http.ResponseWriter, error *HWSError) {
|
||||
w.WriteHeader(error.statusCode)
|
||||
server.logger.logger.Warn().Err(error.error).Msg(error.message)
|
||||
}
|
||||
14
hws/go.mod
Normal file
14
hws/go.mod
Normal file
@@ -0,0 +1,14 @@
|
||||
module git.haelnorr.com/h/golib/hws
|
||||
|
||||
go 1.25.5
|
||||
|
||||
require (
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/rs/zerolog v1.34.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||
golang.org/x/sys v0.12.0 // indirect
|
||||
)
|
||||
16
hws/go.sum
Normal file
16
hws/go.sum
Normal file
@@ -0,0 +1,16 @@
|
||||
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
|
||||
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
|
||||
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||
github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY=
|
||||
github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ=
|
||||
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
31
hws/gzip.go
Normal file
31
hws/gzip.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package hws
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func addgzip(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Encoding", "gzip")
|
||||
gz := gzip.NewWriter(w)
|
||||
defer gz.Close()
|
||||
gzw := gzipResponseWriter{Writer: gz, ResponseWriter: w}
|
||||
next.ServeHTTP(gzw, r)
|
||||
})
|
||||
}
|
||||
|
||||
type gzipResponseWriter struct {
|
||||
io.Writer
|
||||
http.ResponseWriter
|
||||
}
|
||||
|
||||
func (w gzipResponseWriter) Write(b []byte) (int, error) {
|
||||
return w.Writer.Write(b)
|
||||
}
|
||||
44
hws/logger.go
Normal file
44
hws/logger.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package hws
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
type logger struct {
|
||||
logger *zerolog.Logger
|
||||
ignoredPaths []string
|
||||
}
|
||||
|
||||
// Server.AddLogger adds a logger to the server to use for request logging.
|
||||
func (server *Server) AddLogger(zlogger *zerolog.Logger) error {
|
||||
if zlogger == nil {
|
||||
return errors.New("Unable to add logger, no logger provided")
|
||||
}
|
||||
server.logger = &logger{
|
||||
logger: zlogger,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Server.LoggerIgnorePaths sets a list of URL paths to ignore logging for.
|
||||
// Path should match the url.URL.Path field, see https://pkg.go.dev/net/url#URL
|
||||
// Useful for ignoring requests to CSS files or favicons
|
||||
func (server *Server) LoggerIgnorePaths(paths ...string) error {
|
||||
for _, path := range paths {
|
||||
u, err := url.Parse(path)
|
||||
valid := err == nil &&
|
||||
u.Scheme == "" &&
|
||||
u.Host == "" &&
|
||||
u.RawQuery == "" &&
|
||||
u.Fragment == ""
|
||||
if !valid {
|
||||
return fmt.Errorf("Invalid path: '%s'", path)
|
||||
}
|
||||
}
|
||||
server.logger.ignoredPaths = paths
|
||||
return nil
|
||||
}
|
||||
51
hws/middleware.go
Normal file
51
hws/middleware.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package hws
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Middleware func(h http.Handler) http.Handler
|
||||
type MiddlewareFunc func(w http.ResponseWriter, r *http.Request) (*http.Request, *HWSError)
|
||||
|
||||
// Server.AddMiddleware registers all the middleware.
|
||||
// Middleware will be run in the order that they are provided.
|
||||
func (server *Server) AddMiddleware(middleware ...Middleware) error {
|
||||
if !server.routes {
|
||||
return errors.New("Server.AddRoutes must be called before Server.AddMiddleware")
|
||||
}
|
||||
|
||||
// RUN LOGGING MIDDLEWARE FIRST
|
||||
server.server.Handler = logging(server.server.Handler, server.logger)
|
||||
|
||||
// LOOP PROVIDED MIDDLEWARE IN REVERSE order
|
||||
for i := len(middleware); i > 0; i-- {
|
||||
server.server.Handler = middleware[i-1](server.server.Handler)
|
||||
}
|
||||
|
||||
// RUN GZIP
|
||||
if server.gzip {
|
||||
server.server.Handler = addgzip(server.server.Handler)
|
||||
}
|
||||
// RUN TIMER MIDDLEWARE LAST
|
||||
server.server.Handler = startTimer(server.server.Handler)
|
||||
|
||||
server.middleware = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (server *Server) NewMiddleware(
|
||||
middlewareFunc MiddlewareFunc,
|
||||
) Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
newReq, herr := middlewareFunc(w, r)
|
||||
if herr != nil {
|
||||
server.ThrowError(w, r, herr)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, newReq)
|
||||
})
|
||||
}
|
||||
}
|
||||
38
hws/middleware_logging.go
Normal file
38
hws/middleware_logging.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package hws
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"slices"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Middleware to add logs to console with details of the request
|
||||
func logging(next http.Handler, logger *logger) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if logger == nil {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
if slices.Contains(logger.ignoredPaths, r.URL.Path) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
start, err := getStartTime(r.Context())
|
||||
if err != nil {
|
||||
logger.logger.Error().Err(err)
|
||||
return
|
||||
}
|
||||
wrapped := &wrappedWriter{
|
||||
ResponseWriter: w,
|
||||
statusCode: http.StatusOK,
|
||||
}
|
||||
next.ServeHTTP(wrapped, r)
|
||||
logger.logger.Info().
|
||||
Int("status", wrapped.statusCode).
|
||||
Str("method", r.Method).
|
||||
Str("resource", r.URL.Path).
|
||||
Dur("time_elapsed", time.Since(start)).
|
||||
Str("remote_addr", r.Header.Get("X-Forwarded-For")).
|
||||
Msg("Served")
|
||||
})
|
||||
}
|
||||
33
hws/middleware_timer.go
Normal file
33
hws/middleware_timer.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package hws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
func startTimer(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
ctx := setStart(r.Context(), start)
|
||||
newReq := r.WithContext(ctx)
|
||||
next.ServeHTTP(w, newReq)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// Set the start time of the request
|
||||
func setStart(ctx context.Context, time time.Time) context.Context {
|
||||
return context.WithValue(ctx, "hws context key request-timer", time)
|
||||
}
|
||||
|
||||
// Get the start time of the request
|
||||
func getStartTime(ctx context.Context) (time.Time, error) {
|
||||
start, ok := ctx.Value("hws context key request-timer").(time.Time)
|
||||
if !ok {
|
||||
return time.Time{}, errors.New("Failed to get start time of request")
|
||||
}
|
||||
return start, nil
|
||||
}
|
||||
15
hws/responsewriter.go
Normal file
15
hws/responsewriter.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package hws
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Wraps the http.ResponseWriter, adding a statusCode field
|
||||
type wrappedWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
}
|
||||
|
||||
// Extends WriteHeader to the ResponseWriter to add the status code
|
||||
func (w *wrappedWriter) WriteHeader(statusCode int) {
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
w.statusCode = statusCode
|
||||
}
|
||||
62
hws/routes.go
Normal file
62
hws/routes.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package hws
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Route struct {
|
||||
Path string // Absolute path to the requested resource
|
||||
Method Method // HTTP Method
|
||||
Handler http.Handler // Handler to use for the request
|
||||
}
|
||||
|
||||
type Method string
|
||||
|
||||
const (
|
||||
MethodGET Method = "GET"
|
||||
MethodPOST Method = "POST"
|
||||
MethodPUT Method = "PUT"
|
||||
MethodHEAD Method = "HEAD"
|
||||
MethodDELETE Method = "DELETE"
|
||||
MethodCONNECT Method = "CONNECT"
|
||||
MethodOPTIONS Method = "OPTIONS"
|
||||
MethodTRACE Method = "TRACE"
|
||||
MethodPATCH Method = "PATCH"
|
||||
)
|
||||
|
||||
// Server.AddRoutes registers the page handlers for the server.
|
||||
// At least one route must be provided.
|
||||
func (server *Server) AddRoutes(routes ...Route) error {
|
||||
if len(routes) == 0 {
|
||||
return errors.New("No routes provided")
|
||||
}
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("GET /healthz", func(http.ResponseWriter, *http.Request) {})
|
||||
for _, route := range routes {
|
||||
if !validMethod(route.Method) {
|
||||
return fmt.Errorf("Invalid method %s for path %s", route.Method, route.Path)
|
||||
}
|
||||
if route.Handler == nil {
|
||||
return fmt.Errorf("No handler provided for %s %s", route.Method, route.Path)
|
||||
}
|
||||
pattern := fmt.Sprintf("%s %s", route.Method, route.Path)
|
||||
mux.Handle(pattern, route.Handler)
|
||||
}
|
||||
|
||||
server.server.Handler = mux
|
||||
server.routes = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validMethod(m Method) bool {
|
||||
switch m {
|
||||
case MethodGET, MethodPOST, MethodPUT, MethodHEAD,
|
||||
MethodDELETE, MethodCONNECT, MethodOPTIONS, MethodTRACE, MethodPATCH:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
52
hws/safefileserver.go
Normal file
52
hws/safefileserver.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package hws
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Wrapper for default FileSystem
|
||||
type justFilesFilesystem struct {
|
||||
fs http.FileSystem
|
||||
}
|
||||
|
||||
// Wrapper for default File
|
||||
type neuteredReaddirFile struct {
|
||||
http.File
|
||||
}
|
||||
|
||||
// Modifies the behavior of FileSystem.Open to return the neutered version of File
|
||||
func (fs justFilesFilesystem) Open(name string) (http.File, error) {
|
||||
f, err := fs.fs.Open(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if the requested path is a directory
|
||||
// and explicitly return an error to trigger a 404
|
||||
fileInfo, err := f.Stat()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if fileInfo.IsDir() {
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
|
||||
return neuteredReaddirFile{f}, nil
|
||||
}
|
||||
|
||||
// Overrides the Readdir method of File to always return nil
|
||||
func (f neuteredReaddirFile) Readdir(count int) ([]os.FileInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func SafeFileServer(fileSystem *http.FileSystem) (http.Handler, error) {
|
||||
if fileSystem == nil {
|
||||
return nil, errors.New("No file system provided")
|
||||
}
|
||||
nfs := justFilesFilesystem{*fileSystem}
|
||||
fs := http.FileServer(nfs)
|
||||
return fs, nil
|
||||
}
|
||||
84
hws/server.go
Normal file
84
hws/server.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package hws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
server *http.Server
|
||||
logger *logger
|
||||
routes bool
|
||||
middleware bool
|
||||
gzip bool
|
||||
errorPage ErrorPage
|
||||
}
|
||||
|
||||
// NewServer returns a new hws.Server with the specified parameters.
|
||||
// The timeout options are specified in seconds
|
||||
func NewServer(
|
||||
host string,
|
||||
port string,
|
||||
readHeaderTimeout time.Duration,
|
||||
writeTimeout time.Duration,
|
||||
idleTimeout time.Duration,
|
||||
gzip bool,
|
||||
) (*Server, error) {
|
||||
// TODO: test that host and port are valid values
|
||||
httpServer := &http.Server{
|
||||
Addr: net.JoinHostPort(host, port),
|
||||
ReadHeaderTimeout: readHeaderTimeout * time.Second,
|
||||
WriteTimeout: writeTimeout * time.Second,
|
||||
IdleTimeout: idleTimeout * time.Second,
|
||||
}
|
||||
server := &Server{
|
||||
server: httpServer,
|
||||
routes: false,
|
||||
gzip: gzip,
|
||||
}
|
||||
return server, nil
|
||||
}
|
||||
|
||||
func (server *Server) Start() error {
|
||||
if !server.routes {
|
||||
return errors.New("Server.AddRoutes must be run before starting the server")
|
||||
}
|
||||
if !server.middleware {
|
||||
err := server.AddMiddleware()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "server.AddMiddleware")
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
if server.logger == nil {
|
||||
fmt.Printf("Listening for requests on %s", server.server.Addr)
|
||||
} else {
|
||||
server.logger.logger.Info().Str("address", server.server.Addr).Msg("Listening for requests")
|
||||
}
|
||||
if err := server.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
if server.logger == nil {
|
||||
fmt.Printf("Server encountered a fatal error: %s", err.Error())
|
||||
} else {
|
||||
server.logger.logger.Error().Err(err).Msg("Server encountered a fatal error")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (server *Server) Shutdown(ctx context.Context) {
|
||||
if err := server.server.Shutdown(ctx); err != nil {
|
||||
if server.logger == nil {
|
||||
fmt.Printf("Failed to gracefully shutdown the server: %s", err.Error())
|
||||
} else {
|
||||
server.logger.logger.Error().Err(err).Msg("Failed to gracefully shutdown the server")
|
||||
}
|
||||
}
|
||||
}
|
||||
73
jwt/cookies.go
Normal file
73
jwt/cookies.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"github.com/pkg/errors"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Get the value of the access and refresh tokens
|
||||
func GetTokenCookies(
|
||||
r *http.Request,
|
||||
) (acc string, ref string) {
|
||||
accCookie, accErr := r.Cookie("access")
|
||||
refCookie, refErr := r.Cookie("refresh")
|
||||
var (
|
||||
accStr string = ""
|
||||
refStr string = ""
|
||||
)
|
||||
if accErr == nil {
|
||||
accStr = accCookie.Value
|
||||
}
|
||||
if refErr == nil {
|
||||
refStr = refCookie.Value
|
||||
}
|
||||
return accStr, refStr
|
||||
}
|
||||
|
||||
// Set a token with the provided details
|
||||
func setToken(
|
||||
w http.ResponseWriter,
|
||||
token string,
|
||||
scope string,
|
||||
exp int64,
|
||||
rememberme bool,
|
||||
useSSL bool,
|
||||
) {
|
||||
tokenCookie := &http.Cookie{
|
||||
Name: scope,
|
||||
Value: token,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Secure: useSSL,
|
||||
}
|
||||
if rememberme {
|
||||
tokenCookie.Expires = time.Unix(exp, 0)
|
||||
}
|
||||
http.SetCookie(w, tokenCookie)
|
||||
}
|
||||
|
||||
// Generate new tokens for the subject and set them as cookies
|
||||
func SetTokenCookies(
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
tokenGen *TokenGenerator,
|
||||
subject int,
|
||||
fresh bool,
|
||||
rememberMe bool,
|
||||
useSSL bool,
|
||||
) error {
|
||||
at, atexp, err := tokenGen.NewAccess(subject, fresh, rememberMe)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "jwt.GenerateAccessToken")
|
||||
}
|
||||
rt, rtexp, err := tokenGen.NewRefresh(subject, rememberMe)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "jwt.GenerateRefreshToken")
|
||||
}
|
||||
// Don't set the cookies until we know no errors occured
|
||||
setToken(w, at, "access", atexp, rememberMe, useSSL)
|
||||
setToken(w, rt, "refresh", rtexp, rememberMe, useSSL)
|
||||
return nil
|
||||
}
|
||||
62
jwt/generator.go
Normal file
62
jwt/generator.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
)
|
||||
|
||||
type TokenGenerator struct {
|
||||
accessExpireAfter int64 // Access Token expiry time in minutes
|
||||
refreshExpireAfter int64 // Refresh Token expiry time in minutes
|
||||
freshExpireAfter int64 // Token freshness expiry time in minutes
|
||||
trustedHost string // Trusted hostname to use for the tokens
|
||||
secretKey string // Secret key to use for token hashing
|
||||
dbConn *sql.DB // Database handle for token blacklisting
|
||||
}
|
||||
|
||||
// CreateGenerator creates and returns a new TokenGenerator using the provided configuration.
|
||||
// All expiry times should be provided in minutes.
|
||||
// trustedHost and secretKey strings must be provided.
|
||||
// dbConn can be nil, but doing this will disable token revocation
|
||||
func CreateGenerator(
|
||||
accessExpireAfter int64,
|
||||
refreshExpireAfter int64,
|
||||
freshExpireAfter int64,
|
||||
trustedHost string,
|
||||
secretKey string,
|
||||
dbConn *sql.DB,
|
||||
) (gen *TokenGenerator, err error) {
|
||||
if accessExpireAfter <= 0 {
|
||||
return nil, errors.New("accessExpireAfter must be greater than 0")
|
||||
}
|
||||
if refreshExpireAfter <= 0 {
|
||||
return nil, errors.New("refreshExpireAfter must be greater than 0")
|
||||
}
|
||||
if freshExpireAfter <= 0 {
|
||||
return nil, errors.New("freshExpireAfter must be greater than 0")
|
||||
}
|
||||
if trustedHost == "" {
|
||||
return nil, errors.New("trustedHost cannot be an empty string")
|
||||
}
|
||||
if secretKey == "" {
|
||||
return nil, errors.New("secretKey cannot be an empty string")
|
||||
}
|
||||
|
||||
if dbConn != nil {
|
||||
err := dbConn.Ping()
|
||||
if err != nil {
|
||||
return nil, errors.New("Failed to ping database")
|
||||
}
|
||||
// TODO: check if jwtblacklist table exists
|
||||
// TODO: create jwtblacklist table if not existing
|
||||
}
|
||||
|
||||
return &TokenGenerator{
|
||||
accessExpireAfter: accessExpireAfter,
|
||||
refreshExpireAfter: refreshExpireAfter,
|
||||
freshExpireAfter: freshExpireAfter,
|
||||
trustedHost: trustedHost,
|
||||
secretKey: secretKey,
|
||||
dbConn: dbConn,
|
||||
}, nil
|
||||
}
|
||||
90
jwt/generator_test.go
Normal file
90
jwt/generator_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCreateGenerator_Success_NoDB(t *testing.T) {
|
||||
gen, err := CreateGenerator(
|
||||
15,
|
||||
60,
|
||||
5,
|
||||
"example.com",
|
||||
"secret",
|
||||
nil,
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, gen)
|
||||
}
|
||||
|
||||
func TestCreateGenerator_Success_WithDB(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
gen, err := CreateGenerator(
|
||||
15,
|
||||
60,
|
||||
5,
|
||||
"example.com",
|
||||
"secret",
|
||||
db,
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, gen)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestCreateGenerator_InvalidInputs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func() error
|
||||
}{
|
||||
{
|
||||
"access expiry <= 0",
|
||||
func() error {
|
||||
_, err := CreateGenerator(0, 1, 1, "h", "s", nil)
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
"refresh expiry <= 0",
|
||||
func() error {
|
||||
_, err := CreateGenerator(1, 0, 1, "h", "s", nil)
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
"fresh expiry <= 0",
|
||||
func() error {
|
||||
_, err := CreateGenerator(1, 1, 0, "h", "s", nil)
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
"empty trustedHost",
|
||||
func() error {
|
||||
_, err := CreateGenerator(1, 1, 1, "", "s", nil)
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
"empty secretKey",
|
||||
func() error {
|
||||
_, err := CreateGenerator(1, 1, 1, "h", "", nil)
|
||||
return err
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Error(t, tt.fn())
|
||||
})
|
||||
}
|
||||
}
|
||||
17
jwt/go.mod
Normal file
17
jwt/go.mod
Normal file
@@ -0,0 +1,17 @@
|
||||
module git.haelnorr.com/h/golib/jwt
|
||||
|
||||
go 1.25.5
|
||||
|
||||
require (
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/stretchr/testify v1.11.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
19
jwt/go.sum
Normal file
19
jwt/go.sum
Normal file
@@ -0,0 +1,19 @@
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
38
jwt/revoke.go
Normal file
38
jwt/revoke.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Revoke a token by adding it to the database
|
||||
func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error {
|
||||
if gen.dbConn == nil {
|
||||
return errors.New("No DB provided, unable to use this function")
|
||||
}
|
||||
jti := t.GetJTI()
|
||||
exp := t.GetEXP()
|
||||
query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)`
|
||||
_, err := tx.Exec(query, jti, exp)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "tx.Exec")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if a token has been revoked. Returns true if not revoked.
|
||||
func (gen *TokenGenerator) checkNotRevoked(tx *sql.Tx, t Token) (bool, error) {
|
||||
if gen.dbConn == nil {
|
||||
return false, errors.New("No DB provided, unable to use this function")
|
||||
}
|
||||
jti := t.GetJTI()
|
||||
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
|
||||
rows, err := tx.Query(query, jti)
|
||||
if err != nil {
|
||||
return false, errors.Wrap(err, "tx.Query")
|
||||
}
|
||||
defer rows.Close()
|
||||
revoked := rows.Next()
|
||||
return !revoked, nil
|
||||
}
|
||||
83
jwt/revoke_test.go
Normal file
83
jwt/revoke_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newGeneratorWithNoDB(t *testing.T) *TokenGenerator {
|
||||
gen, err := CreateGenerator(
|
||||
15,
|
||||
60,
|
||||
5,
|
||||
"example.com",
|
||||
"supersecret",
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
return gen
|
||||
}
|
||||
|
||||
func TestNoDBFail(t *testing.T) {
|
||||
jti := uuid.New()
|
||||
exp := time.Now().Add(time.Hour).Unix()
|
||||
|
||||
token := AccessToken{
|
||||
JTI: jti,
|
||||
EXP: exp,
|
||||
gen: &TokenGenerator{},
|
||||
}
|
||||
|
||||
// Revoke should fail due to no DB
|
||||
err := token.Revoke(&sql.Tx{})
|
||||
require.Error(t, err)
|
||||
|
||||
// CheckNotRevoked should fail
|
||||
_, err = token.CheckNotRevoked(&sql.Tx{})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestRevokeAndCheckNotRevoked(t *testing.T) {
|
||||
gen, mock, cleanup := newGeneratorWithMockDB(t)
|
||||
defer cleanup()
|
||||
|
||||
jti := uuid.New()
|
||||
exp := time.Now().Add(time.Hour).Unix()
|
||||
|
||||
token := AccessToken{
|
||||
JTI: jti,
|
||||
EXP: exp,
|
||||
gen: gen,
|
||||
}
|
||||
|
||||
// Revoke expectations
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec(`INSERT INTO jwtblacklist`).
|
||||
WithArgs(jti, exp).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectQuery(`SELECT 1 FROM jwtblacklist`).
|
||||
WithArgs(jti).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
tx, err := gen.dbConn.BeginTx(context.Background(), nil)
|
||||
defer tx.Rollback()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = token.Revoke(tx)
|
||||
require.NoError(t, err)
|
||||
|
||||
valid, err := token.CheckNotRevoked(tx)
|
||||
require.NoError(t, err)
|
||||
require.False(t, valid)
|
||||
|
||||
require.NoError(t, tx.Commit())
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
79
jwt/tokengen.go
Normal file
79
jwt/tokengen.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Generates an access token for the provided subject
|
||||
func (gen *TokenGenerator) NewAccess(
|
||||
subjectID int,
|
||||
fresh bool,
|
||||
rememberMe bool,
|
||||
) (tokenString string, expiresIn int64, err error) {
|
||||
issuedAt := time.Now().Unix()
|
||||
expiresAt := issuedAt + (gen.accessExpireAfter * 60)
|
||||
var freshExpiresAt int64
|
||||
if fresh {
|
||||
freshExpiresAt = issuedAt + (gen.freshExpireAfter * 60)
|
||||
} else {
|
||||
freshExpiresAt = issuedAt
|
||||
}
|
||||
var ttl string
|
||||
if rememberMe {
|
||||
ttl = "exp"
|
||||
} else {
|
||||
ttl = "session"
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256,
|
||||
jwt.MapClaims{
|
||||
"iss": gen.trustedHost,
|
||||
"scope": "access",
|
||||
"ttl": ttl,
|
||||
"jti": uuid.New(),
|
||||
"iat": issuedAt,
|
||||
"exp": expiresAt,
|
||||
"fresh": freshExpiresAt,
|
||||
"sub": subjectID,
|
||||
})
|
||||
|
||||
signedToken, err := token.SignedString([]byte(gen.secretKey))
|
||||
if err != nil {
|
||||
return "", 0, errors.Wrap(err, "token.SignedString")
|
||||
}
|
||||
return signedToken, expiresAt, nil
|
||||
}
|
||||
|
||||
// Generates a refresh token for the provided user
|
||||
func (gen *TokenGenerator) NewRefresh(
|
||||
subjectID int,
|
||||
rememberMe bool,
|
||||
) (tokenStr string, exp int64, err error) {
|
||||
issuedAt := time.Now().Unix()
|
||||
expiresAt := issuedAt + (gen.refreshExpireAfter * 60)
|
||||
var ttl string
|
||||
if rememberMe {
|
||||
ttl = "exp"
|
||||
} else {
|
||||
ttl = "session"
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256,
|
||||
jwt.MapClaims{
|
||||
"iss": gen.trustedHost,
|
||||
"scope": "refresh",
|
||||
"ttl": ttl,
|
||||
"jti": uuid.New(),
|
||||
"iat": issuedAt,
|
||||
"exp": expiresAt,
|
||||
"sub": subjectID,
|
||||
})
|
||||
|
||||
signedToken, err := token.SignedString([]byte(gen.secretKey))
|
||||
if err != nil {
|
||||
return "", 0, errors.Wrap(err, "token.SignedString")
|
||||
}
|
||||
return signedToken, expiresAt, nil
|
||||
}
|
||||
38
jwt/tokengen_test.go
Normal file
38
jwt/tokengen_test.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newTestGenerator(t *testing.T) *TokenGenerator {
|
||||
gen, err := CreateGenerator(
|
||||
15,
|
||||
60,
|
||||
5,
|
||||
"example.com",
|
||||
"supersecret",
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
return gen
|
||||
}
|
||||
|
||||
func TestNewAccessToken(t *testing.T) {
|
||||
gen := newTestGenerator(t)
|
||||
|
||||
tokenStr, exp, err := gen.NewAccess(123, true, false)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, tokenStr)
|
||||
require.Greater(t, exp, int64(0))
|
||||
}
|
||||
|
||||
func TestNewRefreshToken(t *testing.T) {
|
||||
gen := newTestGenerator(t)
|
||||
|
||||
tokenStr, exp, err := gen.NewRefresh(123, true)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, tokenStr)
|
||||
require.Greater(t, exp, int64(0))
|
||||
}
|
||||
71
jwt/tokens.go
Normal file
71
jwt/tokens.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type Token interface {
|
||||
GetJTI() uuid.UUID
|
||||
GetEXP() int64
|
||||
GetScope() string
|
||||
Revoke(*sql.Tx) error
|
||||
CheckNotRevoked(*sql.Tx) (bool, error)
|
||||
}
|
||||
|
||||
// Access token
|
||||
type AccessToken struct {
|
||||
ISS string // Issuer, generally TrustedHost
|
||||
IAT int64 // Time issued at
|
||||
EXP int64 // Time expiring at
|
||||
TTL string // Time-to-live: "session" or "exp". Used with 'remember me'
|
||||
SUB int // Subject (user) ID
|
||||
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
|
||||
Fresh int64 // Time freshness expiring at
|
||||
Scope string // Should be "access"
|
||||
gen *TokenGenerator
|
||||
}
|
||||
|
||||
// Refresh token
|
||||
type RefreshToken struct {
|
||||
ISS string // Issuer, generally TrustedHost
|
||||
IAT int64 // Time issued at
|
||||
EXP int64 // Time expiring at
|
||||
TTL string // Time-to-live: "session" or "exp". Used with 'remember me'
|
||||
SUB int // Subject (user) ID
|
||||
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
|
||||
Scope string // Should be "refresh"
|
||||
gen *TokenGenerator
|
||||
}
|
||||
|
||||
func (a AccessToken) GetJTI() uuid.UUID {
|
||||
return a.JTI
|
||||
}
|
||||
func (r RefreshToken) GetJTI() uuid.UUID {
|
||||
return r.JTI
|
||||
}
|
||||
func (a AccessToken) GetEXP() int64 {
|
||||
return a.EXP
|
||||
}
|
||||
func (r RefreshToken) GetEXP() int64 {
|
||||
return r.EXP
|
||||
}
|
||||
func (a AccessToken) GetScope() string {
|
||||
return a.Scope
|
||||
}
|
||||
func (r RefreshToken) GetScope() string {
|
||||
return r.Scope
|
||||
}
|
||||
func (a AccessToken) Revoke(tx *sql.Tx) error {
|
||||
return a.gen.revoke(tx, a)
|
||||
}
|
||||
func (r RefreshToken) Revoke(tx *sql.Tx) error {
|
||||
return r.gen.revoke(tx, r)
|
||||
}
|
||||
func (a AccessToken) CheckNotRevoked(tx *sql.Tx) (bool, error) {
|
||||
return a.gen.checkNotRevoked(tx, a)
|
||||
}
|
||||
func (r RefreshToken) CheckNotRevoked(tx *sql.Tx) (bool, error) {
|
||||
return r.gen.checkNotRevoked(tx, r)
|
||||
}
|
||||
123
jwt/util.go
Normal file
123
jwt/util.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Parse a token, validating its signing sigature and returning the claims
|
||||
func parseToken(secretKey string, tokenString string) (jwt.MapClaims, error) {
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
|
||||
return []byte(secretKey), nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "jwt.Parse")
|
||||
}
|
||||
// Token decoded, parse the claims
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return nil, errors.New("Failed to parse claims")
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// Check if a token is expired. Returns the expiry if not expired
|
||||
func checkTokenExpired(expiry interface{}) (int64, error) {
|
||||
// Coerce the expiry to a float64 to avoid scientific notation
|
||||
expFloat, ok := expiry.(float64)
|
||||
if !ok {
|
||||
return 0, errors.New("Missing or invalid 'exp' claim")
|
||||
}
|
||||
// Convert to the int64 time we expect :)
|
||||
expiryTime := int64(expFloat)
|
||||
|
||||
// Check if its expired
|
||||
isExpired := time.Now().After(time.Unix(expiryTime, 0))
|
||||
if isExpired {
|
||||
return 0, errors.New("Token has expired")
|
||||
}
|
||||
return expiryTime, nil
|
||||
}
|
||||
|
||||
// Check if a token has a valid issuer. Returns the issuer if valid
|
||||
func checkTokenIssuer(trustedHost string, issuer interface{}) (string, error) {
|
||||
issuerVal, ok := issuer.(string)
|
||||
if !ok {
|
||||
return "", errors.New("Missing or invalid 'iss' claim")
|
||||
}
|
||||
if issuer != trustedHost {
|
||||
return "", errors.New("Issuer does not matched trusted host")
|
||||
}
|
||||
return issuerVal, nil
|
||||
}
|
||||
|
||||
// Check the scope matches the expected scope. Returns scope if true
|
||||
func getTokenScope(scope interface{}) (string, error) {
|
||||
scopeStr, ok := scope.(string)
|
||||
if !ok {
|
||||
return "", errors.New("Missing or invalid 'scope' claim")
|
||||
}
|
||||
return scopeStr, nil
|
||||
}
|
||||
|
||||
// Get the TTL of the token, either "session" or "exp"
|
||||
func getTokenTTL(ttl interface{}) (string, error) {
|
||||
ttlStr, ok := ttl.(string)
|
||||
if !ok {
|
||||
return "", errors.New("Missing or invalid 'ttl' claim")
|
||||
}
|
||||
if ttlStr != "exp" && ttlStr != "session" {
|
||||
return "", errors.New("TTL value is not recognised")
|
||||
}
|
||||
return ttlStr, nil
|
||||
}
|
||||
|
||||
// Get the time the token was issued at
|
||||
func getIssuedTime(issued interface{}) (int64, error) {
|
||||
// Same float64 -> int64 trick as expiry
|
||||
issuedFloat, ok := issued.(float64)
|
||||
if !ok {
|
||||
return 0, errors.New("Missing or invalid 'iat' claim")
|
||||
}
|
||||
issuedAt := int64(issuedFloat)
|
||||
return issuedAt, nil
|
||||
}
|
||||
|
||||
// Get the freshness expiry timestamp
|
||||
func getFreshTime(fresh interface{}) (int64, error) {
|
||||
freshUntil, ok := fresh.(float64)
|
||||
if !ok {
|
||||
return 0, errors.New("Missing or invalid 'fresh' claim")
|
||||
}
|
||||
return int64(freshUntil), nil
|
||||
}
|
||||
|
||||
// Get the subject of the token
|
||||
func getTokenSubject(sub interface{}) (int, error) {
|
||||
subject, ok := sub.(float64)
|
||||
if !ok {
|
||||
return 0, errors.New("Missing or invalid 'sub' claim")
|
||||
}
|
||||
return int(subject), nil
|
||||
}
|
||||
|
||||
// Get the JTI of the token
|
||||
func getTokenJTI(jti interface{}) (uuid.UUID, error) {
|
||||
jtiStr, ok := jti.(string)
|
||||
if !ok {
|
||||
return uuid.UUID{}, errors.New("Missing or invalid 'jti' claim")
|
||||
}
|
||||
jtiUUID, err := uuid.Parse(jtiStr)
|
||||
if err != nil {
|
||||
return uuid.UUID{}, errors.New("JTI is not a valid UUID")
|
||||
}
|
||||
return jtiUUID, nil
|
||||
}
|
||||
146
jwt/validate.go
Normal file
146
jwt/validate.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Parse an access token and return a struct with all the claims. Does validation on
|
||||
// all the claims, including checking if it is expired, has a valid issuer, and
|
||||
// has the correct scope.
|
||||
func (gen *TokenGenerator) ValidateAccess(
|
||||
tx *sql.Tx,
|
||||
tokenString string,
|
||||
) (*AccessToken, error) {
|
||||
if tokenString == "" {
|
||||
return nil, errors.New("Access token string not provided")
|
||||
}
|
||||
claims, err := parseToken(gen.secretKey, tokenString)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "parseToken")
|
||||
}
|
||||
expiry, err := checkTokenExpired(claims["exp"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "checkTokenExpired")
|
||||
}
|
||||
issuer, err := checkTokenIssuer(gen.trustedHost, claims["iss"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "checkTokenIssuer")
|
||||
}
|
||||
ttl, err := getTokenTTL(claims["ttl"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getTokenTTL")
|
||||
}
|
||||
scope, err := getTokenScope(claims["scope"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getTokenScope")
|
||||
}
|
||||
if scope != "access" {
|
||||
return nil, errors.New("Token is not an Access token")
|
||||
}
|
||||
issuedAt, err := getIssuedTime(claims["iat"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getIssuedTime")
|
||||
}
|
||||
subject, err := getTokenSubject(claims["sub"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getTokenSubject")
|
||||
}
|
||||
fresh, err := getFreshTime(claims["fresh"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getFreshTime")
|
||||
}
|
||||
jti, err := getTokenJTI(claims["jti"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getTokenJTI")
|
||||
}
|
||||
|
||||
token := &AccessToken{
|
||||
ISS: issuer,
|
||||
TTL: ttl,
|
||||
EXP: expiry,
|
||||
IAT: issuedAt,
|
||||
SUB: subject,
|
||||
Fresh: fresh,
|
||||
JTI: jti,
|
||||
Scope: scope,
|
||||
gen: gen,
|
||||
}
|
||||
|
||||
valid, err := token.CheckNotRevoked(tx)
|
||||
if err != nil && gen.dbConn != nil {
|
||||
return nil, errors.Wrap(err, "token.CheckNotRevoked")
|
||||
}
|
||||
if !valid && gen.dbConn != nil {
|
||||
return nil, errors.New("Token has been revoked")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// Parse a refresh token and return a struct with all the claims. Does validation on
|
||||
// all the claims, including checking if it is expired, has a valid issuer, and
|
||||
// has the correct scope.
|
||||
func (gen *TokenGenerator) ValidateRefresh(
|
||||
tx *sql.Tx,
|
||||
tokenString string,
|
||||
) (*RefreshToken, error) {
|
||||
if tokenString == "" {
|
||||
return nil, errors.New("Refresh token string not provided")
|
||||
}
|
||||
claims, err := parseToken(gen.secretKey, tokenString)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "parseToken")
|
||||
}
|
||||
expiry, err := checkTokenExpired(claims["exp"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "checkTokenExpired")
|
||||
}
|
||||
issuer, err := checkTokenIssuer(gen.trustedHost, claims["iss"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "checkTokenIssuer")
|
||||
}
|
||||
ttl, err := getTokenTTL(claims["ttl"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getTokenTTL")
|
||||
}
|
||||
scope, err := getTokenScope(claims["scope"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getTokenScope")
|
||||
}
|
||||
if scope != "refresh" {
|
||||
return nil, errors.New("Token is not an Refresh token")
|
||||
}
|
||||
issuedAt, err := getIssuedTime(claims["iat"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getIssuedTime")
|
||||
}
|
||||
subject, err := getTokenSubject(claims["sub"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getTokenSubject")
|
||||
}
|
||||
jti, err := getTokenJTI(claims["jti"])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getTokenJTI")
|
||||
}
|
||||
|
||||
token := &RefreshToken{
|
||||
ISS: issuer,
|
||||
TTL: ttl,
|
||||
EXP: expiry,
|
||||
IAT: issuedAt,
|
||||
SUB: subject,
|
||||
JTI: jti,
|
||||
Scope: scope,
|
||||
gen: gen,
|
||||
}
|
||||
|
||||
valid, err := token.CheckNotRevoked(tx)
|
||||
if err != nil && gen.dbConn != nil {
|
||||
return nil, errors.Wrap(err, "token.CheckNotRevoked")
|
||||
}
|
||||
if !valid && gen.dbConn != nil {
|
||||
return nil, errors.New("Token has been revoked")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
118
jwt/validate_test.go
Normal file
118
jwt/validate_test.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, sqlmock.Sqlmock, func()) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
|
||||
gen, err := CreateGenerator(
|
||||
15,
|
||||
60,
|
||||
5,
|
||||
"example.com",
|
||||
"supersecret",
|
||||
db,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
return gen, mock, func() { db.Close() }
|
||||
}
|
||||
|
||||
func expectNotRevoked(mock sqlmock.Sqlmock, jti any) {
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectQuery(`SELECT 1 FROM jwtblacklist`).
|
||||
WithArgs(jti).
|
||||
WillReturnRows(sqlmock.NewRows([]string{}))
|
||||
mock.ExpectCommit()
|
||||
}
|
||||
|
||||
func TestValidateAccess_Success(t *testing.T) {
|
||||
gen, mock, cleanup := newGeneratorWithMockDB(t)
|
||||
defer cleanup()
|
||||
|
||||
tokenStr, _, err := gen.NewAccess(42, true, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
// We don't know the JTI beforehand; match any arg
|
||||
expectNotRevoked(mock, sqlmock.AnyArg())
|
||||
|
||||
tx, err := gen.dbConn.BeginTx(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
token, err := gen.ValidateAccess(tx, tokenStr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 42, token.SUB)
|
||||
require.Equal(t, "access", token.Scope)
|
||||
tx.Commit()
|
||||
}
|
||||
|
||||
func TestValidateAccess_NoDB(t *testing.T) {
|
||||
gen := newGeneratorWithNoDB(t)
|
||||
|
||||
tokenStr, _, err := gen.NewAccess(42, true, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err := gen.ValidateAccess(&sql.Tx{}, tokenStr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 42, token.SUB)
|
||||
require.Equal(t, "access", token.Scope)
|
||||
}
|
||||
|
||||
func TestValidateRefresh_Success(t *testing.T) {
|
||||
gen, mock, cleanup := newGeneratorWithMockDB(t)
|
||||
defer cleanup()
|
||||
|
||||
tokenStr, _, err := gen.NewRefresh(42, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectNotRevoked(mock, sqlmock.AnyArg())
|
||||
|
||||
tx, err := gen.dbConn.BeginTx(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
token, err := gen.ValidateRefresh(tx, tokenStr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 42, token.SUB)
|
||||
require.Equal(t, "refresh", token.Scope)
|
||||
tx.Commit()
|
||||
}
|
||||
|
||||
func TestValidateRefresh_NoDB(t *testing.T) {
|
||||
gen := newGeneratorWithNoDB(t)
|
||||
|
||||
tokenStr, _, err := gen.NewRefresh(42, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err := gen.ValidateRefresh(nil, tokenStr)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 42, token.SUB)
|
||||
require.Equal(t, "refresh", token.Scope)
|
||||
}
|
||||
|
||||
func TestValidateAccess_EmptyToken(t *testing.T) {
|
||||
gen := newTestGenerator(t)
|
||||
|
||||
_, err := gen.ValidateAccess(nil, "")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestValidateRefresh_WrongScope(t *testing.T) {
|
||||
gen := newTestGenerator(t)
|
||||
|
||||
// Create access token but validate as refresh
|
||||
tokenStr, _, err := gen.NewAccess(1, false, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = gen.ValidateRefresh(nil, tokenStr)
|
||||
require.Error(t, err)
|
||||
}
|
||||
32
tmdb/config.go
Normal file
32
tmdb/config.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package tmdb
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Image Image `json:"images"`
|
||||
}
|
||||
|
||||
type Image struct {
|
||||
BaseURL string `json:"base_url"`
|
||||
SecureBaseURL string `json:"secure_base_url"`
|
||||
BackdropSizes []string `json:"backdrop_sizes"`
|
||||
LogoSizes []string `json:"logo_sizes"`
|
||||
PosterSizes []string `json:"poster_sizes"`
|
||||
ProfileSizes []string `json:"profile_sizes"`
|
||||
StillSizes []string `json:"still_sizes"`
|
||||
}
|
||||
|
||||
func GetConfig(token string) (*Config, error) {
|
||||
url := "https://api.themoviedb.org/3/configuration"
|
||||
data, err := tmdbGet(url, token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "tmdbGet")
|
||||
}
|
||||
config := Config{}
|
||||
json.Unmarshal(data, &config)
|
||||
return &config, nil
|
||||
}
|
||||
54
tmdb/credits.go
Normal file
54
tmdb/credits.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package tmdb
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Credits struct {
|
||||
ID int32 `json:"id"`
|
||||
Cast []Cast `json:"cast"`
|
||||
Crew []Crew `json:"crew"`
|
||||
}
|
||||
|
||||
type Cast struct {
|
||||
Adult bool `json:"adult"`
|
||||
Gender int `json:"gender"`
|
||||
ID int32 `json:"id"`
|
||||
KnownFor string `json:"known_for_department"`
|
||||
Name string `json:"name"`
|
||||
OriginalName string `json:"original_name"`
|
||||
Popularity int `json:"popularity"`
|
||||
Profile string `json:"profile_path"`
|
||||
CastID int32 `json:"cast_id"`
|
||||
Character string `json:"character"`
|
||||
CreditID string `json:"credit_id"`
|
||||
Order int `json:"order"`
|
||||
}
|
||||
|
||||
type Crew struct {
|
||||
Adult bool `json:"adult"`
|
||||
Gender int `json:"gender"`
|
||||
ID int32 `json:"id"`
|
||||
KnownFor string `json:"known_for_department"`
|
||||
Name string `json:"name"`
|
||||
OriginalName string `json:"original_name"`
|
||||
Popularity int `json:"popularity"`
|
||||
Profile string `json:"profile_path"`
|
||||
CreditID string `json:"credit_id"`
|
||||
Department string `json:"department"`
|
||||
Job string `json:"job"`
|
||||
}
|
||||
|
||||
func GetCredits(movieid int32, token string) (*Credits, error) {
|
||||
url := fmt.Sprintf("https://api.themoviedb.org/3/movie/%v/credits?language=en-US", movieid)
|
||||
data, err := tmdbGet(url, token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "tmdbGet")
|
||||
}
|
||||
credits := Credits{}
|
||||
json.Unmarshal(data, &credits)
|
||||
return &credits, nil
|
||||
}
|
||||
41
tmdb/crew_functions.go
Normal file
41
tmdb/crew_functions.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package tmdb
|
||||
|
||||
import "sort"
|
||||
|
||||
type BilledCrew struct {
|
||||
Name string
|
||||
Roles []string
|
||||
}
|
||||
|
||||
func (credits *Credits) BilledCrew() []BilledCrew {
|
||||
crewmap := make(map[string][]string)
|
||||
billedcrew := []BilledCrew{}
|
||||
for _, crew := range credits.Crew {
|
||||
if crew.Job == "Director" ||
|
||||
crew.Job == "Screenplay" ||
|
||||
crew.Job == "Writer" ||
|
||||
crew.Job == "Novel" ||
|
||||
crew.Job == "Story" {
|
||||
crewmap[crew.Name] = append(crewmap[crew.Name], crew.Job)
|
||||
}
|
||||
}
|
||||
|
||||
for name, jobs := range crewmap {
|
||||
billedcrew = append(billedcrew, BilledCrew{Name: name, Roles: jobs})
|
||||
}
|
||||
for i := range billedcrew {
|
||||
sort.Strings(billedcrew[i].Roles)
|
||||
}
|
||||
sort.Slice(billedcrew, func(i, j int) bool {
|
||||
return billedcrew[i].Roles[0] < billedcrew[j].Roles[0]
|
||||
})
|
||||
return billedcrew
|
||||
}
|
||||
|
||||
func (billedcrew *BilledCrew) FRoles() string {
|
||||
jobs := ""
|
||||
for _, job := range billedcrew.Roles {
|
||||
jobs += job + ", "
|
||||
}
|
||||
return jobs[:len(jobs)-2]
|
||||
}
|
||||
5
tmdb/go.mod
Normal file
5
tmdb/go.mod
Normal file
@@ -0,0 +1,5 @@
|
||||
module git.haelnorr.com/h/golib/tmdb
|
||||
|
||||
go 1.25.5
|
||||
|
||||
require github.com/pkg/errors v0.9.1
|
||||
2
tmdb/go.sum
Normal file
2
tmdb/go.sum
Normal file
@@ -0,0 +1,2 @@
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
45
tmdb/movie.go
Normal file
45
tmdb/movie.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package tmdb
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Movie struct {
|
||||
Adult bool `json:"adult"`
|
||||
Backdrop string `json:"backdrop_path"`
|
||||
Collection string `json:"belongs_to_collection"`
|
||||
Budget int `json:"budget"`
|
||||
Genres []Genre `json:"genres"`
|
||||
Homepage string `json:"homepage"`
|
||||
ID int32 `json:"id"`
|
||||
IMDbID string `json:"imdb_id"`
|
||||
OriginalLanguage string `json:"original_language"`
|
||||
OriginalTitle string `json:"original_title"`
|
||||
Overview string `json:"overview"`
|
||||
Popularity float32 `json:"popularity"`
|
||||
Poster string `json:"poster_path"`
|
||||
ProductionCompanies []ProductionCompany `json:"production_companies"`
|
||||
ProductionCountries []ProductionCountry `json:"production_countries"`
|
||||
ReleaseDate string `json:"release_date"`
|
||||
Revenue int `json:"revenue"`
|
||||
Runtime int `json:"runtime"`
|
||||
SpokenLanguages []SpokenLanguage `json:"spoken_languages"`
|
||||
Status string `json:"status"`
|
||||
Tagline string `json:"tagline"`
|
||||
Title string `json:"title"`
|
||||
Video bool `json:"video"`
|
||||
}
|
||||
|
||||
func GetMovie(id int32, token string) (*Movie, error) {
|
||||
url := fmt.Sprintf("https://api.themoviedb.org/3/movie/%v?language=en-US", id)
|
||||
data, err := tmdbGet(url, token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "tmdbGet")
|
||||
}
|
||||
movie := Movie{}
|
||||
json.Unmarshal(data, &movie)
|
||||
return &movie, nil
|
||||
}
|
||||
42
tmdb/movie_functions.go
Normal file
42
tmdb/movie_functions.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package tmdb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"path"
|
||||
)
|
||||
|
||||
func (movie *Movie) FRuntime() string {
|
||||
hours := movie.Runtime / 60
|
||||
mins := movie.Runtime % 60
|
||||
return fmt.Sprintf("%dh %02dm", hours, mins)
|
||||
}
|
||||
|
||||
func (movie *Movie) GetPoster(image *Image, size string) string {
|
||||
base, err := url.Parse(image.SecureBaseURL)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
fullPath := path.Join(base.Path, size, movie.Poster)
|
||||
base.Path = fullPath
|
||||
return base.String()
|
||||
}
|
||||
|
||||
func (movie *Movie) ReleaseYear() string {
|
||||
if movie.ReleaseDate == "" {
|
||||
return ""
|
||||
} else {
|
||||
return "(" + movie.ReleaseDate[:4] + ")"
|
||||
}
|
||||
}
|
||||
|
||||
func (movie *Movie) FGenres() string {
|
||||
genres := ""
|
||||
for _, genre := range movie.Genres {
|
||||
genres += genre.Name + ", "
|
||||
}
|
||||
if len(genres) > 2 {
|
||||
return genres[:len(genres)-2]
|
||||
}
|
||||
return genres
|
||||
}
|
||||
28
tmdb/request.go
Normal file
28
tmdb/request.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package tmdb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func tmdbGet(url string, token string) ([]byte, error) {
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "http.NewRequest")
|
||||
}
|
||||
req.Header.Add("accept", "application/json")
|
||||
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "http.DefaultClient.Do")
|
||||
}
|
||||
defer res.Body.Close()
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "io.ReadAll")
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
79
tmdb/search.go
Normal file
79
tmdb/search.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package tmdb
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"path"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Result struct {
|
||||
Page int `json:"page"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
TotalResults int `json:"total_results"`
|
||||
}
|
||||
|
||||
type ResultMovies struct {
|
||||
Result
|
||||
Results []ResultMovie `json:"results"`
|
||||
}
|
||||
type ResultMovie struct {
|
||||
Adult bool `json:"adult"`
|
||||
BackdropPath string `json:"backdrop_path"`
|
||||
GenreIDs []int `json:"genre_ids"`
|
||||
ID int32 `json:"id"`
|
||||
OriginalLanguage string `json:"original_language"`
|
||||
OriginalTitle string `json:"original_title"`
|
||||
Overview string `json:"overview"`
|
||||
Popularity int `json:"popularity"`
|
||||
PosterPath string `json:"poster_path"`
|
||||
ReleaseDate string `json:"release_date"`
|
||||
Title string `json:"title"`
|
||||
Video bool `json:"video"`
|
||||
VoteAverage int `json:"vote_average"`
|
||||
VoteCount int `json:"vote_count"`
|
||||
}
|
||||
|
||||
func (movie *ResultMovie) GetPoster(image *Image, size string) string {
|
||||
base, err := url.Parse(image.SecureBaseURL)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
fullPath := path.Join(base.Path, size, movie.PosterPath)
|
||||
base.Path = fullPath
|
||||
return base.String()
|
||||
}
|
||||
|
||||
func (movie *ResultMovie) ReleaseYear() string {
|
||||
if movie.ReleaseDate == "" {
|
||||
return ""
|
||||
} else {
|
||||
return "(" + movie.ReleaseDate[:4] + ")"
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: genres list https://developer.themoviedb.org/reference/genre-movie-list
|
||||
// func (movie *ResultMovie) FGenres() string {
|
||||
// genres := ""
|
||||
// for _, genre := range movie.Genres {
|
||||
// genres += genre.Name + ", "
|
||||
// }
|
||||
// return genres[:len(genres)-2]
|
||||
// }
|
||||
|
||||
func SearchMovies(token string, query string, adult bool, page int) (*ResultMovies, error) {
|
||||
url := "https://api.themoviedb.org/3/search/movie" +
|
||||
fmt.Sprintf("?query=%s", url.QueryEscape(query)) +
|
||||
fmt.Sprintf("&include_adult=%t", adult) +
|
||||
fmt.Sprintf("&page=%v", page) +
|
||||
"&language=en-US"
|
||||
response, err := tmdbGet(url, token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "tmdbGet")
|
||||
}
|
||||
var results ResultMovies
|
||||
json.Unmarshal(response, &results)
|
||||
return &results, nil
|
||||
}
|
||||
24
tmdb/structs.go
Normal file
24
tmdb/structs.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package tmdb
|
||||
|
||||
type Genre struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type ProductionCompany struct {
|
||||
ID int `json:"id"`
|
||||
Logo string `json:"logo_path"`
|
||||
Name string `json:"name"`
|
||||
OriginCountry string `json:"origin_country"`
|
||||
}
|
||||
|
||||
type ProductionCountry struct {
|
||||
ISO_3166_1 string `json:"iso_3166_1"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type SpokenLanguage struct {
|
||||
EnglishName string `json:"english_name"`
|
||||
ISO_639_1 string `json:"iso_639_1"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
Reference in New Issue
Block a user