Skip to content

Commit a2abf6d

Browse files
Implemented token bucket rate limiter
1 parent 34fdeeb commit a2abf6d

File tree

6 files changed

+172
-0
lines changed

6 files changed

+172
-0
lines changed

AUTHORS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,4 @@ Dmitry Kropachev <[email protected]>
141141
Oliver Boyle <[email protected]>
142142
Jackson Fleming <[email protected]>
143143
Sylwia Szunejko <[email protected]>
144+
Rostyslav Porokhnya <[email protected]>

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77
## [Unreleased]
88

99
### Added
10+
- Added Queries Rate Limiter which uses Token Bucket algorithm to the Session struct.
11+
Added RateLimiterConfig to the ClusterConfig struct. (#1731)
1012

1113
### Changed
1214

cluster.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,9 @@ type ClusterConfig struct {
239239

240240
// internal config for testing
241241
disableControlConn bool
242+
243+
// If Session has RateLimiterConfig then queries will be limited using RateLimiter
244+
RateLimiterConfig *RateLimiterConfig
242245
}
243246

244247
type Dialer interface {

rate_limiter.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
package gocql
2+
3+
import (
4+
"sync"
5+
"time"
6+
)
7+
8+
// RateLimiterConfig holds the configuration parameters for the rate limiter, which uses Token Bucket approach.
9+
//
10+
// Fields:
11+
//
12+
// - rate: Allowed requests per second
13+
// - Burst: Maximum number of burst requests
14+
//
15+
// Example:
16+
// RateLimiterConfig{
17+
// rate: 300000,
18+
// burst: 150,
19+
// }
20+
type RateLimiterConfig struct {
21+
rate float64
22+
burst int
23+
}
24+
25+
type tokenBucket struct {
26+
rate float64
27+
burst int
28+
tokens int
29+
lastRefilled time.Time
30+
mu sync.Mutex
31+
}
32+
33+
func (tb *tokenBucket) refill() {
34+
tb.mu.Lock()
35+
defer tb.mu.Unlock()
36+
now := time.Now()
37+
tokensToAdd := int(tb.rate * now.Sub(tb.lastRefilled).Seconds())
38+
tb.tokens = min(tb.tokens+tokensToAdd, tb.burst)
39+
tb.lastRefilled = now
40+
}
41+
42+
func (tb *tokenBucket) Allow() bool {
43+
tb.refill()
44+
tb.mu.Lock()
45+
defer tb.mu.Unlock()
46+
if tb.tokens > 0 {
47+
tb.tokens--
48+
return true
49+
}
50+
return false
51+
}
52+
53+
func min(a, b int) int {
54+
if a < b {
55+
return a
56+
}
57+
return b
58+
}
59+
60+
type ConfigurableRateLimiter struct {
61+
tb tokenBucket
62+
}
63+
64+
func NewConfigurableRateLimiter(rate float64, burst int) *ConfigurableRateLimiter {
65+
tb := tokenBucket{
66+
rate: rate,
67+
burst: burst,
68+
tokens: burst,
69+
lastRefilled: time.Now(),
70+
}
71+
return &ConfigurableRateLimiter{tb}
72+
}
73+
74+
func (rl *ConfigurableRateLimiter) Allow() bool {
75+
return rl.tb.Allow()
76+
}

rate_limiter_test.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package gocql
2+
3+
import (
4+
"fmt"
5+
"sync"
6+
"testing"
7+
)
8+
9+
const queries = 100
10+
11+
const skipRateLimiterTestMsg = "Skipping rate limiter test, due to limit of simultaneously alive goroutines. Should be tested locally"
12+
13+
func TestRateLimiter50k(t *testing.T) {
14+
t.Skip(skipRateLimiterTestMsg)
15+
fmt.Println("Running rate limiter test with 50_000 workers")
16+
RunRateLimiterTest(t, 50_000)
17+
}
18+
19+
func TestRateLimiter100k(t *testing.T) {
20+
t.Skip(skipRateLimiterTestMsg)
21+
fmt.Println("Running rate limiter test with 100_000 workers")
22+
RunRateLimiterTest(t, 100_000)
23+
}
24+
25+
func TestRateLimiter200k(t *testing.T) {
26+
t.Skip(skipRateLimiterTestMsg)
27+
fmt.Println("Running rate limiter test with 200_000 workers")
28+
RunRateLimiterTest(t, 200_000)
29+
}
30+
31+
func RunRateLimiterTest(t *testing.T, workerCount int) {
32+
cluster := createCluster()
33+
cluster.RateLimiterConfig = &RateLimiterConfig{
34+
rate: 300000,
35+
burst: 100,
36+
}
37+
38+
session := createSessionFromCluster(cluster, t)
39+
defer session.Close()
40+
41+
execRelease(session.Query("drop keyspace if exists pargettest"))
42+
execRelease(session.Query("create keyspace pargettest with replication = {'class' : 'SimpleStrategy', 'replication_factor' : 1}"))
43+
execRelease(session.Query("drop table if exists pargettest.test"))
44+
execRelease(session.Query("create table pargettest.test (a text, b int, primary key(a))"))
45+
execRelease(session.Query("insert into pargettest.test (a, b) values ( 'a', 1)"))
46+
47+
var wg sync.WaitGroup
48+
49+
for i := 1; i <= workerCount; i++ {
50+
wg.Add(1)
51+
52+
go func() {
53+
defer wg.Done()
54+
for j := 0; j < queries; j++ {
55+
iterRelease(session.Query("select * from pargettest.test where a='a'"))
56+
}
57+
}()
58+
}
59+
60+
wg.Wait()
61+
}
62+
63+
func iterRelease(query *Query) {
64+
_, err := query.Iter().SliceMap()
65+
if err != nil {
66+
println(err.Error())
67+
panic(err)
68+
}
69+
query.Release()
70+
}
71+
72+
func execRelease(query *Query) {
73+
if err := query.Exec(); err != nil {
74+
println(err.Error())
75+
panic(err)
76+
}
77+
query.Release()
78+
}

session.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ type Session struct {
8383
isInitialized bool
8484

8585
logger StdLogger
86+
87+
rateLimiter *ConfigurableRateLimiter
8688
}
8789

8890
var queryPool = &sync.Pool{
@@ -168,6 +170,10 @@ func NewSession(cfg ClusterConfig) (*Session, error) {
168170
s.frameObserver = cfg.FrameHeaderObserver
169171
s.streamObserver = cfg.StreamObserver
170172

173+
if cfg.RateLimiterConfig != nil {
174+
s.rateLimiter = NewConfigurableRateLimiter(cfg.RateLimiterConfig.rate, cfg.RateLimiterConfig.burst)
175+
}
176+
171177
//Check the TLS Config before trying to connect to anything external
172178
connCfg, err := connConfig(&s.cfg)
173179
if err != nil {
@@ -432,6 +438,12 @@ func (s *Session) SetTrace(trace Tracer) {
432438
// value before the query is executed. Query is automatically prepared
433439
// if it has not previously been executed.
434440
func (s *Session) Query(stmt string, values ...interface{}) *Query {
441+
if s.rateLimiter != nil {
442+
for !s.rateLimiter.Allow() {
443+
time.Sleep(time.Millisecond * 50)
444+
}
445+
}
446+
435447
qry := queryPool.Get().(*Query)
436448
qry.session = s
437449
qry.stmt = stmt

0 commit comments

Comments
 (0)