From f9fb7c13d1d1e8f408d8cf9bc9596587ef9f0efc Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 24 Feb 2023 07:18:34 +0000 Subject: [PATCH] Bump github.com/jinzhu/gorm Bumps [github.com/jinzhu/gorm](https://github.com/jinzhu/gorm) from 0.0.0-20160404144928-5174cc5c242a to 1.9.16. - [Release notes](https://github.com/jinzhu/gorm/releases) - [Commits](https://github.com/jinzhu/gorm/commits/v1.9.16) --- updated-dependencies: - dependency-name: github.com/jinzhu/gorm dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- go.mod | 7 +- go.sum | 17 +- .../github.com/jinzhu/gorm/.codeclimate.yml | 11 - vendor/github.com/jinzhu/gorm/.gitignore | 1 + vendor/github.com/jinzhu/gorm/CONTRIBUTING.md | 52 -- vendor/github.com/jinzhu/gorm/README.md | 45 +- vendor/github.com/jinzhu/gorm/association.go | 42 +- vendor/github.com/jinzhu/gorm/callback.go | 53 +- .../github.com/jinzhu/gorm/callback_create.go | 91 ++- .../github.com/jinzhu/gorm/callback_delete.go | 18 +- .../github.com/jinzhu/gorm/callback_query.go | 20 +- .../jinzhu/gorm/callback_query_preload.go | 160 ++++- .../jinzhu/gorm/callback_row_query.go | 41 ++ .../github.com/jinzhu/gorm/callback_save.go | 130 +++- .../github.com/jinzhu/gorm/callback_update.go | 25 +- vendor/github.com/jinzhu/gorm/dialect.go | 79 ++- .../github.com/jinzhu/gorm/dialect_common.go | 95 ++- .../github.com/jinzhu/gorm/dialect_mysql.go | 163 ++++- .../jinzhu/gorm/dialect_postgres.go | 53 +- .../github.com/jinzhu/gorm/dialect_sqlite3.go | 15 +- .../github.com/jinzhu/gorm/docker-compose.yml | 30 + vendor/github.com/jinzhu/gorm/errors.go | 60 +- vendor/github.com/jinzhu/gorm/field.go | 10 +- vendor/github.com/jinzhu/gorm/interface.go | 9 +- .../jinzhu/gorm/join_table_handler.go | 51 +- vendor/github.com/jinzhu/gorm/logger.go | 116 ++-- vendor/github.com/jinzhu/gorm/main.go | 466 +++++++++----- vendor/github.com/jinzhu/gorm/model_struct.go | 351 +++++++---- vendor/github.com/jinzhu/gorm/naming.go | 124 ++++ vendor/github.com/jinzhu/gorm/scope.go | 569 ++++++++++++------ vendor/github.com/jinzhu/gorm/search.go | 88 ++- vendor/github.com/jinzhu/gorm/test_all.sh | 4 +- vendor/github.com/jinzhu/gorm/utils.go | 86 +-- vendor/github.com/jinzhu/gorm/wercker.yml | 149 +++++ vendor/github.com/jinzhu/inflection/README.md | 10 +- .../github.com/jinzhu/inflection/wercker.yml | 23 + vendor/modules.txt | 12 +- 37 files changed, 2344 insertions(+), 932 deletions(-) delete mode 100644 vendor/github.com/jinzhu/gorm/.codeclimate.yml delete mode 100644 vendor/github.com/jinzhu/gorm/CONTRIBUTING.md create mode 100644 vendor/github.com/jinzhu/gorm/callback_row_query.go create mode 100644 vendor/github.com/jinzhu/gorm/docker-compose.yml create mode 100644 vendor/github.com/jinzhu/gorm/naming.go create mode 100644 vendor/github.com/jinzhu/gorm/wercker.yml create mode 100644 vendor/github.com/jinzhu/inflection/wercker.yml diff --git a/go.mod b/go.mod index 58c65c96..3b6d3a38 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ require ( github.com/gorilla/mux v1.8.0 github.com/gorilla/securecookie v0.0.0-20160422134519-667fe4e3466a github.com/honeycombio/beeline-go v1.11.1 - github.com/jinzhu/gorm v0.0.0-20160404144928-5174cc5c242a + github.com/jinzhu/gorm v1.9.16 github.com/nicklaw5/helix v1.25.0 github.com/nlopes/slack v0.0.0-20180905213137-8cf10c586222 github.com/nu7hatch/gouuid v0.0.0-20131221200532-179d4d0c4d8d @@ -35,8 +35,6 @@ require ( github.com/antchfx/htmlquery v1.2.3 // indirect github.com/antchfx/xmlquery v1.2.4 // indirect github.com/antchfx/xpath v1.1.8 // indirect - github.com/denisenkom/go-mssqldb v0.0.0-20190915052044-aa4949efa320 // indirect - github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 // indirect github.com/facebookgo/clock v0.0.0-20150410010913-600d898af40a // indirect github.com/facebookgo/limitgroup v0.0.0-20150612190941-6abd8d71ec01 // indirect github.com/facebookgo/muster v0.0.0-20150708232844-fd3d7953fd52 // indirect @@ -51,8 +49,7 @@ require ( github.com/googleapis/gax-go/v2 v2.6.0 // indirect github.com/gorilla/websocket v1.4.2 // indirect github.com/honeycombio/libhoney-go v1.17.1 // indirect - github.com/jinzhu/inflection v0.0.0-20170102125226-1c35d901db3d // indirect - github.com/jinzhu/now v1.0.1 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect github.com/kennygrant/sanitize v1.2.4 // indirect github.com/klauspost/compress v1.15.9 // indirect github.com/lib/pq v1.10.7 // indirect diff --git a/go.sum b/go.sum index 9d50739e..28188dae 100644 --- a/go.sum +++ b/go.sum @@ -31,8 +31,8 @@ github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGX github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/denisenkom/go-mssqldb v0.0.0-20190915052044-aa4949efa320 h1:eCGfXWmAYTB+OAWSmiFlz4L/SIcX+Kf3g8iVRqVcfTY= -github.com/denisenkom/go-mssqldb v0.0.0-20190915052044-aa4949efa320/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= +github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd h1:83Wprp6ROGeiHFAP8WJdI2RoxALQYgdllERc3N5N2DM= +github.com/denisenkom/go-mssqldb v0.0.0-20191124224453-732737034ffd/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/dineshappavoo/basex v0.0.0-20160618072718-f35bafba529c h1:ZMcYoGBFnMeWXBc1PahW5AJdc39BtpTkUWexQ3ugaZc= github.com/dineshappavoo/basex v0.0.0-20160618072718-f35bafba529c/go.mod h1:Kad2hux31v/IyD4Rf4wAwIyK48995rs3qAl9IUAhc2k= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= @@ -56,6 +56,7 @@ github.com/facebookgo/subset v0.0.0-20200203212716-c811ad88dec4/go.mod h1:5tD+ne github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk= github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= @@ -113,20 +114,22 @@ github.com/honeycombio/beeline-go v1.11.1/go.mod h1:VTUrnK49oSsWme20E+94XTn1qofD github.com/honeycombio/libhoney-go v1.17.1 h1:qOHGm5lqzj82O5RvsqTM0OhbEUxY+t9WoKPaD9FJJ5o= github.com/honeycombio/libhoney-go v1.17.1/go.mod h1:KwbcXkqUbH20x3MpfSt/kdvlog3FFdEnouqYD3XKXLY= github.com/jawher/mow.cli v1.1.0/go.mod h1:aNaQlc7ozF3vw6IJ2dHjp2ZFiA4ozMIYY6PyuRJwlUg= -github.com/jinzhu/gorm v0.0.0-20160404144928-5174cc5c242a h1:pfPxlCVlKqBRqHpyCxOIKhhB4ERpz02iadDpRVevLm4= -github.com/jinzhu/gorm v0.0.0-20160404144928-5174cc5c242a/go.mod h1:Vla75njaFJ8clLU1W44h34PjIkijhjHIYnZxMqCdxqo= -github.com/jinzhu/inflection v0.0.0-20170102125226-1c35d901db3d h1:jRQLvyVGL+iVtDElaEIDdKwpPqUIZJfzkNLV34htpEc= -github.com/jinzhu/inflection v0.0.0-20170102125226-1c35d901db3d/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/gorm v1.9.16 h1:+IyIjPEABKRpsu/F8OvDPy9fyQlgsg2luMV2ZIH5i5o= +github.com/jinzhu/gorm v1.9.16/go.mod h1:G3LB3wezTOWM2ITLzPxEXgSkOXAntiLHS7UdBefADcs= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.0.1 h1:HjfetcXq097iXP0uoPCdnM4Efp5/9MsM0/M+XOTeR3M= github.com/jinzhu/now v1.0.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/kennygrant/sanitize v1.2.4 h1:gN25/otpP5vAsO2djbMhF/LQX6R7+O1TB4yv8NzpJ3o= github.com/kennygrant/sanitize v1.2.4/go.mod h1:LGsjYYtgxbetdg5owWB2mpgUL6e2nfw2eObZ0u0qvak= github.com/klauspost/compress v1.15.9 h1:wKRjX6JRtDdrE9qwa4b/Cip7ACOshUI4smpCQanqjSY= github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= +github.com/lib/pq v1.1.1/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lusis/slack-test v0.0.0-20190426140909-c40012f20018 h1:MNApn+Z+fIT4NPZopPfCc1obT6aY3SVM6DOctz1A9ZU= github.com/lusis/slack-test v0.0.0-20190426140909-c40012f20018/go.mod h1:sFlOUpQL1YcjhFVXhg1CG8ZASEs/Mf1oVb6H75JL/zg= +github.com/mattn/go-sqlite3 v1.14.0/go.mod h1:JIl7NbARA7phWnGvh0LKTyg7S9BA+6gx71ShQilpsus= github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOjvxI= github.com/nicklaw5/helix v1.25.0 h1:Mrz537izZVsGdM3I46uGAAlslj61frgkhS/9xQqyT/M= github.com/nicklaw5/helix v1.25.0/go.mod h1:yvXZFapT6afIoxnAvlWiJiUMsYnoHl7tNs+t0bloAMw= @@ -182,6 +185,7 @@ go.uber.org/zap v1.23.0/go.mod h1:D+nX8jyLsMHMYrln8A0rJjFt/T/9/bGgIhAqxv5URuY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20191205180655-e7c4368fe9dd/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= @@ -198,6 +202,7 @@ golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200421231249-e086a090c8fd/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200602114024-627f9648deb9/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= diff --git a/vendor/github.com/jinzhu/gorm/.codeclimate.yml b/vendor/github.com/jinzhu/gorm/.codeclimate.yml deleted file mode 100644 index 51aba50c..00000000 --- a/vendor/github.com/jinzhu/gorm/.codeclimate.yml +++ /dev/null @@ -1,11 +0,0 @@ ---- -engines: - gofmt: - enabled: true - govet: - enabled: true - golint: - enabled: true -ratings: - paths: - - "**.go" diff --git a/vendor/github.com/jinzhu/gorm/.gitignore b/vendor/github.com/jinzhu/gorm/.gitignore index 01dc5ce0..117f92f5 100644 --- a/vendor/github.com/jinzhu/gorm/.gitignore +++ b/vendor/github.com/jinzhu/gorm/.gitignore @@ -1,2 +1,3 @@ documents +coverage.txt _book diff --git a/vendor/github.com/jinzhu/gorm/CONTRIBUTING.md b/vendor/github.com/jinzhu/gorm/CONTRIBUTING.md deleted file mode 100644 index c54d572d..00000000 --- a/vendor/github.com/jinzhu/gorm/CONTRIBUTING.md +++ /dev/null @@ -1,52 +0,0 @@ -# How to Contribute - -## Bug Report - -- Do a search on GitHub under Issues in case it has already been reported -- Submit __executable script__ or failing test pull request that could demonstrates the issue is *MUST HAVE* - -## Feature Request - -- Feature request with pull request is welcome -- Or it won't be implemented until I (other developers) find it is helpful for my (their) daily work - -## Pull Request - -- Prefer single commit pull request, that make the git history can be a bit easier to follow. -- New features need to be covered with tests to make sure your code works as expected, and won't be broken by others in future - -## Contributing to Documentation - -- You are welcome ;) -- You can help improve the README by making them more coherent, consistent or readable, and add more godoc documents to make people easier to follow. -- Blogs & Usage Guides & PPT also welcome, please add them to https://github.com/jinzhu/gorm/wiki/Guides - -### Executable script template - -```go -package main - -import ( - _ "github.com/mattn/go-sqlite3" - _ "github.com/go-sql-driver/mysql" - _ "github.com/lib/pq" - "github.com/jinzhu/gorm" -) - -var db gorm.DB - -func init() { - var err error - db, err = gorm.Open("sqlite3", "test.db") - // db, err := gorm.Open("postgres", "user=username dbname=password sslmode=disable") - // db, err := gorm.Open("mysql", "user:password@/dbname?charset=utf8&parseTime=True") - if err != nil { - panic(err) - } - db.LogMode(true) -} - -func main() { - // Your code -} -``` diff --git a/vendor/github.com/jinzhu/gorm/README.md b/vendor/github.com/jinzhu/gorm/README.md index c3f209c9..85588a79 100644 --- a/vendor/github.com/jinzhu/gorm/README.md +++ b/vendor/github.com/jinzhu/gorm/README.md @@ -1,46 +1,5 @@ # GORM -The fantastic ORM library for Golang, aims to be developer friendly. +GORM V2 moved to https://github.com/go-gorm/gorm -[![Join the chat at https://gitter.im/jinzhu/gorm](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/jinzhu/gorm?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) -[![wercker status](https://app.wercker.com/status/0cb7bb1039e21b74f8274941428e0921/s/master "wercker status")](https://app.wercker.com/project/bykey/0cb7bb1039e21b74f8274941428e0921) -[![GoDoc](https://godoc.org/github.com/jinzhu/gorm?status.svg)](https://godoc.org/github.com/jinzhu/gorm) - -## Overview - -* Full-Featured ORM (almost) -* Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism) -* Callbacks (Before/After Create/Save/Update/Delete/Find) -* Preloading (eager loading) -* Transactions -* Composite Primary Key -* SQL Builder -* Auto Migrations -* Logger -* Extendable, write Plugins based on GORM callbacks -* Every feature comes with tests -* Developer Friendly - -## Getting Started - -* GORM Guides [jinzhu.github.com/gorm](https://jinzhu.github.io/gorm) - -## Upgrading To V1.0 - -* [CHANGELOG](https://jinzhu.github.io/gorm/changelog.html) - -# Author - -**jinzhu** - -* -* -* - -# Contributors - -https://github.com/jinzhu/gorm/graphs/contributors - -## License - -Released under the [MIT License](https://github.com/jinzhu/gorm/blob/master/License). +GORM V1 Doc https://v1.gorm.io/ diff --git a/vendor/github.com/jinzhu/gorm/association.go b/vendor/github.com/jinzhu/gorm/association.go index cd8fd912..a73344fe 100644 --- a/vendor/github.com/jinzhu/gorm/association.go +++ b/vendor/github.com/jinzhu/gorm/association.go @@ -22,6 +22,10 @@ func (association *Association) Find(value interface{}) *Association { // Append append new associations for many2many, has_many, replace current association for has_one, belongs_to func (association *Association) Append(values ...interface{}) *Association { + if association.Error != nil { + return association + } + if relationship := association.field.Relationship; relationship.Kind == "has_one" { return association.Replace(values...) } @@ -30,6 +34,10 @@ func (association *Association) Append(values ...interface{}) *Association { // Replace replace current associations with new one func (association *Association) Replace(values ...interface{}) *Association { + if association.Error != nil { + return association + } + var ( relationship = association.field.Relationship scope = association.scope @@ -55,31 +63,33 @@ func (association *Association) Replace(values ...interface{}) *Association { } else { // Polymorphic Relations if relationship.PolymorphicDBName != "" { - newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) + newDB = newDB.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue) } // Delete Relations except new created if len(values) > 0 { - var associationForeignFieldNames []string + var associationForeignFieldNames, associationForeignDBNames []string if relationship.Kind == "many_to_many" { // if many to many relations, get association fields name from association foreign keys associationScope := scope.New(reflect.New(field.Type()).Interface()) - for _, dbName := range relationship.AssociationForeignFieldNames { + for idx, dbName := range relationship.AssociationForeignFieldNames { if field, ok := associationScope.FieldByName(dbName); ok { associationForeignFieldNames = append(associationForeignFieldNames, field.Name) + associationForeignDBNames = append(associationForeignDBNames, relationship.AssociationForeignDBNames[idx]) } } } else { - // If other relations, use primary keys + // If has one/many relations, use primary keys for _, field := range scope.New(reflect.New(field.Type()).Interface()).PrimaryFields() { associationForeignFieldNames = append(associationForeignFieldNames, field.Name) + associationForeignDBNames = append(associationForeignDBNames, field.DBName) } } newPrimaryKeys := scope.getColumnAsArray(associationForeignFieldNames, field.Interface()) if len(newPrimaryKeys) > 0 { - sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(newPrimaryKeys)) + sql := fmt.Sprintf("%v NOT IN (%v)", toQueryCondition(scope, associationForeignDBNames), toQueryMarks(newPrimaryKeys)) newDB = newDB.Where(sql, toQueryValues(newPrimaryKeys)...) } } @@ -97,7 +107,7 @@ func (association *Association) Replace(values ...interface{}) *Association { if sourcePrimaryKeys := scope.getColumnAsArray(sourceForeignFieldNames, scope.Value); len(sourcePrimaryKeys) > 0 { newDB = newDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(sourcePrimaryKeys)), toQueryValues(sourcePrimaryKeys)...) - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) + association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB)) } } else if relationship.Kind == "has_one" || relationship.Kind == "has_many" { // has_one or has_many relations, set foreign key to be nil (TODO or delete them?) @@ -118,6 +128,10 @@ func (association *Association) Replace(values ...interface{}) *Association { // Delete remove relationship between source & passed arguments, but won't delete those arguments func (association *Association) Delete(values ...interface{}) *Association { + if association.Error != nil { + return association + } + var ( relationship = association.field.Relationship scope = association.scope @@ -159,7 +173,7 @@ func (association *Association) Delete(values ...interface{}) *Association { sql := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(deletingPrimaryKeys)) newDB = newDB.Where(sql, toQueryValues(deletingPrimaryKeys)...) - association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB, relationship)) + association.setErr(relationship.JoinTableHandler.Delete(relationship.JoinTableHandler, newDB)) } else { var foreignKeyMap = map[string]interface{}{} for _, foreignKey := range relationship.ForeignDBNames { @@ -253,15 +267,16 @@ func (association *Association) Count() int { query = scope.DB() ) - if relationship.Kind == "many_to_many" { + switch relationship.Kind { + case "many_to_many": query = relationship.JoinTableHandler.JoinWith(relationship.JoinTableHandler, query, scope.Value) - } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { + case "has_many", "has_one": primaryKeys := scope.getColumnAsArray(relationship.AssociationForeignFieldNames, scope.Value) query = query.Where( fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)..., ) - } else if relationship.Kind == "belongs_to" { + case "belongs_to": primaryKeys := scope.getColumnAsArray(relationship.ForeignFieldNames, scope.Value) query = query.Where( fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relationship.AssociationForeignDBNames), toQueryMarks(primaryKeys)), @@ -272,11 +287,13 @@ func (association *Association) Count() int { if relationship.PolymorphicType != "" { query = query.Where( fmt.Sprintf("%v.%v = ?", scope.New(fieldValue).QuotedTableName(), scope.Quote(relationship.PolymorphicDBName)), - scope.TableName(), + relationship.PolymorphicValue, ) } - query.Model(fieldValue).Count(&count) + if err := query.Model(fieldValue).Count(&count).Error; err != nil { + association.Error = err + } return count } @@ -351,6 +368,7 @@ func (association *Association) saveAssociations(values ...interface{}) *Associa return association } +// setErr set error when the error is not nil. And return Association. func (association *Association) setErr(err error) *Association { if err != nil { association.Error = err diff --git a/vendor/github.com/jinzhu/gorm/callback.go b/vendor/github.com/jinzhu/gorm/callback.go index 93198a71..1f0e3c79 100644 --- a/vendor/github.com/jinzhu/gorm/callback.go +++ b/vendor/github.com/jinzhu/gorm/callback.go @@ -1,13 +1,11 @@ package gorm -import ( - "fmt" -) +import "fmt" // DefaultCallback default callbacks defined by gorm -var DefaultCallback = &Callback{} +var DefaultCallback = &Callback{logger: nopLogger{}} -// Callback is a struct that contains all CURD callbacks +// Callback is a struct that contains all CRUD callbacks // Field `creates` contains callbacks will be call when creating object // Field `updates` contains callbacks will be call when updating object // Field `deletes` contains callbacks will be call when deleting object @@ -15,6 +13,7 @@ var DefaultCallback = &Callback{} // Field `rowQueries` contains callbacks will be call when querying object with Row, Rows... // Field `processors` contains all callback processors, will be used to generate above callbacks in order type Callback struct { + logger logger creates []*func(scope *Scope) updates []*func(scope *Scope) deletes []*func(scope *Scope) @@ -25,6 +24,7 @@ type Callback struct { // CallbackProcessor contains callback informations type CallbackProcessor struct { + logger logger name string // current callback's name before string // register current callback before a callback after string // register current callback after a callback @@ -35,8 +35,9 @@ type CallbackProcessor struct { parent *Callback } -func (c *Callback) clone() *Callback { +func (c *Callback) clone(logger logger) *Callback { return &Callback{ + logger: logger, creates: c.creates, updates: c.updates, deletes: c.deletes, @@ -55,28 +56,28 @@ func (c *Callback) clone() *Callback { // scope.Err(errors.New("error")) // }) func (c *Callback) Create() *CallbackProcessor { - return &CallbackProcessor{kind: "create", parent: c} + return &CallbackProcessor{logger: c.logger, kind: "create", parent: c} } // Update could be used to register callbacks for updating object, refer `Create` for usage func (c *Callback) Update() *CallbackProcessor { - return &CallbackProcessor{kind: "update", parent: c} + return &CallbackProcessor{logger: c.logger, kind: "update", parent: c} } // Delete could be used to register callbacks for deleting object, refer `Create` for usage func (c *Callback) Delete() *CallbackProcessor { - return &CallbackProcessor{kind: "delete", parent: c} + return &CallbackProcessor{logger: c.logger, kind: "delete", parent: c} } // Query could be used to register callbacks for querying objects with query methods like `Find`, `First`, `Related`, `Association`... // Refer `Create` for usage func (c *Callback) Query() *CallbackProcessor { - return &CallbackProcessor{kind: "query", parent: c} + return &CallbackProcessor{logger: c.logger, kind: "query", parent: c} } // RowQuery could be used to register callbacks for querying objects with `Row`, `Rows`, refer `Create` for usage func (c *Callback) RowQuery() *CallbackProcessor { - return &CallbackProcessor{kind: "row_query", parent: c} + return &CallbackProcessor{logger: c.logger, kind: "row_query", parent: c} } // After insert a new callback after callback `callbackName`, refer `Callbacks.Create` @@ -93,6 +94,14 @@ func (cp *CallbackProcessor) Before(callbackName string) *CallbackProcessor { // Register a new callback, refer `Callbacks.Create` func (cp *CallbackProcessor) Register(callbackName string, callback func(scope *Scope)) { + if cp.kind == "row_query" { + if cp.before == "" && cp.after == "" && callbackName != "gorm:row_query" { + cp.logger.Print("info", fmt.Sprintf("Registering RowQuery callback %v without specify order with Before(), After(), applying Before('gorm:row_query') by default for compatibility...", callbackName)) + cp.before = "gorm:row_query" + } + } + + cp.logger.Print("info", fmt.Sprintf("[info] registering callback `%v` from %v", callbackName, fileWithLineNum())) cp.name = callbackName cp.processor = &callback cp.parent.processors = append(cp.parent.processors, cp) @@ -102,7 +111,7 @@ func (cp *CallbackProcessor) Register(callbackName string, callback func(scope * // Remove a registered callback // db.Callback().Create().Remove("gorm:update_time_stamp_when_create") func (cp *CallbackProcessor) Remove(callbackName string) { - fmt.Printf("[info] removing callback `%v` from %v\n", callbackName, fileWithLineNum()) + cp.logger.Print("info", fmt.Sprintf("[info] removing callback `%v` from %v", callbackName, fileWithLineNum())) cp.name = callbackName cp.remove = true cp.parent.processors = append(cp.parent.processors, cp) @@ -111,11 +120,11 @@ func (cp *CallbackProcessor) Remove(callbackName string) { // Replace a registered callback with new callback // db.Callback().Create().Replace("gorm:update_time_stamp_when_create", func(*Scope) { -// scope.SetColumn("Created", now) -// scope.SetColumn("Updated", now) +// scope.SetColumn("CreatedAt", now) +// scope.SetColumn("UpdatedAt", now) // }) func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *Scope)) { - fmt.Printf("[info] replacing callback `%v` from %v\n", callbackName, fileWithLineNum()) + cp.logger.Print("info", fmt.Sprintf("[info] replacing callback `%v` from %v", callbackName, fileWithLineNum())) cp.name = callbackName cp.processor = &callback cp.replace = true @@ -127,11 +136,15 @@ func (cp *CallbackProcessor) Replace(callbackName string, callback func(scope *S // db.Callback().Create().Get("gorm:create") func (cp *CallbackProcessor) Get(callbackName string) (callback func(scope *Scope)) { for _, p := range cp.parent.processors { - if p.name == callbackName && p.kind == cp.kind && !cp.remove { - return *p.processor + if p.name == callbackName && p.kind == cp.kind { + if p.remove { + callback = nil + } else { + callback = *p.processor + } } } - return nil + return } // getRIndex get right index from string slice @@ -154,7 +167,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { for _, cp := range cps { // show warning message the callback name already exists if index := getRIndex(allNames, cp.name); index > -1 && !cp.replace && !cp.remove { - fmt.Printf("[warning] duplicated callback `%v` from %v\n", cp.name, fileWithLineNum()) + cp.logger.Print("warning", fmt.Sprintf("[warning] duplicated callback `%v` from %v", cp.name, fileWithLineNum())) } allNames = append(allNames, cp.name) } @@ -208,7 +221,7 @@ func sortProcessors(cps []*CallbackProcessor) []*func(scope *Scope) { return sortedFuncs } -// reorder all registered processors, and reset CURD callbacks +// reorder all registered processors, and reset CRUD callbacks func (c *Callback) reorder() { var creates, updates, deletes, queries, rowQueries []*CallbackProcessor diff --git a/vendor/github.com/jinzhu/gorm/callback_create.go b/vendor/github.com/jinzhu/gorm/callback_create.go index e3cd2f0b..c4d25f37 100644 --- a/vendor/github.com/jinzhu/gorm/callback_create.go +++ b/vendor/github.com/jinzhu/gorm/callback_create.go @@ -31,9 +31,19 @@ func beforeCreateCallback(scope *Scope) { // updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating func updateTimeStampForCreateCallback(scope *Scope) { if !scope.HasError() { - now := NowFunc() - scope.SetColumn("CreatedAt", now) - scope.SetColumn("UpdatedAt", now) + now := scope.db.nowFunc() + + if createdAtField, ok := scope.FieldByName("CreatedAt"); ok { + if createdAtField.IsBlank { + createdAtField.Set(now) + } + } + + if updatedAtField, ok := scope.FieldByName("UpdatedAt"); ok { + if updatedAtField.IsBlank { + updatedAtField.Set(now) + } + } } } @@ -49,15 +59,13 @@ func createCallback(scope *Scope) { for _, field := range scope.Fields() { if scope.changeableField(field) { - if field.IsNormal { - if !field.IsPrimaryKey || !field.IsBlank { - if field.IsBlank && field.HasDefaultValue { - blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, field.DBName) - scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue) - } else { - columns = append(columns, scope.Quote(field.DBName)) - placeholders = append(placeholders, scope.AddToVars(field.Field.Interface())) - } + if field.IsNormal && !field.IsIgnored { + if field.IsBlank && field.HasDefaultValue { + blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName)) + scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue) + } else if !field.IsPrimaryKey || !field.IsBlank { + columns = append(columns, scope.Quote(field.DBName)) + placeholders = append(placeholders, scope.AddToVars(field.Field.Interface())) } } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" { for _, foreignKey := range field.Relationship.ForeignDBNames { @@ -75,38 +83,53 @@ func createCallback(scope *Scope) { quotedTableName = scope.QuotedTableName() primaryField = scope.PrimaryField() extraOption string + insertModifier string ) if str, ok := scope.Get("gorm:insert_option"); ok { extraOption = fmt.Sprint(str) } + if str, ok := scope.Get("gorm:insert_modifier"); ok { + insertModifier = strings.ToUpper(fmt.Sprint(str)) + if insertModifier == "INTO" { + insertModifier = "" + } + } if primaryField != nil { returningColumn = scope.Quote(primaryField.DBName) } - lastInsertIDReturningSuffix := scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn) + lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(quotedTableName, returningColumn, columns) + var lastInsertIDReturningSuffix string + if lastInsertIDOutputInterstitial == "" { + lastInsertIDReturningSuffix = scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn) + } if len(columns) == 0 { scope.Raw(fmt.Sprintf( - "INSERT INTO %v DEFAULT VALUES%v%v", + "INSERT%v INTO %v %v%v%v", + addExtraSpaceIfExist(insertModifier), quotedTableName, + scope.Dialect().DefaultValueStr(), addExtraSpaceIfExist(extraOption), addExtraSpaceIfExist(lastInsertIDReturningSuffix), )) } else { scope.Raw(fmt.Sprintf( - "INSERT INTO %v (%v) VALUES (%v)%v%v", + "INSERT%v INTO %v (%v)%v VALUES (%v)%v%v", + addExtraSpaceIfExist(insertModifier), scope.QuotedTableName(), strings.Join(columns, ","), + addExtraSpaceIfExist(lastInsertIDOutputInterstitial), strings.Join(placeholders, ","), addExtraSpaceIfExist(extraOption), addExtraSpaceIfExist(lastInsertIDReturningSuffix), )) } - // execute create sql - if lastInsertIDReturningSuffix == "" || primaryField == nil { + // execute create sql: no primaryField + if primaryField == nil { if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { // set rows affected count scope.db.RowsAffected, _ = result.RowsAffected() @@ -118,18 +141,48 @@ func createCallback(scope *Scope) { } } } - } else { + return + } + + // execute create sql: lastInsertID implemention for majority of dialects + if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" { + if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil { + // set rows affected count + scope.db.RowsAffected, _ = result.RowsAffected() + + // set primary value to primary field + if primaryField != nil && primaryField.IsBlank { + if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil { + scope.Err(primaryField.Set(primaryValue)) + } + } + } + return + } + + // execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql) + if primaryField.Field.CanAddr() { if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil { + primaryField.IsBlank = false scope.db.RowsAffected = 1 } + } else { + scope.Err(ErrUnaddressable) } + return } } // forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object func forceReloadAfterCreateCallback(scope *Scope) { if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok { - scope.DB().New().Select(blankColumnsWithDefaultValue.([]string)).First(scope.Value) + db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string)) + for _, field := range scope.Fields() { + if field.IsPrimaryKey && !field.IsBlank { + db = db.Where(fmt.Sprintf("%v = ?", field.DBName), field.Field.Interface()) + } + } + db.Scan(scope.Value) } } diff --git a/vendor/github.com/jinzhu/gorm/callback_delete.go b/vendor/github.com/jinzhu/gorm/callback_delete.go index c8ffcc82..48b97acb 100644 --- a/vendor/github.com/jinzhu/gorm/callback_delete.go +++ b/vendor/github.com/jinzhu/gorm/callback_delete.go @@ -1,6 +1,9 @@ package gorm -import "fmt" +import ( + "errors" + "fmt" +) // Define callbacks for deleting func init() { @@ -13,6 +16,10 @@ func init() { // beforeDeleteCallback will invoke `BeforeDelete` method before deleting func beforeDeleteCallback(scope *Scope) { + if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { + scope.Err(errors.New("missing WHERE clause while deleting")) + return + } if !scope.HasError() { scope.CallMethod("BeforeDelete") } @@ -26,11 +33,14 @@ func deleteCallback(scope *Scope) { extraOption = fmt.Sprint(str) } - if !scope.Search.Unscoped && scope.HasColumn("DeletedAt") { + deletedAtField, hasDeletedAtField := scope.FieldByName("DeletedAt") + + if !scope.Search.Unscoped && hasDeletedAtField { scope.Raw(fmt.Sprintf( - "UPDATE %v SET deleted_at=%v%v%v", + "UPDATE %v SET %v=%v%v%v", scope.QuotedTableName(), - scope.AddToVars(NowFunc()), + scope.Quote(deletedAtField.DBName), + scope.AddToVars(scope.db.nowFunc()), addExtraSpaceIfExist(scope.CombinedConditionSql()), addExtraSpaceIfExist(extraOption), )).Exec() diff --git a/vendor/github.com/jinzhu/gorm/callback_query.go b/vendor/github.com/jinzhu/gorm/callback_query.go index 93782b1d..544afd63 100644 --- a/vendor/github.com/jinzhu/gorm/callback_query.go +++ b/vendor/github.com/jinzhu/gorm/callback_query.go @@ -15,6 +15,15 @@ func init() { // queryCallback used to query data from database func queryCallback(scope *Scope) { + if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { + return + } + + //we are only preloading relations, dont touch base model + if _, skip := scope.InstanceGet("gorm:only_preload"); skip { + return + } + defer scope.trace(NowFunc()) var ( @@ -30,7 +39,7 @@ func queryCallback(scope *Scope) { } if value, ok := scope.Get("gorm:query_destination"); ok { - results = reflect.Indirect(reflect.ValueOf(value)) + results = indirect(reflect.ValueOf(value)) } if kind := results.Kind(); kind == reflect.Slice { @@ -51,6 +60,11 @@ func queryCallback(scope *Scope) { if !scope.HasError() { scope.db.RowsAffected = 0 + + if str, ok := scope.Get("gorm:query_hint"); ok { + scope.SQL = fmt.Sprint(str) + scope.SQL + } + if str, ok := scope.Get("gorm:query_option"); ok { scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) } @@ -78,7 +92,9 @@ func queryCallback(scope *Scope) { } } - if scope.db.RowsAffected == 0 && !isSlice { + if err := rows.Err(); err != nil { + scope.Err(err) + } else if scope.db.RowsAffected == 0 && !isSlice { scope.Err(ErrRecordNotFound) } } diff --git a/vendor/github.com/jinzhu/gorm/callback_query_preload.go b/vendor/github.com/jinzhu/gorm/callback_query_preload.go index 5746f533..a936180a 100644 --- a/vendor/github.com/jinzhu/gorm/callback_query_preload.go +++ b/vendor/github.com/jinzhu/gorm/callback_query_preload.go @@ -4,11 +4,26 @@ import ( "errors" "fmt" "reflect" + "strconv" "strings" ) // preloadCallback used to preload associations func preloadCallback(scope *Scope) { + if _, skip := scope.InstanceGet("gorm:skip_query_callback"); skip { + return + } + + if ap, ok := scope.Get("gorm:auto_preload"); ok { + // If gorm:auto_preload IS NOT a bool then auto preload. + // Else if it IS a bool, use the value + if apb, ok := ap.(bool); !ok { + autoPreload(scope) + } else if apb { + autoPreload(scope) + } + } + if scope.Search.preload == nil || scope.HasError() { return } @@ -28,6 +43,10 @@ func preloadCallback(scope *Scope) { for idx, preloadField := range preloadFields { var currentPreloadConditions []interface{} + if currentScope == nil { + continue + } + // if not preloaded if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] { @@ -67,9 +86,30 @@ func preloadCallback(scope *Scope) { // preload next level if idx < len(preloadFields)-1 { currentScope = currentScope.getColumnAsScope(preloadField) - currentFields = currentScope.Fields() + if currentScope != nil { + currentFields = currentScope.Fields() + } + } + } + } +} + +func autoPreload(scope *Scope) { + for _, field := range scope.Fields() { + if field.Relationship == nil { + continue + } + + if val, ok := field.TagSettingsGet("PRELOAD"); ok { + if preload, err := strconv.ParseBool(val); err != nil { + scope.Err(errors.New("invalid preload option")) + return + } else if !preload { + continue } } + + scope.Search.Preload(field.Name) } } @@ -104,8 +144,15 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) // find relations + query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)) + values := toQueryValues(primaryKeys) + if relation.PolymorphicType != "" { + query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) + values = append(values, relation.PolymorphicValue) + } + results := makeSlice(field.Struct.Type) - scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error) + scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error) // assign find results var ( @@ -113,17 +160,23 @@ func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) indirectScopeValue = scope.IndirectValue() ) - for i := 0; i < resultsValue.Len(); i++ { - result := resultsValue.Index(i) - if indirectScopeValue.Kind() == reflect.Slice { - foreignValues := getValueFromFields(result, relation.ForeignFieldNames) - for j := 0; j < indirectScopeValue.Len(); j++ { - if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) { - indirectValue.FieldByName(field.Name).Set(result) - break - } + if indirectScopeValue.Kind() == reflect.Slice { + foreignValuesToResults := make(map[string]reflect.Value) + for i := 0; i < resultsValue.Len(); i++ { + result := resultsValue.Index(i) + foreignValues := toString(getValueFromFields(result, relation.ForeignFieldNames)) + foreignValuesToResults[foreignValues] = result + } + for j := 0; j < indirectScopeValue.Len(); j++ { + indirectValue := indirect(indirectScopeValue.Index(j)) + valueString := toString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames)) + if result, found := foreignValuesToResults[valueString]; found { + indirectValue.FieldByName(field.Name).Set(result) } - } else { + } + } else { + for i := 0; i < resultsValue.Len(); i++ { + result := resultsValue.Index(i) scope.Err(field.Set(result)) } } @@ -143,8 +196,15 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) // find relations + query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)) + values := toQueryValues(primaryKeys) + if relation.PolymorphicType != "" { + query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) + values = append(values, relation.PolymorphicValue) + } + results := makeSlice(field.Struct.Type) - scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error) + scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error) // assign find results var ( @@ -153,16 +213,21 @@ func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) ) if indirectScopeValue.Kind() == reflect.Slice { + preloadMap := make(map[string][]reflect.Value) for i := 0; i < resultsValue.Len(); i++ { result := resultsValue.Index(i) foreignValues := getValueFromFields(result, relation.ForeignFieldNames) - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - if equalAsString(getValueFromFields(object, relation.AssociationForeignFieldNames), foreignValues) { - objectField := object.FieldByName(field.Name) - objectField.Set(reflect.Append(objectField, result)) - break - } + preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result) + } + + for j := 0; j < indirectScopeValue.Len(); j++ { + object := indirect(indirectScopeValue.Index(j)) + objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames) + f := object.FieldByName(field.Name) + if results, ok := preloadMap[toString(objectRealValue)]; ok { + f.Set(reflect.Append(f, results...)) + } else { + f.Set(reflect.MakeSlice(f.Type(), 0, 0)) } } } else { @@ -193,13 +258,21 @@ func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{ indirectScopeValue = scope.IndirectValue() ) + foreignFieldToObjects := make(map[string][]*reflect.Value) + if indirectScopeValue.Kind() == reflect.Slice { + for j := 0; j < indirectScopeValue.Len(); j++ { + object := indirect(indirectScopeValue.Index(j)) + valueString := toString(getValueFromFields(object, relation.ForeignFieldNames)) + foreignFieldToObjects[valueString] = append(foreignFieldToObjects[valueString], &object) + } + } + for i := 0; i < resultsValue.Len(); i++ { result := resultsValue.Index(i) if indirectScopeValue.Kind() == reflect.Slice { - value := getValueFromFields(result, relation.AssociationForeignFieldNames) - for j := 0; j < indirectScopeValue.Len(); j++ { - object := indirect(indirectScopeValue.Index(j)) - if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) { + valueString := toString(getValueFromFields(result, relation.AssociationForeignFieldNames)) + if objects, found := foreignFieldToObjects[valueString]; found { + for _, object := range objects { object.FieldByName(field.Name).Set(result) } } @@ -236,7 +309,12 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface // generate query with join table newScope := scope.New(reflect.New(fieldType).Interface()) - preloadDB = preloadDB.Table(newScope.TableName()).Select("*") + preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value) + + if len(preloadDB.search.selects) == 0 { + preloadDB = preloadDB.Select("*") + } + preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value) // preload inline conditions @@ -266,6 +344,10 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface scope.scan(rows, columns, append(fields, joinTableFields...)) + scope.New(elem.Addr().Interface()). + InstanceSet("gorm:skip_query_callback", true). + callCallbacks(scope.db.parent.callbacks.queries) + var foreignKeys = make([]interface{}, len(sourceKeys)) // generate hashed forkey keys in join table for idx, joinTableField := range joinTableFields { @@ -282,10 +364,14 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface } } + if err := rows.Err(); err != nil { + scope.Err(err) + } + // assign find results var ( indirectScopeValue = scope.IndirectValue() - fieldsSourceMap = map[string]reflect.Value{} + fieldsSourceMap = map[string][]reflect.Value{} foreignFieldNames = []string{} ) @@ -298,13 +384,27 @@ func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface if indirectScopeValue.Kind() == reflect.Slice { for j := 0; j < indirectScopeValue.Len(); j++ { object := indirect(indirectScopeValue.Index(j)) - fieldsSourceMap[toString(getValueFromFields(object, foreignFieldNames))] = object.FieldByName(field.Name) + key := toString(getValueFromFields(object, foreignFieldNames)) + fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name)) } } else if indirectScopeValue.IsValid() { - fieldsSourceMap[toString(getValueFromFields(indirectScopeValue, foreignFieldNames))] = indirectScopeValue.FieldByName(field.Name) + key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames)) + fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name)) } - for source, link := range linkHash { - fieldsSourceMap[source].Set(reflect.Append(fieldsSourceMap[source], link...)) + for source, fields := range fieldsSourceMap { + for _, f := range fields { + //If not 0 this means Value is a pointer and we already added preloaded models to it + if f.Len() != 0 { + continue + } + + v := reflect.MakeSlice(f.Type(), 0, 0) + if len(linkHash[source]) > 0 { + v = reflect.Append(f, linkHash[source]...) + } + + f.Set(v) + } } } diff --git a/vendor/github.com/jinzhu/gorm/callback_row_query.go b/vendor/github.com/jinzhu/gorm/callback_row_query.go new file mode 100644 index 00000000..323b1605 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/callback_row_query.go @@ -0,0 +1,41 @@ +package gorm + +import ( + "database/sql" + "fmt" +) + +// Define callbacks for row query +func init() { + DefaultCallback.RowQuery().Register("gorm:row_query", rowQueryCallback) +} + +type RowQueryResult struct { + Row *sql.Row +} + +type RowsQueryResult struct { + Rows *sql.Rows + Error error +} + +// queryCallback used to query data from database +func rowQueryCallback(scope *Scope) { + if result, ok := scope.InstanceGet("row_query_result"); ok { + scope.prepareQuerySQL() + + if str, ok := scope.Get("gorm:query_hint"); ok { + scope.SQL = fmt.Sprint(str) + scope.SQL + } + + if str, ok := scope.Get("gorm:query_option"); ok { + scope.SQL += addExtraSpaceIfExist(fmt.Sprint(str)) + } + + if rowResult, ok := result.(*RowQueryResult); ok { + rowResult.Row = scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...) + } else if rowsResult, ok := result.(*RowsQueryResult); ok { + rowsResult.Rows, rowsResult.Error = scope.SQLDB().Query(scope.SQL, scope.SQLVars...) + } + } +} diff --git a/vendor/github.com/jinzhu/gorm/callback_save.go b/vendor/github.com/jinzhu/gorm/callback_save.go index 5ffe53b9..3b4e0589 100644 --- a/vendor/github.com/jinzhu/gorm/callback_save.go +++ b/vendor/github.com/jinzhu/gorm/callback_save.go @@ -1,6 +1,9 @@ package gorm -import "reflect" +import ( + "reflect" + "strings" +) func beginTransactionCallback(scope *Scope) { scope.Begin() @@ -10,15 +13,74 @@ func commitOrRollbackTransactionCallback(scope *Scope) { scope.CommitOrRollback() } -func saveBeforeAssociationsCallback(scope *Scope) { - if !scope.shouldSaveAssociations() { - return +func saveAssociationCheck(scope *Scope, field *Field) (autoUpdate bool, autoCreate bool, saveReference bool, r *Relationship) { + checkTruth := func(value interface{}) bool { + if v, ok := value.(bool); ok && !v { + return false + } + + if v, ok := value.(string); ok { + v = strings.ToLower(v) + return v == "true" + } + + return true + } + + if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { + if r = field.Relationship; r != nil { + autoUpdate, autoCreate, saveReference = true, true, true + + if value, ok := scope.Get("gorm:save_associations"); ok { + autoUpdate = checkTruth(value) + autoCreate = autoUpdate + saveReference = autoUpdate + } else if value, ok := field.TagSettingsGet("SAVE_ASSOCIATIONS"); ok { + autoUpdate = checkTruth(value) + autoCreate = autoUpdate + saveReference = autoUpdate + } + + if value, ok := scope.Get("gorm:association_autoupdate"); ok { + autoUpdate = checkTruth(value) + } else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOUPDATE"); ok { + autoUpdate = checkTruth(value) + } + + if value, ok := scope.Get("gorm:association_autocreate"); ok { + autoCreate = checkTruth(value) + } else if value, ok := field.TagSettingsGet("ASSOCIATION_AUTOCREATE"); ok { + autoCreate = checkTruth(value) + } + + if value, ok := scope.Get("gorm:association_save_reference"); ok { + saveReference = checkTruth(value) + } else if value, ok := field.TagSettingsGet("ASSOCIATION_SAVE_REFERENCE"); ok { + saveReference = checkTruth(value) + } + } } + + return +} + +func saveBeforeAssociationsCallback(scope *Scope) { for _, field := range scope.Fields() { - if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { - fieldValue := field.Field.Addr().Interface() + autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) + + if relationship != nil && relationship.Kind == "belongs_to" { + fieldValue := field.Field.Addr().Interface() + newScope := scope.New(fieldValue) + + if newScope.PrimaryKeyZero() { + if autoCreate { + scope.Err(scope.NewDB().Save(fieldValue).Error) + } + } else if autoUpdate { scope.Err(scope.NewDB().Save(fieldValue).Error) + } + + if saveReference { if len(relationship.ForeignFieldNames) != 0 { // set value's foreign key for idx, fieldName := range relationship.ForeignFieldNames { @@ -34,22 +96,20 @@ func saveBeforeAssociationsCallback(scope *Scope) { } func saveAfterAssociationsCallback(scope *Scope) { - if !scope.shouldSaveAssociations() { - return - } for _, field := range scope.Fields() { - if scope.changeableField(field) && !field.IsBlank && !field.IsIgnored { - if relationship := field.Relationship; relationship != nil && - (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { - value := field.Field - - switch value.Kind() { - case reflect.Slice: - for i := 0; i < value.Len(); i++ { - newDB := scope.NewDB() - elem := value.Index(i).Addr().Interface() - newScope := newDB.NewScope(elem) + autoUpdate, autoCreate, saveReference, relationship := saveAssociationCheck(scope, field) + + if relationship != nil && (relationship.Kind == "has_one" || relationship.Kind == "has_many" || relationship.Kind == "many_to_many") { + value := field.Field + switch value.Kind() { + case reflect.Slice: + for i := 0; i < value.Len(); i++ { + newDB := scope.NewDB() + elem := value.Index(i).Addr().Interface() + newScope := newDB.NewScope(elem) + + if saveReference { if relationship.JoinTableHandler == nil && len(relationship.ForeignFieldNames) != 0 { for idx, fieldName := range relationship.ForeignFieldNames { associationForeignName := relationship.AssociationForeignDBNames[idx] @@ -60,18 +120,29 @@ func saveAfterAssociationsCallback(scope *Scope) { } if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName())) + scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) } + } + if newScope.PrimaryKeyZero() { + if autoCreate { + scope.Err(newDB.Save(elem).Error) + } + } else if autoUpdate { scope.Err(newDB.Save(elem).Error) + } + if !scope.New(newScope.Value).PrimaryKeyZero() && saveReference { if joinTableHandler := relationship.JoinTableHandler; joinTableHandler != nil { scope.Err(joinTableHandler.Add(joinTableHandler, newDB, scope.Value, newScope.Value)) } } - default: - elem := value.Addr().Interface() - newScope := scope.New(elem) + } + default: + elem := value.Addr().Interface() + newScope := scope.New(elem) + + if saveReference { if len(relationship.ForeignFieldNames) != 0 { for idx, fieldName := range relationship.ForeignFieldNames { associationForeignName := relationship.AssociationForeignDBNames[idx] @@ -82,8 +153,15 @@ func saveAfterAssociationsCallback(scope *Scope) { } if relationship.PolymorphicType != "" { - scope.Err(newScope.SetColumn(relationship.PolymorphicType, scope.TableName())) + scope.Err(newScope.SetColumn(relationship.PolymorphicType, relationship.PolymorphicValue)) + } + } + + if newScope.PrimaryKeyZero() { + if autoCreate { + scope.Err(scope.NewDB().Save(elem).Error) } + } else if autoUpdate { scope.Err(scope.NewDB().Save(elem).Error) } } diff --git a/vendor/github.com/jinzhu/gorm/callback_update.go b/vendor/github.com/jinzhu/gorm/callback_update.go index aa27b5fb..699e534b 100644 --- a/vendor/github.com/jinzhu/gorm/callback_update.go +++ b/vendor/github.com/jinzhu/gorm/callback_update.go @@ -1,7 +1,9 @@ package gorm import ( + "errors" "fmt" + "sort" "strings" ) @@ -31,6 +33,10 @@ func assignUpdatingAttributesCallback(scope *Scope) { // beforeUpdateCallback will invoke `BeforeSave`, `BeforeUpdate` method before updating func beforeUpdateCallback(scope *Scope) { + if scope.DB().HasBlockGlobalUpdate() && !scope.hasConditions() { + scope.Err(errors.New("missing WHERE clause while updating")) + return + } if _, ok := scope.Get("gorm:update_column"); !ok { if !scope.HasError() { scope.CallMethod("BeforeSave") @@ -44,7 +50,7 @@ func beforeUpdateCallback(scope *Scope) { // updateTimeStampForUpdateCallback will set `UpdatedAt` when updating func updateTimeStampForUpdateCallback(scope *Scope) { if _, ok := scope.Get("gorm:update_column"); !ok { - scope.SetColumn("UpdatedAt", NowFunc()) + scope.SetColumn("UpdatedAt", scope.db.nowFunc()) } } @@ -54,14 +60,25 @@ func updateCallback(scope *Scope) { var sqls []string if updateAttrs, ok := scope.InstanceGet("gorm:update_attrs"); ok { - for column, value := range updateAttrs.(map[string]interface{}) { + // Sort the column names so that the generated SQL is the same every time. + updateMap := updateAttrs.(map[string]interface{}) + var columns []string + for c := range updateMap { + columns = append(columns, c) + } + sort.Strings(columns) + + for _, column := range columns { + value := updateMap[column] sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(column), scope.AddToVars(value))) } } else { for _, field := range scope.Fields() { if scope.changeableField(field) { - if !field.IsPrimaryKey && field.IsNormal { - sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + if !field.IsPrimaryKey && field.IsNormal && (field.Name != "CreatedAt" || !field.IsBlank) { + if !field.IsForeignKey || !field.IsBlank || !field.HasDefaultValue { + sqls = append(sqls, fmt.Sprintf("%v = %v", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + } } else if relationship := field.Relationship; relationship != nil && relationship.Kind == "belongs_to" { for _, foreignKey := range relationship.ForeignDBNames { if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) { diff --git a/vendor/github.com/jinzhu/gorm/dialect.go b/vendor/github.com/jinzhu/gorm/dialect.go index 6c9405da..749587f4 100644 --- a/vendor/github.com/jinzhu/gorm/dialect.go +++ b/vendor/github.com/jinzhu/gorm/dialect.go @@ -14,7 +14,7 @@ type Dialect interface { GetName() string // SetDB set db for dialect - SetDB(db *sql.DB) + SetDB(db SQLCommon) // BindVar return the placeholder for actual values in SQL statements, in many dbs it is "?", Postgres using $1 BindVar(i int) string @@ -33,18 +33,33 @@ type Dialect interface { HasTable(tableName string) bool // HasColumn check has column or not HasColumn(tableName string, columnName string) bool + // ModifyColumn modify column's type + ModifyColumn(tableName string, columnName string, typ string) error // LimitAndOffsetSQL return generated SQL with Limit and Offset, as mssql has special case - LimitAndOffsetSQL(limit, offset int) string + LimitAndOffsetSQL(limit, offset interface{}) (string, error) // SelectFromDummyTable return select values, for most dbs, `SELECT values` just works, mysql needs `SELECT value FROM DUAL` SelectFromDummyTable() string + // LastInsertIDOutputInterstitial most dbs support LastInsertId, but mssql needs to use `OUTPUT` + LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string // LastInsertIdReturningSuffix most dbs support LastInsertId, but postgres needs to use `RETURNING` LastInsertIDReturningSuffix(tableName, columnName string) string + // DefaultValueStr + DefaultValueStr() string + + // BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference + BuildKeyName(kind, tableName string, fields ...string) string + + // NormalizeIndexAndColumn returns valid index name and column name depending on each dialect + NormalizeIndexAndColumn(indexName, columnName string) (string, string) + + // CurrentDatabase return current database name + CurrentDatabase() string } var dialectsMap = map[string]Dialect{} -func newDialect(name string, db *sql.DB) Dialect { +func newDialect(name string, db SQLCommon) Dialect { if value, ok := dialectsMap[name]; ok { dialect := reflect.New(reflect.TypeOf(value).Elem()).Interface().(Dialect) dialect.SetDB(db) @@ -62,10 +77,20 @@ func RegisterDialect(name string, dialect Dialect) { dialectsMap[name] = dialect } -// ParseFieldStructForDialect parse field struct for dialect -func ParseFieldStructForDialect(field *StructField) (fieldValue reflect.Value, sqlType string, size int, additionalType string) { +// GetDialect gets the dialect for the specified dialect name +func GetDialect(name string) (dialect Dialect, ok bool) { + dialect, ok = dialectsMap[name] + return +} + +// ParseFieldStructForDialect get field's sql data type +var ParseFieldStructForDialect = func(field *StructField, dialect Dialect) (fieldValue reflect.Value, sqlType string, size int, additionalType string) { // Get redirected field type - var reflectType = field.Struct.Type + var ( + reflectType = field.Struct.Type + dataType, _ = field.TagSettingsGet("TYPE") + ) + for reflectType.Kind() == reflect.Ptr { reflectType = reflectType.Elem() } @@ -73,28 +98,50 @@ func ParseFieldStructForDialect(field *StructField) (fieldValue reflect.Value, s // Get redirected field value fieldValue = reflect.Indirect(reflect.New(reflectType)) + if gormDataType, ok := fieldValue.Interface().(interface { + GormDataType(Dialect) string + }); ok { + dataType = gormDataType.GormDataType(dialect) + } + // Get scanner's real value - var getScannerValue func(reflect.Value) - getScannerValue = func(value reflect.Value) { - fieldValue = value - if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct { - getScannerValue(fieldValue.Field(0)) + if dataType == "" { + var getScannerValue func(reflect.Value) + getScannerValue = func(value reflect.Value) { + fieldValue = value + if _, isScanner := reflect.New(fieldValue.Type()).Interface().(sql.Scanner); isScanner && fieldValue.Kind() == reflect.Struct { + getScannerValue(fieldValue.Field(0)) + } } + getScannerValue(fieldValue) } - getScannerValue(fieldValue) // Default Size - if num, ok := field.TagSettings["SIZE"]; ok { + if num, ok := field.TagSettingsGet("SIZE"); ok { size, _ = strconv.Atoi(num) } else { size = 255 } // Default type from tag setting - additionalType = field.TagSettings["NOT NULL"] + " " + field.TagSettings["UNIQUE"] - if value, ok := field.TagSettings["DEFAULT"]; ok { + notNull, _ := field.TagSettingsGet("NOT NULL") + unique, _ := field.TagSettingsGet("UNIQUE") + additionalType = notNull + " " + unique + if value, ok := field.TagSettingsGet("DEFAULT"); ok { additionalType = additionalType + " DEFAULT " + value } - return fieldValue, field.TagSettings["TYPE"], size, strings.TrimSpace(additionalType) + if value, ok := field.TagSettingsGet("COMMENT"); ok { + additionalType = additionalType + " COMMENT " + value + } + + return fieldValue, dataType, size, strings.TrimSpace(additionalType) +} + +func currentDatabaseAndTable(dialect Dialect, tableName string) (string, string) { + if strings.Contains(tableName, ".") { + splitStrings := strings.SplitN(tableName, ".", 2) + return splitStrings[0], splitStrings[1] + } + return dialect.CurrentDatabase(), tableName } diff --git a/vendor/github.com/jinzhu/gorm/dialect_common.go b/vendor/github.com/jinzhu/gorm/dialect_common.go index f009271b..d549510c 100644 --- a/vendor/github.com/jinzhu/gorm/dialect_common.go +++ b/vendor/github.com/jinzhu/gorm/dialect_common.go @@ -1,15 +1,23 @@ package gorm import ( - "database/sql" "fmt" "reflect" + "regexp" + "strconv" "strings" "time" ) +var keyNameRegex = regexp.MustCompile("[^a-zA-Z0-9]+") + +// DefaultForeignKeyNamer contains the default foreign key name generator method +type DefaultForeignKeyNamer struct { +} + type commonDialect struct { - db *sql.DB + db SQLCommon + DefaultForeignKeyNamer } func init() { @@ -20,33 +28,40 @@ func (commonDialect) GetName() string { return "common" } -func (s *commonDialect) SetDB(db *sql.DB) { +func (s *commonDialect) SetDB(db SQLCommon) { s.db = db } func (commonDialect) BindVar(i int) string { - return "$$" // ? + return "$$$" // ? } func (commonDialect) Quote(key string) string { return fmt.Sprintf(`"%s"`, key) } -func (commonDialect) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) +func (s *commonDialect) fieldCanAutoIncrement(field *StructField) bool { + if value, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { + return strings.ToLower(value) != "false" + } + return field.IsPrimaryKey +} + +func (s *commonDialect) DataTypeOf(field *StructField) string { + var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) if sqlType == "" { switch dataValue.Kind() { case reflect.Bool: sqlType = "BOOLEAN" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + if s.fieldCanAutoIncrement(field) { sqlType = "INTEGER AUTO_INCREMENT" } else { sqlType = "INTEGER" } case reflect.Int64, reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok { + if s.fieldCanAutoIncrement(field) { sqlType = "BIGINT AUTO_INCREMENT" } else { sqlType = "BIGINT" @@ -86,7 +101,8 @@ func (commonDialect) DataTypeOf(field *StructField) string { func (s commonDialect) HasIndex(tableName string, indexName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", s.currentDatabase(), tableName, indexName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.STATISTICS WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, tableName, indexName).Scan(&count) return count > 0 } @@ -101,28 +117,42 @@ func (s commonDialect) HasForeignKey(tableName string, foreignKeyName string) bo func (s commonDialect) HasTable(tableName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", s.currentDatabase(), tableName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLES WHERE table_schema = ? AND table_name = ?", currentDatabase, tableName).Scan(&count) return count > 0 } func (s commonDialect) HasColumn(tableName string, columnName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", s.currentDatabase(), tableName, columnName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, tableName, columnName).Scan(&count) return count > 0 } -func (s commonDialect) currentDatabase() (name string) { +func (s commonDialect) ModifyColumn(tableName string, columnName string, typ string) error { + _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v ALTER COLUMN %v TYPE %v", tableName, columnName, typ)) + return err +} + +func (s commonDialect) CurrentDatabase() (name string) { s.db.QueryRow("SELECT DATABASE()").Scan(&name) return } -func (commonDialect) LimitAndOffsetSQL(limit, offset int) (sql string) { - if limit > 0 || offset > 0 { - if limit >= 0 { - sql += fmt.Sprintf(" LIMIT %d", limit) +// LimitAndOffsetSQL return generated SQL with Limit and Offset +func (s commonDialect) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) { + if limit != nil { + if parsedLimit, err := s.parseInt(limit); err != nil { + return "", err + } else if parsedLimit >= 0 { + sql += fmt.Sprintf(" LIMIT %d", parsedLimit) } - if offset >= 0 { - sql += fmt.Sprintf(" OFFSET %d", offset) + } + if offset != nil { + if parsedOffset, err := s.parseInt(offset); err != nil { + return "", err + } else if parsedOffset >= 0 { + sql += fmt.Sprintf(" OFFSET %d", parsedOffset) } } return @@ -132,6 +162,35 @@ func (commonDialect) SelectFromDummyTable() string { return "" } +func (commonDialect) LastInsertIDOutputInterstitial(tableName, columnName string, columns []string) string { + return "" +} + func (commonDialect) LastInsertIDReturningSuffix(tableName, columnName string) string { return "" } + +func (commonDialect) DefaultValueStr() string { + return "DEFAULT VALUES" +} + +// BuildKeyName returns a valid key name (foreign key, index key) for the given table, field and reference +func (DefaultForeignKeyNamer) BuildKeyName(kind, tableName string, fields ...string) string { + keyName := fmt.Sprintf("%s_%s_%s", kind, tableName, strings.Join(fields, "_")) + keyName = keyNameRegex.ReplaceAllString(keyName, "_") + return keyName +} + +// NormalizeIndexAndColumn returns argument's index name and column name without doing anything +func (commonDialect) NormalizeIndexAndColumn(indexName, columnName string) (string, string) { + return indexName, columnName +} + +func (commonDialect) parseInt(value interface{}) (int64, error) { + return strconv.ParseInt(fmt.Sprint(value), 0, 0) +} + +// IsByteArrayOrSlice returns true of the reflected value is an array or slice +func IsByteArrayOrSlice(value reflect.Value) bool { + return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0)) +} diff --git a/vendor/github.com/jinzhu/gorm/dialect_mysql.go b/vendor/github.com/jinzhu/gorm/dialect_mysql.go index 6fade59d..b4467ffa 100644 --- a/vendor/github.com/jinzhu/gorm/dialect_mysql.go +++ b/vendor/github.com/jinzhu/gorm/dialect_mysql.go @@ -1,12 +1,18 @@ package gorm import ( + "crypto/sha1" + "database/sql" "fmt" "reflect" + "regexp" "strings" "time" + "unicode/utf8" ) +var mysqlIndexRegex = regexp.MustCompile(`^(.+)\((\d+)\)$`) + type mysql struct { commonDialect } @@ -24,33 +30,59 @@ func (mysql) Quote(key string) string { } // Get Data Type for MySQL Dialect -func (mysql) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) +func (s *mysql) DataTypeOf(field *StructField) string { + var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) + + // MySQL allows only one auto increment column per table, and it must + // be a KEY column. + if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok { + if _, ok = field.TagSettingsGet("INDEX"); !ok && !field.IsPrimaryKey { + field.TagSettingsDelete("AUTO_INCREMENT") + } + } if sqlType == "" { switch dataValue.Kind() { case reflect.Bool: sqlType = "boolean" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + case reflect.Int8: + if s.fieldCanAutoIncrement(field) { + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") + sqlType = "tinyint AUTO_INCREMENT" + } else { + sqlType = "tinyint" + } + case reflect.Int, reflect.Int16, reflect.Int32: + if s.fieldCanAutoIncrement(field) { + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "int AUTO_INCREMENT" } else { sqlType = "int" } - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + case reflect.Uint8: + if s.fieldCanAutoIncrement(field) { + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") + sqlType = "tinyint unsigned AUTO_INCREMENT" + } else { + sqlType = "tinyint unsigned" + } + case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uintptr: + if s.fieldCanAutoIncrement(field) { + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "int unsigned AUTO_INCREMENT" } else { sqlType = "int unsigned" } case reflect.Int64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "bigint AUTO_INCREMENT" } else { sqlType = "bigint" } case reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "bigint unsigned AUTO_INCREMENT" } else { sqlType = "bigint unsigned" @@ -65,14 +97,19 @@ func (mysql) DataTypeOf(field *StructField) string { } case reflect.Struct: if _, ok := dataValue.Interface().(time.Time); ok { - if _, ok := field.TagSettings["NOT NULL"]; ok { - sqlType = "timestamp" + precision := "" + if p, ok := field.TagSettingsGet("PRECISION"); ok { + precision = fmt.Sprintf("(%s)", p) + } + + if _, ok := field.TagSettings["NOT NULL"]; ok || field.IsPrimaryKey { + sqlType = fmt.Sprintf("DATETIME%v", precision) } else { - sqlType = "timestamp NULL" + sqlType = fmt.Sprintf("DATETIME%v NULL", precision) } } default: - if _, ok := dataValue.Interface().([]byte); ok { + if IsByteArrayOrSlice(dataValue) { if size > 0 && size < 65532 { sqlType = fmt.Sprintf("varbinary(%d)", size) } else { @@ -83,7 +120,7 @@ func (mysql) DataTypeOf(field *StructField) string { } if sqlType == "" { - panic(fmt.Sprintf("invalid sql type %s (%s) for mysql", dataValue.Type().Name(), dataValue.Kind().String())) + panic(fmt.Sprintf("invalid sql type %s (%s) in field %s for mysql", dataValue.Type().Name(), dataValue.Kind().String(), field.Name)) } if strings.TrimSpace(additionalType) == "" { @@ -97,13 +134,76 @@ func (s mysql) RemoveIndex(tableName string, indexName string) error { return err } +func (s mysql) ModifyColumn(tableName string, columnName string, typ string) error { + _, err := s.db.Exec(fmt.Sprintf("ALTER TABLE %v MODIFY COLUMN %v %v", tableName, columnName, typ)) + return err +} + +func (s mysql) LimitAndOffsetSQL(limit, offset interface{}) (sql string, err error) { + if limit != nil { + parsedLimit, err := s.parseInt(limit) + if err != nil { + return "", err + } + if parsedLimit >= 0 { + sql += fmt.Sprintf(" LIMIT %d", parsedLimit) + + if offset != nil { + parsedOffset, err := s.parseInt(offset) + if err != nil { + return "", err + } + if parsedOffset >= 0 { + sql += fmt.Sprintf(" OFFSET %d", parsedOffset) + } + } + } + } + return +} + func (s mysql) HasForeignKey(tableName string, foreignKeyName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", s.currentDatabase(), tableName, foreignKeyName).Scan(&count) + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS WHERE CONSTRAINT_SCHEMA=? AND TABLE_NAME=? AND CONSTRAINT_NAME=? AND CONSTRAINT_TYPE='FOREIGN KEY'", currentDatabase, tableName, foreignKeyName).Scan(&count) return count > 0 } -func (s mysql) currentDatabase() (name string) { +func (s mysql) HasTable(tableName string) bool { + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + var name string + // allow mysql database name with '-' character + if err := s.db.QueryRow(fmt.Sprintf("SHOW TABLES FROM `%s` WHERE `Tables_in_%s` = ?", currentDatabase, currentDatabase), tableName).Scan(&name); err != nil { + if err == sql.ErrNoRows { + return false + } + panic(err) + } else { + return true + } +} + +func (s mysql) HasIndex(tableName string, indexName string) bool { + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + if rows, err := s.db.Query(fmt.Sprintf("SHOW INDEXES FROM `%s` FROM `%s` WHERE Key_name = ?", tableName, currentDatabase), indexName); err != nil { + panic(err) + } else { + defer rows.Close() + return rows.Next() + } +} + +func (s mysql) HasColumn(tableName string, columnName string) bool { + currentDatabase, tableName := currentDatabaseAndTable(&s, tableName) + if rows, err := s.db.Query(fmt.Sprintf("SHOW COLUMNS FROM `%s` FROM `%s` WHERE Field = ?", tableName, currentDatabase), columnName); err != nil { + panic(err) + } else { + defer rows.Close() + return rows.Next() + } +} + +func (s mysql) CurrentDatabase() (name string) { s.db.QueryRow("SELECT DATABASE()").Scan(&name) return } @@ -111,3 +211,36 @@ func (s mysql) currentDatabase() (name string) { func (mysql) SelectFromDummyTable() string { return "FROM DUAL" } + +func (s mysql) BuildKeyName(kind, tableName string, fields ...string) string { + keyName := s.commonDialect.BuildKeyName(kind, tableName, fields...) + if utf8.RuneCountInString(keyName) <= 64 { + return keyName + } + h := sha1.New() + h.Write([]byte(keyName)) + bs := h.Sum(nil) + + // sha1 is 40 characters, keep first 24 characters of destination + destRunes := []rune(keyNameRegex.ReplaceAllString(fields[0], "_")) + if len(destRunes) > 24 { + destRunes = destRunes[:24] + } + + return fmt.Sprintf("%s%x", string(destRunes), bs) +} + +// NormalizeIndexAndColumn returns index name and column name for specify an index prefix length if needed +func (mysql) NormalizeIndexAndColumn(indexName, columnName string) (string, string) { + submatch := mysqlIndexRegex.FindStringSubmatch(indexName) + if len(submatch) != 3 { + return indexName, columnName + } + indexName = submatch[1] + columnName = fmt.Sprintf("%s(%s)", columnName, submatch[2]) + return indexName, columnName +} + +func (mysql) DefaultValueStr() string { + return "VALUES()" +} diff --git a/vendor/github.com/jinzhu/gorm/dialect_postgres.go b/vendor/github.com/jinzhu/gorm/dialect_postgres.go index 09ac5961..d2df3131 100644 --- a/vendor/github.com/jinzhu/gorm/dialect_postgres.go +++ b/vendor/github.com/jinzhu/gorm/dialect_postgres.go @@ -1,6 +1,7 @@ package gorm import ( + "encoding/json" "fmt" "reflect" "strings" @@ -13,6 +14,7 @@ type postgres struct { func init() { RegisterDialect("postgres", &postgres{}) + RegisterDialect("cloudsqlpostgres", &postgres{}) } func (postgres) GetName() string { @@ -23,21 +25,23 @@ func (postgres) BindVar(i int) string { return fmt.Sprintf("$%v", i) } -func (postgres) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) +func (s *postgres) DataTypeOf(field *StructField) string { + var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) if sqlType == "" { switch dataValue.Kind() { case reflect.Bool: sqlType = "boolean" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uintptr: + if s.fieldCanAutoIncrement(field) { + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "serial" } else { sqlType = "integer" } - case reflect.Int64, reflect.Uint64: - if _, ok := field.TagSettings["AUTO_INCREMENT"]; ok || field.IsPrimaryKey { + case reflect.Int64, reflect.Uint32, reflect.Uint64: + if s.fieldCanAutoIncrement(field) { + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "bigserial" } else { sqlType = "bigint" @@ -45,7 +49,7 @@ func (postgres) DataTypeOf(field *StructField) string { case reflect.Float32, reflect.Float64: sqlType = "numeric" case reflect.String: - if _, ok := field.TagSettings["SIZE"]; !ok { + if _, ok := field.TagSettingsGet("SIZE"); !ok { size = 0 // if SIZE haven't been set, use `text` as the default type, as there are no performance different } @@ -63,10 +67,16 @@ func (postgres) DataTypeOf(field *StructField) string { sqlType = "hstore" } default: - if isByteArrayOrSlice(dataValue) { + if IsByteArrayOrSlice(dataValue) { sqlType = "bytea" - } else if isUUID(dataValue) { - sqlType = "uuid" + + if isUUID(dataValue) { + sqlType = "uuid" + } + + if isJSON(dataValue) { + sqlType = "jsonb" + } } } } @@ -83,33 +93,37 @@ func (postgres) DataTypeOf(field *StructField) string { func (s postgres) HasIndex(tableName string, indexName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2", tableName, indexName).Scan(&count) + s.db.QueryRow("SELECT count(*) FROM pg_indexes WHERE tablename = $1 AND indexname = $2 AND schemaname = CURRENT_SCHEMA()", tableName, indexName).Scan(&count) return count > 0 } func (s postgres) HasForeignKey(tableName string, foreignKeyName string) bool { var count int - s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", s.currentDatabase(), foreignKeyName).Scan(&count) + s.db.QueryRow("SELECT count(con.conname) FROM pg_constraint con WHERE $1::regclass::oid = con.conrelid AND con.conname = $2 AND con.contype='f'", tableName, foreignKeyName).Scan(&count) return count > 0 } func (s postgres) HasTable(tableName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE'", tableName).Scan(&count) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = $1 AND table_type = 'BASE TABLE' AND table_schema = CURRENT_SCHEMA()", tableName).Scan(&count) return count > 0 } func (s postgres) HasColumn(tableName string, columnName string) bool { var count int - s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2", tableName, columnName).Scan(&count) + s.db.QueryRow("SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_name = $1 AND column_name = $2 AND table_schema = CURRENT_SCHEMA()", tableName, columnName).Scan(&count) return count > 0 } -func (s postgres) currentDatabase() (name string) { +func (s postgres) CurrentDatabase() (name string) { s.db.QueryRow("SELECT CURRENT_DATABASE()").Scan(&name) return } +func (s postgres) LastInsertIDOutputInterstitial(tableName, key string, columns []string) string { + return "" +} + func (s postgres) LastInsertIDReturningSuffix(tableName, key string) string { return fmt.Sprintf("RETURNING %v.%v", tableName, key) } @@ -118,10 +132,6 @@ func (postgres) SupportLastInsertID() bool { return false } -func isByteArrayOrSlice(value reflect.Value) bool { - return (value.Kind() == reflect.Array || value.Kind() == reflect.Slice) && value.Type().Elem() == reflect.TypeOf(uint8(0)) -} - func isUUID(value reflect.Value) bool { if value.Kind() != reflect.Array || value.Type().Len() != 16 { return false @@ -130,3 +140,8 @@ func isUUID(value reflect.Value) bool { lower := strings.ToLower(typename) return "uuid" == lower || "guid" == lower } + +func isJSON(value reflect.Value) bool { + _, ok := value.Interface().(json.RawMessage) + return ok +} diff --git a/vendor/github.com/jinzhu/gorm/dialect_sqlite3.go b/vendor/github.com/jinzhu/gorm/dialect_sqlite3.go index 5c262aaf..5f96c363 100644 --- a/vendor/github.com/jinzhu/gorm/dialect_sqlite3.go +++ b/vendor/github.com/jinzhu/gorm/dialect_sqlite3.go @@ -12,7 +12,6 @@ type sqlite3 struct { } func init() { - RegisterDialect("sqlite", &sqlite3{}) RegisterDialect("sqlite3", &sqlite3{}) } @@ -21,21 +20,23 @@ func (sqlite3) GetName() string { } // Get Data Type for Sqlite Dialect -func (sqlite3) DataTypeOf(field *StructField) string { - var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field) +func (s *sqlite3) DataTypeOf(field *StructField) string { + var dataValue, sqlType, size, additionalType = ParseFieldStructForDialect(field, s) if sqlType == "" { switch dataValue.Kind() { case reflect.Bool: sqlType = "bool" case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uintptr: - if field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "integer primary key autoincrement" } else { sqlType = "integer" } case reflect.Int64, reflect.Uint64: - if field.IsPrimaryKey { + if s.fieldCanAutoIncrement(field) { + field.TagSettingsSet("AUTO_INCREMENT", "AUTO_INCREMENT") sqlType = "integer primary key autoincrement" } else { sqlType = "bigint" @@ -53,7 +54,7 @@ func (sqlite3) DataTypeOf(field *StructField) string { sqlType = "datetime" } default: - if _, ok := dataValue.Interface().([]byte); ok { + if IsByteArrayOrSlice(dataValue) { sqlType = "blob" } } @@ -87,7 +88,7 @@ func (s sqlite3) HasColumn(tableName string, columnName string) bool { return count > 0 } -func (s sqlite3) currentDatabase() (name string) { +func (s sqlite3) CurrentDatabase() (name string) { var ( ifaces = make([]interface{}, 3) pointers = make([]*string, 3) diff --git a/vendor/github.com/jinzhu/gorm/docker-compose.yml b/vendor/github.com/jinzhu/gorm/docker-compose.yml new file mode 100644 index 00000000..79bf5fc3 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/docker-compose.yml @@ -0,0 +1,30 @@ +version: '3' + +services: + mysql: + image: 'mysql:latest' + ports: + - 9910:3306 + environment: + - MYSQL_DATABASE=gorm + - MYSQL_USER=gorm + - MYSQL_PASSWORD=gorm + - MYSQL_RANDOM_ROOT_PASSWORD="yes" + postgres: + image: 'postgres:latest' + ports: + - 9920:5432 + environment: + - POSTGRES_USER=gorm + - POSTGRES_DB=gorm + - POSTGRES_PASSWORD=gorm + mssql: + image: 'mcmoe/mssqldocker:latest' + ports: + - 9930:1433 + environment: + - ACCEPT_EULA=Y + - SA_PASSWORD=LoremIpsum86 + - MSSQL_DB=gorm + - MSSQL_USER=gorm + - MSSQL_PASSWORD=LoremIpsum86 diff --git a/vendor/github.com/jinzhu/gorm/errors.go b/vendor/github.com/jinzhu/gorm/errors.go index ce3a25c0..d5ef8d57 100644 --- a/vendor/github.com/jinzhu/gorm/errors.go +++ b/vendor/github.com/jinzhu/gorm/errors.go @@ -6,11 +6,11 @@ import ( ) var ( - // ErrRecordNotFound record not found error, happens when haven't find any matched data when looking up with a struct + // ErrRecordNotFound returns a "record not found error". Occurs only when attempting to query the database with a struct; querying with a slice won't return this error ErrRecordNotFound = errors.New("record not found") - // ErrInvalidSQL invalid SQL error, happens when you passed invalid SQL + // ErrInvalidSQL occurs when you attempt a query with invalid SQL ErrInvalidSQL = errors.New("invalid SQL") - // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback` + // ErrInvalidTransaction occurs when you are trying to `Commit` or `Rollback` ErrInvalidTransaction = errors.New("no valid transaction") // ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin` ErrCantStartTransaction = errors.New("can't start transaction") @@ -18,40 +18,54 @@ var ( ErrUnaddressable = errors.New("using unaddressable value") ) -type errorsInterface interface { - GetErrors() []error -} - // Errors contains all happened errors -type Errors struct { - errors []error +type Errors []error + +// IsRecordNotFoundError returns true if error contains a RecordNotFound error +func IsRecordNotFoundError(err error) bool { + if errs, ok := err.(Errors); ok { + for _, err := range errs { + if err == ErrRecordNotFound { + return true + } + } + } + return err == ErrRecordNotFound } -// GetErrors get all happened errors +// GetErrors gets all errors that have occurred and returns a slice of errors (Error type) func (errs Errors) GetErrors() []error { - return errs.errors + return errs } -// Add add an error -func (errs *Errors) Add(err error) { - if errors, ok := err.(errorsInterface); ok { - for _, err := range errors.GetErrors() { - errs.Add(err) +// Add adds an error to a given slice of errors +func (errs Errors) Add(newErrors ...error) Errors { + for _, err := range newErrors { + if err == nil { + continue } - } else { - for _, e := range errs.errors { - if err == e { - return + + if errors, ok := err.(Errors); ok { + errs = errs.Add(errors...) + } else { + ok = true + for _, e := range errs { + if err == e { + ok = false + } + } + if ok { + errs = append(errs, err) } } - errs.errors = append(errs.errors, err) } + return errs } -// Error format happened errors +// Error takes a slice of all errors that have occurred and returns it as a formatted string func (errs Errors) Error() string { var errors = []string{} - for _, e := range errs.errors { + for _, e := range errs { errors = append(errors, e.Error()) } return strings.Join(errors, "; ") diff --git a/vendor/github.com/jinzhu/gorm/field.go b/vendor/github.com/jinzhu/gorm/field.go index 11c410b0..acd06e20 100644 --- a/vendor/github.com/jinzhu/gorm/field.go +++ b/vendor/github.com/jinzhu/gorm/field.go @@ -2,6 +2,7 @@ package gorm import ( "database/sql" + "database/sql/driver" "errors" "fmt" "reflect" @@ -44,7 +45,14 @@ func (field *Field) Set(value interface{}) (err error) { if reflectValue.Type().ConvertibleTo(fieldValue.Type()) { fieldValue.Set(reflectValue.Convert(fieldValue.Type())) } else if scanner, ok := fieldValue.Addr().Interface().(sql.Scanner); ok { - err = scanner.Scan(reflectValue.Interface()) + v := reflectValue.Interface() + if valuer, ok := v.(driver.Valuer); ok { + if v, err = valuer.Value(); err == nil { + err = scanner.Scan(v) + } + } else { + err = scanner.Scan(v) + } } else { err = fmt.Errorf("could not convert argument of field %s from %s to %s", field.Name, reflectValue.Type(), fieldValue.Type()) } diff --git a/vendor/github.com/jinzhu/gorm/interface.go b/vendor/github.com/jinzhu/gorm/interface.go index 7b02aa66..fe649231 100644 --- a/vendor/github.com/jinzhu/gorm/interface.go +++ b/vendor/github.com/jinzhu/gorm/interface.go @@ -1,8 +1,12 @@ package gorm -import "database/sql" +import ( + "context" + "database/sql" +) -type sqlCommon interface { +// SQLCommon is the minimal database connection functionality gorm requires. Implemented by *sql.DB. +type SQLCommon interface { Exec(query string, args ...interface{}) (sql.Result, error) Prepare(query string) (*sql.Stmt, error) Query(query string, args ...interface{}) (*sql.Rows, error) @@ -11,6 +15,7 @@ type sqlCommon interface { type sqlDb interface { Begin() (*sql.Tx, error) + BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) } type sqlTx interface { diff --git a/vendor/github.com/jinzhu/gorm/join_table_handler.go b/vendor/github.com/jinzhu/gorm/join_table_handler.go index 18c12a85..a036d46d 100644 --- a/vendor/github.com/jinzhu/gorm/join_table_handler.go +++ b/vendor/github.com/jinzhu/gorm/join_table_handler.go @@ -59,6 +59,7 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s s.TableName = tableName s.Source = JoinTableSource{ModelType: source} + s.Source.ForeignKeys = []JoinTableForeignKey{} for idx, dbName := range relationship.ForeignFieldNames { s.Source.ForeignKeys = append(s.Source.ForeignKeys, JoinTableForeignKey{ DBName: relationship.ForeignDBNames[idx], @@ -67,6 +68,7 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s } s.Destination = JoinTableSource{ModelType: destination} + s.Destination.ForeignKeys = []JoinTableForeignKey{} for idx, dbName := range relationship.AssociationForeignFieldNames { s.Destination.ForeignKeys = append(s.Destination.ForeignKeys, JoinTableForeignKey{ DBName: relationship.AssociationForeignDBNames[idx], @@ -77,41 +79,43 @@ func (s *JoinTableHandler) Setup(relationship *Relationship, tableName string, s // Table return join table's table name func (s JoinTableHandler) Table(db *DB) string { - return s.TableName + return DefaultTableNameHandler(db, s.TableName) } -func (s JoinTableHandler) getSearchMap(db *DB, sources ...interface{}) map[string]interface{} { - values := map[string]interface{}{} - +func (s JoinTableHandler) updateConditionMap(conditionMap map[string]interface{}, db *DB, joinTableSources []JoinTableSource, sources ...interface{}) { for _, source := range sources { scope := db.NewScope(source) modelType := scope.GetModelStruct().ModelType - if s.Source.ModelType == modelType { - for _, foreignKey := range s.Source.ForeignKeys { - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - values[foreignKey.DBName] = field.Field.Interface() - } - } - } else if s.Destination.ModelType == modelType { - for _, foreignKey := range s.Destination.ForeignKeys { - if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { - values[foreignKey.DBName] = field.Field.Interface() + for _, joinTableSource := range joinTableSources { + if joinTableSource.ModelType == modelType { + for _, foreignKey := range joinTableSource.ForeignKeys { + if field, ok := scope.FieldByName(foreignKey.AssociationDBName); ok { + conditionMap[foreignKey.DBName] = field.Field.Interface() + } } + break } } } - return values } // Add create relationship in join table for source and destination func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source interface{}, destination interface{}) error { - scope := db.NewScope("") - searchMap := s.getSearchMap(db, source, destination) + var ( + scope = db.NewScope("") + conditionMap = map[string]interface{}{} + ) + + // Update condition map for source + s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source}, source) + + // Update condition map for destination + s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Destination}, destination) var assignColumns, binVars, conditions []string var values []interface{} - for key, value := range searchMap { + for key, value := range conditionMap { assignColumns = append(assignColumns, scope.Quote(key)) binVars = append(binVars, `?`) conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) @@ -139,12 +143,15 @@ func (s JoinTableHandler) Add(handler JoinTableHandlerInterface, db *DB, source // Delete delete relationship in join table for sources func (s JoinTableHandler) Delete(handler JoinTableHandlerInterface, db *DB, sources ...interface{}) error { var ( - scope = db.NewScope(nil) - conditions []string - values []interface{} + scope = db.NewScope(nil) + conditions []string + values []interface{} + conditionMap = map[string]interface{}{} ) - for key, value := range s.getSearchMap(db, sources...) { + s.updateConditionMap(conditionMap, db, []JoinTableSource{s.Source, s.Destination}, sources...) + + for key, value := range conditionMap { conditions = append(conditions, fmt.Sprintf("%v = ?", scope.Quote(key))) values = append(values, value) } diff --git a/vendor/github.com/jinzhu/gorm/logger.go b/vendor/github.com/jinzhu/gorm/logger.go index 2c4ccbbc..88e167dd 100644 --- a/vendor/github.com/jinzhu/gorm/logger.go +++ b/vendor/github.com/jinzhu/gorm/logger.go @@ -7,50 +7,62 @@ import ( "os" "reflect" "regexp" + "strconv" "time" "unicode" ) var ( - defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)} - sqlRegexp = regexp.MustCompile(`(\$\d+)|\?`) + defaultLogger = Logger{log.New(os.Stdout, "\r\n", 0)} + sqlRegexp = regexp.MustCompile(`\?`) + numericPlaceHolderRegexp = regexp.MustCompile(`\$\d+`) ) -type logger interface { - Print(v ...interface{}) +func isPrintable(s string) bool { + for _, r := range s { + if !unicode.IsPrint(r) { + return false + } + } + return true } -// LogWriter log writer interface -type LogWriter interface { - Println(v ...interface{}) -} +var LogFormatter = func(values ...interface{}) (messages []interface{}) { + if len(values) > 1 { + var ( + sql string + formattedValues []string + level = values[0] + currentTime = "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m" + source = fmt.Sprintf("\033[35m(%v)\033[0m", values[1]) + ) -// Logger default logger -type Logger struct { - LogWriter -} + messages = []interface{}{source, currentTime} -// Print format & print log -func (logger Logger) Print(values ...interface{}) { - if len(values) > 1 { - level := values[0] - currentTime := "\n\033[33m[" + NowFunc().Format("2006-01-02 15:04:05") + "]\033[0m" - source := fmt.Sprintf("\033[35m(%v)\033[0m", values[1]) - messages := []interface{}{source, currentTime} + if len(values) == 2 { + //remove the line break + currentTime = currentTime[1:] + //remove the brackets + source = fmt.Sprintf("\033[35m%v\033[0m", values[1]) + + messages = []interface{}{currentTime, source} + } if level == "sql" { // duration messages = append(messages, fmt.Sprintf(" \033[36;1m[%.2fms]\033[0m ", float64(values[2].(time.Duration).Nanoseconds()/1e4)/100.0)) // sql - var sql string - var formattedValues []string for _, value := range values[4].([]interface{}) { indirectValue := reflect.Indirect(reflect.ValueOf(value)) if indirectValue.IsValid() { value = indirectValue.Interface() if t, ok := value.(time.Time); ok { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format(time.RFC3339))) + if t.IsZero() { + formattedValues = append(formattedValues, fmt.Sprintf("'%v'", "0000-00-00 00:00:00")) + } else { + formattedValues = append(formattedValues, fmt.Sprintf("'%v'", t.Format("2006-01-02 15:04:05"))) + } } else if b, ok := value.([]byte); ok { if str := string(b); isPrintable(str) { formattedValues = append(formattedValues, fmt.Sprintf("'%v'", str)) @@ -64,36 +76,66 @@ func (logger Logger) Print(values ...interface{}) { formattedValues = append(formattedValues, "NULL") } } else { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) + switch value.(type) { + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool: + formattedValues = append(formattedValues, fmt.Sprintf("%v", value)) + default: + formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) + } } } else { - formattedValues = append(formattedValues, fmt.Sprintf("'%v'", value)) + formattedValues = append(formattedValues, "NULL") } } - var formattedValuesLength = len(formattedValues) - for index, value := range sqlRegexp.Split(values[3].(string), -1) { - sql += value - if index < formattedValuesLength { - sql += formattedValues[index] + // differentiate between $n placeholders or else treat like ? + if numericPlaceHolderRegexp.MatchString(values[3].(string)) { + sql = values[3].(string) + for index, value := range formattedValues { + placeholder := fmt.Sprintf(`\$%d([^\d]|$)`, index+1) + sql = regexp.MustCompile(placeholder).ReplaceAllString(sql, value+"$1") + } + } else { + formattedValuesLength := len(formattedValues) + for index, value := range sqlRegexp.Split(values[3].(string), -1) { + sql += value + if index < formattedValuesLength { + sql += formattedValues[index] + } } } messages = append(messages, sql) + messages = append(messages, fmt.Sprintf(" \n\033[36;31m[%v]\033[0m ", strconv.FormatInt(values[5].(int64), 10)+" rows affected or returned ")) } else { messages = append(messages, "\033[31;1m") messages = append(messages, values[2:]...) messages = append(messages, "\033[0m") } - logger.Println(messages...) } + + return } -func isPrintable(s string) bool { - for _, r := range s { - if !unicode.IsPrint(r) { - return false - } - } - return true +type logger interface { + Print(v ...interface{}) } + +// LogWriter log writer interface +type LogWriter interface { + Println(v ...interface{}) +} + +// Logger default logger +type Logger struct { + LogWriter +} + +// Print format & print log +func (logger Logger) Print(values ...interface{}) { + logger.Println(LogFormatter(values...)...) +} + +type nopLogger struct{} + +func (nopLogger) Print(values ...interface{}) {} diff --git a/vendor/github.com/jinzhu/gorm/main.go b/vendor/github.com/jinzhu/gorm/main.go index cd445555..466e80c3 100644 --- a/vendor/github.com/jinzhu/gorm/main.go +++ b/vendor/github.com/jinzhu/gorm/main.go @@ -1,32 +1,49 @@ package gorm import ( + "context" "database/sql" "errors" "fmt" "reflect" "strings" + "sync" "time" ) // DB contains information for current db connection type DB struct { - Value interface{} - Error error - RowsAffected int64 - callbacks *Callback - db sqlCommon - parent *DB - search *search - logMode int + sync.RWMutex + Value interface{} + Error error + RowsAffected int64 + + // single db + db SQLCommon + blockGlobalUpdate bool + logMode logModeValue logger logger - dialect Dialect - singularTable bool - source string - values map[string]interface{} - joinTableHandlers map[string]JoinTableHandler + search *search + values sync.Map + + // global db + parent *DB + callbacks *Callback + dialect Dialect + singularTable bool + + // function to be used to override the creating of a new timestamp + nowFuncOverride func() time.Time } +type logModeValue int + +const ( + defaultLogMode logModeValue = iota + noLogMode + detailedLogMode +) + // Open initialize a new db connection, need to import driver first, e.g: // // import _ "github.com/go-sql-driver/mysql" @@ -38,57 +55,50 @@ type DB struct { // // import _ "github.com/jinzhu/gorm/dialects/postgres" // // import _ "github.com/jinzhu/gorm/dialects/sqlite" // // import _ "github.com/jinzhu/gorm/dialects/mssql" -func Open(dialect string, args ...interface{}) (*DB, error) { - var db DB - var err error - +func Open(dialect string, args ...interface{}) (db *DB, err error) { if len(args) == 0 { err = errors.New("invalid database source") - } else { - var source string - var dbSQL sqlCommon - - switch value := args[0].(type) { - case string: - var driver = dialect - if len(args) == 1 { - source = value - } else if len(args) >= 2 { - driver = value - source = args[1].(string) - } - dbSQL, err = sql.Open(driver, source) - case sqlCommon: - source = reflect.Indirect(reflect.ValueOf(value)).FieldByName("dsn").String() - dbSQL = value - } - - db = DB{ - dialect: newDialect(dialect, dbSQL.(*sql.DB)), - logger: defaultLogger, - callbacks: DefaultCallback, - source: source, - values: map[string]interface{}{}, - db: dbSQL, + return nil, err + } + var source string + var dbSQL SQLCommon + var ownDbSQL bool + + switch value := args[0].(type) { + case string: + var driver = dialect + if len(args) == 1 { + source = value + } else if len(args) >= 2 { + driver = value + source = args[1].(string) } - db.parent = &db + dbSQL, err = sql.Open(driver, source) + ownDbSQL = true + case SQLCommon: + dbSQL = value + ownDbSQL = false + default: + return nil, fmt.Errorf("invalid database source: %v is not a valid type", value) + } - if err == nil { - err = db.DB().Ping() // Send a ping to make sure the database connection is alive. + db = &DB{ + db: dbSQL, + logger: defaultLogger, + callbacks: DefaultCallback, + dialect: newDialect(dialect, dbSQL), + } + db.parent = db + if err != nil { + return + } + // Send a ping to make sure the database connection is alive. + if d, ok := dbSQL.(*sql.DB); ok { + if err = d.Ping(); err != nil && ownDbSQL { + d.Close() } } - - return &db, err -} - -// Close close current db connection -func (s *DB) Close() error { - return s.parent.db.(*sql.DB).Close() -} - -// DB get `*sql.DB` from current connection -func (s *DB) DB() *sql.DB { - return s.db.(*sql.DB) + return } // New clone a new db connection without search conditions @@ -99,23 +109,43 @@ func (s *DB) New() *DB { return clone } -// NewScope create a scope for current operation -func (s *DB) NewScope(value interface{}) *Scope { - dbClone := s.clone() - dbClone.Value = value - return &Scope{db: dbClone, Search: dbClone.search.clone(), Value: value} +type closer interface { + Close() error +} + +// Close close current db connection. If database connection is not an io.Closer, returns an error. +func (s *DB) Close() error { + if db, ok := s.parent.db.(closer); ok { + return db.Close() + } + return errors.New("can't close current db") +} + +// DB get `*sql.DB` from current connection +// If the underlying database connection is not a *sql.DB, returns nil +func (s *DB) DB() *sql.DB { + db, ok := s.db.(*sql.DB) + if !ok { + panic("can't support full GORM on currently status, maybe this is a TX instance.") + } + return db } // CommonDB return the underlying `*sql.DB` or `*sql.Tx` instance, mainly intended to allow coexistence with legacy non-GORM code. -func (s *DB) CommonDB() sqlCommon { +func (s *DB) CommonDB() SQLCommon { return s.db } +// Dialect get dialect +func (s *DB) Dialect() Dialect { + return s.dialect +} + // Callback return `Callbacks` container, you could add/change/delete callbacks with it // db.Callback().Create().Register("update_created_at", updateCreated) // Refer https://jinzhu.github.io/gorm/development.html#callbacks func (s *DB) Callback() *Callback { - s.parent.callbacks = s.parent.callbacks.clone() + s.parent.callbacks = s.parent.callbacks.clone(s.logger) return s.parent.callbacks } @@ -127,20 +157,80 @@ func (s *DB) SetLogger(log logger) { // LogMode set log mode, `true` for detailed logs, `false` for no log, default, will only print error logs func (s *DB) LogMode(enable bool) *DB { if enable { - s.logMode = 2 + s.logMode = detailedLogMode } else { - s.logMode = 1 + s.logMode = noLogMode } return s } +// SetNowFuncOverride set the function to be used when creating a new timestamp +func (s *DB) SetNowFuncOverride(nowFuncOverride func() time.Time) *DB { + s.nowFuncOverride = nowFuncOverride + return s +} + +// Get a new timestamp, using the provided nowFuncOverride on the DB instance if set, +// otherwise defaults to the global NowFunc() +func (s *DB) nowFunc() time.Time { + if s.nowFuncOverride != nil { + return s.nowFuncOverride() + } + + return NowFunc() +} + +// BlockGlobalUpdate if true, generates an error on update/delete without where clause. +// This is to prevent eventual error with empty objects updates/deletions +func (s *DB) BlockGlobalUpdate(enable bool) *DB { + s.blockGlobalUpdate = enable + return s +} + +// HasBlockGlobalUpdate return state of block +func (s *DB) HasBlockGlobalUpdate() bool { + return s.blockGlobalUpdate +} + // SingularTable use singular table by default func (s *DB) SingularTable(enable bool) { - modelStructsMap = newModelStructsMap() + s.parent.Lock() + defer s.parent.Unlock() s.parent.singularTable = enable } -// Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/curd.html#query +// NewScope create a scope for current operation +func (s *DB) NewScope(value interface{}) *Scope { + dbClone := s.clone() + dbClone.Value = value + scope := &Scope{db: dbClone, Value: value} + if s.search != nil { + scope.Search = s.search.clone() + } else { + scope.Search = &search{} + } + return scope +} + +// QueryExpr returns the query as SqlExpr object +func (s *DB) QueryExpr() *SqlExpr { + scope := s.NewScope(s.Value) + scope.InstanceSet("skip_bindvar", true) + scope.prepareQuerySQL() + + return Expr(scope.SQL, scope.SQLVars...) +} + +// SubQuery returns the query as sub query +func (s *DB) SubQuery() *SqlExpr { + scope := s.NewScope(s.Value) + scope.InstanceSet("skip_bindvar", true) + scope.prepareQuerySQL() + + return Expr(fmt.Sprintf("(%v)", scope.SQL), scope.SQLVars...) +} + +// Where return a new relation, filter records with given conditions, accepts `map`, `struct` or `string` as conditions, refer http://jinzhu.github.io/gorm/crud.html#query func (s *DB) Where(query interface{}, args ...interface{}) *DB { return s.clone().search.Where(query, args...).db } @@ -156,17 +246,20 @@ func (s *DB) Not(query interface{}, args ...interface{}) *DB { } // Limit specify the number of records to be retrieved -func (s *DB) Limit(limit int) *DB { +func (s *DB) Limit(limit interface{}) *DB { return s.clone().search.Limit(limit).db } // Offset specify the number of records to skip before starting to return the records -func (s *DB) Offset(offset int) *DB { +func (s *DB) Offset(offset interface{}) *DB { return s.clone().search.Offset(offset).db } // Order specify order when retrieve records from database, set reorder to `true` to overwrite defined conditions -func (s *DB) Order(value string, reorder ...bool) *DB { +// db.Order("name DESC") +// db.Order("name DESC", true) // reorder +// db.Order(gorm.Expr("name = ? DESC", "first")) // sql expression +func (s *DB) Order(value interface{}, reorder ...bool) *DB { return s.clone().search.Order(value, reorder...).db } @@ -187,7 +280,7 @@ func (s *DB) Group(query string) *DB { } // Having specify HAVING conditions for GROUP BY -func (s *DB) Having(query string, values ...interface{}) *DB { +func (s *DB) Having(query interface{}, values ...interface{}) *DB { return s.clone().search.Having(query, values...).db } @@ -209,7 +302,7 @@ func (s *DB) Joins(query string, args ...interface{}) *DB { // } // // db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) -// Refer https://jinzhu.github.io/gorm/curd.html#scopes +// Refer https://jinzhu.github.io/gorm/crud.html#scopes func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { for _, f := range funcs { s = f(s) @@ -217,32 +310,40 @@ func (s *DB) Scopes(funcs ...func(*DB) *DB) *DB { return s } -// Unscoped return all record including deleted record, refer Soft Delete https://jinzhu.github.io/gorm/curd.html#soft-delete +// Unscoped return all record including deleted record, refer Soft Delete https://jinzhu.github.io/gorm/crud.html#soft-delete func (s *DB) Unscoped() *DB { return s.clone().search.unscoped().db } -// Attrs initialize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/curd.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/curd.html#firstorcreate +// Attrs initialize struct with argument if record not found with `FirstOrInit` https://jinzhu.github.io/gorm/crud.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/crud.html#firstorcreate func (s *DB) Attrs(attrs ...interface{}) *DB { return s.clone().search.Attrs(attrs...).db } -// Assign assign result with argument regardless it is found or not with `FirstOrInit` https://jinzhu.github.io/gorm/curd.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/curd.html#firstorcreate +// Assign assign result with argument regardless it is found or not with `FirstOrInit` https://jinzhu.github.io/gorm/crud.html#firstorinit or `FirstOrCreate` https://jinzhu.github.io/gorm/crud.html#firstorcreate func (s *DB) Assign(attrs ...interface{}) *DB { return s.clone().search.Assign(attrs...).db } // First find first record that match given conditions, order by primary key func (s *DB) First(out interface{}, where ...interface{}) *DB { - newScope := s.clone().NewScope(out) + newScope := s.NewScope(out) newScope.Search.Limit(1) + return newScope.Set("gorm:order_by_primary_key", "ASC"). inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db } +// Take return a record that match given conditions, the order will depend on the database implementation +func (s *DB) Take(out interface{}, where ...interface{}) *DB { + newScope := s.NewScope(out) + newScope.Search.Limit(1) + return newScope.inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db +} + // Last find last record that match given conditions, order by primary key func (s *DB) Last(out interface{}, where ...interface{}) *DB { - newScope := s.clone().NewScope(out) + newScope := s.NewScope(out) newScope.Search.Limit(1) return newScope.Set("gorm:order_by_primary_key", "DESC"). inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db @@ -250,12 +351,17 @@ func (s *DB) Last(out interface{}, where ...interface{}) *DB { // Find find records that match given conditions func (s *DB) Find(out interface{}, where ...interface{}) *DB { - return s.clone().NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db + return s.NewScope(out).inlineCondition(where...).callCallbacks(s.parent.callbacks.queries).db +} + +//Preloads preloads relations, don`t touch out +func (s *DB) Preloads(out interface{}) *DB { + return s.NewScope(out).InstanceSet("gorm:only_preload", 1).callCallbacks(s.parent.callbacks.queries).db } // Scan scan value to a struct func (s *DB) Scan(dest interface{}) *DB { - return s.clone().NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db + return s.NewScope(s.Value).Set("gorm:query_destination", dest).callCallbacks(s.parent.callbacks.queries).db } // Row return `*sql.Row` with given conditions @@ -271,8 +377,8 @@ func (s *DB) Rows() (*sql.Rows, error) { // ScanRows scan `*sql.Rows` to give struct func (s *DB) ScanRows(rows *sql.Rows, result interface{}) error { var ( - clone = s.clone() - scope = clone.NewScope(result) + scope = s.NewScope(result) + clone = scope.db columns, err = rows.Columns() ) @@ -297,11 +403,11 @@ func (s *DB) Count(value interface{}) *DB { // Related get related associations func (s *DB) Related(value interface{}, foreignKeys ...string) *DB { - return s.clone().NewScope(s.Value).related(value, foreignKeys...).db + return s.NewScope(s.Value).related(value, foreignKeys...).db } // FirstOrInit find first matched record or initialize a new one with given conditions (only works with struct, map conditions) -// https://jinzhu.github.io/gorm/curd.html#firstorinit +// https://jinzhu.github.io/gorm/crud.html#firstorinit func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { c := s.clone() if result := c.First(out, where...); result.Error != nil { @@ -316,41 +422,42 @@ func (s *DB) FirstOrInit(out interface{}, where ...interface{}) *DB { } // FirstOrCreate find first matched record or create a new one with given conditions (only works with struct, map conditions) -// https://jinzhu.github.io/gorm/curd.html#firstorcreate +// https://jinzhu.github.io/gorm/crud.html#firstorcreate func (s *DB) FirstOrCreate(out interface{}, where ...interface{}) *DB { c := s.clone() - if result := c.First(out, where...); result.Error != nil { + if result := s.First(out, where...); result.Error != nil { if !result.RecordNotFound() { return result } - c.AddError(c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callbacks.creates).db.Error) + return c.NewScope(out).inlineCondition(where...).initialize().callCallbacks(c.parent.callbacks.creates).db } else if len(c.search.assignAttrs) > 0 { - c.AddError(c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callbacks.updates).db.Error) + return c.NewScope(out).InstanceSet("gorm:update_interface", c.search.assignAttrs).callCallbacks(c.parent.callbacks.updates).db } return c } -// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update +// Update update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update +// WARNING when update with struct, GORM will not update fields that with zero value func (s *DB) Update(attrs ...interface{}) *DB { return s.Updates(toSearchableMap(attrs...), true) } -// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update +// Updates update attributes with callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update func (s *DB) Updates(values interface{}, ignoreProtectedAttrs ...bool) *DB { - return s.clone().NewScope(s.Value). + return s.NewScope(s.Value). Set("gorm:ignore_protected_attrs", len(ignoreProtectedAttrs) > 0). InstanceSet("gorm:update_interface", values). callCallbacks(s.parent.callbacks.updates).db } -// UpdateColumn update attributes without callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update +// UpdateColumn update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update func (s *DB) UpdateColumn(attrs ...interface{}) *DB { return s.UpdateColumns(toSearchableMap(attrs...)) } -// UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/curd.html#update +// UpdateColumns update attributes without callbacks, refer: https://jinzhu.github.io/gorm/crud.html#update func (s *DB) UpdateColumns(values interface{}) *DB { - return s.clone().NewScope(s.Value). + return s.NewScope(s.Value). Set("gorm:update_column", true). Set("gorm:save_associations", false). InstanceSet("gorm:update_interface", values). @@ -359,22 +466,27 @@ func (s *DB) UpdateColumns(values interface{}) *DB { // Save update value in database, if the value doesn't have primary key, will insert it func (s *DB) Save(value interface{}) *DB { - scope := s.clone().NewScope(value) - if scope.PrimaryKeyZero() { - return scope.callCallbacks(s.parent.callbacks.creates).db + scope := s.NewScope(value) + if !scope.PrimaryKeyZero() { + newDB := scope.callCallbacks(s.parent.callbacks.updates).db + if newDB.Error == nil && newDB.RowsAffected == 0 { + return s.New().Table(scope.TableName()).FirstOrCreate(value) + } + return newDB } - return scope.callCallbacks(s.parent.callbacks.updates).db + return scope.callCallbacks(s.parent.callbacks.creates).db } // Create insert the value into database func (s *DB) Create(value interface{}) *DB { - scope := s.clone().NewScope(value) + scope := s.NewScope(value) return scope.callCallbacks(s.parent.callbacks.creates).db } // Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition +// WARNING If model has DeletedAt field, GORM will only set field DeletedAt's value to current time func (s *DB) Delete(value interface{}, where ...interface{}) *DB { - return s.clone().NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db + return s.NewScope(value).inlineCondition(where...).callCallbacks(s.parent.callbacks.deletes).db } // Raw use raw sql as conditions, won't run it unless invoked by other methods @@ -385,8 +497,8 @@ func (s *DB) Raw(sql string, values ...interface{}) *DB { // Exec execute raw sql func (s *DB) Exec(sql string, values ...interface{}) *DB { - scope := s.clone().NewScope(nil) - generatedSQL := scope.buildWhereCondition(map[string]interface{}{"query": sql, "args": values}) + scope := s.NewScope(nil) + generatedSQL := scope.buildCondition(map[string]interface{}{"query": sql, "args": values}, true) generatedSQL = strings.TrimSuffix(strings.TrimPrefix(generatedSQL, "("), ")") scope.Raw(generatedSQL) return scope.Exec().db @@ -416,12 +528,46 @@ func (s *DB) Debug() *DB { return s.clone().LogMode(true) } -// Begin begin a transaction +// Transaction start a transaction as a block, +// return error will rollback, otherwise to commit. +func (s *DB) Transaction(fc func(tx *DB) error) (err error) { + + if _, ok := s.db.(*sql.Tx); ok { + return fc(s) + } + + panicked := true + tx := s.Begin() + defer func() { + // Make sure to rollback when panic, Block error or Commit error + if panicked || err != nil { + tx.Rollback() + } + }() + + err = fc(tx) + + if err == nil { + err = tx.Commit().Error + } + + panicked = false + return +} + +// Begin begins a transaction func (s *DB) Begin() *DB { + return s.BeginTx(context.Background(), &sql.TxOptions{}) +} + +// BeginTx begins a transaction with options +func (s *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) *DB { c := s.clone() - if db, ok := c.db.(sqlDb); ok { - tx, err := db.Begin() - c.db = interface{}(tx).(sqlCommon) + if db, ok := c.db.(sqlDb); ok && db != nil { + tx, err := db.BeginTx(ctx, opts) + c.db = interface{}(tx).(SQLCommon) + + c.dialect.SetDB(c.db) c.AddError(err) } else { c.AddError(ErrCantStartTransaction) @@ -431,7 +577,8 @@ func (s *DB) Begin() *DB { // Commit commit a transaction func (s *DB) Commit() *DB { - if db, ok := s.db.(sqlTx); ok { + var emptySQLTx *sql.Tx + if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { s.AddError(db.Commit()) } else { s.AddError(ErrInvalidTransaction) @@ -441,8 +588,28 @@ func (s *DB) Commit() *DB { // Rollback rollback a transaction func (s *DB) Rollback() *DB { - if db, ok := s.db.(sqlTx); ok { - s.AddError(db.Rollback()) + var emptySQLTx *sql.Tx + if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { + if err := db.Rollback(); err != nil && err != sql.ErrTxDone { + s.AddError(err) + } + } else { + s.AddError(ErrInvalidTransaction) + } + return s +} + +// RollbackUnlessCommitted rollback a transaction if it has not yet been +// committed. +func (s *DB) RollbackUnlessCommitted() *DB { + var emptySQLTx *sql.Tx + if db, ok := s.db.(sqlTx); ok && db != nil && db != emptySQLTx { + err := db.Rollback() + // Ignore the error indicating that the transaction has already + // been committed. + if err != sql.ErrTxDone { + s.AddError(err) + } } else { s.AddError(ErrInvalidTransaction) } @@ -451,7 +618,7 @@ func (s *DB) Rollback() *DB { // NewRecord check if value's primary key is blank func (s *DB) NewRecord(value interface{}) bool { - return s.clone().NewScope(value).PrimaryKeyZero() + return s.NewScope(value).PrimaryKeyZero() } // RecordNotFound check if returning ErrRecordNotFound error @@ -500,7 +667,7 @@ func (s *DB) DropTableIfExists(values ...interface{}) *DB { // HasTable check has table or not func (s *DB) HasTable(value interface{}) bool { var ( - scope = s.clone().NewScope(value) + scope = s.NewScope(value) tableName string ) @@ -526,14 +693,14 @@ func (s *DB) AutoMigrate(values ...interface{}) *DB { // ModifyColumn modify column to type func (s *DB) ModifyColumn(column string, typ string) *DB { - scope := s.clone().NewScope(s.Value) + scope := s.NewScope(s.Value) scope.modifyColumn(column, typ) return scope.db } // DropColumn drop a column func (s *DB) DropColumn(column string) *DB { - scope := s.clone().NewScope(s.Value) + scope := s.NewScope(s.Value) scope.dropColumn(column) return scope.db } @@ -554,7 +721,7 @@ func (s *DB) AddUniqueIndex(indexName string, columns ...string) *DB { // RemoveIndex remove index with name func (s *DB) RemoveIndex(indexName string) *DB { - scope := s.clone().NewScope(s.Value) + scope := s.NewScope(s.Value) scope.removeIndex(indexName) return scope.db } @@ -562,15 +729,23 @@ func (s *DB) RemoveIndex(indexName string) *DB { // AddForeignKey Add foreign key to the given scope, e.g: // db.Model(&User{}).AddForeignKey("city_id", "cities(id)", "RESTRICT", "RESTRICT") func (s *DB) AddForeignKey(field string, dest string, onDelete string, onUpdate string) *DB { - scope := s.clone().NewScope(s.Value) + scope := s.NewScope(s.Value) scope.addForeignKey(field, dest, onDelete, onUpdate) return scope.db } +// RemoveForeignKey Remove foreign key from the given scope, e.g: +// db.Model(&User{}).RemoveForeignKey("city_id", "cities(id)") +func (s *DB) RemoveForeignKey(field string, dest string) *DB { + scope := s.clone().NewScope(s.Value) + scope.removeForeignKey(field, dest) + return scope.db +} + // Association start `Association Mode` to handler relations things easir in that mode, refer: https://jinzhu.github.io/gorm/associations.html#association-mode func (s *DB) Association(column string) *Association { var err error - scope := s.clone().NewScope(s.Value) + var scope = s.Set("gorm:association:source", s.Value).NewScope(s.Value) if primaryField := scope.PrimaryField(); primaryField.IsBlank { err = errors.New("primary key can't be nil") @@ -602,13 +777,13 @@ func (s *DB) Set(name string, value interface{}) *DB { // InstantSet instant set setting, will affect current db func (s *DB) InstantSet(name string, value interface{}) *DB { - s.values[name] = value + s.values.Store(name, value) return s } // Get get setting by name func (s *DB) Get(name string) (value interface{}, ok bool) { - value, ok = s.values[name] + value, ok = s.values.Load(name) return } @@ -617,7 +792,7 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join scope := s.NewScope(source) for _, field := range scope.GetModelStruct().StructFields { if field.Name == column || field.DBName == column { - if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { + if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" { source := (&Scope{Value: source}).GetModelStruct().ModelType destination := (&Scope{Value: reflect.New(field.Struct.Type).Interface()}).GetModelStruct().ModelType handler.Setup(field.Relationship, many2many, source, destination) @@ -634,15 +809,15 @@ func (s *DB) SetJoinTableHandler(source interface{}, column string, handler Join func (s *DB) AddError(err error) error { if err != nil { if err != ErrRecordNotFound { - if s.logMode == 0 { - go s.print(fileWithLineNum(), err) + if s.logMode == defaultLogMode { + go s.print("error", fileWithLineNum(), err) } else { s.log(err) } - errors := Errors{errors: s.GetErrors()} - errors.Add(err) - if len(errors.GetErrors()) > 1 { + errors := Errors(s.GetErrors()) + errors = errors.Add(err) + if len(errors) > 1 { err = errors } } @@ -653,48 +828,59 @@ func (s *DB) AddError(err error) error { } // GetErrors get happened errors from the db -func (s *DB) GetErrors() (errors []error) { - if errs, ok := s.Error.(errorsInterface); ok { - return errs.GetErrors() +func (s *DB) GetErrors() []error { + if errs, ok := s.Error.(Errors); ok { + return errs } else if s.Error != nil { return []error{s.Error} } - return + return []error{} } //////////////////////////////////////////////////////////////////////////////// -// Private Methods For *gorm.DB +// Private Methods For DB //////////////////////////////////////////////////////////////////////////////// func (s *DB) clone() *DB { - db := DB{db: s.db, parent: s.parent, logger: s.logger, logMode: s.logMode, values: map[string]interface{}{}, Value: s.Value, Error: s.Error} - - for key, value := range s.values { - db.values[key] = value + db := &DB{ + db: s.db, + parent: s.parent, + logger: s.logger, + logMode: s.logMode, + Value: s.Value, + Error: s.Error, + blockGlobalUpdate: s.blockGlobalUpdate, + dialect: newDialect(s.dialect.GetName(), s.db), + nowFuncOverride: s.nowFuncOverride, } + s.values.Range(func(k, v interface{}) bool { + db.values.Store(k, v) + return true + }) + if s.search == nil { db.search = &search{limit: -1, offset: -1} } else { db.search = s.search.clone() } - db.search.db = &db - return &db + db.search.db = db + return db } func (s *DB) print(v ...interface{}) { - s.logger.(logger).Print(v...) + s.logger.Print(v...) } func (s *DB) log(v ...interface{}) { - if s != nil && s.logMode == 2 { + if s != nil && s.logMode == detailedLogMode { s.print(append([]interface{}{"log", fileWithLineNum()}, v...)...) } } func (s *DB) slog(sql string, t time.Time, vars ...interface{}) { - if s.logMode == 2 { - s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars) + if s.logMode == detailedLogMode { + s.print("sql", fileWithLineNum(), NowFunc().Sub(t), sql, vars, s.RowsAffected) } } diff --git a/vendor/github.com/jinzhu/gorm/model_struct.go b/vendor/github.com/jinzhu/gorm/model_struct.go index 6df615d1..57dbec38 100644 --- a/vendor/github.com/jinzhu/gorm/model_struct.go +++ b/vendor/github.com/jinzhu/gorm/model_struct.go @@ -17,39 +17,42 @@ var DefaultTableNameHandler = func(db *DB, defaultTableName string) string { return defaultTableName } -type safeModelStructsMap struct { - m map[reflect.Type]*ModelStruct - l *sync.RWMutex -} - -func (s *safeModelStructsMap) Set(key reflect.Type, value *ModelStruct) { - s.l.Lock() - defer s.l.Unlock() - s.m[key] = value -} +// lock for mutating global cached model metadata +var structsLock sync.Mutex -func (s *safeModelStructsMap) Get(key reflect.Type) *ModelStruct { - s.l.RLock() - defer s.l.RUnlock() - return s.m[key] -} - -func newModelStructsMap() *safeModelStructsMap { - return &safeModelStructsMap{l: new(sync.RWMutex), m: make(map[reflect.Type]*ModelStruct)} -} - -var modelStructsMap = newModelStructsMap() +// global cache of model metadata +var modelStructsMap sync.Map // ModelStruct model definition type ModelStruct struct { - PrimaryFields []*StructField - StructFields []*StructField - ModelType reflect.Type + PrimaryFields []*StructField + StructFields []*StructField + ModelType reflect.Type + defaultTableName string + l sync.Mutex } -// TableName get model's table name +// TableName returns model's table name func (s *ModelStruct) TableName(db *DB) string { + s.l.Lock() + defer s.l.Unlock() + + if s.defaultTableName == "" && db != nil && s.ModelType != nil { + // Set default table name + if tabler, ok := reflect.New(s.ModelType).Interface().(tabler); ok { + s.defaultTableName = tabler.TableName() + } else { + tableName := ToTableName(s.ModelType.Name()) + db.parent.RLock() + if db == nil || (db.parent != nil && !db.parent.singularTable) { + tableName = inflection.Plural(tableName) + } + db.parent.RUnlock() + s.defaultTableName = tableName + } + } + return DefaultTableNameHandler(db, s.defaultTableName) } @@ -68,24 +71,61 @@ type StructField struct { Struct reflect.StructField IsForeignKey bool Relationship *Relationship + + tagSettingsLock sync.RWMutex } -func (structField *StructField) clone() *StructField { - return &StructField{ - DBName: structField.DBName, - Name: structField.Name, - Names: structField.Names, - IsPrimaryKey: structField.IsPrimaryKey, - IsNormal: structField.IsNormal, - IsIgnored: structField.IsIgnored, - IsScanner: structField.IsScanner, - HasDefaultValue: structField.HasDefaultValue, - Tag: structField.Tag, - TagSettings: structField.TagSettings, - Struct: structField.Struct, - IsForeignKey: structField.IsForeignKey, - Relationship: structField.Relationship, +// TagSettingsSet Sets a tag in the tag settings map +func (sf *StructField) TagSettingsSet(key, val string) { + sf.tagSettingsLock.Lock() + defer sf.tagSettingsLock.Unlock() + sf.TagSettings[key] = val +} + +// TagSettingsGet returns a tag from the tag settings +func (sf *StructField) TagSettingsGet(key string) (string, bool) { + sf.tagSettingsLock.RLock() + defer sf.tagSettingsLock.RUnlock() + val, ok := sf.TagSettings[key] + return val, ok +} + +// TagSettingsDelete deletes a tag +func (sf *StructField) TagSettingsDelete(key string) { + sf.tagSettingsLock.Lock() + defer sf.tagSettingsLock.Unlock() + delete(sf.TagSettings, key) +} + +func (sf *StructField) clone() *StructField { + clone := &StructField{ + DBName: sf.DBName, + Name: sf.Name, + Names: sf.Names, + IsPrimaryKey: sf.IsPrimaryKey, + IsNormal: sf.IsNormal, + IsIgnored: sf.IsIgnored, + IsScanner: sf.IsScanner, + HasDefaultValue: sf.HasDefaultValue, + Tag: sf.Tag, + TagSettings: map[string]string{}, + Struct: sf.Struct, + IsForeignKey: sf.IsForeignKey, + } + + if sf.Relationship != nil { + relationship := *sf.Relationship + clone.Relationship = &relationship } + + // copy the struct field tagSettings, they should be read-locked while they are copied + sf.tagSettingsLock.Lock() + defer sf.tagSettingsLock.Unlock() + for key, value := range sf.TagSettings { + clone.TagSettings[key] = value + } + + return clone } // Relationship described the relationship between models @@ -93,6 +133,7 @@ type Relationship struct { Kind string PolymorphicType string PolymorphicDBName string + PolymorphicValue string ForeignFieldNames []string ForeignDBNames []string AssociationForeignFieldNames []string @@ -102,7 +143,7 @@ type Relationship struct { func getForeignField(column string, fields []*StructField) *StructField { for _, field := range fields { - if field.Name == column || field.DBName == column || field.DBName == ToDBName(column) { + if field.Name == column || field.DBName == column || field.DBName == ToColumnName(column) { return field } } @@ -111,6 +152,10 @@ func getForeignField(column string, fields []*StructField) *StructField { // GetModelStruct get value's model struct, relationships based on struct and tag definition func (scope *Scope) GetModelStruct() *ModelStruct { + return scope.getModelStruct(scope, make([]*StructField, 0)) +} + +func (scope *Scope) getModelStruct(rootScope *Scope, allFields []*StructField) *ModelStruct { var modelStruct ModelStruct // Scope value can't be nil if scope.Value == nil { @@ -128,23 +173,23 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } // Get Cached model struct - if value := modelStructsMap.Get(reflectType); value != nil { - return value + isSingularTable := false + if scope.db != nil && scope.db.parent != nil { + scope.db.parent.RLock() + isSingularTable = scope.db.parent.singularTable + scope.db.parent.RUnlock() } - modelStruct.ModelType = reflectType - - // Set default table name - if tabler, ok := reflect.New(reflectType).Interface().(tabler); ok { - modelStruct.defaultTableName = tabler.TableName() - } else { - tableName := ToDBName(reflectType.Name()) - if scope.db == nil || !scope.db.parent.singularTable { - tableName = inflection.Plural(tableName) - } - modelStruct.defaultTableName = tableName + hashKey := struct { + singularTable bool + reflectType reflect.Type + }{isSingularTable, reflectType} + if value, ok := modelStructsMap.Load(hashKey); ok && value != nil { + return value.(*ModelStruct) } + modelStruct.ModelType = reflectType + // Get all fields for i := 0; i < reflectType.NumField(); i++ { if fieldStruct := reflectType.Field(i); ast.IsExported(fieldStruct.Name) { @@ -157,15 +202,19 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } // is ignored field - if fieldStruct.Tag.Get("sql") == "-" { + if _, ok := field.TagSettingsGet("-"); ok { field.IsIgnored = true } else { - if _, ok := field.TagSettings["PRIMARY_KEY"]; ok { + if _, ok := field.TagSettingsGet("PRIMARY_KEY"); ok { field.IsPrimaryKey = true modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, field) } - if _, ok := field.TagSettings["DEFAULT"]; ok { + if _, ok := field.TagSettingsGet("DEFAULT"); ok && !field.IsPrimaryKey { + field.HasDefaultValue = true + } + + if _, ok := field.TagSettingsGet("AUTO_INCREMENT"); ok && !field.IsPrimaryKey { field.HasDefaultValue = true } @@ -178,18 +227,45 @@ func (scope *Scope) GetModelStruct() *ModelStruct { if _, isScanner := fieldValue.(sql.Scanner); isScanner { // is scanner field.IsScanner, field.IsNormal = true, true + if indirectType.Kind() == reflect.Struct { + for i := 0; i < indirectType.NumField(); i++ { + for key, value := range parseTagSetting(indirectType.Field(i).Tag) { + if _, ok := field.TagSettingsGet(key); !ok { + field.TagSettingsSet(key, value) + } + } + } + } } else if _, isTime := fieldValue.(*time.Time); isTime { // is time field.IsNormal = true - } else if _, ok := field.TagSettings["EMBEDDED"]; ok || fieldStruct.Anonymous { + } else if _, ok := field.TagSettingsGet("EMBEDDED"); ok || fieldStruct.Anonymous { // is embedded struct - for _, subField := range scope.New(fieldValue).GetStructFields() { + for _, subField := range scope.New(fieldValue).getModelStruct(rootScope, allFields).StructFields { subField = subField.clone() subField.Names = append([]string{fieldStruct.Name}, subField.Names...) + if prefix, ok := field.TagSettingsGet("EMBEDDED_PREFIX"); ok { + subField.DBName = prefix + subField.DBName + } + if subField.IsPrimaryKey { - modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField) + if _, ok := subField.TagSettingsGet("PRIMARY_KEY"); ok { + modelStruct.PrimaryFields = append(modelStruct.PrimaryFields, subField) + } else { + subField.IsPrimaryKey = false + } } + + if subField.Relationship != nil && subField.Relationship.JoinTableHandler != nil { + if joinTableHandler, ok := subField.Relationship.JoinTableHandler.(*JoinTableHandler); ok { + newJoinTableHandler := &JoinTableHandler{} + newJoinTableHandler.Setup(subField.Relationship, joinTableHandler.TableName, reflectType, joinTableHandler.Destination.ModelType) + subField.Relationship.JoinTableHandler = newJoinTableHandler + } + } + modelStruct.StructFields = append(modelStruct.StructFields, subField) + allFields = append(allFields, subField) } continue } else { @@ -205,12 +281,14 @@ func (scope *Scope) GetModelStruct() *ModelStruct { elemType = field.Struct.Type ) - if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { - foreignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",") + if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" { + foreignKeys = strings.Split(foreignKey, ",") } - if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { - associationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",") + if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" { + associationForeignKeys = strings.Split(foreignKey, ",") + } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" { + associationForeignKeys = strings.Split(foreignKey, ",") } for elemType.Kind() == reflect.Slice || elemType.Kind() == reflect.Ptr { @@ -218,40 +296,68 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } if elemType.Kind() == reflect.Struct { - if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { + if many2many, _ := field.TagSettingsGet("MANY2MANY"); many2many != "" { relationship.Kind = "many_to_many" - // if no foreign keys defined with tag - if len(foreignKeys) == 0 { - for _, field := range modelStruct.PrimaryFields { - foreignKeys = append(foreignKeys, field.DBName) + { // Foreign Keys for Source + joinTableDBNames := []string{} + + if foreignKey, _ := field.TagSettingsGet("JOINTABLE_FOREIGNKEY"); foreignKey != "" { + joinTableDBNames = strings.Split(foreignKey, ",") } - } - for _, foreignKey := range foreignKeys { - if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { - // source foreign keys (db names) - relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName) - // join table foreign keys for source - joinTableDBName := ToDBName(reflectType.Name()) + "_" + foreignField.DBName - relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBName) + // if no foreign keys defined with tag + if len(foreignKeys) == 0 { + for _, field := range modelStruct.PrimaryFields { + foreignKeys = append(foreignKeys, field.DBName) + } } - } - // if no association foreign keys defined with tag - if len(associationForeignKeys) == 0 { - for _, field := range toScope.PrimaryFields() { - associationForeignKeys = append(associationForeignKeys, field.DBName) + for idx, foreignKey := range foreignKeys { + if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { + // source foreign keys (db names) + relationship.ForeignFieldNames = append(relationship.ForeignFieldNames, foreignField.DBName) + + // setup join table foreign keys for source + if len(joinTableDBNames) > idx { + // if defined join table's foreign key + relationship.ForeignDBNames = append(relationship.ForeignDBNames, joinTableDBNames[idx]) + } else { + defaultJointableForeignKey := ToColumnName(reflectType.Name()) + "_" + foreignField.DBName + relationship.ForeignDBNames = append(relationship.ForeignDBNames, defaultJointableForeignKey) + } + } } } - for _, name := range associationForeignKeys { - if field, ok := toScope.FieldByName(name); ok { - // association foreign keys (db names) - relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) - // join table foreign keys for association - joinTableDBName := ToDBName(elemType.Name()) + "_" + field.DBName - relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) + { // Foreign Keys for Association (Destination) + associationJoinTableDBNames := []string{} + + if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_JOINTABLE_FOREIGNKEY"); foreignKey != "" { + associationJoinTableDBNames = strings.Split(foreignKey, ",") + } + + // if no association foreign keys defined with tag + if len(associationForeignKeys) == 0 { + for _, field := range toScope.PrimaryFields() { + associationForeignKeys = append(associationForeignKeys, field.DBName) + } + } + + for idx, name := range associationForeignKeys { + if field, ok := toScope.FieldByName(name); ok { + // association foreign keys (db names) + relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, field.DBName) + + // setup join table foreign keys for association + if len(associationJoinTableDBNames) > idx { + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationJoinTableDBNames[idx]) + } else { + // join table foreign keys for association + joinTableDBName := ToColumnName(elemType.Name()) + "_" + field.DBName + relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, joinTableDBName) + } + } } } @@ -265,13 +371,19 @@ func (scope *Scope) GetModelStruct() *ModelStruct { var toFields = toScope.GetStructFields() relationship.Kind = "has_many" - if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { + if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" { // Dog has many toys, tag polymorphic is Owner, then associationType is Owner // Toy use OwnerID, OwnerType ('dogs') as foreign key if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { associationType = polymorphic relationship.PolymorphicType = polymorphicType.Name relationship.PolymorphicDBName = polymorphicType.DBName + // if Dog has multiple set of toys set name of the set (instead of default 'dogs') + if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok { + relationship.PolymorphicValue = value + } else { + relationship.PolymorphicValue = scope.TableName() + } polymorphicType.IsForeignKey = true } } @@ -287,7 +399,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } else { // generate foreign keys from defined association foreign keys for _, scopeFieldName := range associationForeignKeys { - if foreignField := getForeignField(scopeFieldName, modelStruct.StructFields); foreignField != nil { + if foreignField := getForeignField(scopeFieldName, allFields); foreignField != nil { foreignKeys = append(foreignKeys, associationType+foreignField.Name) associationForeignKeys = append(associationForeignKeys, foreignField.Name) } @@ -299,13 +411,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct { for _, foreignKey := range foreignKeys { if strings.HasPrefix(foreignKey, associationType) { associationForeignKey := strings.TrimPrefix(foreignKey, associationType) - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { + if foreignField := getForeignField(associationForeignKey, allFields); foreignField != nil { associationForeignKeys = append(associationForeignKeys, associationForeignKey) } } } if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{scope.PrimaryKey()} + associationForeignKeys = []string{rootScope.PrimaryKey()} } } else if len(foreignKeys) != len(associationForeignKeys) { scope.Err(errors.New("invalid foreign keys, should have same length")) @@ -315,9 +427,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct { for idx, foreignKey := range foreignKeys { if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { - if associationField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); associationField != nil { - // source foreign keys + if associationField := getForeignField(associationForeignKeys[idx], allFields); associationField != nil { + // mark field as foreignkey, use global lock to avoid race + structsLock.Lock() foreignField.IsForeignKey = true + structsLock.Unlock() + + // association foreign keys relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, associationField.DBName) @@ -349,21 +465,29 @@ func (scope *Scope) GetModelStruct() *ModelStruct { tagAssociationForeignKeys []string ) - if foreignKey := field.TagSettings["FOREIGNKEY"]; foreignKey != "" { - tagForeignKeys = strings.Split(field.TagSettings["FOREIGNKEY"], ",") + if foreignKey, _ := field.TagSettingsGet("FOREIGNKEY"); foreignKey != "" { + tagForeignKeys = strings.Split(foreignKey, ",") } - if foreignKey := field.TagSettings["ASSOCIATIONFOREIGNKEY"]; foreignKey != "" { - tagAssociationForeignKeys = strings.Split(field.TagSettings["ASSOCIATIONFOREIGNKEY"], ",") + if foreignKey, _ := field.TagSettingsGet("ASSOCIATION_FOREIGNKEY"); foreignKey != "" { + tagAssociationForeignKeys = strings.Split(foreignKey, ",") + } else if foreignKey, _ := field.TagSettingsGet("ASSOCIATIONFOREIGNKEY"); foreignKey != "" { + tagAssociationForeignKeys = strings.Split(foreignKey, ",") } - if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" { + if polymorphic, _ := field.TagSettingsGet("POLYMORPHIC"); polymorphic != "" { // Cat has one toy, tag polymorphic is Owner, then associationType is Owner // Toy use OwnerID, OwnerType ('cats') as foreign key if polymorphicType := getForeignField(polymorphic+"Type", toFields); polymorphicType != nil { associationType = polymorphic relationship.PolymorphicType = polymorphicType.Name relationship.PolymorphicDBName = polymorphicType.DBName + // if Cat has several different types of toys set name for each (instead of default 'cats') + if value, ok := field.TagSettingsGet("POLYMORPHIC_VALUE"); ok { + relationship.PolymorphicValue = value + } else { + relationship.PolymorphicValue = scope.TableName() + } polymorphicType.IsForeignKey = true } } @@ -383,7 +507,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } else { // generate foreign keys form association foreign keys for _, associationForeignKey := range tagAssociationForeignKeys { - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { + if foreignField := getForeignField(associationForeignKey, allFields); foreignField != nil { foreignKeys = append(foreignKeys, associationType+foreignField.Name) associationForeignKeys = append(associationForeignKeys, foreignField.Name) } @@ -395,13 +519,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct { for _, foreignKey := range foreignKeys { if strings.HasPrefix(foreignKey, associationType) { associationForeignKey := strings.TrimPrefix(foreignKey, associationType) - if foreignField := getForeignField(associationForeignKey, modelStruct.StructFields); foreignField != nil { + if foreignField := getForeignField(associationForeignKey, allFields); foreignField != nil { associationForeignKeys = append(associationForeignKeys, associationForeignKey) } } } if len(associationForeignKeys) == 0 && len(foreignKeys) == 1 { - associationForeignKeys = []string{scope.PrimaryKey()} + associationForeignKeys = []string{rootScope.PrimaryKey()} } } else if len(foreignKeys) != len(associationForeignKeys) { scope.Err(errors.New("invalid foreign keys, should have same length")) @@ -411,9 +535,13 @@ func (scope *Scope) GetModelStruct() *ModelStruct { for idx, foreignKey := range foreignKeys { if foreignField := getForeignField(foreignKey, toFields); foreignField != nil { - if scopeField := getForeignField(associationForeignKeys[idx], modelStruct.StructFields); scopeField != nil { + if scopeField := getForeignField(associationForeignKeys[idx], allFields); scopeField != nil { + // mark field as foreignkey, use global lock to avoid race + structsLock.Lock() foreignField.IsForeignKey = true - // source foreign keys + structsLock.Unlock() + + // association foreign keys relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, scopeField.Name) relationship.AssociationForeignDBNames = append(relationship.AssociationForeignDBNames, scopeField.DBName) @@ -471,7 +599,10 @@ func (scope *Scope) GetModelStruct() *ModelStruct { for idx, foreignKey := range foreignKeys { if foreignField := getForeignField(foreignKey, modelStruct.StructFields); foreignField != nil { if associationField := getForeignField(associationForeignKeys[idx], toFields); associationField != nil { + // mark field as foreignkey, use global lock to avoid race + structsLock.Lock() foreignField.IsForeignKey = true + structsLock.Unlock() // association foreign keys relationship.AssociationForeignFieldNames = append(relationship.AssociationForeignFieldNames, associationField.Name) @@ -497,13 +628,14 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } // Even it is ignored, also possible to decode db value into the field - if value, ok := field.TagSettings["COLUMN"]; ok { + if value, ok := field.TagSettingsGet("COLUMN"); ok { field.DBName = value } else { - field.DBName = ToDBName(fieldStruct.Name) + field.DBName = ToColumnName(fieldStruct.Name) } modelStruct.StructFields = append(modelStruct.StructFields, field) + allFields = append(allFields, field) } } @@ -514,7 +646,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct { } } - modelStructsMap.Set(reflectType, &modelStruct) + modelStructsMap.Store(hashKey, &modelStruct) return &modelStruct } @@ -527,6 +659,9 @@ func (scope *Scope) GetStructFields() (fields []*StructField) { func parseTagSetting(tags reflect.StructTag) map[string]string { setting := map[string]string{} for _, str := range []string{tags.Get("sql"), tags.Get("gorm")} { + if str == "" { + continue + } tags := strings.Split(str, ";") for _, value := range tags { v := strings.Split(value, ":") diff --git a/vendor/github.com/jinzhu/gorm/naming.go b/vendor/github.com/jinzhu/gorm/naming.go new file mode 100644 index 00000000..6b0a4fdd --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/naming.go @@ -0,0 +1,124 @@ +package gorm + +import ( + "bytes" + "strings" +) + +// Namer is a function type which is given a string and return a string +type Namer func(string) string + +// NamingStrategy represents naming strategies +type NamingStrategy struct { + DB Namer + Table Namer + Column Namer +} + +// TheNamingStrategy is being initialized with defaultNamingStrategy +var TheNamingStrategy = &NamingStrategy{ + DB: defaultNamer, + Table: defaultNamer, + Column: defaultNamer, +} + +// AddNamingStrategy sets the naming strategy +func AddNamingStrategy(ns *NamingStrategy) { + if ns.DB == nil { + ns.DB = defaultNamer + } + if ns.Table == nil { + ns.Table = defaultNamer + } + if ns.Column == nil { + ns.Column = defaultNamer + } + TheNamingStrategy = ns +} + +// DBName alters the given name by DB +func (ns *NamingStrategy) DBName(name string) string { + return ns.DB(name) +} + +// TableName alters the given name by Table +func (ns *NamingStrategy) TableName(name string) string { + return ns.Table(name) +} + +// ColumnName alters the given name by Column +func (ns *NamingStrategy) ColumnName(name string) string { + return ns.Column(name) +} + +// ToDBName convert string to db name +func ToDBName(name string) string { + return TheNamingStrategy.DBName(name) +} + +// ToTableName convert string to table name +func ToTableName(name string) string { + return TheNamingStrategy.TableName(name) +} + +// ToColumnName convert string to db name +func ToColumnName(name string) string { + return TheNamingStrategy.ColumnName(name) +} + +var smap = newSafeMap() + +func defaultNamer(name string) string { + const ( + lower = false + upper = true + ) + + if v := smap.Get(name); v != "" { + return v + } + + if name == "" { + return "" + } + + var ( + value = commonInitialismsReplacer.Replace(name) + buf = bytes.NewBufferString("") + lastCase, currCase, nextCase, nextNumber bool + ) + + for i, v := range value[:len(value)-1] { + nextCase = bool(value[i+1] >= 'A' && value[i+1] <= 'Z') + nextNumber = bool(value[i+1] >= '0' && value[i+1] <= '9') + + if i > 0 { + if currCase == upper { + if lastCase == upper && (nextCase == upper || nextNumber == upper) { + buf.WriteRune(v) + } else { + if value[i-1] != '_' && value[i+1] != '_' { + buf.WriteRune('_') + } + buf.WriteRune(v) + } + } else { + buf.WriteRune(v) + if i == len(value)-2 && (nextCase == upper && nextNumber == lower) { + buf.WriteRune('_') + } + } + } else { + currCase = upper + buf.WriteRune(v) + } + lastCase = currCase + currCase = nextCase + } + + buf.WriteByte(value[len(value)-1]) + + s := strings.ToLower(buf.String()) + smap.Set(name, s) + return s +} diff --git a/vendor/github.com/jinzhu/gorm/scope.go b/vendor/github.com/jinzhu/gorm/scope.go index 844df85c..56c3d6e5 100644 --- a/vendor/github.com/jinzhu/gorm/scope.go +++ b/vendor/github.com/jinzhu/gorm/scope.go @@ -1,16 +1,15 @@ package gorm import ( + "bytes" "database/sql" "database/sql/driver" "errors" "fmt" + "reflect" "regexp" - "strconv" "strings" "time" - - "reflect" ) // Scope contain current operation's information when you perform any operation on the database @@ -58,18 +57,18 @@ func (scope *Scope) NewDB() *DB { } // SQLDB return *sql.DB -func (scope *Scope) SQLDB() sqlCommon { +func (scope *Scope) SQLDB() SQLCommon { return scope.db.db } // Dialect get dialect func (scope *Scope) Dialect() Dialect { - return scope.db.parent.dialect + return scope.db.dialect } // Quote used to quote string to escape them for database func (scope *Scope) Quote(str string) string { - if strings.Index(str, ".") != -1 { + if strings.Contains(str, ".") { newStrs := []string{} for _, str := range strings.Split(str, ".") { newStrs = append(newStrs, scope.Dialect().Quote(str)) @@ -116,6 +115,9 @@ func (scope *Scope) Fields() []*Field { if isStruct { fieldValue := indirectScopeValue for _, name := range structField.Names { + if fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { + fieldValue.Set(reflect.New(fieldValue.Type().Elem())) + } fieldValue = reflect.Indirect(fieldValue).FieldByName(name) } fields = append(fields, &Field{StructField: structField, Field: fieldValue, IsBlank: isBlank(fieldValue)}) @@ -132,7 +134,7 @@ func (scope *Scope) Fields() []*Field { // FieldByName find `gorm.Field` with field name or db name func (scope *Scope) FieldByName(name string) (field *Field, ok bool) { var ( - dbName = ToDBName(name) + dbName = ToColumnName(name) mostMatchedField *Field ) @@ -223,7 +225,7 @@ func (scope *Scope) SetColumn(column interface{}, value interface{}) error { updateAttrs[field.DBName] = value return field.Set(value) } - if (field.DBName == dbName) || (field.Name == name && mostMatchedField == nil) { + if !field.IsIgnored && ((field.DBName == dbName) || (field.Name == name && mostMatchedField == nil)) { mostMatchedField = field } } @@ -253,15 +255,25 @@ func (scope *Scope) CallMethod(methodName string) { // AddToVars add value as sql's vars, used to prevent SQL injection func (scope *Scope) AddToVars(value interface{}) string { - if expr, ok := value.(*expr); ok { + _, skipBindVar := scope.InstanceGet("skip_bindvar") + + if expr, ok := value.(*SqlExpr); ok { exp := expr.expr for _, arg := range expr.args { - exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) + if skipBindVar { + scope.AddToVars(arg) + } else { + exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) + } } return exp } scope.SQLVars = append(scope.SQLVars, value) + + if skipBindVar { + return "?" + } return scope.Dialect().BindVar(len(scope.SQLVars)) } @@ -318,7 +330,7 @@ func (scope *Scope) TableName() string { // QuotedTableName return quoted table name func (scope *Scope) QuotedTableName() (name string) { if scope.Search != nil && len(scope.Search.tableName) > 0 { - if strings.Index(scope.Search.tableName, " ") != -1 { + if strings.Contains(scope.Search.tableName, " ") { return scope.Search.tableName } return scope.Quote(scope.Search.tableName) @@ -329,13 +341,18 @@ func (scope *Scope) QuotedTableName() (name string) { // CombinedConditionSql return combined condition sql func (scope *Scope) CombinedConditionSql() string { - return scope.joinsSQL() + scope.whereSQL() + scope.groupSQL() + + joinSQL := scope.joinsSQL() + whereSQL := scope.whereSQL() + if scope.Search.raw { + whereSQL = strings.TrimSuffix(strings.TrimPrefix(whereSQL, "WHERE ("), ")") + } + return joinSQL + whereSQL + scope.groupSQL() + scope.havingSQL() + scope.orderSQL() + scope.limitAndOffsetSQL() } // Raw set raw sql func (scope *Scope) Raw(sql string) *Scope { - scope.SQL = strings.Replace(sql, "$$", "?", -1) + scope.SQL = strings.Replace(sql, "$$$", "?", -1) return scope } @@ -385,8 +402,8 @@ func (scope *Scope) InstanceGet(name string) (interface{}, bool) { // Begin start a transaction func (scope *Scope) Begin() *Scope { if db, ok := scope.SQLDB().(sqlDb); ok { - if tx, err := db.Begin(); err == nil { - scope.db.db = interface{}(tx).(sqlCommon) + if tx, err := db.Begin(); scope.Err(err) == nil { + scope.db.db = interface{}(tx).(SQLCommon) scope.InstanceSet("gorm:started_transaction", true) } } @@ -442,7 +459,12 @@ func (scope *Scope) callMethod(methodName string, reflectValue reflect.Value) { } } -var columnRegexp = regexp.MustCompile("^[a-zA-Z]+(\\.[a-zA-Z]+)*$") // only match string like `name`, `users.name` +var ( + columnRegexp = regexp.MustCompile("^[a-zA-Z\\d]+(\\.[a-zA-Z\\d]+)*$") // only match string like `name`, `users.name` + isNumberRegexp = regexp.MustCompile("^\\s*\\d+\\s*$") // match if string is number + comparisonRegexp = regexp.MustCompile("(?i) (=|<>|(>|<)(=?)|LIKE|IS|IN) ") + countingQueryRegexp = regexp.MustCompile("(?i)^count(.+)$") +) func (scope *Scope) quoteIfPossible(str string) string { if columnRegexp.MatchString(str) { @@ -454,18 +476,20 @@ func (scope *Scope) quoteIfPossible(str string) string { func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { var ( ignored interface{} - selectFields []*Field values = make([]interface{}, len(columns)) + selectFields []*Field selectedColumnsMap = map[string]int{} - resetFields = map[*Field]int{} + resetFields = map[int]*Field{} ) for index, column := range columns { values[index] = &ignored selectFields = fields + offset := 0 if idx, ok := selectedColumnsMap[column]; ok { - selectFields = selectFields[idx+1:] + offset = idx + 1 + selectFields = selectFields[offset:] } for fieldIndex, field := range selectFields { @@ -476,18 +500,21 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { reflectValue := reflect.New(reflect.PtrTo(field.Struct.Type)) reflectValue.Elem().Set(field.Field.Addr()) values[index] = reflectValue.Interface() - resetFields[field] = index + resetFields[index] = field } - selectedColumnsMap[column] = fieldIndex - break + selectedColumnsMap[column] = offset + fieldIndex + + if field.IsNormal { + break + } } } } scope.Err(rows.Scan(values...)) - for field, index := range resetFields { + for index, field := range resetFields { if v := reflect.ValueOf(values[index]).Elem().Elem(); v.IsValid() { field.Field.Set(v) } @@ -495,136 +522,146 @@ func (scope *Scope) scan(rows *sql.Rows, columns []string, fields []*Field) { } func (scope *Scope) primaryCondition(value interface{}) string { - return fmt.Sprintf("(%v = %v)", scope.Quote(scope.PrimaryKey()), value) + return fmt.Sprintf("(%v.%v = %v)", scope.QuotedTableName(), scope.Quote(scope.PrimaryKey()), value) } -func (scope *Scope) buildWhereCondition(clause map[string]interface{}) (str string) { +func (scope *Scope) buildCondition(clause map[string]interface{}, include bool) (str string) { + var ( + quotedTableName = scope.QuotedTableName() + quotedPrimaryKey = scope.Quote(scope.PrimaryKey()) + equalSQL = "=" + inSQL = "IN" + ) + + // If building not conditions + if !include { + equalSQL = "<>" + inSQL = "NOT IN" + } + switch value := clause["query"].(type) { - case string: - // if string is number - if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { - return scope.primaryCondition(scope.AddToVars(value)) - } else if value != "" { - str = fmt.Sprintf("(%v)", value) - } - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: - return scope.primaryCondition(scope.AddToVars(value)) + case sql.NullInt64: + return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value.Int64) + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, value) case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string, []interface{}: - str = fmt.Sprintf("(%v IN (?))", scope.Quote(scope.PrimaryKey())) + if !include && reflect.ValueOf(value).Len() == 0 { + return + } + str = fmt.Sprintf("(%v.%v %s (?))", quotedTableName, quotedPrimaryKey, inSQL) clause["args"] = []interface{}{value} + case string: + if isNumberRegexp.MatchString(value) { + return fmt.Sprintf("(%v.%v %s %v)", quotedTableName, quotedPrimaryKey, equalSQL, scope.AddToVars(value)) + } + + if value != "" { + if !include { + if comparisonRegexp.MatchString(value) { + str = fmt.Sprintf("NOT (%v)", value) + } else { + str = fmt.Sprintf("(%v.%v NOT IN (?))", quotedTableName, scope.Quote(value)) + } + } else { + str = fmt.Sprintf("(%v)", value) + } + } case map[string]interface{}: var sqls []string for key, value := range value { if value != nil { - sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(key), scope.AddToVars(value))) + sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", quotedTableName, scope.Quote(key), equalSQL, scope.AddToVars(value))) } else { - sqls = append(sqls, fmt.Sprintf("(%v IS NULL)", scope.Quote(key))) + if !include { + sqls = append(sqls, fmt.Sprintf("(%v.%v IS NOT NULL)", quotedTableName, scope.Quote(key))) + } else { + sqls = append(sqls, fmt.Sprintf("(%v.%v IS NULL)", quotedTableName, scope.Quote(key))) + } } } return strings.Join(sqls, " AND ") case interface{}: var sqls []string - for _, field := range scope.New(value).Fields() { - if !field.IsIgnored && !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v = %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) + newScope := scope.New(value) + + if len(newScope.Fields()) == 0 { + scope.Err(fmt.Errorf("invalid query condition: %v", value)) + return + } + scopeQuotedTableName := newScope.QuotedTableName() + for _, field := range newScope.Fields() { + if !field.IsIgnored && !field.IsBlank && field.Relationship == nil { + sqls = append(sqls, fmt.Sprintf("(%v.%v %s %v)", scopeQuotedTableName, scope.Quote(field.DBName), equalSQL, scope.AddToVars(field.Field.Interface()))) } } return strings.Join(sqls, " AND ") + default: + scope.Err(fmt.Errorf("invalid query condition: %v", value)) + return } + replacements := []string{} args := clause["args"].([]interface{}) for _, arg := range args { + var err error switch reflect.ValueOf(arg).Kind() { case reflect.Slice: // For where("id in (?)", []int64{1,2}) - if bytes, ok := arg.([]byte); ok { - str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) + if scanner, ok := interface{}(arg).(driver.Valuer); ok { + arg, err = scanner.Value() + replacements = append(replacements, scope.AddToVars(arg)) + } else if b, ok := arg.([]byte); ok { + replacements = append(replacements, scope.AddToVars(b)) + } else if as, ok := arg.([][]interface{}); ok { + var tempMarks []string + for _, a := range as { + var arrayMarks []string + for _, v := range a { + arrayMarks = append(arrayMarks, scope.AddToVars(v)) + } + + if len(arrayMarks) > 0 { + tempMarks = append(tempMarks, fmt.Sprintf("(%v)", strings.Join(arrayMarks, ","))) + } + } + + if len(tempMarks) > 0 { + replacements = append(replacements, strings.Join(tempMarks, ",")) + } } else if values := reflect.ValueOf(arg); values.Len() > 0 { var tempMarks []string for i := 0; i < values.Len(); i++ { tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + replacements = append(replacements, strings.Join(tempMarks, ",")) } else { - str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) + replacements = append(replacements, scope.AddToVars(Expr("NULL"))) } default: if valuer, ok := interface{}(arg).(driver.Valuer); ok { - arg, _ = valuer.Value() + arg, err = valuer.Value() } - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) + replacements = append(replacements, scope.AddToVars(arg)) } - } - return -} - -func (scope *Scope) buildNotCondition(clause map[string]interface{}) (str string) { - var notEqualSQL string - var primaryKey = scope.PrimaryKey() - switch value := clause["query"].(type) { - case string: - // is number - if regexp.MustCompile("^\\s*\\d+\\s*$").MatchString(value) { - id, _ := strconv.Atoi(value) - return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), id) - } else if regexp.MustCompile("(?i) (=|<>|>|<|LIKE|IS|IN) ").MatchString(value) { - str = fmt.Sprintf(" NOT (%v) ", value) - notEqualSQL = fmt.Sprintf("NOT (%v)", value) - } else { - str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(value)) - notEqualSQL = fmt.Sprintf("(%v <> ?)", scope.Quote(value)) + if err != nil { + scope.Err(err) } - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, sql.NullInt64: - return fmt.Sprintf("(%v <> %v)", scope.Quote(primaryKey), value) - case []int, []int8, []int16, []int32, []int64, []uint, []uint8, []uint16, []uint32, []uint64, []string: - if reflect.ValueOf(value).Len() > 0 { - str = fmt.Sprintf("(%v NOT IN (?))", scope.Quote(primaryKey)) - clause["args"] = []interface{}{value} - } - return "" - case map[string]interface{}: - var sqls []string - for key, value := range value { - if value != nil { - sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(key), scope.AddToVars(value))) - } else { - sqls = append(sqls, fmt.Sprintf("(%v IS NOT NULL)", scope.Quote(key))) - } - } - return strings.Join(sqls, " AND ") - case interface{}: - var sqls []string - for _, field := range scope.New(value).Fields() { - if !field.IsBlank { - sqls = append(sqls, fmt.Sprintf("(%v <> %v)", scope.Quote(field.DBName), scope.AddToVars(field.Field.Interface()))) - } - } - return strings.Join(sqls, " AND ") } - args := clause["args"].([]interface{}) - for _, arg := range args { - switch reflect.ValueOf(arg).Kind() { - case reflect.Slice: // For where("id in (?)", []int64{1,2}) - if bytes, ok := arg.([]byte); ok { - str = strings.Replace(str, "?", scope.AddToVars(bytes), 1) - } else if values := reflect.ValueOf(arg); values.Len() > 0 { - var tempMarks []string - for i := 0; i < values.Len(); i++ { - tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) - } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) - } else { - str = strings.Replace(str, "?", scope.AddToVars(Expr("NULL")), 1) - } - default: - if scanner, ok := interface{}(arg).(driver.Valuer); ok { - arg, _ = scanner.Value() - } - str = strings.Replace(notEqualSQL, "?", scope.AddToVars(arg), 1) + buff := bytes.NewBuffer([]byte{}) + i := 0 + for _, s := range str { + if s == '?' && len(replacements) > i { + buff.WriteString(replacements[i]) + i++ + } else { + buff.WriteRune(s) } } + + str = buff.String() + return } @@ -637,6 +674,7 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) } args := clause["args"].([]interface{}) + replacements := []string{} for _, arg := range args { switch reflect.ValueOf(arg).Kind() { case reflect.Slice: @@ -645,25 +683,40 @@ func (scope *Scope) buildSelectQuery(clause map[string]interface{}) (str string) for i := 0; i < values.Len(); i++ { tempMarks = append(tempMarks, scope.AddToVars(values.Index(i).Interface())) } - str = strings.Replace(str, "?", strings.Join(tempMarks, ","), 1) + replacements = append(replacements, strings.Join(tempMarks, ",")) default: if valuer, ok := interface{}(arg).(driver.Valuer); ok { arg, _ = valuer.Value() } - str = strings.Replace(str, "?", scope.AddToVars(arg), 1) + replacements = append(replacements, scope.AddToVars(arg)) + } + } + + buff := bytes.NewBuffer([]byte{}) + i := 0 + for pos, char := range str { + if str[pos] == '?' { + buff.WriteString(replacements[i]) + i++ + } else { + buff.WriteRune(char) } } + + str = buff.String() + return } func (scope *Scope) whereSQL() (sql string) { var ( quotedTableName = scope.QuotedTableName() + deletedAtField, hasDeletedAtField = scope.FieldByName("DeletedAt") primaryConditions, andConditions, orConditions []string ) - if !scope.Search.Unscoped && scope.HasColumn("deleted_at") { - sql := fmt.Sprintf("%v.deleted_at IS NULL", quotedTableName) + if !scope.Search.Unscoped && hasDeletedAtField { + sql := fmt.Sprintf("%v.%v IS NULL", quotedTableName, scope.Quote(deletedAtField.DBName)) primaryConditions = append(primaryConditions, sql) } @@ -675,19 +728,19 @@ func (scope *Scope) whereSQL() (sql string) { } for _, clause := range scope.Search.whereConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { + if sql := scope.buildCondition(clause, true); sql != "" { andConditions = append(andConditions, sql) } } for _, clause := range scope.Search.orConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { + if sql := scope.buildCondition(clause, true); sql != "" { orConditions = append(orConditions, sql) } } for _, clause := range scope.Search.notConditions { - if sql := scope.buildNotCondition(clause); sql != "" { + if sql := scope.buildCondition(clause, false); sql != "" { andConditions = append(andConditions, sql) } } @@ -724,19 +777,29 @@ func (scope *Scope) selectSQL() string { } func (scope *Scope) orderSQL() string { - if len(scope.Search.orders) == 0 || scope.Search.countingQuery { + if len(scope.Search.orders) == 0 || scope.Search.ignoreOrderQuery { return "" } var orders []string for _, order := range scope.Search.orders { - orders = append(orders, scope.quoteIfPossible(order)) + if str, ok := order.(string); ok { + orders = append(orders, scope.quoteIfPossible(str)) + } else if expr, ok := order.(*SqlExpr); ok { + exp := expr.expr + for _, arg := range expr.args { + exp = strings.Replace(exp, "?", scope.AddToVars(arg), 1) + } + orders = append(orders, exp) + } } return " ORDER BY " + strings.Join(orders, ",") } func (scope *Scope) limitAndOffsetSQL() string { - return scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset) + sql, err := scope.Dialect().LimitAndOffsetSQL(scope.Search.limit, scope.Search.offset) + scope.Err(err) + return sql } func (scope *Scope) groupSQL() string { @@ -753,7 +816,7 @@ func (scope *Scope) havingSQL() string { var andConditions []string for _, clause := range scope.Search.havingConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { + if sql := scope.buildCondition(clause, true); sql != "" { andConditions = append(andConditions, sql) } } @@ -769,7 +832,7 @@ func (scope *Scope) havingSQL() string { func (scope *Scope) joinsSQL() string { var joinConditions []string for _, clause := range scope.Search.joinConditions { - if sql := scope.buildWhereCondition(clause); sql != "" { + if sql := scope.buildCondition(clause, true); sql != "" { joinConditions = append(joinConditions, strings.TrimSuffix(strings.TrimPrefix(sql, "("), ")")) } } @@ -779,7 +842,7 @@ func (scope *Scope) joinsSQL() string { func (scope *Scope) prepareQuerySQL() { if scope.Search.raw { - scope.Raw(strings.TrimSuffix(strings.TrimPrefix(scope.CombinedConditionSql(), " WHERE ("), ")")) + scope.Raw(scope.CombinedConditionSql()) } else { scope.Raw(fmt.Sprintf("SELECT %v FROM %v %v", scope.selectSQL(), scope.QuotedTableName(), scope.CombinedConditionSql())) } @@ -794,6 +857,14 @@ func (scope *Scope) inlineCondition(values ...interface{}) *Scope { } func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { + defer func() { + if err := recover(); err != nil { + if db, ok := scope.db.db.(sqlTx); ok { + db.Rollback() + } + panic(err) + } + }() for _, f := range funcs { (*f)(scope) if scope.skipLeft { @@ -803,7 +874,7 @@ func (scope *Scope) callCallbacks(funcs []*func(s *Scope)) *Scope { return scope } -func convertInterfaceToMap(values interface{}) map[string]interface{} { +func convertInterfaceToMap(values interface{}, withIgnoredField bool, db *DB) map[string]interface{} { var attrs = map[string]interface{}{} switch value := values.(type) { @@ -811,7 +882,7 @@ func convertInterfaceToMap(values interface{}) map[string]interface{} { return value case []interface{}: for _, v := range value { - for key, value := range convertInterfaceToMap(v) { + for key, value := range convertInterfaceToMap(v, withIgnoredField, db) { attrs[key] = value } } @@ -821,11 +892,11 @@ func convertInterfaceToMap(values interface{}) map[string]interface{} { switch reflectValue.Kind() { case reflect.Map: for _, key := range reflectValue.MapKeys() { - attrs[ToDBName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() + attrs[ToColumnName(key.Interface().(string))] = reflectValue.MapIndex(key).Interface() } default: - for _, field := range (&Scope{Value: values}).Fields() { - if !field.IsBlank { + for _, field := range (&Scope{Value: values, db: db}).Fields() { + if !field.IsBlank && (withIgnoredField || !field.IsIgnored) { attrs[field.DBName] = field.Field.Interface() } } @@ -836,28 +907,31 @@ func convertInterfaceToMap(values interface{}) map[string]interface{} { func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[string]interface{}, hasUpdate bool) { if scope.IndirectValue().Kind() != reflect.Struct { - return convertInterfaceToMap(value), true + return convertInterfaceToMap(value, false, scope.db), true } results = map[string]interface{}{} - for key, value := range convertInterfaceToMap(value) { - if field, ok := scope.FieldByName(key); ok && scope.changeableField(field) { - if _, ok := value.(*expr); ok { - hasUpdate = true - results[field.DBName] = value - } else { - err := field.Set(value) - if field.IsNormal { + for key, value := range convertInterfaceToMap(value, true, scope.db) { + if field, ok := scope.FieldByName(key); ok { + if scope.changeableField(field) { + if _, ok := value.(*SqlExpr); ok { hasUpdate = true - if err == ErrUnaddressable { - fmt.Println(err) - results[field.DBName] = value - } else { - results[field.DBName] = field.Field.Interface() + results[field.DBName] = value + } else { + err := field.Set(value) + if field.IsNormal && !field.IsIgnored { + hasUpdate = true + if err == ErrUnaddressable { + results[field.DBName] = value + } else { + results[field.DBName] = field.Field.Interface() + } } } } + } else { + results[key] = value } } return @@ -865,16 +939,22 @@ func (scope *Scope) updatedAttrsWithValues(value interface{}) (results map[strin func (scope *Scope) row() *sql.Row { defer scope.trace(NowFunc()) + + result := &RowQueryResult{} + scope.InstanceSet("row_query_result", result) scope.callCallbacks(scope.db.parent.callbacks.rowQueries) - scope.prepareQuerySQL() - return scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...) + + return result.Row } func (scope *Scope) rows() (*sql.Rows, error) { defer scope.trace(NowFunc()) + + result := &RowsQueryResult{} + scope.InstanceSet("row_query_result", result) scope.callCallbacks(scope.db.parent.callbacks.rowQueries) - scope.prepareQuerySQL() - return scope.SQLDB().Query(scope.SQL, scope.SQLVars...) + + return result.Rows, result.Error } func (scope *Scope) initialize() *Scope { @@ -886,14 +966,38 @@ func (scope *Scope) initialize() *Scope { return scope } +func (scope *Scope) isQueryForColumn(query interface{}, column string) bool { + queryStr := strings.ToLower(fmt.Sprint(query)) + if queryStr == column { + return true + } + + if strings.HasSuffix(queryStr, "as "+column) { + return true + } + + if strings.HasSuffix(queryStr, "as "+scope.Quote(column)) { + return true + } + + return false +} + func (scope *Scope) pluck(column string, value interface{}) *Scope { dest := reflect.Indirect(reflect.ValueOf(value)) - scope.Search.Select(column) if dest.Kind() != reflect.Slice { scope.Err(fmt.Errorf("results should be a slice, not %s", dest.Kind())) return scope } + if dest.Len() > 0 { + dest.Set(reflect.Zero(dest.Type())) + } + + if query, ok := scope.Search.selects["query"]; !ok || !scope.isQueryForColumn(query, column) { + scope.Search.Select(column) + } + rows, err := scope.rows() if scope.Err(err) == nil { defer rows.Close() @@ -902,13 +1006,31 @@ func (scope *Scope) pluck(column string, value interface{}) *Scope { scope.Err(rows.Scan(elem)) dest.Set(reflect.Append(dest, reflect.ValueOf(elem).Elem())) } + + if err := rows.Err(); err != nil { + scope.Err(err) + } } return scope } func (scope *Scope) count(value interface{}) *Scope { - scope.Search.Select("count(*)") - scope.Search.countingQuery = true + if query, ok := scope.Search.selects["query"]; !ok || !countingQueryRegexp.MatchString(fmt.Sprint(query)) { + if len(scope.Search.group) != 0 { + if len(scope.Search.havingConditions) != 0 { + scope.prepareQuerySQL() + scope.Search = &search{} + scope.Search.Select("count(*)") + scope.Search.Table(fmt.Sprintf("( %s ) AS count_table", scope.SQL)) + } else { + scope.Search.Select("count(*) FROM ( SELECT count(*) as name ") + scope.Search.group += " ) AS count_table" + } + } else { + scope.Search.Select("count(*)") + } + } + scope.Search.ignoreOrderQuery = true scope.Err(scope.row().Scan(value)) return scope } @@ -949,15 +1071,9 @@ func (scope *Scope) changeableField(field *Field) bool { return true } -func (scope *Scope) shouldSaveAssociations() bool { - if saveAssociations, ok := scope.Get("gorm:save_associations"); ok && !saveAssociations.(bool) { - return false - } - return true && !scope.HasError() -} - func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { toScope := scope.db.NewScope(value) + tx := scope.db.Set("gorm:association:source", scope.Value) for _, foreignKey := range append(foreignKeys, toScope.typeName()+"Id", scope.typeName()+"Id") { fromField, _ := scope.FieldByName(foreignKey) @@ -967,36 +1083,34 @@ func (scope *Scope) related(value interface{}, foreignKeys ...string) *Scope { if relationship := fromField.Relationship; relationship != nil { if relationship.Kind == "many_to_many" { joinTableHandler := relationship.JoinTableHandler - scope.Err(joinTableHandler.JoinWith(joinTableHandler, toScope.db, scope.Value).Find(value).Error) + scope.Err(joinTableHandler.JoinWith(joinTableHandler, tx, scope.Value).Find(value).Error) } else if relationship.Kind == "belongs_to" { - query := toScope.db for idx, foreignKey := range relationship.ForeignDBNames { if field, ok := scope.FieldByName(foreignKey); ok { - query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface()) + tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.AssociationForeignDBNames[idx])), field.Field.Interface()) } } - scope.Err(query.Find(value).Error) + scope.Err(tx.Find(value).Error) } else if relationship.Kind == "has_many" || relationship.Kind == "has_one" { - query := toScope.db for idx, foreignKey := range relationship.ForeignDBNames { if field, ok := scope.FieldByName(relationship.AssociationForeignDBNames[idx]); ok { - query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) + tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(foreignKey)), field.Field.Interface()) } } if relationship.PolymorphicType != "" { - query = query.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), scope.TableName()) + tx = tx.Where(fmt.Sprintf("%v = ?", scope.Quote(relationship.PolymorphicDBName)), relationship.PolymorphicValue) } - scope.Err(query.Find(value).Error) + scope.Err(tx.Find(value).Error) } } else { sql := fmt.Sprintf("%v = ?", scope.Quote(toScope.PrimaryKey())) - scope.Err(toScope.db.Where(sql, fromField.Field.Interface()).Find(value).Error) + scope.Err(tx.Where(sql, fromField.Field.Interface()).Find(value).Error) } return scope } else if toField != nil { sql := fmt.Sprintf("%v = ?", scope.Quote(toField.DBName)) - scope.Err(toScope.db.Where(sql, scope.PrimaryKeyValue()).Find(value).Error) + scope.Err(tx.Where(sql, scope.PrimaryKeyValue()).Find(value).Error) return scope } } @@ -1011,7 +1125,7 @@ func (scope *Scope) getTableOptions() string { if !ok { return "" } - return tableOptions.(string) + return " " + tableOptions.(string) } func (scope *Scope) createJoinTable(field *StructField) { @@ -1026,7 +1140,8 @@ func (scope *Scope) createJoinTable(field *StructField) { if field, ok := scope.FieldByName(fieldName); ok { foreignKeyStruct := field.clone() foreignKeyStruct.IsPrimaryKey = false - foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true" + foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true") + foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT") sqlTypes = append(sqlTypes, scope.Quote(relationship.ForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) primaryKeys = append(primaryKeys, scope.Quote(relationship.ForeignDBNames[idx])) } @@ -1036,13 +1151,14 @@ func (scope *Scope) createJoinTable(field *StructField) { if field, ok := toScope.FieldByName(fieldName); ok { foreignKeyStruct := field.clone() foreignKeyStruct.IsPrimaryKey = false - foreignKeyStruct.TagSettings["IS_JOINTABLE_FOREIGNKEY"] = "true" + foreignKeyStruct.TagSettingsSet("IS_JOINTABLE_FOREIGNKEY", "true") + foreignKeyStruct.TagSettingsDelete("AUTO_INCREMENT") sqlTypes = append(sqlTypes, scope.Quote(relationship.AssociationForeignDBNames[idx])+" "+scope.Dialect().DataTypeOf(foreignKeyStruct)) primaryKeys = append(primaryKeys, scope.Quote(relationship.AssociationForeignDBNames[idx])) } } - scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v)) %s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error) + scope.Err(scope.NewDB().Exec(fmt.Sprintf("CREATE TABLE %v (%v, PRIMARY KEY (%v))%s", scope.Quote(joinTable), strings.Join(sqlTypes, ","), strings.Join(primaryKeys, ","), scope.getTableOptions())).Error) } scope.NewDB().Table(joinTable).AutoMigrate(joinTableHandler) } @@ -1077,7 +1193,7 @@ func (scope *Scope) createTable() *Scope { primaryKeyStr = fmt.Sprintf(", PRIMARY KEY (%v)", strings.Join(primaryKeys, ",")) } - scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v) %s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec() + scope.Raw(fmt.Sprintf("CREATE TABLE %v (%v %v)%s", scope.QuotedTableName(), strings.Join(tags, ","), primaryKeyStr, scope.getTableOptions())).Exec() scope.autoIndex() return scope @@ -1089,7 +1205,7 @@ func (scope *Scope) dropTable() *Scope { } func (scope *Scope) modifyColumn(column string, typ string) { - scope.Raw(fmt.Sprintf("ALTER TABLE %v MODIFY %v %v", scope.QuotedTableName(), scope.Quote(column), typ)).Exec() + scope.db.AddError(scope.Dialect().ModifyColumn(scope.QuotedTableName(), scope.Quote(column), typ)) } func (scope *Scope) dropColumn(column string) { @@ -1115,8 +1231,8 @@ func (scope *Scope) addIndex(unique bool, indexName string, column ...string) { } func (scope *Scope) addForeignKey(field string, dest string, onDelete string, onUpdate string) { - var keyName = fmt.Sprintf("%s_%s_%s_foreign", scope.TableName(), field, dest) - keyName = regexp.MustCompile("(_*[^a-zA-Z]+_*|_+)").ReplaceAllString(keyName, "_") + // Compatible with old generated key + keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") if scope.Dialect().HasForeignKey(scope.TableName(), keyName) { return @@ -1125,6 +1241,22 @@ func (scope *Scope) addForeignKey(field string, dest string, onDelete string, on scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName), scope.quoteIfPossible(field), dest, onDelete, onUpdate)).Exec() } +func (scope *Scope) removeForeignKey(field string, dest string) { + keyName := scope.Dialect().BuildKeyName(scope.TableName(), field, dest, "foreign") + if !scope.Dialect().HasForeignKey(scope.TableName(), keyName) { + return + } + var mysql mysql + var query string + if scope.Dialect().GetName() == mysql.GetName() { + query = `ALTER TABLE %s DROP FOREIGN KEY %s;` + } else { + query = `ALTER TABLE %s DROP CONSTRAINT %s;` + } + + scope.Raw(fmt.Sprintf(query, scope.QuotedTableName(), scope.quoteIfPossible(keyName))).Exec() +} + func (scope *Scope) removeIndex(indexName string) { scope.Dialect().RemoveIndex(scope.TableName(), indexName) } @@ -1155,57 +1287,94 @@ func (scope *Scope) autoIndex() *Scope { var uniqueIndexes = map[string][]string{} for _, field := range scope.GetStructFields() { - if name, ok := field.TagSettings["INDEX"]; ok { - if name == "INDEX" { - name = fmt.Sprintf("idx_%v_%v", scope.TableName(), field.DBName) + if name, ok := field.TagSettingsGet("INDEX"); ok { + names := strings.Split(name, ",") + + for _, name := range names { + if name == "INDEX" || name == "" { + name = scope.Dialect().BuildKeyName("idx", scope.TableName(), field.DBName) + } + name, column := scope.Dialect().NormalizeIndexAndColumn(name, field.DBName) + indexes[name] = append(indexes[name], column) } - indexes[name] = append(indexes[name], field.DBName) } - if name, ok := field.TagSettings["UNIQUE_INDEX"]; ok { - if name == "UNIQUE_INDEX" { - name = fmt.Sprintf("uix_%v_%v", scope.TableName(), field.DBName) + if name, ok := field.TagSettingsGet("UNIQUE_INDEX"); ok { + names := strings.Split(name, ",") + + for _, name := range names { + if name == "UNIQUE_INDEX" || name == "" { + name = scope.Dialect().BuildKeyName("uix", scope.TableName(), field.DBName) + } + name, column := scope.Dialect().NormalizeIndexAndColumn(name, field.DBName) + uniqueIndexes[name] = append(uniqueIndexes[name], column) } - uniqueIndexes[name] = append(uniqueIndexes[name], field.DBName) } } for name, columns := range indexes { - scope.NewDB().Model(scope.Value).AddIndex(name, columns...) + if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddIndex(name, columns...); db.Error != nil { + scope.db.AddError(db.Error) + } } for name, columns := range uniqueIndexes { - scope.NewDB().Model(scope.Value).AddUniqueIndex(name, columns...) + if db := scope.NewDB().Table(scope.TableName()).Model(scope.Value).AddUniqueIndex(name, columns...); db.Error != nil { + scope.db.AddError(db.Error) + } } return scope } func (scope *Scope) getColumnAsArray(columns []string, values ...interface{}) (results [][]interface{}) { + resultMap := make(map[string][]interface{}) for _, value := range values { - indirectValue := reflect.ValueOf(value) - for indirectValue.Kind() == reflect.Ptr { - indirectValue = indirectValue.Elem() - } + indirectValue := indirect(reflect.ValueOf(value)) switch indirectValue.Kind() { case reflect.Slice: for i := 0; i < indirectValue.Len(); i++ { var result []interface{} var object = indirect(indirectValue.Index(i)) + var hasValue = false for _, column := range columns { - result = append(result, object.FieldByName(column).Interface()) + field := object.FieldByName(column) + if hasValue || !isBlank(field) { + hasValue = true + } + result = append(result, field.Interface()) + } + + if hasValue { + h := fmt.Sprint(result...) + if _, exist := resultMap[h]; !exist { + resultMap[h] = result + } } - results = append(results, result) } case reflect.Struct: var result []interface{} + var hasValue = false for _, column := range columns { - result = append(result, indirectValue.FieldByName(column).Interface()) + field := indirectValue.FieldByName(column) + if hasValue || !isBlank(field) { + hasValue = true + } + result = append(result, field.Interface()) + } + + if hasValue { + h := fmt.Sprint(result...) + if _, exist := resultMap[h]; !exist { + resultMap[h] = result + } } - results = append(results, result) } } + for _, v := range resultMap { + results = append(results, v) + } return } @@ -1220,6 +1389,7 @@ func (scope *Scope) getColumnAsScope(column string) *Scope { fieldType = fieldType.Elem() } + resultsMap := map[interface{}]bool{} results := reflect.New(reflect.SliceOf(reflect.PtrTo(fieldType))).Elem() for i := 0; i < indirectScopeValue.Len(); i++ { @@ -1227,11 +1397,13 @@ func (scope *Scope) getColumnAsScope(column string) *Scope { if result.Kind() == reflect.Slice { for j := 0; j < result.Len(); j++ { - if elem := result.Index(j); elem.CanAddr() { + if elem := result.Index(j); elem.CanAddr() && resultsMap[elem.Addr()] != true { + resultsMap[elem.Addr()] = true results = reflect.Append(results, elem.Addr()) } } - } else if result.CanAddr() { + } else if result.CanAddr() && resultsMap[result.Addr()] != true { + resultsMap[result.Addr()] = true results = reflect.Append(results, result.Addr()) } } @@ -1244,3 +1416,10 @@ func (scope *Scope) getColumnAsScope(column string) *Scope { } return nil } + +func (scope *Scope) hasConditions() bool { + return !scope.PrimaryKeyZero() || + len(scope.Search.whereConditions) > 0 || + len(scope.Search.orConditions) > 0 || + len(scope.Search.notConditions) > 0 +} diff --git a/vendor/github.com/jinzhu/gorm/search.go b/vendor/github.com/jinzhu/gorm/search.go index 078bd429..52ae2efc 100644 --- a/vendor/github.com/jinzhu/gorm/search.go +++ b/vendor/github.com/jinzhu/gorm/search.go @@ -1,6 +1,8 @@ package gorm -import "fmt" +import ( + "fmt" +) type search struct { db *DB @@ -13,15 +15,15 @@ type search struct { assignAttrs []interface{} selects map[string]interface{} omits []string - orders []string + orders []interface{} preload []searchPreload - offset int - limit int + offset interface{} + limit interface{} group string tableName string raw bool Unscoped bool - countingQuery bool + ignoreOrderQuery bool } type searchPreload struct { @@ -30,7 +32,57 @@ type searchPreload struct { } func (s *search) clone() *search { - clone := *s + clone := search{ + db: s.db, + whereConditions: make([]map[string]interface{}, len(s.whereConditions)), + orConditions: make([]map[string]interface{}, len(s.orConditions)), + notConditions: make([]map[string]interface{}, len(s.notConditions)), + havingConditions: make([]map[string]interface{}, len(s.havingConditions)), + joinConditions: make([]map[string]interface{}, len(s.joinConditions)), + initAttrs: make([]interface{}, len(s.initAttrs)), + assignAttrs: make([]interface{}, len(s.assignAttrs)), + selects: s.selects, + omits: make([]string, len(s.omits)), + orders: make([]interface{}, len(s.orders)), + preload: make([]searchPreload, len(s.preload)), + offset: s.offset, + limit: s.limit, + group: s.group, + tableName: s.tableName, + raw: s.raw, + Unscoped: s.Unscoped, + ignoreOrderQuery: s.ignoreOrderQuery, + } + for i, value := range s.whereConditions { + clone.whereConditions[i] = value + } + for i, value := range s.orConditions { + clone.orConditions[i] = value + } + for i, value := range s.notConditions { + clone.notConditions[i] = value + } + for i, value := range s.havingConditions { + clone.havingConditions[i] = value + } + for i, value := range s.joinConditions { + clone.joinConditions[i] = value + } + for i, value := range s.initAttrs { + clone.initAttrs[i] = value + } + for i, value := range s.assignAttrs { + clone.assignAttrs[i] = value + } + for i, value := range s.omits { + clone.omits[i] = value + } + for i, value := range s.orders { + clone.orders[i] = value + } + for i, value := range s.preload { + clone.preload[i] = value + } return &clone } @@ -59,14 +111,12 @@ func (s *search) Assign(attrs ...interface{}) *search { return s } -func (s *search) Order(value string, reorder ...bool) *search { +func (s *search) Order(value interface{}, reorder ...bool) *search { if len(reorder) > 0 && reorder[0] { - if value != "" { - s.orders = []string{value} - } else { - s.orders = []string{} - } - } else if value != "" { + s.orders = []interface{}{} + } + + if value != nil && value != "" { s.orders = append(s.orders, value) } return s @@ -82,12 +132,12 @@ func (s *search) Omit(columns ...string) *search { return s } -func (s *search) Limit(limit int) *search { +func (s *search) Limit(limit interface{}) *search { s.limit = limit return s } -func (s *search) Offset(offset int) *search { +func (s *search) Offset(offset interface{}) *search { s.offset = offset return s } @@ -97,8 +147,12 @@ func (s *search) Group(query string) *search { return s } -func (s *search) Having(query string, values ...interface{}) *search { - s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values}) +func (s *search) Having(query interface{}, values ...interface{}) *search { + if val, ok := query.(*SqlExpr); ok { + s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": val.expr, "args": val.args}) + } else { + s.havingConditions = append(s.havingConditions, map[string]interface{}{"query": query, "args": values}) + } return s } diff --git a/vendor/github.com/jinzhu/gorm/test_all.sh b/vendor/github.com/jinzhu/gorm/test_all.sh index 6c5593b3..5cfb3321 100644 --- a/vendor/github.com/jinzhu/gorm/test_all.sh +++ b/vendor/github.com/jinzhu/gorm/test_all.sh @@ -1,5 +1,5 @@ -dialects=("postgres" "mysql" "sqlite") +dialects=("postgres" "mysql" "mssql" "sqlite") for dialect in "${dialects[@]}" ; do - GORM_DIALECT=${dialect} go test + DEBUG=false GORM_DIALECT=${dialect} go test done diff --git a/vendor/github.com/jinzhu/gorm/utils.go b/vendor/github.com/jinzhu/gorm/utils.go index dc69e804..d2ae9465 100644 --- a/vendor/github.com/jinzhu/gorm/utils.go +++ b/vendor/github.com/jinzhu/gorm/utils.go @@ -1,7 +1,6 @@ package gorm import ( - "bytes" "database/sql/driver" "fmt" "reflect" @@ -23,9 +22,12 @@ var NowFunc = func() time.Time { } // Copied from golint -var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UI", "UID", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} +var commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} var commonInitialismsReplacer *strings.Replacer +var goSrcRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*.go`) +var goTestRegexp = regexp.MustCompile(`jinzhu/gorm(@.*)?/.*test.go`) + func init() { var commonInitialismsForReplacer []string for _, initialism := range commonInitialisms { @@ -55,71 +57,16 @@ func newSafeMap() *safeMap { return &safeMap{l: new(sync.RWMutex), m: make(map[string]string)} } -var smap = newSafeMap() - -type strCase bool - -const ( - lower strCase = false - upper strCase = true -) - -// ToDBName convert string to db name -func ToDBName(name string) string { - if v := smap.Get(name); v != "" { - return v - } - - if name == "" { - return "" - } - - var ( - value = commonInitialismsReplacer.Replace(name) - buf = bytes.NewBufferString("") - lastCase, currCase, nextCase strCase - ) - - for i, v := range value[:len(value)-1] { - nextCase = strCase(value[i+1] >= 'A' && value[i+1] <= 'Z') - if i > 0 { - if currCase == upper { - if lastCase == upper && nextCase == upper { - buf.WriteRune(v) - } else { - if value[i-1] != '_' && value[i+1] != '_' { - buf.WriteRune('_') - } - buf.WriteRune(v) - } - } else { - buf.WriteRune(v) - } - } else { - currCase = upper - buf.WriteRune(v) - } - lastCase = currCase - currCase = nextCase - } - - buf.WriteByte(value[len(value)-1]) - - s := strings.ToLower(buf.String()) - smap.Set(name, s) - return s -} - // SQL expression -type expr struct { +type SqlExpr struct { expr string args []interface{} } // Expr generate raw SQL expression, for example: // DB.Model(&product).Update("price", gorm.Expr("price * ? + ?", 2, 100)) -func Expr(expression string, args ...interface{}) *expr { - return &expr{expr: expression, args: args} +func Expr(expression string, args ...interface{}) *SqlExpr { + return &SqlExpr{expr: expression, args: args} } func indirect(reflectValue reflect.Value) reflect.Value { @@ -171,7 +118,7 @@ func toQueryValues(values [][]interface{}) (results []interface{}) { func fileWithLineNum() string { for i := 2; i < 15; i++ { _, file, line, ok := runtime.Caller(i) - if ok && (!regexp.MustCompile(`jinzhu/gorm/.*.go`).MatchString(file) || regexp.MustCompile(`jinzhu/gorm/.*test.go`).MatchString(file)) { + if ok && (!goSrcRegexp.MatchString(file) || goTestRegexp.MatchString(file)) { return fmt.Sprintf("%v:%v", file, line) } } @@ -179,6 +126,21 @@ func fileWithLineNum() string { } func isBlank(value reflect.Value) bool { + switch value.Kind() { + case reflect.String: + return value.Len() == 0 + case reflect.Bool: + return !value.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return value.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return value.Uint() == 0 + case reflect.Float32, reflect.Float64: + return value.Float() == 0 + case reflect.Interface, reflect.Ptr: + return value.IsNil() + } + return reflect.DeepEqual(value.Interface(), reflect.Zero(value.Type()).Interface()) } @@ -244,7 +206,7 @@ func getValueFromFields(value reflect.Value, fieldNames []string) (results []int // as FieldByName could panic if indirectValue := reflect.Indirect(value); indirectValue.IsValid() { for _, fieldName := range fieldNames { - if fieldValue := indirectValue.FieldByName(fieldName); fieldValue.IsValid() { + if fieldValue := reflect.Indirect(indirectValue.FieldByName(fieldName)); fieldValue.IsValid() { result := fieldValue.Interface() if r, ok := result.(driver.Valuer); ok { result, _ = r.Value() diff --git a/vendor/github.com/jinzhu/gorm/wercker.yml b/vendor/github.com/jinzhu/gorm/wercker.yml new file mode 100644 index 00000000..1de947b8 --- /dev/null +++ b/vendor/github.com/jinzhu/gorm/wercker.yml @@ -0,0 +1,149 @@ +# use the default golang container from Docker Hub +box: golang + +services: + - name: mariadb + id: mariadb:latest + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + - name: mysql + id: mysql:latest + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + - name: mysql57 + id: mysql:5.7 + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + - name: mysql56 + id: mysql:5.6 + env: + MYSQL_DATABASE: gorm + MYSQL_USER: gorm + MYSQL_PASSWORD: gorm + MYSQL_RANDOM_ROOT_PASSWORD: "yes" + - name: postgres + id: postgres:latest + env: + POSTGRES_USER: gorm + POSTGRES_PASSWORD: gorm + POSTGRES_DB: gorm + - name: postgres96 + id: postgres:9.6 + env: + POSTGRES_USER: gorm + POSTGRES_PASSWORD: gorm + POSTGRES_DB: gorm + - name: postgres95 + id: postgres:9.5 + env: + POSTGRES_USER: gorm + POSTGRES_PASSWORD: gorm + POSTGRES_DB: gorm + - name: postgres94 + id: postgres:9.4 + env: + POSTGRES_USER: gorm + POSTGRES_PASSWORD: gorm + POSTGRES_DB: gorm + - name: postgres93 + id: postgres:9.3 + env: + POSTGRES_USER: gorm + POSTGRES_PASSWORD: gorm + POSTGRES_DB: gorm + - name: mssql + id: mcmoe/mssqldocker:latest + env: + ACCEPT_EULA: Y + SA_PASSWORD: LoremIpsum86 + MSSQL_DB: gorm + MSSQL_USER: gorm + MSSQL_PASSWORD: LoremIpsum86 + +# The steps that will be executed in the build pipeline +build: + # The steps that will be executed on build + steps: + # Sets the go workspace and places you package + # at the right place in the workspace tree + - setup-go-workspace + + # Gets the dependencies + - script: + name: go get + code: | + cd $WERCKER_SOURCE_DIR + go version + go get -t -v ./... + + # Build the project + - script: + name: go build + code: | + go build ./... + + # Test the project + - script: + name: test sqlite + code: | + go test -race -v ./... + + - script: + name: test mariadb + code: | + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mariadb:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... + + - script: + name: test mysql + code: | + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... + + - script: + name: test mysql5.7 + code: | + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql57:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... + + - script: + name: test mysql5.6 + code: | + GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(mysql56:3306)/gorm?charset=utf8&parseTime=True" go test -race ./... + + - script: + name: test postgres + code: | + GORM_DIALECT=postgres GORM_DSN="host=postgres user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... + + - script: + name: test postgres96 + code: | + GORM_DIALECT=postgres GORM_DSN="host=postgres96 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... + + - script: + name: test postgres95 + code: | + GORM_DIALECT=postgres GORM_DSN="host=postgres95 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... + + - script: + name: test postgres94 + code: | + GORM_DIALECT=postgres GORM_DSN="host=postgres94 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... + + - script: + name: test postgres93 + code: | + GORM_DIALECT=postgres GORM_DSN="host=postgres93 user=gorm password=gorm DB.name=gorm port=5432 sslmode=disable" go test -race ./... + + - script: + name: codecov + code: | + go test -race -coverprofile=coverage.txt -covermode=atomic ./... + bash <(curl -s https://codecov.io/bash) diff --git a/vendor/github.com/jinzhu/inflection/README.md b/vendor/github.com/jinzhu/inflection/README.md index 4dd0f2d9..a3de3361 100644 --- a/vendor/github.com/jinzhu/inflection/README.md +++ b/vendor/github.com/jinzhu/inflection/README.md @@ -1,8 +1,9 @@ -Inflection -========= +# Inflection Inflection pluralizes and singularizes English nouns +[![wercker status](https://app.wercker.com/status/f8c7432b097d1f4ce636879670be0930/s/master "wercker status")](https://app.wercker.com/project/byKey/f8c7432b097d1f4ce636879670be0930) + ## Basic Usage ```go @@ -37,10 +38,9 @@ inflection.AddPlural("(bu)s$", "${1}ses") # "bus" => "buses" / "BUS" => "BUSES" inflection.AddSingular("(bus)(es)?$", "${1}") # "buses" => "bus" / "Buses" => "Bus" / "BUSES" => "BUS" ``` -## Supporting the project - -[![http://patreon.com/jinzhu](http://patreon_public_assets.s3.amazonaws.com/sized/becomeAPatronBanner.png)](http://patreon.com/jinzhu) +## Contributing +You can help to make the project better, check out [http://gorm.io/contribute.html](http://gorm.io/contribute.html) for things you can do. ## Author diff --git a/vendor/github.com/jinzhu/inflection/wercker.yml b/vendor/github.com/jinzhu/inflection/wercker.yml new file mode 100644 index 00000000..5e6ce981 --- /dev/null +++ b/vendor/github.com/jinzhu/inflection/wercker.yml @@ -0,0 +1,23 @@ +box: golang + +build: + steps: + - setup-go-workspace + + # Gets the dependencies + - script: + name: go get + code: | + go get + + # Build the project + - script: + name: go build + code: | + go build ./... + + # Test the project + - script: + name: go test + code: | + go test ./... diff --git a/vendor/modules.txt b/vendor/modules.txt index 90a2b316..8583f2ad 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -28,13 +28,9 @@ github.com/boltdb/bolt # github.com/bwmarrin/discordgo v0.25.0 ## explicit; go 1.13 github.com/bwmarrin/discordgo -# github.com/denisenkom/go-mssqldb v0.0.0-20190915052044-aa4949efa320 -## explicit; go 1.11 # github.com/dineshappavoo/basex v0.0.0-20160618072718-f35bafba529c ## explicit github.com/dineshappavoo/basex -# github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 -## explicit # github.com/facebookgo/clock v0.0.0-20150410010913-600d898af40a ## explicit github.com/facebookgo/clock @@ -119,15 +115,13 @@ github.com/honeycombio/beeline-go/wrappers/hnynethttp ## explicit; go 1.14 github.com/honeycombio/libhoney-go github.com/honeycombio/libhoney-go/transmission -# github.com/jinzhu/gorm v0.0.0-20160404144928-5174cc5c242a -## explicit +# github.com/jinzhu/gorm v1.9.16 +## explicit; go 1.12 github.com/jinzhu/gorm github.com/jinzhu/gorm/dialects/mysql -# github.com/jinzhu/inflection v0.0.0-20170102125226-1c35d901db3d +# github.com/jinzhu/inflection v1.0.0 ## explicit github.com/jinzhu/inflection -# github.com/jinzhu/now v1.0.1 -## explicit; go 1.12 # github.com/kennygrant/sanitize v1.2.4 ## explicit github.com/kennygrant/sanitize