fix: wireguard can't be auto closed

This commit is contained in:
wwqgtxx 2024-04-24 11:07:22 +08:00
parent b2280c85b7
commit 2f8f139f7c
2 changed files with 169 additions and 122 deletions

View File

@ -38,9 +38,17 @@ type WireGuard struct {
device *device.Device
tunDevice wireguard.Device
dialer proxydialer.SingDialer
init func(ctx context.Context) error
resolver *dns.Resolver
refP *refProxyAdapter
initOk atomic.Bool
initMutex sync.Mutex
initErr error
option WireGuardOption
connectAddr M.Socksaddr
localPrefixes []netip.Prefix
closeCh chan struct{} // for test
}
type WireGuardOption struct {
@ -141,19 +149,6 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
}
runtime.SetFinalizer(outbound, closeWireGuard)
resolv := func(ctx context.Context, address M.Socksaddr) (netip.AddrPort, error) {
if address.Addr.IsValid() {
return address.AddrPort(), nil
}
udpAddr, err := resolveUDPAddrWithPrefer(ctx, "udp", address.String(), outbound.prefer)
if err != nil {
return netip.AddrPort{}, err
}
// net.ResolveUDPAddr maybe return 4in6 address, so unmap at here
addrPort := udpAddr.AddrPort()
return netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()), nil
}
var reserved [3]uint8
if len(option.Reserved) > 0 {
if len(option.Reserved) != 3 {
@ -162,29 +157,28 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
copy(reserved[:], option.Reserved)
}
var isConnect bool
var connectAddr M.Socksaddr
if len(option.Peers) < 2 {
isConnect = true
if len(option.Peers) == 1 {
connectAddr = option.Peers[0].Addr()
outbound.connectAddr = option.Peers[0].Addr()
} else {
connectAddr = option.Addr()
outbound.connectAddr = option.Addr()
}
}
outbound.bind = wireguard.NewClientBind(context.Background(), wgSingErrorHandler{outbound.Name()}, outbound.dialer, isConnect, connectAddr.AddrPort(), reserved)
outbound.bind = wireguard.NewClientBind(context.Background(), wgSingErrorHandler{outbound.Name()}, outbound.dialer, isConnect, outbound.connectAddr.AddrPort(), reserved)
localPrefixes, err := option.Prefixes()
var err error
outbound.localPrefixes, err = option.Prefixes()
if err != nil {
return nil, err
}
var privateKey string
{
bytes, err := base64.StdEncoding.DecodeString(option.PrivateKey)
if err != nil {
return nil, E.Cause(err, "decode private key")
}
privateKey = hex.EncodeToString(bytes)
option.PrivateKey = hex.EncodeToString(bytes)
}
if len(option.Peers) > 0 {
@ -230,110 +224,16 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
option.PreSharedKey = hex.EncodeToString(bytes)
}
}
var (
initOk atomic.Bool
initMutex sync.Mutex
initErr error
)
outbound.init = func(ctx context.Context) error {
if initOk.Load() {
return nil
}
initMutex.Lock()
defer initMutex.Unlock()
// double check like sync.Once
if initOk.Load() {
return nil
}
if initErr != nil {
return initErr
}
outbound.bind.ResetReservedForEndpoint()
ipcConf := "private_key=" + privateKey
if len(option.Peers) > 0 {
for i, peer := range option.Peers {
destination, err := resolv(ctx, peer.Addr())
if err != nil {
// !!! do not set initErr here !!!
// let us can retry domain resolve in next time
return E.Cause(err, "resolve endpoint domain for peer ", i)
}
ipcConf += "\npublic_key=" + peer.PublicKey
ipcConf += "\nendpoint=" + destination.String()
if peer.PreSharedKey != "" {
ipcConf += "\npreshared_key=" + peer.PreSharedKey
}
for _, allowedIP := range peer.AllowedIPs {
ipcConf += "\nallowed_ip=" + allowedIP
}
if len(peer.Reserved) > 0 {
copy(reserved[:], option.Reserved)
outbound.bind.SetReservedForEndpoint(destination, reserved)
}
}
} else {
ipcConf += "\npublic_key=" + option.PublicKey
destination, err := resolv(ctx, connectAddr)
if err != nil {
// !!! do not set initErr here !!!
// let us can retry domain resolve in next time
return E.Cause(err, "resolve endpoint domain")
}
outbound.bind.SetConnectAddr(destination)
ipcConf += "\nendpoint=" + destination.String()
if option.PreSharedKey != "" {
ipcConf += "\npreshared_key=" + option.PreSharedKey
}
var has4, has6 bool
for _, address := range localPrefixes {
if address.Addr().Is4() {
has4 = true
} else {
has6 = true
}
}
if has4 {
ipcConf += "\nallowed_ip=0.0.0.0/0"
}
if has6 {
ipcConf += "\nallowed_ip=::/0"
}
}
if option.PersistentKeepalive != 0 {
ipcConf += fmt.Sprintf("\npersistent_keepalive_interval=%d", option.PersistentKeepalive)
}
if debug.Enabled {
log.SingLogger.Trace(fmt.Sprintf("[WG](%s) created wireguard ipc conf: \n %s", option.Name, ipcConf))
}
err = outbound.device.IpcSet(ipcConf)
if err != nil {
initErr = E.Cause(err, "setup wireguard")
return initErr
}
err = outbound.tunDevice.Start()
if err != nil {
initErr = err
return initErr
}
initOk.Store(true)
return nil
}
outbound.option = option
mtu := option.MTU
if mtu == 0 {
mtu = 1408
}
if len(localPrefixes) == 0 {
if len(outbound.localPrefixes) == 0 {
return nil, E.New("missing local address")
}
outbound.tunDevice, err = wireguard.NewStackDevice(localPrefixes, uint32(mtu))
outbound.tunDevice, err = wireguard.NewStackDevice(outbound.localPrefixes, uint32(mtu))
if err != nil {
return nil, E.Cause(err, "create WireGuard device")
}
@ -347,7 +247,7 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
}, option.Workers)
var has6 bool
for _, address := range localPrefixes {
for _, address := range outbound.localPrefixes {
if !address.Addr().Unmap().Is4() {
has6 = true
break
@ -373,11 +273,117 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
return outbound, nil
}
func (w *WireGuard) resolve(ctx context.Context, address M.Socksaddr) (netip.AddrPort, error) {
if address.Addr.IsValid() {
return address.AddrPort(), nil
}
udpAddr, err := resolveUDPAddrWithPrefer(ctx, "udp", address.String(), w.prefer)
if err != nil {
return netip.AddrPort{}, err
}
// net.ResolveUDPAddr maybe return 4in6 address, so unmap at here
addrPort := udpAddr.AddrPort()
return netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()), nil
}
func (w *WireGuard) init(ctx context.Context) error {
if w.initOk.Load() {
return nil
}
w.initMutex.Lock()
defer w.initMutex.Unlock()
// double check like sync.Once
if w.initOk.Load() {
return nil
}
if w.initErr != nil {
return w.initErr
}
w.bind.ResetReservedForEndpoint()
ipcConf := "private_key=" + w.option.PrivateKey
if len(w.option.Peers) > 0 {
for i, peer := range w.option.Peers {
destination, err := w.resolve(ctx, peer.Addr())
if err != nil {
// !!! do not set initErr here !!!
// let us can retry domain resolve in next time
return E.Cause(err, "resolve endpoint domain for peer ", i)
}
ipcConf += "\npublic_key=" + peer.PublicKey
ipcConf += "\nendpoint=" + destination.String()
if peer.PreSharedKey != "" {
ipcConf += "\npreshared_key=" + peer.PreSharedKey
}
for _, allowedIP := range peer.AllowedIPs {
ipcConf += "\nallowed_ip=" + allowedIP
}
if len(peer.Reserved) > 0 {
var reserved [3]uint8
copy(reserved[:], w.option.Reserved)
w.bind.SetReservedForEndpoint(destination, reserved)
}
}
} else {
ipcConf += "\npublic_key=" + w.option.PublicKey
destination, err := w.resolve(ctx, w.connectAddr)
if err != nil {
// !!! do not set initErr here !!!
// let us can retry domain resolve in next time
return E.Cause(err, "resolve endpoint domain")
}
w.bind.SetConnectAddr(destination)
ipcConf += "\nendpoint=" + destination.String()
if w.option.PreSharedKey != "" {
ipcConf += "\npreshared_key=" + w.option.PreSharedKey
}
var has4, has6 bool
for _, address := range w.localPrefixes {
if address.Addr().Is4() {
has4 = true
} else {
has6 = true
}
}
if has4 {
ipcConf += "\nallowed_ip=0.0.0.0/0"
}
if has6 {
ipcConf += "\nallowed_ip=::/0"
}
}
if w.option.PersistentKeepalive != 0 {
ipcConf += fmt.Sprintf("\npersistent_keepalive_interval=%d", w.option.PersistentKeepalive)
}
if debug.Enabled {
log.SingLogger.Trace(fmt.Sprintf("[WG](%s) created wireguard ipc conf: \n %s", w.option.Name, ipcConf))
}
err := w.device.IpcSet(ipcConf)
if err != nil {
w.initErr = E.Cause(err, "setup wireguard")
return w.initErr
}
err = w.tunDevice.Start()
if err != nil {
w.initErr = err
return w.initErr
}
w.initOk.Store(true)
return nil
}
func closeWireGuard(w *WireGuard) {
if w.device != nil {
w.device.Close()
}
_ = common.Close(w.tunDevice)
if w.closeCh != nil {
close(w.closeCh)
}
}
func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
@ -416,9 +422,6 @@ func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadat
if err = w.init(ctx); err != nil {
return nil, err
}
if err != nil {
return nil, err
}
if (!metadata.Resolved() || w.resolver != nil) && metadata.Host != "" {
r := resolver.DefaultResolver
if w.resolver != nil {

View File

@ -0,0 +1,44 @@
//go:build with_gvisor
package outbound
import (
"context"
"runtime"
"testing"
"time"
)
func TestWireGuardGC(t *testing.T) {
option := WireGuardOption{}
option.Server = "162.159.192.1"
option.Port = 2408
option.PrivateKey = "iOx7749AdqH3IqluG7+0YbGKd0m1mcEXAfGRzpy9rG8="
option.PublicKey = "bmXOC+F1FxEMF9dyiK2H5/1SUtzH0JuVo51h2wPfgyo="
option.Ip = "172.16.0.2"
option.Ipv6 = "2606:4700:110:8d29:be92:3a6a:f4:c437"
option.Reserved = []uint8{51, 69, 125}
wg, err := NewWireGuard(option)
if err != nil {
t.Error(err)
}
closeCh := make(chan struct{})
wg.closeCh = closeCh
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
err = wg.init(ctx)
if err != nil {
t.Error(err)
}
// must do a small sleep before test GC
// because it maybe deadlocks if w.device.Close call too fast after w.device.Start
time.Sleep(10 * time.Millisecond)
wg = nil
runtime.GC()
select {
case <-closeCh:
return
case <-ctx.Done():
t.Error("timeout not GC")
}
}