Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 1 addition & 10 deletions internal/balancer/balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, newest []endpoi
)
}()

connections := endpointsToConnections(b.pool, newest)
connections := conn.EndpointsToConnections(b.pool, newest)
for _, c := range connections {
b.pool.Allow(ctx, c)
c.Endpoint().Touch()
Expand Down Expand Up @@ -453,12 +453,3 @@ func (b *Balancer) nextConn(ctx context.Context) (c conn.Conn, err error) {

return c, nil
}

func endpointsToConnections(p *conn.Pool, endpoints []endpoint.Endpoint) []conn.Conn {
conns := make([]conn.Conn, 0, len(endpoints))
for _, e := range endpoints {
conns = append(conns, p.Get(e))
}

return conns
}
49 changes: 49 additions & 0 deletions internal/balancer/balancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ import (
"google.golang.org/grpc/test/bufconn"

"github.com/ydb-platform/ydb-go-sdk/v3/config"
balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/conn"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint"
)

func TestBalancer_discoveryConn(t *testing.T) {
Expand Down Expand Up @@ -71,6 +74,52 @@ func TestBalancer_discoveryConn(t *testing.T) {
require.NoError(t, err)
}

func TestApplyDiscoveredEndpoints(t *testing.T) {
ctx := context.Background()

cfg := config.New()
pool := conn.NewPool(ctx, cfg)
defer func() { _ = pool.Release(ctx) }()

b := &Balancer{
driverConfig: cfg,
pool: pool,
balancerConfig: balancerConfig.Config{},
}

initial := newConnectionsState(nil, b.balancerConfig.Filter, balancerConfig.Info{}, b.balancerConfig.AllowFallback)
b.connectionsState.Store(initial)

e1 := endpoint.New("e1.example:2135", endpoint.WithIPV6([]string{"2001:db8::1"}), endpoint.WithID(1))
e2 := endpoint.New("e2.example:2135", endpoint.WithIPV6([]string{"2001:db8::2"}), endpoint.WithID(2))

// call with two endpoints
b.applyDiscoveredEndpoints(ctx, []endpoint.Endpoint{e1, e2}, "")

// connectionsState should be updated and reflect the endpoints
after := b.connections()
require.NotNil(t, after)
all := after.All()
require.Equal(t, 2, len(all))
require.Equal(t, e1.Address(), all[0].Address())
require.Equal(t, e1.NodeID(), all[0].NodeID())
require.Equal(t, e2.Address(), all[1].Address())
require.Equal(t, e2.NodeID(), all[1].NodeID())

// partially replace endpoints
e3 := endpoint.New("e3.example:2135", endpoint.WithIPV6([]string{"2001:db8::3"}), endpoint.WithID(1))
b.applyDiscoveredEndpoints(ctx, []endpoint.Endpoint{e2, e3}, "")
// connectionsState should be updated and reflect the endpoints
after = b.connections()
require.NotNil(t, after)
all = after.All()
require.Equal(t, 2, len(all))
require.Equal(t, e2.Address(), all[0].Address())
require.Equal(t, e2.NodeID(), all[0].NodeID())
require.Equal(t, e3.Address(), all[1].Address())
require.Equal(t, e3.NodeID(), all[1].NodeID())
}

// Mock resolver
//

Expand Down
9 changes: 9 additions & 0 deletions internal/conn/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ type Pool struct {
done chan struct{}
}

func EndpointsToConnections(p *Pool, endpoints []endpoint.Endpoint) []Conn {
conns := make([]Conn, 0, len(endpoints))
for _, e := range endpoints {
conns = append(conns, p.Get(e))
}

return conns
}

func (p *Pool) DialTimeout() time.Duration {
return p.config.DialTimeout()
}
Expand Down
139 changes: 139 additions & 0 deletions internal/conn/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -612,3 +612,142 @@ func TestPool_ConnParker(t *testing.T) {
require.Equal(t, 3, tickCount)
})
}

