Fix: trojan split udp packet

This commit is contained in:
gVisor bot 2020-03-20 00:02:05 +08:00
parent 11b67b19e6
commit 25396eaa34
2 changed files with 102 additions and 25 deletions

View File

@ -178,7 +178,7 @@ func ServerHandshake(rw net.Conn, authenticator auth.Authenticator) (addr Addr,
} }
command = buf[1] command = buf[1]
addr, err = readAddr(rw, buf) addr, err = ReadAddr(rw, buf)
if err != nil { if err != nil {
return return
} }
@ -260,10 +260,10 @@ func ClientHandshake(rw io.ReadWriter, addr Addr, command Command, user *User) (
return nil, err return nil, err
} }
return readAddr(rw, buf) return ReadAddr(rw, buf)
} }
func readAddr(r io.Reader, b []byte) (Addr, error) { func ReadAddr(r io.Reader, b []byte) (Addr, error) {
if len(b) < MaxAddrLen { if len(b) < MaxAddrLen {
return nil, io.ErrShortBuffer return nil, io.ErrShortBuffer
} }

View File

@ -7,12 +7,18 @@ import (
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"errors" "errors"
"io"
"net" "net"
"sync" "sync"
"github.com/Dreamacro/clash/component/socks5" "github.com/Dreamacro/clash/component/socks5"
) )
const (
// max packet length
maxLength = 8192
)
var ( var (
defaultALPN = []string{"h2", "http/1.1"} defaultALPN = []string{"h2", "http/1.1"}
crlf = []byte{'\r', '\n'} crlf = []byte{'\r', '\n'}
@ -62,7 +68,7 @@ func (t *Trojan) StreamConn(conn net.Conn) (net.Conn, error) {
return tlsConn, nil return tlsConn, nil
} }
func (t *Trojan) WriteHeader(conn net.Conn, command Command, socks5Addr []byte) error { func (t *Trojan) WriteHeader(w io.Writer, command Command, socks5Addr []byte) error {
buf := bufPool.Get().(*bytes.Buffer) buf := bufPool.Get().(*bytes.Buffer)
defer buf.Reset() defer buf.Reset()
defer bufPool.Put(buf) defer bufPool.Put(buf)
@ -74,15 +80,17 @@ func (t *Trojan) WriteHeader(conn net.Conn, command Command, socks5Addr []byte)
buf.Write(socks5Addr) buf.Write(socks5Addr)
buf.Write(crlf) buf.Write(crlf)
_, err := conn.Write(buf.Bytes()) _, err := w.Write(buf.Bytes())
return err return err
} }
func (t *Trojan) PacketConn(conn net.Conn) net.PacketConn { func (t *Trojan) PacketConn(conn net.Conn) net.PacketConn {
return &PacketConn{conn} return &PacketConn{
Conn: conn,
}
} }
func WritePacket(conn net.Conn, socks5Addr, payload []byte) (int, error) { func writePacket(w io.Writer, socks5Addr, payload []byte) (int, error) {
buf := bufPool.Get().(*bytes.Buffer) buf := bufPool.Get().(*bytes.Buffer)
defer buf.Reset() defer buf.Reset()
defer bufPool.Put(buf) defer bufPool.Put(buf)
@ -92,26 +100,67 @@ func WritePacket(conn net.Conn, socks5Addr, payload []byte) (int, error) {
buf.Write(crlf) buf.Write(crlf)
buf.Write(payload) buf.Write(payload)
return conn.Write(buf.Bytes()) return w.Write(buf.Bytes())
} }
func DecodePacket(payload []byte) (net.Addr, []byte, error) { func WritePacket(w io.Writer, socks5Addr, payload []byte) (int, error) {
addr := socks5.SplitAddr(payload) if len(payload) <= maxLength {
if addr == nil { return writePacket(w, socks5Addr, payload)
return nil, nil, errors.New("split addr error")
} }
buf := payload[len(addr):] offset := 0
if len(buf) <= 4 { total := len(payload)
return nil, nil, errors.New("packet invalid") for {
cursor := offset + maxLength
if cursor > total {
cursor = total
}
n, err := writePacket(w, socks5Addr, payload[offset:cursor])
if err != nil {
return offset + n, err
}
offset = cursor
if offset == total {
break
}
} }
length := binary.BigEndian.Uint16(buf[:2]) return total, nil
if len(buf) < 4+int(length) { }
return nil, nil, errors.New("packet invalid")
func ReadPacket(r io.Reader, payload []byte) (net.Addr, int, int, error) {
addr, err := socks5.ReadAddr(r, payload)
if err != nil {
return nil, 0, 0, errors.New("read addr error")
}
uAddr := addr.UDPAddr()
if _, err = io.ReadFull(r, payload[:2]); err != nil {
return nil, 0, 0, errors.New("read length error")
} }
return addr.UDPAddr(), buf[4 : 4+length], nil total := int(binary.BigEndian.Uint16(payload[:2]))
if total > maxLength {
return nil, 0, 0, errors.New("packet invalid")
}
// read crlf
if _, err = io.ReadFull(r, payload[:2]); err != nil {
return nil, 0, 0, errors.New("read crlf error")
}
length := len(payload)
if total < length {
length = total
}
if _, err = io.ReadFull(r, payload[:length]); err != nil {
return nil, 0, 0, errors.New("read packet error")
}
return uAddr, length, total - length, nil
} }
func New(option *Option) *Trojan { func New(option *Option) *Trojan {
@ -120,6 +169,9 @@ func New(option *Option) *Trojan {
type PacketConn struct { type PacketConn struct {
net.Conn net.Conn
remain int
rAddr net.Addr
mux sync.Mutex
} }
func (pc *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { func (pc *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
@ -127,14 +179,39 @@ func (pc *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
} }
func (pc *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) { func (pc *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
n, err := pc.Conn.Read(b) pc.mux.Lock()
addr, payload, err := DecodePacket(b) defer pc.mux.Unlock()
if err != nil { if pc.remain != 0 {
return n, nil, err length := len(b)
if pc.remain < length {
length = pc.remain
}
n, err := pc.Conn.Read(b[:length])
if err != nil {
return 0, nil, err
}
pc.remain -= n
addr := pc.rAddr
if pc.remain == 0 {
pc.rAddr = nil
}
return n, addr, nil
} }
copy(b, payload) addr, n, remain, err := ReadPacket(pc.Conn, b)
return len(payload), addr, nil if err != nil {
return 0, nil, err
}
if remain != 0 {
pc.remain = remain
pc.rAddr = addr
}
return n, addr, nil
} }
func hexSha224(data []byte) []byte { func hexSha224(data []byte) []byte {