Skip to content

Commit 2c2136f

Browse files
feat: add initial implementation of the 2FA recovery handler
1 parent f9b482e commit 2c2136f

File tree

11 files changed

+366
-4
lines changed

11 files changed

+366
-4
lines changed

internal/dbsqlc/copyfrom.go

+45
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/dbsqlc/db.go

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/dbsqlc/models.go

+11
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
-- name: CreateRecoveryCodeBatch :copyfrom
2+
INSERT INTO shield_recovery_codes
3+
(id, user_id, recovery_code_hash, is_consumable)
4+
VALUES
5+
(@id::UUID, @user_id::UUID, @recovery_code_hash, @is_consumable);
6+
7+
-- name: EvictUnconsumedRecoveryCodeBatch :exec
8+
UPDATE shield_recovery_codes
9+
SET
10+
evicted_by = @evicted_by::UUID,
11+
evicted_at = NOW()
12+
WHERE user_id = @user_id AND is_consumable = TRUE;

internal/dbsqlc/recovery_code_query.sql.go

+37
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
-- migration: 1731842623024_recovery_code.sql
2+
3+
CREATE TABLE IF NOT EXISTS shield_recovery_codes (
4+
id UUID NOT NULL,
5+
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
6+
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
7+
user_id UUID NOT NULL,
8+
recovery_code_hash VARCHAR(4095) NOT NULL,
9+
is_consumable BOOL NOT NULL DEFAULT TRUE,
10+
evicted_by UUID NULL,
11+
evicted_at TIMESTAMP NULL DEFAULT NULL,
12+
PRIMARY KEY (id),
13+
FOREIGN KEY (user_id) REFERENCES shield_users (id)
14+
ON DELETE CASCADE,
15+
FOREIGN KEY (evicted_by) REFERENCES shield_users (id)
16+
ON DELETE CASCADE
17+
);
18+
19+
---- create above / drop below ----
20+
21+
DROP TABLE IF EXISTS shield_recovery_codes;

internal/random/random.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ func secureBytes(l int) ([]byte, error) {
1111
bytes := make([]byte, l)
1212
_, err := rand.Read(bytes)
1313
if err != nil {
14-
return bytes, fmt.Errorf("random: error reading random bytes: %w", err)
14+
return bytes, fmt.Errorf("shield: error reading random bytes: %w", err)
1515
}
1616
return bytes, nil
1717
}

shieldpassword/bcrypt.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@ type bcryptPasswordHasher struct {
1717
cost int
1818
}
1919

20-
// NewBcryptPasswordHasher implements a password hashing algorithm with bcrypt.
20+
// NewBcryptPasswordHasher creates a password hasher using the bcrypt algorithm.
21+
//
22+
// Please note that bcrypt has a maximum input length of 72 bytes. For passwords
23+
// requiring more than 72 bytes of data, consider using an alternative algorithm
24+
// such as Argon2.
2125
func NewBcryptPasswordHasher(cost int) PasswordHasher {
2226
return &bcryptPasswordHasher{cost}
2327
}

shieldrecoverycode/handler.go

