Added documentation to functions and basic JWT generation

This commit is contained in:
2025-02-09 00:48:30 +11:00
parent 597fc6f072
commit 25868becf3
29 changed files with 254 additions and 58 deletions

View File

@@ -14,5 +14,9 @@ dev:
air &\
tailwindcss -i ./static/css/input.css -o ./static/css/output.css --watch
test:
go mod tidy && \
go run . --port 3232 --test
clean:
go clean

View File

@@ -6,6 +6,7 @@ import (
"time"
)
// Check the value of "pagefrom" cookie, delete the cookie, and return the value
func CheckPageFrom(w http.ResponseWriter, r *http.Request) string {
pageFromCookie, err := r.Cookie("pagefrom")
if err != nil {
@@ -17,12 +18,17 @@ func CheckPageFrom(w http.ResponseWriter, r *http.Request) string {
return pageFrom
}
// Check the referer of the request, and if it matches the trustedHost, set
// the "pagefrom" cookie as the Path of the referer
func SetPageFrom(w http.ResponseWriter, r *http.Request, trustedHost string) {
referer := r.Referer()
parsedURL, err := url.Parse(referer)
if err != nil {
return
}
// NOTE: its possible this could cause an infinite redirect
// if that happens, will need to add a way to 'blacklist' certain paths
// from being set here
var pageFrom string
if parsedURL.Path == "" || parsedURL.Host != trustedHost {
pageFrom = "/"

View File

@@ -8,6 +8,7 @@ import (
_ "github.com/tursodatabase/libsql-client-go/libsql"
)
// Returns a database connection handle for the Turso DB
func ConnectToDatabase(primaryUrl *string, authToken *string) (*sql.DB, error) {
url := fmt.Sprintf("libsql://%s.turso.io?authToken=%s", *primaryUrl, *authToken)

View File

@@ -9,12 +9,13 @@ import (
)
type User struct {
ID int
Username string
Password_hash string
Created_at int64
ID int // Integer ID (index primary key)
Username string // Username (unique)
Password_hash string // Bcrypt password hash
Created_at int64 // Epoch timestamp when the user was added to the database
}
// Uses bcrypt to set the users Password_hash from the given password
func (user *User) SetPassword(conn *sql.DB, password string) error {
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
@@ -30,6 +31,7 @@ func (user *User) SetPassword(conn *sql.DB, password string) error {
return nil
}
// Uses bcrypt to check if the given password matches the users Password_hash
func (user *User) CheckPassword(password string) error {
err := bcrypt.CompareHashAndPassword([]byte(user.Password_hash), []byte(password))
if err != nil {
@@ -38,6 +40,8 @@ func (user *User) CheckPassword(password string) error {
return nil
}
// Queries the database for a user matching the given username.
// Query is case insensitive
func GetUserFromUsername(conn *sql.DB, username string) (User, error) {
query := `SELECT id, username, password_hash, created_at FROM users
WHERE username = ? COLLATE NOCASE`

1
go.mod
View File

@@ -4,6 +4,7 @@ go 1.23.5
require (
github.com/a-h/templ v0.3.833
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/joho/godotenv v1.5.1
github.com/pkg/errors v0.9.1
github.com/tursodatabase/libsql-client-go v0.0.0-20240902231107-85af5b9d094d

2
go.sum
View File

@@ -4,6 +4,8 @@ github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8
github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g=
github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo=
github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=

View File

@@ -2,9 +2,12 @@ package handlers
import (
"net/http"
"projectreshoot/view/page"
)
// Handles responses to the / path. Also serves a 404 Page for paths that
// don't have explicit handlers
func HandleRoot() http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {

View File

@@ -13,6 +13,8 @@ import (
"github.com/pkg/errors"
)
// Validates the username matches a user in the database and the password
// is correct. Returns the corresponding user
func validateLogin(conn *sql.DB, r *http.Request) (db.User, error) {
formUsername := r.FormValue("username")
formPassword := r.FormValue("password")
@@ -29,6 +31,7 @@ func validateLogin(conn *sql.DB, r *http.Request) (db.User, error) {
return user, nil
}
// Returns result of the "Remember me?" checkbox as a boolean
func checkRememberMe(r *http.Request) bool {
rememberMe := r.FormValue("remember-me")
if rememberMe == "on" {
@@ -38,7 +41,10 @@ func checkRememberMe(r *http.Request) bool {
}
}
func HandleLoginRequest(conn *sql.DB) http.Handler {
// Handles an attempted login request. On success will return a HTMX redirect
// and on fail will return the login form again, passing the error to the
// template for user feedback
func HandleLoginRequest(conn *sql.DB, secretKey string) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
r.ParseForm()
@@ -62,6 +68,8 @@ func HandleLoginRequest(conn *sql.DB) http.Handler {
)
}
// Handles a request to view the login page. Will attempt to set "pagefrom"
// cookie so a successful login can redirect the user to the page they came
func HandleLoginPage(trustedHost string) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {

View File

@@ -1,10 +1,13 @@
package handlers
import (
"github.com/a-h/templ"
"net/http"
"github.com/a-h/templ"
)
// Handler for static pages. Will render the given templ.Component to the
// http.ResponseWriter
func HandlePage(Page templ.Component) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {

View File

@@ -5,26 +5,43 @@ import (
"os"
)
// Wrapper for default FileSystem
type justFilesFilesystem struct {
fs http.FileSystem
}
// Wrapper for default File
type neuteredReaddirFile struct {
http.File
}
// Modifies the behavior of FileSystem.Open to return the neutered version of File
func (fs justFilesFilesystem) Open(name string) (http.File, error) {
f, err := fs.fs.Open(name)
if err != nil {
return nil, err
}
// Check if the requested path is a directory
// and explicitly return an error to trigger a 404
fileInfo, err := f.Stat()
if err != nil {
return nil, err
}
if fileInfo.IsDir() {
return nil, os.ErrNotExist
}
return neuteredReaddirFile{f}, nil
}
// Overrides the Readdir method of File to always return nil
func (f neuteredReaddirFile) Readdir(count int) ([]os.FileInfo, error) {
return nil, nil
}
// Handles requests for static files, without allowing access to the
// directory viewer and returning 404 if an exact file is not found
func HandleStatic() http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {

43
jwt/createtoken.go Normal file
View File

@@ -0,0 +1,43 @@
package jwt
import (
"time"
"projectreshoot/db"
"projectreshoot/server"
"github.com/golang-jwt/jwt"
"github.com/pkg/errors"
)
// Generates an access token for the provided user, using the variables set
// in the config object
func GenerateAccessToken(
config *server.Config,
user *db.User,
fresh bool,
) (string, error) {
issuedAt := time.Now().Unix()
expiresAt := issuedAt + (config.AccessTokenExpiry * 60)
var freshExpiresAt int64
if fresh {
freshExpiresAt = issuedAt + (config.TokenFreshTime * 60)
} else {
freshExpiresAt = issuedAt
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256,
jwt.MapClaims{
"iss": config.TrustedHost,
"sub": user.ID,
"aud": config.TrustedHost,
"iat": issuedAt,
"exp": expiresAt,
"fresh": freshExpiresAt,
})
signedToken, err := token.SignedString([]byte(config.SecretKey))
if err != nil {
return "", errors.Wrap(err, "token.SignedString")
}
return signedToken, nil
}

23
main.go
View File

@@ -3,12 +3,14 @@ package main
import (
"context"
"embed"
"flag"
"fmt"
"io"
"net"
"net/http"
"os"
"os/signal"
"strconv"
"sync"
"time"
@@ -18,16 +20,17 @@ import (
"github.com/pkg/errors"
)
func run(ctx context.Context, w io.Writer) error {
// Initializes and runs the server
func run(ctx context.Context, w io.Writer, args []string) error {
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
defer cancel()
config, err := server.GetConfig()
config, err := server.GetConfig(args)
if err != nil {
return errors.Wrap(err, "server.GetConfig")
}
conn, err := db.ConnectToDatabase(&config.TursoURL, &config.TursoToken)
conn, err := db.ConnectToDatabase(&config.TursoDBName, &config.TursoToken)
if err != nil {
return errors.Wrap(err, "db.ConnectToDatabase")
}
@@ -38,6 +41,12 @@ func run(ctx context.Context, w io.Writer) error {
Handler: srv,
}
// TEST: runs function for testing in dev if --test flag true
if args[1] == "true" {
test(config, conn, httpServer)
return nil
}
go func() {
fmt.Fprintf(w, "Listening on %s\n", httpServer.Addr)
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
@@ -65,9 +74,15 @@ func run(ctx context.Context, w io.Writer) error {
//go:embed static/*
var static embed.FS
// Start of runtime. Parse commandline arguments & flags, Initializes context
// and starts the server
func main() {
port := flag.String("port", "", "Override port")
test := flag.Bool("test", false, "Run test function")
flag.Parse()
args := []string{*port, strconv.FormatBool(*test)}
ctx := context.Background()
if err := run(ctx, os.Stdout); err != nil {
if err := run(ctx, os.Stdout, args); err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1)
}

View File

@@ -6,16 +6,19 @@ import (
"time"
)
// Wraps the http.ResponseWriter, adding a statusCode field
type wrappedWriter struct {
http.ResponseWriter
statusCode int
}
// Extends WriteHeader to the ResponseWriter to add the status code
func (w *wrappedWriter) WriteHeader(statusCode int) {
w.ResponseWriter.WriteHeader(statusCode)
w.statusCode = statusCode
}
// Middleware to add logs to console with details of the request
func Logging(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()

63
server/config.go Normal file
View File

@@ -0,0 +1,63 @@
package server
import (
"errors"
"fmt"
"os"
"github.com/joho/godotenv"
)
type Config struct {
Host string // Host to listen on
Port string // Port to listen on
TrustedHost string // Domain/Hostname to accept as trusted
TursoDBName string // DB Name for Turso DB/Branch
TursoToken string // Bearer token for Turso DB/Branch
SecretKey string // Secret key for signing tokens
AccessTokenExpiry int64 // Access token expiry in minutes
RefreshTokenExpiry int64 // Refresh token expiry in minutes
TokenFreshTime int64 // Time for tokens to stay fresh in minutes
}
// Load the application configuration and get a pointer to the Config object
func GetConfig(args []string) (*Config, error) {
err := godotenv.Load(".env")
if err != nil {
fmt.Println(".env file not found.")
}
var port string
if args[0] != "" {
port = args[0]
} else {
port = GetEnvDefault("PORT", "3333")
}
config := &Config{
Host: GetEnvDefault("HOST", "127.0.0.1"),
Port: port,
TrustedHost: os.Getenv("TRUSTED_HOST"),
TursoDBName: os.Getenv("TURSO_DB_NAME"),
TursoToken: os.Getenv("TURSO_AUTH_TOKEN"),
SecretKey: os.Getenv("SECRET_KEY"),
AccessTokenExpiry: GetEnvInt64("ACCESS_TOKEN_EXPIRY", 5),
RefreshTokenExpiry: GetEnvInt64("REFRESH_TOKEN_EXPIRY", 1440), // defaults to 1 day
TokenFreshTime: GetEnvInt64("TOKEN_FRESH_TIME", 5),
}
if config.TrustedHost == "" {
return nil, errors.New("Envar not set: TRUSTED_HOST")
}
if config.TursoDBName == "" {
return nil, errors.New("Envar not set: TURSO_DB_NAME")
}
if config.TursoToken == "" {
return nil, errors.New("Envar not set: TURSO_AUTH_TOKEN")
}
if config.SecretKey == "" {
return nil, errors.New("Envar not set: SECRET_KEY")
}
return config, nil
}

31
server/environment.go Normal file
View File

@@ -0,0 +1,31 @@
package server
import (
"os"
"strconv"
)
// Get an environment variable, specifying a default value if its not set
func GetEnvDefault(key string, defaultValue string) string {
val, exists := os.LookupEnv(key)
if !exists {
return defaultValue
}
return val
}
// Get an environment variable as an int64, specifying a default value if its
// not set or can't be parsed properly into an int64
func GetEnvInt64(key string, defaultValue int64) int64 {
val, exists := os.LookupEnv(key)
if !exists {
return defaultValue
}
intVal, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return defaultValue
}
return intVal
}

View File

@@ -3,10 +3,12 @@ package server
import (
"database/sql"
"net/http"
"projectreshoot/handlers"
"projectreshoot/view/page"
)
// Add all the handled routes to the mux
func addRoutes(
mux *http.ServeMux,
config *Config,
@@ -23,5 +25,5 @@ func addRoutes(
// Login page and handlers
mux.Handle("GET /login", handlers.HandleLoginPage(config.TrustedHost))
mux.Handle("POST /login", handlers.HandleLoginRequest(conn))
mux.Handle("POST /login", handlers.HandleLoginRequest(conn, config.SecretKey))
}

View File

@@ -2,56 +2,12 @@ package server
import (
"database/sql"
"errors"
"fmt"
"net/http"
"os"
"projectreshoot/middleware"
"github.com/joho/godotenv"
)
type Config struct {
Host string
Port string
TrustedHost string
TursoURL string
TursoToken string
}
func GetConfig() (*Config, error) {
err := godotenv.Load(".env")
if err != nil {
fmt.Println(".env file not found.")
}
config := &Config{
Host: os.Getenv("HOST"),
Port: os.Getenv("PORT"),
TrustedHost: os.Getenv("TRUSTED_HOST"),
TursoURL: os.Getenv("TURSO_DATABASE_URL"),
TursoToken: os.Getenv("TURSO_AUTH_TOKEN"),
}
if config.Host == "" {
return nil, errors.New("Envar not set: HOST")
}
if config.Port == "" {
return nil, errors.New("Envar not set: PORT")
}
if config.TrustedHost == "" {
return nil, errors.New("Envar not set: TRUSTED_HOST")
}
if config.TursoURL == "" {
return nil, errors.New("Envar not set: TURSO_DATABASE_URL")
}
if config.TursoToken == "" {
return nil, errors.New("Envar not set: TURSO_AUTH_TOKEN")
}
return config, nil
}
// Returns a new http.Handler with all the routes and middleware added
func NewServer(config *Config, conn *sql.DB) http.Handler {
mux := http.NewServeMux()
addRoutes(

15
tester.go Normal file
View File

@@ -0,0 +1,15 @@
package main
import (
"database/sql"
"net/http"
"projectreshoot/server"
)
// This function will only be called if the --test commandline flag is set.
// After the function finishes the application will close.
// Running command `make test` will run the test using port 3232 to avoid
// conflicts on the default 3333. Useful for testing things out during dev
func test(config *server.Config, conn *sql.DB, srv *http.Server) {
}

View File

@@ -5,6 +5,7 @@ type FooterItem struct {
href string
}
// Specify the links to show in the footer
func getFooterItems() []FooterItem {
return []FooterItem{
{
@@ -18,6 +19,7 @@ func getFooterItems() []FooterItem {
}
}
// Returns the template fragment for the Footer
templ Footer() {
<footer class="bg-mantle mt-10">
<div

View File

@@ -2,6 +2,10 @@ package form
import "fmt"
// Login Form. If loginError is not an empty string, it will display the
// contents of loginError to the user.
// If loginError is "Username or password incorrect" it will also show
// error icons on the username and password field
templ LoginForm(loginError string) {
{{
var errCreds string

View File

@@ -1,10 +1,11 @@
package nav
type NavItem struct {
name string
href string
name string // Label to display
href string // Link reference
}
// Return the list of navbar links
func getNavItems() []NavItem {
return []NavItem{
{
@@ -14,6 +15,7 @@ func getNavItems() []NavItem {
}
}
// Returns the navbar template fragment
templ Navbar() {
{{ navItems := getNavItems() }}
<div x-data="{ open: false }">

View File

@@ -1,5 +1,6 @@
package nav
// Returns the left portion of the navbar
templ navLeft(navItems []NavItem) {
<nav aria-label="Global" class="hidden sm:block">
<ul class="flex items-center gap-6 text-xl">

View File

@@ -1,5 +1,6 @@
package nav
// Returns the right portion of the navbar
templ navRight() {
<div class="flex items-center gap-2">
<div class="sm:flex sm:gap-2">

View File

@@ -1,5 +1,6 @@
package nav
// Returns the mobile version of the navbar thats only visible when activated
templ sideNav(navItems []NavItem) {
<div
x-show="open"

View File

@@ -3,6 +3,8 @@ package layout
import "projectreshoot/view/component/nav"
import "projectreshoot/view/component/footer"
// Global page layout. Includes HTML document settings, header tags
// navbar and footer
templ Global() {
<!DOCTYPE html>
<html

View File

@@ -2,6 +2,7 @@ package page
import "projectreshoot/view/layout"
// Returns the about page content
templ About() {
@layout.Global() {
<div class="text-center max-w-150 m-auto">

View File

@@ -2,6 +2,9 @@ package page
import "projectreshoot/view/layout"
// Page template for Error pages. Error code should be a HTTP status code as
// a string, and err should be the corresponding response title.
// Message is a custom error message displayed below the code and error.
templ Error(code string, err string, message string) {
@layout.Global() {
<div

View File

@@ -2,6 +2,7 @@ package page
import "projectreshoot/view/layout"
// Page content for the index page
templ Index() {
@layout.Global() {
<div class="text-center mt-24">

View File

@@ -3,6 +3,7 @@ package page
import "projectreshoot/view/layout"
import "projectreshoot/view/component/form"
// Returns the login page
templ Login() {
@layout.Global() {
<div class="max-w-100 mx-auto px-2">