diff --git a/adapter/outbound/wireguard.go b/adapter/outbound/wireguard.go index 1044c8ec..976f3959 100644 --- a/adapter/outbound/wireguard.go +++ b/adapter/outbound/wireguard.go @@ -38,9 +38,17 @@ type WireGuard struct { device *device.Device tunDevice wireguard.Device dialer proxydialer.SingDialer - init func(ctx context.Context) error resolver *dns.Resolver refP *refProxyAdapter + + initOk atomic.Bool + initMutex sync.Mutex + initErr error + option WireGuardOption + connectAddr M.Socksaddr + localPrefixes []netip.Prefix + + closeCh chan struct{} // for test } type WireGuardOption struct { @@ -141,19 +149,6 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) { } runtime.SetFinalizer(outbound, closeWireGuard) - resolv := func(ctx context.Context, address M.Socksaddr) (netip.AddrPort, error) { - if address.Addr.IsValid() { - return address.AddrPort(), nil - } - udpAddr, err := resolveUDPAddrWithPrefer(ctx, "udp", address.String(), outbound.prefer) - if err != nil { - return netip.AddrPort{}, err - } - // net.ResolveUDPAddr maybe return 4in6 address, so unmap at here - addrPort := udpAddr.AddrPort() - return netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()), nil - } - var reserved [3]uint8 if len(option.Reserved) > 0 { if len(option.Reserved) != 3 { @@ -162,29 +157,28 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) { copy(reserved[:], option.Reserved) } var isConnect bool - var connectAddr M.Socksaddr if len(option.Peers) < 2 { isConnect = true if len(option.Peers) == 1 { - connectAddr = option.Peers[0].Addr() + outbound.connectAddr = option.Peers[0].Addr() } else { - connectAddr = option.Addr() + outbound.connectAddr = option.Addr() } } - outbound.bind = wireguard.NewClientBind(context.Background(), wgSingErrorHandler{outbound.Name()}, outbound.dialer, isConnect, connectAddr.AddrPort(), reserved) + outbound.bind = wireguard.NewClientBind(context.Background(), wgSingErrorHandler{outbound.Name()}, outbound.dialer, isConnect, outbound.connectAddr.AddrPort(), reserved) - localPrefixes, err := option.Prefixes() + var err error + outbound.localPrefixes, err = option.Prefixes() if err != nil { return nil, err } - var privateKey string { bytes, err := base64.StdEncoding.DecodeString(option.PrivateKey) if err != nil { return nil, E.Cause(err, "decode private key") } - privateKey = hex.EncodeToString(bytes) + option.PrivateKey = hex.EncodeToString(bytes) } if len(option.Peers) > 0 { @@ -230,110 +224,16 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) { option.PreSharedKey = hex.EncodeToString(bytes) } } - - var ( - initOk atomic.Bool - initMutex sync.Mutex - initErr error - ) - - outbound.init = func(ctx context.Context) error { - if initOk.Load() { - return nil - } - initMutex.Lock() - defer initMutex.Unlock() - // double check like sync.Once - if initOk.Load() { - return nil - } - if initErr != nil { - return initErr - } - - outbound.bind.ResetReservedForEndpoint() - ipcConf := "private_key=" + privateKey - if len(option.Peers) > 0 { - for i, peer := range option.Peers { - destination, err := resolv(ctx, peer.Addr()) - if err != nil { - // !!! do not set initErr here !!! - // let us can retry domain resolve in next time - return E.Cause(err, "resolve endpoint domain for peer ", i) - } - ipcConf += "\npublic_key=" + peer.PublicKey - ipcConf += "\nendpoint=" + destination.String() - if peer.PreSharedKey != "" { - ipcConf += "\npreshared_key=" + peer.PreSharedKey - } - for _, allowedIP := range peer.AllowedIPs { - ipcConf += "\nallowed_ip=" + allowedIP - } - if len(peer.Reserved) > 0 { - copy(reserved[:], option.Reserved) - outbound.bind.SetReservedForEndpoint(destination, reserved) - } - } - } else { - ipcConf += "\npublic_key=" + option.PublicKey - destination, err := resolv(ctx, connectAddr) - if err != nil { - // !!! do not set initErr here !!! - // let us can retry domain resolve in next time - return E.Cause(err, "resolve endpoint domain") - } - outbound.bind.SetConnectAddr(destination) - ipcConf += "\nendpoint=" + destination.String() - if option.PreSharedKey != "" { - ipcConf += "\npreshared_key=" + option.PreSharedKey - } - var has4, has6 bool - for _, address := range localPrefixes { - if address.Addr().Is4() { - has4 = true - } else { - has6 = true - } - } - if has4 { - ipcConf += "\nallowed_ip=0.0.0.0/0" - } - if has6 { - ipcConf += "\nallowed_ip=::/0" - } - } - - if option.PersistentKeepalive != 0 { - ipcConf += fmt.Sprintf("\npersistent_keepalive_interval=%d", option.PersistentKeepalive) - } - - if debug.Enabled { - log.SingLogger.Trace(fmt.Sprintf("[WG](%s) created wireguard ipc conf: \n %s", option.Name, ipcConf)) - } - err = outbound.device.IpcSet(ipcConf) - if err != nil { - initErr = E.Cause(err, "setup wireguard") - return initErr - } - - err = outbound.tunDevice.Start() - if err != nil { - initErr = err - return initErr - } - - initOk.Store(true) - return nil - } + outbound.option = option mtu := option.MTU if mtu == 0 { mtu = 1408 } - if len(localPrefixes) == 0 { + if len(outbound.localPrefixes) == 0 { return nil, E.New("missing local address") } - outbound.tunDevice, err = wireguard.NewStackDevice(localPrefixes, uint32(mtu)) + outbound.tunDevice, err = wireguard.NewStackDevice(outbound.localPrefixes, uint32(mtu)) if err != nil { return nil, E.Cause(err, "create WireGuard device") } @@ -347,7 +247,7 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) { }, option.Workers) var has6 bool - for _, address := range localPrefixes { + for _, address := range outbound.localPrefixes { if !address.Addr().Unmap().Is4() { has6 = true break @@ -373,11 +273,117 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) { return outbound, nil } +func (w *WireGuard) resolve(ctx context.Context, address M.Socksaddr) (netip.AddrPort, error) { + if address.Addr.IsValid() { + return address.AddrPort(), nil + } + udpAddr, err := resolveUDPAddrWithPrefer(ctx, "udp", address.String(), w.prefer) + if err != nil { + return netip.AddrPort{}, err + } + // net.ResolveUDPAddr maybe return 4in6 address, so unmap at here + addrPort := udpAddr.AddrPort() + return netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()), nil +} + +func (w *WireGuard) init(ctx context.Context) error { + if w.initOk.Load() { + return nil + } + w.initMutex.Lock() + defer w.initMutex.Unlock() + // double check like sync.Once + if w.initOk.Load() { + return nil + } + if w.initErr != nil { + return w.initErr + } + + w.bind.ResetReservedForEndpoint() + ipcConf := "private_key=" + w.option.PrivateKey + if len(w.option.Peers) > 0 { + for i, peer := range w.option.Peers { + destination, err := w.resolve(ctx, peer.Addr()) + if err != nil { + // !!! do not set initErr here !!! + // let us can retry domain resolve in next time + return E.Cause(err, "resolve endpoint domain for peer ", i) + } + ipcConf += "\npublic_key=" + peer.PublicKey + ipcConf += "\nendpoint=" + destination.String() + if peer.PreSharedKey != "" { + ipcConf += "\npreshared_key=" + peer.PreSharedKey + } + for _, allowedIP := range peer.AllowedIPs { + ipcConf += "\nallowed_ip=" + allowedIP + } + if len(peer.Reserved) > 0 { + var reserved [3]uint8 + copy(reserved[:], w.option.Reserved) + w.bind.SetReservedForEndpoint(destination, reserved) + } + } + } else { + ipcConf += "\npublic_key=" + w.option.PublicKey + destination, err := w.resolve(ctx, w.connectAddr) + if err != nil { + // !!! do not set initErr here !!! + // let us can retry domain resolve in next time + return E.Cause(err, "resolve endpoint domain") + } + w.bind.SetConnectAddr(destination) + ipcConf += "\nendpoint=" + destination.String() + if w.option.PreSharedKey != "" { + ipcConf += "\npreshared_key=" + w.option.PreSharedKey + } + var has4, has6 bool + for _, address := range w.localPrefixes { + if address.Addr().Is4() { + has4 = true + } else { + has6 = true + } + } + if has4 { + ipcConf += "\nallowed_ip=0.0.0.0/0" + } + if has6 { + ipcConf += "\nallowed_ip=::/0" + } + } + + if w.option.PersistentKeepalive != 0 { + ipcConf += fmt.Sprintf("\npersistent_keepalive_interval=%d", w.option.PersistentKeepalive) + } + + if debug.Enabled { + log.SingLogger.Trace(fmt.Sprintf("[WG](%s) created wireguard ipc conf: \n %s", w.option.Name, ipcConf)) + } + err := w.device.IpcSet(ipcConf) + if err != nil { + w.initErr = E.Cause(err, "setup wireguard") + return w.initErr + } + + err = w.tunDevice.Start() + if err != nil { + w.initErr = err + return w.initErr + } + + w.initOk.Store(true) + return nil +} + func closeWireGuard(w *WireGuard) { if w.device != nil { w.device.Close() } _ = common.Close(w.tunDevice) + if w.closeCh != nil { + close(w.closeCh) + } } func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) { @@ -416,9 +422,6 @@ func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadat if err = w.init(ctx); err != nil { return nil, err } - if err != nil { - return nil, err - } if (!metadata.Resolved() || w.resolver != nil) && metadata.Host != "" { r := resolver.DefaultResolver if w.resolver != nil { diff --git a/adapter/outbound/wireguard_test.go b/adapter/outbound/wireguard_test.go new file mode 100644 index 00000000..20dbdbdd --- /dev/null +++ b/adapter/outbound/wireguard_test.go @@ -0,0 +1,44 @@ +//go:build with_gvisor + +package outbound + +import ( + "context" + "runtime" + "testing" + "time" +) + +func TestWireGuardGC(t *testing.T) { + option := WireGuardOption{} + option.Server = "162.159.192.1" + option.Port = 2408 + option.PrivateKey = "iOx7749AdqH3IqluG7+0YbGKd0m1mcEXAfGRzpy9rG8=" + option.PublicKey = "bmXOC+F1FxEMF9dyiK2H5/1SUtzH0JuVo51h2wPfgyo=" + option.Ip = "172.16.0.2" + option.Ipv6 = "2606:4700:110:8d29:be92:3a6a:f4:c437" + option.Reserved = []uint8{51, 69, 125} + wg, err := NewWireGuard(option) + if err != nil { + t.Error(err) + } + closeCh := make(chan struct{}) + wg.closeCh = closeCh + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + err = wg.init(ctx) + if err != nil { + t.Error(err) + } + // must do a small sleep before test GC + // because it maybe deadlocks if w.device.Close call too fast after w.device.Start + time.Sleep(10 * time.Millisecond) + wg = nil + runtime.GC() + select { + case <-closeCh: + return + case <-ctx.Done(): + t.Error("timeout not GC") + } +}