+232
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
package shieldrecoverycode
2+
3+
import (
4+
"cmp"
5+
"context"
6+
"fmt"
7+
"log/slog"
8+
9+
"github.com/google/uuid"
10+
"github.com/jackc/pgx/v5"
11+
"github.com/jackc/pgx/v5/pgxpool"
12+
"go.inout.gg/foundations/debug"
13+
"go.inout.gg/shield"
14+
"go.inout.gg/shield/internal/dbsqlc"
15+
"go.inout.gg/shield/internal/random"
16+
"go.inout.gg/shield/internal/uuidv7"
17+
"go.inout.gg/shield/shieldpassword"
18+
)
19+
20+
var _ Generator = (*generator)(nil)
21+
22+
const (
23+
DefaultRecoveryCodeTotalCount = 16
24+
DefaultRecoveryCodeLength = 16
25+
)
26+
27+
var DefaultGenerator Generator = &generator{}
28+
29+
// Generator provides methods to create a set of unique recovery codes used
30+
// for 2FA authentication recovery.
31+
type Generator interface {
32+
Generate(count int, len int) ([]string, error)
33+
}
34+
35+
type generator struct{}
36+
37+
// Generate creates count number of secure random recovery codes.
38+
// Each code is length bytes long and encoded as a hex string.
39+
func (g *generator) Generate(count, length int) ([]string, error) {
40+
codes := make([]string, count, count)
41+
42+
for i := range count {
43+
code, err := random.SecureHexString(length)
44+
if err != nil {
45+
return nil, err
46+
}
47+
48+
codes[i] = code
49+
}
50+
51+
return codes, nil
52+
}
53+
54+
type Config struct {
55+
Logger *slog.Logger
56+
PasswordHasher shieldpassword.PasswordHasher
57+
Generator Generator
58+
RecoveryCodeTotalCount int
59+
RecoveryCodeLength int
60+
}
61+
62+
func (c *Config) defaults() {
63+
c.PasswordHasher = cmp.Or(c.PasswordHasher, shieldpassword.DefaultPasswordHasher)
64+
c.Logger = cmp.Or(c.Logger, shield.DefaultLogger)
65+
c.RecoveryCodeTotalCount = cmp.Or(c.RecoveryCodeTotalCount, DefaultRecoveryCodeTotalCount)
66+
c.RecoveryCodeLength = cmp.Or(c.RecoveryCodeLength, DefaultRecoveryCodeLength)
67+
c.Generator = cmp.Or(c.Generator, DefaultGenerator)
68+
}
69+
70+
func (c *Config) assert() {
71+
debug.Assert(c.Logger != nil, "expected Logger to be defined")
72+
debug.Assert(c.PasswordHasher != nil, "expected PasswordHasher to be defined")
73+
debug.Assert(c.Generator != nil, "expected Generator to be defined")
74+
}
75+
76+
func NewConfig(opts ...func(*Config)) *Config {
77+
c := &Config{}
78+
for _, opt := range opts {
79+
opt(c)
80+
}
81+
82+
c.defaults()
83+
c.assert()
84+
85+
return c
86+
}
87+
88+
type Handler struct {
89+
config *Config
90+
pool *pgxpool.Pool
91+
}
92+
93+
func New(pool *pgxpool.Pool, config *Config) *Handler {
94+
if config == nil {
95+
config = NewConfig()
96+
}
97+
98+
h := Handler{config, pool}
99+
h.assert()
100+
101+
return &h
102+
}
103+
104+
func (h *Handler) assert() {
105+
h.config.assert()
106+
debug.Assert(h.pool != nil, "expected pool to be defined")
107+
}
108+
109+
func (h *Handler) Generate() ([]string, error) {
110+
codes, err := h.config.Generator.Generate(h.config.RecoveryCodeTotalCount, h.config.RecoveryCodeLength)
111+
if err != nil {
112+
return nil, err
113+
}
114+
115+
hashedCodes := make([]string, len(codes))
116+
for i, code := range codes {
117+
hashedCode, err := h.config.PasswordHasher.Hash(code)
118+
if err != nil {
119+
return nil, err
120+
}
121+
122+
hashedCodes[i] = hashedCode
123+
}
124+
125+
return hashedCodes, nil
126+
}
127+
128+
// CreateRecoveryCodes generates a new set of recovery codes
129+
func (h *Handler) CreateRecoveryCodes(ctx context.Context, userID uuid.UUID) error {
130+
codes, err := h.Generate()
131+
if err != nil {
132+
return err
133+
}
134+
135+
tx, err := h.pool.Begin(ctx)
136+
if err != nil {
137+
return fmt.Errorf("shield/recovery_code: failed to begin transaction: %w", err)
138+
}
139+
defer tx.Rollback(ctx)
140+
141+
if err := h.CreateRecoveryCodesInTx(ctx, userID, codes, tx); err != nil {
142+
return err
143+
}
144+
145+
if err := tx.Commit(ctx); err != nil {
146+
return fmt.Errorf("shield/recovery_code: failed to commit transaction: %w", err)
147+
}
148+
149+
return nil
150+
}
151+
152+
// ReplaceRecoveryCodes regenerates a new set of recovery codes and replaces
153+
// any previously unconsumed recovery codes with the newly generated set.
154+
//
155+
// userID is the ID of the user to update recovery codes for
156+
func (h *Handler) ReplaceRecoveryCodes(ctx context.Context, userID, replacedBy uuid.UUID) error {
157+
codes, err := h.Generate()
158+
if err != nil {
159+
return err
160+
}
161+
162+
tx, err := h.pool.Begin(ctx)
163+
if err != nil {
164+
return fmt.Errorf("shield/recovery_code: failed to begin transaction: %w", err)
165+
}
166+
defer tx.Rollback(ctx)
167+
168+
if err := h.ReplaceRecoveryCodesInTx(ctx, userID, replacedBy, codes, tx); err != nil {
169+
return err
170+
}
171+
172+
if err := tx.Commit(ctx); err != nil {
173+
return fmt.Errorf("shield/recovery_code: failed to commit transaction: %w", err)
174+
}
175+
176+
return nil
177+
}
178+
179+
func (h *Handler) ReplaceRecoveryCodesInTx(
180+
ctx context.Context,
181+
userID uuid.UUID,
182+
replacedBy uuid.UUID,
183+
codes []string,
184+
tx pgx.Tx,
185+
) error {
186+
if err := h.EvictRecoveryCodesInTx(ctx, userID, replacedBy, tx); err != nil {
187+
return err
188+
}
189+
190+
if err := h.CreateRecoveryCodesInTx(ctx, userID, codes, tx); err != nil {
191+
return err
192+
}
193+
194+
return nil
195+
}
196+
197+
func (h *Handler) EvictRecoveryCodesInTx(
198+
ctx context.Context,
199+
userID uuid.UUID,
200+
evictedBy uuid.UUID,
201+
tx pgx.Tx,
202+
) error {
203+
arg := dbsqlc.EvictUnconsumedRecoveryCodeBatchParams{UserID: userID, EvictedBy: evictedBy}
204+
if err := dbsqlc.New().EvictUnconsumedRecoveryCodeBatch(ctx, tx, arg); err != nil {
205+
return fmt.Errorf("shield/recovery_code: failed to evict recovery codes: %w", err)
206+
}
207+
208+
return nil
209+
}
210+
211+
func (h *Handler) CreateRecoveryCodesInTx(
212+
ctx context.Context,
213+
userID uuid.UUID,
214+
codes []string,
215+
tx pgx.Tx,
216+
) error {
217+
rows := make([]dbsqlc.CreateRecoveryCodeBatchParams, len(codes))
218+
for i, code := range codes {
219+
rows[i] = dbsqlc.CreateRecoveryCodeBatchParams{
220+
ID: uuidv7.Must(),
221+
IsConsumable: true,
222+
RecoveryCodeHash: code,
223+
UserID: userID,
224+
}
225+
}
226+
227+
if _, err := dbsqlc.New().CreateRecoveryCodeBatch(ctx, tx, rows); err != nil {
228+
return fmt.Errorf("shield/recovery_code: failed to create recovery codes: %w", err)
229+
}
230+
231+
return nil
232+
}

0 commit comments

Comments
 (0)