Compare commits

...

5 Commits

Author SHA1 Message Date
ade3fa0454 imported env module 2026-01-02 19:03:07 +11:00
516be905a9 imported cookies module 2026-01-02 18:25:38 +11:00
6e632267ea added cookie control to jwt 2026-01-02 18:15:49 +11:00
05aad5f11b fixed transaction issues 2026-01-01 22:44:39 +11:00
c4574e32c7 imported tmdb module 2026-01-01 20:42:50 +11:00
25 changed files with 675 additions and 69 deletions

19
cookies/delete.go Normal file
View File

@@ -0,0 +1,19 @@
package cookies
import (
"net/http"
"time"
)
// Tell the browser to delete the cookie matching the name provided
// Path must match the original set cookie for it to delete
func DeleteCookie(w http.ResponseWriter, name string, path string) {
http.SetCookie(w, &http.Cookie{
Name: name,
Value: "",
Path: path,
Expires: time.Unix(0, 0), // Expire in the past
MaxAge: -1, // Immediately expire
HttpOnly: true,
})
}

3
cookies/go.mod Normal file
View File

@@ -0,0 +1,3 @@
module git.haelnorr.com/h/golib/cookies
go 1.25.5

36
cookies/pagefrom.go Normal file
View File

@@ -0,0 +1,36 @@
package cookies
import (
"net/http"
"net/url"
)
// Check the value of "pagefrom" cookie, delete the cookie, and return the value
func CheckPageFrom(w http.ResponseWriter, r *http.Request) string {
pageFromCookie, err := r.Cookie("pagefrom")
if err != nil {
return "/"
}
pageFrom := pageFromCookie.Value
DeleteCookie(w, pageFromCookie.Name, pageFromCookie.Path)
return pageFrom
}
// Check the referer of the request, and if it matches the trustedHost, set
// the "pagefrom" cookie as the Path of the referer
func SetPageFrom(w http.ResponseWriter, r *http.Request, trustedHost string) {
referer := r.Referer()
parsedURL, err := url.Parse(referer)
if err != nil {
return
}
var pageFrom string
if parsedURL.Path == "" || parsedURL.Host != trustedHost {
pageFrom = "/"
} else if parsedURL.Path == "/login" || parsedURL.Path == "/register" {
return
} else {
pageFrom = parsedURL.Path
}
SetCookie(w, "pagefrom", "/", pageFrom, 0)
}

23
cookies/set.go Normal file
View File

@@ -0,0 +1,23 @@
package cookies
import (
"net/http"
)
// Set a cookie with the given name, path and value. maxAge directly relates
// to cookie MaxAge (0 for no max age, >0 for TTL in seconds)
func SetCookie(
w http.ResponseWriter,
name string,
path string,
value string,
maxAge int,
) {
http.SetCookie(w, &http.Cookie{
Name: name,
Value: value,
Path: path,
HttpOnly: true,
MaxAge: maxAge,
})
}

35
env/boolean.go vendored Normal file
View File

@@ -0,0 +1,35 @@
package env
import (
"os"
"strings"
)
// Get an environment variable as a boolean, specifying a default value if its
// not set or can't be parsed properly into a bool
func Bool(key string, defaultValue bool) bool {
val, exists := os.LookupEnv(key)
if !exists {
return defaultValue
}
truthy := map[string]bool{
"true": true, "t": true, "yes": true, "y": true, "on": true, "1": true,
"enable": true, "enabled": true, "active": true, "affirmative": true,
}
falsy := map[string]bool{
"false": false, "f": false, "no": false, "n": false, "off": false, "0": false,
"disable": false, "disabled": false, "inactive": false, "negative": false,
}
normalized := strings.TrimSpace(strings.ToLower(val))
if val, ok := truthy[normalized]; ok {
return val
}
if val, ok := falsy[normalized]; ok {
return val
}
return defaultValue
}

23
env/duration.go vendored Normal file
View File

