feat: Attempts to send request with first payload on VLESS

This commit is contained in:
H1JK 2023-02-10 10:03:37 +08:00
parent 24419551a9
commit 3fd3d83029

View File

@ -6,6 +6,8 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"sync"
"time"
"github.com/Dreamacro/clash/common/buf" "github.com/Dreamacro/clash/common/buf"
N "github.com/Dreamacro/clash/common/net" N "github.com/Dreamacro/clash/common/net"
@ -21,6 +23,10 @@ type Conn struct {
id *uuid.UUID id *uuid.UUID
addons *Addons addons *Addons
received bool received bool
handshake chan struct{}
handshakeMutex sync.Mutex
err error
} }
func (vc *Conn) Read(b []byte) (int, error) { func (vc *Conn) Read(b []byte) (int, error) {
@ -47,7 +53,41 @@ func (vc *Conn) ReadBuffer(buffer *buf.Buffer) error {
return vc.ExtendedConn.ReadBuffer(buffer) return vc.ExtendedConn.ReadBuffer(buffer)
} }
func (vc *Conn) sendRequest() (err error) { func (vc *Conn) Write(p []byte) (int, error) {
select {
case <-vc.handshake:
default:
err := vc.sendRequest(p)
if err != nil {
return 0, err
}
}
return vc.ExtendedConn.Write(p)
}
func (vc *Conn) WriteBuffer(buffer *buf.Buffer) error {
select {
case <-vc.handshake:
default:
err := vc.sendRequest(buffer.Bytes())
if err != nil {
return err
}
}
return vc.ExtendedConn.WriteBuffer(buffer)
}
func (vc *Conn) sendRequest(p []byte) (err error) {
vc.handshakeMutex.Lock()
defer vc.handshakeMutex.Unlock()
select {
case <-vc.handshake:
return vc.err
default:
}
defer close(vc.handshake)
requestLen := 1 // protocol version requestLen := 1 // protocol version
requestLen += 16 // UUID requestLen += 16 // UUID
requestLen += 1 // addons length requestLen += 1 // addons length
@ -65,6 +105,8 @@ func (vc *Conn) sendRequest() (err error) {
requestLen += 1 // addr type requestLen += 1 // addr type
requestLen += len(vc.dst.Addr) requestLen += len(vc.dst.Addr)
} }
requestLen += len(p)
_buffer := buf.StackNewSize(requestLen) _buffer := buf.StackNewSize(requestLen)
defer buf.KeepAlive(_buffer) defer buf.KeepAlive(_buffer)
buffer := buf.Dup(_buffer) buffer := buf.Dup(_buffer)
@ -93,25 +135,26 @@ func (vc *Conn) sendRequest() (err error) {
) )
} }
buf.Must(buf.Error(buffer.Write(p)))
_, err = vc.ExtendedConn.Write(buffer.Bytes()) _, err = vc.ExtendedConn.Write(buffer.Bytes())
return return
} }
func (vc *Conn) recvResponse() error { func (vc *Conn) recvResponse() error {
var err error
var buf [1]byte var buf [1]byte
_, err = io.ReadFull(vc.ExtendedConn, buf[:]) _, vc.err = io.ReadFull(vc.ExtendedConn, buf[:])
if err != nil { if vc.err != nil {
return err return vc.err
} }
if buf[0] != Version { if buf[0] != Version {
return errors.New("unexpected response version") return errors.New("unexpected response version")
} }
_, err = io.ReadFull(vc.ExtendedConn, buf[:]) _, vc.err = io.ReadFull(vc.ExtendedConn, buf[:])
if err != nil { if vc.err != nil {
return err return vc.err
} }
length := int64(buf[0]) length := int64(buf[0])
@ -132,6 +175,7 @@ func newConn(conn net.Conn, client *Client, dst *DstAddr) (*Conn, error) {
ExtendedConn: N.NewExtendedConn(conn), ExtendedConn: N.NewExtendedConn(conn),
id: client.uuid, id: client.uuid,
dst: dst, dst: dst,
handshake: make(chan struct{}),
} }
if !dst.UDP && client.Addons != nil { if !dst.UDP && client.Addons != nil {
@ -155,8 +199,12 @@ func newConn(conn net.Conn, client *Client, dst *DstAddr) (*Conn, error) {
} }
} }
if err := c.sendRequest(); err != nil { go func() {
return nil, err select {
} case <-c.handshake:
case <-time.After(200 * time.Millisecond):
_ = c.sendRequest(nil)
}
}()
return c, nil return c, nil
} }