added discord api limiting
This commit is contained in:
46
internal/store/newlogin.go
Normal file
46
internal/store/newlogin.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"git.haelnorr.com/h/oslstats/internal/discord"
|
||||
"github.com/bwmarrin/discordgo"
|
||||
)
|
||||
|
||||
type RegistrationSession struct {
|
||||
DiscordUser *discordgo.User
|
||||
Token *discord.Token
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
func (s *Store) CreateRegistrationSession(user *discordgo.User, token *discord.Token) (string, error) {
|
||||
if user == nil {
|
||||
return "", errors.New("user cannot be nil")
|
||||
}
|
||||
if token == nil {
|
||||
return "", errors.New("token cannot be nil")
|
||||
}
|
||||
id := generateID()
|
||||
s.sessions.Store(id, &RegistrationSession{
|
||||
DiscordUser: user,
|
||||
Token: token,
|
||||
ExpiresAt: time.Now().Add(5 * time.Minute),
|
||||
})
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetRegistrationSession(id string) (*RegistrationSession, bool) {
|
||||
val, ok := s.sessions.Load(id)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
session := val.(*RegistrationSession)
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
s.sessions.Delete(id)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return session, true
|
||||
}
|
||||
95
internal/store/redirects.go
Normal file
95
internal/store/redirects.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// getClientIP extracts the client IP address, checking X-Forwarded-For first
|
||||
func getClientIP(r *http.Request) string {
|
||||
// Check X-Forwarded-For header (comma-separated list, first is client)
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// Take the first IP in the list
|
||||
ips := strings.Split(xff, ",")
|
||||
if len(ips) > 0 {
|
||||
return strings.TrimSpace(ips[0])
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to RemoteAddr (format: "IP:port" or "[IPv6]:port")
|
||||
// Use net.SplitHostPort to properly handle both IPv4 and IPv6
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
// If SplitHostPort fails, return as-is (shouldn't happen with valid RemoteAddr)
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// TrackRedirect increments the redirect counter for this IP+UA+Path combination
|
||||
// Returns the current attempt count, whether limit was exceeded, and the track details
|
||||
func (s *Store) TrackRedirect(r *http.Request, path string, maxAttempts int) (attempts int, exceeded bool, track *RedirectTrack) {
|
||||
if r == nil {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
ip := getClientIP(r)
|
||||
userAgent := r.UserAgent()
|
||||
key := redirectKey(ip, userAgent, path)
|
||||
|
||||
now := time.Now()
|
||||
expiresAt := now.Add(5 * time.Minute)
|
||||
|
||||
// Try to load existing track
|
||||
val, exists := s.redirectTracks.Load(key)
|
||||
if exists {
|
||||
track = val.(*RedirectTrack)
|
||||
|
||||
// Check if expired
|
||||
if now.After(track.ExpiresAt) {
|
||||
// Expired, start fresh
|
||||
track = &RedirectTrack{
|
||||
IP: ip,
|
||||
UserAgent: userAgent,
|
||||
Path: path,
|
||||
Attempts: 1,
|
||||
FirstSeen: now,
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
s.redirectTracks.Store(key, track)
|
||||
return 1, false, track
|
||||
}
|
||||
|
||||
// Increment existing
|
||||
track.Attempts++
|
||||
track.ExpiresAt = expiresAt // Extend expiry
|
||||
exceeded = track.Attempts >= maxAttempts
|
||||
return track.Attempts, exceeded, track
|
||||
}
|
||||
|
||||
// Create new track
|
||||
track = &RedirectTrack{
|
||||
IP: ip,
|
||||
UserAgent: userAgent,
|
||||
Path: path,
|
||||
Attempts: 1,
|
||||
FirstSeen: now,
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
s.redirectTracks.Store(key, track)
|
||||
return 1, false, track
|
||||
}
|
||||
|
||||
// ClearRedirectTrack removes a redirect tracking entry (called after successful completion)
|
||||
func (s *Store) ClearRedirectTrack(r *http.Request, path string) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ip := getClientIP(r)
|
||||
userAgent := r.UserAgent()
|
||||
key := redirectKey(ip, userAgent, path)
|
||||
s.redirectTracks.Delete(key)
|
||||
}
|
||||
80
internal/store/store.go
Normal file
80
internal/store/store.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RedirectTrack represents a single redirect attempt tracking entry
|
||||
type RedirectTrack struct {
|
||||
IP string // Client IP (X-Forwarded-For aware)
|
||||
UserAgent string // Full User-Agent string for debugging
|
||||
Path string // Request path (without query params)
|
||||
Attempts int // Number of redirect attempts
|
||||
FirstSeen time.Time // When first redirect was tracked
|
||||
ExpiresAt time.Time // When to clean up this entry
|
||||
}
|
||||
|
||||
type Store struct {
|
||||
sessions sync.Map // key: string, value: *RegistrationSession
|
||||
redirectTracks sync.Map // key: string, value: *RedirectTrack
|
||||
cleanup *time.Ticker
|
||||
}
|
||||
|
||||
func NewStore() *Store {
|
||||
s := &Store{
|
||||
cleanup: time.NewTicker(1 * time.Minute),
|
||||
}
|
||||
|
||||
// Background cleanup of expired sessions
|
||||
go func() {
|
||||
for range s.cleanup.C {
|
||||
s.cleanupExpired()
|
||||
}
|
||||
}()
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Store) Delete(id string) {
|
||||
s.sessions.Delete(id)
|
||||
}
|
||||
func (s *Store) cleanupExpired() {
|
||||
now := time.Now()
|
||||
|
||||
// Clean up expired registration sessions
|
||||
s.sessions.Range(func(key, value any) bool {
|
||||
session := value.(*RegistrationSession)
|
||||
if now.After(session.ExpiresAt) {
|
||||
s.sessions.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Clean up expired redirect tracks
|
||||
s.redirectTracks.Range(func(key, value any) bool {
|
||||
track := value.(*RedirectTrack)
|
||||
if now.After(track.ExpiresAt) {
|
||||
s.redirectTracks.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
func generateID() string {
|
||||
b := make([]byte, 32)
|
||||
rand.Read(b)
|
||||
return base64.RawURLEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
// redirectKey generates a unique key for tracking redirects
|
||||
// Uses IP + first 100 chars of UA + path as key (not hashed for debugging)
|
||||
func redirectKey(ip, userAgent, path string) string {
|
||||
ua := userAgent
|
||||
if len(ua) > 100 {
|
||||
ua = ua[:100]
|
||||
}
|
||||
return fmt.Sprintf("%s:%s:%s", ip, ua, path)
|
||||
}
|
||||
Reference in New Issue
Block a user