diff --git a/component/nat/proxy.go b/component/nat/proxy.go new file mode 100644 index 00000000..29ff3c81 --- /dev/null +++ b/component/nat/proxy.go @@ -0,0 +1,26 @@ +package nat + +import ( + "net" + + "github.com/Dreamacro/clash/common/atomic" + C "github.com/Dreamacro/clash/constant" +) + +type writeBackProxy struct { + wb atomic.TypedValue[C.WriteBack] +} + +func (w *writeBackProxy) WriteBack(b []byte, addr net.Addr) (n int, err error) { + return w.wb.Load().WriteBack(b, addr) +} + +func (w *writeBackProxy) UpdateWriteBack(wb C.WriteBack) { + w.wb.Store(wb) +} + +func NewWriteBackProxy(wb C.WriteBack) C.WriteBackProxy { + w := &writeBackProxy{} + w.UpdateWriteBack(wb) + return w +} diff --git a/component/nat/table.go b/component/nat/table.go index 5dcd91ed..adc6eace 100644 --- a/component/nat/table.go +++ b/component/nat/table.go @@ -13,22 +13,24 @@ type Table struct { type Entry struct { PacketConn C.PacketConn + WriteBackProxy C.WriteBackProxy LocalUDPConnMap sync.Map } -func (t *Table) Set(key string, e C.PacketConn) { +func (t *Table) Set(key string, e C.PacketConn, w C.WriteBackProxy) { t.mapping.Store(key, &Entry{ PacketConn: e, + WriteBackProxy: w, LocalUDPConnMap: sync.Map{}, }) } -func (t *Table) Get(key string) C.PacketConn { +func (t *Table) Get(key string) (C.PacketConn, C.WriteBackProxy) { entry, exist := t.getEntry(key) if !exist { - return nil + return nil, nil } - return entry.PacketConn + return entry.PacketConn, entry.WriteBackProxy } func (t *Table) GetOrCreateLock(key string) (*sync.Cond, bool) { diff --git a/constant/adapters.go b/constant/adapters.go index 12579685..39b7d6eb 100644 --- a/constant/adapters.go +++ b/constant/adapters.go @@ -217,7 +217,7 @@ type UDPPacket interface { // - variable source IP/Port is important to STUN // - if addr is not provided, WriteBack will write out UDP packet with SourceIP/Port equals to original Target, // this is important when using Fake-IP. - WriteBack(b []byte, addr net.Addr) (n int, err error) + WriteBack // Drop call after packet is used, could recycle buffer in this function. Drop() @@ -236,10 +236,19 @@ type PacketAdapter interface { Metadata() *Metadata } -type NatTable interface { - Set(key string, e PacketConn) +type WriteBack interface { + WriteBack(b []byte, addr net.Addr) (n int, err error) +} - Get(key string) PacketConn +type WriteBackProxy interface { + WriteBack + UpdateWriteBack(wb WriteBack) +} + +type NatTable interface { + Set(key string, e PacketConn, w WriteBackProxy) + + Get(key string) (PacketConn, WriteBackProxy) GetOrCreateLock(key string) (*sync.Cond, bool) diff --git a/listener/shadowsocks/udp.go b/listener/shadowsocks/udp.go index 4efafa60..af610431 100644 --- a/listener/shadowsocks/udp.go +++ b/listener/shadowsocks/udp.go @@ -58,7 +58,7 @@ func (l *UDPListener) LocalAddr() net.Addr { return l.packetConn.LocalAddr() } -func handleSocksUDP(pc net.PacketConn, in chan<- C.PacketAdapter, buf []byte, put func(), addr net.Addr) { +func handleSocksUDP(pc net.PacketConn, in chan<- C.PacketAdapter, buf []byte, put func(), addr net.Addr, additions ...inbound.Addition) { tgtAddr := socks5.SplitAddr(buf) if tgtAddr == nil { // Unresolved UDP packet, return buffer to the pool @@ -77,7 +77,7 @@ func handleSocksUDP(pc net.PacketConn, in chan<- C.PacketAdapter, buf []byte, pu put: put, } select { - case in <- inbound.NewPacket(target, packet, C.SHADOWSOCKS): + case in <- inbound.NewPacket(target, packet, C.SHADOWSOCKS, additions...): default: } } diff --git a/listener/shadowsocks/utils.go b/listener/shadowsocks/utils.go index c34c5cd0..a732cbbe 100644 --- a/listener/shadowsocks/utils.go +++ b/listener/shadowsocks/utils.go @@ -38,7 +38,9 @@ func (c *packet) LocalAddr() net.Addr { func (c *packet) Drop() { if c.put != nil { c.put() + c.put = nil } + c.payload = nil } func (c *packet) InAddr() net.Addr { diff --git a/listener/socks/udp.go b/listener/socks/udp.go index f375dade..31858f74 100644 --- a/listener/socks/udp.go +++ b/listener/socks/udp.go @@ -4,7 +4,7 @@ import ( "net" "github.com/Dreamacro/clash/adapter/inbound" - "github.com/Dreamacro/clash/common/pool" + N "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/common/sockopt" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/log" @@ -53,36 +53,40 @@ func NewUDP(addr string, in chan<- C.PacketAdapter, additions ...inbound.Additio packetConn: l, addr: addr, } + conn := N.NewEnhancePacketConn(l) go func() { for { - buf := pool.Get(pool.UDPBufferSize) - n, remoteAddr, err := l.ReadFrom(buf) + data, put, remoteAddr, err := conn.WaitReadFrom() if err != nil { - pool.Put(buf) + if put != nil { + put() + } if sl.closed { break } continue } - handleSocksUDP(l, in, buf[:n], remoteAddr, additions...) + handleSocksUDP(l, in, data, put, remoteAddr, additions...) } }() return sl, nil } -func handleSocksUDP(pc net.PacketConn, in chan<- C.PacketAdapter, buf []byte, addr net.Addr, additions ...inbound.Addition) { +func handleSocksUDP(pc net.PacketConn, in chan<- C.PacketAdapter, buf []byte, put func(), addr net.Addr, additions ...inbound.Addition) { target, payload, err := socks5.DecodeUDPPacket(buf) if err != nil { // Unresolved UDP packet, return buffer to the pool - pool.Put(buf) + if put != nil { + put() + } return } packet := &packet{ pc: pc, rAddr: addr, payload: payload, - bufRef: buf, + put: put, } select { case in <- inbound.NewPacket(target, packet, C.SOCKS5, additions...): diff --git a/listener/socks/utils.go b/listener/socks/utils.go index 4c53b9e5..3456b595 100644 --- a/listener/socks/utils.go +++ b/listener/socks/utils.go @@ -3,7 +3,6 @@ package socks import ( "net" - "github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/transport/socks5" ) @@ -11,7 +10,7 @@ type packet struct { pc net.PacketConn rAddr net.Addr payload []byte - bufRef []byte + put func() } func (c *packet) Data() []byte { @@ -33,7 +32,11 @@ func (c *packet) LocalAddr() net.Addr { } func (c *packet) Drop() { - pool.Put(c.bufRef) + if c.put != nil { + c.put() + c.put = nil + } + c.payload = nil } func (c *packet) InAddr() net.Addr { diff --git a/listener/tproxy/packet.go b/listener/tproxy/packet.go index 4967adc6..2966fd2e 100644 --- a/listener/tproxy/packet.go +++ b/listener/tproxy/packet.go @@ -41,7 +41,8 @@ func (c *packet) LocalAddr() net.Addr { } func (c *packet) Drop() { - pool.Put(c.buf) + _ = pool.Put(c.buf) + c.buf = nil } func (c *packet) InAddr() net.Addr { diff --git a/listener/tunnel/packet.go b/listener/tunnel/packet.go index 602f7675..35601e38 100644 --- a/listener/tunnel/packet.go +++ b/listener/tunnel/packet.go @@ -27,7 +27,8 @@ func (c *packet) LocalAddr() net.Addr { } func (c *packet) Drop() { - pool.Put(c.payload) + _ = pool.Put(c.payload) + c.payload = nil } func (c *packet) InAddr() net.Addr { diff --git a/tunnel/connection.go b/tunnel/connection.go index b130f79a..38dbfa65 100644 --- a/tunnel/connection.go +++ b/tunnel/connection.go @@ -26,7 +26,7 @@ func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, metadata *C.Metadata return nil } -func handleUDPToLocal(packet C.UDPPacket, pc N.EnhancePacketConn, key string, oAddrPort netip.AddrPort, fAddr netip.Addr) { +func handleUDPToLocal(writeBack C.WriteBack, pc N.EnhancePacketConn, key string, oAddrPort netip.AddrPort, fAddr netip.Addr) { defer func() { _ = pc.Close() closeAllLocalCoon(key) @@ -59,7 +59,7 @@ func handleUDPToLocal(packet C.UDPPacket, pc N.EnhancePacketConn, key string, oA log.Warnln("server return a [%T](%s) which isn't a *net.UDPAddr, force replace to (%s), this may be caused by a wrongly implemented server", from, from, oAddrPort) } - _, err = packet.WriteBack(data, fromUDPAddr) + _, err = writeBack.WriteBack(data, fromUDPAddr) if put != nil { put() } diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 4e00aca2..cbbcaa75 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -303,8 +303,11 @@ func handleUDPConn(packet C.PacketAdapter) { key := packet.LocalAddr().String() handle := func() bool { - pc := natTable.Get(key) + pc, proxy := natTable.Get(key) if pc != nil { + if proxy != nil { + proxy.UpdateWriteBack(packet) + } _ = handleUDPToRemote(packet, pc, metadata) return true } @@ -384,9 +387,10 @@ func handleUDPConn(packet C.PacketAdapter) { } oAddrPort := metadata.AddrPort() - natTable.Set(key, pc) + writeBackProxy := nat.NewWriteBackProxy(packet) + natTable.Set(key, pc, writeBackProxy) - go handleUDPToLocal(packet, pc, key, oAddrPort, fAddr) + go handleUDPToLocal(writeBackProxy, pc, key, oAddrPort, fAddr) handle() }()