diff --git a/scan.go b/scan.go index b00a2fe..0e5a73d 100644 --- a/scan.go +++ b/scan.go @@ -18,63 +18,72 @@ type scanIterator struct { segments int keysOnly bool - doneWG sync.WaitGroup - + doneWG sync.WaitGroup resultChan chan query.Result closeOnce sync.Once ctx context.Context cancel context.CancelFunc + errChan chan error } func (s *scanIterator) trySend(result query.Result) bool { - log.Debugw("sending scan result", "Result", result) select { case <-s.ctx.Done(): return true case s.resultChan <- result: + return false } - return false } func (s *scanIterator) worker(ctx context.Context, segment int64, totalSegments int64) { defer s.doneWG.Done() defer log.Debug("scan worker done") log.Debug("scan worker starting") + var exclusiveStartKey map[string]*dynamodb.AttributeValue for { - req := &dynamodb.ScanInput{ - TableName: &s.tableName, - Segment: &segment, - TotalSegments: &totalSegments, - ExclusiveStartKey: exclusiveStartKey, - } + select { + case <-ctx.Done(): + s.errChan <- ctx.Err() + return + default: + req := &dynamodb.ScanInput{ + TableName: &s.tableName, + Segment: &segment, + TotalSegments: &totalSegments, + ExclusiveStartKey: exclusiveStartKey, + } - if s.indexName != "" { - req.IndexName = &s.indexName - } + if s.indexName != "" { + req.IndexName = &s.indexName + } - if s.keysOnly { - req.ProjectionExpression = aws.String(attrNameKey) - } + if s.keysOnly { + req.ProjectionExpression = aws.String(attrNameKey) + } - log.Debugw("scanning", "Req", req) - res, err := s.ddbClient.ScanWithContext(s.ctx, req) - if err != nil { - if s.trySend(query.Result{Error: err}) { - return + log.Debugw("scanning", "Req", req) + res, err := s.ddbClient.ScanWithContext(s.ctx, req) + if err != nil { + if s.trySend(query.Result{Error: err}) { + s.errChan <- err + return + } } - } - for _, itemMap := range res.Items { - log.Debugw("scan got items", "NumItems", len(res.Items)) - result := itemMapToQueryResult(itemMap, s.keysOnly) - if s.trySend(result) { + + for _, itemMap := range res.Items { + log.Debugw("scan got items", "NumItems", len(res.Items)) + result := itemMapToQueryResult(itemMap, s.keysOnly) + if s.trySend(result) { + return + } + } + + if res.LastEvaluatedKey == nil { return } + exclusiveStartKey = res.LastEvaluatedKey } - if res.LastEvaluatedKey == nil { - return - } - exclusiveStartKey = res.LastEvaluatedKey } } @@ -95,29 +104,51 @@ func itemMapToQueryResult(itemMap map[string]*dynamodb.AttributeValue, keysOnly func (s *scanIterator) start(ctx context.Context) { s.ctx, s.cancel = context.WithCancel(ctx) s.resultChan = make(chan query.Result) + s.errChan = make(chan error, s.segments) s.doneWG.Add(s.segments) + totalSegments := int64(s.segments) for i := 0; i < s.segments; i++ { segment := int64(i) - go s.worker(ctx, segment, totalSegments) + go s.worker(s.ctx, segment, totalSegments) } - // Don't wait on the Close() method to be called to close the chan; - // close it as soon as there are no more results, so that Next() will return false. - // If Close() is called, it races with this, hence the use of sync.Once. + go func() { s.doneWG.Wait() - s.closeOnce.Do(func() { close(s.resultChan) }) + s.closeOnce.Do(func() { + close(s.errChan) + close(s.resultChan) + }) }() } func (s *scanIterator) Next() (query.Result, bool) { + select { + case <-s.ctx.Done(): + return query.Result{Error: s.ctx.Err()}, false + default: + } + + select { + case err, ok := <-s.errChan: + if ok && err != nil { + return query.Result{Error: err}, false + } + default: + } + result, ok := <-s.resultChan return result, ok } func (s *scanIterator) Close() error { - s.cancel() - s.doneWG.Wait() - s.closeOnce.Do(func() { close(s.resultChan) }) + s.closeOnce.Do(func() { + if s.cancel != nil { + s.cancel() + } + s.doneWG.Wait() + close(s.errChan) + close(s.resultChan) + }) return nil }