From b0565e52ebe4dbfbf062fcf72e7e5a07a05f9f50 Mon Sep 17 00:00:00 2001 From: Erik Unger Date: Sun, 1 Dec 2024 22:23:22 +0100 Subject: [PATCH 001/110] wip --- connection.go | 28 -------- connectionwitherror.go | 112 +++++++++++++++++++++++++++++++ {impl => db}/update.go | 27 +++++--- {impl => db}/upsert.go | 15 +++-- errors.go | 129 ------------------------------------ impl/connection.go | 20 ------ impl/transaction.go | 20 ------ mockconn/connection.go | 23 +------ mockconn/connection_test.go | 26 +++++--- pqconn/connection.go | 20 ------ pqconn/transaction.go | 20 ------ transaction_test.go | 7 ++ 12 files changed, 163 insertions(+), 284 deletions(-) create mode 100644 connectionwitherror.go rename {impl => db}/update.go (76%) rename {impl => db}/upsert.go (73%) diff --git a/connection.go b/connection.go index 054f43e..da8bf5a 100644 --- a/connection.go +++ b/connection.go @@ -63,34 +63,6 @@ type Connection interface { // Exec executes a query with optional args. Exec(query string, args ...any) error - // Update table rows(s) with values using the where statement with passed in args starting at $1. - Update(table string, values Values, where string, args ...any) error - - // UpdateReturningRow updates a table row with values using the where statement with passed in args starting at $1 - // and returning a single row with the columns specified in returning argument. - UpdateReturningRow(table string, values Values, returning, where string, args ...any) RowScanner - - // UpdateReturningRows updates table rows with values using the where statement with passed in args starting at $1 - // and returning multiple rows with the columns specified in returning argument. - UpdateReturningRows(table string, values Values, returning, where string, args ...any) RowsScanner - - // UpdateStruct updates a row in a table using the exported fields - // of rowStruct which have a `db` tag that is not "-". - // If restrictToColumns are provided, then only struct fields with a `db` tag - // matching any of the passed column names will be used. - // The struct must have at least one field with a `db` tag value having a ",pk" suffix - // to mark primary key column(s). - UpdateStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error - - // UpsertStruct upserts a row to table using the exported fields - // of rowStruct which have a `db` tag that is not "-". - // If restrictToColumns are provided, then only struct fields with a `db` tag - // matching any of the passed column names will be used. - // The struct must have at least one field with a `db` tag value having a ",pk" suffix - // to mark primary key column(s). - // If inserting conflicts on the primary key column(s), then an update is performed. - UpsertStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error - // QueryRow queries a single row and returns a RowScanner for the results. QueryRow(query string, args ...any) RowScanner diff --git a/connectionwitherror.go b/connectionwitherror.go new file mode 100644 index 0000000..87e7e71 --- /dev/null +++ b/connectionwitherror.go @@ -0,0 +1,112 @@ +package sqldb + +import ( + "context" + "database/sql" + "fmt" + "time" +) + +// ConnectionWithError returns a dummy Connection +// where all methods return the passed error. +func ConnectionWithError(ctx context.Context, err error) Connection { + if err == nil { + panic("ConnectionWithError needs an error") + } + return connectionWithError{ctx, err} +} + +type connectionWithError struct { + ctx context.Context + err error +} + +func (e connectionWithError) Context() context.Context { return e.ctx } + +func (e connectionWithError) WithContext(ctx context.Context) Connection { + return connectionWithError{ctx: ctx, err: e.err} +} + +func (e connectionWithError) WithStructFieldMapper(namer StructFieldMapper) Connection { + return e +} + +func (e connectionWithError) StructFieldMapper() StructFieldMapper { + return DefaultStructFieldMapping +} + +func (e connectionWithError) Ping(time.Duration) error { + return e.err +} + +func (e connectionWithError) Stats() sql.DBStats { + return sql.DBStats{} +} + +func (e connectionWithError) Config() *Config { + return &Config{Err: e.err} +} + +func (e connectionWithError) Placeholder(paramIndex int) string { + return fmt.Sprintf("?%d", paramIndex+1) +} + +func (e connectionWithError) ValidateColumnName(name string) error { + return e.err +} + +func (e connectionWithError) Exec(query string, args ...any) error { + return e.err +} + +func (e connectionWithError) QueryRow(query string, args ...any) RowScanner { + return RowScannerWithError(e.err) +} + +func (e connectionWithError) QueryRows(query string, args ...any) RowsScanner { + return RowsScannerWithError(e.err) +} + +func (e connectionWithError) IsTransaction() bool { + return false +} + +func (e connectionWithError) TransactionNo() uint64 { + return 0 +} + +func (ce connectionWithError) TransactionOptions() (*sql.TxOptions, bool) { + return nil, false +} + +func (e connectionWithError) Begin(opts *sql.TxOptions, no uint64) (Connection, error) { + return nil, e.err +} + +func (e connectionWithError) Commit() error { + return e.err +} + +func (e connectionWithError) Rollback() error { + return e.err +} + +func (e connectionWithError) Transaction(opts *sql.TxOptions, txFunc func(tx Connection) error) error { + return e.err +} + +func (e connectionWithError) ListenOnChannel(channel string, onNotify OnNotifyFunc, onUnlisten OnUnlistenFunc) error { + return e.err +} + +func (e connectionWithError) UnlistenChannel(channel string) error { + return e.err +} + +func (e connectionWithError) IsListeningOnChannel(channel string) bool { + return false +} + +func (e connectionWithError) Close() error { + return e.err +} diff --git a/impl/update.go b/db/update.go similarity index 76% rename from impl/update.go rename to db/update.go index e5f6ca1..9c987d6 100644 --- a/impl/update.go +++ b/db/update.go @@ -1,6 +1,7 @@ -package impl +package db import ( + "context" "fmt" "reflect" "slices" @@ -10,22 +11,24 @@ import ( ) // Update table rows(s) with values using the where statement with passed in args starting at $1. -func Update(conn sqldb.Connection, table string, values sqldb.Values, where, argFmt string, args []any) error { +func Update(ctx context.Context, table string, values sqldb.Values, where string, args ...any) error { if len(values) == 0 { return fmt.Errorf("Update table %s: no values passed", table) } + conn := Conn(ctx) query, vals := buildUpdateQuery(table, values, where, args) err := conn.Exec(query, vals...) - return WrapNonNilErrorWithQuery(err, query, argFmt, vals) + return WrapNonNilErrorWithQuery(err, query, vals) } // UpdateReturningRow updates a table row with values using the where statement with passed in args starting at $1 // and returning a single row with the columns specified in returning argument. -func UpdateReturningRow(conn sqldb.Connection, table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { +func UpdateReturningRow(ctx context.Context, table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { if len(values) == 0 { return sqldb.RowScannerWithError(fmt.Errorf("UpdateReturningRow table %s: no values passed", table)) } + conn := Conn(ctx) query, vals := buildUpdateQuery(table, values, where, args) query += " RETURNING " + returning @@ -34,17 +37,18 @@ func UpdateReturningRow(conn sqldb.Connection, table string, values sqldb.Values // UpdateReturningRows updates table rows with values using the where statement with passed in args starting at $1 // and returning multiple rows with the columns specified in returning argument. -func UpdateReturningRows(conn sqldb.Connection, table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { +func UpdateReturningRows(ctx context.Context, table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { if len(values) == 0 { return sqldb.RowsScannerWithError(fmt.Errorf("UpdateReturningRows table %s: no values passed", table)) } + conn := Conn(ctx) query, vals := buildUpdateQuery(table, values, where, args) query += " RETURNING " + returning return conn.QueryRows(query, vals...) } -func buildUpdateQuery(table string, values sqldb.Values, where string, args []any) (string, []any) { +func buildUpdateQuery(table string, values sqldb.Values, where string, args ...any) (string, []any) { names, vals := values.Sorted() var query strings.Builder @@ -60,12 +64,13 @@ func buildUpdateQuery(table string, values sqldb.Values, where string, args []an return query.String(), append(args, vals...) } -// UpdateStruct updates a row of table using the exported fields +// UpdateStruct updates a row in a table using the exported fields // of rowStruct which have a `db` tag that is not "-". -// Struct fields with a `db` tag matching any of the passed ignoreColumns will not be used. // If restrictToColumns are provided, then only struct fields with a `db` tag // matching any of the passed column names will be used. -func UpdateStruct(conn sqldb.Connection, table string, rowStruct any, namer sqldb.StructFieldMapper, argFmt string, ignoreColumns []sqldb.ColumnFilter) error { +// The struct must have at least one field with a `db` tag value having a ",pk" suffix +// to mark primary key column(s). +func UpdateStruct(ctx context.Context, table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { v := reflect.ValueOf(rowStruct) for v.Kind() == reflect.Ptr && !v.IsNil() { v = v.Elem() @@ -107,7 +112,9 @@ func UpdateStruct(conn sqldb.Connection, table string, rowStruct any, namer sqld query := b.String() + conn := Conn(ctx) + err := conn.Exec(query, vals...) - return WrapNonNilErrorWithQuery(err, query, argFmt, vals) + return WrapNonNilErrorWithQuery(err, query, vals) } diff --git a/impl/upsert.go b/db/upsert.go similarity index 73% rename from impl/upsert.go rename to db/upsert.go index ebbbfbd..05e82ad 100644 --- a/impl/upsert.go +++ b/db/upsert.go @@ -1,21 +1,24 @@ -package impl +package db import ( + "context" "fmt" "reflect" "slices" "strings" "github.com/domonda/go-sqldb" + "github.com/domonda/go-sqldb/impl" ) // UpsertStruct upserts a row to table using the exported fields // of rowStruct which have a `db` tag that is not "-". -// Struct fields with a `db` tag matching any of the passed ignoreColumns will not be used. // If restrictToColumns are provided, then only struct fields with a `db` tag // matching any of the passed column names will be used. -// If inserting conflicts on pkColumn, then an update of the existing row is performed. -func UpsertStruct(conn sqldb.Connection, table string, rowStruct any, namer sqldb.StructFieldMapper, argFmt string, ignoreColumns []sqldb.ColumnFilter) error { +// The struct must have at least one field with a `db` tag value having a ",pk" suffix +// to mark primary key column(s). +// If inserting conflicts on the primary key column(s), then an update is performed. +func UpsertStruct(ctx context.Context, table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { v := reflect.ValueOf(rowStruct) for v.Kind() == reflect.Ptr && !v.IsNil() { v = v.Elem() @@ -27,7 +30,7 @@ func UpsertStruct(conn sqldb.Connection, table string, rowStruct any, namer sqld return fmt.Errorf("UpsertStruct to table %s: expected struct but got %T", table, rowStruct) } - columns, pkCols, vals := ReflectStructValues(v, namer, append(ignoreColumns, sqldb.IgnoreReadOnly)) + columns, pkCols, vals := impl.ReflectStructValues(v, namer, append(ignoreColumns, sqldb.IgnoreReadOnly)) if len(pkCols) == 0 { return fmt.Errorf("UpsertStruct of table %s: %s has no mapped primary key field", table, v.Type()) } @@ -57,6 +60,8 @@ func UpsertStruct(conn sqldb.Connection, table string, rowStruct any, namer sqld } query := b.String() + conn := Conn(ctx) + err := conn.Exec(query, vals...) return WrapNonNilErrorWithQuery(err, query, argFmt, vals) diff --git a/errors.go b/errors.go index 3427ca8..f106fa5 100644 --- a/errors.go +++ b/errors.go @@ -1,11 +1,8 @@ package sqldb import ( - "context" "database/sql" "errors" - "fmt" - "time" ) var ( @@ -167,132 +164,6 @@ func (e ErrExclusionViolation) Unwrap() error { return ErrIntegrityConstraintViolation{Constraint: e.Constraint} } -// ConnectionWithError - -// ConnectionWithError returns a dummy Connection -// where all methods return the passed error. -func ConnectionWithError(ctx context.Context, err error) Connection { - if err == nil { - panic("ConnectionWithError needs an error") - } - return connectionWithError{ctx, err} -} - -type connectionWithError struct { - ctx context.Context - err error -} - -func (e connectionWithError) Context() context.Context { return e.ctx } - -func (e connectionWithError) WithContext(ctx context.Context) Connection { - return connectionWithError{ctx: ctx, err: e.err} -} - -func (e connectionWithError) WithStructFieldMapper(namer StructFieldMapper) Connection { - return e -} - -func (e connectionWithError) StructFieldMapper() StructFieldMapper { - return DefaultStructFieldMapping -} - -func (e connectionWithError) Ping(time.Duration) error { - return e.err -} - -func (e connectionWithError) Stats() sql.DBStats { - return sql.DBStats{} -} - -func (e connectionWithError) Config() *Config { - return &Config{Err: e.err} -} - -func (e connectionWithError) Placeholder(paramIndex int) string { - return fmt.Sprintf("$%d", paramIndex+1) -} - -func (e connectionWithError) ValidateColumnName(name string) error { - return e.err -} - -func (e connectionWithError) Exec(query string, args ...any) error { - return e.err -} - -func (e connectionWithError) Update(table string, values Values, where string, args ...any) error { - return e.err -} - -func (e connectionWithError) UpdateReturningRow(table string, values Values, returning, where string, args ...any) RowScanner { - return RowScannerWithError(e.err) -} - -func (e connectionWithError) UpdateReturningRows(table string, values Values, returning, where string, args ...any) RowsScanner { - return RowsScannerWithError(e.err) -} - -func (e connectionWithError) UpdateStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error { - return e.err -} - -func (e connectionWithError) UpsertStruct(table string, rowStruct any, ignoreColumns ...ColumnFilter) error { - return e.err -} - -func (e connectionWithError) QueryRow(query string, args ...any) RowScanner { - return RowScannerWithError(e.err) -} - -func (e connectionWithError) QueryRows(query string, args ...any) RowsScanner { - return RowsScannerWithError(e.err) -} - -func (e connectionWithError) IsTransaction() bool { - return false -} - -func (e connectionWithError) TransactionNo() uint64 { - return 0 -} - -func (ce connectionWithError) TransactionOptions() (*sql.TxOptions, bool) { - return nil, false -} - -func (e connectionWithError) Begin(opts *sql.TxOptions, no uint64) (Connection, error) { - return nil, e.err -} - -func (e connectionWithError) Commit() error { - return e.err -} - -func (e connectionWithError) Rollback() error { - return e.err -} - -func (e connectionWithError) Transaction(opts *sql.TxOptions, txFunc func(tx Connection) error) error { - return e.err -} - -func (e connectionWithError) ListenOnChannel(channel string, onNotify OnNotifyFunc, onUnlisten OnUnlistenFunc) error { - return e.err -} - -func (e connectionWithError) UnlistenChannel(channel string) error { - return e.err -} - -func (e connectionWithError) IsListeningOnChannel(channel string) bool { - return false -} - -func (e connectionWithError) Close() error { - return e.err -} - // RowScannerWithError // RowScannerWithError returns a dummy RowScanner diff --git a/impl/connection.go b/impl/connection.go index f461d96..9d495d8 100644 --- a/impl/connection.go +++ b/impl/connection.go @@ -91,26 +91,6 @@ func (conn *connection) Exec(query string, args ...any) error { return WrapNonNilErrorWithQuery(err, query, conn.argFmt, args) } -func (conn *connection) Update(table string, values sqldb.Values, where string, args ...any) error { - return Update(conn, table, values, where, conn.argFmt, args) -} - -func (conn *connection) UpdateReturningRow(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { - return UpdateReturningRow(conn, table, values, returning, where, args) -} - -func (conn *connection) UpdateReturningRows(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { - return UpdateReturningRows(conn, table, values, returning, where, args) -} - -func (conn *connection) UpdateStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return UpdateStruct(conn, table, rowStruct, conn.structFieldNamer, conn.argFmt, ignoreColumns) -} - -func (conn *connection) UpsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return UpsertStruct(conn, table, rowStruct, conn.structFieldNamer, conn.argFmt, ignoreColumns) -} - func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { rows, err := conn.db.QueryContext(conn.ctx, query, args...) if err != nil { diff --git a/impl/transaction.go b/impl/transaction.go index d4fed91..a1f5cfe 100644 --- a/impl/transaction.go +++ b/impl/transaction.go @@ -72,26 +72,6 @@ func (conn *transaction) Exec(query string, args ...any) error { return WrapNonNilErrorWithQuery(err, query, conn.parent.argFmt, args) } -func (conn *transaction) Update(table string, values sqldb.Values, where string, args ...any) error { - return Update(conn, table, values, where, conn.parent.argFmt, args) -} - -func (conn *transaction) UpdateReturningRow(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { - return UpdateReturningRow(conn, table, values, returning, where, args) -} - -func (conn *transaction) UpdateReturningRows(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { - return UpdateReturningRows(conn, table, values, returning, where, args) -} - -func (conn *transaction) UpdateStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return UpdateStruct(conn, table, rowStruct, conn.structFieldNamer, conn.parent.argFmt, ignoreColumns) -} - -func (conn *transaction) UpsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return UpsertStruct(conn, table, rowStruct, conn.structFieldNamer, conn.parent.argFmt, ignoreColumns) -} - func (conn *transaction) QueryRow(query string, args ...any) sqldb.RowScanner { rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) if err != nil { diff --git a/mockconn/connection.go b/mockconn/connection.go index e25c4d7..d92f406 100644 --- a/mockconn/connection.go +++ b/mockconn/connection.go @@ -8,10 +8,9 @@ import ( "time" "github.com/domonda/go-sqldb" - "github.com/domonda/go-sqldb/impl" ) -var DefaultArgFmt = "$%d" +var DefaultArgFmt = "?%d" func New(ctx context.Context, queryWriter io.Writer, rowsProvider RowsProvider) sqldb.Connection { return &connection{ @@ -91,26 +90,6 @@ func (conn *connection) Exec(query string, args ...any) error { return nil } -func (conn *connection) Update(table string, values sqldb.Values, where string, args ...any) error { - return impl.Update(conn, table, values, where, conn.argFmt, args) -} - -func (conn *connection) UpdateReturningRow(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { - return impl.UpdateReturningRow(conn, table, values, returning, where, args) -} - -func (conn *connection) UpdateReturningRows(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { - return impl.UpdateReturningRows(conn, table, values, returning, where, args) -} - -func (conn *connection) UpdateStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.UpdateStruct(conn, table, rowStruct, conn.structFieldNamer, conn.argFmt, ignoreColumns) -} - -func (conn *connection) UpsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.UpsertStruct(conn, table, rowStruct, conn.structFieldNamer, conn.argFmt, ignoreColumns) -} - func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { if conn.ctx.Err() != nil { return sqldb.RowScannerWithError(conn.ctx.Err()) diff --git a/mockconn/connection_test.go b/mockconn/connection_test.go index 9842bf2..7eaf762 100644 --- a/mockconn/connection_test.go +++ b/mockconn/connection_test.go @@ -139,6 +139,7 @@ func TestUpdateQuery(t *testing.T) { queryOutput := bytes.NewBuffer(nil) naming := &sqldb.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: sqldb.ToSnakeCase} conn := New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming) + ctx := db.ContextWithConn(context.Background(), conn) str := "Hello World!" values := sqldb.Values{ @@ -153,13 +154,13 @@ func TestUpdateQuery(t *testing.T) { } expected := `UPDATE public.table SET "bool"=$2,"bools"=$3,"created_at"=$4,"int"=$5,"nil_ptr"=$6,"str"=$7,"str_ptr"=$8,"untagged_field"=$9 WHERE id = $1` - err := conn.Update("public.table", values, "id = $1", 1) + err := db.Update(ctx, "public.table", values, "id = $1", 1) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() expected = `UPDATE public.table SET "bool"=$3,"bools"=$4,"created_at"=$5,"int"=$6,"nil_ptr"=$7,"str"=$8,"str_ptr"=$9,"untagged_field"=$10 WHERE a = $1 AND b = $2` - err = conn.Update("public.table", values, "a = $1 AND b = $2", 1, 2) + err = db.Update(ctx, "public.table", values, "a = $1 AND b = $2", 1, 2) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) } @@ -168,6 +169,7 @@ func TestUpdateReturningQuery(t *testing.T) { queryOutput := bytes.NewBuffer(nil) naming := &sqldb.TaggedStructFieldMapping{NameTag: "db", Ignore: "-", UntaggedNameFunc: sqldb.ToSnakeCase} conn := New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming) + ctx := db.ContextWithConn(context.Background(), conn) str := "Hello World!" values := sqldb.Values{ @@ -182,13 +184,13 @@ func TestUpdateReturningQuery(t *testing.T) { } expected := `UPDATE public.table SET "bool"=$2,"bools"=$3,"created_at"=$4,"int"=$5,"nil_ptr"=$6,"str"=$7,"str_ptr"=$8,"untagged_field"=$9 WHERE id = $1 RETURNING *` - err := conn.UpdateReturningRow("public.table", values, "*", "id = $1", 1).Scan() + err := db.UpdateReturningRow(ctx, "public.table", values, "*", "id = $1", 1).Scan() assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() expected = `UPDATE public.table SET "bool"=$2,"bools"=$3,"created_at"=$4,"int"=$5,"nil_ptr"=$6,"str"=$7,"str_ptr"=$8,"untagged_field"=$9 WHERE id = $1 RETURNING created_at,untagged_field` - err = conn.UpdateReturningRows("public.table", values, "created_at,untagged_field", "id = $1", 1, 2).ScanSlice(nil) + err = db.UpdateReturningRows(ctx, "public.table", values, "created_at,untagged_field", "id = $1", 1, 2).ScanSlice(nil) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) } @@ -204,23 +206,24 @@ func TestUpdateStructQuery(t *testing.T) { UntaggedNameFunc: sqldb.ToSnakeCase, } conn := New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming) + ctx := db.ContextWithConn(context.Background(), conn) row := new(testRow) expected := `UPDATE public.table SET "int"=$2,"bool"=$3,"str"=$4,"str_ptr"=$5,"nil_ptr"=$6,"untagged_field"=$7,"created_at"=$8,"bools"=$9 WHERE "id"=$1` - err := conn.UpdateStruct("public.table", row) + err := db.UpdateStruct(ctx, "public.table", row) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() expected = `UPDATE public.table SET "bool"=$2,"str"=$3,"created_at"=$4 WHERE "id"=$1` - err = conn.UpdateStruct("public.table", row, sqldb.OnlyColumns("id", "bool", "str", "created_at")) + err = db.UpdateStruct(ctx, "public.table", row, sqldb.OnlyColumns("id", "bool", "str", "created_at")) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() expected = `UPDATE public.table SET "int"=$2,"bool"=$3,"str_ptr"=$4,"nil_ptr"=$5,"created_at"=$6 WHERE "id"=$1` - err = conn.UpdateStruct("public.table", row, sqldb.IgnoreColumns("untagged_field", "str", "bools")) + err = db.UpdateStruct(ctx, "public.table", row, sqldb.IgnoreColumns("untagged_field", "str", "bools")) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) } @@ -236,12 +239,13 @@ func TestUpsertStructQuery(t *testing.T) { UntaggedNameFunc: sqldb.ToSnakeCase, } conn := New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming) + ctx := db.ContextWithConn(context.Background(), conn) row := new(testRow) expected := `INSERT INTO public.table("id","int","bool","str","str_ptr","nil_ptr","untagged_field","created_at","bools") VALUES($1,$2,$3,$4,$5,$6,$7,$8,$9)` + ` ON CONFLICT("id") DO UPDATE SET "int"=$2,"bool"=$3,"str"=$4,"str_ptr"=$5,"nil_ptr"=$6,"untagged_field"=$7,"created_at"=$8,"bools"=$9` - err := conn.UpsertStruct("public.table", row) + err := db.UpsertStruct(ctx, "public.table", row) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) } @@ -265,11 +269,12 @@ func TestUpsertStructMultiPKQuery(t *testing.T) { UntaggedNameFunc: sqldb.ToSnakeCase, } conn := New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming) + ctx := db.ContextWithConn(context.Background(), conn) row := new(multiPrimaryKeyRow) expected := `INSERT INTO public.multi_pk("first_id","second_id","third_id","created_at") VALUES($1,$2,$3,$4) ON CONFLICT("first_id","second_id","third_id") DO UPDATE SET "created_at"=$4` - err := conn.UpsertStruct("public.multi_pk", row) + err := db.UpsertStruct(ctx, "public.multi_pk", row) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) } @@ -285,11 +290,12 @@ func TestUpdateStructMultiPKQuery(t *testing.T) { UntaggedNameFunc: sqldb.ToSnakeCase, } conn := New(context.Background(), queryOutput, nil).WithStructFieldMapper(naming) + ctx := db.ContextWithConn(context.Background(), conn) row := new(multiPrimaryKeyRow) expected := `UPDATE public.multi_pk SET "created_at"=$4 WHERE "first_id"=$1 AND "second_id"=$2 AND "third_id"=$3` - err := conn.UpdateStruct("public.multi_pk", row) + err := db.UpdateStruct(ctx, "public.multi_pk", row) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) } diff --git a/pqconn/connection.go b/pqconn/connection.go index 55a801d..29bf405 100644 --- a/pqconn/connection.go +++ b/pqconn/connection.go @@ -119,26 +119,6 @@ func (conn *connection) Exec(query string, args ...any) error { return nil } -func (conn *connection) Update(table string, values sqldb.Values, where string, args ...any) error { - return impl.Update(conn, table, values, where, argFmt, args) -} - -func (conn *connection) UpdateReturningRow(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { - return impl.UpdateReturningRow(conn, table, values, returning, where, args) -} - -func (conn *connection) UpdateReturningRows(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { - return impl.UpdateReturningRows(conn, table, values, returning, where, args) -} - -func (conn *connection) UpdateStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.UpdateStruct(conn, table, rowStruct, conn.structFieldNamer, argFmt, ignoreColumns) -} - -func (conn *connection) UpsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.UpsertStruct(conn, table, rowStruct, conn.structFieldNamer, argFmt, ignoreColumns) -} - func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { impl.WrapArrayArgs(args) rows, err := conn.db.QueryContext(conn.ctx, query, args...) diff --git a/pqconn/transaction.go b/pqconn/transaction.go index 78bc987..68746ad 100644 --- a/pqconn/transaction.go +++ b/pqconn/transaction.go @@ -72,26 +72,6 @@ func (conn *transaction) Exec(query string, args ...any) error { return impl.WrapNonNilErrorWithQuery(err, query, argFmt, args) } -func (conn *transaction) Update(table string, values sqldb.Values, where string, args ...any) error { - return impl.Update(conn, table, values, where, argFmt, args) -} - -func (conn *transaction) UpdateReturningRow(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { - return impl.UpdateReturningRow(conn, table, values, returning, where, args) -} - -func (conn *transaction) UpdateReturningRows(table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { - return impl.UpdateReturningRows(conn, table, values, returning, where, args) -} - -func (conn *transaction) UpdateStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.UpdateStruct(conn, table, rowStruct, conn.structFieldNamer, argFmt, ignoreColumns) -} - -func (conn *transaction) UpsertStruct(table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { - return impl.UpsertStruct(conn, table, rowStruct, conn.structFieldNamer, argFmt, ignoreColumns) -} - func (conn *transaction) QueryRow(query string, args ...any) sqldb.RowScanner { impl.WrapArrayArgs(args) rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) diff --git a/transaction_test.go b/transaction_test.go index c62c57c..88b4615 100644 --- a/transaction_test.go +++ b/transaction_test.go @@ -120,3 +120,10 @@ func TestCheckTxOptionsCompatibility(t *testing.T) { }) } } + +func TestNextTransactionNo(t *testing.T) { + // Always returns >= 1 + if NextTransactionNo() < 1 { + t.Fatal("NextTransactionNo() < 1") + } +} From 844985f25fae02bbf7843bb12c76dea78266f5bc Mon Sep 17 00:00:00 2001 From: Erik Unger Date: Fri, 13 Dec 2024 14:19:21 +0100 Subject: [PATCH 002/110] remove Connection.IsTransaction --- connection.go | 3 --- connectionwitherror.go | 4 ---- db/conn.go | 2 +- db/transaction.go | 8 ++++---- db/transaction_test.go | 8 ++++---- db/update.go | 18 ++++++++++++------ db/upsert.go | 14 ++++++++------ impl/connection.go | 4 ---- impl/transaction.go | 4 ---- mockconn/connection.go | 4 ---- mockconn/transaction.go | 4 ---- pqconn/connection.go | 4 ---- pqconn/transaction.go | 4 ---- 13 files changed, 29 insertions(+), 52 deletions(-) diff --git a/connection.go b/connection.go index da8bf5a..4707b02 100644 --- a/connection.go +++ b/connection.go @@ -69,9 +69,6 @@ type Connection interface { // QueryRows queries multiple rows and returns a RowsScanner for the results. QueryRows(query string, args ...any) RowsScanner - // IsTransaction returns if the connection is a transaction - IsTransaction() bool - // TransactionNo returns the globally unique number of the transaction // or zero if the connection is not a transaction. // Implementations should use the package function NextTransactionNo diff --git a/connectionwitherror.go b/connectionwitherror.go index 87e7e71..724096f 100644 --- a/connectionwitherror.go +++ b/connectionwitherror.go @@ -67,10 +67,6 @@ func (e connectionWithError) QueryRows(query string, args ...any) RowsScanner { return RowsScannerWithError(e.err) } -func (e connectionWithError) IsTransaction() bool { - return false -} - func (e connectionWithError) TransactionNo() uint64 { return 0 } diff --git a/db/conn.go b/db/conn.go index 281e248..4649947 100644 --- a/db/conn.go +++ b/db/conn.go @@ -50,5 +50,5 @@ func ContextWithConn(ctx context.Context, conn sqldb.Connection) context.Context // or the default connection if the context has none, // is a transaction. func IsTransaction(ctx context.Context) bool { - return Conn(ctx).IsTransaction() + return Conn(ctx).TransactionNo() != 0 } diff --git a/db/transaction.go b/db/transaction.go index fb85419..0d476ed 100644 --- a/db/transaction.go +++ b/db/transaction.go @@ -19,7 +19,7 @@ func ValidateWithinTransaction(ctx context.Context) error { if err := conn.Config().Err; err != nil { return err } - if !conn.IsTransaction() { + if conn.TransactionNo() == 0 { return sqldb.ErrNotWithinTransaction } return nil @@ -32,7 +32,7 @@ func ValidateNotWithinTransaction(ctx context.Context) error { if err := conn.Config().Err; err != nil { return err } - if conn.IsTransaction() { + if conn.TransactionNo() != 0 { return sqldb.ErrWithinTransaction } return nil @@ -105,7 +105,7 @@ func Transaction(ctx context.Context, txFunc func(context.Context) error) error // Because of the retryable nature, please be careful with the size of the transaction and the retry cost. func SerializedTransaction(ctx context.Context, txFunc func(context.Context) error) error { // Pass nested serialized transactions through - if Conn(ctx).IsTransaction() { + if IsTransaction(ctx) { if ctx.Value(&serializedTransactionCtxKey) == nil { return errors.New("SerializedTransaction called from within a non-serialized transaction") } @@ -166,7 +166,7 @@ func TransactionReadOnly(ctx context.Context, txFunc func(context.Context) error // they should behandled by the parent Transaction function. func TransactionSavepoint(ctx context.Context, txFunc func(context.Context) error) error { conn := Conn(ctx) - if !conn.IsTransaction() { + if conn.TransactionNo() == 0 { // If not already in a transaction, then execute txFunc // within a as transaction instead of using savepoints: return Transaction(ctx, txFunc) diff --git a/db/transaction_test.go b/db/transaction_test.go index a6e16ac..5333ee0 100644 --- a/db/transaction_test.go +++ b/db/transaction_test.go @@ -13,7 +13,7 @@ func TestSerializedTransaction(t *testing.T) { globalConn = mockconn.New(context.Background(), os.Stdout, nil) expectSerialized := func(ctx context.Context) error { - if !Conn(ctx).IsTransaction() { + if !IsTransaction(ctx) { panic("not in transaction") } if ctx.Value(&serializedTransactionCtxKey) == nil { @@ -23,7 +23,7 @@ func TestSerializedTransaction(t *testing.T) { } expectSerializedWithError := func(ctx context.Context) error { - if !Conn(ctx).IsTransaction() { + if !IsTransaction(ctx) { panic("not in transaction") } if ctx.Value(&serializedTransactionCtxKey) == nil { @@ -67,7 +67,7 @@ func TestTransaction(t *testing.T) { globalConn = mockconn.New(context.Background(), os.Stdout, nil) expectNonSerialized := func(ctx context.Context) error { - if !Conn(ctx).IsTransaction() { + if !IsTransaction(ctx) { panic("not in transaction") } if ctx.Value(&serializedTransactionCtxKey) != nil { @@ -77,7 +77,7 @@ func TestTransaction(t *testing.T) { } expectNonSerializedWithError := func(ctx context.Context) error { - if !Conn(ctx).IsTransaction() { + if !IsTransaction(ctx) { panic("not in transaction") } if ctx.Value(&serializedTransactionCtxKey) != nil { diff --git a/db/update.go b/db/update.go index 9c987d6..d97a260 100644 --- a/db/update.go +++ b/db/update.go @@ -8,6 +8,7 @@ import ( "strings" sqldb "github.com/domonda/go-sqldb" + "github.com/domonda/go-sqldb/impl" ) // Update table rows(s) with values using the where statement with passed in args starting at $1. @@ -19,7 +20,10 @@ func Update(ctx context.Context, table string, values sqldb.Values, where string query, vals := buildUpdateQuery(table, values, where, args) err := conn.Exec(query, vals...) - return WrapNonNilErrorWithQuery(err, query, vals) + if err != nil { + return wrapErrorWithQuery(err, query, vals, conn) + } + return nil } // UpdateReturningRow updates a table row with values using the where statement with passed in args starting at $1 @@ -82,7 +86,9 @@ func UpdateStruct(ctx context.Context, table string, rowStruct any, ignoreColumn return fmt.Errorf("UpdateStruct of table %s: expected struct but got %T", table, rowStruct) } - columns, pkCols, vals := ReflectStructValues(v, namer, append(ignoreColumns, sqldb.IgnoreReadOnly)) + conn := Conn(ctx) + + columns, pkCols, vals := impl.ReflectStructValues(v, conn.StructFieldMapper(), append(ignoreColumns, sqldb.IgnoreReadOnly)) if len(pkCols) == 0 { return fmt.Errorf("UpdateStruct of table %s: %s has no mapped primary key field", table, v.Type()) } @@ -112,9 +118,9 @@ func UpdateStruct(ctx context.Context, table string, rowStruct any, ignoreColumn query := b.String() - conn := Conn(ctx) - err := conn.Exec(query, vals...) - - return WrapNonNilErrorWithQuery(err, query, vals) + if err != nil { + return wrapErrorWithQuery(err, query, vals, conn) + } + return nil } diff --git a/db/upsert.go b/db/upsert.go index 05e82ad..fe0e392 100644 --- a/db/upsert.go +++ b/db/upsert.go @@ -30,13 +30,15 @@ func UpsertStruct(ctx context.Context, table string, rowStruct any, ignoreColumn return fmt.Errorf("UpsertStruct to table %s: expected struct but got %T", table, rowStruct) } - columns, pkCols, vals := impl.ReflectStructValues(v, namer, append(ignoreColumns, sqldb.IgnoreReadOnly)) + conn := Conn(ctx) + + columns, pkCols, vals := impl.ReflectStructValues(v, conn.StructFieldMapper(), append(ignoreColumns, sqldb.IgnoreReadOnly)) if len(pkCols) == 0 { return fmt.Errorf("UpsertStruct of table %s: %s has no mapped primary key field", table, v.Type()) } var b strings.Builder - writeInsertQuery(&b, table, argFmt, columns) + writeInsertQuery(&b, table, columns, conn) b.WriteString(` ON CONFLICT(`) for i, pkCol := range pkCols { if i > 0 { @@ -60,9 +62,9 @@ func UpsertStruct(ctx context.Context, table string, rowStruct any, ignoreColumn } query := b.String() - conn := Conn(ctx) - err := conn.Exec(query, vals...) - - return WrapNonNilErrorWithQuery(err, query, argFmt, vals) + if err != nil { + return wrapErrorWithQuery(err, query, vals, conn) + } + return nil } diff --git a/impl/connection.go b/impl/connection.go index 9d495d8..e3e5348 100644 --- a/impl/connection.go +++ b/impl/connection.go @@ -109,10 +109,6 @@ func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { return NewRowsScanner(conn.ctx, rows, conn.structFieldNamer, query, conn.argFmt, args) } -func (conn *connection) IsTransaction() bool { - return false -} - func (conn *connection) TransactionNo() uint64 { return 0 } diff --git a/impl/transaction.go b/impl/transaction.go index a1f5cfe..0ce15b2 100644 --- a/impl/transaction.go +++ b/impl/transaction.go @@ -90,10 +90,6 @@ func (conn *transaction) QueryRows(query string, args ...any) sqldb.RowsScanner return NewRowsScanner(conn.parent.ctx, rows, conn.structFieldNamer, query, conn.parent.argFmt, args) } -func (conn *transaction) IsTransaction() bool { - return true -} - func (conn *transaction) TransactionNo() uint64 { return conn.no } diff --git a/mockconn/connection.go b/mockconn/connection.go index d92f406..833c49b 100644 --- a/mockconn/connection.go +++ b/mockconn/connection.go @@ -116,10 +116,6 @@ func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { return conn.rowsProvider.QueryRows(conn.structFieldNamer, query, args...) } -func (conn *connection) IsTransaction() bool { - return false -} - func (conn *connection) TransactionNo() uint64 { return 0 } diff --git a/mockconn/transaction.go b/mockconn/transaction.go index 5611e23..1a4216f 100644 --- a/mockconn/transaction.go +++ b/mockconn/transaction.go @@ -27,10 +27,6 @@ func (conn transaction) WithContext(ctx context.Context) sqldb.Connection { } } -func (conn transaction) IsTransaction() bool { - return true -} - func (conn transaction) TransactionNo() uint64 { return conn.no } diff --git a/pqconn/connection.go b/pqconn/connection.go index 29bf405..0746ef5 100644 --- a/pqconn/connection.go +++ b/pqconn/connection.go @@ -139,10 +139,6 @@ func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { return impl.NewRowsScanner(conn.ctx, rows, conn.structFieldNamer, query, argFmt, args) } -func (conn *connection) IsTransaction() bool { - return false -} - func (conn *connection) TransactionNo() uint64 { return 0 } diff --git a/pqconn/transaction.go b/pqconn/transaction.go index 68746ad..4a5ed85 100644 --- a/pqconn/transaction.go +++ b/pqconn/transaction.go @@ -92,10 +92,6 @@ func (conn *transaction) QueryRows(query string, args ...any) sqldb.RowsScanner return impl.NewRowsScanner(conn.parent.ctx, rows, conn.structFieldNamer, query, argFmt, args) } -func (conn *transaction) IsTransaction() bool { - return true -} - func (conn *transaction) TransactionNo() uint64 { return conn.no } From c475b4c14dd6167fa4e701a5cac3dd31d3aa926b Mon Sep 17 00:00:00 2001 From: Erik Unger Date: Fri, 13 Dec 2024 15:40:02 +0100 Subject: [PATCH 003/110] ListenerConnection and fix tests --- connection.go | 12 ++++++--- db/listen.go | 44 ++++++++++++++++++++++++++++++++ db/update.go | 18 ++++++------- db/upsert.go | 4 +-- impl/connection.go | 13 ---------- impl/transaction.go | 14 ----------- mockconn/connection_test.go | 50 ++++++++++++++++++++----------------- 7 files changed, 90 insertions(+), 65 deletions(-) create mode 100644 db/listen.go diff --git a/connection.go b/connection.go index 4707b02..1a5f584 100644 --- a/connection.go +++ b/connection.go @@ -99,6 +99,14 @@ type Connection interface { // is not within a transaction. Rollback() error + // Close the connection. + // Transactions will be rolled back. + Close() error +} + +type ListenerConnection interface { + Connection + // ListenOnChannel will call onNotify for every channel notification // and onUnlisten if the channel gets unlistened // or the listener connection gets closed for some reason. @@ -115,8 +123,4 @@ type Connection interface { // IsListeningOnChannel returns if a channel is listened to. IsListeningOnChannel(channel string) bool - - // Close the connection. - // Transactions will be rolled back. - Close() error } diff --git a/db/listen.go b/db/listen.go new file mode 100644 index 0000000..3c82424 --- /dev/null +++ b/db/listen.go @@ -0,0 +1,44 @@ +package db + +import ( + "context" + "errors" + "fmt" + + "github.com/domonda/go-sqldb" +) + +// ListenOnChannel will call onNotify for every channel notification +// and onUnlisten if the channel gets unlistened +// or the listener connection gets closed for some reason. +// It is valid to pass nil for onNotify or onUnlisten to not get those callbacks. +// Note that the callbacks are called in sequence from a single go routine, +// so callbacks should offload long running or potentially blocking code to other go routines. +// Panics from callbacks will be recovered and logged. +func ListenOnChannel(ctx context.Context, channel string, onNotify sqldb.OnNotifyFunc, onUnlisten sqldb.OnUnlistenFunc) error { + conn, ok := Conn(ctx).(sqldb.ListenerConnection) + if !ok { + return fmt.Errorf("notifications %w", errors.ErrUnsupported) + } + return conn.ListenOnChannel(channel, onNotify, onUnlisten) +} + +// UnlistenChannel will stop listening on the channel. +// An error is returned, when the channel was not listened to +// or the listener connection is closed. +func UnlistenChannel(ctx context.Context, channel string) error { + conn, ok := Conn(ctx).(sqldb.ListenerConnection) + if !ok { + return fmt.Errorf("notifications %w", errors.ErrUnsupported) + } + return conn.UnlistenChannel(channel) +} + +// IsListeningOnChannel returns if a channel is listened to. +func IsListeningOnChannel(ctx context.Context, channel string) bool { + conn, ok := Conn(ctx).(sqldb.ListenerConnection) + if !ok { + return false + } + return conn.IsListeningOnChannel(channel) +} diff --git a/db/update.go b/db/update.go index d97a260..84b3a21 100644 --- a/db/update.go +++ b/db/update.go @@ -18,7 +18,7 @@ func Update(ctx context.Context, table string, values sqldb.Values, where string } conn := Conn(ctx) - query, vals := buildUpdateQuery(table, values, where, args) + query, vals := buildUpdateQuery(table, values, where, args, conn) err := conn.Exec(query, vals...) if err != nil { return wrapErrorWithQuery(err, query, vals, conn) @@ -34,7 +34,7 @@ func UpdateReturningRow(ctx context.Context, table string, values sqldb.Values, } conn := Conn(ctx) - query, vals := buildUpdateQuery(table, values, where, args) + query, vals := buildUpdateQuery(table, values, where, args, conn) query += " RETURNING " + returning return conn.QueryRow(query, vals...) } @@ -47,21 +47,21 @@ func UpdateReturningRows(ctx context.Context, table string, values sqldb.Values, } conn := Conn(ctx) - query, vals := buildUpdateQuery(table, values, where, args) + query, vals := buildUpdateQuery(table, values, where, args, conn) query += " RETURNING " + returning return conn.QueryRows(query, vals...) } -func buildUpdateQuery(table string, values sqldb.Values, where string, args ...any) (string, []any) { +func buildUpdateQuery(table string, values sqldb.Values, where string, args []any, argFmt sqldb.PlaceholderFormatter) (string, []any) { names, vals := values.Sorted() var query strings.Builder - fmt.Fprintf(&query, `UPDATE %s SET `, table) + fmt.Fprintf(&query, `UPDATE %s SET`, table) for i := range names { if i > 0 { query.WriteByte(',') } - fmt.Fprintf(&query, `"%s"=$%d`, names[i], 1+len(args)+i) + fmt.Fprintf(&query, ` "%s"=%s`, names[i], argFmt.Placeholder(len(args)+i)) } fmt.Fprintf(&query, ` WHERE %s`, where) @@ -94,7 +94,7 @@ func UpdateStruct(ctx context.Context, table string, rowStruct any, ignoreColumn } var b strings.Builder - fmt.Fprintf(&b, `UPDATE %s SET `, table) + fmt.Fprintf(&b, `UPDATE %s SET`, table) first := true for i := range columns { if slices.Contains(pkCols, i) { @@ -105,7 +105,7 @@ func UpdateStruct(ctx context.Context, table string, rowStruct any, ignoreColumn } else { b.WriteByte(',') } - fmt.Fprintf(&b, `"%s"=$%d`, columns[i], i+1) + fmt.Fprintf(&b, ` "%s"=%s`, columns[i], conn.Placeholder(i)) } b.WriteString(` WHERE `) @@ -113,7 +113,7 @@ func UpdateStruct(ctx context.Context, table string, rowStruct any, ignoreColumn if i > 0 { b.WriteString(` AND `) } - fmt.Fprintf(&b, `"%s"=$%d`, columns[pkCol], i+1) + fmt.Fprintf(&b, `"%s"=%s`, columns[pkCol], conn.Placeholder(i)) } query := b.String() diff --git a/db/upsert.go b/db/upsert.go index fe0e392..bded8c2 100644 --- a/db/upsert.go +++ b/db/upsert.go @@ -47,7 +47,7 @@ func UpsertStruct(ctx context.Context, table string, rowStruct any, ignoreColumn fmt.Fprintf(&b, `"%s"`, columns[pkCol]) } - b.WriteString(`) DO UPDATE SET `) + b.WriteString(`) DO UPDATE SET`) first := true for i := range columns { if slices.Contains(pkCols, i) { @@ -58,7 +58,7 @@ func UpsertStruct(ctx context.Context, table string, rowStruct any, ignoreColumn } else { b.WriteByte(',') } - fmt.Fprintf(&b, `"%s"=$%d`, columns[i], i+1) + fmt.Fprintf(&b, ` "%s"=%s`, columns[i], conn.Placeholder(i)) } query := b.String() diff --git a/impl/connection.go b/impl/connection.go index e3e5348..d68c4ad 100644 --- a/impl/connection.go +++ b/impl/connection.go @@ -3,7 +3,6 @@ package impl import ( "context" "database/sql" - "errors" "fmt" "time" @@ -133,18 +132,6 @@ func (conn *connection) Rollback() error { return sqldb.ErrNotWithinTransaction } -func (conn *connection) ListenOnChannel(channel string, onNotify sqldb.OnNotifyFunc, onUnlisten sqldb.OnUnlistenFunc) (err error) { - return fmt.Errorf("notifications %w", errors.ErrUnsupported) -} - -func (conn *connection) UnlistenChannel(channel string) (err error) { - return fmt.Errorf("notifications %w", errors.ErrUnsupported) -} - -func (conn *connection) IsListeningOnChannel(channel string) bool { - return false -} - func (conn *connection) Close() error { return conn.db.Close() } diff --git a/impl/transaction.go b/impl/transaction.go index 0ce15b2..959348b 100644 --- a/impl/transaction.go +++ b/impl/transaction.go @@ -3,8 +3,6 @@ package impl import ( "context" "database/sql" - "errors" - "fmt" "time" "github.com/domonda/go-sqldb" @@ -114,18 +112,6 @@ func (conn *transaction) Rollback() error { return conn.tx.Rollback() } -func (conn *transaction) ListenOnChannel(channel string, onNotify sqldb.OnNotifyFunc, onUnlisten sqldb.OnUnlistenFunc) (err error) { - return fmt.Errorf("notifications %w", errors.ErrUnsupported) -} - -func (conn *transaction) UnlistenChannel(channel string) (err error) { - return fmt.Errorf("notifications %w", errors.ErrUnsupported) -} - -func (conn *transaction) IsListeningOnChannel(channel string) bool { - return false -} - func (conn *transaction) Close() error { return conn.Rollback() } diff --git a/mockconn/connection_test.go b/mockconn/connection_test.go index 7eaf762..cece73e 100644 --- a/mockconn/connection_test.go +++ b/mockconn/connection_test.go @@ -52,13 +52,13 @@ func TestInsertQuery(t *testing.T) { "bools": pq.BoolArray{true, false}, } - expected := `INSERT INTO public.table("bool","bools","created_at","id","int","nil_ptr","str","str_ptr","untagged_field") VALUES($1,$2,$3,$4,$5,$6,$7,$8,$9)` + expected := `INSERT INTO public.table("bool","bools","created_at","id","int","nil_ptr","str","str_ptr","untagged_field") VALUES(?1,?2,?3,?4,?5,?6,?7,?8,?9)` err := db.Insert(ctx, "public.table", values) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() - expected = `INSERT INTO public.table("bool","bools","created_at","id","int","nil_ptr","str","str_ptr","untagged_field") VALUES($1,$2,$3,$4,$5,$6,$7,$8,$9) ON CONFLICT (id) DO NOTHING RETURNING TRUE` + expected = `INSERT INTO public.table("bool","bools","created_at","id","int","nil_ptr","str","str_ptr","untagged_field") VALUES(?1,?2,?3,?4,?5,?6,?7,?8,?9) ON CONFLICT (id) DO NOTHING RETURNING TRUE` inserted, err := db.InsertUnique(ctx, "public.table", values, "id") assert.NoError(t, err) assert.True(t, inserted) @@ -80,19 +80,19 @@ func TestInsertStructQuery(t *testing.T) { row := new(testRow) - expected := `INSERT INTO public.table("id","int","bool","str","str_ptr","nil_ptr","untagged_field","created_at","bools") VALUES($1,$2,$3,$4,$5,$6,$7,$8,$9)` + expected := `INSERT INTO public.table("id","int","bool","str","str_ptr","nil_ptr","untagged_field","created_at","bools") VALUES(?1,?2,?3,?4,?5,?6,?7,?8,?9)` err := db.InsertStruct(ctx, "public.table", row) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() - expected = `INSERT INTO public.table("id","untagged_field","bools") VALUES($1,$2,$3)` + expected = `INSERT INTO public.table("id","untagged_field","bools") VALUES(?1,?2,?3)` err = db.InsertStruct(ctx, "public.table", row, sqldb.OnlyColumns("id", "untagged_field", "bools")) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() - expected = `INSERT INTO public.table("id","int","bool","str","str_ptr","bools") VALUES($1,$2,$3,$4,$5,$6)` + expected = `INSERT INTO public.table("id","int","bool","str","str_ptr","bools") VALUES(?1,?2,?3,?4,?5,?6)` err = db.InsertStruct(ctx, "public.table", row, sqldb.IgnoreColumns("nil_ptr", "untagged_field", "created_at")) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) @@ -114,21 +114,21 @@ func TestInsertUniqueStructQuery(t *testing.T) { row := new(testRow) - expected := `INSERT INTO public.table("id","int","bool","str","str_ptr","nil_ptr","untagged_field","created_at","bools") VALUES($1,$2,$3,$4,$5,$6,$7,$8,$9) ON CONFLICT (id) DO NOTHING RETURNING TRUE` + expected := `INSERT INTO public.table("id","int","bool","str","str_ptr","nil_ptr","untagged_field","created_at","bools") VALUES(?1,?2,?3,?4,?5,?6,?7,?8,?9) ON CONFLICT (id) DO NOTHING RETURNING TRUE` inserted, err := db.InsertUniqueStruct(ctx, "public.table", row, "(id)") assert.NoError(t, err) assert.True(t, inserted) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() - expected = `INSERT INTO public.table("id","untagged_field","bools") VALUES($1,$2,$3) ON CONFLICT (id, untagged_field) DO NOTHING RETURNING TRUE` + expected = `INSERT INTO public.table("id","untagged_field","bools") VALUES(?1,?2,?3) ON CONFLICT (id, untagged_field) DO NOTHING RETURNING TRUE` inserted, err = db.InsertUniqueStruct(ctx, "public.table", row, "(id, untagged_field)", sqldb.OnlyColumns("id", "untagged_field", "bools")) assert.NoError(t, err) assert.True(t, inserted) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() - expected = `INSERT INTO public.table("id","int","bool","str","str_ptr","bools") VALUES($1,$2,$3,$4,$5,$6) ON CONFLICT (id) DO NOTHING RETURNING TRUE` + expected = `INSERT INTO public.table("id","int","bool","str","str_ptr","bools") VALUES(?1,?2,?3,?4,?5,?6) ON CONFLICT (id) DO NOTHING RETURNING TRUE` inserted, err = db.InsertUniqueStruct(ctx, "public.table", row, "(id)", sqldb.IgnoreColumns("nil_ptr", "untagged_field", "created_at")) assert.NoError(t, err) assert.True(t, inserted) @@ -153,14 +153,16 @@ func TestUpdateQuery(t *testing.T) { "bools": pq.BoolArray{true, false}, } - expected := `UPDATE public.table SET "bool"=$2,"bools"=$3,"created_at"=$4,"int"=$5,"nil_ptr"=$6,"str"=$7,"str_ptr"=$8,"untagged_field"=$9 WHERE id = $1` - err := db.Update(ctx, "public.table", values, "id = $1", 1) + // Passing one varidic arg as ?1, moves the index of the rest of the args by 1 + expected := `UPDATE public.table SET "bool"=?2, "bools"=?3, "created_at"=?4, "int"=?5, "nil_ptr"=?6, "str"=?7, "str_ptr"=?8, "untagged_field"=?9 WHERE id = ?1` + err := db.Update(ctx, "public.table", values, "id = ?1", 1) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() - expected = `UPDATE public.table SET "bool"=$3,"bools"=$4,"created_at"=$5,"int"=$6,"nil_ptr"=$7,"str"=$8,"str_ptr"=$9,"untagged_field"=$10 WHERE a = $1 AND b = $2` - err = db.Update(ctx, "public.table", values, "a = $1 AND b = $2", 1, 2) + // Passing two varidic args as ?1 and ?2, moves the index of the rest of the args by 2 + expected = `UPDATE public.table SET "bool"=?3, "bools"=?4, "created_at"=?5, "int"=?6, "nil_ptr"=?7, "str"=?8, "str_ptr"=?9, "untagged_field"=?10 WHERE a = ?1 AND b = ?2` + err = db.Update(ctx, "public.table", values, "a = ?1 AND b = ?2", 1, 2) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) } @@ -183,14 +185,16 @@ func TestUpdateReturningQuery(t *testing.T) { "bools": pq.BoolArray{true, false}, } - expected := `UPDATE public.table SET "bool"=$2,"bools"=$3,"created_at"=$4,"int"=$5,"nil_ptr"=$6,"str"=$7,"str_ptr"=$8,"untagged_field"=$9 WHERE id = $1 RETURNING *` - err := db.UpdateReturningRow(ctx, "public.table", values, "*", "id = $1", 1).Scan() + // Passing one varidic arg as ?1, moves the index of the rest of the args by 1 + expected := `UPDATE public.table SET "bool"=?2, "bools"=?3, "created_at"=?4, "int"=?5, "nil_ptr"=?6, "str"=?7, "str_ptr"=?8, "untagged_field"=?9 WHERE id = ?1 RETURNING *` + err := db.UpdateReturningRow(ctx, "public.table", values, "*", "id = ?1", 1).Scan() assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() - expected = `UPDATE public.table SET "bool"=$2,"bools"=$3,"created_at"=$4,"int"=$5,"nil_ptr"=$6,"str"=$7,"str_ptr"=$8,"untagged_field"=$9 WHERE id = $1 RETURNING created_at,untagged_field` - err = db.UpdateReturningRows(ctx, "public.table", values, "created_at,untagged_field", "id = $1", 1, 2).ScanSlice(nil) + // Passing two varidic args as ?1 and ?2, moves the index of the rest of the args by 2 + expected = `UPDATE public.table SET "bool"=?3, "bools"=?4, "created_at"=?5, "int"=?6, "nil_ptr"=?7, "str"=?8, "str_ptr"=?9, "untagged_field"=?10 WHERE id = ?1 RETURNING created_at,untagged_field` + err = db.UpdateReturningRows(ctx, "public.table", values, "created_at,untagged_field", "id = ?1", 1, 2).ScanSlice(nil) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) } @@ -210,19 +214,19 @@ func TestUpdateStructQuery(t *testing.T) { row := new(testRow) - expected := `UPDATE public.table SET "int"=$2,"bool"=$3,"str"=$4,"str_ptr"=$5,"nil_ptr"=$6,"untagged_field"=$7,"created_at"=$8,"bools"=$9 WHERE "id"=$1` + expected := `UPDATE public.table SET "int"=?2, "bool"=?3, "str"=?4, "str_ptr"=?5, "nil_ptr"=?6, "untagged_field"=?7, "created_at"=?8, "bools"=?9 WHERE "id"=?1` err := db.UpdateStruct(ctx, "public.table", row) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() - expected = `UPDATE public.table SET "bool"=$2,"str"=$3,"created_at"=$4 WHERE "id"=$1` + expected = `UPDATE public.table SET "bool"=?2, "str"=?3, "created_at"=?4 WHERE "id"=?1` err = db.UpdateStruct(ctx, "public.table", row, sqldb.OnlyColumns("id", "bool", "str", "created_at")) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) queryOutput.Reset() - expected = `UPDATE public.table SET "int"=$2,"bool"=$3,"str_ptr"=$4,"nil_ptr"=$5,"created_at"=$6 WHERE "id"=$1` + expected = `UPDATE public.table SET "int"=?2, "bool"=?3, "str_ptr"=?4, "nil_ptr"=?5, "created_at"=?6 WHERE "id"=?1` err = db.UpdateStruct(ctx, "public.table", row, sqldb.IgnoreColumns("untagged_field", "str", "bools")) assert.NoError(t, err) assert.Equal(t, expected, queryOutput.String()) @@ -242,8 +246,8 @@ func TestUpsertStructQuery(t *testing.T) { ctx := db.ContextWithConn(context.Background(), conn) row := new(testRow) - expected := `INSERT INTO public.table("id","int","bool","str","str_ptr","nil_ptr","untagged_field","created_at","bools") VALUES($1,$2,$3,$4,$5,$6,$7,$8,$9)` + - ` ON CONFLICT("id") DO UPDATE SET "int"=$2,"bool"=$3,"str"=$4,"str_ptr"=$5,"nil_ptr"=$6,"untagged_field"=$7,"created_at"=$8,"bools"=$9` + expected := `INSERT INTO public.table("id","int","bool","str","str_ptr","nil_ptr","untagged_field","created_at","bools") VALUES(?1,?2,?3,?4,?5,?6,?7,?8,?9)` + + ` ON CONFLICT("id") DO UPDATE SET "int"=?2, "bool"=?3, "str"=?4, "str_ptr"=?5, "nil_ptr"=?6, "untagged_field"=?7, "created_at"=?8, "bools"=?9` err := db.UpsertStruct(ctx, "public.table", row) assert.NoError(t, err) @@ -272,7 +276,7 @@ func TestUpsertStructMultiPKQuery(t *testing.T) { ctx := db.ContextWithConn(context.Background(), conn) row := new(multiPrimaryKeyRow) - expected := `INSERT INTO public.multi_pk("first_id","second_id","third_id","created_at") VALUES($1,$2,$3,$4) ON CONFLICT("first_id","second_id","third_id") DO UPDATE SET "created_at"=$4` + expected := `INSERT INTO public.multi_pk("first_id","second_id","third_id","created_at") VALUES(?1,?2,?3,?4) ON CONFLICT("first_id","second_id","third_id") DO UPDATE SET "created_at"=?4` err := db.UpsertStruct(ctx, "public.multi_pk", row) assert.NoError(t, err) @@ -293,7 +297,7 @@ func TestUpdateStructMultiPKQuery(t *testing.T) { ctx := db.ContextWithConn(context.Background(), conn) row := new(multiPrimaryKeyRow) - expected := `UPDATE public.multi_pk SET "created_at"=$4 WHERE "first_id"=$1 AND "second_id"=$2 AND "third_id"=$3` + expected := `UPDATE public.multi_pk SET "created_at"=?4 WHERE "first_id"=?1 AND "second_id"=?2 AND "third_id"=?3` err := db.UpdateStruct(ctx, "public.multi_pk", row) assert.NoError(t, err) From 0edd2db39e2face4ba5457a0e7ea68f056693e46 Mon Sep 17 00:00:00 2001 From: Erik Unger Date: Fri, 13 Dec 2024 16:00:24 +0100 Subject: [PATCH 004/110] TransactionOptions --- connection.go | 17 ++++++----------- connectionwitherror.go | 10 +++------- db/conn.go | 3 ++- db/transaction.go | 6 +++--- db/utils.go | 6 +++--- impl/connection.go | 10 +++------- impl/transaction.go | 10 +++------- mockconn/connection.go | 10 +++------- mockconn/transaction.go | 10 +++------- pqconn/connection.go | 10 +++------- pqconn/transaction.go | 10 +++------- transaction.go | 4 ++-- 12 files changed, 37 insertions(+), 69 deletions(-) diff --git a/connection.go b/connection.go index 1a5f584..cf35dd5 100644 --- a/connection.go +++ b/connection.go @@ -69,16 +69,11 @@ type Connection interface { // QueryRows queries multiple rows and returns a RowsScanner for the results. QueryRows(query string, args ...any) RowsScanner - // TransactionNo returns the globally unique number of the transaction - // or zero if the connection is not a transaction. - // Implementations should use the package function NextTransactionNo - // to aquire a new number in a threadsafe way. - TransactionNo() uint64 - - // TransactionOptions returns the sql.TxOptions of the - // current transaction and true as second result value, - // or false if the connection is not a transaction. - TransactionOptions() (*sql.TxOptions, bool) + // TransactionInfo returns the number and sql.TxOptions + // of the connection's transaction, + // or zero and nil if the connection is not + // in a transaction. + TransactionInfo() (no uint64, opts *sql.TxOptions) // Begin a new transaction. // If the connection is already a transaction then a brand @@ -87,7 +82,7 @@ type Connection interface { // Connection.TransactionNo method. // Implementations should use the package function NextTransactionNo // to aquire a new number in a threadsafe way. - Begin(opts *sql.TxOptions, no uint64) (Connection, error) + Begin(no uint64, opts *sql.TxOptions) (Connection, error) // Commit the current transaction. // Returns ErrNotWithinTransaction if the connection diff --git a/connectionwitherror.go b/connectionwitherror.go index 724096f..e5ed2fc 100644 --- a/connectionwitherror.go +++ b/connectionwitherror.go @@ -67,15 +67,11 @@ func (e connectionWithError) QueryRows(query string, args ...any) RowsScanner { return RowsScannerWithError(e.err) } -func (e connectionWithError) TransactionNo() uint64 { - return 0 +func (ce connectionWithError) TransactionInfo() (no uint64, opts *sql.TxOptions) { + return 0, nil } -func (ce connectionWithError) TransactionOptions() (*sql.TxOptions, bool) { - return nil, false -} - -func (e connectionWithError) Begin(opts *sql.TxOptions, no uint64) (Connection, error) { +func (e connectionWithError) Begin(no uint64, opts *sql.TxOptions) (Connection, error) { return nil, e.err } diff --git a/db/conn.go b/db/conn.go index 4649947..0f6e3d1 100644 --- a/db/conn.go +++ b/db/conn.go @@ -50,5 +50,6 @@ func ContextWithConn(ctx context.Context, conn sqldb.Connection) context.Context // or the default connection if the context has none, // is a transaction. func IsTransaction(ctx context.Context) bool { - return Conn(ctx).TransactionNo() != 0 + tx, _ := Conn(ctx).TransactionInfo() + return tx != 0 } diff --git a/db/transaction.go b/db/transaction.go index 0d476ed..2719e02 100644 --- a/db/transaction.go +++ b/db/transaction.go @@ -19,7 +19,7 @@ func ValidateWithinTransaction(ctx context.Context) error { if err := conn.Config().Err; err != nil { return err } - if conn.TransactionNo() == 0 { + if tx, _ := conn.TransactionInfo(); tx == 0 { return sqldb.ErrNotWithinTransaction } return nil @@ -32,7 +32,7 @@ func ValidateNotWithinTransaction(ctx context.Context) error { if err := conn.Config().Err; err != nil { return err } - if conn.TransactionNo() != 0 { + if tx, _ := conn.TransactionInfo(); tx != 0 { return sqldb.ErrWithinTransaction } return nil @@ -166,7 +166,7 @@ func TransactionReadOnly(ctx context.Context, txFunc func(context.Context) error // they should behandled by the parent Transaction function. func TransactionSavepoint(ctx context.Context, txFunc func(context.Context) error) error { conn := Conn(ctx) - if conn.TransactionNo() == 0 { + if tx, _ := conn.TransactionInfo(); tx == 0 { // If not already in a transaction, then execute txFunc // within a as transaction instead of using savepoints: return Transaction(ctx, txFunc) diff --git a/db/utils.go b/db/utils.go index 3d2c210..f5665fa 100644 --- a/db/utils.go +++ b/db/utils.go @@ -31,15 +31,15 @@ func IsOtherThanErrNoRows(err error) bool { // and the current time of the database using `select now()` // or an error if the time could not be queried. func DebugPrintConn(ctx context.Context, args ...any) { - opts, isTx := Conn(ctx).TransactionOptions() - if isTx { + conn := Conn(ctx) + if txNo, opts := conn.TransactionInfo(); txNo != 0 { args = append(args, "SQL-Transaction") if optsStr := TxOptionsString(opts); optsStr != "" { args = append(args, "Isolation", optsStr) } } var t time.Time - err := Conn(ctx).QueryRow("SELECT CURRENT_TIMESTAMP").Scan(&t) + err := conn.QueryRow("SELECT CURRENT_TIMESTAMP").Scan(&t) if err == nil { args = append(args, "CURRENT_TIMESTAMP:", t) } else { diff --git a/impl/connection.go b/impl/connection.go index d68c4ad..6b1d533 100644 --- a/impl/connection.go +++ b/impl/connection.go @@ -108,15 +108,11 @@ func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { return NewRowsScanner(conn.ctx, rows, conn.structFieldNamer, query, conn.argFmt, args) } -func (conn *connection) TransactionNo() uint64 { - return 0 +func (conn *connection) TransactionInfo() (no uint64, opts *sql.TxOptions) { + return 0, nil } -func (conn *connection) TransactionOptions() (*sql.TxOptions, bool) { - return nil, false -} - -func (conn *connection) Begin(opts *sql.TxOptions, no uint64) (sqldb.Connection, error) { +func (conn *connection) Begin(no uint64, opts *sql.TxOptions) (sqldb.Connection, error) { tx, err := conn.db.BeginTx(conn.ctx, opts) if err != nil { return nil, err diff --git a/impl/transaction.go b/impl/transaction.go index 959348b..2a1f21f 100644 --- a/impl/transaction.go +++ b/impl/transaction.go @@ -88,15 +88,11 @@ func (conn *transaction) QueryRows(query string, args ...any) sqldb.RowsScanner return NewRowsScanner(conn.parent.ctx, rows, conn.structFieldNamer, query, conn.parent.argFmt, args) } -func (conn *transaction) TransactionNo() uint64 { - return conn.no +func (conn *transaction) TransactionInfo() (no uint64, opts *sql.TxOptions) { + return conn.no, conn.opts } -func (conn *transaction) TransactionOptions() (*sql.TxOptions, bool) { - return conn.opts, true -} - -func (conn *transaction) Begin(opts *sql.TxOptions, no uint64) (sqldb.Connection, error) { +func (conn *transaction) Begin(no uint64, opts *sql.TxOptions) (sqldb.Connection, error) { tx, err := conn.parent.db.BeginTx(conn.parent.ctx, opts) if err != nil { return nil, err diff --git a/mockconn/connection.go b/mockconn/connection.go index 833c49b..8c88f44 100644 --- a/mockconn/connection.go +++ b/mockconn/connection.go @@ -116,15 +116,11 @@ func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { return conn.rowsProvider.QueryRows(conn.structFieldNamer, query, args...) } -func (conn *connection) TransactionNo() uint64 { - return 0 +func (conn *connection) TransactionInfo() (no uint64, opts *sql.TxOptions) { + return 0, nil } -func (conn *connection) TransactionOptions() (*sql.TxOptions, bool) { - return nil, false -} - -func (conn *connection) Begin(opts *sql.TxOptions, no uint64) (sqldb.Connection, error) { +func (conn *connection) Begin(no uint64, opts *sql.TxOptions) (sqldb.Connection, error) { if conn.queryWriter != nil { fmt.Fprint(conn.queryWriter, "BEGIN") } diff --git a/mockconn/transaction.go b/mockconn/transaction.go index 1a4216f..8676d7c 100644 --- a/mockconn/transaction.go +++ b/mockconn/transaction.go @@ -27,15 +27,11 @@ func (conn transaction) WithContext(ctx context.Context) sqldb.Connection { } } -func (conn transaction) TransactionNo() uint64 { - return conn.no +func (conn transaction) TransactionInfo() (no uint64, opts *sql.TxOptions) { + return conn.no, conn.opts } -func (conn transaction) TransactionOptions() (*sql.TxOptions, bool) { - return conn.opts, true -} - -func (conn transaction) Begin(opts *sql.TxOptions, no uint64) (sqldb.Connection, error) { +func (conn transaction) Begin(no uint64, opts *sql.TxOptions) (sqldb.Connection, error) { if conn.queryWriter != nil { fmt.Fprint(conn.queryWriter, "BEGIN") } diff --git a/pqconn/connection.go b/pqconn/connection.go index 0746ef5..4f67f4e 100644 --- a/pqconn/connection.go +++ b/pqconn/connection.go @@ -139,15 +139,11 @@ func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { return impl.NewRowsScanner(conn.ctx, rows, conn.structFieldNamer, query, argFmt, args) } -func (conn *connection) TransactionNo() uint64 { - return 0 +func (conn *connection) TransactionInfo() (no uint64, opts *sql.TxOptions) { + return 0, nil } -func (conn *connection) TransactionOptions() (*sql.TxOptions, bool) { - return nil, false -} - -func (conn *connection) Begin(opts *sql.TxOptions, no uint64) (sqldb.Connection, error) { +func (conn *connection) Begin(no uint64, opts *sql.TxOptions) (sqldb.Connection, error) { tx, err := conn.db.BeginTx(conn.ctx, opts) if err != nil { return nil, err diff --git a/pqconn/transaction.go b/pqconn/transaction.go index 4a5ed85..df2c525 100644 --- a/pqconn/transaction.go +++ b/pqconn/transaction.go @@ -92,15 +92,11 @@ func (conn *transaction) QueryRows(query string, args ...any) sqldb.RowsScanner return impl.NewRowsScanner(conn.parent.ctx, rows, conn.structFieldNamer, query, argFmt, args) } -func (conn *transaction) TransactionNo() uint64 { - return conn.no +func (conn *transaction) TransactionInfo() (no uint64, opts *sql.TxOptions) { + return conn.no, conn.opts } -func (conn *transaction) TransactionOptions() (*sql.TxOptions, bool) { - return conn.opts, true -} - -func (conn *transaction) Begin(opts *sql.TxOptions, no uint64) (sqldb.Connection, error) { +func (conn *transaction) Begin(no uint64, opts *sql.TxOptions) (sqldb.Connection, error) { tx, err := conn.parent.db.BeginTx(conn.parent.ctx, opts) if err != nil { return nil, err diff --git a/transaction.go b/transaction.go index 9cae3e3..d00836b 100644 --- a/transaction.go +++ b/transaction.go @@ -27,7 +27,7 @@ func NextTransactionNo() uint64 { // Errors and panics from txFunc will rollback the transaction if parentConn was not already a transaction. // Recovered panics are re-paniced and rollback errors after a panic are logged with ErrLogger. func Transaction(parentConn Connection, opts *sql.TxOptions, txFunc func(tx Connection) error) (err error) { - if parentOpts, parentIsTx := parentConn.TransactionOptions(); parentIsTx { + if parentTxNo, parentOpts := parentConn.TransactionInfo(); parentTxNo != 0 { err = CheckTxOptionsCompatibility(parentOpts, opts, parentConn.Config().DefaultIsolationLevel) if err != nil { return err @@ -44,7 +44,7 @@ func Transaction(parentConn Connection, opts *sql.TxOptions, txFunc func(tx Conn // Recovered panics are re-paniced and rollback errors after a panic are logged with ErrLogger. func IsolatedTransaction(parentConn Connection, opts *sql.TxOptions, txFunc func(tx Connection) error) (err error) { txNo := NextTransactionNo() - tx, e := parentConn.Begin(opts, txNo) + tx, e := parentConn.Begin(txNo, opts) if e != nil { return fmt.Errorf("Transaction %d Begin error: %w", txNo, e) } From 153ee3c9f466d8c5cbb985b89df0b2001779f6ef Mon Sep 17 00:00:00 2001 From: Erik Unger Date: Fri, 13 Dec 2024 16:03:29 +0100 Subject: [PATCH 005/110] transaction number must not be zero --- impl/connection.go | 4 ++++ impl/transaction.go | 4 ++++ mockconn/connection.go | 4 ++++ mockconn/transaction.go | 4 ++++ pqconn/connection.go | 4 ++++ pqconn/transaction.go | 4 ++++ 6 files changed, 24 insertions(+) diff --git a/impl/connection.go b/impl/connection.go index 6b1d533..e960068 100644 --- a/impl/connection.go +++ b/impl/connection.go @@ -3,6 +3,7 @@ package impl import ( "context" "database/sql" + "errors" "fmt" "time" @@ -113,6 +114,9 @@ func (conn *connection) TransactionInfo() (no uint64, opts *sql.TxOptions) { } func (conn *connection) Begin(no uint64, opts *sql.TxOptions) (sqldb.Connection, error) { + if no == 0 { + return nil, errors.New("transaction number must not be zero") + } tx, err := conn.db.BeginTx(conn.ctx, opts) if err != nil { return nil, err diff --git a/impl/transaction.go b/impl/transaction.go index 2a1f21f..1340866 100644 --- a/impl/transaction.go +++ b/impl/transaction.go @@ -3,6 +3,7 @@ package impl import ( "context" "database/sql" + "errors" "time" "github.com/domonda/go-sqldb" @@ -93,6 +94,9 @@ func (conn *transaction) TransactionInfo() (no uint64, opts *sql.TxOptions) { } func (conn *transaction) Begin(no uint64, opts *sql.TxOptions) (sqldb.Connection, error) { + if no == 0 { + return nil, errors.New("transaction number must not be zero") + } tx, err := conn.parent.db.BeginTx(conn.parent.ctx, opts) if err != nil { return nil, err diff --git a/mockconn/connection.go b/mockconn/connection.go index 8c88f44..8cf1dde 100644 --- a/mockconn/connection.go +++ b/mockconn/connection.go @@ -3,6 +3,7 @@ package mockconn import ( "context" "database/sql" + "errors" "fmt" "io" "time" @@ -121,6 +122,9 @@ func (conn *connection) TransactionInfo() (no uint64, opts *sql.TxOptions) { } func (conn *connection) Begin(no uint64, opts *sql.TxOptions) (sqldb.Connection, error) { + if no == 0 { + return nil, errors.New("transaction number must not be zero") + } if conn.queryWriter != nil { fmt.Fprint(conn.queryWriter, "BEGIN") } diff --git a/mockconn/transaction.go b/mockconn/transaction.go index 8676d7c..f5606fe 100644 --- a/mockconn/transaction.go +++ b/mockconn/transaction.go @@ -3,6 +3,7 @@ package mockconn import ( "context" "database/sql" + "errors" "fmt" "github.com/domonda/go-sqldb" @@ -32,6 +33,9 @@ func (conn transaction) TransactionInfo() (no uint64, opts *sql.TxOptions) { } func (conn transaction) Begin(no uint64, opts *sql.TxOptions) (sqldb.Connection, error) { + if no == 0 { + return nil, errors.New("transaction number must not be zero") + } if conn.queryWriter != nil { fmt.Fprint(conn.queryWriter, "BEGIN") } diff --git a/pqconn/connection.go b/pqconn/connection.go index 4f67f4e..bb6922f 100644 --- a/pqconn/connection.go +++ b/pqconn/connection.go @@ -3,6 +3,7 @@ package pqconn import ( "context" "database/sql" + "errors" "fmt" "time" @@ -144,6 +145,9 @@ func (conn *connection) TransactionInfo() (no uint64, opts *sql.TxOptions) { } func (conn *connection) Begin(no uint64, opts *sql.TxOptions) (sqldb.Connection, error) { + if no == 0 { + return nil, errors.New("transaction number must not be zero") + } tx, err := conn.db.BeginTx(conn.ctx, opts) if err != nil { return nil, err diff --git a/pqconn/transaction.go b/pqconn/transaction.go index df2c525..5ba2696 100644 --- a/pqconn/transaction.go +++ b/pqconn/transaction.go @@ -3,6 +3,7 @@ package pqconn import ( "context" "database/sql" + "errors" "time" "github.com/domonda/go-sqldb" @@ -97,6 +98,9 @@ func (conn *transaction) TransactionInfo() (no uint64, opts *sql.TxOptions) { } func (conn *transaction) Begin(no uint64, opts *sql.TxOptions) (sqldb.Connection, error) { + if no == 0 { + return nil, errors.New("transaction number must not be zero") + } tx, err := conn.parent.db.BeginTx(conn.parent.ctx, opts) if err != nil { return nil, err From d96a875fef6220c5f6a5ed37e6e424d4226a7359 Mon Sep 17 00:00:00 2001 From: Erik Unger Date: Fri, 13 Dec 2024 17:16:46 +0100 Subject: [PATCH 006/110] added Connection.Query --- connection.go | 3 +++ connectionwitherror.go | 4 ++++ impl/connection.go | 6 +++++- impl/transaction.go | 4 ++++ mockconn/connection.go | 7 +++++++ mockconn/rowsprovider.go | 1 + mockconn/singlerowprovider.go | 4 ++++ pqconn/connection.go | 9 +++++++++ pqconn/transaction.go | 9 +++++++++ rows.go | 31 +++++++++++++++++++++++++++++++ 10 files changed, 77 insertions(+), 1 deletion(-) create mode 100644 rows.go diff --git a/connection.go b/connection.go index cf35dd5..963f31f 100644 --- a/connection.go +++ b/connection.go @@ -63,6 +63,9 @@ type Connection interface { // Exec executes a query with optional args. Exec(query string, args ...any) error + // Query queries rows with optional args. + Query(query string, args ...any) (Rows, error) + // QueryRow queries a single row and returns a RowScanner for the results. QueryRow(query string, args ...any) RowScanner diff --git a/connectionwitherror.go b/connectionwitherror.go index e5ed2fc..f7eb7da 100644 --- a/connectionwitherror.go +++ b/connectionwitherror.go @@ -59,6 +59,10 @@ func (e connectionWithError) Exec(query string, args ...any) error { return e.err } +func (e connectionWithError) Query(query string, args ...any) (Rows, error) { + return nil, e.err +} + func (e connectionWithError) QueryRow(query string, args ...any) RowScanner { return RowScannerWithError(e.err) } diff --git a/impl/connection.go b/impl/connection.go index e960068..a9c1698 100644 --- a/impl/connection.go +++ b/impl/connection.go @@ -88,7 +88,11 @@ func (conn *connection) ValidateColumnName(name string) error { func (conn *connection) Exec(query string, args ...any) error { _, err := conn.db.ExecContext(conn.ctx, query, args...) - return WrapNonNilErrorWithQuery(err, query, conn.argFmt, args) + return err +} + +func (conn *connection) Query(query string, args ...any) (sqldb.Rows, error) { + return conn.db.QueryContext(conn.ctx, query, args...) } func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { diff --git a/impl/transaction.go b/impl/transaction.go index 1340866..246de8c 100644 --- a/impl/transaction.go +++ b/impl/transaction.go @@ -71,6 +71,10 @@ func (conn *transaction) Exec(query string, args ...any) error { return WrapNonNilErrorWithQuery(err, query, conn.parent.argFmt, args) } +func (conn *transaction) Query(query string, args ...any) (sqldb.Rows, error) { + return conn.tx.QueryContext(conn.parent.ctx, query, args...) +} + func (conn *transaction) QueryRow(query string, args ...any) sqldb.RowScanner { rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) if err != nil { diff --git a/mockconn/connection.go b/mockconn/connection.go index 8cf1dde..03bb5be 100644 --- a/mockconn/connection.go +++ b/mockconn/connection.go @@ -91,6 +91,13 @@ func (conn *connection) Exec(query string, args ...any) error { return nil } +func (conn *connection) Query(query string, args ...any) (sqldb.Rows, error) { + if err := conn.ctx.Err(); err != nil { + return nil, err + } + return conn.rowsProvider.Query(conn.structFieldNamer, query, args...) +} + func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { if conn.ctx.Err() != nil { return sqldb.RowScannerWithError(conn.ctx.Err()) diff --git a/mockconn/rowsprovider.go b/mockconn/rowsprovider.go index 9cf8394..cad4449 100644 --- a/mockconn/rowsprovider.go +++ b/mockconn/rowsprovider.go @@ -5,6 +5,7 @@ import ( ) type RowsProvider interface { + Query(structFieldNamer sqldb.StructFieldMapper, query string, args ...any) (sqldb.Rows, error) QueryRow(structFieldNamer sqldb.StructFieldMapper, query string, args ...any) sqldb.RowScanner QueryRows(structFieldNamer sqldb.StructFieldMapper, query string, args ...any) sqldb.RowsScanner } diff --git a/mockconn/singlerowprovider.go b/mockconn/singlerowprovider.go index 8c8e39e..ed02839 100644 --- a/mockconn/singlerowprovider.go +++ b/mockconn/singlerowprovider.go @@ -20,6 +20,10 @@ type singleRowProvider struct { argFmt string } +func (p *singleRowProvider) Query(structFieldNamer sqldb.StructFieldMapper, query string, args ...any) (sqldb.Rows, error) { + panic("TODO") +} + func (p *singleRowProvider) QueryRow(structFieldNamer sqldb.StructFieldMapper, query string, args ...any) sqldb.RowScanner { return impl.NewRowScanner(impl.RowAsRows(p.row), structFieldNamer, query, p.argFmt, args) } diff --git a/pqconn/connection.go b/pqconn/connection.go index bb6922f..8b6bb61 100644 --- a/pqconn/connection.go +++ b/pqconn/connection.go @@ -120,6 +120,15 @@ func (conn *connection) Exec(query string, args ...any) error { return nil } +func (conn *connection) Query(query string, args ...any) (sqldb.Rows, error) { + impl.WrapArrayArgs(args) + rows, err := conn.db.QueryContext(conn.ctx, query, args...) + if err != nil { + return nil, wrapKnownErrors(err) + } + return rows, nil +} + func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { impl.WrapArrayArgs(args) rows, err := conn.db.QueryContext(conn.ctx, query, args...) diff --git a/pqconn/transaction.go b/pqconn/transaction.go index 5ba2696..5c54530 100644 --- a/pqconn/transaction.go +++ b/pqconn/transaction.go @@ -73,6 +73,15 @@ func (conn *transaction) Exec(query string, args ...any) error { return impl.WrapNonNilErrorWithQuery(err, query, argFmt, args) } +func (conn *transaction) Query(query string, args ...any) (sqldb.Rows, error) { + impl.WrapArrayArgs(args) + rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) + if err != nil { + return nil, wrapKnownErrors(err) + } + return rows, nil +} + func (conn *transaction) QueryRow(query string, args ...any) sqldb.RowScanner { impl.WrapArrayArgs(args) rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) diff --git a/rows.go b/rows.go new file mode 100644 index 0000000..13715b5 --- /dev/null +++ b/rows.go @@ -0,0 +1,31 @@ +package sqldb + +// Rows is an interface with the methods of sql.Rows. +// Allows mocking for tests without an SQL driver. +type Rows interface { + // Columns returns the column names. + Columns() ([]string, error) + + // Scan copies the columns in the current row into the values pointed + // at by dest. The number of values in dest must be the same as the + // number of columns in Rows. + Scan(dest ...any) error + + // Close closes the Rows, preventing further enumeration. If Next is called + // and returns false and there are no further result sets, + // the Rows are closed automatically and it will suffice to check the + // result of Err. Close is idempotent and does not affect the result of Err. + Close() error + + // Next prepares the next result row for reading with the Scan method. It + // returns true on success, or false if there is no next result row or an error + // happened while preparing it. Err should be consulted to distinguish between + // the two cases. + // + // Every call to Scan, even the first one, must be preceded by a call to Next. + Next() bool + + // Err returns the error, if any, that was encountered during iteration. + // Err may be called after an explicit or implicit Close. + Err() error +} From 0209603737415c8c0b17334f3bc8e8d86fa8724b Mon Sep 17 00:00:00 2001 From: Erik Unger Date: Mon, 16 Dec 2024 11:44:58 +0100 Subject: [PATCH 007/110] added db.TableForStruct --- db/table.go | 30 +++++++++++++++++++ db/table_test.go | 78 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 108 insertions(+) create mode 100644 db/table.go create mode 100644 db/table_test.go diff --git a/db/table.go b/db/table.go new file mode 100644 index 0000000..3bb0571 --- /dev/null +++ b/db/table.go @@ -0,0 +1,30 @@ +package db + +import ( + "fmt" + "reflect" +) + +type Table struct{} + +var typeOfTable = reflect.TypeFor[Table]() + +func TableForStruct(t reflect.Type, tagKey string) (table string, err error) { + if t.Kind() != reflect.Struct { + return "", fmt.Errorf("db.StructTable: %s is not a struct", t) + } + if tagKey == "" { + return "", fmt.Errorf("db.StructTable: tagKey is empty") + } + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if field.Anonymous && field.Type == typeOfTable { + table = field.Tag.Get(tagKey) + if table == "" { + return "", fmt.Errorf("db.StructTable: embedded db.Table has no tag '%s'", tagKey) + } + return table, nil + } + } + return "", fmt.Errorf("db.StructTable: struct type %s has no embedded db.Table field", t) +} diff --git a/db/table_test.go b/db/table_test.go new file mode 100644 index 0000000..25a213f --- /dev/null +++ b/db/table_test.go @@ -0,0 +1,78 @@ +package db + +import ( + "reflect" + "testing" +) + +func TestTableForStruct(t *testing.T) { + tests := []struct { + name string + t reflect.Type + tagKey string + wantTable string + wantErr bool + }{ + { + name: "OK", + t: reflect.TypeFor[struct { + Table `db:"table_name"` + }](), + tagKey: "db", + wantTable: "table_name", + }, + { + name: "more struct fields", + t: reflect.TypeFor[struct { + ID int `db:"id"` + Table `db:"table_name"` + Value string `db:"value"` + }](), + tagKey: "db", + wantTable: "table_name", + }, + // Error cases + { + name: "empty", + t: reflect.TypeFor[struct{}](), + tagKey: "db", + wantErr: true, + }, + { + name: "no tagKey", + t: reflect.TypeFor[struct { + Table + }](), + tagKey: "db", + wantErr: true, + }, + { + name: "wrong tagKey", + t: reflect.TypeFor[struct { + Table `json:"table_name"` + }](), + tagKey: "db", + wantErr: true, + }, + { + name: "named field", + t: reflect.TypeFor[struct { + Table Table `db:"table_name"` + }](), + tagKey: "db", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotTable, err := TableForStruct(tt.t, tt.tagKey) + if (err != nil) != tt.wantErr { + t.Errorf("TableForStruct(%s, %#v) error = %v, wantErr %v", tt.t, tt.tagKey, err, tt.wantErr) + return + } + if gotTable != tt.wantTable { + t.Errorf("TableForStruct(%s, %#v) = %v, want %v", tt.t, tt.tagKey, gotTable, tt.wantTable) + } + }) + } +} From 1b97cb8c72b901907be5398647642f14d5432da6 Mon Sep 17 00:00:00 2001 From: Erik Unger Date: Mon, 16 Dec 2024 16:50:41 +0100 Subject: [PATCH 008/110] StructWithTableName interface --- db/insert.go | 83 +++++++++++++++------------ db/table.go | 30 ---------- structfieldmapping.go | 7 +++ tablename.go | 47 +++++++++++++++ db/table_test.go => tablename_test.go | 20 +++---- 5 files changed, 111 insertions(+), 76 deletions(-) delete mode 100644 db/table.go create mode 100644 tablename.go rename db/table_test.go => tablename_test.go (78%) diff --git a/db/insert.go b/db/insert.go index be33483..706ac75 100644 --- a/db/insert.go +++ b/db/insert.go @@ -10,42 +10,6 @@ import ( "github.com/domonda/go-sqldb/impl" ) -func writeInsertQuery(w *strings.Builder, table string, names []string, format sqldb.PlaceholderFormatter) { - fmt.Fprintf(w, `INSERT INTO %s(`, table) - for i, name := range names { - if i > 0 { - w.WriteByte(',') - } - w.WriteByte('"') - w.WriteString(name) - w.WriteByte('"') - } - w.WriteString(`) VALUES(`) - for i := range names { - if i > 0 { - w.WriteByte(',') - } - w.WriteString(format.Placeholder(i)) - } - w.WriteByte(')') -} - -func insertStructValues(table string, rowStruct any, namer sqldb.StructFieldMapper, ignoreColumns []sqldb.ColumnFilter) (columns []string, vals []any, err error) { - v := reflect.ValueOf(rowStruct) - for v.Kind() == reflect.Ptr && !v.IsNil() { - v = v.Elem() - } - switch { - case v.Kind() == reflect.Ptr && v.IsNil(): - return nil, nil, fmt.Errorf("InsertStruct into table %s: can't insert nil", table) - case v.Kind() != reflect.Struct: - return nil, nil, fmt.Errorf("InsertStruct into table %s: expected struct but got %T", table, rowStruct) - } - - columns, _, vals = impl.ReflectStructValues(v, namer, append(ignoreColumns, sqldb.IgnoreReadOnly)) - return columns, vals, nil -} - // Insert a new row into table using the values. func Insert(ctx context.Context, table string, values sqldb.Values) error { if len(values) == 0 { @@ -126,6 +90,17 @@ func InsertStruct(ctx context.Context, table string, rowStruct any, ignoreColumn return nil } +// InsertStructWithTableName inserts a new row into table using the connection's +// StructFieldMapper to map struct fields to column names. +// Optional ColumnFilter can be passed to ignore mapped columns. +func InsertStructWithTableName(ctx context.Context, row sqldb.StructWithTableName, ignoreColumns ...sqldb.ColumnFilter) error { + table, err := Conn(ctx).StructFieldMapper().TableNameForStruct(reflect.TypeOf(row)) + if err != nil { + return err + } + return InsertStruct(ctx, table, row, ignoreColumns...) +} + // InsertUniqueStruct inserts a new row into table using the connection's // StructFieldMapper to map struct fields to column names. // Optional ColumnFilter can be passed to ignore mapped columns. @@ -177,3 +152,39 @@ func InsertStructs(ctx context.Context, table string, rowStructs any, ignoreColu return nil }) } + +func writeInsertQuery(w *strings.Builder, table string, names []string, format sqldb.PlaceholderFormatter) { + fmt.Fprintf(w, `INSERT INTO %s(`, table) + for i, name := range names { + if i > 0 { + w.WriteByte(',') + } + w.WriteByte('"') + w.WriteString(name) + w.WriteByte('"') + } + w.WriteString(`) VALUES(`) + for i := range names { + if i > 0 { + w.WriteByte(',') + } + w.WriteString(format.Placeholder(i)) + } + w.WriteByte(')') +} + +func insertStructValues(table string, rowStruct any, namer sqldb.StructFieldMapper, ignoreColumns []sqldb.ColumnFilter) (columns []string, vals []any, err error) { + v := reflect.ValueOf(rowStruct) + for v.Kind() == reflect.Ptr && !v.IsNil() { + v = v.Elem() + } + switch { + case v.Kind() == reflect.Ptr && v.IsNil(): + return nil, nil, fmt.Errorf("InsertStruct into table %s: can't insert nil", table) + case v.Kind() != reflect.Struct: + return nil, nil, fmt.Errorf("InsertStruct into table %s: expected struct but got %T", table, rowStruct) + } + + columns, _, vals = impl.ReflectStructValues(v, namer, append(ignoreColumns, sqldb.IgnoreReadOnly)) + return columns, vals, nil +} diff --git a/db/table.go b/db/table.go deleted file mode 100644 index 3bb0571..0000000 --- a/db/table.go +++ /dev/null @@ -1,30 +0,0 @@ -package db - -import ( - "fmt" - "reflect" -) - -type Table struct{} - -var typeOfTable = reflect.TypeFor[Table]() - -func TableForStruct(t reflect.Type, tagKey string) (table string, err error) { - if t.Kind() != reflect.Struct { - return "", fmt.Errorf("db.StructTable: %s is not a struct", t) - } - if tagKey == "" { - return "", fmt.Errorf("db.StructTable: tagKey is empty") - } - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - if field.Anonymous && field.Type == typeOfTable { - table = field.Tag.Get(tagKey) - if table == "" { - return "", fmt.Errorf("db.StructTable: embedded db.Table has no tag '%s'", tagKey) - } - return table, nil - } - } - return "", fmt.Errorf("db.StructTable: struct type %s has no embedded db.Table field", t) -} diff --git a/structfieldmapping.go b/structfieldmapping.go index 533c972..c2a6e55 100644 --- a/structfieldmapping.go +++ b/structfieldmapping.go @@ -35,6 +35,9 @@ const ( // StructFieldMapper is used to map struct type fields to column names // and indicate special column properies via flags. type StructFieldMapper interface { + // TableNameForStruct returns the table name for a struct type + TableNameForStruct(t reflect.Type) (table string, err error) + // MapStructField returns the column name for a reflected struct field // and flags for special column properies. // If false is returned for use then the field is not mapped. @@ -83,6 +86,10 @@ type TaggedStructFieldMapping struct { UntaggedNameFunc func(fieldName string) string } +func (m *TaggedStructFieldMapping) TableNameForStruct(t reflect.Type) (table string, err error) { + return TableNameForStruct(t, m.NameTag) +} + func (m *TaggedStructFieldMapping) MapStructField(field reflect.StructField) (table, column string, flags FieldFlag, use bool) { if field.Anonymous { column, hasTag := field.Tag.Lookup(m.NameTag) diff --git a/tablename.go b/tablename.go new file mode 100644 index 0000000..2043ecd --- /dev/null +++ b/tablename.go @@ -0,0 +1,47 @@ +package sqldb + +import ( + "fmt" + "reflect" +) + +// StructWithTableName is a marker interface for structs +// that embed a TableName field to specify the table name. +type StructWithTableName interface { + HasTableName() +} + +// TableName implements the StructWithTableName marker interface +var _ StructWithTableName = TableName{} + +// TableName is an empty struct that can be embedded in other structs +// to specify the table name for the struct using a struct tag. +type TableName struct{} + +// HasTableName implements the StructWithTableName interface +func (TableName) HasTableName() {} + +func TableNameForStruct(t reflect.Type, tagKey string) (table string, err error) { + structType := t + for structType.Kind() == reflect.Pointer { + structType = structType.Elem() + } + if structType.Kind() != reflect.Struct { + return "", fmt.Errorf("db.StructTable: %s is not a struct or pointer to a struct", t) + } + if tagKey == "" { + return "", fmt.Errorf("db.StructTable: tagKey is empty") + } + tableNameType := reflect.TypeFor[TableName]() + for i := 0; i < structType.NumField(); i++ { + field := structType.Field(i) + if field.Anonymous && field.Type == tableNameType { + table = field.Tag.Get(tagKey) + if table == "" { + return "", fmt.Errorf("db.StructTable: embedded db.Table has no tag '%s'", tagKey) + } + return table, nil + } + } + return "", fmt.Errorf("db.StructTable: struct type %s has no embedded db.Table field", t) +} diff --git a/db/table_test.go b/tablename_test.go similarity index 78% rename from db/table_test.go rename to tablename_test.go index 25a213f..92446cd 100644 --- a/db/table_test.go +++ b/tablename_test.go @@ -1,11 +1,11 @@ -package db +package sqldb import ( "reflect" "testing" ) -func TestTableForStruct(t *testing.T) { +func TestTableNameForStruct(t *testing.T) { tests := []struct { name string t reflect.Type @@ -16,7 +16,7 @@ func TestTableForStruct(t *testing.T) { { name: "OK", t: reflect.TypeFor[struct { - Table `db:"table_name"` + TableName `db:"table_name"` }](), tagKey: "db", wantTable: "table_name", @@ -24,9 +24,9 @@ func TestTableForStruct(t *testing.T) { { name: "more struct fields", t: reflect.TypeFor[struct { - ID int `db:"id"` - Table `db:"table_name"` - Value string `db:"value"` + ID int `db:"id"` + TableName `db:"table_name"` + Value string `db:"value"` }](), tagKey: "db", wantTable: "table_name", @@ -41,7 +41,7 @@ func TestTableForStruct(t *testing.T) { { name: "no tagKey", t: reflect.TypeFor[struct { - Table + TableName }](), tagKey: "db", wantErr: true, @@ -49,7 +49,7 @@ func TestTableForStruct(t *testing.T) { { name: "wrong tagKey", t: reflect.TypeFor[struct { - Table `json:"table_name"` + TableName `json:"table_name"` }](), tagKey: "db", wantErr: true, @@ -57,7 +57,7 @@ func TestTableForStruct(t *testing.T) { { name: "named field", t: reflect.TypeFor[struct { - Table Table `db:"table_name"` + Table TableName `db:"table_name"` }](), tagKey: "db", wantErr: true, @@ -65,7 +65,7 @@ func TestTableForStruct(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotTable, err := TableForStruct(tt.t, tt.tagKey) + gotTable, err := TableNameForStruct(tt.t, tt.tagKey) if (err != nil) != tt.wantErr { t.Errorf("TableForStruct(%s, %#v) error = %v, wantErr %v", tt.t, tt.tagKey, err, tt.wantErr) return From 23339b66124b74a3d5cdf70f0ac762ebc811aa10 Mon Sep 17 00:00:00 2001 From: Erik Unger Date: Tue, 17 Dec 2024 07:25:26 +0100 Subject: [PATCH 009/110] StructFieldMapper.MapStructField does not return table name --- db/query.go | 34 +++++++++++++--------------------- impl/reflectstruct.go | 4 ++-- mockconn/row.go | 2 +- structfieldmapping.go | 18 ++++++++---------- structfieldmapping_test.go | 18 +++++++----------- 5 files changed, 31 insertions(+), 45 deletions(-) diff --git a/db/query.go b/db/query.go index 791ca33..e95cd64 100644 --- a/db/query.go +++ b/db/query.go @@ -119,7 +119,7 @@ func QueryRowStructOrNil[S any](ctx context.Context, query string, args ...any) // and scan it into a struct of type S that must have tagged fields // with primary key flags to identify the primary key column names // for the passed pkValue+pkValues and a table name. -func GetRow[S any](ctx context.Context, pkValue any, pkValues ...any) (row *S, err error) { +func GetRow[S sqldb.StructWithTableName](ctx context.Context, pkValue any, pkValues ...any) (row *S, err error) { // Using explicit first pkValue value // to not be able to compile without any value pkValues = append([]any{pkValue}, pkValues...) @@ -128,7 +128,11 @@ func GetRow[S any](ctx context.Context, pkValue any, pkValues ...any) (row *S, e return nil, fmt.Errorf("expected struct template type instead of %s", t) } conn := Conn(ctx) - table, pkColumns, err := pkColumnsOfStruct(conn, t) + table, err := conn.StructFieldMapper().TableNameForStruct(t) + if err != nil { + return nil, err + } + pkColumns, err := pkColumnsOfStruct(conn, t) if err != nil { return nil, err } @@ -153,7 +157,7 @@ func GetRow[S any](ctx context.Context, pkValue any, pkValues ...any) (row *S, e // for the passed pkValue+pkValues and a table name. // Returns nil as row and error if no row could be found with the // passed pkValue+pkValues. -func GetRowOrNil[S any](ctx context.Context, pkValue any, pkValues ...any) (row *S, err error) { +func GetRowOrNil[S sqldb.StructWithTableName](ctx context.Context, pkValue any, pkValues ...any) (row *S, err error) { row, err = GetRow[S](ctx, pkValue, pkValues...) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -164,41 +168,29 @@ func GetRowOrNil[S any](ctx context.Context, pkValue any, pkValues ...any) (row return row, nil } -func pkColumnsOfStruct(conn sqldb.Connection, t reflect.Type) (table string, columns []string, err error) { +func pkColumnsOfStruct(conn sqldb.Connection, t reflect.Type) (columns []string, err error) { mapper := conn.StructFieldMapper() for i := 0; i < t.NumField(); i++ { field := t.Field(i) - fieldTable, column, flags, ok := mapper.MapStructField(field) + column, flags, ok := mapper.MapStructField(field) if !ok { continue } - if fieldTable != "" && fieldTable != table { - if table != "" { - return "", nil, fmt.Errorf("table name not unique (%s vs %s) in struct %s", table, fieldTable, t) - } - table = fieldTable - } if column == "" { - fieldTable, columnsEmbed, err := pkColumnsOfStruct(conn, field.Type) + columnsEmbed, err := pkColumnsOfStruct(conn, field.Type) if err != nil { - return "", nil, err - } - if fieldTable != "" && fieldTable != table { - if table != "" { - return "", nil, fmt.Errorf("table name not unique (%s vs %s) in struct %s", table, fieldTable, t) - } - table = fieldTable + return nil, err } columns = append(columns, columnsEmbed...) } else if flags.PrimaryKey() { if err = conn.ValidateColumnName(column); err != nil { - return "", nil, fmt.Errorf("%w in struct field %s.%s", err, t, field.Name) + return nil, fmt.Errorf("%w in struct field %s.%s", err, t, field.Name) } columns = append(columns, column) } } - return table, columns, nil + return columns, nil } // QueryStructSlice returns queried rows as slice of the generic type S diff --git a/impl/reflectstruct.go b/impl/reflectstruct.go index f224d71..d1ca3d0 100644 --- a/impl/reflectstruct.go +++ b/impl/reflectstruct.go @@ -13,7 +13,7 @@ import ( func ReflectStructValues(structVal reflect.Value, namer sqldb.StructFieldMapper, ignoreColumns []sqldb.ColumnFilter) (columns []string, pkCols []int, values []any) { for i := 0; i < structVal.NumField(); i++ { fieldType := structVal.Type().Field(i) - _, column, flags, use := namer.MapStructField(fieldType) + column, flags, use := namer.MapStructField(fieldType) if !use { continue } @@ -75,7 +75,7 @@ func reflectStructColumnPointers(structVal reflect.Value, namer sqldb.StructFiel structType := structVal.Type() for i := 0; i < structType.NumField(); i++ { field := structType.Field(i) - _, column, _, use := namer.MapStructField(field) + column, _, use := namer.MapStructField(field) if !use { continue } diff --git a/mockconn/row.go b/mockconn/row.go index 579f754..1e51fb9 100644 --- a/mockconn/row.go +++ b/mockconn/row.go @@ -38,7 +38,7 @@ func (r *Row) Columns() ([]string, error) { columns := make([]string, r.rowStructVal.NumField()) for i := range columns { field := r.rowStructVal.Type().Field(i) - _, columns[i], _, _ = r.columnNamer.MapStructField(field) + columns[i], _, _ = r.columnNamer.MapStructField(field) } return columns, nil } diff --git a/structfieldmapping.go b/structfieldmapping.go index c2a6e55..dee3ab4 100644 --- a/structfieldmapping.go +++ b/structfieldmapping.go @@ -43,7 +43,7 @@ type StructFieldMapper interface { // If false is returned for use then the field is not mapped. // An empty name and true for use indicates an embedded struct // field whose fields should be recursively mapped. - MapStructField(field reflect.StructField) (table, column string, flags FieldFlag, use bool) + MapStructField(field reflect.StructField) (column string, flags FieldFlag, use bool) } // NewTaggedStructFieldMapping returns a default mapping. @@ -90,23 +90,23 @@ func (m *TaggedStructFieldMapping) TableNameForStruct(t reflect.Type) (table str return TableNameForStruct(t, m.NameTag) } -func (m *TaggedStructFieldMapping) MapStructField(field reflect.StructField) (table, column string, flags FieldFlag, use bool) { +func (m *TaggedStructFieldMapping) MapStructField(field reflect.StructField) (column string, flags FieldFlag, use bool) { if field.Anonymous { column, hasTag := field.Tag.Lookup(m.NameTag) if !hasTag { // Embedded struct fields are ok if not tagged with IgnoreName - return "", "", 0, true + return "", 0, true } if i := strings.IndexByte(column, ','); i != -1 { column = column[:i] } // Embedded struct fields are ok if not tagged with IgnoreName - return "", "", 0, column != m.Ignore + return "", 0, column != m.Ignore } if !field.IsExported() { // Not exported struct fields that are not // anonymously embedded structs are not ok - return "", "", 0, false + return "", 0, false } tag, hasTag := field.Tag.Lookup(m.NameTag) @@ -118,13 +118,11 @@ func (m *TaggedStructFieldMapping) MapStructField(field reflect.StructField) (ta continue } // Follow on parts are flags - flag, value, _ := strings.Cut(part, "=") - switch flag { + switch part { case "": // Ignore empty flags case m.PrimaryKey: flags |= FieldFlagPrimaryKey - table = value case m.ReadOnly: flags |= FieldFlagReadOnly case m.Default: @@ -136,9 +134,9 @@ func (m *TaggedStructFieldMapping) MapStructField(field reflect.StructField) (ta } if column == "" || column == m.Ignore { - return "", "", 0, false + return "", 0, false } - return table, column, flags, true + return column, flags, true } func (n TaggedStructFieldMapping) String() string { diff --git a/structfieldmapping_test.go b/structfieldmapping_test.go index 4fcece5..6b0b7c8 100644 --- a/structfieldmapping_test.go +++ b/structfieldmapping_test.go @@ -37,10 +37,10 @@ func TestTaggedStructFieldMapping_StructFieldName(t *testing.T) { } type AnonymousEmbedded struct{} var s struct { - Index int `db:"index,pk=public.my_table"` // Field(0) - IndexB int `db:"index_b,pk"` // Field(1) - Str string `db:"named_str"` // Field(2) - ReadOnly bool `db:"read_only,readonly"` // Field(3) + Index int `db:"index,pk"` // Field(0) + IndexB int `db:"index_b,pk"` // Field(1) + Str string `db:"named_str"` // Field(2) + ReadOnly bool `db:"read_only,readonly"` // Field(3) UntaggedField bool // Field(4) Ignore bool `db:"-"` // Field(5) PKReadOnly int `db:"pk_read_only,pk,readonly"` // Field(6) @@ -53,13 +53,12 @@ func TestTaggedStructFieldMapping_StructFieldName(t *testing.T) { tests := []struct { name string structField reflect.StructField - wantTable string wantColumn string wantFlags FieldFlag wantOk bool }{ - {name: "index", structField: st.Field(0), wantTable: "public.my_table", wantColumn: "index", wantFlags: FieldFlagPrimaryKey, wantOk: true}, - {name: "index_b", structField: st.Field(1), wantTable: "", wantColumn: "index_b", wantFlags: FieldFlagPrimaryKey, wantOk: true}, + {name: "index", structField: st.Field(0), wantColumn: "index", wantFlags: FieldFlagPrimaryKey, wantOk: true}, + {name: "index_b", structField: st.Field(1), wantColumn: "index_b", wantFlags: FieldFlagPrimaryKey, wantOk: true}, {name: "named_str", structField: st.Field(2), wantColumn: "named_str", wantFlags: 0, wantOk: true}, {name: "read_only", structField: st.Field(3), wantColumn: "read_only", wantFlags: FieldFlagReadOnly, wantOk: true}, {name: "untagged_field", structField: st.Field(4), wantColumn: "untagged_field", wantFlags: 0, wantOk: true}, @@ -71,10 +70,7 @@ func TestTaggedStructFieldMapping_StructFieldName(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotTable, gotColumn, gotFlags, gotOk := naming.MapStructField(tt.structField) - if gotTable != tt.wantTable { - t.Errorf("TaggedStructFieldMapping.MapStructField(%q) gotTable = %q, want %q", tt.structField.Name, gotTable, tt.wantTable) - } + gotColumn, gotFlags, gotOk := naming.MapStructField(tt.structField) if gotColumn != tt.wantColumn { t.Errorf("TaggedStructFieldMapping.MapStructField(%q) gotColumn = %q, want %q", tt.structField.Name, gotColumn, tt.wantColumn) } From b46366e33acf2689107e43d18022f645719dca6a Mon Sep 17 00:00:00 2001 From: Erik Unger Date: Fri, 20 Dec 2024 20:24:47 +0100 Subject: [PATCH 010/110] refactor --- connection.go | 24 ++--- connectionwitherror.go | 16 +--- db/errors.go | 25 +++-- {impl => db}/foreachrow.go | 8 +- {impl => db}/foreachrow_test.go | 2 +- db/insert.go | 45 +++++---- db/multirowscanner.go | 164 ++++++++++++++++++++++++++++++++ db/query.go | 56 ++++++----- {impl => db}/reflectstruct.go | 13 +-- db/rowscanner.go | 130 +++++++++++++++++++++++++ db/scanresult.go | 54 +++++++++++ {impl => db}/scanstruct.go | 7 +- {impl => db}/scanstruct_test.go | 2 +- db/transaction_test.go | 2 + db/update.go | 53 +++++------ db/upsert.go | 3 +- db/utils.go | 3 +- errors.go | 74 -------------- examples/user_demo/user_demo.go | 2 +- impl/arrays.go | 12 +++ impl/connection.go | 46 +++++---- impl/insert.go | 2 + impl/now.go | 15 --- impl/rows.go | 41 -------- impl/rowscanner.go | 129 ------------------------- impl/rowsscanner.go | 154 ------------------------------ impl/scanresult.go | 54 ----------- impl/transaction.go | 46 +++++---- information/primarykeys.go | 17 ++-- information/table.go | 22 ++--- mockconn/connection.go | 58 +++++------ mockconn/connection_test.go | 20 ++-- mockconn/onetimerowsprovider.go | 32 +++---- mockconn/row.go | 6 +- mockconn/row_test.go | 2 +- mockconn/rows.go | 2 +- mockconn/rows_test.go | 2 +- mockconn/rowsprovider.go | 6 +- mockconn/singlerowprovider.go | 50 +++++----- pqconn/connection.go | 54 +++++------ pqconn/transaction.go | 54 +++++------ impl/row.go => row.go | 5 +- rows.go | 14 +++ rowscanner.go | 2 + rowsscanner.go | 2 + structfieldmapping.go | 32 +++---- structfieldmapping_test.go | 10 +- 47 files changed, 741 insertions(+), 831 deletions(-) rename {impl => db}/foreachrow.go (96%) rename {impl => db}/foreachrow_test.go (98%) create mode 100644 db/multirowscanner.go rename {impl => db}/reflectstruct.go (88%) create mode 100644 db/rowscanner.go create mode 100644 db/scanresult.go rename {impl => db}/scanstruct.go (79%) rename {impl => db}/scanstruct_test.go (99%) delete mode 100644 impl/now.go delete mode 100644 impl/rows.go delete mode 100644 impl/rowscanner.go delete mode 100644 impl/rowsscanner.go delete mode 100644 impl/scanresult.go rename impl/row.go => row.go (73%) diff --git a/connection.go b/connection.go index 963f31f..16693a5 100644 --- a/connection.go +++ b/connection.go @@ -26,20 +26,17 @@ type PlaceholderFormatter interface { type Connection interface { PlaceholderFormatter - // Context that all connection operations use. - // See also WithContext. + // TODO remove Context() context.Context - // WithContext returns a connection that uses the passed - // context for its operations. + // TODO remove WithContext(ctx context.Context) Connection - // WithStructFieldMapper returns a copy of the connection - // that will use the passed StructFieldMapper. - WithStructFieldMapper(StructFieldMapper) Connection + // TODO remove + WithStructFieldMapper(StructReflector) Connection - // StructFieldMapper used by methods of this Connection. - StructFieldMapper() StructFieldMapper + // TODO remove + StructReflector() StructReflector // Ping returns an error if the database // does not answer on this connection @@ -64,13 +61,8 @@ type Connection interface { Exec(query string, args ...any) error // Query queries rows with optional args. - Query(query string, args ...any) (Rows, error) - - // QueryRow queries a single row and returns a RowScanner for the results. - QueryRow(query string, args ...any) RowScanner - - // QueryRows queries multiple rows and returns a RowsScanner for the results. - QueryRows(query string, args ...any) RowsScanner + // Any error will be returned by the Rows.Err method. + Query(query string, args ...any) Rows // TransactionInfo returns the number and sql.TxOptions // of the connection's transaction, diff --git a/connectionwitherror.go b/connectionwitherror.go index f7eb7da..766c011 100644 --- a/connectionwitherror.go +++ b/connectionwitherror.go @@ -27,11 +27,11 @@ func (e connectionWithError) WithContext(ctx context.Context) Connection { return connectionWithError{ctx: ctx, err: e.err} } -func (e connectionWithError) WithStructFieldMapper(namer StructFieldMapper) Connection { +func (e connectionWithError) WithStructFieldMapper(namer StructReflector) Connection { return e } -func (e connectionWithError) StructFieldMapper() StructFieldMapper { +func (e connectionWithError) StructReflector() StructReflector { return DefaultStructFieldMapping } @@ -59,16 +59,8 @@ func (e connectionWithError) Exec(query string, args ...any) error { return e.err } -func (e connectionWithError) Query(query string, args ...any) (Rows, error) { - return nil, e.err -} - -func (e connectionWithError) QueryRow(query string, args ...any) RowScanner { - return RowScannerWithError(e.err) -} - -func (e connectionWithError) QueryRows(query string, args ...any) RowsScanner { - return RowsScannerWithError(e.err) +func (e connectionWithError) Query(query string, args ...any) Rows { + return RowsErr(e.err) } func (ce connectionWithError) TransactionInfo() (no uint64, opts *sql.TxOptions) { diff --git a/db/errors.go b/db/errors.go index 7b761cf..837023c 100644 --- a/db/errors.go +++ b/db/errors.go @@ -1,27 +1,24 @@ package db import ( + "errors" "fmt" "github.com/domonda/go-sqldb" "github.com/domonda/go-sqldb/impl" ) -// // WrapNonNilErrorWithQuery wraps non nil errors with a formatted query -// // if the error was not already wrapped with a query. -// // If the passed error is nil, then nil will be returned. -// func WrapNonNilErrorWithQuery(err error, query string, args []any, argFmt sqldb.PlaceholderFormatter) error { -// if err == nil { -// return nil -// } -// var wrapped errWithQuery -// if errors.As(err, &wrapped) { -// return err // already wrapped -// } -// return errWithQuery{err, query, args, argFmt} -// } - +// wrapErrorWithQuery wraps an errors with a formatted query +// if the error was not already wrapped with a query. +// If the passed error is nil, then nil will be returned. func wrapErrorWithQuery(err error, query string, args []any, argFmt sqldb.PlaceholderFormatter) error { + if err == nil { + return nil + } + var wrapped errWithQuery + if errors.As(err, &wrapped) { + return err // already wrapped + } return errWithQuery{err, query, args, argFmt} } diff --git a/impl/foreachrow.go b/db/foreachrow.go similarity index 96% rename from impl/foreachrow.go rename to db/foreachrow.go index ccaaf24..14f25b9 100644 --- a/impl/foreachrow.go +++ b/db/foreachrow.go @@ -1,4 +1,4 @@ -package impl +package db import ( "context" @@ -7,8 +7,6 @@ import ( "fmt" "reflect" "time" - - sqldb "github.com/domonda/go-sqldb" ) var ( @@ -31,7 +29,7 @@ var ( // If a non nil error is returned from the callback, then this error // is returned immediately by this function without scanning further rows. // In case of zero rows, no error will be returned. -func ForEachRowCallFunc(ctx context.Context, callback any) (f func(sqldb.RowScanner) error, err error) { +func ForEachRowCallFunc(ctx context.Context, callback any) (f func(*RowScanner) error, err error) { val := reflect.ValueOf(callback) typ := val.Type() if typ.Kind() != reflect.Func { @@ -76,7 +74,7 @@ func ForEachRowCallFunc(ctx context.Context, callback any) (f func(sqldb.RowScan return nil, fmt.Errorf("ForEachRowCall callback function result must be of type error: %s", typ) } - f = func(row sqldb.RowScanner) (err error) { + f = func(row *RowScanner) (err error) { // First scan row scannedValPtrs := make([]any, typ.NumIn()-firstArg) for i := range scannedValPtrs { diff --git a/impl/foreachrow_test.go b/db/foreachrow_test.go similarity index 98% rename from impl/foreachrow_test.go rename to db/foreachrow_test.go index 7509553..e3491ec 100644 --- a/impl/foreachrow_test.go +++ b/db/foreachrow_test.go @@ -1,4 +1,4 @@ -package impl +package db import ( "testing" diff --git a/db/insert.go b/db/insert.go index 706ac75..8179b57 100644 --- a/db/insert.go +++ b/db/insert.go @@ -7,7 +7,6 @@ import ( "strings" "github.com/domonda/go-sqldb" - "github.com/domonda/go-sqldb/impl" ) // Insert a new row into table using the values. @@ -46,7 +45,7 @@ func InsertUnique(ctx context.Context, table string, values sqldb.Values, onConf writeInsertQuery(&query, table, names, conn) fmt.Fprintf(&query, " ON CONFLICT (%s) DO NOTHING RETURNING TRUE", onConflict) - err = conn.QueryRow(query.String(), vals...).Scan(&inserted) + inserted, err = QueryValue[bool](ctx, query.String(), vals...) err = sqldb.ReplaceErrNoRows(err, nil) if err != nil { return false, wrapErrorWithQuery(err, query.String(), vals, conn) @@ -54,28 +53,28 @@ func InsertUnique(ctx context.Context, table string, values sqldb.Values, onConf return inserted, err } -// InsertReturning inserts a new row into table using values -// and returns values from the inserted row listed in returning. -func InsertReturning(ctx context.Context, table string, values sqldb.Values, returning string) sqldb.RowScanner { - if len(values) == 0 { - return sqldb.RowScannerWithError(fmt.Errorf("InsertReturning into table %s: no values", table)) - } - conn := Conn(ctx) - - var query strings.Builder - names, vals := values.Sorted() - writeInsertQuery(&query, table, names, conn) - query.WriteString(" RETURNING ") - query.WriteString(returning) - return conn.QueryRow(query.String(), vals...) // TODO wrap error with query -} +// // InsertReturning inserts a new row into table using values +// // and returns values from the inserted row listed in returning. +// func InsertReturning(ctx context.Context, table string, values sqldb.Values, returning string) sqldb.RowScanner { +// if len(values) == 0 { +// return sqldb.RowScannerWithError(fmt.Errorf("InsertReturning into table %s: no values", table)) +// } +// conn := Conn(ctx) + +// var query strings.Builder +// names, vals := values.Sorted() +// writeInsertQuery(&query, table, names, conn) +// query.WriteString(" RETURNING ") +// query.WriteString(returning) +// return conn.QueryRow(query.String(), vals...) // TODO wrap error with query +// } // InsertStruct inserts a new row into table using the connection's // StructFieldMapper to map struct fields to column names. // Optional ColumnFilter can be passed to ignore mapped columns. func InsertStruct(ctx context.Context, table string, rowStruct any, ignoreColumns ...sqldb.ColumnFilter) error { conn := Conn(ctx) - columns, vals, err := insertStructValues(table, rowStruct, conn.StructFieldMapper(), ignoreColumns) + columns, vals, err := insertStructValues(table, rowStruct, conn.StructReflector(), ignoreColumns) if err != nil { return err } @@ -94,7 +93,7 @@ func InsertStruct(ctx context.Context, table string, rowStruct any, ignoreColumn // StructFieldMapper to map struct fields to column names. // Optional ColumnFilter can be passed to ignore mapped columns. func InsertStructWithTableName(ctx context.Context, row sqldb.StructWithTableName, ignoreColumns ...sqldb.ColumnFilter) error { - table, err := Conn(ctx).StructFieldMapper().TableNameForStruct(reflect.TypeOf(row)) + table, err := Conn(ctx).StructReflector().TableNameForStruct(reflect.TypeOf(row)) if err != nil { return err } @@ -108,7 +107,7 @@ func InsertStructWithTableName(ctx context.Context, row sqldb.StructWithTableNam // and returns if a row was inserted. func InsertUniqueStruct(ctx context.Context, table string, rowStruct any, onConflict string, ignoreColumns ...sqldb.ColumnFilter) (inserted bool, err error) { conn := Conn(ctx) - columns, vals, err := insertStructValues(table, rowStruct, conn.StructFieldMapper(), ignoreColumns) + columns, vals, err := insertStructValues(table, rowStruct, conn.StructReflector(), ignoreColumns) if err != nil { return false, err } @@ -121,7 +120,7 @@ func InsertUniqueStruct(ctx context.Context, table string, rowStruct any, onConf writeInsertQuery(&query, table, columns, conn) fmt.Fprintf(&query, " ON CONFLICT (%s) DO NOTHING RETURNING TRUE", onConflict) - err = conn.QueryRow(query.String(), vals...).Scan(&inserted) + inserted, err = QueryValue[bool](ctx, query.String(), vals...) err = sqldb.ReplaceErrNoRows(err, nil) if err != nil { return false, wrapErrorWithQuery(err, query.String(), vals, conn) @@ -173,7 +172,7 @@ func writeInsertQuery(w *strings.Builder, table string, names []string, format s w.WriteByte(')') } -func insertStructValues(table string, rowStruct any, namer sqldb.StructFieldMapper, ignoreColumns []sqldb.ColumnFilter) (columns []string, vals []any, err error) { +func insertStructValues(table string, rowStruct any, namer sqldb.StructReflector, ignoreColumns []sqldb.ColumnFilter) (columns []string, vals []any, err error) { v := reflect.ValueOf(rowStruct) for v.Kind() == reflect.Ptr && !v.IsNil() { v = v.Elem() @@ -185,6 +184,6 @@ func insertStructValues(table string, rowStruct any, namer sqldb.StructFieldMapp return nil, nil, fmt.Errorf("InsertStruct into table %s: expected struct but got %T", table, rowStruct) } - columns, _, vals = impl.ReflectStructValues(v, namer, append(ignoreColumns, sqldb.IgnoreReadOnly)) + columns, _, vals = ReflectStructValues(v, namer, append(ignoreColumns, sqldb.IgnoreReadOnly)) return columns, vals, nil } diff --git a/db/multirowscanner.go b/db/multirowscanner.go new file mode 100644 index 0000000..3b2989e --- /dev/null +++ b/db/multirowscanner.go @@ -0,0 +1,164 @@ +package db + +import ( + "context" + "errors" + "fmt" + "reflect" + + "github.com/domonda/go-sqldb" +) + +// ScanRowsAsSlice scans all srcRows as slice into dest. +// +// The sqlRows must either have only one column compatible with the element type of the slice, +// or in case of multiple columns the slice element type must be a struct or struct pointer +// so that every column maps on exactly one struct field using the passed reflector. +// +// In case of single column rows, nil must be passed for reflector. +// +// The function closes the sqlRows. +// +// TODO two different functions for single column and multi column rows? +func ScanRowsAsSlice(ctx context.Context, sqlRows sqldb.Rows, reflector sqldb.StructReflector, dest any) error { + defer sqlRows.Close() + + destVal := reflect.ValueOf(dest) + if destVal.Kind() != reflect.Ptr { + return fmt.Errorf("scan dest is not a pointer but %s", destVal.Type()) + } + if destVal.IsNil() { + return errors.New("scan dest is nil") + } + slice := destVal.Elem() + if slice.Kind() != reflect.Slice { + return fmt.Errorf("scan dest is not pointer to slice but %s", destVal.Type()) + } + sliceElemType := slice.Type().Elem() + + newSlice := reflect.MakeSlice(slice.Type(), 0, 32) + + for sqlRows.Next() { + if ctx.Err() != nil { + return ctx.Err() + } + + newSlice = reflect.Append(newSlice, reflect.Zero(sliceElemType)) + target := newSlice.Index(newSlice.Len() - 1).Addr() + if reflector != nil { + err := scanStruct(sqlRows, reflector, target.Interface()) + if err != nil { + return err + } + } else { + err := sqlRows.Scan(target.Interface()) + if err != nil { + return err + } + } + } + if sqlRows.Err() != nil { + return sqlRows.Err() + } + + // Assign newSlice if there were no errors + if newSlice.Len() == 0 { + slice.SetLen(0) + } else { + slice.Set(newSlice) + } + + return nil +} + +/* +// MultiRowScanner +type MultiRowScanner struct { + ctx context.Context // ctx is checked for every row and passed through to callbacks + rows sqldb.Rows + reflector sqldb.StructReflector + argFmt sqldb.PlaceholderFormatter // for error wrapping + query string // for error wrapping + args []any // for error wrapping +} + +func NewMultiRowScanner(ctx context.Context, rows sqldb.Rows, reflector sqldb.StructReflector, argFmt sqldb.PlaceholderFormatter, query string, args []any) *MultiRowScanner { + return &MultiRowScanner{ctx, rows, reflector, argFmt, query, args} +} + +func (s *MultiRowScanner) Columns() ([]string, error) { + cols, err := s.rows.Columns() + if err != nil { + return nil, wrapErrorWithQuery(err, s.query, s.args, s.argFmt) + } + return cols, nil +} + +func (s *MultiRowScanner) ScanSlice(dest any) error { + err := ScanRowsAsSlice(s.ctx, s.rows, dest, nil) + if err != nil { + return wrapErrorWithQuery(err, s.query, s.args, s.argFmt) + } + return nil +} + +// TODO is ScanStructSlice needed besides ScanSlice? +func (s *MultiRowScanner) ScanStructSlice(dest any) error { + err := ScanRowsAsSlice(s.ctx, s.rows, dest, s.reflector) + if err != nil { + return wrapErrorWithQuery(err, s.query, s.args, s.argFmt) + } + return nil +} + +// func (s *MultiRowScanner) ScanAllRowsAsStrings(headerRow bool) (rows [][]string, err error) { +// cols, err := s.rows.Columns() +// if err != nil { +// return nil, err +// } +// if headerRow { +// rows = [][]string{cols} +// } +// stringScannablePtrs := make([]any, len(cols)) +// err = s.ForEachRow(func(rowScanner sqldb.RowScanner) error { +// row := make([]string, len(cols)) +// for i := range stringScannablePtrs { +// stringScannablePtrs[i] = (*sqldb.StringScannable)(&row[i]) +// } +// err := rowScanner.Scan(stringScannablePtrs...) +// if err != nil { +// return err +// } +// rows = append(rows, row) +// return nil +// }) +// return rows, err +// } + +// func (s *MultiRowScanner) ForEachRow(callback func(*RowScanner) error) (err error) { +// defer func() { +// err = errors.Join(err, s.rows.Close()) +// err = WrapNonNilErrorWithQuery(err, s.query, s.argFmt, s.args) +// }() + +// for s.rows.Next() { +// if s.ctx.Err() != nil { +// return s.ctx.Err() +// } + +// err := callback(CurrentRowScanner{s.rows, s.reflector}) +// if err != nil { +// return err +// } +// } +// return s.rows.Err() +// } + +// func (s *MultiRowScanner) ForEachRowCall(callback any) error { +// forEachRowFunc, err := ForEachRowCallFunc(s.ctx, callback) +// if err != nil { +// return err +// } +// return s.ForEachRow(forEachRowFunc) +// } +*/ diff --git a/db/query.go b/db/query.go index e95cd64..25de083 100644 --- a/db/query.go +++ b/db/query.go @@ -21,7 +21,7 @@ import ( // Useful for getting the timestamp of a // SQL transaction for use in Go code. func CurrentTimestamp(ctx context.Context) time.Time { - t, err := QueryValue[time.Time](ctx, "SELECT CURRENT_TIMESTAMP") + t, err := QueryValue[time.Time](ctx, `SELECT CURRENT_TIMESTAMP`) if err != nil { return time.Now() } @@ -30,22 +30,31 @@ func CurrentTimestamp(ctx context.Context) time.Time { // Exec executes a query with optional args. func Exec(ctx context.Context, query string, args ...any) error { - return Conn(ctx).Exec(query, args...) + conn := Conn(ctx) + err := conn.Exec(query, args...) + if err != nil { + return wrapErrorWithQuery(err, query, args, conn) + } + return nil } // QueryRow queries a single row and returns a RowScanner for the results. -func QueryRow(ctx context.Context, query string, args ...any) sqldb.RowScanner { - return Conn(ctx).QueryRow(query, args...) +func QueryRow(ctx context.Context, query string, args ...any) *RowScanner { + conn := Conn(ctx) + rows := conn.Query(query, args...) + return NewRowScanner(rows, conn.StructReflector(), conn, query, args) } -// QueryRows queries multiple rows and returns a RowsScanner for the results. -func QueryRows(ctx context.Context, query string, args ...any) sqldb.RowsScanner { - return Conn(ctx).QueryRows(query, args...) -} +// // QueryRows queries multiple rows and returns a RowsScanner for the results. +// func QueryRows(ctx context.Context, query string, args ...any) *MultiRowScanner { +// conn := Conn(ctx) +// rows := conn.Query(query, args...) +// return NewMultiRowScanner(ctx, rows, conn.StructReflector(), conn, query, args) +// } // QueryValue queries a single value of type T. func QueryValue[T any](ctx context.Context, query string, args ...any) (value T, err error) { - err = Conn(ctx).QueryRow(query, args...).Scan(&value) + err = QueryRow(ctx, query, args...).Scan(&value) if err != nil { return *new(T), err } @@ -56,7 +65,7 @@ func QueryValue[T any](ctx context.Context, query string, args ...any) (value T, // In case of an sql.ErrNoRows error, errNoRows will be called // and its result returned together with the default value for T. func QueryValueReplaceErrNoRows[T any](ctx context.Context, errNoRows func() error, query string, args ...any) (value T, err error) { - err = Conn(ctx).QueryRow(query, args...).Scan(&value) + err = QueryRow(ctx, query, args...).Scan(&value) if err != nil { if errors.Is(err, sql.ErrNoRows) && errNoRows != nil { return *new(T), errNoRows() @@ -69,7 +78,7 @@ func QueryValueReplaceErrNoRows[T any](ctx context.Context, errNoRows func() err // QueryValueOr queries a single value of type T // or returns the passed defaultValue in case of sql.ErrNoRows. func QueryValueOr[T any](ctx context.Context, defaultValue T, query string, args ...any) (value T, err error) { - err = Conn(ctx).QueryRow(query, args...).Scan(&value) + err = QueryRow(ctx, query, args...).Scan(&value) if err != nil { if errors.Is(err, sql.ErrNoRows) { return defaultValue, nil @@ -81,9 +90,12 @@ func QueryValueOr[T any](ctx context.Context, defaultValue T, query string, args // QueryRowStruct queries a row and scans it as struct. func QueryRowStruct[S any](ctx context.Context, query string, args ...any) (row *S, err error) { - err = Conn(ctx).QueryRow(query, args...).ScanStruct(&row) + conn := Conn(ctx) + rows := conn.Query(query, args...) + defer rows.Close() + err = scanStruct(rows, conn.StructReflector(), &row) if err != nil { - return nil, err + return nil, wrapErrorWithQuery(err, query, args, conn) } return row, nil } @@ -92,7 +104,7 @@ func QueryRowStruct[S any](ctx context.Context, query string, args ...any) (row // In case of an sql.ErrNoRows error, errNoRows will be called // and its result returned as error together with nil as row. func QueryRowStructReplaceErrNoRows[S any](ctx context.Context, errNoRows func() error, query string, args ...any) (row *S, err error) { - err = Conn(ctx).QueryRow(query, args...).ScanStruct(&row) + row, err = QueryRowStruct[S](ctx, query, args...) if err != nil { if errors.Is(err, sql.ErrNoRows) && errNoRows != nil { return nil, errNoRows() @@ -105,7 +117,7 @@ func QueryRowStructReplaceErrNoRows[S any](ctx context.Context, errNoRows func() // QueryRowStructOrNil queries a row and scans it as struct // or returns nil in case of sql.ErrNoRows. func QueryRowStructOrNil[S any](ctx context.Context, query string, args ...any) (row *S, err error) { - err = Conn(ctx).QueryRow(query, args...).ScanStruct(&row) + row, err = QueryRowStruct[S](ctx, query, args...) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil @@ -128,7 +140,7 @@ func GetRow[S sqldb.StructWithTableName](ctx context.Context, pkValue any, pkVal return nil, fmt.Errorf("expected struct template type instead of %s", t) } conn := Conn(ctx) - table, err := conn.StructFieldMapper().TableNameForStruct(t) + table, err := conn.StructReflector().TableNameForStruct(t) if err != nil { return nil, err } @@ -144,11 +156,7 @@ func GetRow[S sqldb.StructWithTableName](ctx context.Context, pkValue any, pkVal for i := 1; i < len(pkColumns); i++ { fmt.Fprintf(&query, ` AND "%s" = $%d`, pkColumns[i], i+1) //#nosec G104 } - err = conn.QueryRow(query.String(), pkValues...).ScanStruct(&row) - if err != nil { - return nil, err - } - return row, nil + return QueryRowStruct[S](ctx, query.String(), pkValues...) } // GetRowOrNil uses the passed pkValue+pkValues to query a table row @@ -169,7 +177,7 @@ func GetRowOrNil[S sqldb.StructWithTableName](ctx context.Context, pkValue any, } func pkColumnsOfStruct(conn sqldb.Connection, t reflect.Type) (columns []string, err error) { - mapper := conn.StructFieldMapper() + mapper := conn.StructReflector() for i := 0; i < t.NumField(); i++ { field := t.Field(i) column, flags, ok := mapper.MapStructField(field) @@ -196,7 +204,9 @@ func pkColumnsOfStruct(conn sqldb.Connection, t reflect.Type) (columns []string, // QueryStructSlice returns queried rows as slice of the generic type S // which must be a struct or a pointer to a struct. func QueryStructSlice[S any](ctx context.Context, query string, args ...any) (rows []S, err error) { - err = Conn(ctx).QueryRows(query, args...).ScanStructSlice(&rows) + conn := Conn(ctx) + sqlRows := conn.Query(query, args...) + err = ScanRowsAsSlice(ctx, sqlRows, conn.StructReflector(), &rows) if err != nil { return nil, err } diff --git a/impl/reflectstruct.go b/db/reflectstruct.go similarity index 88% rename from impl/reflectstruct.go rename to db/reflectstruct.go index d1ca3d0..0d2c477 100644 --- a/impl/reflectstruct.go +++ b/db/reflectstruct.go @@ -1,4 +1,4 @@ -package impl +package db import ( "errors" @@ -8,9 +8,10 @@ import ( "strings" "github.com/domonda/go-sqldb" + "github.com/domonda/go-sqldb/impl" ) -func ReflectStructValues(structVal reflect.Value, namer sqldb.StructFieldMapper, ignoreColumns []sqldb.ColumnFilter) (columns []string, pkCols []int, values []any) { +func ReflectStructValues(structVal reflect.Value, namer sqldb.StructReflector, ignoreColumns []sqldb.ColumnFilter) (columns []string, pkCols []int, values []any) { for i := 0; i < structVal.NumField(); i++ { fieldType := structVal.Type().Field(i) column, flags, use := namer.MapStructField(fieldType) @@ -43,7 +44,7 @@ func ReflectStructValues(structVal reflect.Value, namer sqldb.StructFieldMapper, return columns, pkCols, values } -func ReflectStructColumnPointers(structVal reflect.Value, namer sqldb.StructFieldMapper, columns []string) (pointers []any, err error) { +func ReflectStructColumnPointers(structVal reflect.Value, namer sqldb.StructReflector, columns []string) (pointers []any, err error) { if len(columns) == 0 { return nil, errors.New("no columns") } @@ -71,7 +72,7 @@ func ReflectStructColumnPointers(structVal reflect.Value, namer sqldb.StructFiel return pointers, nil } -func reflectStructColumnPointers(structVal reflect.Value, namer sqldb.StructFieldMapper, columns []string, pointers []any) error { +func reflectStructColumnPointers(structVal reflect.Value, namer sqldb.StructReflector, columns []string, pointers []any) error { structType := structVal.Type() for i := 0; i < structType.NumField(); i++ { field := structType.Field(i) @@ -103,8 +104,8 @@ func reflectStructColumnPointers(structVal reflect.Value, namer sqldb.StructFiel // If field is a slice or array that does not implement sql.Scanner // and it's not a string scannable []byte type underneath // then wrap it with WrapForArray to make it scannable - if NeedsArrayWrappingForScanning(fieldValue) { - pointer = WrapArray(pointer) + if impl.NeedsArrayWrappingForScanning(fieldValue) { + pointer = impl.WrapArray(pointer) } pointers[colIndex] = pointer } diff --git a/db/rowscanner.go b/db/rowscanner.go new file mode 100644 index 0000000..665fd73 --- /dev/null +++ b/db/rowscanner.go @@ -0,0 +1,130 @@ +package db + +import ( + "database/sql" + "errors" + + sqldb "github.com/domonda/go-sqldb" +) + +// RowScanner implements sqldb.RowScanner for a sql.Row +type RowScanner struct { + rows sqldb.Rows + reflector sqldb.StructReflector // for ScanStruct + argFmt sqldb.PlaceholderFormatter // for error wrapping + query string // for error wrapping + args []any // for error wrapping +} + +func NewRowScanner(rows sqldb.Rows, reflector sqldb.StructReflector, argFmt sqldb.PlaceholderFormatter, query string, args []any) *RowScanner { + return &RowScanner{rows, reflector, argFmt, query, args} +} + +func (s *RowScanner) Columns() ([]string, error) { + cols, err := s.rows.Columns() + if err != nil { + return nil, wrapErrorWithQuery(err, s.query, s.args, s.argFmt) + } + return cols, nil +} + +func (s *RowScanner) Scan(dest ...any) (err error) { + defer func() { + err = errors.Join(err, s.rows.Close()) + if err != nil { + err = wrapErrorWithQuery(err, s.query, s.args, s.argFmt) + } + }() + + if len(dest) == 0 { + return errors.New("RowScanner.Scan called with no destination arguments") + } + // Check if there was an error even before preparing the row with Next() + if s.rows.Err() != nil { + return s.rows.Err() + } + if !s.rows.Next() { + // Error during preparing the row with Next() + if s.rows.Err() != nil { + return s.rows.Err() + } + return sql.ErrNoRows + } + + return s.rows.Scan(dest...) +} + +// TODO integrate ScanStruct into Scan +func (s *RowScanner) ScanStruct(dest any) (err error) { + defer func() { + err = errors.Join(err, s.rows.Close()) + if err != nil { + err = wrapErrorWithQuery(err, s.query, s.args, s.argFmt) + } + }() + + // Check if there was an error even before preparing the row with Next() + if s.rows.Err() != nil { + return s.rows.Err() + } + if !s.rows.Next() { + // Error during preparing the row with Next() + if s.rows.Err() != nil { + return s.rows.Err() + } + return sql.ErrNoRows + } + + return scanStruct(s.rows, s.reflector, dest) +} + +// ScanValues returns the values of a row exactly how they are +// passed from the database driver to an `sql.Scanner`. +// Byte slices will be copied. +func (s *RowScanner) ScanValues() (vals []any, err error) { + cols, err := s.Columns() + if err != nil { + return nil, err + } + var ( + anys = make([]sqldb.AnyValue, len(cols)) + result = make([]any, len(cols)) + ) + // result elements hold pointer to sqldb.AnyValue for scanning + for i := range result { + result[i] = &anys[i] + } + err = s.Scan(result...) + if err != nil { + return nil, err + } + // don't return pointers to sqldb.AnyValue + // but what internal value has been scanned + for i := range result { + result[i] = anys[i].Val + } + return result, nil +} + +// ScanStrings scans the values of a row as strings. +// Byte slices will be interpreted as strings, +// nil (SQL NULL) will be converted to an empty string, +// all other types are converted with `fmt.Sprint`. +func (s *RowScanner) ScanStrings() (vals []string, err error) { + cols, err := s.Columns() + if err != nil { + return nil, err + } + var ( + result = make([]string, len(cols)) + resultPtrs = make([]any, len(cols)) + ) + for i := range resultPtrs { + resultPtrs[i] = (*sqldb.StringScannable)(&result[i]) + } + err = s.Scan(resultPtrs...) + if err != nil { + return nil, err + } + return result, nil +} diff --git a/db/scanresult.go b/db/scanresult.go new file mode 100644 index 0000000..6da0d91 --- /dev/null +++ b/db/scanresult.go @@ -0,0 +1,54 @@ +package db + +// TODO move to RowScanner ? + +// // scanValues returns the values of a row exactly how they are +// // passed from the database driver to an sql.Scanner. +// // Byte slices will be copied. +// func scanValues(src sqldb.Rows) ([]any, error) { +// cols, err := src.Columns() +// if err != nil { +// return nil, err +// } +// var ( +// anys = make([]sqldb.AnyValue, len(cols)) +// result = make([]any, len(cols)) +// ) +// // result elements hold pointer to sqldb.AnyValue for scanning +// for i := range result { +// result[i] = &anys[i] +// } +// err = src.Scan(result...) +// if err != nil { +// return nil, err +// } +// // don't return pointers to sqldb.AnyValue +// // but what internal value has been scanned +// for i := range result { +// result[i] = anys[i].Val +// } +// return result, nil +// } + +// // scanStrings scans the values of a row as strings. +// // Byte slices will be interpreted as strings, +// // nil (SQL NULL) will be converted to an empty string, +// // all other types are converted with fmt.Sprint. +// func scanStrings(src sqldb.Rows) ([]string, error) { +// cols, err := src.Columns() +// if err != nil { +// return nil, err +// } +// var ( +// result = make([]string, len(cols)) +// resultPtrs = make([]any, len(cols)) +// ) +// for i := range resultPtrs { +// resultPtrs[i] = (*sqldb.StringScannable)(&result[i]) +// } +// err = src.Scan(resultPtrs...) +// if err != nil { +// return nil, err +// } +// return result, nil +// } diff --git a/impl/scanstruct.go b/db/scanstruct.go similarity index 79% rename from impl/scanstruct.go rename to db/scanstruct.go index ce2dd56..4158830 100644 --- a/impl/scanstruct.go +++ b/db/scanstruct.go @@ -1,4 +1,4 @@ -package impl +package db import ( "fmt" @@ -7,7 +7,8 @@ import ( sqldb "github.com/domonda/go-sqldb" ) -func ScanStruct(srcRow Row, destStruct any, namer sqldb.StructFieldMapper) error { +// scanStruct scans the srcRow into the destStruct using the reflector. +func scanStruct(srcRow sqldb.Row, reflector sqldb.StructReflector, destStruct any) error { v := reflect.ValueOf(destStruct) for v.Kind() == reflect.Ptr && !v.IsNil() { v = v.Elem() @@ -35,7 +36,7 @@ func ScanStruct(srcRow Row, destStruct any, namer sqldb.StructFieldMapper) error return err } - fieldPointers, err := ReflectStructColumnPointers(v, namer, columns) + fieldPointers, err := ReflectStructColumnPointers(v, reflector, columns) if err != nil { return fmt.Errorf("ScanStruct: %w", err) } diff --git a/impl/scanstruct_test.go b/db/scanstruct_test.go similarity index 99% rename from impl/scanstruct_test.go rename to db/scanstruct_test.go index 787e76e..f37b206 100644 --- a/impl/scanstruct_test.go +++ b/db/scanstruct_test.go @@ -1,4 +1,4 @@ -package impl +package db // func TestGetStructFieldIndices(t *testing.T) { // type DeepEmbeddedStruct struct { diff --git a/db/transaction_test.go b/db/transaction_test.go index 5333ee0..32e133f 100644 --- a/db/transaction_test.go +++ b/db/transaction_test.go @@ -1,5 +1,6 @@ package db +/* import ( "context" "errors" @@ -116,3 +117,4 @@ func TestTransaction(t *testing.T) { }) } } +*/ diff --git a/db/update.go b/db/update.go index 84b3a21..b1a3b7a 100644 --- a/db/update.go +++ b/db/update.go @@ -8,7 +8,6 @@ import ( "strings" sqldb "github.com/domonda/go-sqldb" - "github.com/domonda/go-sqldb/impl" ) // Update table rows(s) with values using the where statement with passed in args starting at $1. @@ -26,31 +25,31 @@ func Update(ctx context.Context, table string, values sqldb.Values, where string return nil } -// UpdateReturningRow updates a table row with values using the where statement with passed in args starting at $1 -// and returning a single row with the columns specified in returning argument. -func UpdateReturningRow(ctx context.Context, table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { - if len(values) == 0 { - return sqldb.RowScannerWithError(fmt.Errorf("UpdateReturningRow table %s: no values passed", table)) - } - conn := Conn(ctx) - - query, vals := buildUpdateQuery(table, values, where, args, conn) - query += " RETURNING " + returning - return conn.QueryRow(query, vals...) -} - -// UpdateReturningRows updates table rows with values using the where statement with passed in args starting at $1 -// and returning multiple rows with the columns specified in returning argument. -func UpdateReturningRows(ctx context.Context, table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { - if len(values) == 0 { - return sqldb.RowsScannerWithError(fmt.Errorf("UpdateReturningRows table %s: no values passed", table)) - } - conn := Conn(ctx) - - query, vals := buildUpdateQuery(table, values, where, args, conn) - query += " RETURNING " + returning - return conn.QueryRows(query, vals...) -} +// // UpdateReturningRow updates a table row with values using the where statement with passed in args starting at $1 +// // and returning a single row with the columns specified in returning argument. +// func UpdateReturningRow(ctx context.Context, table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowScanner { +// if len(values) == 0 { +// return sqldb.RowScannerWithError(fmt.Errorf("UpdateReturningRow table %s: no values passed", table)) +// } +// conn := Conn(ctx) + +// query, vals := buildUpdateQuery(table, values, where, args, conn) +// query += " RETURNING " + returning +// return conn.QueryRow(query, vals...) +// } + +// // UpdateReturningRows updates table rows with values using the where statement with passed in args starting at $1 +// // and returning multiple rows with the columns specified in returning argument. +// func UpdateReturningRows(ctx context.Context, table string, values sqldb.Values, returning, where string, args ...any) sqldb.RowsScanner { +// if len(values) == 0 { +// return sqldb.RowsScannerWithError(fmt.Errorf("UpdateReturningRows table %s: no values passed", table)) +// } +// conn := Conn(ctx) + +// query, vals := buildUpdateQuery(table, values, where, args, conn) +// query += " RETURNING " + returning +// return conn.QueryRows(query, vals...) +// } func buildUpdateQuery(table string, values sqldb.Values, where string, args []any, argFmt sqldb.PlaceholderFormatter) (string, []any) { names, vals := values.Sorted() @@ -88,7 +87,7 @@ func UpdateStruct(ctx context.Context, table string, rowStruct any, ignoreColumn conn := Conn(ctx) - columns, pkCols, vals := impl.ReflectStructValues(v, conn.StructFieldMapper(), append(ignoreColumns, sqldb.IgnoreReadOnly)) + columns, pkCols, vals := ReflectStructValues(v, conn.StructReflector(), append(ignoreColumns, sqldb.IgnoreReadOnly)) if len(pkCols) == 0 { return fmt.Errorf("UpdateStruct of table %s: %s has no mapped primary key field", table, v.Type()) } diff --git a/db/upsert.go b/db/upsert.go index bded8c2..6da0923 100644 --- a/db/upsert.go +++ b/db/upsert.go @@ -8,7 +8,6 @@ import ( "strings" "github.com/domonda/go-sqldb" - "github.com/domonda/go-sqldb/impl" ) // UpsertStruct upserts a row to table using the exported fields @@ -32,7 +31,7 @@ func UpsertStruct(ctx context.Context, table string, rowStruct any, ignoreColumn conn := Conn(ctx) - columns, pkCols, vals := impl.ReflectStructValues(v, conn.StructFieldMapper(), append(ignoreColumns, sqldb.IgnoreReadOnly)) + columns, pkCols, vals := ReflectStructValues(v, conn.StructReflector(), append(ignoreColumns, sqldb.IgnoreReadOnly)) if len(pkCols) == 0 { return fmt.Errorf("UpsertStruct of table %s: %s has no mapped primary key field", table, v.Type()) } diff --git a/db/utils.go b/db/utils.go index f5665fa..e4aabf3 100644 --- a/db/utils.go +++ b/db/utils.go @@ -38,8 +38,7 @@ func DebugPrintConn(ctx context.Context, args ...any) { args = append(args, "Isolation", optsStr) } } - var t time.Time - err := conn.QueryRow("SELECT CURRENT_TIMESTAMP").Scan(&t) + t, err := QueryValue[time.Time](ctx, "SELECT CURRENT_TIMESTAMP") if err == nil { args = append(args, "CURRENT_TIMESTAMP:", t) } else { diff --git a/errors.go b/errors.go index f106fa5..5b622b5 100644 --- a/errors.go +++ b/errors.go @@ -5,12 +5,6 @@ import ( "errors" ) -var ( - _ Connection = connectionWithError{} - _ RowScanner = rowScannerWithError{} - _ RowsScanner = rowsScannerWithError{} -) - // ReplaceErrNoRows returns the passed replacement error // if errors.Is(err, sql.ErrNoRows), // else the passed err is returned unchanged. @@ -163,71 +157,3 @@ func (e ErrExclusionViolation) Error() string { func (e ErrExclusionViolation) Unwrap() error { return ErrIntegrityConstraintViolation{Constraint: e.Constraint} } - -// RowScannerWithError - -// RowScannerWithError returns a dummy RowScanner -// where all methods return the passed error. -func RowScannerWithError(err error) RowScanner { - return rowScannerWithError{err} -} - -type rowScannerWithError struct { - err error -} - -func (e rowScannerWithError) Scan(dest ...any) error { - return e.err -} - -func (e rowScannerWithError) ScanStruct(dest any) error { - return e.err -} - -func (e rowScannerWithError) ScanValues() ([]any, error) { - return nil, e.err -} - -func (e rowScannerWithError) ScanStrings() ([]string, error) { - return nil, e.err -} - -func (e rowScannerWithError) Columns() ([]string, error) { - return nil, e.err -} - -// RowsScannerWithError - -// RowsScannerWithError returns a dummy RowsScanner -// where all methods return the passed error. -func RowsScannerWithError(err error) RowsScanner { - return rowsScannerWithError{err} -} - -type rowsScannerWithError struct { - err error -} - -func (e rowsScannerWithError) ScanSlice(dest any) error { - return e.err -} - -func (e rowsScannerWithError) ScanStructSlice(dest any) error { - return e.err -} - -func (e rowsScannerWithError) Columns() ([]string, error) { - return nil, e.err -} - -func (e rowsScannerWithError) ScanAllRowsAsStrings(headerRow bool) ([][]string, error) { - return nil, e.err -} - -func (e rowsScannerWithError) ForEachRow(callback func(RowScanner) error) error { - return e.err -} - -func (e rowsScannerWithError) ForEachRowCall(callback any) error { - return e.err -} diff --git a/examples/user_demo/user_demo.go b/examples/user_demo/user_demo.go index b543f2c..df4e4c4 100644 --- a/examples/user_demo/user_demo.go +++ b/examples/user_demo/user_demo.go @@ -45,7 +45,7 @@ func main() { panic(err) } - conn = conn.WithStructFieldMapper(&sqldb.TaggedStructFieldMapping{ + conn = conn.WithStructFieldMapper(&sqldb.TaggedStructReflector{ NameTag: "col", Ignore: "ignore", UntaggedNameFunc: sqldb.ToSnakeCase, diff --git a/impl/arrays.go b/impl/arrays.go index 3a2fcb1..af1c928 100644 --- a/impl/arrays.go +++ b/impl/arrays.go @@ -1,13 +1,25 @@ package impl import ( + "context" "database/sql" "database/sql/driver" "reflect" + "time" "github.com/lib/pq" ) +var ( + typeOfError = reflect.TypeFor[error]() + typeOfContext = reflect.TypeFor[context.Context]() + typeOfSQLScanner = reflect.TypeFor[sql.Scanner]() + typeOfDriverValuer = reflect.TypeFor[driver.Valuer]() + typeOfTime = reflect.TypeFor[time.Time]() + typeOfByte = reflect.TypeFor[byte]() + typeOfByteSlice = reflect.TypeFor[[]byte]() +) + type ValuerScanner interface { driver.Valuer sql.Scanner diff --git a/impl/connection.go b/impl/connection.go index a9c1698..c7553e0 100644 --- a/impl/connection.go +++ b/impl/connection.go @@ -29,7 +29,7 @@ type connection struct { ctx context.Context db *sql.DB config *sqldb.Config - structFieldNamer sqldb.StructFieldMapper + structFieldNamer sqldb.StructReflector argFmt string validateColumnName func(string) error } @@ -50,13 +50,13 @@ func (conn *connection) WithContext(ctx context.Context) sqldb.Connection { return c } -func (conn *connection) WithStructFieldMapper(namer sqldb.StructFieldMapper) sqldb.Connection { +func (conn *connection) WithStructFieldMapper(namer sqldb.StructReflector) sqldb.Connection { c := conn.clone() c.structFieldNamer = namer return c } -func (conn *connection) StructFieldMapper() sqldb.StructFieldMapper { +func (conn *connection) StructReflector() sqldb.StructReflector { return conn.structFieldNamer } @@ -91,27 +91,31 @@ func (conn *connection) Exec(query string, args ...any) error { return err } -func (conn *connection) Query(query string, args ...any) (sqldb.Rows, error) { - return conn.db.QueryContext(conn.ctx, query, args...) -} - -func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { - rows, err := conn.db.QueryContext(conn.ctx, query, args...) - if err != nil { - err = WrapNonNilErrorWithQuery(err, query, conn.argFmt, args) - return sqldb.RowScannerWithError(err) - } - return NewRowScanner(rows, conn.structFieldNamer, query, conn.argFmt, args) -} - -func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { +func (conn *connection) Query(query string, args ...any) sqldb.Rows { rows, err := conn.db.QueryContext(conn.ctx, query, args...) if err != nil { - err = WrapNonNilErrorWithQuery(err, query, conn.argFmt, args) - return sqldb.RowsScannerWithError(err) + return sqldb.RowsErr(err) } - return NewRowsScanner(conn.ctx, rows, conn.structFieldNamer, query, conn.argFmt, args) -} + return rows +} + +// func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner { +// rows, err := conn.db.QueryContext(conn.ctx, query, args...) +// if err != nil { +// err = WrapNonNilErrorWithQuery(err, query, conn.argFmt, args) +// return sqldb.RowScannerWithError(err) +// } +// return NewRowScanner(rows, conn.structFieldNamer, query, conn.argFmt, args) +// } + +// func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner { +// rows, err := conn.db.QueryContext(conn.ctx, query, args...) +// if err != nil { +// err = WrapNonNilErrorWithQuery(err, query, conn.argFmt, args) +// return sqldb.RowsScannerWithError(err) +// } +// return NewRowsScanner(conn.ctx, rows, conn.structFieldNamer, query, conn.argFmt, args) +// } func (conn *connection) TransactionInfo() (no uint64, opts *sql.TxOptions) { return 0, nil diff --git a/impl/insert.go b/impl/insert.go index 8254b4f..3ca7bf4 100644 --- a/impl/insert.go +++ b/impl/insert.go @@ -1,5 +1,6 @@ package impl +/* import ( "fmt" "reflect" @@ -142,3 +143,4 @@ func insertStructValues(table string, rowStruct any, namer sqldb.StructFieldMapp columns, _, vals = ReflectStructValues(v, namer, append(ignoreColumns, sqldb.IgnoreReadOnly)) return columns, vals, nil } +*/ diff --git a/impl/now.go b/impl/now.go deleted file mode 100644 index 4a1bd2f..0000000 --- a/impl/now.go +++ /dev/null @@ -1,15 +0,0 @@ -package impl - -import ( - "time" - - "github.com/domonda/go-sqldb" -) - -func Now(conn sqldb.Connection) (now time.Time, err error) { - err = conn.QueryRow(`select now()`).Scan(&now) - if err != nil { - return time.Time{}, err - } - return now, nil -} diff --git a/impl/rows.go b/impl/rows.go deleted file mode 100644 index 84ab136..0000000 --- a/impl/rows.go +++ /dev/null @@ -1,41 +0,0 @@ -package impl - -// Rows is an interface with the methods of sql.Rows -// that are needed for ScanSlice. -// Allows mocking for tests without an SQL driver. -type Rows interface { - Row - - // Close closes the Rows, preventing further enumeration. If Next is called - // and returns false and there are no further result sets, - // the Rows are closed automatically and it will suffice to check the - // result of Err. Close is idempotent and does not affect the result of Err. - Close() error - - // Next prepares the next result row for reading with the Scan method. It - // returns true on success, or false if there is no next result row or an error - // happened while preparing it. Err should be consulted to distinguish between - // the two cases. - // - // Every call to Scan, even the first one, must be preceded by a call to Next. - Next() bool - - // Err returns the error, if any, that was encountered during iteration. - // Err may be called after an explicit or implicit Close. - Err() error -} - -// RowAsRows implements the methods of Rows for a Row as no-ops. -// Note that Next() always returns true leading to an endless loop -// if used to scan multiple rows. -func RowAsRows(row Row) Rows { - return rowAsRows{Row: row} -} - -type rowAsRows struct { - Row -} - -func (rowAsRows) Close() error { return nil } -func (rowAsRows) Next() bool { return true } -func (rowAsRows) Err() error { return nil } diff --git a/impl/rowscanner.go b/impl/rowscanner.go deleted file mode 100644 index 6bc5826..0000000 --- a/impl/rowscanner.go +++ /dev/null @@ -1,129 +0,0 @@ -package impl - -import ( - "database/sql" - "errors" - - sqldb "github.com/domonda/go-sqldb" -) - -var ( - _ sqldb.RowScanner = &RowScanner{} - _ sqldb.RowScanner = CurrentRowScanner{} - _ sqldb.RowScanner = SingleRowScanner{} -) - -// RowScanner implements sqldb.RowScanner for a sql.Row -type RowScanner struct { - rows Rows - structFieldNamer sqldb.StructFieldMapper - query string // for error wrapping - argFmt string // for error wrapping - args []any // for error wrapping -} - -func NewRowScanner(rows Rows, structFieldNamer sqldb.StructFieldMapper, query, argFmt string, args []any) *RowScanner { - return &RowScanner{rows, structFieldNamer, query, argFmt, args} -} - -func (s *RowScanner) Scan(dest ...any) (err error) { - defer func() { - err = errors.Join(err, s.rows.Close()) - err = WrapNonNilErrorWithQuery(err, s.query, s.argFmt, s.args) - }() - - if s.rows.Err() != nil { - return s.rows.Err() - } - if !s.rows.Next() { - if s.rows.Err() != nil { - return s.rows.Err() - } - return sql.ErrNoRows - } - - return s.rows.Scan(dest...) -} - -func (s *RowScanner) ScanStruct(dest any) (err error) { - defer func() { - err = errors.Join(err, s.rows.Close()) - err = WrapNonNilErrorWithQuery(err, s.query, s.argFmt, s.args) - }() - - if s.rows.Err() != nil { - return s.rows.Err() - } - if !s.rows.Next() { - if s.rows.Err() != nil { - return s.rows.Err() - } - return sql.ErrNoRows - } - - return ScanStruct(s.rows, dest, s.structFieldNamer) -} - -func (s *RowScanner) ScanValues() ([]any, error) { - return ScanValues(s.rows) -} - -func (s *RowScanner) ScanStrings() ([]string, error) { - return ScanStrings(s.rows) -} - -func (s *RowScanner) Columns() ([]string, error) { - return s.rows.Columns() -} - -// CurrentRowScanner calls Rows.Scan without Rows.Next and Rows.Close -type CurrentRowScanner struct { - Rows Rows - StructFieldMapper sqldb.StructFieldMapper -} - -func (s CurrentRowScanner) Scan(dest ...any) error { - return s.Rows.Scan(dest...) -} - -func (s CurrentRowScanner) ScanStruct(dest any) error { - return ScanStruct(s.Rows, dest, s.StructFieldMapper) -} - -func (s CurrentRowScanner) ScanValues() ([]any, error) { - return ScanValues(s.Rows) -} - -func (s CurrentRowScanner) ScanStrings() ([]string, error) { - return ScanStrings(s.Rows) -} - -func (s CurrentRowScanner) Columns() ([]string, error) { - return s.Rows.Columns() -} - -// SingleRowScanner always uses the same Row -type SingleRowScanner struct { - Row Row - StructFieldMapper sqldb.StructFieldMapper -} - -func (s SingleRowScanner) Scan(dest ...any) error { - return s.Row.Scan(dest...) -} - -func (s SingleRowScanner) ScanStruct(dest any) error { - return ScanStruct(s.Row, dest, s.StructFieldMapper) -} - -func (s SingleRowScanner) ScanValues() ([]any, error) { - return ScanValues(s.Row) -} - -func (s SingleRowScanner) ScanStrings() ([]string, error) { - return ScanStrings(s.Row) -} - -func (s SingleRowScanner) Columns() ([]string, error) { - return s.Row.Columns() -} diff --git a/impl/rowsscanner.go b/impl/rowsscanner.go deleted file mode 100644 index 339a0e7..0000000 --- a/impl/rowsscanner.go +++ /dev/null @@ -1,154 +0,0 @@ -package impl - -import ( - "context" - "errors" - "fmt" - "reflect" - - sqldb "github.com/domonda/go-sqldb" -) - -var _ sqldb.RowsScanner = &RowsScanner{} - -// RowsScanner implements sqldb.RowsScanner with Rows -type RowsScanner struct { - ctx context.Context // ctx is checked for every row and passed through to callbacks - rows Rows - structFieldNamer sqldb.StructFieldMapper - query string // for error wrapping - argFmt string // for error wrapping - args []any // for error wrapping -} - -func NewRowsScanner(ctx context.Context, rows Rows, structFieldNamer sqldb.StructFieldMapper, query, argFmt string, args []any) *RowsScanner { - return &RowsScanner{ctx, rows, structFieldNamer, query, argFmt, args} -} - -func (s *RowsScanner) ScanSlice(dest any) error { - err := ScanRowsAsSlice(s.ctx, s.rows, dest, nil) - if err != nil { - return fmt.Errorf("%w from query: %s", err, FormatQuery(s.query, s.argFmt, s.args...)) - } - return nil -} - -func (s *RowsScanner) ScanStructSlice(dest any) error { - err := ScanRowsAsSlice(s.ctx, s.rows, dest, s.structFieldNamer) - if err != nil { - return fmt.Errorf("%w from query: %s", err, FormatQuery(s.query, s.argFmt, s.args...)) - } - return nil -} - -func (s *RowsScanner) Columns() ([]string, error) { - return s.rows.Columns() -} - -func (s *RowsScanner) ScanAllRowsAsStrings(headerRow bool) (rows [][]string, err error) { - cols, err := s.rows.Columns() - if err != nil { - return nil, err - } - if headerRow { - rows = [][]string{cols} - } - stringScannablePtrs := make([]any, len(cols)) - err = s.ForEachRow(func(rowScanner sqldb.RowScanner) error { - row := make([]string, len(cols)) - for i := range stringScannablePtrs { - stringScannablePtrs[i] = (*sqldb.StringScannable)(&row[i]) - } - err := rowScanner.Scan(stringScannablePtrs...) - if err != nil { - return err - } - rows = append(rows, row) - return nil - }) - return rows, err -} - -func (s *RowsScanner) ForEachRow(callback func(sqldb.RowScanner) error) (err error) { - defer func() { - err = errors.Join(err, s.rows.Close()) - err = WrapNonNilErrorWithQuery(err, s.query, s.argFmt, s.args) - }() - - for s.rows.Next() { - if s.ctx.Err() != nil { - return s.ctx.Err() - } - - err := callback(CurrentRowScanner{s.rows, s.structFieldNamer}) - if err != nil { - return err - } - } - return s.rows.Err() -} - -func (s *RowsScanner) ForEachRowCall(callback any) error { - forEachRowFunc, err := ForEachRowCallFunc(s.ctx, callback) - if err != nil { - return err - } - return s.ForEachRow(forEachRowFunc) -} - -// ScanRowsAsSlice scans all srcRows as slice into dest. -// The rows must either have only one column compatible with the element type of the slice, -// or if multiple columns are returned then the slice element type must me a struct or struction pointer -// so that every column maps on exactly one struct field using structFieldNamer. -// In case of single column rows, nil must be passed for structFieldNamer. -// ScanRowsAsSlice calls srcRows.Close(). -func ScanRowsAsSlice(ctx context.Context, srcRows Rows, dest any, structFieldNamer sqldb.StructFieldMapper) error { - defer srcRows.Close() - - destVal := reflect.ValueOf(dest) - if destVal.Kind() != reflect.Ptr { - return fmt.Errorf("scan dest is not a pointer but %s", destVal.Type()) - } - if destVal.IsNil() { - return errors.New("scan dest is nil") - } - slice := destVal.Elem() - if slice.Kind() != reflect.Slice { - return fmt.Errorf("scan dest is not pointer to slice but %s", destVal.Type()) - } - sliceElemType := slice.Type().Elem() - - newSlice := reflect.MakeSlice(slice.Type(), 0, 32) - - for srcRows.Next() { - if ctx.Err() != nil { - return ctx.Err() - } - - newSlice = reflect.Append(newSlice, reflect.Zero(sliceElemType)) - target := newSlice.Index(newSlice.Len() - 1).Addr() - if structFieldNamer != nil { - err := ScanStruct(srcRows, target.Interface(), structFieldNamer) - if err != nil { - return err - } - } else { - err := srcRows.Scan(target.Interface()) - if err != nil { - return err - } - } - } - if srcRows.Err() != nil { - return srcRows.Err() - } - - // Assign newSlice if there were no errors - if newSlice.Len() == 0 { - slice.SetLen(0) - } else { - slice.Set(newSlice) - } - - return nil -} diff --git a/impl/scanresult.go b/impl/scanresult.go deleted file mode 100644 index d8a4bc3..0000000 --- a/impl/scanresult.go +++ /dev/null @@ -1,54 +0,0 @@ -package impl - -import "github.com/domonda/go-sqldb" - -// ScanValues returns the values of a row exactly how they are -// passed from the database driver to an sql.Scanner. -// Byte slices will be copied. -func ScanValues(src Row) ([]any, error) { - cols, err := src.Columns() - if err != nil { - return nil, err - } - var ( - anys = make([]sqldb.AnyValue, len(cols)) - result = make([]any, len(cols)) - ) - // result elements hold pointer to sqldb.AnyValue for scanning - for i := range result { - result[i] = &anys[i] - } - err = src.Scan(result...) - if err != nil { - return nil, err - } - // don't return pointers to sqldb.AnyValue - // but what internal value has been scanned - for i := range result { - result[i] = anys[i].Val - } - return result, nil -} - -// ScanStrings scans the values of a row as strings. -// Byte slices will be interpreted as strings, -// nil (SQL NULL) will be converted to an empty string, -// all other types are converted with fmt.Sprint. -func ScanStrings(src Row) ([]string, error) { - cols, err := src.Columns() - if err != nil { - return nil, err - } - var ( - result = make([]string, len(cols)) - resultPtrs = make([]any, len(cols)) - ) - for i := range resultPtrs { - resultPtrs[i] = (*sqldb.StringScannable)(&result[i]) - } - err = src.Scan(resultPtrs...) - if err != nil { - return nil, err - } - return result, nil -} diff --git a/impl/transaction.go b/impl/transaction.go index 246de8c..c8b56c9 100644 --- a/impl/transaction.go +++ b/impl/transaction.go @@ -16,7 +16,7 @@ type transaction struct { tx *sql.Tx opts *sql.TxOptions no uint64 - structFieldNamer sqldb.StructFieldMapper + structFieldNamer sqldb.StructReflector } func newTransaction(parent *connection, tx *sql.Tx, opts *sql.TxOptions, no uint64) *transaction { @@ -45,13 +45,13 @@ func (conn *transaction) WithContext(ctx context.Context) sqldb.Connection { return newTransaction(parent, conn.tx, conn.opts, conn.no) } -func (conn *transaction) WithStructFieldMapper(namer sqldb.StructFieldMapper) sqldb.Connection { +func (conn *transaction) WithStructFieldMapper(namer sqldb.StructReflector) sqldb.Connection { c := conn.clone() c.structFieldNamer = namer return c } -func (conn *transaction) StructFieldMapper() sqldb.StructFieldMapper { +func (conn *transaction) StructReflector() sqldb.StructReflector { return conn.structFieldNamer } @@ -71,27 +71,31 @@ func (conn *transaction) Exec(query string, args ...any) error { return WrapNonNilErrorWithQuery(err, query, conn.parent.argFmt, args) } -func (conn *transaction) Query(query string, args ...any) (sqldb.Rows, error) { - return conn.tx.QueryContext(conn.parent.ctx, query, args...) -} - -func (conn *transaction) QueryRow(query string, args ...any) sqldb.RowScanner { - rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) - if err != nil { - err = WrapNonNilErrorWithQuery(err, query, conn.parent.argFmt, args) - return sqldb.RowScannerWithError(err) - } - return NewRowScanner(rows, conn.structFieldNamer, query, conn.parent.argFmt, args) -} - -func (conn *transaction) QueryRows(query string, args ...any) sqldb.RowsScanner { +func (conn *transaction) Query(query string, args ...any) sqldb.Rows { rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) if err != nil { - err = WrapNonNilErrorWithQuery(err, query, conn.parent.argFmt, args) - return sqldb.RowsScannerWithError(err) + return sqldb.RowsErr(err) } - return NewRowsScanner(conn.parent.ctx, rows, conn.structFieldNamer, query, conn.parent.argFmt, args) -} + return rows +} + +// func (conn *transaction) QueryRow(query string, args ...any) sqldb.RowScanner { +// rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) +// if err != nil { +// err = WrapNonNilErrorWithQuery(err, query, conn.parent.argFmt, args) +// return sqldb.RowScannerWithError(err) +// } +// return NewRowScanner(rows, conn.structFieldNamer, query, conn.parent.argFmt, args) +// } + +// func (conn *transaction) QueryRows(query string, args ...any) sqldb.RowsScanner { +// rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...) +// if err != nil { +// err = WrapNonNilErrorWithQuery(err, query, conn.parent.argFmt, args) +// return sqldb.RowsScannerWithError(err) +// } +// return NewRowsScanner(conn.parent.ctx, rows, conn.structFieldNamer, query, conn.parent.argFmt, args) +// } func (conn *transaction) TransactionInfo() (no uint64, opts *sql.TxOptions) { return conn.no, conn.opts diff --git a/information/primarykeys.go b/information/primarykeys.go index 87f6ffb..d672c86 100644 --- a/information/primarykeys.go +++ b/information/primarykeys.go @@ -115,18 +115,17 @@ type TableRowWithPrimaryKey struct { func GetTableRowsWithPrimaryKey(ctx context.Context, pkCols []PrimaryKeyColumn, pk any) (tableRows []TableRowWithPrimaryKey, err error) { defer errs.WrapWithFuncParams(&err, ctx, pkCols, pk) - conn := db.Conn(ctx) for _, col := range pkCols { - query := fmt.Sprintf(`select * from %s where "%s" = $1`, col.Table, col.Column) - row := conn.QueryRows(query, pk) - cols, err := row.Columns() + query := fmt.Sprintf(`SELECT * FROM %s WHERE "%s" = $1`, col.Table, col.Column) + rows := db.QueryRows(ctx, query, pk) + cols, err := rows.Columns() if err != nil { if errors.Is(err, sql.ErrNoRows) { continue } return nil, err } - vals, err := row.ScanAllRowsAsStrings(false) + vals, err := rows.ScanAllRowsAsStrings(false) if err != nil { if errors.Is(err, sql.ErrNoRows) { continue @@ -154,7 +153,7 @@ var RenderUUIDPrimaryKeyRefsHTML = http.HandlerFunc(func(writer http.ResponseWri pk, err := uu.IDFromString(request.URL.Query().Get("pk")) if err != nil { title = "Primary Key UUID" - mainContent = ` + mainContent = /*html*/ `
@@ -222,7 +221,7 @@ var RenderUUIDPrimaryKeyRefsHTML = http.HandlerFunc(func(writer http.ResponseWri writer.Write(buf.Bytes()) //#nosec G104 }) -var htmlTemplate = ` +var htmlTemplate = /*html*/ ` @@ -254,9 +253,9 @@ var htmlTemplate = ` ` -const StyleAllMonospace = `` +const StyleAllMonospace = /*html*/ `` -const StyleDefaultTable = ``} + style = []string{ + StyleAllMonospace, + StyleDefaultTable, + ``, + } ) pk, err := uu.IDFromString(request.URL.Query().Get("pk")) if err != nil { @@ -175,7 +173,7 @@ var RenderUUIDPrimaryKeyRefsHTML = http.HandlerFunc(func(writer http.ResponseWri return !tableRows[i].ForeignKey && tableRows[j].ForeignKey }) var b strings.Builder - fmt.Fprintf(&b, "

", pk) + fmt.Fprintf(&b, "

", pk) for _, tableRow := range tableRows { //#nosec fmt.Fprintf(&b, "

%s

", html.EscapeString(tableRow.Table)) fmt.Fprintf(&b, "") @@ -267,7 +265,7 @@ const StyleDefaultTable = /*html*/ `