func TestEndpointsToConnections(t *testing.T) {
t.Run("CreatesConnectionsForEndpoints", func(t *testing.T) {
ctx := context.Background()
config := &mockConfig{
dialTimeout: 5 * time.Second,
connectionTTL: 0,
}
pool := NewPool(ctx, config)
defer func() {
_ = pool.Release(ctx)
}()

require.Equal(t, 0, pool.conns.Len())

e1 := endpoint.New("e1:2135")
e2 := endpoint.New("e2:2135")

conns := EndpointsToConnections(pool, []endpoint.Endpoint{e1, e2})

require.Len(t, conns, 2)
require.Equal(t, 2, pool.conns.Len())

require.Equal(t, pool.Get(e1), conns[0])
require.Equal(t, pool.Get(e2), conns[1])
})

t.Run("ReusesExistingConnections", func(t *testing.T) {
ctx := context.Background()
config := &mockConfig{
dialTimeout: 5 * time.Second,
connectionTTL: 0,
}
pool := NewPool(ctx, config)
defer func() {
_ = pool.Release(ctx)
}()

e := endpoint.New("reuse:2135")

existing := pool.Get(e)
require.NotNil(t, existing)

initialLen := pool.conns.Len()

conns := EndpointsToConnections(pool, []endpoint.Endpoint{e})

require.Len(t, conns, 1)
require.Equal(t, existing, conns[0])

require.Equal(t, initialLen, pool.conns.Len())
})

t.Run("IPv6AndHostOverrideUniqueKeys", func(t *testing.T) {
ctx := context.Background()
config := &mockConfig{
dialTimeout: 5 * time.Second,
connectionTTL: 0,
}
pool := NewPool(ctx, config)
defer func() {
_ = pool.Release(ctx)
}()

// ensure empty pool
require.Equal(t, 0, pool.conns.Len())

// address is a dns-style host:port, ipv6 provides resolved ip used in Key.Address()
e1 := endpoint.New("example.com:2135", endpoint.WithIPV6([]string{"2001:db8::1"}))
// if node is rebooted with different ssl name override, we need a different connection
e2 := endpoint.New(
"example.com:2135",
endpoint.WithIPV6([]string{"2001:db8::1"}),
endpoint.WithSslTargetNameOverride("override"),
)
// different ipv6 -> different Address()
e3 := endpoint.New("example.com:2135", endpoint.WithIPV6([]string{"2001:db8::2"}), endpoint.WithID(2))
// same ipv6 as e1 but different NodeID -> different Key.NodeID
e4 := endpoint.New("example.com:2135", endpoint.WithIPV6([]string{"2001:db8::1"}), endpoint.WithID(1))

endpoints := []endpoint.Endpoint{e1, e2, e3, e4}
conns := EndpointsToConnections(pool, endpoints)

require.Len(t, conns, len(endpoints))
require.Equal(t, 4, pool.conns.Len())

for i, e := range endpoints {
got := conns[i]
require.NotNil(t, got)
require.Equal(t, pool.Get(e), got)
cc, ok := pool.conns.Get(e.Key())
require.True(t, ok)
require.Equal(t, cc, got)
}

require.Equal(t, e2.Key().HostOverride, "override")
require.Equal(t, e4.Key().NodeID, uint32(1))
})

t.Run("AddNewEndpointAndNodeIDVariation", func(t *testing.T) {
ctx := context.Background()
config := &mockConfig{
dialTimeout: 5 * time.Second,
connectionTTL: 0,
}
pool := NewPool(ctx, config)
defer func() {
_ = pool.Release(ctx)
}()

// initial two endpoints with IPv6 and distinct NodeIDs
e1 := endpoint.New("e1.example:2135", endpoint.WithIPV6([]string{"2001:db8::1"}), endpoint.WithID(1))
e2 := endpoint.New("e2.example:2135", endpoint.WithIPV6([]string{"2001:db8::2"}), endpoint.WithID(2))

// create initial connections
initialConns := EndpointsToConnections(pool, []endpoint.Endpoint{e1, e2})
require.Len(t, initialConns, 2)
require.Equal(t, 2, pool.conns.Len())
require.Equal(t, pool.Get(e1), initialConns[0])
require.Equal(t, pool.Get(e2), initialConns[1])

// add a new unique endpoint e3 -> pool should grow
e3 := endpoint.New("e3.example:2135", endpoint.WithIPV6([]string{"2001:db8::3"}), endpoint.WithID(3))
connsAfterE3 := EndpointsToConnections(pool, []endpoint.Endpoint{e1, e2, e3})
require.Len(t, connsAfterE3, 3)
require.Equal(t, 3, pool.conns.Len())
require.Equal(t, pool.Get(e3), connsAfterE3[2])

// now use same address as e1 but different NodeID (and same ipv6) -> should create new conn
e1DifferentNode := endpoint.New("e1.example:2135", endpoint.WithIPV6([]string{"2001:db8::1"}), endpoint.WithID(99))
connsAfterNodeChange := EndpointsToConnections(pool, []endpoint.Endpoint{e1DifferentNode})
require.Len(t, connsAfterNodeChange, 1)
// pool size must increase by one
require.Equal(t, 4, pool.conns.Len())
// returned conn corresponds to the new endpoint key
require.Equal(t, pool.Get(e1DifferentNode), connsAfterNodeChange[0])
require.Equal(t, pool.Get(e1), initialConns[0])
})
}
Loading