Skip to content

Commit d4faf1b

Browse files
authored
update: allow config files set address validation (#73)
Add DialedAddressValidator support for JSON and Protobuf config files. Signed-off-by: Gaukas Wang <[email protected]>
1 parent 8979246 commit d4faf1b

File tree

6 files changed

+431
-68
lines changed

6 files changed

+431
-68
lines changed

address_validator.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package water
2+
3+
import (
4+
"errors"
5+
)
6+
7+
var (
8+
ErrAddressValidatorNotInitialized = errors.New("address validator not initialized properly")
9+
ErrAddressValidationDenied = errors.New("address validation denied")
10+
)
11+
12+
type addressValidator struct {
13+
catchAll bool
14+
allowlist map[string][]string // map[address]networks
15+
denylist map[string][]string // map[address]networks
16+
}
17+
18+
func (a *addressValidator) validate(network, address string) error {
19+
if a.catchAll {
20+
// only check denylist, otherwise allow
21+
if a.denylist == nil {
22+
return ErrAddressValidatorNotInitialized
23+
}
24+
25+
if deniedNetworks, ok := a.denylist[address]; ok {
26+
if deniedNetworks == nil {
27+
return ErrAddressValidatorNotInitialized
28+
}
29+
30+
for _, deniedNet := range deniedNetworks {
31+
if deniedNet == network {
32+
return ErrAddressValidationDenied
33+
}
34+
}
35+
}
36+
return nil
37+
} else {
38+
// only check allowlist, otherwise deny
39+
if a.allowlist == nil {
40+
return ErrAddressValidatorNotInitialized
41+
}
42+
43+
if allowedNetworks, ok := a.allowlist[address]; ok {
44+
if allowedNetworks == nil {
45+
return ErrAddressValidatorNotInitialized
46+
}
47+
48+
for _, allowedNet := range allowedNetworks {
49+
if allowedNet == network {
50+
return nil
51+
}
52+
}
53+
}
54+
return ErrAddressValidationDenied
55+
}
56+
}

address_validator_test.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package water
2+
3+
// package water instead of water_test to access unexported struct addressValidator and its unexported fields/methods
4+
5+
import "testing"
6+
7+
func Test_addressValidator_validate(t *testing.T) {
8+
var a addressValidator
9+
10+
// test catchAll with nil denylist
11+
a.catchAll = true
12+
13+
if err := a.validate("random net", "random address"); err != ErrAddressValidatorNotInitialized {
14+
t.Errorf("Expected ErrAddressValidatorNotInitialized, got %v", err)
15+
}
16+
17+
// test nil denylist entry
18+
a.denylist = map[string][]string{
19+
"denied address": nil,
20+
}
21+
22+
if err := a.validate("random net", "denied address"); err != ErrAddressValidatorNotInitialized {
23+
t.Errorf("Expected ErrAddressValidatorNotInitialized, got %v", err)
24+
}
25+
26+
// test denied address on denied network
27+
a.denylist["denied address"] = []string{"denied net"}
28+
29+
if err := a.validate("denied net", "denied address"); err != ErrAddressValidationDenied {
30+
t.Errorf("Expected ErrAddressValidationDenied, got %v", err)
31+
}
32+
33+
// test random network with denied address
34+
if err := a.validate("random net", "denied address"); err != nil {
35+
t.Errorf("Expected nil, got %v", err)
36+
}
37+
38+
// test random address on denied network
39+
if err := a.validate("denied net", "random address"); err != nil {
40+
t.Errorf("Expected nil, got %v", err)
41+
}
42+
43+
// test not catchAll with nil allowlist
44+
a.catchAll = false
45+
46+
if err := a.validate("random net", "random address"); err != ErrAddressValidatorNotInitialized {
47+
t.Errorf("Expected ErrAddressValidatorNotInitialized, got %v", err)
48+
}
49+
50+
// test nil allowlist entry
51+
a.allowlist = map[string][]string{
52+
"allowed address": nil,
53+
}
54+
55+
if err := a.validate("random net", "allowed address"); err != ErrAddressValidatorNotInitialized {
56+
t.Errorf("Expected ErrAddressValidatorNotInitialized, got %v", err)
57+
}
58+
59+
// test allowed address on allowed network
60+
a.allowlist["allowed address"] = []string{"allowed net"}
61+
62+
if err := a.validate("allowed net", "allowed address"); err != nil {
63+
t.Errorf("Expected nil, got %v", err)
64+
}
65+
66+
// test random network with allowed address
67+
if err := a.validate("random net", "allowed address"); err != ErrAddressValidationDenied {
68+
t.Errorf("Expected ErrAddressValidationDenied, got %v", err)
69+
}
70+
71+
// test random address on allowed network
72+
if err := a.validate("allowed net", "random address"); err != ErrAddressValidationDenied {
73+
t.Errorf("Expected ErrAddressValidationDenied, got %v", err)
74+
}
75+
}

config.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,16 @@ func (c *Config) UnmarshalJSON(data []byte) error {
204204
}
205205
}
206206

