Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions stage/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ func ParseStage(stage *Stage, stages Map) (*Stage, error) {
}
}
stages[stage.Id] = stage
err := processStreams(stage, stages)
if err != nil {
return nil, fmt.Errorf("failed to process streams for stage %s: %w", stage.Id, err)
}

for _, nextStagePath := range stage.NextStagePaths {
if nextStage, err := ParseStageFromFile(nextStagePath, stages); err != nil {
return nil, err
Expand Down Expand Up @@ -150,3 +155,57 @@ func checkStageLinks(stage *Stage) error {
}
return nil
}

func processStreams(stage *Stage, stages Map) error {
if len(stage.Streams) == 0 {
stage.seed = stage.States.RandSeed
return nil
}

for _, spec := range stage.Streams {
if spec.StreamCount <= 0 {
return fmt.Errorf("stream_count must be positive, got %d for stream %s", spec.StreamCount, spec.StreamPath)
}

if len(spec.Seeds) > 0 {
if len(spec.Seeds) != 1 && len(spec.Seeds) != spec.StreamCount {
return fmt.Errorf("seeds array length (%d) must be either 1 or equal to stream_count (%d) for stream %s",
len(spec.Seeds), spec.StreamCount, spec.StreamPath)
}
stage.States.RandSeed = 0 // Disable random seed generation when custom seeds are provided
}

streamPath, err := spec.GetValidatedPath(stage.BaseDir)
if err != nil {
return err
}
for i := 0; i < spec.StreamCount; i++ {
streamStage, err := ReadStageFromFile(streamPath)
if err != nil {
return fmt.Errorf("failed to read stream file %s: %w", streamPath, err)
}

// Set unique ID for this stream instance
baseId := fileNameWithoutPathAndExt(streamPath)
streamStage.Id = fmt.Sprintf("%s_stream_%d", baseId, i+1)

// Set custom seed if configured
if seed, hasCustomSeed := spec.GetSeedForInstance(i); hasCustomSeed {
streamStage.seed = seed
log.Info().Str("stream_stage", streamStage.Id).Int64("custom_seed", seed).Int("instance", i+1).Msg("stream assigned custom seed")
} else {
// No seed configured, use stage's RandSeed + instance offset
streamStage.seed = stage.States.RandSeed + int64(i-1)
log.Info().Str("stream_stage", streamStage.Id).Int64("generated_seed", streamStage.seed).Int64("base_seed", stage.States.RandSeed).Int("instance", i+1).Msg("stream assigned generated seed")
}

stages[streamStage.Id] = streamStage
stage.NextStages = append(stage.NextStages, streamStage)
streamStage.wgPrerequisites.Add(1)
}
}

stage.Streams = nil

return nil
}
11 changes: 7 additions & 4 deletions stage/mysql_run_recorder.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ import (
"context"
"database/sql"
_ "embed"
_ "github.com/go-sql-driver/mysql"
"pbench/log"
"pbench/utils"

_ "github.com/go-sql-driver/mysql"
)

var (
Expand Down Expand Up @@ -65,7 +66,7 @@ VALUES (?, ?, ?, 0, 0, 0, ?)`

func (m *MySQLRunRecorder) RecordQuery(_ context.Context, s *Stage, result *QueryResult) {
recordNewQuery := `INSERT INTO pbench_queries (run_id, stage_id, query_file, query_index, query_id, sequence_no,
cold_run, succeeded, start_time, end_time, row_count, expected_row_count, duration_ms, info_url) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`
cold_run, succeeded, start_time, end_time, row_count, expected_row_count, duration_ms, info_url, seed) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`
var queryFile string
if result.Query.File != nil {
queryFile = *result.Query.File
Expand All @@ -83,11 +84,13 @@ cold_run, succeeded, start_time, end_time, row_count, expected_row_count, durati
result.RowCount, sql.NullInt32{
Int32: int32(result.Query.ExpectedRowCount),
Valid: result.Query.ExpectedRowCount >= 0,
}, result.Duration.Milliseconds(), result.InfoUrl)
}, result.Duration.Milliseconds(), result.InfoUrl, result.Seed)
log.Info().Str("stage_id", result.StageId).Stringer("start_time", result.StartTime).Stringer("end_time", result.EndTime).
Str("info_url", result.InfoUrl).Int64("seed", result.Seed).Msg("recorded query result to MySQL")
if err != nil {
log.Error().EmbedObject(result).Err(err).Msg("failed to send query summary to MySQL")
}
updateRunInfo := `UPDATE pbench_runs SET start_time = ?, queries_ran = queries_ran + 1, failed = ?, mismatch = ? WHERE run_id = ?`
updateRunInfo := `UPDATE pbench_runs SET start_time = ?, queries_ran = queries_ran + 1, failed = ? , mismatch = ? WHERE run_id = ?`
res, err := m.db.Exec(updateRunInfo, s.States.RunStartTime, m.failed, m.mismatch, m.runId)
if err != nil {
log.Error().Err(err).Str("run_name", s.States.RunName).Int64("run_id", m.runId).
Expand Down
4 changes: 3 additions & 1 deletion stage/result.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package stage

import (
"github.com/rs/zerolog"
"pbench/log"
"time"

"github.com/rs/zerolog"
)

type QueryResult struct {
StageId string
Seed int64
Query *Query
QueryId string
InfoUrl string
Expand Down
102 changes: 83 additions & 19 deletions stage/stage.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ type Stage struct {
// Use RandomlyExecuteUntil to specify a duration like "1h" or an integer as the number of queries should be executed
// before exiting.
RandomlyExecuteUntil *string `json:"randomly_execute_until,omitempty"`
// If NoRandomDuplicates is set to true, queries will not be repeated during random execution
// until all queries have been executed once. After that, the selection pool resets if more
// executions are needed.
NoRandomDuplicates *bool `json:"no_random_duplicates,omitempty"`
// If not set, the default is 1. The default value is set when the stage is run.
ColdRuns *int `json:"cold_runs,omitempty" validate:"omitempty,gte=0"`
// If not set, the default is 0.
Expand All @@ -87,6 +91,9 @@ type Stage struct {
// knob was not set to true.
SaveJson *bool `json:"save_json,omitempty"`
NextStagePaths []string `json:"next,omitempty"`
// StreamSpecs allows specifying streams to launch dynamically with custom counts and seeds
// Format: [{"stream_file_path": "path/to/stream.json", "stream_count": 5, "seeds": [123, 456]}]
Streams []Streams `json:"streams,omitempty"`

// BaseDir is set to the directory path of this stage's location. It is used to locate the descendant stages when
// their locations are specified using relative paths. It is not possible to set this in a stage definition json file.
Expand All @@ -101,6 +108,11 @@ type Stage struct {
// Client is by default passed down to descendant stages.
Client *presto.Client `json:"-"`

// Stream instance information for custom seeding and identification
// Descendant stages will **NOT** inherit this value from their parents so this is declared as a value not a pointer.
// Custom seed for this stage instance, nil if using default seeding
seed int64 `json:"-"`

// Convenient access to the expected row count array under the current schema.
expectedRowCountInCurrentSchema []int
// Convenient access to the catalog, schema, and timezone
Expand Down Expand Up @@ -150,6 +162,7 @@ func (s *Stage) Run(ctx context.Context) int {

go func() {
s.States.wgExitMainStage.Wait()
close(s.States.resultChan)
// wgExitMainStage goes down to 0 after all the goroutines finish. Then we exit the driver by
// closing the timeToExit channel, which will trigger the graceful shutdown process -
// (flushing the log file, writing the final time log summary, etc.).
Expand All @@ -174,20 +187,30 @@ func (s *Stage) Run(ctx context.Context) int {

for {
select {
case result := <-s.States.resultChan:
case result, ok := <-s.States.resultChan:
if !ok {
// resultChan closed: all results received, finalize and exit
s.States.RunFinishTime = time.Now()
for _, recorder := range s.States.runRecorders {
recorder.RecordRun(utils.GetCtxWithTimeout(time.Second*5), s, results)
}
return int(s.States.exitCode.Load())
}
results = append(results, result)
for _, recorder := range s.States.runRecorders {
recorder.RecordQuery(utils.GetCtxWithTimeout(time.Second*5), s, result)
}
case sig := <-timeToExit:
if sig != nil {
// Cancel the context and wait for the goroutines to exit.
s.States.AbortAll(fmt.Errorf(sig.String()))
case sig, ok := <-timeToExit:
if !ok {
// timeToExit channel closed, no more signals — continue to receive results
continue
}
s.States.RunFinishTime = time.Now()
for _, recorder := range s.States.runRecorders {
recorder.RecordRun(utils.GetCtxWithTimeout(time.Second*5), s, results)
if sig != nil {
// Received shutdown signal; cancel ongoing queries
log.Info().Msgf("Shutdown signal received: %v. Aborting queries...", sig)
s.States.AbortAll(fmt.Errorf("%s", sig.String()))
// Keep receiving results until resultChan is closed
}
return int(s.States.exitCode.Load())
}
Expand Down Expand Up @@ -237,8 +260,11 @@ func (s *Stage) run(ctx context.Context) (returnErr error) {
if preStageErr != nil {
return fmt.Errorf("pre-stage script execution failed: %w", preStageErr)
}
if len(s.Queries)+len(s.QueryFiles) > 0 {
if len(s.Queries)+len(s.QueryFiles)+len(s.Streams) > 0 {
if *s.RandomExecution {
if s.RandomlyExecuteUntil == nil {
return fmt.Errorf("randomly_execute_until must be set for random execution in stage %s", s.Id)
}
returnErr = s.runRandomly(ctx)
} else {
returnErr = s.runSequentially(ctx)
Expand Down Expand Up @@ -343,21 +369,57 @@ func (s *Stage) runRandomly(ctx context.Context) error {
return nil
}
}
r := rand.New(rand.NewSource(s.States.RandSeed))

r := rand.New(rand.NewSource(s.seed))
log.Info().Str("stream_id", s.Id).Int64("custom_seed", s.seed).Msg("initialized with seed")
s.States.RandSeedUsed = true
log.Info().Int64("seed", s.States.RandSeed).Msg("random source seeded")
randIndexUpperBound := len(s.Queries) + len(s.QueryFiles)
for i := 1; continueExecution(i); i++ {
idx := r.Intn(randIndexUpperBound)
if i <= s.States.RandSkip {
if i == s.States.RandSkip {
log.Info().Msgf("skipped %d random selections", i)

totalQueries := len(s.Queries) + len(s.QueryFiles)

// refreshIndices generates a new set of random indices for selecting queries.
// If NoRandomDuplicates is set to true, it generates a shuffled list of all indices.
// Otherwise, it generates a list of random indices with possible duplicates.
refreshIndices := func() []int {
indices := make([]int, totalQueries)
if s.NoRandomDuplicates != nil && *s.NoRandomDuplicates {
for i := 0; i < totalQueries; i++ {
indices[i] = i
}
r.Shuffle(len(indices), func(i, j int) {
indices[i], indices[j] = indices[j], indices[i]
})
} else {
for i := 0; i < totalQueries; i++ {
indices[i] = r.Intn(totalQueries)
}
}
return indices
}

executionCount := 1
var currentIndices []int
var indexPosition int

for continueExecution(executionCount) {
// Refresh indices when all queries have been used
if currentIndices == nil || indexPosition >= len(currentIndices) {
currentIndices = refreshIndices()
indexPosition = 0
}

idx := currentIndices[indexPosition]
indexPosition++

if executionCount <= s.States.RandSkip {
if executionCount == s.States.RandSkip {
log.Info().Msgf("skipped %d random selections", executionCount)
}
executionCount++
continue
}

if idx < len(s.Queries) {
// Run query embedded in the json file.
pseudoFileName := fmt.Sprintf("rand_%d", i)
pseudoFileName := fmt.Sprintf("rand_%d", executionCount)
if err := s.runQueries(ctx, s.Queries[idx:idx+1], &pseudoFileName, 0); err != nil {
return err
}
Expand All @@ -367,11 +429,12 @@ func (s *Stage) runRandomly(ctx context.Context) error {
if relPath, relErr := filepath.Rel(s.BaseDir, queryFile); relErr == nil {
fileAlias = relPath
}
fileAlias = fmt.Sprintf("rand_%d_%s", i, fileAlias)
fileAlias = fmt.Sprintf("rand_%d_%s", executionCount, fileAlias)
if err := s.runQueryFile(ctx, queryFile, nil, &fileAlias); err != nil {
return err
}
}
executionCount++
}
log.Info().Msg("random execution concluded.")
return nil
Expand Down Expand Up @@ -476,6 +539,7 @@ func (s *Stage) runQuery(ctx context.Context, query *Query) (result *QueryResult

result = &QueryResult{
StageId: s.Id,
Seed: s.seed,
Query: query,
StartTime: time.Now(),
}
Expand Down
10 changes: 10 additions & 0 deletions stage/stage_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ func (s *Stage) MergeWith(other *Stage) *Stage {
if other.RandomExecution != nil {
s.RandomExecution = other.RandomExecution
}
if other.NoRandomDuplicates != nil {
s.NoRandomDuplicates = other.NoRandomDuplicates
}
if other.RandomlyExecuteUntil != nil {
s.RandomlyExecuteUntil = other.RandomlyExecuteUntil
}
Expand All @@ -92,6 +95,7 @@ func (s *Stage) MergeWith(other *Stage) *Stage {
}
s.NextStagePaths = append(s.NextStagePaths, other.NextStagePaths...)
s.BaseDir = other.BaseDir
s.Streams = append(s.Streams, other.Streams...)

s.PreStageShellScripts = append(s.PreStageShellScripts, other.PreStageShellScripts...)
s.PostQueryShellScripts = append(s.PostQueryShellScripts, other.PostQueryShellScripts...)
Expand Down Expand Up @@ -194,6 +198,9 @@ func (s *Stage) setDefaults() {
if s.RandomExecution == nil {
s.RandomExecution = &falseValue
}
if s.NoRandomDuplicates == nil {
s.NoRandomDuplicates = &falseValue
}
if s.AbortOnError == nil {
s.AbortOnError = &falseValue
}
Expand Down Expand Up @@ -235,6 +242,9 @@ func (s *Stage) propagateStates() {
if nextStage.RandomExecution == nil {
nextStage.RandomExecution = s.RandomExecution
}
if nextStage.NoRandomDuplicates == nil {
nextStage.NoRandomDuplicates = s.NoRandomDuplicates
}
if nextStage.RandomlyExecuteUntil == nil {
nextStage.RandomlyExecuteUntil = s.RandomlyExecuteUntil
}
Expand Down
Loading