@@ -0,0 +1,23 @@
package env
import (
"os"
"strconv"
"time"
)
// Get an environment variable as a time.Duration, specifying a default value if its
// not set or can't be parsed properly
func Duration(key string, defaultValue time.Duration) time.Duration {
val, exists := os.LookupEnv(key)
if !exists {
return time.Duration(defaultValue)
}
intVal, err := strconv.Atoi(val)
if err != nil {
return time.Duration(defaultValue)
}
return time.Duration(intVal)
}

3
env/go.mod vendored Normal file
View File

@@ -0,0 +1,3 @@
module git.haelnorr.com/h/golib/env
go 1.25.5

37
env/int.go vendored Normal file
View File

@@ -0,0 +1,37 @@
package env
import (
"os"
"strconv"
)
// Get an environment variable as an int, specifying a default value if its
// not set or can't be parsed properly into an int
func Int(key string, defaultValue int) int {
val, exists := os.LookupEnv(key)
if !exists {
return defaultValue
}
intVal, err := strconv.Atoi(val)
if err != nil {
return defaultValue
}
return intVal
}
// Get an environment variable as an int64, specifying a default value if its
// not set or can't be parsed properly into an int64
func Int64(key string, defaultValue int64) int64 {
val, exists := os.LookupEnv(key)
if !exists {
return defaultValue
}
intVal, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return defaultValue
}
return intVal
}

14
env/string.go vendored Normal file
View File

@@ -0,0 +1,14 @@
package env
import (
"os"
)
// Get an environment variable, specifying a default value if its not set
func String(key string, defaultValue string) string {
val, exists := os.LookupEnv(key)
if !exists {
return defaultValue
}
return val
}

73
jwt/cookies.go Normal file
View File

@@ -0,0 +1,73 @@
package jwt
import (
"github.com/pkg/errors"
"net/http"
"time"
)
// Get the value of the access and refresh tokens
func GetTokenCookies(
r *http.Request,
) (acc string, ref string) {
accCookie, accErr := r.Cookie("access")
refCookie, refErr := r.Cookie("refresh")
var (
accStr string = ""
refStr string = ""
)
if accErr == nil {
accStr = accCookie.Value
}
if refErr == nil {
refStr = refCookie.Value
}
return accStr, refStr
}
// Set a token with the provided details
func setToken(
w http.ResponseWriter,
token string,
scope string,
exp int64,
rememberme bool,
useSSL bool,
) {
tokenCookie := &http.Cookie{
Name: scope,
Value: token,
Path: "/",
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Secure: useSSL,
}
if rememberme {
tokenCookie.Expires = time.Unix(exp, 0)
}
http.SetCookie(w, tokenCookie)
}
// Generate new tokens for the subject and set them as cookies
func SetTokenCookies(
w http.ResponseWriter,
r *http.Request,
tokenGen *TokenGenerator,
subject int,
fresh bool,
rememberMe bool,
useSSL bool,
) error {
at, atexp, err := tokenGen.NewAccess(subject, fresh, rememberMe)
if err != nil {
return errors.Wrap(err, "jwt.GenerateAccessToken")
}
rt, rtexp, err := tokenGen.NewRefresh(subject, rememberMe)
if err != nil {
return errors.Wrap(err, "jwt.GenerateRefreshToken")
}
// Don't set the cookies until we know no errors occured
setToken(w, at, "access", atexp, rememberMe, useSSL)
setToken(w, rt, "refresh", rtexp, rememberMe, useSSL)
return nil
}

View File

