diff --git a/Makefile b/Makefile index 57e8d72..860fbd0 100644 --- a/Makefile +++ b/Makefile @@ -14,5 +14,9 @@ dev: air &\ tailwindcss -i ./static/css/input.css -o ./static/css/output.css --watch +test: + go mod tidy && \ + go run . --port 3232 --test + clean: go clean diff --git a/cookies/pagefrom.go b/cookies/pagefrom.go index 18c2af1..ace8040 100644 --- a/cookies/pagefrom.go +++ b/cookies/pagefrom.go @@ -6,6 +6,7 @@ import ( "time" ) +// Check the value of "pagefrom" cookie, delete the cookie, and return the value func CheckPageFrom(w http.ResponseWriter, r *http.Request) string { pageFromCookie, err := r.Cookie("pagefrom") if err != nil { @@ -17,12 +18,17 @@ func CheckPageFrom(w http.ResponseWriter, r *http.Request) string { return pageFrom } +// Check the referer of the request, and if it matches the trustedHost, set +// the "pagefrom" cookie as the Path of the referer func SetPageFrom(w http.ResponseWriter, r *http.Request, trustedHost string) { referer := r.Referer() parsedURL, err := url.Parse(referer) if err != nil { return } + // NOTE: its possible this could cause an infinite redirect + // if that happens, will need to add a way to 'blacklist' certain paths + // from being set here var pageFrom string if parsedURL.Path == "" || parsedURL.Host != trustedHost { pageFrom = "/" diff --git a/db/connection.go b/db/connection.go index 23381a2..f2e6900 100644 --- a/db/connection.go +++ b/db/connection.go @@ -8,6 +8,7 @@ import ( _ "github.com/tursodatabase/libsql-client-go/libsql" ) +// Returns a database connection handle for the Turso DB func ConnectToDatabase(primaryUrl *string, authToken *string) (*sql.DB, error) { url := fmt.Sprintf("libsql://%s.turso.io?authToken=%s", *primaryUrl, *authToken) diff --git a/db/users.go b/db/users.go index b0cb5e8..27cc7a1 100644 --- a/db/users.go +++ b/db/users.go @@ -9,12 +9,13 @@ import ( ) type User struct { - ID int - Username string - Password_hash string - Created_at int64 + ID int // Integer ID (index primary key) + Username string // Username (unique) + Password_hash string // Bcrypt password hash + Created_at int64 // Epoch timestamp when the user was added to the database } +// Uses bcrypt to set the users Password_hash from the given password func (user *User) SetPassword(conn *sql.DB, password string) error { hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { @@ -30,6 +31,7 @@ func (user *User) SetPassword(conn *sql.DB, password string) error { return nil } +// Uses bcrypt to check if the given password matches the users Password_hash func (user *User) CheckPassword(password string) error { err := bcrypt.CompareHashAndPassword([]byte(user.Password_hash), []byte(password)) if err != nil { @@ -38,6 +40,8 @@ func (user *User) CheckPassword(password string) error { return nil } +// Queries the database for a user matching the given username. +// Query is case insensitive func GetUserFromUsername(conn *sql.DB, username string) (User, error) { query := `SELECT id, username, password_hash, created_at FROM users WHERE username = ? COLLATE NOCASE` diff --git a/go.mod b/go.mod index c3655ff..726d608 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.23.5 require ( github.com/a-h/templ v0.3.833 + github.com/golang-jwt/jwt v3.2.2+incompatible github.com/joho/godotenv v1.5.1 github.com/pkg/errors v0.9.1 github.com/tursodatabase/libsql-client-go v0.0.0-20240902231107-85af5b9d094d diff --git a/go.sum b/go.sum index 666d29a..7359dbf 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8 github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= diff --git a/handlers/index.go b/handlers/index.go index 9cf524e..072408a 100644 --- a/handlers/index.go +++ b/handlers/index.go @@ -2,9 +2,12 @@ package handlers import ( "net/http" + "projectreshoot/view/page" ) +// Handles responses to the / path. Also serves a 404 Page for paths that +// don't have explicit handlers func HandleRoot() http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { diff --git a/handlers/login.go b/handlers/login.go index 1bc6806..9c84442 100644 --- a/handlers/login.go +++ b/handlers/login.go @@ -13,6 +13,8 @@ import ( "github.com/pkg/errors" ) +// Validates the username matches a user in the database and the password +// is correct. Returns the corresponding user func validateLogin(conn *sql.DB, r *http.Request) (db.User, error) { formUsername := r.FormValue("username") formPassword := r.FormValue("password") @@ -29,6 +31,7 @@ func validateLogin(conn *sql.DB, r *http.Request) (db.User, error) { return user, nil } +// Returns result of the "Remember me?" checkbox as a boolean func checkRememberMe(r *http.Request) bool { rememberMe := r.FormValue("remember-me") if rememberMe == "on" { @@ -38,7 +41,10 @@ func checkRememberMe(r *http.Request) bool { } } -func HandleLoginRequest(conn *sql.DB) http.Handler { +// Handles an attempted login request. On success will return a HTMX redirect +// and on fail will return the login form again, passing the error to the +// template for user feedback +func HandleLoginRequest(conn *sql.DB, secretKey string) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { r.ParseForm() @@ -62,6 +68,8 @@ func HandleLoginRequest(conn *sql.DB) http.Handler { ) } +// Handles a request to view the login page. Will attempt to set "pagefrom" +// cookie so a successful login can redirect the user to the page they came func HandleLoginPage(trustedHost string) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { diff --git a/handlers/page.go b/handlers/page.go index 54f8041..223ff78 100644 --- a/handlers/page.go +++ b/handlers/page.go @@ -1,10 +1,13 @@ package handlers import ( - "github.com/a-h/templ" "net/http" + + "github.com/a-h/templ" ) +// Handler for static pages. Will render the given templ.Component to the +// http.ResponseWriter func HandlePage(Page templ.Component) http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { diff --git a/handlers/static.go b/handlers/static.go index ccee73f..768e6e1 100644 --- a/handlers/static.go +++ b/handlers/static.go @@ -5,26 +5,43 @@ import ( "os" ) +// Wrapper for default FileSystem type justFilesFilesystem struct { fs http.FileSystem } +// Wrapper for default File type neuteredReaddirFile struct { http.File } +// Modifies the behavior of FileSystem.Open to return the neutered version of File func (fs justFilesFilesystem) Open(name string) (http.File, error) { f, err := fs.fs.Open(name) if err != nil { return nil, err } + + // Check if the requested path is a directory + // and explicitly return an error to trigger a 404 + fileInfo, err := f.Stat() + if err != nil { + return nil, err + } + if fileInfo.IsDir() { + return nil, os.ErrNotExist + } + return neuteredReaddirFile{f}, nil } +// Overrides the Readdir method of File to always return nil func (f neuteredReaddirFile) Readdir(count int) ([]os.FileInfo, error) { return nil, nil } +// Handles requests for static files, without allowing access to the +// directory viewer and returning 404 if an exact file is not found func HandleStatic() http.Handler { return http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { diff --git a/jwt/createtoken.go b/jwt/createtoken.go new file mode 100644 index 0000000..ed43146 --- /dev/null +++ b/jwt/createtoken.go @@ -0,0 +1,43 @@ +package jwt + +import ( + "time" + + "projectreshoot/db" + "projectreshoot/server" + + "github.com/golang-jwt/jwt" + "github.com/pkg/errors" +) + +// Generates an access token for the provided user, using the variables set +// in the config object +func GenerateAccessToken( + config *server.Config, + user *db.User, + fresh bool, +) (string, error) { + issuedAt := time.Now().Unix() + expiresAt := issuedAt + (config.AccessTokenExpiry * 60) + var freshExpiresAt int64 + if fresh { + freshExpiresAt = issuedAt + (config.TokenFreshTime * 60) + } else { + freshExpiresAt = issuedAt + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, + jwt.MapClaims{ + "iss": config.TrustedHost, + "sub": user.ID, + "aud": config.TrustedHost, + "iat": issuedAt, + "exp": expiresAt, + "fresh": freshExpiresAt, + }) + + signedToken, err := token.SignedString([]byte(config.SecretKey)) + if err != nil { + return "", errors.Wrap(err, "token.SignedString") + } + return signedToken, nil +} diff --git a/main.go b/main.go index 3e5e31d..f8c20d8 100644 --- a/main.go +++ b/main.go @@ -3,12 +3,14 @@ package main import ( "context" "embed" + "flag" "fmt" "io" "net" "net/http" "os" "os/signal" + "strconv" "sync" "time" @@ -18,16 +20,17 @@ import ( "github.com/pkg/errors" ) -func run(ctx context.Context, w io.Writer) error { +// Initializes and runs the server +func run(ctx context.Context, w io.Writer, args []string) error { ctx, cancel := signal.NotifyContext(ctx, os.Interrupt) defer cancel() - config, err := server.GetConfig() + config, err := server.GetConfig(args) if err != nil { return errors.Wrap(err, "server.GetConfig") } - conn, err := db.ConnectToDatabase(&config.TursoURL, &config.TursoToken) + conn, err := db.ConnectToDatabase(&config.TursoDBName, &config.TursoToken) if err != nil { return errors.Wrap(err, "db.ConnectToDatabase") } @@ -38,6 +41,12 @@ func run(ctx context.Context, w io.Writer) error { Handler: srv, } + // TEST: runs function for testing in dev if --test flag true + if args[1] == "true" { + test(config, conn, httpServer) + return nil + } + go func() { fmt.Fprintf(w, "Listening on %s\n", httpServer.Addr) if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { @@ -65,9 +74,15 @@ func run(ctx context.Context, w io.Writer) error { //go:embed static/* var static embed.FS +// Start of runtime. Parse commandline arguments & flags, Initializes context +// and starts the server func main() { + port := flag.String("port", "", "Override port") + test := flag.Bool("test", false, "Run test function") + flag.Parse() + args := []string{*port, strconv.FormatBool(*test)} ctx := context.Background() - if err := run(ctx, os.Stdout); err != nil { + if err := run(ctx, os.Stdout, args); err != nil { fmt.Fprintf(os.Stderr, "%s\n", err) os.Exit(1) } diff --git a/middleware/logging.go b/middleware/logging.go index ae0e019..df95866 100644 --- a/middleware/logging.go +++ b/middleware/logging.go @@ -6,16 +6,19 @@ import ( "time" ) +// Wraps the http.ResponseWriter, adding a statusCode field type wrappedWriter struct { http.ResponseWriter statusCode int } +// Extends WriteHeader to the ResponseWriter to add the status code func (w *wrappedWriter) WriteHeader(statusCode int) { w.ResponseWriter.WriteHeader(statusCode) w.statusCode = statusCode } +// Middleware to add logs to console with details of the request func Logging(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() diff --git a/server/config.go b/server/config.go new file mode 100644 index 0000000..e4c46b7 --- /dev/null +++ b/server/config.go @@ -0,0 +1,63 @@ +package server + +import ( + "errors" + "fmt" + "os" + + "github.com/joho/godotenv" +) + +type Config struct { + Host string // Host to listen on + Port string // Port to listen on + TrustedHost string // Domain/Hostname to accept as trusted + TursoDBName string // DB Name for Turso DB/Branch + TursoToken string // Bearer token for Turso DB/Branch + SecretKey string // Secret key for signing tokens + AccessTokenExpiry int64 // Access token expiry in minutes + RefreshTokenExpiry int64 // Refresh token expiry in minutes + TokenFreshTime int64 // Time for tokens to stay fresh in minutes +} + +// Load the application configuration and get a pointer to the Config object +func GetConfig(args []string) (*Config, error) { + err := godotenv.Load(".env") + if err != nil { + fmt.Println(".env file not found.") + } + var port string + + if args[0] != "" { + port = args[0] + } else { + port = GetEnvDefault("PORT", "3333") + } + + config := &Config{ + Host: GetEnvDefault("HOST", "127.0.0.1"), + Port: port, + TrustedHost: os.Getenv("TRUSTED_HOST"), + TursoDBName: os.Getenv("TURSO_DB_NAME"), + TursoToken: os.Getenv("TURSO_AUTH_TOKEN"), + SecretKey: os.Getenv("SECRET_KEY"), + AccessTokenExpiry: GetEnvInt64("ACCESS_TOKEN_EXPIRY", 5), + RefreshTokenExpiry: GetEnvInt64("REFRESH_TOKEN_EXPIRY", 1440), // defaults to 1 day + TokenFreshTime: GetEnvInt64("TOKEN_FRESH_TIME", 5), + } + + if config.TrustedHost == "" { + return nil, errors.New("Envar not set: TRUSTED_HOST") + } + if config.TursoDBName == "" { + return nil, errors.New("Envar not set: TURSO_DB_NAME") + } + if config.TursoToken == "" { + return nil, errors.New("Envar not set: TURSO_AUTH_TOKEN") + } + if config.SecretKey == "" { + return nil, errors.New("Envar not set: SECRET_KEY") + } + + return config, nil +} diff --git a/server/environment.go b/server/environment.go new file mode 100644 index 0000000..f932b53 --- /dev/null +++ b/server/environment.go @@ -0,0 +1,31 @@ +package server + +import ( + "os" + "strconv" +) + +// Get an environment variable, specifying a default value if its not set +func GetEnvDefault(key string, defaultValue string) string { + val, exists := os.LookupEnv(key) + if !exists { + return defaultValue + } + return val +} + +// Get an environment variable as an int64, specifying a default value if its +// not set or can't be parsed properly into an int64 +func GetEnvInt64(key string, defaultValue int64) int64 { + val, exists := os.LookupEnv(key) + if !exists { + return defaultValue + } + + intVal, err := strconv.ParseInt(val, 10, 64) + if err != nil { + return defaultValue + } + return intVal + +} diff --git a/server/routes.go b/server/routes.go index 1edc17a..2adc0e2 100644 --- a/server/routes.go +++ b/server/routes.go @@ -3,10 +3,12 @@ package server import ( "database/sql" "net/http" + "projectreshoot/handlers" "projectreshoot/view/page" ) +// Add all the handled routes to the mux func addRoutes( mux *http.ServeMux, config *Config, @@ -23,5 +25,5 @@ func addRoutes( // Login page and handlers mux.Handle("GET /login", handlers.HandleLoginPage(config.TrustedHost)) - mux.Handle("POST /login", handlers.HandleLoginRequest(conn)) + mux.Handle("POST /login", handlers.HandleLoginRequest(conn, config.SecretKey)) } diff --git a/server/server.go b/server/server.go index ba1eb82..8a531cf 100644 --- a/server/server.go +++ b/server/server.go @@ -2,56 +2,12 @@ package server import ( "database/sql" - "errors" - "fmt" "net/http" - "os" "projectreshoot/middleware" - - "github.com/joho/godotenv" ) -type Config struct { - Host string - Port string - TrustedHost string - TursoURL string - TursoToken string -} - -func GetConfig() (*Config, error) { - err := godotenv.Load(".env") - if err != nil { - fmt.Println(".env file not found.") - } - - config := &Config{ - Host: os.Getenv("HOST"), - Port: os.Getenv("PORT"), - TrustedHost: os.Getenv("TRUSTED_HOST"), - TursoURL: os.Getenv("TURSO_DATABASE_URL"), - TursoToken: os.Getenv("TURSO_AUTH_TOKEN"), - } - if config.Host == "" { - return nil, errors.New("Envar not set: HOST") - } - if config.Port == "" { - return nil, errors.New("Envar not set: PORT") - } - if config.TrustedHost == "" { - return nil, errors.New("Envar not set: TRUSTED_HOST") - } - if config.TursoURL == "" { - return nil, errors.New("Envar not set: TURSO_DATABASE_URL") - } - if config.TursoToken == "" { - return nil, errors.New("Envar not set: TURSO_AUTH_TOKEN") - } - - return config, nil -} - +// Returns a new http.Handler with all the routes and middleware added func NewServer(config *Config, conn *sql.DB) http.Handler { mux := http.NewServeMux() addRoutes( diff --git a/tester.go b/tester.go new file mode 100644 index 0000000..a0f4ea5 --- /dev/null +++ b/tester.go @@ -0,0 +1,15 @@ +package main + +import ( + "database/sql" + "net/http" + + "projectreshoot/server" +) + +// This function will only be called if the --test commandline flag is set. +// After the function finishes the application will close. +// Running command `make test` will run the test using port 3232 to avoid +// conflicts on the default 3333. Useful for testing things out during dev +func test(config *server.Config, conn *sql.DB, srv *http.Server) { +} diff --git a/view/component/footer/footer.templ b/view/component/footer/footer.templ index ff3c131..4db4868 100644 --- a/view/component/footer/footer.templ +++ b/view/component/footer/footer.templ @@ -5,6 +5,7 @@ type FooterItem struct { href string } +// Specify the links to show in the footer func getFooterItems() []FooterItem { return []FooterItem{ { @@ -18,6 +19,7 @@ func getFooterItems() []FooterItem { } } +// Returns the template fragment for the Footer templ Footer() {