diff --git a/adapters/outbound/vmess.go b/adapters/outbound/vmess.go index ea2bce5d..e45844ae 100644 --- a/adapters/outbound/vmess.go +++ b/adapters/outbound/vmess.go @@ -38,6 +38,8 @@ type VmessOption struct { AlterID int `proxy:"alterId"` Cipher string `proxy:"cipher"` TLS bool `proxy:"tls,omitempty"` + Network string `proxy:"network,omitempty"` + WSPath string `proxy:"ws-path,omitempty"` } func (ss *Vmess) Name() string { @@ -54,17 +56,20 @@ func (ss *Vmess) Generator(metadata *C.Metadata) (adapter C.ProxyAdapter, err er return nil, fmt.Errorf("%s connect error", ss.server) } tcpKeepAlive(c) - c = ss.client.New(c, parseVmessAddr(metadata)) + c, err = ss.client.New(c, parseVmessAddr(metadata)) return &VmessAdapter{conn: c}, err } func NewVmess(option VmessOption) (*Vmess, error) { security := strings.ToLower(option.Cipher) client, err := vmess.NewClient(vmess.Config{ - UUID: option.UUID, - AlterID: uint16(option.AlterID), - Security: security, - TLS: option.TLS, + UUID: option.UUID, + AlterID: uint16(option.AlterID), + Security: security, + TLS: option.TLS, + Host: fmt.Sprintf("%s:%d", option.Server, option.Port), + NetWork: option.Network, + WebSocketPath: option.WSPath, }) if err != nil { return nil, err diff --git a/component/vmess/vmess.go b/component/vmess/vmess.go index 8b294b94..23dfc8f5 100644 --- a/component/vmess/vmess.go +++ b/component/vmess/vmess.go @@ -5,9 +5,12 @@ import ( "fmt" "math/rand" "net" + "net/url" "runtime" + "time" "github.com/gofrs/uuid" + "github.com/gorilla/websocket" ) // Version of vmess @@ -62,27 +65,69 @@ type DstAddr struct { // Client is vmess connection generator type Client struct { - user []*ID - uuid *uuid.UUID - security Security - tls bool + user []*ID + uuid *uuid.UUID + security Security + tls bool + host string + websocket bool + websocketPath string } // Config of vmess type Config struct { - UUID string - AlterID uint16 - Security string - TLS bool + UUID string + AlterID uint16 + Security string + TLS bool + Host string + NetWork string + WebSocketPath string } // New return a Conn with net.Conn and DstAddr -func (c *Client) New(conn net.Conn, dst *DstAddr) net.Conn { +func (c *Client) New(conn net.Conn, dst *DstAddr) (net.Conn, error) { r := rand.Intn(len(c.user)) - if c.tls { + if c.websocket { + dialer := &websocket.Dialer{ + NetDial: func(network, addr string) (net.Conn, error) { + return conn, nil + }, + ReadBufferSize: 4 * 1024, + WriteBufferSize: 4 * 1024, + HandshakeTimeout: time.Second * 8, + } + scheme := "ws" + if c.tls { + scheme = "wss" + } + + host, port, err := net.SplitHostPort(c.host) + if (scheme == "ws" && port != "80") || (scheme == "wss" && port != "443") { + host = c.host + } + + uri := url.URL{ + Scheme: scheme, + Host: host, + Path: c.websocketPath, + } + + wsConn, resp, err := dialer.Dial(uri.String(), nil) + if err != nil { + var reason string + if resp != nil { + reason = resp.Status + } + println(uri.String(), err.Error()) + return nil, fmt.Errorf("Dial %s error: %s", host, reason) + } + + conn = newWebsocketConn(wsConn, conn.RemoteAddr()) + } else if c.tls { conn = tls.Client(conn, tlsConfig) } - return newConn(conn, c.user[r], dst, c.security) + return newConn(conn, c.user[r], dst, c.security), nil } // NewClient return Client instance @@ -108,10 +153,18 @@ func NewClient(config Config) (*Client, error) { default: return nil, fmt.Errorf("Unknown security type: %s", config.Security) } + + if config.NetWork != "" && config.NetWork != "ws" { + return nil, fmt.Errorf("Unknown network type: %s", config.NetWork) + } + return &Client{ - user: newAlterIDs(newID(&uid), config.AlterID), - uuid: &uid, - security: security, - tls: config.TLS, + user: newAlterIDs(newID(&uid), config.AlterID), + uuid: &uid, + security: security, + tls: config.TLS, + host: config.Host, + websocket: config.NetWork == "ws", + websocketPath: config.WebSocketPath, }, nil } diff --git a/component/vmess/websocket.go b/component/vmess/websocket.go new file mode 100644 index 00000000..21e45814 --- /dev/null +++ b/component/vmess/websocket.go @@ -0,0 +1,99 @@ +package vmess + +import ( + "fmt" + "io" + "net" + "strings" + "time" + + "github.com/gorilla/websocket" +) + +type websocketConn struct { + conn *websocket.Conn + reader io.Reader + remoteAddr net.Addr +} + +// Read implements net.Conn.Read() +func (wsc *websocketConn) Read(b []byte) (int, error) { + for { + reader, err := wsc.getReader() + if err != nil { + return 0, err + } + + nBytes, err := reader.Read(b) + if err == io.EOF { + wsc.reader = nil + continue + } + return nBytes, err + } +} + +// Write implements io.Writer. +func (wsc *websocketConn) Write(b []byte) (int, error) { + if err := wsc.conn.WriteMessage(websocket.BinaryMessage, b); err != nil { + return 0, err + } + return len(b), nil +} + +func (wsc *websocketConn) Close() error { + var errors []string + if err := wsc.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5)); err != nil { + errors = append(errors, err.Error()) + } + if err := wsc.conn.Close(); err != nil { + errors = append(errors, err.Error()) + } + if len(errors) > 0 { + return fmt.Errorf("Failed to close connection: %s", strings.Join(errors, ",")) + } + return nil +} + +func (wsc *websocketConn) getReader() (io.Reader, error) { + if wsc.reader != nil { + return wsc.reader, nil + } + + _, reader, err := wsc.conn.NextReader() + if err != nil { + return nil, err + } + wsc.reader = reader + return reader, nil +} + +func (wsc *websocketConn) LocalAddr() net.Addr { + return wsc.conn.LocalAddr() +} + +func (wsc *websocketConn) RemoteAddr() net.Addr { + return wsc.remoteAddr +} + +func (wsc *websocketConn) SetDeadline(t time.Time) error { + if err := wsc.SetReadDeadline(t); err != nil { + return err + } + return wsc.SetWriteDeadline(t) +} + +func (wsc *websocketConn) SetReadDeadline(t time.Time) error { + return wsc.conn.SetReadDeadline(t) +} + +func (wsc *websocketConn) SetWriteDeadline(t time.Time) error { + return wsc.conn.SetWriteDeadline(t) +} + +func newWebsocketConn(conn *websocket.Conn, remoteAddr net.Addr) net.Conn { + return &websocketConn{ + conn: conn, + remoteAddr: remoteAddr, + } +} diff --git a/go.mod b/go.mod index f3a32a09..2e536bf5 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/go-chi/cors v1.0.0 github.com/go-chi/render v1.0.1 github.com/gofrs/uuid v3.1.0+incompatible + github.com/gorilla/websocket v1.4.0 github.com/oschwald/geoip2-golang v1.2.1 github.com/oschwald/maxminddb-golang v1.3.0 // indirect github.com/sirupsen/logrus v1.1.0 diff --git a/go.sum b/go.sum index a8a3132c..185a7fdf 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ github.com/go-chi/render v1.0.1 h1:4/5tis2cKaNdnv9zFLfXzcquC9HbeZgCnxGnKrltBS8= github.com/go-chi/render v1.0.1/go.mod h1:pq4Rr7HbnsdaeHagklXub+p6Wd16Af5l9koip1OvJns= github.com/gofrs/uuid v3.1.0+incompatible h1:q2rtkjaKT4YEr6E1kamy0Ha4RtepWlQBedyHx0uzKwA= github.com/gofrs/uuid v3.1.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/gorilla/websocket v1.4.0 h1:WDFjx/TMzVgy9VdMMQi2K2Emtwi2QcUQsztZ/zLaH/Q= +github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/konsorten/go-windows-terminal-sequences v0.0.0-20180402223658-b729f2633dfe h1:CHRGQ8V7OlCYtwaKPJi3iA7J+YdNKdo8j7nG5IgDhjs= github.com/konsorten/go-windows-terminal-sequences v0.0.0-20180402223658-b729f2633dfe/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/oschwald/geoip2-golang v1.2.1 h1:3iz+jmeJc6fuCyWeKgtXSXu7+zvkxJbHFXkMT5FVebU=