mirror of
https://github.com/MetaCubeX/mihomo.git
synced 2024-11-16 11:42:43 +08:00
[Refactor] gvisor support hijack dns list
dns-hijack: - 1.1.1.1 - 8.8.8.8:53 - tcp://1.1.1.1:53 - udp://223.5.5.5 - 10.0.0.1:5353
This commit is contained in:
parent
64869d0f17
commit
4ab986cccb
46
common/net/tcpip.go
Normal file
46
common/net/tcpip.go
Normal file
|
@ -0,0 +1,46 @@
|
|||
package net
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func SplitNetworkType(s string) (string, string, error) {
|
||||
var (
|
||||
shecme string
|
||||
hostPort string
|
||||
)
|
||||
result := strings.Split(s, "://")
|
||||
if len(result) == 2 {
|
||||
shecme = result[0]
|
||||
hostPort = result[1]
|
||||
} else if len(result) == 1 {
|
||||
hostPort = result[0]
|
||||
} else {
|
||||
return "", "", fmt.Errorf("tcp/udp style error")
|
||||
}
|
||||
|
||||
if len(shecme) == 0 {
|
||||
shecme = "udp"
|
||||
}
|
||||
|
||||
if shecme != "tcp" && shecme != "udp" {
|
||||
return "", "", fmt.Errorf("scheme should be tcp:// or udp://")
|
||||
} else {
|
||||
return shecme, hostPort, nil
|
||||
}
|
||||
}
|
||||
|
||||
func SplitHostPort(s string) (host, port string, hasPort bool, err error) {
|
||||
temp := s
|
||||
hasPort = true
|
||||
|
||||
if !strings.Contains(s, ":") && !strings.Contains(s, "]:") {
|
||||
temp += ":0"
|
||||
hasPort = false
|
||||
}
|
||||
|
||||
host, port, err = net.SplitHostPort(temp)
|
||||
return
|
||||
}
|
|
@ -34,10 +34,10 @@ import (
|
|||
const nicID tcpip.NICID = 1
|
||||
|
||||
type gvisorAdapter struct {
|
||||
device dev.TunDevice
|
||||
ipstack *stack.Stack
|
||||
dnsServers []*DNSServer
|
||||
udpIn chan<- *inbound.PacketAdapter
|
||||
device dev.TunDevice
|
||||
ipstack *stack.Stack
|
||||
dnsServer *DNSServer
|
||||
udpIn chan<- *inbound.PacketAdapter
|
||||
|
||||
stackName string
|
||||
autoRoute bool
|
||||
|
@ -47,7 +47,7 @@ type gvisorAdapter struct {
|
|||
writeHandle *channel.NotificationHandle
|
||||
}
|
||||
|
||||
// GvisorAdapter create GvisorAdapter
|
||||
// NewAdapter GvisorAdapter create GvisorAdapter
|
||||
func NewAdapter(device dev.TunDevice, conf config.Tun, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) (ipstack.TunAdapter, error) {
|
||||
ipstack := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
||||
|
@ -132,7 +132,7 @@ func (t *gvisorAdapter) AutoRoute() bool {
|
|||
|
||||
// Close close the TunAdapter
|
||||
func (t *gvisorAdapter) Close() {
|
||||
t.StopAllDNSServer()
|
||||
t.StopDNSServer()
|
||||
if t.ipstack != nil {
|
||||
t.ipstack.Close()
|
||||
}
|
||||
|
|
|
@ -2,13 +2,14 @@ package gvisor
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"net"
|
||||
|
||||
Common "github.com/Dreamacro/clash/common/net"
|
||||
"github.com/Dreamacro/clash/dns"
|
||||
"github.com/Dreamacro/clash/log"
|
||||
D "github.com/miekg/dns"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
|
@ -23,15 +24,33 @@ var (
|
|||
ipv6Zero = tcpip.Address(net.IPv6zero.To16())
|
||||
)
|
||||
|
||||
type ListenerWrap struct {
|
||||
net.Listener
|
||||
listener net.Listener
|
||||
}
|
||||
|
||||
func (l *ListenerWrap) Accept() (conn net.Conn, err error) {
|
||||
conn, err = l.listener.Accept()
|
||||
log.Debugln("[DNS] hijack tcp:%s", l.Addr())
|
||||
return
|
||||
}
|
||||
|
||||
func (l *ListenerWrap) Close() error {
|
||||
return l.listener.Close()
|
||||
}
|
||||
|
||||
func (l *ListenerWrap) Addr() net.Addr {
|
||||
return l.listener.Addr()
|
||||
}
|
||||
|
||||
// DNSServer is DNS Server listening on tun devcice
|
||||
type DNSServer struct {
|
||||
*dns.Server
|
||||
resolver *dns.Resolver
|
||||
|
||||
stack *stack.Stack
|
||||
tcpListener net.Listener
|
||||
udpEndpoint *dnsEndpoint
|
||||
udpEndpointID *stack.TransportEndpointID
|
||||
dnsServers []*dns.Server
|
||||
tcpListeners []net.Listener
|
||||
resolver *dns.Resolver
|
||||
stack *stack.Stack
|
||||
udpEndpoints []*dnsEndpoint
|
||||
udpEndpointIDs []*stack.TransportEndpointID
|
||||
tcpip.NICID
|
||||
}
|
||||
|
||||
|
@ -66,6 +85,7 @@ func (e *dnsEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pack
|
|||
var msg D.Msg
|
||||
msg.Unpack(pkt.Data().AsRange().ToOwnedView())
|
||||
writer := dnsResponseWriter{s: e.stack, pkt: pkt, id: id}
|
||||
log.Debugln("[DNS] hijack udp:%s:%d", id.LocalAddress.String(), id.LocalPort)
|
||||
go e.server.ServeDNS(&writer, &msg)
|
||||
}
|
||||
|
||||
|
@ -129,167 +149,276 @@ func (w *dnsResponseWriter) Close() error {
|
|||
}
|
||||
|
||||
// CreateDNSServer create a dns server on given netstack
|
||||
func CreateDNSServer(s *stack.Stack, resolver *dns.Resolver, mapper *dns.ResolverEnhancer, ip net.IP, port int, nicID tcpip.NICID) (*DNSServer, error) {
|
||||
var v4 bool
|
||||
func CreateDNSServer(s *stack.Stack, resolver *dns.Resolver, mapper *dns.ResolverEnhancer, dnsHijack []net.Addr, nicID tcpip.NICID) (*DNSServer, error) {
|
||||
var err error
|
||||
|
||||
address := tcpip.FullAddress{NIC: nicID, Port: uint16(port)}
|
||||
var protocol tcpip.NetworkProtocolNumber
|
||||
if ip.To4() != nil {
|
||||
v4 = true
|
||||
address.Addr = tcpip.Address(ip.To4())
|
||||
protocol = ipv4.ProtocolNumber
|
||||
|
||||
} else {
|
||||
v4 = false
|
||||
address.Addr = tcpip.Address(ip.To16())
|
||||
protocol = ipv6.ProtocolNumber
|
||||
}
|
||||
protocolAddr := tcpip.ProtocolAddress{
|
||||
Protocol: protocol,
|
||||
AddressWithPrefix: address.Addr.WithPrefix(),
|
||||
}
|
||||
// netstack will only reassemble IP fragments when its' dest ip address is registered in NIC.endpoints
|
||||
if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
|
||||
log.Errorln("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
|
||||
}
|
||||
|
||||
if address.Addr == ipv4Zero || address.Addr == ipv6Zero {
|
||||
address.Addr = ""
|
||||
}
|
||||
|
||||
handler := dns.NewHandler(resolver, mapper)
|
||||
serverIn := &dns.Server{}
|
||||
serverIn.SetHandler(handler)
|
||||
|
||||
// UDP DNS
|
||||
id := &stack.TransportEndpointID{
|
||||
LocalAddress: address.Addr,
|
||||
LocalPort: uint16(port),
|
||||
RemotePort: 0,
|
||||
RemoteAddress: "",
|
||||
}
|
||||
|
||||
// TransportEndpoint for DNS
|
||||
endpoint := &dnsEndpoint{
|
||||
stack: s,
|
||||
uniqueID: s.UniqueID(),
|
||||
server: serverIn,
|
||||
}
|
||||
|
||||
if tcpiperr := s.RegisterTransportEndpoint(
|
||||
[]tcpip.NetworkProtocolNumber{
|
||||
ipv4.ProtocolNumber,
|
||||
ipv6.ProtocolNumber,
|
||||
},
|
||||
udp.ProtocolNumber,
|
||||
*id,
|
||||
endpoint,
|
||||
ports.Flags{LoadBalanced: true}, // it's actually the SO_REUSEPORT. Not sure it take effect.
|
||||
nicID); tcpiperr != nil {
|
||||
log.Errorln("Unable to start UDP DNS on tun: %v", tcpiperr.String())
|
||||
}
|
||||
|
||||
// TCP DNS
|
||||
var tcpListener net.Listener
|
||||
if v4 {
|
||||
tcpListener, err = gonet.ListenTCP(s, address, ipv4.ProtocolNumber)
|
||||
} else {
|
||||
tcpListener, err = gonet.ListenTCP(s, address, ipv6.ProtocolNumber)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("can not listen on tun: %v", err)
|
||||
tcpDnsArr := make([]net.TCPAddr, 0, len(dnsHijack))
|
||||
udpDnsArr := make([]net.UDPAddr, 0, len(dnsHijack))
|
||||
for _, d := range dnsHijack {
|
||||
switch d.(type) {
|
||||
case *net.TCPAddr:
|
||||
{
|
||||
tcpDnsArr = append(tcpDnsArr, *d.(*net.TCPAddr))
|
||||
break
|
||||
}
|
||||
case *net.UDPAddr:
|
||||
{
|
||||
udpDnsArr = append(udpDnsArr, *d.(*net.UDPAddr))
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
endpoints, ids := hijackUdpDns(udpDnsArr, s, serverIn)
|
||||
tcpListeners, dnsServers := hijackTcpDns(tcpDnsArr, s, serverIn)
|
||||
server := &DNSServer{
|
||||
Server: serverIn,
|
||||
resolver: resolver,
|
||||
stack: s,
|
||||
tcpListener: tcpListener,
|
||||
udpEndpoint: endpoint,
|
||||
udpEndpointID: id,
|
||||
NICID: nicID,
|
||||
resolver: resolver,
|
||||
stack: s,
|
||||
udpEndpoints: endpoints,
|
||||
udpEndpointIDs: ids,
|
||||
NICID: nicID,
|
||||
tcpListeners: tcpListeners,
|
||||
}
|
||||
server.SetHandler(handler)
|
||||
server.Server.Server = &D.Server{Listener: tcpListener, Handler: server}
|
||||
|
||||
go func() {
|
||||
server.ActivateAndServe()
|
||||
}()
|
||||
server.dnsServers = dnsServers
|
||||
|
||||
return server, err
|
||||
}
|
||||
|
||||
func hijackUdpDns(dnsArr []net.UDPAddr, s *stack.Stack, serverIn *dns.Server) ([]*dnsEndpoint, []*stack.TransportEndpointID) {
|
||||
endpoints := make([]*dnsEndpoint, len(dnsArr))
|
||||
ids := make([]*stack.TransportEndpointID, len(dnsArr))
|
||||
for i, dns := range dnsArr {
|
||||
port := dns.Port
|
||||
ip := dns.IP
|
||||
address := tcpip.FullAddress{NIC: nicID, Port: uint16(port)}
|
||||
var protocol tcpip.NetworkProtocolNumber
|
||||
if ip.To4() != nil {
|
||||
address.Addr = tcpip.Address(ip.To4())
|
||||
protocol = ipv4.ProtocolNumber
|
||||
|
||||
} else {
|
||||
address.Addr = tcpip.Address(ip.To16())
|
||||
protocol = ipv6.ProtocolNumber
|
||||
}
|
||||
|
||||
protocolAddr := tcpip.ProtocolAddress{
|
||||
Protocol: protocol,
|
||||
AddressWithPrefix: address.Addr.WithPrefix(),
|
||||
}
|
||||
|
||||
// netstack will only reassemble IP fragments when its' dest ip address is registered in NIC.endpoints
|
||||
if err := s.AddProtocolAddress(nicID, protocolAddr, stack.AddressProperties{}); err != nil {
|
||||
log.Errorln("AddProtocolAddress(%d, %+v, {}): %s", nicID, protocolAddr, err)
|
||||
}
|
||||
|
||||
if address.Addr == ipv4Zero || address.Addr == ipv6Zero {
|
||||
address.Addr = ""
|
||||
}
|
||||
|
||||
// UDP DNS
|
||||
id := &stack.TransportEndpointID{
|
||||
LocalAddress: address.Addr,
|
||||
LocalPort: uint16(port),
|
||||
RemotePort: 0,
|
||||
RemoteAddress: "",
|
||||
}
|
||||
|
||||
// TransportEndpoint for DNS
|
||||
endpoint := &dnsEndpoint{
|
||||
stack: s,
|
||||
uniqueID: s.UniqueID(),
|
||||
server: serverIn,
|
||||
}
|
||||
|
||||
if tcpiperr := s.RegisterTransportEndpoint(
|
||||
[]tcpip.NetworkProtocolNumber{
|
||||
ipv4.ProtocolNumber,
|
||||
ipv6.ProtocolNumber,
|
||||
},
|
||||
udp.ProtocolNumber,
|
||||
*id,
|
||||
endpoint,
|
||||
ports.Flags{LoadBalanced: true}, // it's actually the SO_REUSEPORT. Not sure it take effect.
|
||||
nicID); tcpiperr != nil {
|
||||
log.Errorln("Unable to start UDP DNS on tun: %v", tcpiperr.String())
|
||||
}
|
||||
|
||||
ids[i] = id
|
||||
endpoints[i] = endpoint
|
||||
}
|
||||
|
||||
return endpoints, ids
|
||||
}
|
||||
|
||||
func hijackTcpDns(dnsArr []net.TCPAddr, s *stack.Stack, serverIn *dns.Server) ([]net.Listener, []*dns.Server) {
|
||||
tcpListeners := make([]net.Listener, len(dnsArr))
|
||||
dnsServers := make([]*dns.Server, len(dnsArr))
|
||||
|
||||
for i, dnsAddr := range dnsArr {
|
||||
var tcpListener net.Listener
|
||||
var v4 bool
|
||||
var err error
|
||||
port := dnsAddr.Port
|
||||
ip := dnsAddr.IP
|
||||
address := tcpip.FullAddress{NIC: nicID, Port: uint16(port)}
|
||||
if ip.To4() != nil {
|
||||
address.Addr = tcpip.Address(ip.To4())
|
||||
v4 = true
|
||||
} else {
|
||||
address.Addr = tcpip.Address(ip.To16())
|
||||
v4 = false
|
||||
}
|
||||
|
||||
if v4 {
|
||||
tcpListener, err = gonet.ListenTCP(s, address, ipv4.ProtocolNumber)
|
||||
} else {
|
||||
tcpListener, err = gonet.ListenTCP(s, address, ipv6.ProtocolNumber)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Errorln("can not listen on tun: %v, hijack tcp[%s] failed", err, dnsAddr)
|
||||
} else {
|
||||
tcpListeners[i] = tcpListener
|
||||
server := &D.Server{Listener: &ListenerWrap{
|
||||
listener: tcpListener,
|
||||
}, Handler: serverIn}
|
||||
dnsServer := dns.Server{}
|
||||
dnsServer.Server = server
|
||||
go dnsServer.ActivateAndServe()
|
||||
dnsServers[i] = &dnsServer
|
||||
}
|
||||
|
||||
}
|
||||
//
|
||||
//for _, listener := range tcpListeners {
|
||||
// server := &D.Server{Listener: listener, Handler: serverIn}
|
||||
//
|
||||
// dnsServers = append(dnsServers, &dnsServer)
|
||||
// go dnsServer.ActivateAndServe()
|
||||
//}
|
||||
|
||||
return tcpListeners, dnsServers
|
||||
}
|
||||
|
||||
// Stop stop the DNS Server on tun
|
||||
func (s *DNSServer) Stop() {
|
||||
// shutdown TCP DNS Server
|
||||
s.Server.Shutdown()
|
||||
// remove TCP endpoint from stack
|
||||
if s.Listener != nil {
|
||||
s.Listener.Close()
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < len(s.udpEndpointIDs); i++ {
|
||||
ep := s.udpEndpoints[i]
|
||||
id := s.udpEndpointIDs[i]
|
||||
// remove udp endpoint from stack
|
||||
s.stack.UnregisterTransportEndpoint(
|
||||
[]tcpip.NetworkProtocolNumber{
|
||||
ipv4.ProtocolNumber,
|
||||
ipv6.ProtocolNumber,
|
||||
},
|
||||
udp.ProtocolNumber,
|
||||
*id,
|
||||
ep,
|
||||
ports.Flags{LoadBalanced: true}, // should match the RegisterTransportEndpoint
|
||||
s.NICID)
|
||||
}
|
||||
|
||||
for _, server := range s.dnsServers {
|
||||
server.Shutdown()
|
||||
}
|
||||
|
||||
for _, listener := range s.tcpListeners {
|
||||
listener.Close()
|
||||
}
|
||||
// remove udp endpoint from stack
|
||||
s.stack.UnregisterTransportEndpoint(
|
||||
[]tcpip.NetworkProtocolNumber{
|
||||
ipv4.ProtocolNumber,
|
||||
ipv6.ProtocolNumber,
|
||||
},
|
||||
udp.ProtocolNumber,
|
||||
*s.udpEndpointID,
|
||||
s.udpEndpoint,
|
||||
ports.Flags{LoadBalanced: true}, // should match the RegisterTransportEndpoint
|
||||
s.NICID)
|
||||
}
|
||||
|
||||
// DnsHijack return the listening address of DNS Server
|
||||
func (t *gvisorAdapter) DnsHijack() []string {
|
||||
results := make([]string, len(t.dnsServers))
|
||||
for i, dnsServer := range t.dnsServers {
|
||||
id := dnsServer.udpEndpointID
|
||||
results[i] = fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort)
|
||||
dnsHijackArr := make([]string, len(t.dnsServer.udpEndpoints))
|
||||
for _, id := range t.dnsServer.udpEndpointIDs {
|
||||
dnsHijackArr = append(dnsHijackArr, fmt.Sprintf("%s:%d", id.LocalAddress.String(), id.LocalPort))
|
||||
}
|
||||
|
||||
return results
|
||||
return dnsHijackArr
|
||||
}
|
||||
|
||||
func (t *gvisorAdapter) StopAllDNSServer() {
|
||||
for _, dnsServer := range t.dnsServers {
|
||||
dnsServer.Stop()
|
||||
}
|
||||
func (t *gvisorAdapter) StopDNSServer() {
|
||||
t.dnsServer.Stop()
|
||||
log.Debugln("tun DNS server stoped")
|
||||
t.dnsServers = nil
|
||||
t.dnsServer = nil
|
||||
}
|
||||
|
||||
// ReCreateDNSServer recreate the DNS Server on tun
|
||||
func (t *gvisorAdapter) ReCreateDNSServer(resolver *dns.Resolver, mapper *dns.ResolverEnhancer, addrs []string) error {
|
||||
t.StopAllDNSServer()
|
||||
func (t *gvisorAdapter) ReCreateDNSServer(resolver *dns.Resolver, mapper *dns.ResolverEnhancer, dnsHijackArr []string) error {
|
||||
t.StopDNSServer()
|
||||
|
||||
if resolver == nil {
|
||||
return fmt.Errorf("failed to create DNS server on tun: resolver not provided")
|
||||
}
|
||||
|
||||
if len(addrs) == 0 {
|
||||
if len(dnsHijackArr) == 0 {
|
||||
return fmt.Errorf("failed to create DNS server on tun: len(addrs) == 0")
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
var err error
|
||||
_, port, err := net.SplitHostPort(addr)
|
||||
if port == "0" || port == "" || err != nil {
|
||||
return nil
|
||||
}
|
||||
var err error
|
||||
var addrs []net.Addr
|
||||
for _, addr := range dnsHijackArr {
|
||||
var (
|
||||
addrType string
|
||||
hostPort string
|
||||
)
|
||||
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
addrType, hostPort, err = Common.SplitNetworkType(addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
server, err := CreateDNSServer(t.ipstack, resolver, mapper, udpAddr.IP, udpAddr.Port, nicID)
|
||||
if err != nil {
|
||||
return err
|
||||
var (
|
||||
host, port string
|
||||
hasPort bool
|
||||
)
|
||||
|
||||
host, port, hasPort, err = Common.SplitHostPort(hostPort)
|
||||
if !hasPort {
|
||||
port = "53"
|
||||
}
|
||||
t.dnsServers = append(t.dnsServers, server)
|
||||
log.Infoln("Tun DNS server listening at: %s, fake ip enabled: %v", addr, mapper.FakeIPEnabled())
|
||||
|
||||
switch addrType {
|
||||
case "udp", "":
|
||||
{
|
||||
var udpDNS *net.UDPAddr
|
||||
udpDNS, err = net.ResolveUDPAddr("udp", net.JoinHostPort(host, port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
addrs = append(addrs, udpDNS)
|
||||
break
|
||||
}
|
||||
case "tcp":
|
||||
{
|
||||
var tcpDNS *net.TCPAddr
|
||||
tcpDNS, err = net.ResolveTCPAddr("tcp", net.JoinHostPort(host, port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
addrs = append(addrs, tcpDNS)
|
||||
break
|
||||
}
|
||||
default:
|
||||
err = fmt.Errorf("unspported dns scheme:%s", addrType)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
server, err := CreateDNSServer(t.ipstack, resolver, mapper, addrs, nicID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.dnsServer = server
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user