207+
if c.DialedAddressValidator == nil {
208+
a := &addressValidator{
209+
catchAll: confJson.Network.AddressValidation.CatchAll,
210+
allowlist: confJson.Network.AddressValidation.Allowlist,
211+
denylist: confJson.Network.AddressValidation.Denylist,
212+
}
213+
214+
c.DialedAddressValidator = a.validate
215+
}
216+
207217
if len(confJson.Network.Listener.Network) > 0 && len(confJson.Network.Listener.Address) > 0 {
208218
c.NetworkListener, err = net.Listen(confJson.Network.Listener.Network, confJson.Network.Listener.Address)
209219
if err != nil {
@@ -281,6 +291,31 @@ func (c *Config) UnmarshalProto(b []byte) error {
281291
c.TransportModuleConfig = TransportModuleConfigFromBytes(confProto.GetTransportModule().GetConfig())
282292
}
283293

294+
// Parse DialedAddressValidator if not already set
295+
if c.DialedAddressValidator == nil {
296+
a := &addressValidator{
297+
catchAll: confProto.GetNetwork().GetAddressValidation().GetCatchAll(),
298+
}
299+
300+
allowlist := confProto.GetNetwork().GetAddressValidation().GetAllowlist()
301+
if len(allowlist) > 0 {
302+
a.allowlist = make(map[string][]string)
303+
for k, v := range allowlist {
304+
a.allowlist[k] = v.GetNames()
305+
}
306+
}
307+
308+
denylist := confProto.GetNetwork().GetAddressValidation().GetDenylist()
309+
if len(denylist) > 0 {
310+
a.denylist = make(map[string][]string)
311+
for k, v := range denylist {
312+
a.denylist[k] = v.GetNames()
313+
}
314+
}
315+
316+
c.DialedAddressValidator = a.validate
317+
}
318+
284319
// Parse NetworkListener
285320
listenerNetwork, listenerAddress := confProto.GetNetwork().GetListener().GetNetwork(), confProto.GetNetwork().GetListener().GetAddress()
286321
if len(listenerNetwork) > 0 && len(listenerAddress) > 0 {

configbuilder/config.json.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ type ConfigJSON struct {
1212

1313
Network struct {
1414
// DialerFunc string `json:"dialer_func,omitempty"` // we have no good way to represent a func in JSON format yet
15+
AddressValidation struct {
16+
CatchAll bool `json:"catch_all,omitempty"` // If set, will allow all unspecified addresses. Otherwise, unspecified addresses will be rejected.
17+
Allowlist map[string][]string `json:"allowlist,omitempty"` // e.g. {"1.1.1.1:443": ["tcp", "udp"], "1.0.0.1:443": ["tcp"], ...}
18+
Denylist map[string][]string `json:"denylist,omitempty"` // e.g. {"1.0.0.0:80": ["udp"], ...}
19+
} `json:"address_validator,omitempty"`
1520
Listener struct {
1621
Network string `json:"network"` // e.g. "tcp"
1722
Address string `json:"address"` // e.g. "0.0.0.0:0"

0 commit comments

Comments
 (0)