@@ -1,47 +1,31 @@
package jwt
import (
"context"
"database/sql"
"github.com/pkg/errors"
)
// Revoke a token by adding it to the database
func revoke(ctx context.Context, t Token) error {
db := t.getDB()
if db == nil {
func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error {
if gen.dbConn == nil {
return errors.New("No DB provided, unable to use this function")
}
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return errors.Wrap(err, "db.BeginTx")
}
defer tx.Rollback()
jti := t.GetJTI()
exp := t.GetEXP()
query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)`
_, err = tx.Exec(query, jti, exp)
_, err := tx.Exec(query, jti, exp)
if err != nil {
return errors.Wrap(err, "tx.Exec")
}
err = tx.Commit()
if err != nil {
return errors.Wrap(err, "tx.Commit")
}
return nil
}
// Check if a token has been revoked. Returns true if not revoked.
func checkNotRevoked(ctx context.Context, t Token) (bool, error) {
db := t.getDB()
if db == nil {
func (gen *TokenGenerator) checkNotRevoked(tx *sql.Tx, t Token) (bool, error) {
if gen.dbConn == nil {
return false, errors.New("No DB provided, unable to use this function")
}
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return false, errors.Wrap(err, "db.BeginTx")
}
defer tx.Rollback()
jti := t.GetJTI()
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
rows, err := tx.Query(query, jti)
@@ -50,9 +34,5 @@ func checkNotRevoked(ctx context.Context, t Token) (bool, error) {
}
defer rows.Close()
revoked := rows.Next()
err = tx.Commit()
if err != nil {
return false, errors.Wrap(err, "tx.Commit")
}
return !revoked, nil
}

View File

@@ -2,6 +2,7 @@ package jwt
import (
"context"
"database/sql"
"testing"
"time"
@@ -31,14 +32,15 @@ func TestNoDBFail(t *testing.T) {
token := AccessToken{
JTI: jti,
EXP: exp,
gen: &TokenGenerator{},
}
// Revoke should fail due to no DB
err := token.Revoke(context.Background())
err := token.Revoke(&sql.Tx{})
require.Error(t, err)
// CheckNotRevoked should fail
_, err = token.CheckNotRevoked(context.Background())
_, err = token.CheckNotRevoked(&sql.Tx{})
require.Error(t, err)
}
@@ -52,7 +54,7 @@ func TestRevokeAndCheckNotRevoked(t *testing.T) {
token := AccessToken{
JTI: jti,
EXP: exp,
db: gen.dbConn,
gen: gen,
}
// Revoke expectations
@@ -60,21 +62,22 @@ func TestRevokeAndCheckNotRevoked(t *testing.T) {
mock.ExpectExec(`INSERT INTO jwtblacklist`).
WithArgs(jti, exp).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
err := token.Revoke(context.Background())
require.NoError(t, err)
// CheckNotRevoked expectations (now revoked)
mock.ExpectBegin()
mock.ExpectQuery(`SELECT 1 FROM jwtblacklist`).
WithArgs(jti).
WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
mock.ExpectCommit()
valid, err := token.CheckNotRevoked(context.Background())
tx, err := gen.dbConn.BeginTx(context.Background(), nil)
defer tx.Rollback()
require.NoError(t, err)
err = token.Revoke(tx)
require.NoError(t, err)
valid, err := token.CheckNotRevoked(tx)
require.NoError(t, err)
require.False(t, valid)
require.NoError(t, tx.Commit())
require.NoError(t, mock.ExpectationsWereMet())
}

View File

@@ -1,7 +1,6 @@
package jwt
import (
"context"
"database/sql"
"github.com/google/uuid"
@@ -11,8 +10,8 @@ type Token interface {
GetJTI() uuid.UUID
GetEXP() int64
GetScope() string
getDB() *sql.DB
Revoke(context.Context) error
Revoke(*sql.Tx) error
CheckNotRevoked(*sql.Tx) (bool, error)
}
// Access token
@@ -25,7 +24,7 @@ type AccessToken struct {
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
Fresh int64 // Time freshness expiring at
Scope string // Should be "access"
db *sql.DB
gen *TokenGenerator
}
// Refresh token
@@ -37,7 +36,7 @@ type RefreshToken struct {
SUB int // Subject (user) ID
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
Scope string // Should be "refresh"
db *sql.DB
gen *TokenGenerator
}
func (a AccessToken) GetJTI() uuid.UUID {
@@ -58,21 +57,15 @@ func (a AccessToken) GetScope() string {
func (r RefreshToken) GetScope() string {
return r.Scope
}
func (a AccessToken) getDB() *sql.DB {
return a.db
func (a AccessToken) Revoke(tx *sql.Tx) error {
return a.gen.revoke(tx, a)
}
func (r RefreshToken) getDB() *sql.DB {
return r.db
func (r RefreshToken) Revoke(tx *sql.Tx) error {
return r.gen.revoke(tx, r)
}
func (a AccessToken) Revoke(ctx context.Context) error {
return revoke(ctx, a)
func (a AccessToken) CheckNotRevoked(tx *sql.Tx) (bool, error) {
return a.gen.checkNotRevoked(tx, a)
}
func (r RefreshToken) Revoke(ctx context.Context) error {
return revoke(ctx, r)
}
func (a AccessToken) CheckNotRevoked(ctx context.Context) (bool, error) {
return checkNotRevoked(ctx, a)
}
func (r RefreshToken) CheckNotRevoked(ctx context.Context) (bool, error) {
return checkNotRevoked(ctx, r)
func (r RefreshToken) CheckNotRevoked(tx *sql.Tx) (bool, error) {
return r.gen.checkNotRevoked(tx, r)
}

View File

@@ -1,7 +1,8 @@
package jwt
import (
"context"
"database/sql"
"github.com/pkg/errors"
)
@@ -9,7 +10,7 @@ import (
// all the claims, including checking if it is expired, has a valid issuer, and
// has the correct scope.
func (gen *TokenGenerator) ValidateAccess(
ctx context.Context,
tx *sql.Tx,
tokenString string,
) (*AccessToken, error) {
if tokenString == "" {
@@ -64,10 +65,10 @@ func (gen *TokenGenerator) ValidateAccess(
Fresh: fresh,
JTI: jti,
Scope: scope,
db: gen.dbConn,
gen: gen,
}
valid, err := token.CheckNotRevoked(ctx)
valid, err := token.CheckNotRevoked(tx)
if err != nil && gen.dbConn != nil {
return nil, errors.Wrap(err, "token.CheckNotRevoked")
}
@@ -81,7 +82,7 @@ func (gen *TokenGenerator) ValidateAccess(
// all the claims, including checking if it is expired, has a valid issuer, and
// has the correct scope.
func (gen *TokenGenerator) ValidateRefresh(
ctx context.Context,
tx *sql.Tx,
tokenString string,
) (*RefreshToken, error) {
if tokenString == "" {
@@ -131,10 +132,10 @@ func (gen *TokenGenerator) ValidateRefresh(
SUB: subject,
JTI: jti,
Scope: scope,
db: gen.dbConn,
gen: gen,
}
valid, err := token.CheckNotRevoked(ctx)
valid, err := token.CheckNotRevoked(tx)
if err != nil && gen.dbConn != nil {
return nil, errors.Wrap(err, "token.CheckNotRevoked")
}

View File

@@ -2,6 +2,7 @@ package jwt
import (
"context"
"database/sql"
"testing"
"github.com/DATA-DOG/go-sqlmock"
@@ -43,10 +44,15 @@ func TestValidateAccess_Success(t *testing.T) {
// We don't know the JTI beforehand; match any arg
expectNotRevoked(mock, sqlmock.AnyArg())
token, err := gen.ValidateAccess(context.Background(), tokenStr)
tx, err := gen.dbConn.BeginTx(context.Background(), nil)
require.NoError(t, err)
defer tx.Rollback()
token, err := gen.ValidateAccess(tx, tokenStr)
require.NoError(t, err)
require.Equal(t, 42, token.SUB)
require.Equal(t, "access", token.Scope)
tx.Commit()
}
func TestValidateAccess_NoDB(t *testing.T) {
@@ -55,7 +61,7 @@ func TestValidateAccess_NoDB(t *testing.T) {
tokenStr, _, err := gen.NewAccess(42, true, false)
require.NoError(t, err)
token, err := gen.ValidateAccess(context.Background(), tokenStr)
token, err := gen.ValidateAccess(&sql.Tx{}, tokenStr)
require.NoError(t, err)
require.Equal(t, 42, token.SUB)
require.Equal(t, "access", token.Scope)
@@ -70,10 +76,15 @@ func TestValidateRefresh_Success(t *testing.T) {
expectNotRevoked(mock, sqlmock.AnyArg())
token, err := gen.ValidateRefresh(context.Background(), tokenStr)
tx, err := gen.dbConn.BeginTx(context.Background(), nil)
require.NoError(t, err)
defer tx.Rollback()
token, err := gen.ValidateRefresh(tx, tokenStr)
require.NoError(t, err)
require.Equal(t, 42, token.SUB)
require.Equal(t, "refresh", token.Scope)
tx.Commit()
}
func TestValidateRefresh_NoDB(t *testing.T) {
@@ -82,7 +93,7 @@ func TestValidateRefresh_NoDB(t *testing.T) {
tokenStr, _, err := gen.NewRefresh(42, false)
require.NoError(t, err)
token, err := gen.ValidateRefresh(context.Background(), tokenStr)
token, err := gen.ValidateRefresh(nil, tokenStr)
require.NoError(t, err)
require.Equal(t, 42, token.SUB)
require.Equal(t, "refresh", token.Scope)
@@ -91,7 +102,7 @@ func TestValidateRefresh_NoDB(t *testing.T) {
func TestValidateAccess_EmptyToken(t *testing.T) {
gen := newTestGenerator(t)
_, err := gen.ValidateAccess(context.Background(), "")
_, err := gen.ValidateAccess(nil, "")
require.Error(t, err)
}
@@ -102,6 +113,6 @@ func TestValidateRefresh_WrongScope(t *testing.T) {
tokenStr, _, err := gen.NewAccess(1, false, false)
require.NoError(t, err)
_, err = gen.ValidateRefresh(context.Background(), tokenStr)
_, err = gen.ValidateRefresh(nil, tokenStr)
require.Error(t, err)
}

32
tmdb/config.go Normal file
View File

@@ -0,0 +1,32 @@
package tmdb
import (
"encoding/json"
"github.com/pkg/errors"
)
type Config struct {
Image Image `json:"images"`
}
type Image struct {
BaseURL string `json:"base_url"`
SecureBaseURL string `json:"secure_base_url"`
BackdropSizes []string `json:"backdrop_sizes"`
LogoSizes []string `json:"logo_sizes"`
PosterSizes []string `json:"poster_sizes"`
ProfileSizes []string `json:"profile_sizes"`
StillSizes []string `json:"still_sizes"`
}
func GetConfig(token string) (*Config, error) {
url := "https://api.themoviedb.org/3/configuration"
data, err := tmdbGet(url, token)
if err != nil {
return nil, errors.Wrap(err, "tmdbGet")
}
config := Config{}
json.Unmarshal(data, &config)
return &config, nil
}

54
tmdb/credits.go Normal file
View File

@@ -0,0 +1,54 @@
package tmdb
import (
"encoding/json"
"fmt"
"github.com/pkg/errors"
)
type Credits struct {
ID int32 `json:"id"`
Cast []Cast `json:"cast"`
Crew []Crew `json:"crew"`
}
type Cast struct {
Adult bool `json:"adult"`
Gender int `json:"gender"`
ID int32 `json:"id"`
KnownFor string `json:"known_for_department"`
Name string `json:"name"`
OriginalName string `json:"original_name"`
Popularity int `json:"popularity"`
Profile string `json:"profile_path"`
CastID int32 `json:"cast_id"`
Character string `json:"character"`
CreditID string `json:"credit_id"`
Order int `json:"order"`
}
type Crew struct {
Adult bool `json:"adult"`
Gender int `json:"gender"`
ID int32 `json:"id"`
KnownFor string `json:"known_for_department"`
Name string `json:"name"`
OriginalName string `json:"original_name"`
Popularity int `json:"popularity"`
Profile string `json:"profile_path"`
CreditID string `json:"credit_id"`
Department string `json:"department"`
Job string `json:"job"`
}
func GetCredits(movieid int32, token string) (*Credits, error) {
url := fmt.Sprintf("https://api.themoviedb.org/3/movie/%v/credits?language=en-US", movieid)
data, err := tmdbGet(url, token)
if err != nil {
return nil, errors.Wrap(err, "tmdbGet")
}
credits := Credits{}
json.Unmarshal(data, &credits)
return &credits, nil
}

41
tmdb/crew_functions.go Normal file
View File

@@ -0,0 +1,41 @@
package tmdb
import "sort"
type BilledCrew struct {
Name string
Roles []string
}
func (credits *Credits) BilledCrew() []BilledCrew {
crewmap := make(map[string][]string)
billedcrew := []BilledCrew{}
for _, crew := range credits.Crew {
if crew.Job == "Director" ||
crew.Job == "Screenplay" ||
crew.Job == "Writer" ||
crew.Job == "Novel" ||
crew.Job == "Story" {
crewmap[crew.Name] = append(crewmap[crew.Name], crew.Job)
}
}
for name, jobs := range crewmap {
billedcrew = append(billedcrew, BilledCrew{Name: name, Roles: jobs})
}
for i := range billedcrew {
sort.Strings(billedcrew[i].Roles)
}
sort.Slice(billedcrew, func(i, j int) bool {
return billedcrew[i].Roles[0] < billedcrew[j].Roles[0]
})
return billedcrew
}
func (billedcrew *BilledCrew) FRoles() string {
jobs := ""
for _, job := range billedcrew.Roles {
jobs += job + ", "
}
return jobs[:len(jobs)-2]
}

5
tmdb/go.mod Normal file
View File

@@ -0,0 +1,5 @@
module git.haelnorr.com/h/golib/tmdb
go 1.25.5
require github.com/pkg/errors v0.9.1

2
tmdb/go.sum Normal file
View File

@@ -0,0 +1,2 @@
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=

45
tmdb/movie.go Normal file
View File

@@ -0,0 +1,45 @@
package tmdb
import (
"encoding/json"
"fmt"
"github.com/pkg/errors"
)
type Movie struct {
Adult bool `json:"adult"`
Backdrop string `json:"backdrop_path"`
Collection string `json:"belongs_to_collection"`
Budget int `json:"budget"`
Genres []Genre `json:"genres"`
Homepage string `json:"homepage"`
ID int32 `json:"id"`
IMDbID string `json:"imdb_id"`
OriginalLanguage string `json:"original_language"`
OriginalTitle string `json:"original_title"`
Overview string `json:"overview"`
Popularity float32 `json:"popularity"`
Poster string `json:"poster_path"`
ProductionCompanies []ProductionCompany `json:"production_companies"`
ProductionCountries []ProductionCountry `json:"production_countries"`
ReleaseDate string `json:"release_date"`
Revenue int `json:"revenue"`
Runtime int `json:"runtime"`
SpokenLanguages []SpokenLanguage `json:"spoken_languages"`
Status string `json:"status"`
Tagline string `json:"tagline"`
Title string `json:"title"`
Video bool `json:"video"`
}
func GetMovie(id int32, token string) (*Movie, error) {
url := fmt.Sprintf("https://api.themoviedb.org/3/movie/%v?language=en-US", id)
data, err := tmdbGet(url, token)
if err != nil {
return nil, errors.Wrap(err, "tmdbGet")
}
movie := Movie{}
json.Unmarshal(data, &movie)
return &movie, nil
}

42
tmdb/movie_functions.go Normal file
View File

@@ -0,0 +1,42 @@
package tmdb
import (
"fmt"
"net/url"
"path"
)
func (movie *Movie) FRuntime() string {
hours := movie.Runtime / 60
mins := movie.Runtime % 60
return fmt.Sprintf("%dh %02dm", hours, mins)
}
func (movie *Movie) GetPoster(image *Image, size string) string {
base, err := url.Parse(image.SecureBaseURL)
if err != nil {
return ""
}
fullPath := path.Join(base.Path, size, movie.Poster)
base.Path = fullPath
return base.String()
}
func (movie *Movie) ReleaseYear() string {
if movie.ReleaseDate == "" {
return ""
} else {
return "(" + movie.ReleaseDate[:4] + ")"
}
}
func (movie *Movie) FGenres() string {
genres := ""
for _, genre := range movie.Genres {
genres += genre.Name + ", "
}
if len(genres) > 2 {
return genres[:len(genres)-2]
}
return genres
}

28
tmdb/request.go Normal file
View File

@@ -0,0 +1,28 @@
package tmdb
import (
"fmt"
"io"
"net/http"
"github.com/pkg/errors"
)
func tmdbGet(url string, token string) ([]byte, error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, errors.Wrap(err, "http.NewRequest")
}
req.Header.Add("accept", "application/json")
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token))
res, err := http.DefaultClient.Do(req)
if err != nil {
return nil, errors.Wrap(err, "http.DefaultClient.Do")
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
return nil, errors.Wrap(err, "io.ReadAll")
}
return body, nil
}

79
tmdb/search.go Normal file
View File

@@ -0,0 +1,79 @@
package tmdb
import (
"encoding/json"
"fmt"
"net/url"
"path"
"github.com/pkg/errors"
)
type Result struct {
Page int `json:"page"`
TotalPages int `json:"total_pages"`
TotalResults int `json:"total_results"`
}
type ResultMovies struct {
Result
Results []ResultMovie `json:"results"`
}
type ResultMovie struct {
Adult bool `json:"adult"`
BackdropPath string `json:"backdrop_path"`
GenreIDs []int `json:"genre_ids"`
ID int32 `json:"id"`
OriginalLanguage string `json:"original_language"`
OriginalTitle string `json:"original_title"`
Overview string `json:"overview"`
Popularity int `json:"popularity"`
PosterPath string `json:"poster_path"`
ReleaseDate string `json:"release_date"`
Title string `json:"title"`
Video bool `json:"video"`
VoteAverage int `json:"vote_average"`
VoteCount int `json:"vote_count"`
}
func (movie *ResultMovie) GetPoster(image *Image, size string) string {
base, err := url.Parse(image.SecureBaseURL)
if err != nil {
return ""
}
fullPath := path.Join(base.Path, size, movie.PosterPath)
base.Path = fullPath
return base.String()
}
func (movie *ResultMovie) ReleaseYear() string {
if movie.ReleaseDate == "" {
return ""
} else {
return "(" + movie.ReleaseDate[:4] + ")"
}
}
// TODO: genres list https://developer.themoviedb.org/reference/genre-movie-list
// func (movie *ResultMovie) FGenres() string {
// genres := ""
// for _, genre := range movie.Genres {
// genres += genre.Name + ", "
// }
// return genres[:len(genres)-2]
// }
func SearchMovies(token string, query string, adult bool, page int) (*ResultMovies, error) {
url := "https://api.themoviedb.org/3/search/movie" +
fmt.Sprintf("?query=%s", url.QueryEscape(query)) +
fmt.Sprintf("&include_adult=%t", adult) +
fmt.Sprintf("&page=%v", page) +
"&language=en-US"
response, err := tmdbGet(url, token)
if err != nil {
return nil, errors.Wrap(err, "tmdbGet")
}
var results ResultMovies
json.Unmarshal(response, &results)
return &results, nil
}

24
tmdb/structs.go Normal file
View File

@@ -0,0 +1,24 @@
package tmdb
type Genre struct {
ID int `json:"id"`
Name string `json:"name"`
}
type ProductionCompany struct {
ID int `json:"id"`
Logo string `json:"logo_path"`
Name string `json:"name"`
OriginCountry string `json:"origin_country"`
}
type ProductionCountry struct {
ISO_3166_1 string `json:"iso_3166_1"`
Name string `json:"name"`
}
type SpokenLanguage struct {
EnglishName string `json:"english_name"`
ISO_639_1 string `json:"iso_639_1"`
Name string `json:"name"`
}