Skip to content

Commit a6afffb

Browse files
committed
chore: add test for vmess/vless inbound
1 parent 3d2cb99 commit a6afffb

File tree

4 files changed

+658
-10
lines changed

4 files changed

+658
-10
lines changed

common/net/tls.go

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ package net
33
import (
44
"crypto/rand"
55
"crypto/rsa"
6+
"crypto/sha256"
67
"crypto/tls"
78
"crypto/x509"
9+
"encoding/hex"
810
"encoding/pem"
911
"fmt"
1012
"math/big"
@@ -16,7 +18,11 @@ type Path interface {
1618

1719
func ParseCert(certificate, privateKey string, path Path) (tls.Certificate, error) {
1820
if certificate == "" && privateKey == "" {
19-
return newRandomTLSKeyPair()
21+
var err error
22+
certificate, privateKey, _, err = NewRandomTLSKeyPair()
23+
if err != nil {
24+
return tls.Certificate{}, err
25+
}
2026
}
2127
cert, painTextErr := tls.X509KeyPair([]byte(certificate), []byte(privateKey))
2228
if painTextErr == nil {
@@ -32,10 +38,10 @@ func ParseCert(certificate, privateKey string, path Path) (tls.Certificate, erro
3238
return cert, nil
3339
}
3440

35-
func newRandomTLSKeyPair() (tls.Certificate, error) {
41+
func NewRandomTLSKeyPair() (certificate string, privateKey string, fingerprint string, err error) {
3642
key, err := rsa.GenerateKey(rand.Reader, 2048)
3743
if err != nil {
38-
return tls.Certificate{}, err
44+
return
3945
}
4046
template := x509.Certificate{SerialNumber: big.NewInt(1)}
4147
certDER, err := x509.CreateCertificate(
@@ -45,14 +51,15 @@ func newRandomTLSKeyPair() (tls.Certificate, error) {
4551
&key.PublicKey,
4652
key)
4753
if err != nil {
48-
return tls.Certificate{}, err
54+
return
4955
}
50-
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
51-
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
52-
53-
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
56+
cert, err := x509.ParseCertificate(certDER)
5457
if err != nil {
55-
return tls.Certificate{}, err
58+
return
5659
}
57-
return tlsCert, nil
60+
hash := sha256.Sum256(cert.Raw)
61+
fingerprint = hex.EncodeToString(hash[:])
62+
privateKey = string(pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}))
63+
certificate = string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}))
64+
return
5865
}

listener/inbound/common_test.go

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
package inbound_test
2+
3+
import (
4+
"context"
5+
"crypto/rand"
6+
"crypto/tls"
7+
"encoding/base64"
8+
"fmt"
9+
"io"
10+
"net"
11+
"net/http"
12+
"net/netip"
13+
"sync"
14+
"testing"
15+
"time"
16+
17+
N "github.com/metacubex/mihomo/common/net"
18+
"github.com/metacubex/mihomo/component/ca"
19+
"github.com/metacubex/mihomo/component/generater"
20+
C "github.com/metacubex/mihomo/constant"
21+
22+
"github.com/go-chi/chi/v5"
23+
"github.com/go-chi/render"
24+
"github.com/stretchr/testify/assert"
25+
)
26+
27+
var tlsCertificate, tlsPrivateKey, tlsFingerprint, _ = N.NewRandomTLSKeyPair()
28+
var tlsConfigCert, _ = tls.X509KeyPair([]byte(tlsCertificate), []byte(tlsPrivateKey))
29+
var tlsConfig = &tls.Config{Certificates: []tls.Certificate{tlsConfigCert}, NextProtos: []string{"h2", "http/1.1"}}
30+
var tlsClientConfig, _ = ca.GetTLSConfig(nil, tlsFingerprint, "", "")
31+
var realityPrivateKey, realityPublickey string
32+
var realityDest = "itunes.apple.com"
33+
var realityShortid = "10f897e26c4b9478"
34+
35+
func init() {
36+
privateKey, err := generater.GeneratePrivateKey()
37+
if err != nil {
38+
panic(err)
39+
}
40+
publicKey := privateKey.PublicKey()
41+
realityPrivateKey = base64.RawURLEncoding.EncodeToString(privateKey[:])
42+
realityPublickey = base64.RawURLEncoding.EncodeToString(publicKey[:])
43+
}
44+
45+
type TestTunnel struct {
46+
HandleTCPConnFn func(conn net.Conn, metadata *C.Metadata)
47+
HandleUDPPacketFn func(packet C.UDPPacket, metadata *C.Metadata)
48+
NatTableFn func() C.NatTable
49+
CloseFn func() error
50+
DoTestFn func(t *testing.T, proxy C.ProxyAdapter)
51+
}
52+
53+
func (tt *TestTunnel) HandleTCPConn(conn net.Conn, metadata *C.Metadata) {
54+
tt.HandleTCPConnFn(conn, metadata)
55+
}
56+
57+
func (tt *TestTunnel) HandleUDPPacket(packet C.UDPPacket, metadata *C.Metadata) {
58+
tt.HandleUDPPacketFn(packet, metadata)
59+
}
60+
61+
func (tt *TestTunnel) NatTable() C.NatTable {
62+
return tt.NatTableFn()
63+
}
64+
65+
func (tt *TestTunnel) Close() error {
66+
return tt.CloseFn()
67+
}
68+
69+
func (tt *TestTunnel) DoTest(t *testing.T, proxy C.ProxyAdapter) {
70+
tt.DoTestFn(t, proxy)
71+
}
72+
73+
type TestTunnelListener struct {
74+
ch chan net.Conn
75+
ctx context.Context
76+
cancel context.CancelFunc
77+
addr net.Addr
78+
}
79+
80+
func (t *TestTunnelListener) Accept() (net.Conn, error) {
81+
select {
82+
case conn, ok := <-t.ch:
83+
if !ok {
84+
return nil, net.ErrClosed
85+
}
86+
return conn, nil
87+
case <-t.ctx.Done():
88+
return nil, t.ctx.Err()
89+
}
90+
}
91+
92+
func (t *TestTunnelListener) Close() error {
93+
t.cancel()
94+
return nil
95+
}
96+
97+
func (t *TestTunnelListener) Addr() net.Addr {
98+
return t.addr
99+
}
100+
101+
type WaitCloseConn struct {
102+
net.Conn
103+
ch chan struct{}
104+
once sync.Once
105+
}
106+
107+
func (c *WaitCloseConn) Close() error {
108+
err := c.Conn.Close()
109+
c.once.Do(func() {
110+
close(c.ch)
111+
})
112+
return err
113+
}
114+
115+
var _ C.Tunnel = (*TestTunnel)(nil)
116+
var _ net.Listener = (*TestTunnelListener)(nil)
117+
118+
type HttpTestConfig struct {
119+
RemoteAddr netip.AddrPort
120+
HttpPath string
121+
HttpData []byte
122+
}
123+
124+
func NewHttpTestTunnel() *TestTunnel {
125+
httpData := make([]byte, 10240)
126+
rand.Read(httpData)
127+
config := &HttpTestConfig{
128+
HttpPath: "/inbound_test",
129+
HttpData: httpData,
130+
RemoteAddr: netip.MustParseAddrPort("1.2.3.4:443"),
131+
}
132+
ctx, cancel := context.WithCancel(context.Background())
133+
ln := &TestTunnelListener{ch: make(chan net.Conn), ctx: ctx, cancel: cancel, addr: net.TCPAddrFromAddrPort(config.RemoteAddr)}
134+
135+
r := chi.NewRouter()
136+
r.Get(config.HttpPath, func(w http.ResponseWriter, r *http.Request) {
137+
render.Data(w, r, config.HttpData)
138+
})
139+
go http.Serve(ln, r)
140+
tunnel := &TestTunnel{
141+
HandleTCPConnFn: func(conn net.Conn, metadata *C.Metadata) {
142+
defer conn.Close()
143+
if metadata.AddrPort() != config.RemoteAddr && metadata.Host != realityDest {
144+
return // not match, just return
145+
}
146+
c := &WaitCloseConn{
147+
Conn: conn,
148+
ch: make(chan struct{}),
149+
}
150+
ln.ch <- tls.Server(c, tlsConfig)
151+
<-c.ch
152+
},
153+
CloseFn: ln.Close,
154+
DoTestFn: func(t *testing.T, proxy C.ProxyAdapter) {
155+
ctx, cancel := context.WithTimeout(ctx, time.Second)
156+
defer cancel()
157+
158+
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://%s%s", config.RemoteAddr, config.HttpPath), nil)
159+
assert.Nil(t, err)
160+
req = req.WithContext(ctx)
161+
162+
metadata := &C.Metadata{
163+
NetWork: C.TCP,
164+
DstIP: config.RemoteAddr.Addr(),
165+
DstPort: config.RemoteAddr.Port(),
166+
}
167+
instance, err := proxy.DialContext(ctx, metadata)
168+
assert.Nil(t, err)
169+
defer instance.Close()
170+
171+
transport := &http.Transport{
172+
DialContext: func(context.Context, string, string) (net.Conn, error) {
173+
return instance, nil
174+
},
175+
// from http.DefaultTransport
176+
MaxIdleConns: 100,
177+
IdleConnTimeout: 90 * time.Second,
178+
TLSHandshakeTimeout: 10 * time.Second,
179+
ExpectContinueTimeout: 1 * time.Second,
180+
// for our self-signed cert
181+
TLSClientConfig: tlsClientConfig,
182+
// open http2
183+
ForceAttemptHTTP2: true,
184+
}
185+
186+
client := http.Client{
187+
Timeout: 30 * time.Second,
188+
Transport: transport,
189+
CheckRedirect: func(req *http.Request, via []*http.Request) error {
190+
return http.ErrUseLastResponse
191+
},
192+
}
193+
194+
defer client.CloseIdleConnections()
195+
196+
resp, err := client.Do(req)
197+
assert.Nil(t, err)
198+
199+
defer resp.Body.Close()
200+
201+
assert.Equal(t, http.StatusOK, resp.StatusCode)
202+
203+
data, err := io.ReadAll(resp.Body)
204+
assert.Nil(t, err)
205+
assert.Equal(t, config.HttpData, data)
206+
},
207+
}
208+
return tunnel
209+
}

0 commit comments

Comments
 (0)