diff --git a/nettap/alloc.go b/nettap/alloc.go new file mode 100644 index 0000000..711ac48 --- /dev/null +++ b/nettap/alloc.go @@ -0,0 +1,77 @@ +package nettap + +import ( + "fmt" + + "github.com/pkg/errors" + "github.com/shirou/gopsutil/process" +) + +type NetTapAlloc struct { + TapName string + PID int +} + +func (a *NetTapAlloc) Validate() error { + err := ValidateTapName(a.TapName) + if err != nil { + return errors.Wrap(err, "validate tap name") + } + + if a.PID == 0 { + return fmt.Errorf("pid is zero") + } + + if a.PID > int(a.PID) { + return fmt.Errorf("pid int32 overflow (%v)", a.PID) + } + + return nil +} + +// The taps removed slice always returns the taps removed, even after +// an error has occurred sometime while deleting non-first interfaces. +func (tm *TapManager) PruneTaps(knownAllocs []NetTapAlloc) ([]string, error) { + var tapsRemoved []string + + for i, alloc := range knownAllocs { + err := alloc.Validate() + if err != nil { + return tapsRemoved, errors.Wrapf(err, "validate alloc #%v", i) + } + } + + runningPids, err := process.Pids() + if err != nil { + return tapsRemoved, errors.Wrap(err, "get running pids") + } + + runningPidsMap := make(map[int32]struct{}) + for _, pid := range runningPids { + runningPidsMap[pid] = struct{}{} + } + + var tapsToRemove []string + + for _, alloc := range knownAllocs { + if _, exists := runningPidsMap[int32(alloc.PID)]; !exists { + tm.logger.Info("Found a dangling network tap", "name", alloc.TapName, "pid", alloc.PID) + tapsToRemove = append(tapsToRemove, alloc.TapName) + } + } + + for _, tapToRemove := range tapsToRemove { + err = tm.DeleteTap(tapToRemove) + if err != nil { + if errors.Is(err, ErrTapNotFound) { + tm.logger.Warn("Attempted to prune a network tap that doesn't exist, skipping", "name", tapToRemove) + } else { + return tapsRemoved, errors.Wrapf(err, "delete tap '%v'", tapToRemove) + } + } + + tapsRemoved = append(tapsRemoved, tapToRemove) + } + + return tapsRemoved, nil +} diff --git a/nettap/errors.go b/nettap/errors.go new file mode 100644 index 0000000..c1dfc93 --- /dev/null +++ b/nettap/errors.go @@ -0,0 +1,7 @@ +package nettap + +import "errors" + +var ( + ErrTapNotFound = errors.New("tap not found") +) diff --git a/nettap/nettap.go b/nettap/nettap.go new file mode 100644 index 0000000..76812e8 --- /dev/null +++ b/nettap/nettap.go @@ -0,0 +1,179 @@ +package nettap + +import ( + "bytes" + "crypto/rand" + "fmt" + "net" + "os" + "os/exec" + "regexp" + "strings" + "time" + + "log/slog" + + "github.com/AlexSSD7/linsk/utils" + "github.com/alessio/shellescape" + "github.com/google/uuid" + "github.com/pkg/errors" +) + +type TapManager struct { + logger *slog.Logger + + tapctlPath string +} + +func NewTapManager(logger *slog.Logger) (*TapManager, error) { + tapctlPath := `C:\Program Files\OpenVPN\bin\tapctl.exe` + _, err := os.Stat(tapctlPath) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + logger.Warn("Required OpenVPN tap networking Windows drivers do not appear to be installed. The easiest way to get them is to install OpenVPN: https://openvpn.net/community-downloads/") + } + return nil, errors.Wrapf(err, "stat tapctl path '%v'", tapctlPath) + } + + return &TapManager{ + logger: logger, + + tapctlPath: tapctlPath, + }, nil +} + +// We need some sort of format to avoid conflicting with other Windows interfaces. +var tapNameRegexp = regexp.MustCompile(`^LinskTap-\d+$`) + +func NewRandomTapName() string { + return fmt.Sprintf("LinskTap-%v", time.Now().UnixNano()) +} + +func (tm *TapManager) CreateNewTap(tapName string) error { + err := ValidateTapName(tapName) + if err != nil { + return errors.Wrap(err, "validate tap name") + } + + out, err := exec.Command(tm.tapctlPath, "create", "--name", tapName).CombinedOutput() + if err != nil { + return errors.Wrapf(err, "exec tapctl create cmd (out '%v')", utils.ClearUnprintableChars(string(out), false)) + } + + tm.logger.Info("Created network tap", "name", tapName) + + return nil +} + +func ValidateTapName(s string) error { + if !tapNameRegexp.MatchString(s) { + return fmt.Errorf("invalid tap name '%v'", s) + } + + return nil +} + +func (tm *TapManager) DeleteTap(name string) error { + stderr := bytes.NewBuffer(nil) + cmd := exec.Command(tm.tapctlPath, "list") + cmd.Stderr = stderr + tapList, err := cmd.Output() + if err != nil { + return errors.Wrapf(err, "exec tapctl list cmd (out '%v')", utils.ClearUnprintableChars(stderr.String(), false)) + } + + for _, line := range strings.Split(string(tapList), "\n") { + if line == "" { + continue + } + + line = strings.ReplaceAll(line, "\t", " ") + line = utils.ClearUnprintableChars(line, false) + + split := strings.Split(line, " ") + if want, have := 2, len(split); want > have { + return fmt.Errorf("bad tap list item split length: want %v > have %v (line '%v')", want, have, line) + } + + if name != split[1] { + continue + } + + lineTapUUIDStr := strings.TrimPrefix(split[0], "{") + lineTapUUIDStr = strings.TrimSuffix(lineTapUUIDStr, "}") + lineTapUUID, err := uuid.Parse(lineTapUUIDStr) + if err != nil { + return errors.Wrapf(err, "parse found line tap uuid (value '%v', line '%v')", lineTapUUIDStr, line) + } + + deleteOut, err := exec.Command(tm.tapctlPath, "delete", "{"+lineTapUUID.String()+"}").CombinedOutput() + if err != nil { + return errors.Wrapf(err, "exec tapctl delete (out '%v')", utils.ClearUnprintableChars(string(deleteOut), false)) + } + + tm.logger.Info("Deleted network tap", "name", name) + + return nil + } + + return ErrTapNotFound +} + +func (tm *TapManager) ConfigureNet(tapName string, hostCIDR string) error { + err := ValidateTapName(tapName) + if err != nil { + return errors.Wrap(err, "validate tap name") + } + + ip, _, err := net.ParseCIDR(hostCIDR) + if err != nil { + return errors.Wrap(err, "parse cidr") + } + + if !utils.IsIPv6IP(ip) { + return fmt.Errorf("ipv6 is accepted only (have '%v')", ip) + } + + out, err := exec.Command("netsh", "interface", "ipv6", "set", "address", shellescape.Quote(tapName), shellescape.Quote(hostCIDR)).CombinedOutput() + if err != nil { + return errors.Wrapf(err, "exec netsh cmd (out '%v')", utils.ClearUnprintableChars(string(out), false)) + } + + tm.logger.Info("Configured network tap", "name", tapName, "cidr", hostCIDR) + + return nil +} + +type TapNet struct { + HostIP net.IP + GuestIP net.IP + + HostCIDR string + GuestCIDR string +} + +func GenerateNet() (TapNet, error) { + // This is a Linsk internal network IPv6 prefix. + hostIP := []byte(net.ParseIP("fe8f:5980:3253:7df4:0f4b:6db1::")) + _, err := rand.Read(hostIP[len(hostIP)-4:]) + if err != nil { + return TapNet{}, errors.Wrap(err, "random read") + } + + // Put the last bit to zero. + hostIP[len(hostIP)-1] &= 0xfe + + guestIP := make([]byte, len(hostIP)) + copy(guestIP, hostIP) + + // Put the last bit to one. + guestIP[len(hostIP)-1] |= 0x1 + + return TapNet{ + HostIP: hostIP, + GuestIP: guestIP, + + HostCIDR: fmt.Sprintf("%v/127", net.IP(hostIP)), + GuestCIDR: fmt.Sprintf("%v/127", net.IP(guestIP)), + }, nil +}