Skip to content

Commit 82f3b17

Browse files
committed
fix addToSet and push to support each op
1 parent 8008220 commit 82f3b17

File tree

4 files changed

+53
-7
lines changed

4 files changed

+53
-7
lines changed

pkg/mongoproxy/plugins/schema/schema.go

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,19 @@ import (
66
"io/ioutil"
77
"log"
88
"path"
9+
"strings"
910
"sync/atomic"
1011

1112
"github.com/cespare/xxhash/v2"
1213
"github.com/prometheus/client_golang/prometheus"
1314
"github.com/prometheus/client_golang/prometheus/promauto"
1415
"github.com/sirupsen/logrus"
15-
"go.mongodb.org/mongo-driver/bson"
16-
"gopkg.in/fsnotify.v1"
17-
1816
"github.com/wish/mongoproxy/pkg/bsonutil"
1917
"github.com/wish/mongoproxy/pkg/command"
2018
"github.com/wish/mongoproxy/pkg/mongoerror"
2119
"github.com/wish/mongoproxy/pkg/mongoproxy/plugins"
20+
"go.mongodb.org/mongo-driver/bson"
21+
"gopkg.in/fsnotify.v1"
2222
)
2323

2424
var (
@@ -114,12 +114,20 @@ func (p *SchemaPlugin) Configure(d bson.D) error {
114114
if err := p.LoadSchema(); err != nil {
115115
return err
116116
}
117+
// skip watcher for unit test
118+
if strings.HasPrefix(p.conf.SchemaPath, "example.json") {
119+
return nil
120+
}
121+
117122
// start watch
118123
watcher, err := fsnotify.NewWatcher()
119124
if err != nil {
120125
log.Fatal(err)
121126
}
122127

128+
defer watcher.Close()
129+
done := make(chan bool)
130+
123131
go func() {
124132
for {
125133
select {
@@ -145,7 +153,7 @@ func (p *SchemaPlugin) Configure(d bson.D) error {
145153
if err := watcher.Add(path.Dir(p.conf.SchemaPath)); err != nil {
146154
return err
147155
}
148-
156+
<-done
149157
return nil
150158
}
151159

@@ -168,7 +176,7 @@ func (p *SchemaPlugin) Process(ctx context.Context, r *plugins.Request, next plu
168176
case *command.FindAndModify:
169177
if len(cmd.Update) > 0 {
170178
schema := p.GetSchema()
171-
logrus.Infof("command findAndModify: %v", cmd.Update)
179+
logrus.Debugf("command findAndModify: %v", cmd.Update)
172180
if err := schema.ValidateUpdate(ctx, cmd.Database, cmd.Collection, cmd.Update, bsonutil.GetBoolDefault(cmd.Upsert, false)); err != nil {
173181
schemaDeny.WithLabelValues(cmd.Database, cmd.Collection, r.CommandName).Inc()
174182
if !p.conf.EnforceSchemaLogOnly {
@@ -182,7 +190,7 @@ func (p *SchemaPlugin) Process(ctx context.Context, r *plugins.Request, next plu
182190
case *command.Update:
183191
schema := p.GetSchema()
184192
for _, updateDoc := range cmd.Updates {
185-
logrus.Infof("print command Update: %v", updateDoc)
193+
logrus.Debugf("print command Update: %v", updateDoc)
186194
if err := schema.ValidateUpdate(ctx, cmd.Database, cmd.Collection, updateDoc.U, bsonutil.GetBoolDefault(updateDoc.Upsert, false)); err != nil {
187195
schemaDeny.WithLabelValues(cmd.Database, cmd.Collection, r.CommandName).Inc()
188196
if !p.conf.EnforceSchemaLogOnly {

pkg/mongoproxy/plugins/schema/type_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,10 @@ var (
199199
// push extra field
200200
{DB: "testdb", Collection: "requireonlysuba", In: bson.D{{"$push", bson.D{{"doc.a", "name"}, {"doc.b", 1}}}}, Err: true},
201201
{DB: "testdb", Collection: "requireonlysuba", In: bson.D{{"$push", bson.D{{"a", "name"}, {"doc.b", 1}}}}, Err: true},
202+
//test with each
203+
{DB: "testdb", Collection: "nonrequire", In: bson.D{{"$push", bson.D{{"luckynumbers", bson.D{{"$each", bson.A{1, 2, 3}}}}}}}},
204+
//test with each
205+
{DB: "testdb", Collection: "nonrequire", In: bson.D{{"$push", bson.D{{"luckynumbers", bson.E{"$each", bson.A{1, 2, 3}}}}}}},
202206

203207
//
204208
// pull tests
@@ -337,6 +341,10 @@ var (
337341
// addToSet extra field
338342
{DB: "testdb", Collection: "requireonlysuba", In: bson.D{{"$addToSet", bson.D{{"doc.a", "name"}, {"doc.b", 1}}}}, Err: true},
339343
{DB: "testdb", Collection: "requireonlysuba", In: bson.D{{"$addToSet", bson.D{{"a", "name"}, {"doc.b", 1}}}}, Err: true},
344+
//test with each
345+
{DB: "testdb", Collection: "nonrequire", In: bson.D{{"$addToSet", bson.D{{"luckynumbers", bson.D{{"$each", bson.A{1, 2, 3}}}}}}}},
346+
//test with each
347+
{DB: "testdb", Collection: "nonrequire", In: bson.D{{"$addToSet", bson.D{{"luckynumbers", bson.E{"$each", bson.A{1, 2, 3}}}}}}},
340348

341349
//
342350
// rename tests

pkg/mongoproxy/plugins/schema/types.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ func (c *Collection) ValidateUpdate(ctx context.Context, obj bson.D, upsert bool
267267
}
268268
case "$rename":
269269
renameFields = e.Value.(bson.D).Map()
270-
case "$set", "$pull", "$push", "$addToSet", "$pullAll":
270+
case "$set", "$pull", "$pullAll":
271271
if setFields == nil {
272272
setFields = Mapify(e.Value.(bson.D))
273273
} else {
@@ -276,6 +276,12 @@ func (c *Collection) ValidateUpdate(ctx context.Context, obj bson.D, upsert bool
276276
setFields[item.Key] = item.Value
277277
}
278278
}
279+
case "$addToSet", "$push":
280+
if setFields == nil {
281+
setFields = make(bson.M, len(e.Value.(bson.D)))
282+
}
283+
setFields = MapifyWithOp(e.Value.(bson.D), setFields)
284+
279285
case "$setOnInsert":
280286
insertFields = Mapify(e.Value.(bson.D))
281287
case "$unset":

pkg/mongoproxy/plugins/schema/util.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"reflect"
77
"regexp"
88

9+
"github.com/sirupsen/logrus"
910
"go.mongodb.org/mongo-driver/bson"
1011
)
1112

@@ -141,6 +142,29 @@ func Mapify(d bson.D) bson.M {
141142
return m
142143
}
143144

145+
// Map creates a map from the elements of the D with operator
146+
// It makes additional process for arrays
147+
func MapifyWithOp(d bson.D, m bson.M) bson.M {
148+
for _, e := range d {
149+
e := processArray(e)
150+
itemType := fmt.Sprint(reflect.TypeOf(e.Value))
151+
if itemType == "primitive.D" {
152+
itemValueSet := e.Value.(bson.D).Map()
153+
if val, ok := itemValueSet["$each"]; ok {
154+
m[e.Key] = val
155+
} else {
156+
m[e.Key] = e.Value
157+
}
158+
} else if itemType == "primitive.E" && e.Value.(bson.E).Key == "$each" {
159+
m[e.Key] = e.Value.(bson.E).Value
160+
} else {
161+
m[e.Key] = e.Value
162+
logrus.Debugf("Add %s type element to set", itemType)
163+
}
164+
}
165+
return m
166+
}
167+
144168
// looping and process elements in object
145169
func handleObj(obj bson.D, m bson.M) bson.M {
146170
for _, e := range obj {

0 commit comments

Comments
 (0)