pkg/auth/auth.go 6.2 K raw
1
// Package auth provides session storage, password verification and small
2
// helpers shared by andromeda Go apps.
3
package auth
4
5
import (
6
	"crypto/rand"
7
	"crypto/subtle"
8
	"database/sql"
9
	"encoding/hex"
10
	"errors"
11
	"net/http"
12
	"strings"
13
	"time"
14
15
	"golang.org/x/crypto/bcrypt"
16
)
17
18
const shortIDAlphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-"
19
20
// Store manages session tokens in a sessions table inside the given DB.
21
type Store struct {
22
	DB           *sql.DB
23
	CookieName   string
24
	CookieSecure bool
25
	// MaxAge is the cookie lifetime. Defaults to 7 days when zero.
26
	MaxAge time.Duration
27
}
28
29
// EnsureSchema creates the sessions table if it does not exist.
30
func (s *Store) EnsureSchema() error {
31
	_, err := s.DB.Exec(`CREATE TABLE IF NOT EXISTS sessions (
32
		id         INTEGER PRIMARY KEY AUTOINCREMENT,
33
		token      TEXT NOT NULL UNIQUE,
34
		expires_at TEXT NOT NULL
35
	)`)
36
	return err
37
}
38
39
// Create issues a new session, persists it, and returns the raw token.
40
func (s *Store) Create() (string, error) {
41
	token, err := GenerateSessionToken()
42
	if err != nil {
43
		return "", err
44
	}
45
	expires := time.Now().UTC().Add(s.maxAge())
46
	if _, err := s.DB.Exec(`INSERT INTO sessions (token, expires_at) VALUES (?, ?)`,
47
		token, expires.Format("2006-01-02 15:04:05")); err != nil {
48
		return "", err
49
	}
50
	return token, nil
51
}
52
53
// Valid reports whether the given token exists and has not expired.
54
func (s *Store) Valid(token string) bool {
55
	if token == "" {
56
		return false
57
	}
58
	var expires string
59
	err := s.DB.QueryRow(`SELECT expires_at FROM sessions WHERE token = ?`, token).Scan(&expires)
60
	if err != nil {
61
		return false
62
	}
63
	t, err := time.ParseInLocation("2006-01-02 15:04:05", expires, time.UTC)
64
	return err == nil && t.After(time.Now().UTC())
65
}
66
67
// Delete removes the given session token if present.
68
func (s *Store) Delete(token string) {
69
	_, _ = s.DB.Exec(`DELETE FROM sessions WHERE token = ?`, token)
70
}
71
72
// PruneExpired removes all expired session rows.
73
func (s *Store) PruneExpired() {
74
	_, _ = s.DB.Exec(`DELETE FROM sessions WHERE expires_at < datetime('now')`)
75
}
76
77
// HasValid checks the cookie on r and returns true if it carries a live session.
78
func (s *Store) HasValid(r *http.Request) bool {
79
	c, err := r.Cookie(s.CookieName)
80
	if err != nil || c.Value == "" {
81
		return false
82
	}
83
	return s.Valid(c.Value)
84
}
85
86
// RequireSession wraps next and redirects to redirectPath when no valid session
87
// cookie is present.
88
func (s *Store) RequireSession(redirectPath string, next http.HandlerFunc) http.HandlerFunc {
89
	return func(w http.ResponseWriter, r *http.Request) {
90
		if !s.HasValid(r) {
91
			http.Redirect(w, r, redirectPath, http.StatusSeeOther)
92
			return
93
		}
94
		next(w, r)
95
	}
96
}
97
98
// SessionCookie returns a cookie configured with the store's name, security
99
// flag and MaxAge.
100
func (s *Store) SessionCookie(token string) *http.Cookie {
101
	return &http.Cookie{
102
		Name:     s.CookieName,
103
		Value:    token,
104
		Path:     "/",
105
		HttpOnly: true,
106
		Secure:   s.CookieSecure,
107
		SameSite: http.SameSiteStrictMode,
108
		MaxAge:   int(s.maxAge().Seconds()),
109
	}
110
}
111
112
// ClearCookie returns an expired cookie used to log out a session.
113
func (s *Store) ClearCookie() *http.Cookie {
114
	return &http.Cookie{
115
		Name:     s.CookieName,
116
		Value:    "",
117
		Path:     "/",
118
		HttpOnly: true,
119
		Secure:   s.CookieSecure,
120
		SameSite: http.SameSiteStrictMode,
121
		MaxAge:   -1,
122
	}
123
}
124
125
func (s *Store) maxAge() time.Duration {
126
	if s.MaxAge > 0 {
127
		return s.MaxAge
128
	}
129
	return 7 * 24 * time.Hour
130
}
131
132
// RequireAPIKey wraps next with a strict X-API-Key header check.
133
// Returns 403 if expected is empty, 401 on mismatch.
134
func RequireAPIKey(expected string, next http.HandlerFunc) http.HandlerFunc {
135
	return func(w http.ResponseWriter, r *http.Request) {
136
		if expected == "" {
137
			http.Error(w, "API key not configured on server", http.StatusForbidden)
138
			return
139
		}
140
		if !SecureEqual(r.Header.Get("x-api-key"), expected) {
141
			http.Error(w, "Invalid API key", http.StatusUnauthorized)
142
			return
143
		}
144
		next(w, r)
145
	}
146
}
147
148
// RequireBearerOrSession allows requests carrying either a matching bearer
149
// token or a valid session cookie. Falls through with 401 JSON otherwise.
150
func RequireBearerOrSession(store *Store, expectedBearer string, next http.HandlerFunc) http.HandlerFunc {
151
	return func(w http.ResponseWriter, r *http.Request) {
152
		if expectedBearer != "" {
153
			authz := r.Header.Get("Authorization")
154
			if strings.HasPrefix(strings.ToLower(authz), "bearer ") &&
155
				SecureEqual(strings.TrimSpace(authz[7:]), expectedBearer) {
156
				next(w, r)
157
				return
158
			}
159
		}
160
		if store != nil && store.HasValid(r) {
161
			next(w, r)
162
			return
163
		}
164
		w.Header().Set("Content-Type", "application/json")
165
		w.WriteHeader(http.StatusUnauthorized)
166
		_, _ = w.Write([]byte(`{"error":"unauthorized"}`))
167
	}
168
}
169
170
// VerifyPassword checks input against expected. If expected looks like a bcrypt
171
// hash (`$2`-prefixed) it is compared as such, otherwise as plain text.
172
func VerifyPassword(input, expected string) bool {
173
	if strings.HasPrefix(expected, "$2") {
174
		return bcrypt.CompareHashAndPassword([]byte(expected), []byte(input)) == nil
175
	}
176
	return SecureEqual(input, expected)
177
}
178
179
// SecureEqual reports whether a and b are equal in constant time. Inputs are
180
// padded/truncated to a fixed 256-byte buffer so length differences don't leak
181
// via timing. A length-equal mask is AND-ed with the buffer compare.
182
func SecureEqual(a, b string) bool {
183
	const padLen = 256
184
	var bufA, bufB [padLen]byte
185
	ab := []byte(a)
186
	bb := []byte(b)
187
	na := min(len(ab), padLen)
188
	nb := min(len(bb), padLen)
189
	copy(bufA[:na], ab[:na])
190
	copy(bufB[:nb], bb[:nb])
191
	lengthsMatch := subtle.ConstantTimeEq(int32(len(ab)), int32(len(bb)))
192
	bytesMatch := subtle.ConstantTimeCompare(bufA[:], bufB[:])
193
	return (lengthsMatch & bytesMatch) == 1
194
}
195
196
// GenerateSessionToken returns a 32-byte random hex token.
197
func GenerateSessionToken() (string, error) {
198
	buf := make([]byte, 32)
199
	if _, err := rand.Read(buf); err != nil {
200
		return "", err
201
	}
202
	return hex.EncodeToString(buf), nil
203
}
204
205
// GenerateShortID returns an n-character URL-safe identifier.
206
func GenerateShortID(n int) (string, error) {
207
	if n <= 0 {
208
		return "", errors.New("auth: short id length must be positive")
209
	}
210
	buf := make([]byte, n)
211
	if _, err := rand.Read(buf); err != nil {
212
		return "", err
213
	}
214
	out := make([]byte, n)
215
	for i, b := range buf {
216
		out[i] = shortIDAlphabet[int(b)%len(shortIDAlphabet)]
217
	}
218
	return string(out), nil
219
}