Skip to content

Commit a60ed77

Browse files
authored
Merge pull request #62 from Songmu/preprepare
add PrePrepare, Prepare and PostPrepare to hook prepare
2 parents b039787 + 61cb1ab commit a60ed77

File tree

4 files changed

+138
-10
lines changed

4 files changed

+138
-10
lines changed

conn.go

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,30 +50,44 @@ func (conn *Conn) Prepare(query string) (driver.Stmt, error) {
5050

5151
// PrepareContext returns a prepared statement which is wrapped by Stmt.
5252
func (conn *Conn) PrepareContext(c context.Context, query string) (driver.Stmt, error) {
53-
var stmt driver.Stmt
53+
var ctx interface{}
54+
var stmt = &Stmt{
55+
QueryString: query,
56+
Proxy: conn.Proxy,
57+
Conn: conn,
58+
}
5459
var err error
60+
hooks := conn.Proxy.getHooks(c)
61+
if hooks != nil {
62+
defer func() { hooks.postPrepare(c, ctx, stmt, err) }()
63+
if ctx, err = hooks.prePrepare(c, stmt); err != nil {
64+
return nil, err
65+
}
66+
}
67+
5568
if connCtx, ok := conn.Conn.(driver.ConnPrepareContext); ok {
56-
stmt, err = connCtx.PrepareContext(c, query)
69+
stmt.Stmt, err = connCtx.PrepareContext(c, stmt.QueryString)
5770
} else {
58-
stmt, err = conn.Conn.Prepare(query)
71+
stmt.Stmt, err = conn.Conn.Prepare(stmt.QueryString)
5972
if err == nil {
6073
select {
6174
default:
6275
case <-c.Done():
63-
stmt.Close()
76+
stmt.Stmt.Close()
6477
return nil, c.Err()
6578
}
6679
}
6780
}
6881
if err != nil {
6982
return nil, err
7083
}
71-
return &Stmt{
72-
Stmt: stmt,
73-
QueryString: query,
74-
Proxy: conn.Proxy,
75-
Conn: conn,
76-
}, nil
84+
85+
if hooks != nil {
86+
if err = hooks.prepare(c, ctx, stmt); err != nil {
87+
return nil, err
88+
}
89+
}
90+
return stmt, nil
7791
}
7892

7993
// Close calls the original Close method.

hooks.go

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ type hooks interface {
1717
preOpen(c context.Context, name string) (interface{}, error)
1818
open(c context.Context, ctx interface{}, conn *Conn) error
1919
postOpen(c context.Context, ctx interface{}, conn *Conn, err error) error
20+
prePrepare(c context.Context, stmt *Stmt) (interface{}, error)
21+
prepare(c context.Context, ctx interface{}, stmt *Stmt) error
22+
postPrepare(c context.Context, ctx interface{}, stmt *Stmt, err error) error
2023
preExec(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error)
2124
exec(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result) error
2225
postExec(c context.Context, ctx interface{}, stmt *Stmt, args []driver.NamedValue, result driver.Result, err error) error
@@ -109,6 +112,36 @@ type HooksContext struct {
109112
// `Hooks.PreOpen` method, and may be nil.
110113
PostOpen func(c context.Context, ctx interface{}, conn *Conn, err error) error
111114

115+
// PrePrepare is a callback that gets called prior to calling
116+
// `db.Prepare`, and is ALWAYS called. If this callback returns an
117+
// error, the underlying driver's `db.Exec` and `Hooks.Prepare` methods
118+
// are not called.
119+
//
120+
// The first return value is passed to both `Hooks.Prepare` and
121+
// `Hooks.PostPrepare` callbacks. You may specify anything you want.
122+
// Return nil if you do not need to use it.
123+
//
124+
// The second return value is indicates the error found while
125+
// executing this hook.
126+
PrePrepare func(c context.Context, stmt *Stmt) (interface{}, error)
127+
128+
// Prepare is called after the underlying driver's `db.Prepare` method
129+
// returns without any errors.
130+
//
131+
// The `ctx` parameter is the return value supplied from the
132+
// `Hooks.PrePrepare` method, and may be nil.
133+
//
134+
// If this callback returns an error, then the error from this
135+
// callback is returned by the `db.Prepare` method.
136+
Prepare func(c context.Context, ctx interface{}, stmt *Stmt) error
137+
138+
// PostPrepare is a callback that gets called at the end of
139+
// the call to `db.Prepare`. It is ALWAYS called.
140+
//
141+
// The `ctx` parameter is the return value supplied from the
142+
// `Hooks.PrePrepare` method, and may be nil.
143+
PostPrepare func(c context.Context, ctx interface{}, stmt *Stmt, err error) error
144+
112145
// PreExec is a callback that gets called prior to calling
113146
// `Stmt.Exec`, and is ALWAYS called. If this callback returns an
114147
// error, the underlying driver's `Stmt.Exec` and `Hooks.Exec` methods
@@ -405,6 +438,27 @@ func (h *HooksContext) postOpen(c context.Context, ctx interface{}, conn *Conn,
405438
return h.PostOpen(c, ctx, conn, err)
406439
}
407440

441+
func (h *HooksContext) prePrepare(c context.Context, stmt *Stmt) (interface{}, error) {
442+
if h == nil || h.PrePrepare == nil {
443+
return nil, nil
444+
}
445+
return h.PrePrepare(c, stmt)
446+
}
447+
448+
func (h *HooksContext) prepare(c context.Context, ctx interface{}, stmt *Stmt) error {
449+
if h == nil || h.Prepare == nil {
450+
return nil
451+
}
452+
return h.Prepare(c, ctx, stmt)
453+
}
454+
455+
func (h *HooksContext) postPrepare(c context.Context, ctx interface{}, stmt *Stmt, err error) error {
456+
if h == nil || h.PostPrepare == nil {
457+
return nil
458+
}
459+
return h.PostPrepare(c, ctx, stmt, err)
460+
}
461+
408462
func (h *HooksContext) preExec(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) {
409463
if h == nil || h.PreExec == nil {
410464
return nil, nil
@@ -929,6 +983,18 @@ func (h *Hooks) postOpen(c context.Context, ctx interface{}, conn *Conn, err err
929983
return h.PostOpen(ctx, conn)
930984
}
931985

986+
func (h *Hooks) prePrepare(c context.Context, stmt *Stmt) (interface{}, error) {
987+
return nil, nil
988+
}
989+
990+
func (h *Hooks) prepare(c context.Context, ctx interface{}, stmt *Stmt) error {
991+
return nil
992+
}
993+
994+
func (h *Hooks) postPrepare(c context.Context, ctx interface{}, stmt *Stmt, err error) error {
995+
return nil
996+
}
997+
932998
func (h *Hooks) preExec(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) {
933999
if h == nil || h.PreExec == nil {
9341000
return nil, nil
@@ -1187,6 +1253,24 @@ func (h multipleHooks) postOpen(c context.Context, ctx interface{}, conn *Conn,
11871253
})
11881254
}
11891255

1256+
func (h multipleHooks) prePrepare(c context.Context, stmt *Stmt) (interface{}, error) {
1257+
return h.preDo(func(h hooks) (interface{}, error) {
1258+
return h.prePrepare(c, stmt)
1259+
})
1260+
}
1261+
1262+
func (h multipleHooks) prepare(c context.Context, ctx interface{}, stmt *Stmt) error {
1263+
return h.do(ctx, func(h hooks, ctx interface{}) error {
1264+
return h.prepare(c, ctx, stmt)
1265+
})
1266+
}
1267+
1268+
func (h multipleHooks) postPrepare(c context.Context, ctx interface{}, stmt *Stmt, err error) error {
1269+
return h.postDo(ctx, err, func(h hooks, ctx interface{}, err error) error {
1270+
return h.postPrepare(c, ctx, stmt, err)
1271+
})
1272+
}
1273+
11901274
func (h multipleHooks) preExec(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) {
11911275
return h.preDo(func(h hooks) (interface{}, error) {
11921276
return h.preExec(c, stmt, args)

logging_hook_test.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,27 @@ func (h *loggingHook) postOpen(c context.Context, ctx interface{}, conn *Conn, e
6161
return nil
6262
}
6363

64+
func (h *loggingHook) prePrepare(c context.Context, stmt *Stmt) (interface{}, error) {
65+
h.mu.Lock()
66+
defer h.mu.Unlock()
67+
fmt.Fprintln(h, "[PrePrepare]")
68+
return nil, nil
69+
}
70+
71+
func (h *loggingHook) prepare(c context.Context, ctx interface{}, stmt *Stmt) error {
72+
h.mu.Lock()
73+
defer h.mu.Unlock()
74+
fmt.Fprintln(h, "[Prepare]")
75+
return nil
76+
}
77+
78+
func (h *loggingHook) postPrepare(c context.Context, ctx interface{}, stmt *Stmt, err error) error {
79+
h.mu.Lock()
80+
defer h.mu.Unlock()
81+
fmt.Fprintln(h, "[PostPrepare]")
82+
return nil
83+
}
84+
6485
func (h *loggingHook) preExec(c context.Context, stmt *Stmt, args []driver.NamedValue) (interface{}, error) {
6586
h.mu.Lock()
6687
defer h.mu.Unlock()

proxy_test.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ func TestFakeDB(t *testing.T) {
3636
Name: "execAll",
3737
},
3838
hooksLog: "[PreOpen]\n[Open]\n[PostOpen]\n" +
39+
"[PrePrepare]\n[Prepare]\n[PostPrepare]\n" +
3940
"[PreExec]\n[Exec]\n[PostExec]\n" +
4041
"[PreClose]\n[Close]\n[PostClose]\n",
4142
f: func(db *sql.DB) error {
@@ -49,6 +50,7 @@ func TestFakeDB(t *testing.T) {
4950
FailExec: true,
5051
},
5152
hooksLog: "[PreOpen]\n[Open]\n[PostOpen]\n" +
53+
"[PrePrepare]\n[Prepare]\n[PostPrepare]\n" +
5254
"[PreExec]\n[PostExec]\n" +
5355
"[PreClose]\n[Close]\n[PostClose]\n",
5456
f: func(db *sql.DB) error {
@@ -64,6 +66,7 @@ func TestFakeDB(t *testing.T) {
6466
Name: "execError-NamedValue",
6567
},
6668
hooksLog: "[PreOpen]\n[Open]\n[PostOpen]\n" +
69+
"[PrePrepare]\n[Prepare]\n[PostPrepare]\n" +
6770
"[PreExec]\n[PostExec]\n" +
6871
"[PreClose]\n[Close]\n[PostClose]\n",
6972
f: func(db *sql.DB) error {
@@ -80,6 +83,7 @@ func TestFakeDB(t *testing.T) {
8083
Name: "queryAll",
8184
},
8285
hooksLog: "[PreOpen]\n[Open]\n[PostOpen]\n" +
86+
"[PrePrepare]\n[Prepare]\n[PostPrepare]\n" +
8387
"[PreQuery]\n[Query]\n[PostQuery]\n",
8488
f: func(db *sql.DB) error {
8589
_, err := db.Query("SELECT * FROM test WHERE id = ?", 123456789)
@@ -92,6 +96,7 @@ func TestFakeDB(t *testing.T) {
9296
FailQuery: true,
9397
},
9498
hooksLog: "[PreOpen]\n[Open]\n[PostOpen]\n" +
99+
"[PrePrepare]\n[Prepare]\n[PostPrepare]\n" +
95100
"[PreQuery]\n[PostQuery]\n" +
96101
"[PreClose]\n[Close]\n[PostClose]\n",
97102
f: func(db *sql.DB) error {
@@ -107,6 +112,7 @@ func TestFakeDB(t *testing.T) {
107112
Name: "prepare",
108113
},
109114
hooksLog: "[PreOpen]\n[Open]\n[PostOpen]\n" +
115+
"[PrePrepare]\n[Prepare]\n[PostPrepare]\n" +
110116
"[PreClose]\n[Close]\n[PostClose]\n",
111117
f: func(db *sql.DB) error {
112118
stmt, err := db.Prepare("SELECT * FROM test WHERE id = ?")
@@ -255,6 +261,7 @@ func TestFakeDB(t *testing.T) {
255261
ConnType: "fakeConnExt",
256262
},
257263
hooksLog: "[PreOpen]\n[Open]\n[PostOpen]\n" +
264+
"[PrePrepare]\n[Prepare]\n[PostPrepare]\n" +
258265
"[PreExec]\n[Exec]\n[PostExec]\n" +
259266
"[PreClose]\n[Close]\n[PostClose]\n",
260267
f: func(db *sql.DB) error {
@@ -325,6 +332,7 @@ func TestFakeDB(t *testing.T) {
325332
ConnType: "fakeConnCtx",
326333
},
327334
hooksLog: "[PreOpen]\n[Open]\n[PostOpen]\n" +
335+
"[PrePrepare]\n[Prepare]\n[PostPrepare]\n" +
328336
"[PreExec]\n[Exec]\n[PostExec]\n" +
329337
"[PreClose]\n[Close]\n[PostClose]\n",
330338
f: func(db *sql.DB) error {
@@ -343,6 +351,7 @@ func TestFakeDB(t *testing.T) {
343351
ConnType: "fakeConnCtx",
344352
},
345353
hooksLog: "[PreOpen]\n[Open]\n[PostOpen]\n" +
354+
"[PrePrepare]\n[Prepare]\n[PostPrepare]\n" +
346355
"[PreQuery]\n[Query]\n[PostQuery]\n",
347356
f: func(db *sql.DB) error {
348357
stmt, err := db.Prepare("SELECT * FROM test WHERE id = ?")

0 commit comments

Comments
 (0)