diff --git a/go.mod b/go.mod index d290d4c5..faf471ce 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( golang.org/x/crypto v0.0.0-20200604202706-70a84ac30bf9 golang.org/x/net v0.0.0-20200602114024-627f9648deb9 golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a + golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd gopkg.in/eapache/channels.v1 v1.1.0 gopkg.in/yaml.v2 v2.3.0 ) diff --git a/rules/process_other.go b/rules/process_other.go index c45ccc4c..a392250b 100644 --- a/rules/process_other.go +++ b/rules/process_other.go @@ -1,4 +1,4 @@ -// +build !darwin,!linux +// +build !darwin,!linux,!windows // +build !freebsd !amd64 package rules diff --git a/rules/process_windows.go b/rules/process_windows.go new file mode 100644 index 00000000..b1e4b931 --- /dev/null +++ b/rules/process_windows.go @@ -0,0 +1,286 @@ +package rules + +import ( + "errors" + "fmt" + "net" + "path/filepath" + "strconv" + "strings" + "sync" + "syscall" + "unsafe" + + "github.com/Dreamacro/clash/common/cache" + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/log" + + "golang.org/x/sys/windows" +) + +const ( + tcpTableFunc = "GetExtendedTcpTable" + tcpTablePidConn = 4 + udpTableFunc = "GetExtendedUdpTable" + udpTablePid = 1 + queryProcNameFunc = "QueryFullProcessImageNameW" +) + +var ( + processCache = cache.NewLRUCache(cache.WithAge(2), cache.WithSize(64)) + errNotFound = errors.New("process not found") + matchMeta = func(p *Process, m *C.Metadata) bool { return false } + + getExTcpTable uintptr + getExUdpTable uintptr + queryProcName uintptr + + once sync.Once +) + +func initWin32API() error { + h, err := windows.LoadLibrary("iphlpapi.dll") + if err != nil { + return fmt.Errorf("LoadLibrary iphlpapi.dll failed: %s", err.Error()) + } + + getExTcpTable, err = windows.GetProcAddress(h, tcpTableFunc) + if err != nil { + return fmt.Errorf("GetProcAddress of %s failed: %s", tcpTableFunc, err.Error()) + } + + getExUdpTable, err = windows.GetProcAddress(h, udpTableFunc) + if err != nil { + return fmt.Errorf("GetProcAddress of %s failed: %s", udpTableFunc, err.Error()) + } + + h, err = windows.LoadLibrary("kernel32.dll") + if err != nil { + return fmt.Errorf("LoadLibrary kernel32.dll failed: %s", err.Error()) + } + + queryProcName, err = windows.GetProcAddress(h, queryProcNameFunc) + if err != nil { + return fmt.Errorf("GetProcAddress of %s failed: %s", queryProcNameFunc, err.Error()) + } + + return nil +} + +type Process struct { + adapter string + process string +} + +func (p *Process) RuleType() C.RuleType { + return C.Process +} + +func (p *Process) Adapter() string { + return p.adapter +} + +func (p *Process) Payload() string { + return p.process +} + +func (p *Process) ShouldResolveIP() bool { + return false +} + +func match(p *Process, metadata *C.Metadata) bool { + key := fmt.Sprintf("%s:%s:%s", metadata.NetWork.String(), metadata.SrcIP.String(), metadata.SrcPort) + cached, hit := processCache.Get(key) + if !hit { + processName, err := resolveProcessName(metadata) + if err != nil { + log.Debugln("[%s] Resolve process of %s failed: %s", C.Process.String(), key, err.Error()) + } + + processCache.Set(key, processName) + cached = processName + } + return strings.EqualFold(cached.(string), p.process) +} + +func (p *Process) Match(metadata *C.Metadata) bool { + return matchMeta(p, metadata) +} + +func NewProcess(process string, adapter string) (*Process, error) { + once.Do(func() { + err := initWin32API() + if err != nil { + log.Errorln("Initialize PROCESS-NAME failed: %s", err.Error()) + log.Warnln("All PROCESS-NAMES rules will be skiped") + return + } + matchMeta = match + }) + return &Process{ + adapter: adapter, + process: process, + }, nil +} + +func resolveProcessName(metadata *C.Metadata) (string, error) { + ip := metadata.SrcIP + family := windows.AF_INET + if ip.To4() == nil { + family = windows.AF_INET6 + } + + var class int + var fn uintptr + switch metadata.NetWork { + case C.TCP: + fn = getExTcpTable + class = tcpTablePidConn + case C.UDP: + fn = getExUdpTable + class = udpTablePid + default: + return "", ErrInvalidNetwork + } + + srcPort, err := strconv.Atoi(metadata.SrcPort) + if err != nil { + return "", err + } + + buf, err := getTransportTable(fn, family, class) + if err != nil { + return "", err + } + + s := newSearcher(family == windows.AF_INET, metadata.NetWork == C.TCP) + + pid, err := s.Search(buf, ip, uint16(srcPort)) + if err != nil { + return "", err + } + return getExecPathFromPID(pid) +} + +type searcher struct { + itemSize int + port int + ip int + ipSize int + pid int + tcpState int +} + +func (s *searcher) Search(b []byte, ip net.IP, port uint16) (uint32, error) { + n := int(readNativeUint32(b[:4])) + itemSize := s.itemSize + for i := 0; i < n; i++ { + row := b[4+itemSize*i : 4+itemSize*(i+1)] + + if s.tcpState >= 0 { + tcpState := readNativeUint32(row[s.tcpState : s.tcpState+4]) + // MIB_TCP_STATE_ESTAB, only check established connections for TCP + if tcpState != 5 { + continue + } + } + + // according to MSDN, only the lower 16 bits of dwLocalPort are used and the port number is in network endian. + // this field can be illustrated as follows depends on different machine endianess: + // little endian: [ MSB LSB 0 0 ] interpret as native uint32 is ((LSB<<8)|MSB) + // big endian: [ 0 0 MSB LSB ] interpret as native uint32 is ((MSB<<8)|LSB) + // so we need an syscall.Ntohs on the lower 16 bits after read the port as native uint32 + srcPort := syscall.Ntohs(uint16(readNativeUint32(row[s.port : s.port+4]))) + if srcPort != port { + continue + } + + srcIP := net.IP(row[s.ip : s.ip+s.ipSize]) + if !ip.Equal(srcIP) { + continue + } + + pid := readNativeUint32(row[s.pid : s.pid+4]) + return pid, nil + } + return 0, errNotFound +} + +func newSearcher(isV4, isTCP bool) *searcher { + var itemSize, port, ip, ipSize, pid int + tcpState := -1 + switch { + case isV4 && isTCP: + // struct MIB_TCPROW_OWNER_PID + itemSize, port, ip, ipSize, pid, tcpState = 24, 8, 4, 4, 20, 0 + case isV4 && !isTCP: + // struct MIB_UDPROW_OWNER_PID + itemSize, port, ip, ipSize, pid = 12, 4, 0, 4, 8 + case !isV4 && isTCP: + // struct MIB_TCP6ROW_OWNER_PID + itemSize, port, ip, ipSize, pid, tcpState = 56, 20, 0, 16, 52, 48 + case !isV4 && !isTCP: + // struct MIB_UDP6ROW_OWNER_PID + itemSize, port, ip, ipSize, pid = 28, 20, 0, 16, 24 + } + + return &searcher{ + itemSize: itemSize, + port: port, + ip: ip, + ipSize: ipSize, + pid: pid, + tcpState: tcpState, + } +} + +func getTransportTable(fn uintptr, family int, class int) ([]byte, error) { + for size, buf := uint32(8), make([]byte, 8); ; { + ptr := unsafe.Pointer(&buf[0]) + err, _, _ := syscall.Syscall6(fn, 6, uintptr(ptr), uintptr(unsafe.Pointer(&size)), 0, uintptr(family), uintptr(class), 0) + + switch err { + case 0: + return buf, nil + case uintptr(syscall.ERROR_INSUFFICIENT_BUFFER): + buf = make([]byte, size) + default: + return nil, fmt.Errorf("syscall error: %d", err) + } + } +} + +func readNativeUint32(b []byte) uint32 { + return *(*uint32)(unsafe.Pointer(&b[0])) +} + +func getExecPathFromPID(pid uint32) (string, error) { + // kernel process starts with a colon in order to distinguish with normal processes + switch pid { + case 0: + // reserved pid for system idle process + return ":System Idle Process", nil + case 4: + // reserved pid for windows kernel image + return ":System", nil + } + h, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION, false, pid) + if err != nil { + return "", err + } + defer windows.CloseHandle(h) + + buf := make([]uint16, syscall.MAX_LONG_PATH) + size := uint32(len(buf)) + r1, _, err := syscall.Syscall6( + queryProcName, 4, + uintptr(h), + uintptr(1), + uintptr(unsafe.Pointer(&buf[0])), + uintptr(unsafe.Pointer(&size)), + 0, 0) + if r1 == 0 { + return "", err + } + return filepath.Base(syscall.UTF16ToString(buf[:size])), nil +}