-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathregistry.go
183 lines (156 loc) · 4.96 KB
/
registry.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
package sqld
import (
"database/sql"
"errors"
"fmt"
"reflect"
"sync"
"time"
"github.com/jackc/pgx/v5/pgtype"
)
// Registry is a type-safe registry for model metadata and scanners
type Registry struct {
models map[reflect.Type]ModelMetadata
scanners map[reflect.Type]func() sql.Scanner
mu sync.RWMutex
}
// NewRegistry returns a new instance of the registry
func NewRegistry() *Registry {
return &Registry{
models: make(map[reflect.Type]ModelMetadata),
scanners: make(map[reflect.Type]func() sql.Scanner),
}
}
// defaultRegistry is the default global registry instance
var defaultRegistry = NewRegistry()
// Register adds a model's metadata to the registry
func Register[T Model]() error {
var model T
return defaultRegistry.Register(model)
}
// RegisterScanner registers a function that creates scanners for a specific type
func RegisterScanner(t reflect.Type, scannerFactory func() sql.Scanner) {
defaultRegistry.RegisterScanner(t, scannerFactory)
}
// getModelMetadata retrieves metadata for a model type
func getModelMetadata(model Model) (ModelMetadata, error) {
// First attempt to get from registry
metadata, err := defaultRegistry.GetModelMetadata(model)
if err != nil {
// Check if it's a "not registered" error
var notRegistered *ErrModelNotRegistered
if errors.As(err, ¬Registered) {
// Attempt lazy registration with proper locking
if regErr := defaultRegistry.Register(model); regErr != nil {
return ModelMetadata{}, fmt.Errorf("failed lazy-registering model: %w", regErr)
}
// After registration, try to get metadata again
metadata, err = defaultRegistry.GetModelMetadata(model)
if err != nil {
return ModelMetadata{}, fmt.Errorf("failed to get model metadata after lazy registration: %w", err)
}
return metadata, nil
}
// Some other error occurred
return ModelMetadata{}, err
}
return metadata, nil
}
// ErrModelNotRegistered is returned when a model is not found in the registry
type ErrModelNotRegistered struct {
ModelType reflect.Type
}
func (e *ErrModelNotRegistered) Error() string {
return fmt.Sprintf("model %s not registered", e.ModelType.Name())
}
// Register adds a model's metadata to the registry
func (r *Registry) Register(model Model) error {
r.mu.Lock()
defer r.mu.Unlock()
t := reflect.TypeOf(model)
// If model is already registered, silently succeed
if _, exists := r.models[t]; exists {
return nil
}
metadata := ModelMetadata{
TableName: model.TableName(),
Fields: make(map[string]Field),
}
// Reflect over the struct fields
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
// Get database column name from db tag
dbName := field.Tag.Get("db")
if dbName == "" {
return fmt.Errorf("field %q missing required db tag", field.Name)
}
// Get JSON name from json tag
jsonName := field.Tag.Get("json")
if jsonName == "" {
return fmt.Errorf("field %q missing required json tag", field.Name)
}
metadata.Fields[jsonName] = Field{
Name: dbName, // Store DB column name
JSONName: jsonName, // Store JSON field name
GoFieldName: field.Name, // Store Go field name
Type: field.Type,
NormalizedType: normalizeReflectType(field.Type),
}
}
r.models[t] = metadata
return nil
}
// normalizeReflectType normalizes a reflect.Type to a simpler form for validation
func normalizeReflectType(rt reflect.Type) reflect.Type {
// Strip pointer layers
for rt.Kind() == reflect.Pointer {
rt = rt.Elem()
}
// Handle pgtype types
switch rt {
case reflect.TypeOf(pgtype.Text{}):
return reflect.TypeOf("")
case reflect.TypeOf(pgtype.Numeric{}):
return reflect.TypeOf(float64(0))
case reflect.TypeOf(pgtype.Int8{}):
return reflect.TypeOf(int64(0))
case reflect.TypeOf(pgtype.Int4{}):
return reflect.TypeOf(int32(0))
case reflect.TypeOf(pgtype.Bool{}):
return reflect.TypeOf(bool(false))
case reflect.TypeOf(pgtype.Timestamptz{}):
return reflect.TypeOf(time.Time{})
case reflect.TypeOf(pgtype.Date{}):
return reflect.TypeOf(time.Time{})
}
// If underlying kind is string (including custom string-based enums),
// treat it as plain string for validation
if rt.Kind() == reflect.String {
return reflect.TypeOf("")
}
return rt
}
// RegisterScanner registers a function that creates scanners for a specific type
func (r *Registry) RegisterScanner(t reflect.Type, scannerFactory func() sql.Scanner) {
r.mu.Lock()
defer r.mu.Unlock()
r.scanners[t] = scannerFactory
}
// GetModelMetadata retrieves metadata for a model type
func (r *Registry) GetModelMetadata(model Model) (ModelMetadata, error) {
r.mu.RLock()
defer r.mu.RUnlock()
t := reflect.TypeOf(model)
metadata, ok := r.models[t]
if !ok {
return ModelMetadata{}, &ErrModelNotRegistered{ModelType: t}
}
return metadata, nil
}
// GetScanner returns a scanner factory for the given type, if registered
func (r *Registry) GetScanner(t reflect.Type) (func() sql.Scanner, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
factory, ok := r.scanners[t]
return factory, ok
}