Files
projectreshoot/pkg/db/safetx.go

164 lines
3.9 KiB
Go

package db
import (
"context"
"database/sql"
"regexp"
"strings"
"github.com/pkg/errors"
)
type SafeTX interface {
Query(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRow(ctx context.Context, query string, args ...interface{}) (*sql.Row, error)
Commit() error
Rollback() error
}
// Extends sql.Tx for use with SafeConn
type SafeWTX struct {
tx *sql.Tx
sc *SafeConn
}
type SafeRTX struct {
tx *sql.Tx
sc *SafeConn
}
func isWriteOperation(query string) bool {
query = strings.TrimSpace(query)
query = strings.ToUpper(query)
writeOpsRegex := `^(INSERT|UPDATE|DELETE|REPLACE|MERGE|CREATE|DROP|ALTER|TRUNCATE)\s+`
re := regexp.MustCompile(writeOpsRegex)
return re.MatchString(query)
}
// Query the database inside the transaction
func (stx *SafeRTX) Query(
ctx context.Context,
query string,
args ...interface{},
) (*sql.Rows, error) {
if stx.tx == nil {
return nil, errors.New("Cannot query without a transaction")
}
if isWriteOperation(query) {
return nil, errors.New("Cannot query with a write operation")
}
rows, err := stx.tx.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.Wrap(err, "tx.QueryContext")
}
return rows, nil
}
// Query the database inside the transaction
func (stx *SafeWTX) Query(
ctx context.Context,
query string,
args ...interface{},
) (*sql.Rows, error) {
if stx.tx == nil {
return nil, errors.New("Cannot query without a transaction")
}
if isWriteOperation(query) {
return nil, errors.New("Cannot query with a write operation")
}
rows, err := stx.tx.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.Wrap(err, "tx.QueryContext")
}
return rows, nil
}
// Query a row from the database inside the transaction
func (stx *SafeRTX) QueryRow(
ctx context.Context,
query string,
args ...interface{},
) (*sql.Row, error) {
if stx.tx == nil {
return nil, errors.New("Cannot query without a transaction")
}
if isWriteOperation(query) {
return nil, errors.New("Cannot query with a write operation")
}
return stx.tx.QueryRowContext(ctx, query, args...), nil
}
// Query a row from the database inside the transaction
func (stx *SafeWTX) QueryRow(
ctx context.Context,
query string,
args ...interface{},
) (*sql.Row, error) {
if stx.tx == nil {
return nil, errors.New("Cannot query without a transaction")
}
if isWriteOperation(query) {
return nil, errors.New("Cannot query with a write operation")
}
return stx.tx.QueryRowContext(ctx, query, args...), nil
}
// Exec a statement on the database inside the transaction
func (stx *SafeWTX) Exec(
ctx context.Context,
query string,
args ...interface{},
) (sql.Result, error) {
if stx.tx == nil {
return nil, errors.New("Cannot exec without a transaction")
}
res, err := stx.tx.ExecContext(ctx, query, args...)
if err != nil {
return nil, errors.Wrap(err, "tx.ExecContext")
}
return res, nil
}
// Commit the current transaction and release the read lock
func (stx *SafeRTX) Commit() error {
if stx.tx == nil {
return errors.New("Cannot commit without a transaction")
}
err := stx.tx.Commit()
stx.tx = nil
stx.sc.releaseReadLock()
return err
}
// Commit the current transaction and release the read lock
func (stx *SafeWTX) Commit() error {
if stx.tx == nil {
return errors.New("Cannot commit without a transaction")
}
err := stx.tx.Commit()
stx.tx = nil
stx.sc.releaseReadLock()
return err
}
// Abort the current transaction, releasing the read lock
func (stx *SafeRTX) Rollback() error {
if stx.tx == nil {
return errors.New("Cannot rollback without a transaction")
}
err := stx.tx.Rollback()
stx.tx = nil
stx.sc.releaseReadLock()
return err
}
// Abort the current transaction, releasing the read lock
func (stx *SafeWTX) Rollback() error {
if stx.tx == nil {
return errors.New("Cannot rollback without a transaction")
}
err := stx.tx.Rollback()
stx.tx = nil
stx.sc.releaseReadLock()
return err
}