Compare commits
3 Commits
hlog/v0.9.
...
jwt/v0.9.1
| Author | SHA1 | Date | |
|---|---|---|---|
| 05aad5f11b | |||
| c4574e32c7 | |||
| c466cd3163 |
62
jwt/generator.go
Normal file
62
jwt/generator.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TokenGenerator struct {
|
||||||
|
accessExpireAfter int64 // Access Token expiry time in minutes
|
||||||
|
refreshExpireAfter int64 // Refresh Token expiry time in minutes
|
||||||
|
freshExpireAfter int64 // Token freshness expiry time in minutes
|
||||||
|
trustedHost string // Trusted hostname to use for the tokens
|
||||||
|
secretKey string // Secret key to use for token hashing
|
||||||
|
dbConn *sql.DB // Database handle for token blacklisting
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateGenerator creates and returns a new TokenGenerator using the provided configuration.
|
||||||
|
// All expiry times should be provided in minutes.
|
||||||
|
// trustedHost and secretKey strings must be provided.
|
||||||
|
// dbConn can be nil, but doing this will disable token revocation
|
||||||
|
func CreateGenerator(
|
||||||
|
accessExpireAfter int64,
|
||||||
|
refreshExpireAfter int64,
|
||||||
|
freshExpireAfter int64,
|
||||||
|
trustedHost string,
|
||||||
|
secretKey string,
|
||||||
|
dbConn *sql.DB,
|
||||||
|
) (gen *TokenGenerator, err error) {
|
||||||
|
if accessExpireAfter <= 0 {
|
||||||
|
return nil, errors.New("accessExpireAfter must be greater than 0")
|
||||||
|
}
|
||||||
|
if refreshExpireAfter <= 0 {
|
||||||
|
return nil, errors.New("refreshExpireAfter must be greater than 0")
|
||||||
|
}
|
||||||
|
if freshExpireAfter <= 0 {
|
||||||
|
return nil, errors.New("freshExpireAfter must be greater than 0")
|
||||||
|
}
|
||||||
|
if trustedHost == "" {
|
||||||
|
return nil, errors.New("trustedHost cannot be an empty string")
|
||||||
|
}
|
||||||
|
if secretKey == "" {
|
||||||
|
return nil, errors.New("secretKey cannot be an empty string")
|
||||||
|
}
|
||||||
|
|
||||||
|
if dbConn != nil {
|
||||||
|
err := dbConn.Ping()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("Failed to ping database")
|
||||||
|
}
|
||||||
|
// TODO: check if jwtblacklist table exists
|
||||||
|
// TODO: create jwtblacklist table if not existing
|
||||||
|
}
|
||||||
|
|
||||||
|
return &TokenGenerator{
|
||||||
|
accessExpireAfter: accessExpireAfter,
|
||||||
|
refreshExpireAfter: refreshExpireAfter,
|
||||||
|
freshExpireAfter: freshExpireAfter,
|
||||||
|
trustedHost: trustedHost,
|
||||||
|
secretKey: secretKey,
|
||||||
|
dbConn: dbConn,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
90
jwt/generator_test.go
Normal file
90
jwt/generator_test.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCreateGenerator_Success_NoDB(t *testing.T) {
|
||||||
|
gen, err := CreateGenerator(
|
||||||
|
15,
|
||||||
|
60,
|
||||||
|
5,
|
||||||
|
"example.com",
|
||||||
|
"secret",
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, gen)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateGenerator_Success_WithDB(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
gen, err := CreateGenerator(
|
||||||
|
15,
|
||||||
|
60,
|
||||||
|
5,
|
||||||
|
"example.com",
|
||||||
|
"secret",
|
||||||
|
db,
|
||||||
|
)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, gen)
|
||||||
|
require.NoError(t, mock.ExpectationsWereMet())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateGenerator_InvalidInputs(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fn func() error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"access expiry <= 0",
|
||||||
|
func() error {
|
||||||
|
_, err := CreateGenerator(0, 1, 1, "h", "s", nil)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"refresh expiry <= 0",
|
||||||
|
func() error {
|
||||||
|
_, err := CreateGenerator(1, 0, 1, "h", "s", nil)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"fresh expiry <= 0",
|
||||||
|
func() error {
|
||||||
|
_, err := CreateGenerator(1, 1, 0, "h", "s", nil)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"empty trustedHost",
|
||||||
|
func() error {
|
||||||
|
_, err := CreateGenerator(1, 1, 1, "", "s", nil)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"empty secretKey",
|
||||||
|
func() error {
|
||||||
|
_, err := CreateGenerator(1, 1, 1, "h", "", nil)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
require.Error(t, tt.fn())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
17
jwt/go.mod
Normal file
17
jwt/go.mod
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
module git.haelnorr.com/h/golib/jwt
|
||||||
|
|
||||||
|
go 1.25.5
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||||
|
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||||
|
github.com/google/uuid v1.6.0
|
||||||
|
github.com/pkg/errors v0.9.1
|
||||||
|
github.com/stretchr/testify v1.11.1
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
|
)
|
||||||
19
jwt/go.sum
Normal file
19
jwt/go.sum
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||||
|
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||||
|
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||||
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
|
||||||
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
38
jwt/revoke.go
Normal file
38
jwt/revoke.go
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Revoke a token by adding it to the database
|
||||||
|
func (gen *TokenGenerator) revoke(tx *sql.Tx, t Token) error {
|
||||||
|
if gen.dbConn == nil {
|
||||||
|
return errors.New("No DB provided, unable to use this function")
|
||||||
|
}
|
||||||
|
jti := t.GetJTI()
|
||||||
|
exp := t.GetEXP()
|
||||||
|
query := `INSERT INTO jwtblacklist (jti, exp) VALUES (?, ?)`
|
||||||
|
_, err := tx.Exec(query, jti, exp)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "tx.Exec")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if a token has been revoked. Returns true if not revoked.
|
||||||
|
func (gen *TokenGenerator) checkNotRevoked(tx *sql.Tx, t Token) (bool, error) {
|
||||||
|
if gen.dbConn == nil {
|
||||||
|
return false, errors.New("No DB provided, unable to use this function")
|
||||||
|
}
|
||||||
|
jti := t.GetJTI()
|
||||||
|
query := `SELECT 1 FROM jwtblacklist WHERE jti = ? LIMIT 1`
|
||||||
|
rows, err := tx.Query(query, jti)
|
||||||
|
if err != nil {
|
||||||
|
return false, errors.Wrap(err, "tx.Query")
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
revoked := rows.Next()
|
||||||
|
return !revoked, nil
|
||||||
|
}
|
||||||
83
jwt/revoke_test.go
Normal file
83
jwt/revoke_test.go
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newGeneratorWithNoDB(t *testing.T) *TokenGenerator {
|
||||||
|
gen, err := CreateGenerator(
|
||||||
|
15,
|
||||||
|
60,
|
||||||
|
5,
|
||||||
|
"example.com",
|
||||||
|
"supersecret",
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return gen
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNoDBFail(t *testing.T) {
|
||||||
|
jti := uuid.New()
|
||||||
|
exp := time.Now().Add(time.Hour).Unix()
|
||||||
|
|
||||||
|
token := AccessToken{
|
||||||
|
JTI: jti,
|
||||||
|
EXP: exp,
|
||||||
|
gen: &TokenGenerator{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Revoke should fail due to no DB
|
||||||
|
err := token.Revoke(&sql.Tx{})
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
// CheckNotRevoked should fail
|
||||||
|
_, err = token.CheckNotRevoked(&sql.Tx{})
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRevokeAndCheckNotRevoked(t *testing.T) {
|
||||||
|
gen, mock, cleanup := newGeneratorWithMockDB(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
jti := uuid.New()
|
||||||
|
exp := time.Now().Add(time.Hour).Unix()
|
||||||
|
|
||||||
|
token := AccessToken{
|
||||||
|
JTI: jti,
|
||||||
|
EXP: exp,
|
||||||
|
gen: gen,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Revoke expectations
|
||||||
|
mock.ExpectBegin()
|
||||||
|
mock.ExpectExec(`INSERT INTO jwtblacklist`).
|
||||||
|
WithArgs(jti, exp).
|
||||||
|
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||||
|
mock.ExpectQuery(`SELECT 1 FROM jwtblacklist`).
|
||||||
|
WithArgs(jti).
|
||||||
|
WillReturnRows(sqlmock.NewRows([]string{"1"}).AddRow(1))
|
||||||
|
mock.ExpectCommit()
|
||||||
|
|
||||||
|
tx, err := gen.dbConn.BeginTx(context.Background(), nil)
|
||||||
|
defer tx.Rollback()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = token.Revoke(tx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
valid, err := token.CheckNotRevoked(tx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, valid)
|
||||||
|
|
||||||
|
require.NoError(t, tx.Commit())
|
||||||
|
require.NoError(t, mock.ExpectationsWereMet())
|
||||||
|
}
|
||||||
79
jwt/tokengen.go
Normal file
79
jwt/tokengen.go
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Generates an access token for the provided subject
|
||||||
|
func (gen *TokenGenerator) NewAccess(
|
||||||
|
subjectID int,
|
||||||
|
fresh bool,
|
||||||
|
rememberMe bool,
|
||||||
|
) (tokenString string, expiresIn int64, err error) {
|
||||||
|
issuedAt := time.Now().Unix()
|
||||||
|
expiresAt := issuedAt + (gen.accessExpireAfter * 60)
|
||||||
|
var freshExpiresAt int64
|
||||||
|
if fresh {
|
||||||
|
freshExpiresAt = issuedAt + (gen.freshExpireAfter * 60)
|
||||||
|
} else {
|
||||||
|
freshExpiresAt = issuedAt
|
||||||
|
}
|
||||||
|
var ttl string
|
||||||
|
if rememberMe {
|
||||||
|
ttl = "exp"
|
||||||
|
} else {
|
||||||
|
ttl = "session"
|
||||||
|
}
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256,
|
||||||
|
jwt.MapClaims{
|
||||||
|
"iss": gen.trustedHost,
|
||||||
|
"scope": "access",
|
||||||
|
"ttl": ttl,
|
||||||
|
"jti": uuid.New(),
|
||||||
|
"iat": issuedAt,
|
||||||
|
"exp": expiresAt,
|
||||||
|
"fresh": freshExpiresAt,
|
||||||
|
"sub": subjectID,
|
||||||
|
})
|
||||||
|
|
||||||
|
signedToken, err := token.SignedString([]byte(gen.secretKey))
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, errors.Wrap(err, "token.SignedString")
|
||||||
|
}
|
||||||
|
return signedToken, expiresAt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generates a refresh token for the provided user
|
||||||
|
func (gen *TokenGenerator) NewRefresh(
|
||||||
|
subjectID int,
|
||||||
|
rememberMe bool,
|
||||||
|
) (tokenStr string, exp int64, err error) {
|
||||||
|
issuedAt := time.Now().Unix()
|
||||||
|
expiresAt := issuedAt + (gen.refreshExpireAfter * 60)
|
||||||
|
var ttl string
|
||||||
|
if rememberMe {
|
||||||
|
ttl = "exp"
|
||||||
|
} else {
|
||||||
|
ttl = "session"
|
||||||
|
}
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256,
|
||||||
|
jwt.MapClaims{
|
||||||
|
"iss": gen.trustedHost,
|
||||||
|
"scope": "refresh",
|
||||||
|
"ttl": ttl,
|
||||||
|
"jti": uuid.New(),
|
||||||
|
"iat": issuedAt,
|
||||||
|
"exp": expiresAt,
|
||||||
|
"sub": subjectID,
|
||||||
|
})
|
||||||
|
|
||||||
|
signedToken, err := token.SignedString([]byte(gen.secretKey))
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, errors.Wrap(err, "token.SignedString")
|
||||||
|
}
|
||||||
|
return signedToken, expiresAt, nil
|
||||||
|
}
|
||||||
38
jwt/tokengen_test.go
Normal file
38
jwt/tokengen_test.go
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestGenerator(t *testing.T) *TokenGenerator {
|
||||||
|
gen, err := CreateGenerator(
|
||||||
|
15,
|
||||||
|
60,
|
||||||
|
5,
|
||||||
|
"example.com",
|
||||||
|
"supersecret",
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return gen
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewAccessToken(t *testing.T) {
|
||||||
|
gen := newTestGenerator(t)
|
||||||
|
|
||||||
|
tokenStr, exp, err := gen.NewAccess(123, true, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, tokenStr)
|
||||||
|
require.Greater(t, exp, int64(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewRefreshToken(t *testing.T) {
|
||||||
|
gen := newTestGenerator(t)
|
||||||
|
|
||||||
|
tokenStr, exp, err := gen.NewRefresh(123, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, tokenStr)
|
||||||
|
require.Greater(t, exp, int64(0))
|
||||||
|
}
|
||||||
71
jwt/tokens.go
Normal file
71
jwt/tokens.go
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Token interface {
|
||||||
|
GetJTI() uuid.UUID
|
||||||
|
GetEXP() int64
|
||||||
|
GetScope() string
|
||||||
|
Revoke(*sql.Tx) error
|
||||||
|
CheckNotRevoked(*sql.Tx) (bool, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Access token
|
||||||
|
type AccessToken struct {
|
||||||
|
ISS string // Issuer, generally TrustedHost
|
||||||
|
IAT int64 // Time issued at
|
||||||
|
EXP int64 // Time expiring at
|
||||||
|
TTL string // Time-to-live: "session" or "exp". Used with 'remember me'
|
||||||
|
SUB int // Subject (user) ID
|
||||||
|
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
|
||||||
|
Fresh int64 // Time freshness expiring at
|
||||||
|
Scope string // Should be "access"
|
||||||
|
gen *TokenGenerator
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh token
|
||||||
|
type RefreshToken struct {
|
||||||
|
ISS string // Issuer, generally TrustedHost
|
||||||
|
IAT int64 // Time issued at
|
||||||
|
EXP int64 // Time expiring at
|
||||||
|
TTL string // Time-to-live: "session" or "exp". Used with 'remember me'
|
||||||
|
SUB int // Subject (user) ID
|
||||||
|
JTI uuid.UUID // UUID-4 used for identifying blacklisted tokens
|
||||||
|
Scope string // Should be "refresh"
|
||||||
|
gen *TokenGenerator
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a AccessToken) GetJTI() uuid.UUID {
|
||||||
|
return a.JTI
|
||||||
|
}
|
||||||
|
func (r RefreshToken) GetJTI() uuid.UUID {
|
||||||
|
return r.JTI
|
||||||
|
}
|
||||||
|
func (a AccessToken) GetEXP() int64 {
|
||||||
|
return a.EXP
|
||||||
|
}
|
||||||
|
func (r RefreshToken) GetEXP() int64 {
|
||||||
|
return r.EXP
|
||||||
|
}
|
||||||
|
func (a AccessToken) GetScope() string {
|
||||||
|
return a.Scope
|
||||||
|
}
|
||||||
|
func (r RefreshToken) GetScope() string {
|
||||||
|
return r.Scope
|
||||||
|
}
|
||||||
|
func (a AccessToken) Revoke(tx *sql.Tx) error {
|
||||||
|
return a.gen.revoke(tx, a)
|
||||||
|
}
|
||||||
|
func (r RefreshToken) Revoke(tx *sql.Tx) error {
|
||||||
|
return r.gen.revoke(tx, r)
|
||||||
|
}
|
||||||
|
func (a AccessToken) CheckNotRevoked(tx *sql.Tx) (bool, error) {
|
||||||
|
return a.gen.checkNotRevoked(tx, a)
|
||||||
|
}
|
||||||
|
func (r RefreshToken) CheckNotRevoked(tx *sql.Tx) (bool, error) {
|
||||||
|
return r.gen.checkNotRevoked(tx, r)
|
||||||
|
}
|
||||||
123
jwt/util.go
Normal file
123
jwt/util.go
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Parse a token, validating its signing sigature and returning the claims
|
||||||
|
func parseToken(secretKey string, tokenString string) (jwt.MapClaims, error) {
|
||||||
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||||
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||||
|
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
|
||||||
|
}
|
||||||
|
|
||||||
|
return []byte(secretKey), nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "jwt.Parse")
|
||||||
|
}
|
||||||
|
// Token decoded, parse the claims
|
||||||
|
claims, ok := token.Claims.(jwt.MapClaims)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("Failed to parse claims")
|
||||||
|
}
|
||||||
|
return claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if a token is expired. Returns the expiry if not expired
|
||||||
|
func checkTokenExpired(expiry interface{}) (int64, error) {
|
||||||
|
// Coerce the expiry to a float64 to avoid scientific notation
|
||||||
|
expFloat, ok := expiry.(float64)
|
||||||
|
if !ok {
|
||||||
|
return 0, errors.New("Missing or invalid 'exp' claim")
|
||||||
|
}
|
||||||
|
// Convert to the int64 time we expect :)
|
||||||
|
expiryTime := int64(expFloat)
|
||||||
|
|
||||||
|
// Check if its expired
|
||||||
|
isExpired := time.Now().After(time.Unix(expiryTime, 0))
|
||||||
|
if isExpired {
|
||||||
|
return 0, errors.New("Token has expired")
|
||||||
|
}
|
||||||
|
return expiryTime, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if a token has a valid issuer. Returns the issuer if valid
|
||||||
|
func checkTokenIssuer(trustedHost string, issuer interface{}) (string, error) {
|
||||||
|
issuerVal, ok := issuer.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", errors.New("Missing or invalid 'iss' claim")
|
||||||
|
}
|
||||||
|
if issuer != trustedHost {
|
||||||
|
return "", errors.New("Issuer does not matched trusted host")
|
||||||
|
}
|
||||||
|
return issuerVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the scope matches the expected scope. Returns scope if true
|
||||||
|
func getTokenScope(scope interface{}) (string, error) {
|
||||||
|
scopeStr, ok := scope.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", errors.New("Missing or invalid 'scope' claim")
|
||||||
|
}
|
||||||
|
return scopeStr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the TTL of the token, either "session" or "exp"
|
||||||
|
func getTokenTTL(ttl interface{}) (string, error) {
|
||||||
|
ttlStr, ok := ttl.(string)
|
||||||
|
if !ok {
|
||||||
|
return "", errors.New("Missing or invalid 'ttl' claim")
|
||||||
|
}
|
||||||
|
if ttlStr != "exp" && ttlStr != "session" {
|
||||||
|
return "", errors.New("TTL value is not recognised")
|
||||||
|
}
|
||||||
|
return ttlStr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the time the token was issued at
|
||||||
|
func getIssuedTime(issued interface{}) (int64, error) {
|
||||||
|
// Same float64 -> int64 trick as expiry
|
||||||
|
issuedFloat, ok := issued.(float64)
|
||||||
|
if !ok {
|
||||||
|
return 0, errors.New("Missing or invalid 'iat' claim")
|
||||||
|
}
|
||||||
|
issuedAt := int64(issuedFloat)
|
||||||
|
return issuedAt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the freshness expiry timestamp
|
||||||
|
func getFreshTime(fresh interface{}) (int64, error) {
|
||||||
|
freshUntil, ok := fresh.(float64)
|
||||||
|
if !ok {
|
||||||
|
return 0, errors.New("Missing or invalid 'fresh' claim")
|
||||||
|
}
|
||||||
|
return int64(freshUntil), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the subject of the token
|
||||||
|
func getTokenSubject(sub interface{}) (int, error) {
|
||||||
|
subject, ok := sub.(float64)
|
||||||
|
if !ok {
|
||||||
|
return 0, errors.New("Missing or invalid 'sub' claim")
|
||||||
|
}
|
||||||
|
return int(subject), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the JTI of the token
|
||||||
|
func getTokenJTI(jti interface{}) (uuid.UUID, error) {
|
||||||
|
jtiStr, ok := jti.(string)
|
||||||
|
if !ok {
|
||||||
|
return uuid.UUID{}, errors.New("Missing or invalid 'jti' claim")
|
||||||
|
}
|
||||||
|
jtiUUID, err := uuid.Parse(jtiStr)
|
||||||
|
if err != nil {
|
||||||
|
return uuid.UUID{}, errors.New("JTI is not a valid UUID")
|
||||||
|
}
|
||||||
|
return jtiUUID, nil
|
||||||
|
}
|
||||||
146
jwt/validate.go
Normal file
146
jwt/validate.go
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Parse an access token and return a struct with all the claims. Does validation on
|
||||||
|
// all the claims, including checking if it is expired, has a valid issuer, and
|
||||||
|
// has the correct scope.
|
||||||
|
func (gen *TokenGenerator) ValidateAccess(
|
||||||
|
tx *sql.Tx,
|
||||||
|
tokenString string,
|
||||||
|
) (*AccessToken, error) {
|
||||||
|
if tokenString == "" {
|
||||||
|
return nil, errors.New("Access token string not provided")
|
||||||
|
}
|
||||||
|
claims, err := parseToken(gen.secretKey, tokenString)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "parseToken")
|
||||||
|
}
|
||||||
|
expiry, err := checkTokenExpired(claims["exp"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "checkTokenExpired")
|
||||||
|
}
|
||||||
|
issuer, err := checkTokenIssuer(gen.trustedHost, claims["iss"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "checkTokenIssuer")
|
||||||
|
}
|
||||||
|
ttl, err := getTokenTTL(claims["ttl"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "getTokenTTL")
|
||||||
|
}
|
||||||
|
scope, err := getTokenScope(claims["scope"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "getTokenScope")
|
||||||
|
}
|
||||||
|
if scope != "access" {
|
||||||
|
return nil, errors.New("Token is not an Access token")
|
||||||
|
}
|
||||||
|
issuedAt, err := getIssuedTime(claims["iat"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "getIssuedTime")
|
||||||
|
}
|
||||||
|
subject, err := getTokenSubject(claims["sub"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "getTokenSubject")
|
||||||
|
}
|
||||||
|
fresh, err := getFreshTime(claims["fresh"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "getFreshTime")
|
||||||
|
}
|
||||||
|
jti, err := getTokenJTI(claims["jti"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "getTokenJTI")
|
||||||
|
}
|
||||||
|
|
||||||
|
token := &AccessToken{
|
||||||
|
ISS: issuer,
|
||||||
|
TTL: ttl,
|
||||||
|
EXP: expiry,
|
||||||
|
IAT: issuedAt,
|
||||||
|
SUB: subject,
|
||||||
|
Fresh: fresh,
|
||||||
|
JTI: jti,
|
||||||
|
Scope: scope,
|
||||||
|
gen: gen,
|
||||||
|
}
|
||||||
|
|
||||||
|
valid, err := token.CheckNotRevoked(tx)
|
||||||
|
if err != nil && gen.dbConn != nil {
|
||||||
|
return nil, errors.Wrap(err, "token.CheckNotRevoked")
|
||||||
|
}
|
||||||
|
if !valid && gen.dbConn != nil {
|
||||||
|
return nil, errors.New("Token has been revoked")
|
||||||
|
}
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse a refresh token and return a struct with all the claims. Does validation on
|
||||||
|
// all the claims, including checking if it is expired, has a valid issuer, and
|
||||||
|
// has the correct scope.
|
||||||
|
func (gen *TokenGenerator) ValidateRefresh(
|
||||||
|
tx *sql.Tx,
|
||||||
|
tokenString string,
|
||||||
|
) (*RefreshToken, error) {
|
||||||
|
if tokenString == "" {
|
||||||
|
return nil, errors.New("Refresh token string not provided")
|
||||||
|
}
|
||||||
|
claims, err := parseToken(gen.secretKey, tokenString)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "parseToken")
|
||||||
|
}
|
||||||
|
expiry, err := checkTokenExpired(claims["exp"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "checkTokenExpired")
|
||||||
|
}
|
||||||
|
issuer, err := checkTokenIssuer(gen.trustedHost, claims["iss"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "checkTokenIssuer")
|
||||||
|
}
|
||||||
|
ttl, err := getTokenTTL(claims["ttl"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "getTokenTTL")
|
||||||
|
}
|
||||||
|
scope, err := getTokenScope(claims["scope"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "getTokenScope")
|
||||||
|
}
|
||||||
|
if scope != "refresh" {
|
||||||
|
return nil, errors.New("Token is not an Refresh token")
|
||||||
|
}
|
||||||
|
issuedAt, err := getIssuedTime(claims["iat"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "getIssuedTime")
|
||||||
|
}
|
||||||
|
subject, err := getTokenSubject(claims["sub"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "getTokenSubject")
|
||||||
|
}
|
||||||
|
jti, err := getTokenJTI(claims["jti"])
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "getTokenJTI")
|
||||||
|
}
|
||||||
|
|
||||||
|
token := &RefreshToken{
|
||||||
|
ISS: issuer,
|
||||||
|
TTL: ttl,
|
||||||
|
EXP: expiry,
|
||||||
|
IAT: issuedAt,
|
||||||
|
SUB: subject,
|
||||||
|
JTI: jti,
|
||||||
|
Scope: scope,
|
||||||
|
gen: gen,
|
||||||
|
}
|
||||||
|
|
||||||
|
valid, err := token.CheckNotRevoked(tx)
|
||||||
|
if err != nil && gen.dbConn != nil {
|
||||||
|
return nil, errors.Wrap(err, "token.CheckNotRevoked")
|
||||||
|
}
|
||||||
|
if !valid && gen.dbConn != nil {
|
||||||
|
return nil, errors.New("Token has been revoked")
|
||||||
|
}
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
118
jwt/validate_test.go
Normal file
118
jwt/validate_test.go
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
package jwt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newGeneratorWithMockDB(t *testing.T) (*TokenGenerator, sqlmock.Sqlmock, func()) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
gen, err := CreateGenerator(
|
||||||
|
15,
|
||||||
|
60,
|
||||||
|
5,
|
||||||
|
"example.com",
|
||||||
|
"supersecret",
|
||||||
|
db,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return gen, mock, func() { db.Close() }
|
||||||
|
}
|
||||||
|
|
||||||
|
func expectNotRevoked(mock sqlmock.Sqlmock, jti any) {
|
||||||
|
mock.ExpectBegin()
|
||||||
|
mock.ExpectQuery(`SELECT 1 FROM jwtblacklist`).
|
||||||
|
WithArgs(jti).
|
||||||
|
WillReturnRows(sqlmock.NewRows([]string{}))
|
||||||
|
mock.ExpectCommit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateAccess_Success(t *testing.T) {
|
||||||
|
gen, mock, cleanup := newGeneratorWithMockDB(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
tokenStr, _, err := gen.NewAccess(42, true, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// We don't know the JTI beforehand; match any arg
|
||||||
|
expectNotRevoked(mock, sqlmock.AnyArg())
|
||||||
|
|
||||||
|
tx, err := gen.dbConn.BeginTx(context.Background(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
token, err := gen.ValidateAccess(tx, tokenStr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 42, token.SUB)
|
||||||
|
require.Equal(t, "access", token.Scope)
|
||||||
|
tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateAccess_NoDB(t *testing.T) {
|
||||||
|
gen := newGeneratorWithNoDB(t)
|
||||||
|
|
||||||
|
tokenStr, _, err := gen.NewAccess(42, true, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
token, err := gen.ValidateAccess(&sql.Tx{}, tokenStr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 42, token.SUB)
|
||||||
|
require.Equal(t, "access", token.Scope)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRefresh_Success(t *testing.T) {
|
||||||
|
gen, mock, cleanup := newGeneratorWithMockDB(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
tokenStr, _, err := gen.NewRefresh(42, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
expectNotRevoked(mock, sqlmock.AnyArg())
|
||||||
|
|
||||||
|
tx, err := gen.dbConn.BeginTx(context.Background(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
token, err := gen.ValidateRefresh(tx, tokenStr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 42, token.SUB)
|
||||||
|
require.Equal(t, "refresh", token.Scope)
|
||||||
|
tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRefresh_NoDB(t *testing.T) {
|
||||||
|
gen := newGeneratorWithNoDB(t)
|
||||||
|
|
||||||
|
tokenStr, _, err := gen.NewRefresh(42, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
token, err := gen.ValidateRefresh(nil, tokenStr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 42, token.SUB)
|
||||||
|
require.Equal(t, "refresh", token.Scope)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateAccess_EmptyToken(t *testing.T) {
|
||||||
|
gen := newTestGenerator(t)
|
||||||
|
|
||||||
|
_, err := gen.ValidateAccess(nil, "")
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRefresh_WrongScope(t *testing.T) {
|
||||||
|
gen := newTestGenerator(t)
|
||||||
|
|
||||||
|
// Create access token but validate as refresh
|
||||||
|
tokenStr, _, err := gen.NewAccess(1, false, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = gen.ValidateRefresh(nil, tokenStr)
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
32
tmdb/config.go
Normal file
32
tmdb/config.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package tmdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
Image Image `json:"images"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Image struct {
|
||||||
|
BaseURL string `json:"base_url"`
|
||||||
|
SecureBaseURL string `json:"secure_base_url"`
|
||||||
|
BackdropSizes []string `json:"backdrop_sizes"`
|
||||||
|
LogoSizes []string `json:"logo_sizes"`
|
||||||
|
PosterSizes []string `json:"poster_sizes"`
|
||||||
|
ProfileSizes []string `json:"profile_sizes"`
|
||||||
|
StillSizes []string `json:"still_sizes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetConfig(token string) (*Config, error) {
|
||||||
|
url := "https://api.themoviedb.org/3/configuration"
|
||||||
|
data, err := tmdbGet(url, token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "tmdbGet")
|
||||||
|
}
|
||||||
|
config := Config{}
|
||||||
|
json.Unmarshal(data, &config)
|
||||||
|
return &config, nil
|
||||||
|
}
|
||||||
54
tmdb/credits.go
Normal file
54
tmdb/credits.go
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
package tmdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Credits struct {
|
||||||
|
ID int32 `json:"id"`
|
||||||
|
Cast []Cast `json:"cast"`
|
||||||
|
Crew []Crew `json:"crew"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Cast struct {
|
||||||
|
Adult bool `json:"adult"`
|
||||||
|
Gender int `json:"gender"`
|
||||||
|
ID int32 `json:"id"`
|
||||||
|
KnownFor string `json:"known_for_department"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
OriginalName string `json:"original_name"`
|
||||||
|
Popularity int `json:"popularity"`
|
||||||
|
Profile string `json:"profile_path"`
|
||||||
|
CastID int32 `json:"cast_id"`
|
||||||
|
Character string `json:"character"`
|
||||||
|
CreditID string `json:"credit_id"`
|
||||||
|
Order int `json:"order"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Crew struct {
|
||||||
|
Adult bool `json:"adult"`
|
||||||
|
Gender int `json:"gender"`
|
||||||
|
ID int32 `json:"id"`
|
||||||
|
KnownFor string `json:"known_for_department"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
OriginalName string `json:"original_name"`
|
||||||
|
Popularity int `json:"popularity"`
|
||||||
|
Profile string `json:"profile_path"`
|
||||||
|
CreditID string `json:"credit_id"`
|
||||||
|
Department string `json:"department"`
|
||||||
|
Job string `json:"job"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetCredits(movieid int32, token string) (*Credits, error) {
|
||||||
|
url := fmt.Sprintf("https://api.themoviedb.org/3/movie/%v/credits?language=en-US", movieid)
|
||||||
|
data, err := tmdbGet(url, token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "tmdbGet")
|
||||||
|
}
|
||||||
|
credits := Credits{}
|
||||||
|
json.Unmarshal(data, &credits)
|
||||||
|
return &credits, nil
|
||||||
|
}
|
||||||
41
tmdb/crew_functions.go
Normal file
41
tmdb/crew_functions.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
package tmdb
|
||||||
|
|
||||||
|
import "sort"
|
||||||
|
|
||||||
|
type BilledCrew struct {
|
||||||
|
Name string
|
||||||
|
Roles []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (credits *Credits) BilledCrew() []BilledCrew {
|
||||||
|
crewmap := make(map[string][]string)
|
||||||
|
billedcrew := []BilledCrew{}
|
||||||
|
for _, crew := range credits.Crew {
|
||||||
|
if crew.Job == "Director" ||
|
||||||
|
crew.Job == "Screenplay" ||
|
||||||
|
crew.Job == "Writer" ||
|
||||||
|
crew.Job == "Novel" ||
|
||||||
|
crew.Job == "Story" {
|
||||||
|
crewmap[crew.Name] = append(crewmap[crew.Name], crew.Job)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, jobs := range crewmap {
|
||||||
|
billedcrew = append(billedcrew, BilledCrew{Name: name, Roles: jobs})
|
||||||
|
}
|
||||||
|
for i := range billedcrew {
|
||||||
|
sort.Strings(billedcrew[i].Roles)
|
||||||
|
}
|
||||||
|
sort.Slice(billedcrew, func(i, j int) bool {
|
||||||
|
return billedcrew[i].Roles[0] < billedcrew[j].Roles[0]
|
||||||
|
})
|
||||||
|
return billedcrew
|
||||||
|
}
|
||||||
|
|
||||||
|
func (billedcrew *BilledCrew) FRoles() string {
|
||||||
|
jobs := ""
|
||||||
|
for _, job := range billedcrew.Roles {
|
||||||
|
jobs += job + ", "
|
||||||
|
}
|
||||||
|
return jobs[:len(jobs)-2]
|
||||||
|
}
|
||||||
5
tmdb/go.mod
Normal file
5
tmdb/go.mod
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
module git.haelnorr.com/h/golib/tmdb
|
||||||
|
|
||||||
|
go 1.25.5
|
||||||
|
|
||||||
|
require github.com/pkg/errors v0.9.1
|
||||||
2
tmdb/go.sum
Normal file
2
tmdb/go.sum
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
45
tmdb/movie.go
Normal file
45
tmdb/movie.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package tmdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Movie struct {
|
||||||
|
Adult bool `json:"adult"`
|
||||||
|
Backdrop string `json:"backdrop_path"`
|
||||||
|
Collection string `json:"belongs_to_collection"`
|
||||||
|
Budget int `json:"budget"`
|
||||||
|
Genres []Genre `json:"genres"`
|
||||||
|
Homepage string `json:"homepage"`
|
||||||
|
ID int32 `json:"id"`
|
||||||
|
IMDbID string `json:"imdb_id"`
|
||||||
|
OriginalLanguage string `json:"original_language"`
|
||||||
|
OriginalTitle string `json:"original_title"`
|
||||||
|
Overview string `json:"overview"`
|
||||||
|
Popularity float32 `json:"popularity"`
|
||||||
|
Poster string `json:"poster_path"`
|
||||||
|
ProductionCompanies []ProductionCompany `json:"production_companies"`
|
||||||
|
ProductionCountries []ProductionCountry `json:"production_countries"`
|
||||||
|
ReleaseDate string `json:"release_date"`
|
||||||
|
Revenue int `json:"revenue"`
|
||||||
|
Runtime int `json:"runtime"`
|
||||||
|
SpokenLanguages []SpokenLanguage `json:"spoken_languages"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Tagline string `json:"tagline"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Video bool `json:"video"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetMovie(id int32, token string) (*Movie, error) {
|
||||||
|
url := fmt.Sprintf("https://api.themoviedb.org/3/movie/%v?language=en-US", id)
|
||||||
|
data, err := tmdbGet(url, token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "tmdbGet")
|
||||||
|
}
|
||||||
|
movie := Movie{}
|
||||||
|
json.Unmarshal(data, &movie)
|
||||||
|
return &movie, nil
|
||||||
|
}
|
||||||
42
tmdb/movie_functions.go
Normal file
42
tmdb/movie_functions.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package tmdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"path"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (movie *Movie) FRuntime() string {
|
||||||
|
hours := movie.Runtime / 60
|
||||||
|
mins := movie.Runtime % 60
|
||||||
|
return fmt.Sprintf("%dh %02dm", hours, mins)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (movie *Movie) GetPoster(image *Image, size string) string {
|
||||||
|
base, err := url.Parse(image.SecureBaseURL)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
fullPath := path.Join(base.Path, size, movie.Poster)
|
||||||
|
base.Path = fullPath
|
||||||
|
return base.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (movie *Movie) ReleaseYear() string {
|
||||||
|
if movie.ReleaseDate == "" {
|
||||||
|
return ""
|
||||||
|
} else {
|
||||||
|
return "(" + movie.ReleaseDate[:4] + ")"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (movie *Movie) FGenres() string {
|
||||||
|
genres := ""
|
||||||
|
for _, genre := range movie.Genres {
|
||||||
|
genres += genre.Name + ", "
|
||||||
|
}
|
||||||
|
if len(genres) > 2 {
|
||||||
|
return genres[:len(genres)-2]
|
||||||
|
}
|
||||||
|
return genres
|
||||||
|
}
|
||||||
28
tmdb/request.go
Normal file
28
tmdb/request.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package tmdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
func tmdbGet(url string, token string) ([]byte, error) {
|
||||||
|
req, err := http.NewRequest("GET", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "http.NewRequest")
|
||||||
|
}
|
||||||
|
req.Header.Add("accept", "application/json")
|
||||||
|
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token))
|
||||||
|
res, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "http.DefaultClient.Do")
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
body, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "io.ReadAll")
|
||||||
|
}
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
79
tmdb/search.go
Normal file
79
tmdb/search.go
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
package tmdb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"path"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Result struct {
|
||||||
|
Page int `json:"page"`
|
||||||
|
TotalPages int `json:"total_pages"`
|
||||||
|
TotalResults int `json:"total_results"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ResultMovies struct {
|
||||||
|
Result
|
||||||
|
Results []ResultMovie `json:"results"`
|
||||||
|
}
|
||||||
|
type ResultMovie struct {
|
||||||
|
Adult bool `json:"adult"`
|
||||||
|
BackdropPath string `json:"backdrop_path"`
|
||||||
|
GenreIDs []int `json:"genre_ids"`
|
||||||
|
ID int32 `json:"id"`
|
||||||
|
OriginalLanguage string `json:"original_language"`
|
||||||
|
OriginalTitle string `json:"original_title"`
|
||||||
|
Overview string `json:"overview"`
|
||||||
|
Popularity int `json:"popularity"`
|
||||||
|
PosterPath string `json:"poster_path"`
|
||||||
|
ReleaseDate string `json:"release_date"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
Video bool `json:"video"`
|
||||||
|
VoteAverage int `json:"vote_average"`
|
||||||
|
VoteCount int `json:"vote_count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (movie *ResultMovie) GetPoster(image *Image, size string) string {
|
||||||
|
base, err := url.Parse(image.SecureBaseURL)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
fullPath := path.Join(base.Path, size, movie.PosterPath)
|
||||||
|
base.Path = fullPath
|
||||||
|
return base.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (movie *ResultMovie) ReleaseYear() string {
|
||||||
|
if movie.ReleaseDate == "" {
|
||||||
|
return ""
|
||||||
|
} else {
|
||||||
|
return "(" + movie.ReleaseDate[:4] + ")"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: genres list https://developer.themoviedb.org/reference/genre-movie-list
|
||||||
|
// func (movie *ResultMovie) FGenres() string {
|
||||||
|
// genres := ""
|
||||||
|
// for _, genre := range movie.Genres {
|
||||||
|
// genres += genre.Name + ", "
|
||||||
|
// }
|
||||||
|
// return genres[:len(genres)-2]
|
||||||
|
// }
|
||||||
|
|
||||||
|
func SearchMovies(token string, query string, adult bool, page int) (*ResultMovies, error) {
|
||||||
|
url := "https://api.themoviedb.org/3/search/movie" +
|
||||||
|
fmt.Sprintf("?query=%s", url.QueryEscape(query)) +
|
||||||
|
fmt.Sprintf("&include_adult=%t", adult) +
|
||||||
|
fmt.Sprintf("&page=%v", page) +
|
||||||
|
"&language=en-US"
|
||||||
|
response, err := tmdbGet(url, token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "tmdbGet")
|
||||||
|
}
|
||||||
|
var results ResultMovies
|
||||||
|
json.Unmarshal(response, &results)
|
||||||
|
return &results, nil
|
||||||
|
}
|
||||||
24
tmdb/structs.go
Normal file
24
tmdb/structs.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package tmdb
|
||||||
|
|
||||||
|
type Genre struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProductionCompany struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Logo string `json:"logo_path"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
OriginCountry string `json:"origin_country"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProductionCountry struct {
|
||||||
|
ISO_3166_1 string `json:"iso_3166_1"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SpokenLanguage struct {
|
||||||
|
EnglishName string `json:"english_name"`
|
||||||
|
ISO_639_1 string `json:"iso_639_1"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user