chore: update proxy's udpConn when received a new packet

This commit is contained in:
wwqgtxx 2023-06-03 21:40:09 +08:00
parent 2af758e5f1
commit 63b5387164
11 changed files with 80 additions and 28 deletions

26
component/nat/proxy.go Normal file
View File

@ -0,0 +1,26 @@
package nat
import (
"net"
"github.com/Dreamacro/clash/common/atomic"
C "github.com/Dreamacro/clash/constant"
)
type writeBackProxy struct {
wb atomic.TypedValue[C.WriteBack]
}
func (w *writeBackProxy) WriteBack(b []byte, addr net.Addr) (n int, err error) {
return w.wb.Load().WriteBack(b, addr)
}
func (w *writeBackProxy) UpdateWriteBack(wb C.WriteBack) {
w.wb.Store(wb)
}
func NewWriteBackProxy(wb C.WriteBack) C.WriteBackProxy {
w := &writeBackProxy{}
w.UpdateWriteBack(wb)
return w
}

View File

@ -13,22 +13,24 @@ type Table struct {
type Entry struct { type Entry struct {
PacketConn C.PacketConn PacketConn C.PacketConn
WriteBackProxy C.WriteBackProxy
LocalUDPConnMap sync.Map LocalUDPConnMap sync.Map
} }
func (t *Table) Set(key string, e C.PacketConn) { func (t *Table) Set(key string, e C.PacketConn, w C.WriteBackProxy) {
t.mapping.Store(key, &Entry{ t.mapping.Store(key, &Entry{
PacketConn: e, PacketConn: e,
WriteBackProxy: w,
LocalUDPConnMap: sync.Map{}, LocalUDPConnMap: sync.Map{},
}) })
} }
func (t *Table) Get(key string) C.PacketConn { func (t *Table) Get(key string) (C.PacketConn, C.WriteBackProxy) {
entry, exist := t.getEntry(key) entry, exist := t.getEntry(key)
if !exist { if !exist {
return nil return nil, nil
} }
return entry.PacketConn return entry.PacketConn, entry.WriteBackProxy
} }
func (t *Table) GetOrCreateLock(key string) (*sync.Cond, bool) { func (t *Table) GetOrCreateLock(key string) (*sync.Cond, bool) {

View File

@ -217,7 +217,7 @@ type UDPPacket interface {
// - variable source IP/Port is important to STUN // - variable source IP/Port is important to STUN
// - if addr is not provided, WriteBack will write out UDP packet with SourceIP/Port equals to original Target, // - if addr is not provided, WriteBack will write out UDP packet with SourceIP/Port equals to original Target,
// this is important when using Fake-IP. // this is important when using Fake-IP.
WriteBack(b []byte, addr net.Addr) (n int, err error) WriteBack
// Drop call after packet is used, could recycle buffer in this function. // Drop call after packet is used, could recycle buffer in this function.
Drop() Drop()
@ -236,10 +236,19 @@ type PacketAdapter interface {
Metadata() *Metadata Metadata() *Metadata
} }
type NatTable interface { type WriteBack interface {
Set(key string, e PacketConn) WriteBack(b []byte, addr net.Addr) (n int, err error)
}
Get(key string) PacketConn type WriteBackProxy interface {
WriteBack
UpdateWriteBack(wb WriteBack)
}
type NatTable interface {
Set(key string, e PacketConn, w WriteBackProxy)
Get(key string) (PacketConn, WriteBackProxy)
GetOrCreateLock(key string) (*sync.Cond, bool) GetOrCreateLock(key string) (*sync.Cond, bool)

View File

@ -58,7 +58,7 @@ func (l *UDPListener) LocalAddr() net.Addr {
return l.packetConn.LocalAddr() return l.packetConn.LocalAddr()
} }
func handleSocksUDP(pc net.PacketConn, in chan<- C.PacketAdapter, buf []byte, put func(), addr net.Addr) { func handleSocksUDP(pc net.PacketConn, in chan<- C.PacketAdapter, buf []byte, put func(), addr net.Addr, additions ...inbound.Addition) {
tgtAddr := socks5.SplitAddr(buf) tgtAddr := socks5.SplitAddr(buf)
if tgtAddr == nil { if tgtAddr == nil {
// Unresolved UDP packet, return buffer to the pool // Unresolved UDP packet, return buffer to the pool
@ -77,7 +77,7 @@ func handleSocksUDP(pc net.PacketConn, in chan<- C.PacketAdapter, buf []byte, pu
put: put, put: put,
} }
select { select {
case in <- inbound.NewPacket(target, packet, C.SHADOWSOCKS): case in <- inbound.NewPacket(target, packet, C.SHADOWSOCKS, additions...):
default: default:
} }
} }

View File

@ -38,7 +38,9 @@ func (c *packet) LocalAddr() net.Addr {
func (c *packet) Drop() { func (c *packet) Drop() {
if c.put != nil { if c.put != nil {
c.put() c.put()
c.put = nil
} }
c.payload = nil
} }
func (c *packet) InAddr() net.Addr { func (c *packet) InAddr() net.Addr {

View File

@ -4,7 +4,7 @@ import (
"net" "net"
"github.com/Dreamacro/clash/adapter/inbound" "github.com/Dreamacro/clash/adapter/inbound"
"github.com/Dreamacro/clash/common/pool" N "github.com/Dreamacro/clash/common/net"
"github.com/Dreamacro/clash/common/sockopt" "github.com/Dreamacro/clash/common/sockopt"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/log"
@ -53,36 +53,40 @@ func NewUDP(addr string, in chan<- C.PacketAdapter, additions ...inbound.Additio
packetConn: l, packetConn: l,
addr: addr, addr: addr,
} }
conn := N.NewEnhancePacketConn(l)
go func() { go func() {
for { for {
buf := pool.Get(pool.UDPBufferSize) data, put, remoteAddr, err := conn.WaitReadFrom()
n, remoteAddr, err := l.ReadFrom(buf)
if err != nil { if err != nil {
pool.Put(buf) if put != nil {
put()
}
if sl.closed { if sl.closed {
break break
} }
continue continue
} }
handleSocksUDP(l, in, buf[:n], remoteAddr, additions...) handleSocksUDP(l, in, data, put, remoteAddr, additions...)
} }
}() }()
return sl, nil return sl, nil
} }
func handleSocksUDP(pc net.PacketConn, in chan<- C.PacketAdapter, buf []byte, addr net.Addr, additions ...inbound.Addition) { func handleSocksUDP(pc net.PacketConn, in chan<- C.PacketAdapter, buf []byte, put func(), addr net.Addr, additions ...inbound.Addition) {
target, payload, err := socks5.DecodeUDPPacket(buf) target, payload, err := socks5.DecodeUDPPacket(buf)
if err != nil { if err != nil {
// Unresolved UDP packet, return buffer to the pool // Unresolved UDP packet, return buffer to the pool
pool.Put(buf) if put != nil {
put()
}
return return
} }
packet := &packet{ packet := &packet{
pc: pc, pc: pc,
rAddr: addr, rAddr: addr,
payload: payload, payload: payload,
bufRef: buf, put: put,
} }
select { select {
case in <- inbound.NewPacket(target, packet, C.SOCKS5, additions...): case in <- inbound.NewPacket(target, packet, C.SOCKS5, additions...):

View File

@ -3,7 +3,6 @@ package socks
import ( import (
"net" "net"
"github.com/Dreamacro/clash/common/pool"
"github.com/Dreamacro/clash/transport/socks5" "github.com/Dreamacro/clash/transport/socks5"
) )
@ -11,7 +10,7 @@ type packet struct {
pc net.PacketConn pc net.PacketConn
rAddr net.Addr rAddr net.Addr
payload []byte payload []byte
bufRef []byte put func()
} }
func (c *packet) Data() []byte { func (c *packet) Data() []byte {
@ -33,7 +32,11 @@ func (c *packet) LocalAddr() net.Addr {
} }
func (c *packet) Drop() { func (c *packet) Drop() {
pool.Put(c.bufRef) if c.put != nil {
c.put()
c.put = nil
}
c.payload = nil
} }
func (c *packet) InAddr() net.Addr { func (c *packet) InAddr() net.Addr {

View File

@ -41,7 +41,8 @@ func (c *packet) LocalAddr() net.Addr {
} }
func (c *packet) Drop() { func (c *packet) Drop() {
pool.Put(c.buf) _ = pool.Put(c.buf)
c.buf = nil
} }
func (c *packet) InAddr() net.Addr { func (c *packet) InAddr() net.Addr {

View File

@ -27,7 +27,8 @@ func (c *packet) LocalAddr() net.Addr {
} }
func (c *packet) Drop() { func (c *packet) Drop() {
pool.Put(c.payload) _ = pool.Put(c.payload)
c.payload = nil
} }
func (c *packet) InAddr() net.Addr { func (c *packet) InAddr() net.Addr {

View File

@ -26,7 +26,7 @@ func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, metadata *C.Metadata
return nil return nil
} }
func handleUDPToLocal(packet C.UDPPacket, pc N.EnhancePacketConn, key string, oAddrPort netip.AddrPort, fAddr netip.Addr) { func handleUDPToLocal(writeBack C.WriteBack, pc N.EnhancePacketConn, key string, oAddrPort netip.AddrPort, fAddr netip.Addr) {
defer func() { defer func() {
_ = pc.Close() _ = pc.Close()
closeAllLocalCoon(key) closeAllLocalCoon(key)
@ -59,7 +59,7 @@ func handleUDPToLocal(packet C.UDPPacket, pc N.EnhancePacketConn, key string, oA
log.Warnln("server return a [%T](%s) which isn't a *net.UDPAddr, force replace to (%s), this may be caused by a wrongly implemented server", from, from, oAddrPort) log.Warnln("server return a [%T](%s) which isn't a *net.UDPAddr, force replace to (%s), this may be caused by a wrongly implemented server", from, from, oAddrPort)
} }
_, err = packet.WriteBack(data, fromUDPAddr) _, err = writeBack.WriteBack(data, fromUDPAddr)
if put != nil { if put != nil {
put() put()
} }

View File

@ -303,8 +303,11 @@ func handleUDPConn(packet C.PacketAdapter) {
key := packet.LocalAddr().String() key := packet.LocalAddr().String()
handle := func() bool { handle := func() bool {
pc := natTable.Get(key) pc, proxy := natTable.Get(key)
if pc != nil { if pc != nil {
if proxy != nil {
proxy.UpdateWriteBack(packet)
}
_ = handleUDPToRemote(packet, pc, metadata) _ = handleUDPToRemote(packet, pc, metadata)
return true return true
} }
@ -384,9 +387,10 @@ func handleUDPConn(packet C.PacketAdapter) {
} }
oAddrPort := metadata.AddrPort() oAddrPort := metadata.AddrPort()
natTable.Set(key, pc) writeBackProxy := nat.NewWriteBackProxy(packet)
natTable.Set(key, pc, writeBackProxy)
go handleUDPToLocal(packet, pc, key, oAddrPort, fAddr) go handleUDPToLocal(writeBackProxy, pc, key, oAddrPort, fAddr)
handle() handle()
}() }()