diff --git a/appender.go b/appender.go index 835aa4e6..3e4863a9 100644 --- a/appender.go +++ b/appender.go @@ -3,25 +3,42 @@ package duckdb import ( "database/sql/driver" "errors" + "runtime" + "sync" "github.com/marcboeker/go-duckdb/mapping" ) -// Appender holds the DuckDB appender. It allows efficient bulk loading into a DuckDB database. -type Appender struct { - conn *Conn - schema string - table string - appender mapping.Appender - closed bool - - // The chunk to append to. - chunk DataChunk - // The column types of the table to append to. - types []mapping.LogicalType - // The number of appended rows. - rowCount int -} +type ( + appenderSource[T any] struct { + Source T + } + + rowAppenderSource = appenderSource[RowTableSource] + parallelRowAppenderSource = appenderSource[ParallelRowTableSource] + chunkAppenderSource = appenderSource[ChunkTableSource] + parallelChunkAppenderSource = appenderSource[ParallelChunkTableSource] + + AppenderSource interface { + _secret() + } + + // Appender holds the DuckDB appender. It allows efficient bulk loading into a DuckDB database. + Appender struct { + conn *Conn + schema string + table string + appender mapping.Appender + closed bool + + // The chunk to append to. + chunk DataChunk + // The column types of the table to append to. + types []mapping.LogicalType + // The number of appended rows. + rowCount int + } +) // NewAppenderFromConn returns a new Appender for the default catalog from a DuckDB driver connection. func NewAppenderFromConn(driverConn driver.Conn, schema, table string) (*Appender, error) { @@ -145,6 +162,72 @@ func (a *Appender) AppendRow(args ...driver.Value) error { return nil } +func (a *Appender)AppendTableSource(s AppenderSource) error { + lock := &sync.Mutex{} + // projection is not used in chunk, so we must keep it a 1-1 mapping + columnCount := mapping.AppenderColumnCount(a.appender) + projection := make([]int, 0, columnCount) + for i := mapping.IdxT(0); i < columnCount; i++ { + projection = append(projection, int(i)) + } + var x any = s + switch s := x.(type) { + case rowAppenderSource: + s.Source.Init() + err := appenderRowThread(¶llelRowTSWrapper{s.Source}, lock, a.types, a.appender, projection) + if err != nil { + return err + } + case parallelRowAppenderSource: + wg := sync.WaitGroup{} + + info := s.Source.Init() + threads := min(info.MaxThreads, runtime.GOMAXPROCS(-1)) + var oerr error + for range threads { + wg.Add(1) + go func() { + err := appenderRowThread(s.Source, lock, a.types, a.appender, projection) + if err != nil { + oerr = err + } + wg.Done() + }() + } + wg.Wait() + if oerr != nil { + return oerr + } + case chunkAppenderSource: + s.Source.Init() + err := appenderChunkThread(¶llelChunkTSWrapper{s.Source}, lock, a.types, a.appender) + if err != nil { + return err + } + case parallelChunkAppenderSource: + wg := sync.WaitGroup{} + + info := s.Source.Init() + threads := min(info.MaxThreads, runtime.GOMAXPROCS(-1)) + var oerr error + for range threads { + wg.Add(1) + go func() { + err := appenderChunkThread(s.Source, lock, a.types, a.appender) + if err != nil { + oerr = err + } + wg.Done() + }() + } + wg.Wait() + if oerr != nil { + return oerr + } + } + return nil +} + func (a *Appender) appendRowSlice(args []driver.Value) error { // Early-out, if the number of args does not match the column count. if len(args) != len(a.types) { @@ -193,3 +276,95 @@ func destroyTypeSlice(slice []mapping.LogicalType) { mapping.DestroyLogicalType(&t) } } + +func appenderRowThread(s ParallelRowTableSource, lock *sync.Mutex, types []mapping.LogicalType, duckdbAppender mapping.Appender, projection []int) error { + maxSize := GetDataChunkCapacity() + lstate := s.NewLocalState() + var chunk DataChunk + err := chunk.initFromTypes(types, true) + if err != nil { + return err + } + + for { + row := Row{ + chunk: &chunk, + projection: projection, + } + var next bool + for row.r = 0; row.r < mapping.IdxT(maxSize); row.r++ { + next, err = s.FillRow(lstate, row) + if err != nil { + chunk.close() + return err + } + if !next { + break + } + } + + mapping.DataChunkSetSize(chunk.chunk, row.r) + + lock.Lock() + state := mapping.AppendDataChunk(duckdbAppender, chunk.chunk) + if state == mapping.StateError { + return getDuckDBError(mapping.AppenderError(duckdbAppender)) + } + lock.Unlock() + if !next { + break + } + chunk.reset(true) + } + chunk.close() + return nil +} + +func appenderChunkThread(s ParallelChunkTableSource, lock *sync.Mutex, types []mapping.LogicalType, duckdbAppender mapping.Appender) error { + lstate := s.NewLocalState() + var chunk DataChunk + err := chunk.initFromTypes(types, true) + if err != nil { + return err + } + + for { + err = s.FillChunk(lstate, chunk) + if err != nil { + return err + } + + if chunk.GetSize() == 0 { + chunk.close() + break + } + + lock.Lock() + state := mapping.AppendDataChunk(duckdbAppender, chunk.chunk) + if state == mapping.StateError { + return getDuckDBError(mapping.AppenderError(duckdbAppender)) + } + lock.Unlock() + chunk.reset(true) + } + chunk.close() + return nil +} + +func (a appenderSource[T]) _secret() {} + +func NewAppenderRowSource(source RowTableSource) AppenderSource { + return rowAppenderSource{Source: source} +} + +func NewAppenderParallelRowSource(source ParallelRowTableSource) AppenderSource { + return parallelRowAppenderSource{Source: source} +} + +func NewAppenderChunkSource(source ChunkTableSource) AppenderSource { + return chunkAppenderSource{Source: source} +} + +func NewAppenderParallelChunkSource(source ParallelChunkTableSource) AppenderSource { + return parallelChunkAppenderSource{Source: source} +} diff --git a/appender_test.go b/appender_test.go index a36edb87..5d56c481 100644 --- a/appender_test.go +++ b/appender_test.go @@ -10,6 +10,7 @@ import ( "math/rand" "os" "reflect" + "sync" "testing" "time" @@ -91,6 +92,8 @@ type resultRow struct { mixList []any } +const benchmarkRowsToAppend = 2048 * 128 + func castList[T any](val []any) []T { res := make([]T, len(val)) for i, v := range val { @@ -160,7 +163,7 @@ func TestAppendChunks(t *testing.T) { } rowsToAppend := make([]row, rowCount) - for i := 0; i < rowCount; i++ { + for i := range rowCount { rowsToAppend[i] = row{ID: int64(i), UInt8: uint8(randInt(0, 255))} require.NoError(t, a.AppendRow(rowsToAppend[i].ID, rowsToAppend[i].UInt8)) } @@ -190,7 +193,7 @@ func TestAppenderList(t *testing.T) { defer cleanupAppender(t, c, db, conn, a) rowsToAppend := make([]nestedDataRow, 10) - for i := 0; i < 10; i++ { + for i := range 10 { rowsToAppend[i].stringList = []string{"a", "b", "c"} rowsToAppend[i].intList = []int32{1, 2, 3} } @@ -222,7 +225,7 @@ func TestAppenderArray(t *testing.T) { count := 10 expected := Composite[[3]string]{[3]string{"a", "b", "c"}} - for i := 0; i < count; i++ { + for range count { require.NoError(t, a.AppendRow([]string{"a", "b", "c"})) require.NoError(t, a.AppendRow(expected.Get())) } @@ -800,7 +803,7 @@ func TestAppendToCatalog(t *testing.T) { conn := openConnWrapper(t, db, context.Background()) defer closeConnWrapper(t, conn) - err = conn.Raw(func(anyConn interface{}) error { + err = conn.Raw(func(anyConn any) error { driverConn := anyConn.(driver.Conn) a, innerErr := NewAppender(driverConn, "other", "", "test") require.NoError(t, innerErr) @@ -858,7 +861,7 @@ func TestAppenderWithJSON(t *testing.T) { defer cleanupAppender(t, c, db, conn, a) for _, jsonInput := range jsonInputs { - var jsonData map[string]interface{} + var jsonData map[string]any err := json.Unmarshal(jsonInput, &jsonData) require.NoError(t, err) require.NoError(t, a.AppendRow(jsonData["c1"], jsonData["l1"], jsonData["s1"], jsonData["l2"])) @@ -874,10 +877,10 @@ func TestAppenderWithJSON(t *testing.T) { i := 0 for res.Next() { var ( - c1 interface{} - l1 interface{} - s1 interface{} - l2 interface{} + c1 any + l1 any + s1 any + l2 any ) err := res.Scan(&c1, &l1, &s1, &l2) require.NoError(t, err) @@ -957,13 +960,258 @@ func TestAppenderAppendDataChunk(t *testing.T) { defer cleanupAppender(t, c, db, conn, a) // Add enough rows to overflow several chunks. - for i := 0; i < GetDataChunkCapacity()*3; i++ { + for i := range GetDataChunkCapacity() * 3 { require.NoError(t, a.AppendRow(i, Union{Value: "str2", Tag: "s"})) require.NoError(t, a.AppendRow(i, nil)) } require.NoError(t, a.Flush()) } +type ( + appStructTableUDF struct { + n int64 + count int64 + } + + appParStructTableUDF struct { + lock *sync.Mutex + claimed int64 + n int64 + } +) + +func (udf *appStructTableUDF) ColumnInfos() []ColumnInfo { + t, _ := NewTypeInfo(TYPE_BIGINT) + t2, _ := NewTypeInfo(TYPE_UTINYINT) + return []ColumnInfo{ + {Name: "id", T: t}, + {Name: "uint8", T: t2}, + } +} + +func (udf *appStructTableUDF) Init() {} + +func (udf *appStructTableUDF) FillRow(row Row) (bool, error) { + if udf.count >= udf.n { + return false, nil + } + udf.count++ + err := SetRowValue(row, 0, udf.count) + if err != nil { + return true, err + } + err = SetRowValue(row, 1, udf.count) + return true, err +} + +func (udf *appStructTableUDF) GetTypes() []any { + return []any{int64(0), uint8(0)} +} + +func (udf *appStructTableUDF) GetValue(r, c int) any { + if c == 0 { + return int64(r + 1) + } else { + return uint8(r + 1) + } +} + +func (udf *appStructTableUDF) Cardinality() *CardinalityInfo { + return nil +} + +func (udf *appParStructTableUDF) ColumnInfos() []ColumnInfo { + return []ColumnInfo{{Name: "result", T: typeBigintTableUDF}} +} + +func (udf *appParStructTableUDF) Init() ParallelTableSourceInfo { + return ParallelTableSourceInfo{MaxThreads: 8} +} + +func (udf *appParStructTableUDF) NewLocalState() any { + return ¶llelIncTableLocalState{ + start: 0, + end: -1, + } +} + +func (udf *appParStructTableUDF) FillRow(localState any, row Row) (bool, error) { + state := localState.(*parallelIncTableLocalState) + + if state.start >= state.end { + // Claim a new work unit. + udf.lock.Lock() + remaining := udf.n - udf.claimed + + if remaining <= 0 { + // No more work. + udf.lock.Unlock() + return false, nil + } else if remaining >= 2024 { + remaining = 2024 + } + + state.start = udf.claimed + udf.claimed += remaining + state.end = udf.claimed + udf.lock.Unlock() + } + + state.start++ + err := SetRowValue(row, 0, state.start) + if err != nil { + return true, err + } + err = SetRowValue(row, 1, state.start) + return true, err + +} + +func (udf *appParStructTableUDF) GetValue(r, c int) any { + if c == 0 { + return int64(r + 1) + } else { + return uint8(r + 1) + } +} + +func (udf *appParStructTableUDF) GetTypes() []any { + return []any{int64(0), uint8(0)} +} + +func (udf *appParStructTableUDF) Cardinality() *CardinalityInfo { + return nil +} + +func TestAppendParallelRowSource(t *testing.T) { + t.Parallel() + sc := ` + CREATE TABLE test ( + id BIGINT, + uint8 UTINYINT + )` + + c, db, con, a := prepareAppender(t, sc) + + f := appParStructTableUDF{ + lock: &sync.Mutex{}, + claimed: 0, + n: 3000, + } + + err := a.AppendTableSource(NewAppenderParallelRowSource(&f)) + require.NoError(t, err) + + err = a.Flush() + require.NoError(t, err) + + // Verify results. + res, err := sql.OpenDB(c).QueryContext(context.Background(), `SELECT * FROM test ORDER BY id`) + require.NoError(t, err) + + values := f.GetTypes() + args := make([]any, len(values)) + for i := range values { + args[i] = &values[i] + } + + count := 0 + for r := 0; res.Next(); r++ { + require.NoError(t, res.Scan(args...)) + for i, value := range values { + expected := f.GetValue(r, i) + require.Equal(t, expected, value, "incorrect value", r, i) + } + count++ + } + cleanupAppender(t, c, db, con, a) +} + +func TestAppendParallelRowSourceSingle(t *testing.T) { + t.Parallel() + sc := ` + CREATE TABLE test ( + id BIGINT, + )` + + c, db, con, a := prepareAppender(t, sc) + + f := parallelIncTableUDF{ + lock: &sync.Mutex{}, + n: 3000, + } + + err := a.AppendTableSource(NewAppenderParallelRowSource(&f)) + require.NoError(t, err) + + err = a.Flush() + require.NoError(t, err) + + // Verify results. + res, err := sql.OpenDB(c).QueryContext(context.Background(), `SELECT * FROM test ORDER BY id`) + require.NoError(t, err) + + values := f.GetTypes() + args := make([]any, len(values)) + for i := range values { + args[i] = &values[i] + } + + count := 0 + for r := 0; res.Next(); r++ { + require.NoError(t, res.Scan(args...)) + for i, value := range values { + expected := f.GetValue(r, i) + //fmt.Println(args) + //fmt.Println(expected, value) + require.Equal(t, expected, value, "incorrect value", r, i) + } + count++ + } + cleanupAppender(t, c, db, con, a) +} + +func TestAppendRowSource(t *testing.T) { + t.Parallel() + sc := ` + CREATE TABLE test ( + id BIGINT, + uint8 UTINYINT + )` + c, db, con, a := prepareAppender(t, sc) + + f := appStructTableUDF{ + n: 3000, + } + + err := a.AppendTableSource(NewAppenderRowSource(&f)) + require.NoError(t, err) + + err = a.Flush() + require.NoError(t, err) + + // Verify results. + res, err := sql.OpenDB(c).QueryContext(context.Background(), `SELECT * FROM test ORDER BY id`) + require.NoError(t, err) + + values := f.GetTypes() + args := make([]any, len(values)) + for i := range values { + args[i] = &values[i] + } + + count := 0 + for r := 0; res.Next(); r++ { + require.NoError(t, res.Scan(args...)) + for i, value := range values { + expected := f.GetValue(r, i) + require.Equal(t, expected, value, "incorrect value", r, i) + } + count++ + } + cleanupAppender(t, c, db, con, a) +} + func BenchmarkAppenderNested(b *testing.B) { c, db, conn, a := prepareAppender(b, createNestedDataTableSQL) defer cleanupAppender(b, c, db, conn, a) @@ -971,11 +1219,9 @@ func BenchmarkAppenderNested(b *testing.B) { const rowCount = 600 rowsToAppend := prepareNestedData(rowCount) - b.ResetTimer() - for n := 0; n < b.N; n++ { + for b.Loop() { appendNestedData(b, a, rowsToAppend) } - b.StopTimer() } const createNestedDataTableSQL = ` @@ -1030,7 +1276,7 @@ func prepareNestedData(rowCount int) []nestedDataRow { } rowsToAppend := make([]nestedDataRow, rowCount) - for i := 0; i < rowCount; i++ { + for i := range rowCount { rowsToAppend[i].ID = int64(i) rowsToAppend[i].stringList = []string{"a", "b", "c"} rowsToAppend[i].intList = []int32{1, 2, 3} @@ -1075,3 +1321,145 @@ func appendNestedData[T require.TestingT](t T, a *Appender, rowsToAppend []neste } require.NoError(t, a.Flush()) } + +var types = map[reflect.Type]string{ + reflect.TypeFor[int8](): "TINYINT", +} + +func benchmarkAppenderSingle[T any](v T) func(*testing.B) { + return func(b *testing.B) { + if _, ok := types[reflect.TypeFor[T]()]; !ok { + b.Fatal("Type not defined in table:", reflect.TypeFor[T]()) + } + tableSQL := fmt.Sprintf(createSingleTableSQL, types[reflect.TypeFor[T]()]) + c, db, con, a := prepareAppender(b, tableSQL) + + var vec [benchmarkRowsToAppend]T = [benchmarkRowsToAppend]T{} + for i := range benchmarkRowsToAppend { + vec[i] = v + } + + for b.Loop() { + for range benchmarkRowsToAppend { + // require took up the majority of the time + err := a.AppendRow(v) + if err != nil { + b.Error(err) + } + } + } + cleanupAppender(b, c, db, con, a) + } +} + +func benchmarkAppenderRowSingle[T any](_ T) func(*testing.B) { + return func(b *testing.B) { + if _, ok := types[reflect.TypeFor[T]()]; !ok { + b.Fatal("Type not defined in table:", reflect.TypeFor[T]()) + } + tableSQL := fmt.Sprintf(createSingleTableSQL, types[reflect.TypeFor[T]()]) + c, db, con, a := prepareAppender(b, tableSQL) + + for b.Loop() { + f := incTableUDF{ + n: benchmarkRowsToAppend, + } + err := a.AppendTableSource(NewAppenderRowSource(&f)) + if err != nil { + b.Error(err) + } + } + cleanupAppender(b, c, db, con, a) + } +} + +func benchmarkAppenderParallelRowSingle[T any](_ T) func(*testing.B) { + return func(b *testing.B) { + if _, ok := types[reflect.TypeFor[T]()]; !ok { + b.Fatal("Type not defined in table:", reflect.TypeFor[T]()) + } + tableSQL := fmt.Sprintf(createSingleTableSQL, types[reflect.TypeFor[T]()]) + c, db, con, a := prepareAppender(b, tableSQL) + + for b.Loop() { + f := parallelIncTableUDF{ + lock: &sync.Mutex{}, + n: benchmarkRowsToAppend, + } + err := a.AppendTableSource(NewAppenderParallelRowSource(&f)) + if err != nil { + b.Error(err) + } + } + cleanupAppender(b, c, db, con, a) + } +} + +func benchmarkAppenderChunkSingle[T any](_ T) func(*testing.B) { + return func(b *testing.B) { + if _, ok := types[reflect.TypeFor[T]()]; !ok { + b.Fatal("Type not defined in table:", reflect.TypeFor[T]()) + } + tableSQL := fmt.Sprintf(createSingleTableSQL, types[reflect.TypeFor[T]()]) + c, db, con, a := prepareAppender(b, tableSQL) + + for b.Loop() { + f := chunkIncTableUDF{ + n: benchmarkRowsToAppend, + } + err := a.AppendTableSource(NewAppenderChunkSource(&f)) + if err != nil { + b.Error(err) + } + } + cleanupAppender(b, c, db, con, a) + } +} + +func benchmarkAppenderParallelChunkSingle[T any](_ T) func(*testing.B) { + return func(b *testing.B) { + if _, ok := types[reflect.TypeFor[T]()]; !ok { + b.Fatal("Type not defined in table:", reflect.TypeFor[T]()) + } + tableSQL := fmt.Sprintf(createSingleTableSQL, types[reflect.TypeFor[T]()]) + c, db, con, a := prepareAppender(b, tableSQL) + + for b.Loop() { + f := parallelChunkIncTableUDF{ + lock: &sync.Mutex{}, + n: benchmarkRowsToAppend, + } + err := a.AppendTableSource(NewAppenderParallelChunkSource(&f)) + if err != nil { + b.Error(err) + } + } + cleanupAppender(b, c, db, con, a) + } +} + +func BenchmarkAppenderSingle(b *testing.B) { + b.Run("int8", benchmarkAppenderSingle[int8](0)) +} + +func BenchmarkAppenderRowSingle(b *testing.B) { + b.Run("int8", benchmarkAppenderRowSingle[int8](0)) +} + +func BenchmarkAppenderParallelRowSingle(b *testing.B) { + b.Run("int8", benchmarkAppenderParallelRowSingle[int8](0)) +} + +func BenchmarkAppenderChunkSingle(b *testing.B) { + b.Run("int8", benchmarkAppenderChunkSingle[int8](0)) +} + +func BenchmarkAppenderParallelChunkSingle(b *testing.B) { + b.Run("int8", benchmarkAppenderParallelChunkSingle[int8](0)) +} + +const createSingleTableSQL = ` + CREATE TABLE test ( + nested_int_list %s, + ) +` diff --git a/data_chunk.go b/data_chunk.go index d769362f..d8749f75 100644 --- a/data_chunk.go +++ b/data_chunk.go @@ -2,6 +2,8 @@ package duckdb import "C" import ( + "sync" + "github.com/marcboeker/go-duckdb/mapping" ) @@ -18,9 +20,9 @@ type DataChunk struct { } // GetDataChunkCapacity returns the capacity of a data chunk. -func GetDataChunkCapacity() int { +var GetDataChunkCapacity = sync.OnceValue(func() int { return int(mapping.VectorSize()) -} +}) // GetSize returns the internal size of the data chunk. func (chunk *DataChunk) GetSize() int { diff --git a/row.go b/row.go index b5efacb4..628b7bd8 100644 --- a/row.go +++ b/row.go @@ -24,8 +24,8 @@ func SetRowValue[T any](row Row, colIdx int, val T) error { if projectedIdx < 0 || projectedIdx >= len(row.chunk.columns) { return nil } - vec := row.chunk.columns[projectedIdx] - return setVectorVal(&vec, row.r, val) + vec := &row.chunk.columns[projectedIdx] + return setVectorVal(vec, row.r, val) } // SetRowValue sets the value at colIdx to val. Returns an error on failure. diff --git a/table_source.go b/table_source.go new file mode 100644 index 00000000..3730d7db --- /dev/null +++ b/table_source.go @@ -0,0 +1,141 @@ +package duckdb + +type ( + tableSource interface { + // ColumnInfos returns column information for each column of the table function. + ColumnInfos() []ColumnInfo + // Cardinality returns the cardinality information of the table function. + // Optionally, if no cardinality exists, it may return nil. + Cardinality() *CardinalityInfo + } + + parallelTableSource interface { + tableSource + // Init the table source. + // Additionally, it returns information for the parallelism-aware table source. + Init() ParallelTableSourceInfo + // NewLocalState returns a thread-local execution state. + // It must return a pointer or a reference type for correct state updates. + // go-duckdb does not prevent non-reference values. + NewLocalState() any + } + + sequentialTableSource interface { + tableSource + // Init the table source. + Init() + } + + // A RowTableSource represents anything that produces rows in a non-vectorised way. + // The cardinality is requested before function initialization. + // After initializing the RowTableSource, go-duckdb requests the rows. + // It sequentially calls the FillRow method with a single thread. + RowTableSource interface { + sequentialTableSource + // FillRow takes a Row and fills it with values. + // Returns true, if there are more rows to fill. + FillRow(Row) (bool, error) + } + + // A ParallelRowTableSource represents anything that produces rows in a non-vectorised way. + // The cardinality is requested before function initialization. + // After initializing the ParallelRowTableSource, go-duckdb requests the rows. + // It simultaneously calls the FillRow method with multiple threads. + // If ParallelTableSourceInfo.MaxThreads is greater than one, FillRow must use synchronisation + // primitives to avoid race conditions. + ParallelRowTableSource interface { + parallelTableSource + // FillRow takes a Row and fills it with values. + // Returns true, if there are more rows to fill. + FillRow(any, Row) (bool, error) + } + + // A ChunkTableSource represents anything that produces rows in a vectorised way. + // The cardinality is requested before function initialization. + // After initializing the ChunkTableSource, go-duckdb requests the rows. + // It sequentially calls the FillChunk method with a single thread. + ChunkTableSource interface { + sequentialTableSource + // FillChunk takes a Chunk and fills it with values. + // Set the chunk size to 0 for end the function. + FillChunk(DataChunk) error + } + + // A ParallelChunkTableSource represents anything that produces rows in a vectorised way. + // The cardinality is requested before function initialization. + // After initializing the ParallelChunkTableSource, go-duckdb requests the rows. + // It simultaneously calls the FillChunk method with multiple threads. + // If ParallelTableSourceInfo.MaxThreads is greater than one, FillChunk must use synchronization + // primitives to avoid race conditions. + ParallelChunkTableSource interface { + parallelTableSource + // FillChunk takes a Chunk and fills it with values. + // Set the chunk size to 0 for end the function + FillChunk(any, DataChunk) error + } + + // parallelRowTSWrapper wraps a synchronous table source for a parallel context with nthreads=1 + parallelRowTSWrapper struct { + s RowTableSource + } + + // parallelChunkTSWrapper wraps a synchronous table source for a parallel context with nthreads=1 + parallelChunkTSWrapper struct { + s ChunkTableSource + } + + // ParallelTableSourceInfo contains information for initializing a parallelism-aware table source. + ParallelTableSourceInfo struct { + // MaxThreads is the maximum number of threads on which to run the table source function. + // If set to 0, it uses DuckDB's default thread configuration. + MaxThreads int + } +) + +// ParallelRow wrapper +func (s parallelRowTSWrapper) ColumnInfos() []ColumnInfo { + return s.s.ColumnInfos() +} + +func (s parallelRowTSWrapper) Cardinality() *CardinalityInfo { + return s.s.Cardinality() +} + +func (s parallelRowTSWrapper) Init() ParallelTableSourceInfo { + s.s.Init() + return ParallelTableSourceInfo{ + MaxThreads: 1, + } +} + +func (s parallelRowTSWrapper) NewLocalState() any { + return struct{}{} +} + +func (s parallelRowTSWrapper) FillRow(ls any, chunk Row) (bool, error) { + return s.s.FillRow(chunk) +} + +// ParallelChunk wrapper +func (s parallelChunkTSWrapper) ColumnInfos() []ColumnInfo { + return s.s.ColumnInfos() +} + +func (s parallelChunkTSWrapper) Cardinality() *CardinalityInfo { + return s.s.Cardinality() +} + +func (s parallelChunkTSWrapper) Init() ParallelTableSourceInfo { + s.s.Init() + return ParallelTableSourceInfo{ + MaxThreads: 1, + } +} + +func (s parallelChunkTSWrapper) NewLocalState() any { + return struct{}{} +} + +func (s parallelChunkTSWrapper) FillChunk(ls any, chunk DataChunk) error { + return s.s.FillChunk(chunk) +} diff --git a/table_udf.go b/table_udf.go index a52279e4..d71f5121 100644 --- a/table_udf.go +++ b/table_udf.go @@ -6,17 +6,13 @@ package duckdb void table_udf_bind_row(void *); void table_udf_bind_chunk(void *); -void table_udf_bind_parallel_row(void *); -void table_udf_bind_parallel_chunk(void *); typedef void (*table_udf_bind_t)(void *); void table_udf_init(void *); -void table_udf_init_parallel(void *); void table_udf_local_init(void *); typedef void (*table_udf_init_t)(void *); -void table_udf_row_callback(void *, void *); -void table_udf_chunk_callback(void *, void *); +void table_udf_callback(void *, void *); typedef void (*table_udf_callback_t)(void *, void *, void *); void table_udf_delete_callback(void *); @@ -51,91 +47,11 @@ type ( Exact bool } - // ParallelTableSourceInfo contains information for initializing a parallelism-aware table source. - ParallelTableSourceInfo struct { - // MaxThreads is the maximum number of threads on which to run the table source function. - // If set to 0, it uses DuckDB's default thread configuration. - MaxThreads int - } - tableFunctionData struct { fun any projection []int } - tableSource interface { - // ColumnInfos returns column information for each column of the table function. - ColumnInfos() []ColumnInfo - // Cardinality returns the cardinality information of the table function. - // Optionally, if no cardinality exists, it may return nil. - Cardinality() *CardinalityInfo - } - - parallelTableSource interface { - tableSource - // Init the table source. - // Additionally, it returns information for the parallelism-aware table source. - Init() ParallelTableSourceInfo - // NewLocalState returns a thread-local execution state. - // It must return a pointer or a reference type for correct state updates. - // go-duckdb does not prevent non-reference values. - NewLocalState() any - } - - sequentialTableSource interface { - tableSource - // Init the table source. - Init() - } - - // A RowTableSource represents anything that produces rows in a non-vectorised way. - // The cardinality is requested before function initialization. - // After initializing the RowTableSource, go-duckdb requests the rows. - // It sequentially calls the FillRow method with a single thread. - RowTableSource interface { - sequentialTableSource - // FillRow takes a Row and fills it with values. - // It returns true, if there are more rows to fill. - FillRow(Row) (bool, error) - } - - // A ParallelRowTableSource represents anything that produces rows in a non-vectorised way. - // The cardinality is requested before function initialization. - // After initializing the ParallelRowTableSource, go-duckdb requests the rows. - // It simultaneously calls the FillRow method with multiple threads. - // If ParallelTableSourceInfo.MaxThreads is greater than one, FillRow must use synchronisation - // primitives to avoid race conditions. - ParallelRowTableSource interface { - parallelTableSource - // FillRow takes a Row and fills it with values. - // It returns true, if there are more rows to fill. - FillRow(any, Row) (bool, error) - } - - // A ChunkTableSource represents anything that produces rows in a vectorised way. - // The cardinality is requested before function initialization. - // After initializing the ChunkTableSource, go-duckdb requests the rows. - // It sequentially calls the FillChunk method with a single thread. - ChunkTableSource interface { - sequentialTableSource - // FillChunk takes a Chunk and fills it with values. - // It returns true, if there are more chunks to fill. - FillChunk(DataChunk) error - } - - // A ParallelChunkTableSource represents anything that produces rows in a vectorised way. - // The cardinality is requested before function initialization. - // After initializing the ParallelChunkTableSource, go-duckdb requests the rows. - // It simultaneously calls the FillChunk method with multiple threads. - // If ParallelTableSourceInfo.MaxThreads is greater than one, FillChunk must use synchronization - // primitives to avoid race conditions. - ParallelChunkTableSource interface { - parallelTableSource - // FillChunk takes a Chunk and fills it with values. - // It returns true, if there are more chunks to fill. - FillChunk(any, DataChunk) error - } - // TableFunctionConfig contains any information passed to DuckDB when registering the table function. TableFunctionConfig struct { // The Arguments of the table function. @@ -150,6 +66,12 @@ type ( RowTableFunction | ParallelRowTableFunction | ChunkTableFunction | ParallelChunkTableFunction } + // parallelTableFunction implements different table function types: + // ParallelRowTableFunction and ParallelChunkTableFunction. + parallelTableFunction interface { + ParallelRowTableFunction | ParallelChunkTableFunction + } + // A RowTableFunction is a type which can be bound to return a RowTableSource. RowTableFunction = tableFunction[RowTableSource] // A ParallelRowTableFunction is a type which can be bound to return a ParallelRowTableSource. @@ -167,6 +89,26 @@ type ( } ) +func wrapRowTF(f RowTableFunction) ParallelRowTableFunction { + return ParallelRowTableFunction{ + Config: f.Config, + BindArguments: func(named map[string]any, args ...any) (ParallelRowTableSource, error) { + rts, err := f.BindArguments(named, args...) + return parallelRowTSWrapper{s: rts}, err + }, + } +} + +func wrapChunkTF(f ChunkTableFunction) ParallelChunkTableFunction { + return ParallelChunkTableFunction{ + Config: f.Config, + BindArguments: func(named map[string]any, args ...any) (ParallelChunkTableSource, error) { + rts, err := f.BindArguments(named, args...) + return parallelChunkTSWrapper{s: rts}, err + }, + } +} + func isRowIdColumn(i mapping.IdxT) bool { // FIXME: Replace this with mapping.IsRowIdColumn(i) / virtual column changes, once available in the C API. return i == 18446744073709551615 @@ -186,21 +128,11 @@ func (tfd *tableFunctionData) setColumnCount(info mapping.InitInfo) { //export table_udf_bind_row func table_udf_bind_row(infoPtr unsafe.Pointer) { - udfBindTyped[RowTableSource](infoPtr) + udfBindTyped[ParallelRowTableSource](infoPtr) } //export table_udf_bind_chunk func table_udf_bind_chunk(infoPtr unsafe.Pointer) { - udfBindTyped[ChunkTableSource](infoPtr) -} - -//export table_udf_bind_parallel_row -func table_udf_bind_parallel_row(infoPtr unsafe.Pointer) { - udfBindTyped[ParallelRowTableSource](infoPtr) -} - -//export table_udf_bind_parallel_chunk -func table_udf_bind_parallel_chunk(infoPtr unsafe.Pointer) { udfBindTyped[ParallelChunkTableSource](infoPtr) } @@ -279,20 +211,11 @@ func udfBindTyped[T tableSource](infoPtr unsafe.Pointer) { //export table_udf_init func table_udf_init(infoPtr unsafe.Pointer) { - info := mapping.InitInfo{Ptr: infoPtr} - instance := getPinned[tableFunctionData](mapping.InitGetBindData(info)) - instance.setColumnCount(info) - instance.fun.(sequentialTableSource).Init() -} - -//export table_udf_init_parallel -func table_udf_init_parallel(infoPtr unsafe.Pointer) { info := mapping.InitInfo{Ptr: infoPtr} instance := getPinned[tableFunctionData](mapping.InitGetBindData(info)) instance.setColumnCount(info) initData := instance.fun.(parallelTableSource).Init() - maxThreads := initData.MaxThreads - mapping.InitSetMaxThreads(info, mapping.IdxT(maxThreads)) + mapping.InitSetMaxThreads(info, mapping.IdxT(initData.MaxThreads)) } //export table_udf_local_init @@ -309,8 +232,8 @@ func table_udf_local_init(infoPtr unsafe.Pointer) { mapping.InitSetInitData(info, unsafe.Pointer(&h), deleteCallbackPtr) } -//export table_udf_row_callback -func table_udf_row_callback(infoPtr unsafe.Pointer, outputPtr unsafe.Pointer) { +//export table_udf_callback +func table_udf_callback(infoPtr unsafe.Pointer, outputPtr unsafe.Pointer) { info := mapping.FunctionInfo{Ptr: infoPtr} output := mapping.DataChunk{Ptr: outputPtr} @@ -323,28 +246,17 @@ func table_udf_row_callback(infoPtr unsafe.Pointer, outputPtr unsafe.Pointer) { return } - row := Row{ - chunk: &chunk, - projection: instance.projection, - } - maxSize := mapping.IdxT(GetDataChunkCapacity()) + localState := getPinned[any](mapping.FunctionGetLocalInitData(info)) switch fun := instance.fun.(type) { - case RowTableSource: - // At the end of the loop row.r must be the index of the last row. - for row.r = 0; row.r < maxSize; row.r++ { - next, errRow := fun.FillRow(row) - if errRow != nil { - mapping.FunctionSetError(info, errRow.Error()) - break - } - if !next { - break - } - } case ParallelRowTableSource: + row := Row{ + chunk: &chunk, + projection: instance.projection, + } + maxSize := mapping.IdxT(GetDataChunkCapacity()) + // At the end of the loop row.r must be the index of the last row. - localState := getPinned[any](mapping.FunctionGetLocalInitData(info)) for row.r = 0; row.r < maxSize; row.r++ { next, errRow := fun.FillRow(localState, row) if errRow != nil { @@ -355,33 +267,12 @@ func table_udf_row_callback(infoPtr unsafe.Pointer, outputPtr unsafe.Pointer) { break } } - } - mapping.DataChunkSetSize(output, row.r) -} - -//export table_udf_chunk_callback -func table_udf_chunk_callback(infoPtr unsafe.Pointer, outputPtr unsafe.Pointer) { - info := mapping.FunctionInfo{Ptr: infoPtr} - output := mapping.DataChunk{Ptr: outputPtr} - - instance := getPinned[tableFunctionData](mapping.FunctionGetBindData(info)) - - var chunk DataChunk - err := chunk.initFromDuckDataChunk(output, true) - if err != nil { - mapping.FunctionSetError(info, err.Error()) - return - } - - switch fun := instance.fun.(type) { - case ChunkTableSource: - err = fun.FillChunk(chunk) + mapping.DataChunkSetSize(output, row.r) case ParallelChunkTableSource: - localState := getPinned[any](mapping.FunctionGetLocalInitData(info)) err = fun.FillChunk(localState, chunk) - } - if err != nil { - mapping.FunctionSetError(info, err.Error()) + if err != nil { + mapping.FunctionSetError(info, err.Error()) + } } } @@ -398,6 +289,24 @@ func RegisterTableUDF[TFT TableFunction](conn *sql.Conn, name string, f TFT) err if name == "" { return getError(errAPI, errTableUDFNoName) } + + // normalise the function + var x any = f + switch tableFunc := x.(type) { + case RowTableFunction: + return registerParallelTableUDF(conn, name, wrapRowTF(tableFunc)) + case ChunkTableFunction: + return registerParallelTableUDF(conn, name, wrapChunkTF(tableFunc)) + case ParallelRowTableFunction: + return registerParallelTableUDF(conn, name, tableFunc) + case ParallelChunkTableFunction: + return registerParallelTableUDF(conn, name, tableFunc) + default: + return getError(errInternal, nil) + } +} + +func registerParallelTableUDF[TFT parallelTableFunction](conn *sql.Conn, name string, f TFT) error { function := mapping.CreateTableFunction() mapping.TableFunctionSetName(function, name) @@ -417,70 +326,30 @@ func RegisterTableUDF[TFT TableFunction](conn *sql.Conn, name string, f TFT) err mapping.TableFunctionSupportsProjectionPushdown(function, true) - // Set the config. - var x any = f - switch tableFunc := x.(type) { - case RowTableFunction: - initCallbackPtr := unsafe.Pointer(C.table_udf_init_t(C.table_udf_init)) - mapping.TableFunctionSetInit(function, initCallbackPtr) - - bindCallbackPtr := unsafe.Pointer(C.table_udf_bind_t(C.table_udf_bind_row)) - mapping.TableFunctionSetBind(function, bindCallbackPtr) - - callbackPtr := unsafe.Pointer(C.table_udf_callback_t(C.table_udf_row_callback)) - mapping.TableFunctionSetFunction(function, callbackPtr) - - config = tableFunc.Config - if tableFunc.BindArguments == nil { - return getError(errAPI, errTableUDFMissingBindArgs) - } - - case ChunkTableFunction: - initCallbackPtr := unsafe.Pointer(C.table_udf_init_t(C.table_udf_init)) - mapping.TableFunctionSetInit(function, initCallbackPtr) - - bindCallbackPtr := unsafe.Pointer(C.table_udf_bind_t(C.table_udf_bind_chunk)) - mapping.TableFunctionSetBind(function, bindCallbackPtr) + initCallbackPtr := unsafe.Pointer(C.table_udf_init_t(C.table_udf_init)) + mapping.TableFunctionSetInit(function, initCallbackPtr) - callbackPtr := unsafe.Pointer(C.table_udf_callback_t(C.table_udf_chunk_callback)) - mapping.TableFunctionSetFunction(function, callbackPtr) + localInitCallbackPtr := unsafe.Pointer(C.table_udf_init_t(C.table_udf_local_init)) + mapping.TableFunctionSetLocalInit(function, localInitCallbackPtr) - config = tableFunc.Config - if tableFunc.BindArguments == nil { - return getError(errAPI, errTableUDFMissingBindArgs) - } + callbackPtr := unsafe.Pointer(C.table_udf_callback_t(C.table_udf_callback)) + mapping.TableFunctionSetFunction(function, callbackPtr) + var x any = f + switch tableFunc := x.(type) { case ParallelRowTableFunction: - initCallbackPtr := unsafe.Pointer(C.table_udf_init_t(C.table_udf_init_parallel)) - mapping.TableFunctionSetInit(function, initCallbackPtr) - - bindCallbackPtr := unsafe.Pointer(C.table_udf_bind_t(C.table_udf_bind_parallel_row)) + bindCallbackPtr := unsafe.Pointer(C.table_udf_bind_t(C.table_udf_bind_row)) mapping.TableFunctionSetBind(function, bindCallbackPtr) - callbackPtr := unsafe.Pointer(C.table_udf_callback_t(C.table_udf_row_callback)) - mapping.TableFunctionSetFunction(function, callbackPtr) - - localInitCallbackPtr := unsafe.Pointer(C.table_udf_init_t(C.table_udf_local_init)) - mapping.TableFunctionSetLocalInit(function, localInitCallbackPtr) - config = tableFunc.Config if tableFunc.BindArguments == nil { return getError(errAPI, errTableUDFMissingBindArgs) } case ParallelChunkTableFunction: - initCallbackPtr := unsafe.Pointer(C.table_udf_init_t(C.table_udf_init_parallel)) - mapping.TableFunctionSetInit(function, initCallbackPtr) - - bindCallbackPtr := unsafe.Pointer(C.table_udf_bind_t(C.table_udf_bind_parallel_chunk)) + bindCallbackPtr := unsafe.Pointer(C.table_udf_bind_t(C.table_udf_bind_chunk)) mapping.TableFunctionSetBind(function, bindCallbackPtr) - callbackPtr := unsafe.Pointer(C.table_udf_callback_t(C.table_udf_chunk_callback)) - mapping.TableFunctionSetFunction(function, callbackPtr) - - localInitCallbackPtr := unsafe.Pointer(C.table_udf_init_t(C.table_udf_local_init)) - mapping.TableFunctionSetLocalInit(function, localInitCallbackPtr) - config = tableFunc.Config if tableFunc.BindArguments == nil { return getError(errAPI, errTableUDFMissingBindArgs) diff --git a/table_udf_test.go b/table_udf_test.go index db247b9e..546507c7 100644 --- a/table_udf_test.go +++ b/table_udf_test.go @@ -385,7 +385,6 @@ func (udf *parallelIncTableUDF) FillRow(localState any, row Row) (bool, error) { udf.claimed += remaining udf.lock.Unlock() } - state.start++ err := SetRowValue(row, 0, state.start) return true, err @@ -444,7 +443,7 @@ func (udf *parallelChunkIncTableUDF) FillChunk(localState any, chunk DataChunk) if remaining <= 0 { // No more work. udf.lock.Unlock() - return nil + return chunk.SetSize(int(remaining)) } else if remaining >= 2048 { remaining = 2048 } @@ -454,7 +453,7 @@ func (udf *parallelChunkIncTableUDF) FillChunk(localState any, chunk DataChunk) udf.lock.Unlock() for i := 0; i < int(remaining); i++ { - err := chunk.SetValue(0, i, int64(i)+state.start+1) + err := SetChunkValue(chunk, 0, i, int64(i)+state.start+1) if err != nil { return err } @@ -712,7 +711,7 @@ func (udf *chunkIncTableUDF) FillChunk(chunk DataChunk) error { return err } udf.count++ - err := chunk.SetValue(0, i, udf.count) + err := SetChunkValue(chunk, 0, i, udf.count) if err != nil { return err } diff --git a/vector_setters.go b/vector_setters.go index 2b9b8d49..3ecd64f1 100644 --- a/vector_setters.go +++ b/vector_setters.go @@ -18,7 +18,7 @@ type fnSetVectorValue func(vec *vector, rowIdx mapping.IdxT, val any) error func (vec *vector) setNull(rowIdx mapping.IdxT) { mapping.ValiditySetRowInvalid(vec.maskPtr, rowIdx) if vec.Type == TYPE_STRUCT || vec.Type == TYPE_UNION { - for i := 0; i < len(vec.childVectors); i++ { + for i := range vec.childVectors { vec.childVectors[i].setNull(rowIdx) } } @@ -311,7 +311,7 @@ func setStruct[S any](vec *vector, rowIdx mapping.IdxT, val S) error { rv := reflect.ValueOf(val) structType := rv.Type() - for i := 0; i < structType.NumField(); i++ { + for i := range structType.NumField() { if !rv.Field(i).CanInterface() { continue } @@ -326,7 +326,7 @@ func setStruct[S any](vec *vector, rowIdx mapping.IdxT, val S) error { } } - for i := 0; i < len(vec.childVectors); i++ { + for i := range vec.childVectors { child := &vec.childVectors[i] name := vec.structEntries[i].Name() v, ok := m[name] @@ -395,7 +395,8 @@ func setUUID[S any](vec *vector, rowIdx mapping.IdxT, val S) error { if len(v) != uuidLength { return castError(reflect.TypeOf(val).String(), reflect.TypeOf(uuid).String()) } - for i := 0; i < uuidLength; i++ { + // TODO: test performance with cpy instead of a loop + for i := range uuidLength { uuid[i] = v[i] } default: @@ -471,14 +472,9 @@ func setUnion[S any](vec *vector, rowIdx mapping.IdxT, val S) error { } func setVectorVal[S any](vec *vector, rowIdx mapping.IdxT, val S) error { - name, inMap := unsupportedTypeToStringMap[vec.Type] - if inMap { - return unsupportedTypeError(name) - } - switch vec.Type { case TYPE_BOOLEAN: - return setBool[S](vec, rowIdx, val) + return setBool(vec, rowIdx, val) case TYPE_TINYINT: return setNumeric[S, int8](vec, rowIdx, val) case TYPE_SMALLINT: @@ -502,33 +498,38 @@ func setVectorVal[S any](vec *vector, rowIdx mapping.IdxT, val S) error { case TYPE_TIMESTAMP, TYPE_TIMESTAMP_S, TYPE_TIMESTAMP_MS, TYPE_TIMESTAMP_NS, TYPE_TIMESTAMP_TZ: return setTS(vec, rowIdx, val) case TYPE_DATE: - return setDate[S](vec, rowIdx, val) + return setDate(vec, rowIdx, val) case TYPE_TIME, TYPE_TIME_TZ: - return setTime[S](vec, rowIdx, val) + return setTime(vec, rowIdx, val) case TYPE_INTERVAL: - return setInterval[S](vec, rowIdx, val) + return setInterval(vec, rowIdx, val) case TYPE_HUGEINT: - return setHugeint[S](vec, rowIdx, val) + return setHugeint(vec, rowIdx, val) case TYPE_VARCHAR: - return setBytes[S](vec, rowIdx, val) + return setBytes(vec, rowIdx, val) case TYPE_BLOB: - return setBytes[S](vec, rowIdx, val) + return setBytes(vec, rowIdx, val) case TYPE_DECIMAL: - return setDecimal[S](vec, rowIdx, val) + return setDecimal(vec, rowIdx, val) case TYPE_ENUM: - return setEnum[S](vec, rowIdx, val) + return setEnum(vec, rowIdx, val) case TYPE_LIST: - return setList[S](vec, rowIdx, val) + return setList(vec, rowIdx, val) case TYPE_STRUCT: - return setStruct[S](vec, rowIdx, val) + return setStruct(vec, rowIdx, val) case TYPE_MAP, TYPE_ARRAY: // FIXME: Is this already supported? And tested? return unsupportedTypeError(unsupportedTypeToStringMap[vec.Type]) case TYPE_UUID: - return setUUID[S](vec, rowIdx, val) + return setUUID(vec, rowIdx, val) case TYPE_UNION: - return setUnion[S](vec, rowIdx, val) + return setUnion(vec, rowIdx, val) default: + name, inMap := unsupportedTypeToStringMap[vec.Type] + if inMap { + return unsupportedTypeError(name) + } + return unsupportedTypeError(unknownTypeErrMsg) } }