From 003b562e4864990d6de9a2ee65d18f317d0e3d51 Mon Sep 17 00:00:00 2001 From: AlexSSD7 Date: Tue, 29 Aug 2023 15:31:17 +0100 Subject: [PATCH] Raw block device passthrough support --- cmd/ls.go | 30 +++++++++++++++---- cmd/utils.go | 6 ++-- vm/os_specifics.go | 27 +++++++++++++++++ vm/passthrough.go | 15 ++++++++++ vm/ssh.go | 4 +-- vm/types.go | 5 ---- vm/vm.go | 72 +++++++++++++++++++++++++++++++++++++++------- 7 files changed, 133 insertions(+), 26 deletions(-) create mode 100644 vm/passthrough.go diff --git a/cmd/ls.go b/cmd/ls.go index 3c8aa38..f6804a4 100644 --- a/cmd/ls.go +++ b/cmd/ls.go @@ -5,6 +5,7 @@ import ( "fmt" "log/slog" "os" + "path/filepath" "strconv" "strings" @@ -35,7 +36,7 @@ var lsCmd = &cobra.Command{ }, } -func getDevicePassthroughConfig(val string) vm.USBDevicePassthroughConfig { +func getDevicePassthroughConfig(val string) vm.PassthroughConfig { valSplit := strings.Split(val, ":") if want, have := 2, len(valSplit); want != have { slog.Error("Bad device passthrough syntax", "error", fmt.Errorf("wrong items split by ':' count: want %v, have %v", want, have)) @@ -62,14 +63,33 @@ func getDevicePassthroughConfig(val string) vm.USBDevicePassthroughConfig { os.Exit(1) } - return vm.USBDevicePassthroughConfig{ - VendorID: uint16(vendorID), - ProductID: uint16(productID), + return vm.PassthroughConfig{ + USB: []vm.USBDevicePassthroughConfig{{ + VendorID: uint16(vendorID), + ProductID: uint16(productID), + }}, } + case "dev": + devPath := filepath.Clean(valSplit[1]) + stat, err := os.Stat(devPath) + if err != nil { + slog.Error("Failed to stat the device path", "error", err.Error(), "path", devPath) + os.Exit(1) + } + + isDev := stat.Mode()&os.ModeDevice != 0 + if !isDev { + slog.Error("Provided path is not a path to a valid block device", "path", devPath, "file-mode", stat.Mode()) + os.Exit(1) + } + + return vm.PassthroughConfig{Block: []vm.BlockDevicePassthroughConfig{{ + Path: devPath, + }}} default: slog.Error("Unknown device passthrough type", "value", valSplit[0]) os.Exit(1) // This unreachable code is required to compile. - return vm.USBDevicePassthroughConfig{} + return vm.PassthroughConfig{} } } diff --git a/cmd/utils.go b/cmd/utils.go index 9d40cda..9617dbd 100644 --- a/cmd/utils.go +++ b/cmd/utils.go @@ -51,10 +51,10 @@ func doUSBRootCheck() { } func runVM(passthroughArg string, fn func(context.Context, *vm.VM, *vm.FileManager) int, forwardPortsRules []vm.PortForwardingRule, unrestrictedNetworking bool) int { - var passthroughConfig []vm.USBDevicePassthroughConfig + var passthroughConfig vm.PassthroughConfig if passthroughArg != "" { - passthroughConfig = []vm.USBDevicePassthroughConfig{getDevicePassthroughConfig(passthroughArg)} + passthroughConfig = getDevicePassthroughConfig(passthroughArg) doUSBRootCheck() } @@ -66,7 +66,7 @@ func runVM(passthroughArg string, fn func(context.Context, *vm.VM, *vm.FileManag MemoryAlloc: vmMemAllocFlag, - USBDevices: passthroughConfig, + PassthroughConfig: passthroughConfig, ExtraPortForwardingRules: forwardPortsRules, OSUpTimeout: time.Duration(vmOSUpTimeoutFlag) * time.Second, diff --git a/vm/os_specifics.go b/vm/os_specifics.go index e23e1a7..3dd2871 100644 --- a/vm/os_specifics.go +++ b/vm/os_specifics.go @@ -4,7 +4,11 @@ package vm import ( "os/exec" + "path/filepath" + "strings" "syscall" + + "github.com/pkg/errors" ) func prepareVMCmd(cmd *exec.Cmd) { @@ -17,3 +21,26 @@ func prepareVMCmd(cmd *exec.Cmd) { func terminateProcess(pid int) error { return syscall.Kill(-pid, syscall.SIGTERM) } + +// This is never used except for a band-aid that would check +// that there are no double-mounts. +func checkDeviceSeemsMounted(devPathPrefix string) (bool, error) { + absDevPathPrefix, err := filepath.Abs(devPathPrefix) + if err != nil { + return false, errors.Wrap(err, "get abs path") + } + + mounts, err := exec.Command("mount").Output() + if err != nil { + return false, errors.Wrap(err, "run mount command") + } + + for _, line := range strings.Split(string(mounts), "\n") { + // I know, I know, this is a rare band-aid. + if strings.HasPrefix(line, devPathPrefix) || strings.HasPrefix(line, absDevPathPrefix) { + return true, nil + } + } + + return false, nil +} diff --git a/vm/passthrough.go b/vm/passthrough.go new file mode 100644 index 0000000..96473c5 --- /dev/null +++ b/vm/passthrough.go @@ -0,0 +1,15 @@ +package vm + +type USBDevicePassthroughConfig struct { + VendorID uint16 + ProductID uint16 +} + +type BlockDevicePassthroughConfig struct { + Path string +} + +type PassthroughConfig struct { + USB []USBDevicePassthroughConfig + Block []BlockDevicePassthroughConfig +} diff --git a/vm/ssh.go b/vm/ssh.go index bf66feb..edb7a3f 100644 --- a/vm/ssh.go +++ b/vm/ssh.go @@ -103,10 +103,10 @@ func (vm *VM) sshSetup() (ssh.Signer, error) { installSSHDCmd := "" if vm.installSSH { - installSSHDCmd = "apk add openssh; " + installSSHDCmd = "ifconfig eth0 up && ifconfig lo up && udhcpc; apk add openssh; " } - cmd := `do_setup () { sh -c "set -ex; setup-alpine -q; ` + installSSHDCmd + `mkdir -p ~/.ssh; echo ` + shellescape.Quote(string(sshPublicKey)) + ` > ~/.ssh/authorized_keys; rc-update add sshd; rc-service sshd start"; echo "SERIAL"" ""STATUS: $?"; }; do_setup` + "\n" + cmd := `do_setup () { sh -c "set -ex; ` + installSSHDCmd + `mkdir -p ~/.ssh; echo ` + shellescape.Quote(string(sshPublicKey)) + ` > ~/.ssh/authorized_keys; rc-update add sshd; rc-service sshd start"; echo "SERIAL"" ""STATUS: $?"; }; do_setup` + "\n" err = vm.writeSerial([]byte(cmd)) if err != nil { diff --git a/vm/types.go b/vm/types.go index b9401b3..428c66b 100644 --- a/vm/types.go +++ b/vm/types.go @@ -9,11 +9,6 @@ import ( "github.com/pkg/errors" ) -type USBDevicePassthroughConfig struct { - VendorID uint16 - ProductID uint16 -} - type PortForwardingRule struct { HostIP net.IP HostPort uint16 diff --git a/vm/vm.go b/vm/vm.go index c019f17..9d627a8 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -54,6 +54,8 @@ type VM struct { // These are to be interacted with using `atomic` package disposed uint32 canceled uint32 + + originalCfg VMConfig } type DriveConfig struct { @@ -67,7 +69,7 @@ type VMConfig struct { MemoryAlloc uint32 // In KiB. - USBDevices []USBDevicePassthroughConfig + PassthroughConfig PassthroughConfig ExtraPortForwardingRules []PortForwardingRule // Timeouts @@ -149,26 +151,43 @@ func NewVM(logger *slog.Logger, cfg VMConfig) (*VM, error) { cmdArgs = append(cmdArgs, "-device", "virtio-gpu-device") } - if len(cfg.USBDevices) != 0 { - cmdArgs = append(cmdArgs, "-device", "nec-usb-xhci") - - for _, dev := range cfg.USBDevices { - cmdArgs = append(cmdArgs, "-device", "usb-host,vendorid=0x"+hex.EncodeToString(utils.Uint16ToBytesBE(dev.VendorID))+",productid=0x"+hex.EncodeToString(utils.Uint16ToBytesBE(dev.ProductID))) - } - } - for i, extraDrive := range cfg.Drives { _, err = os.Stat(extraDrive.Path) if err != nil { return nil, errors.Wrapf(err, "stat extra drive #%v path", i) } - driveArgs := "file=" + shellescape.Quote(extraDrive.Path) + ",format=qcow2,if=virtio" + driveArgs := "file=" + shellescape.Quote(extraDrive.Path) + ",format=qcow2,if=none,id=disk" + fmt.Sprint(i) if extraDrive.SnapshotMode { driveArgs += ",snapshot=on" } - cmdArgs = append(cmdArgs, "-drive", driveArgs) + cmdArgs = append(cmdArgs, "-drive", driveArgs, "-device", "virtio-blk-pci,drive=disk"+fmt.Sprint(i)+",bootindex="+fmt.Sprint(i)) + } + + if len(cfg.PassthroughConfig.USB) != 0 { + cmdArgs = append(cmdArgs, "-device", "nec-usb-xhci") + + for _, dev := range cfg.PassthroughConfig.USB { + cmdArgs = append(cmdArgs, "-device", "usb-host,vendorid=0x"+hex.EncodeToString(utils.Uint16ToBytesBE(dev.VendorID))+",productid=0x"+hex.EncodeToString(utils.Uint16ToBytesBE(dev.ProductID))) + } + } + + for _, dev := range cfg.PassthroughConfig.Block { + // It's always a user's responsibility to ensure that no drives are mounted + // in both host and guest system. This should serve as the last resort. + { + seemsMounted, err := checkDeviceSeemsMounted(dev.Path) + if err != nil { + return nil, errors.Wrapf(err, "check whether device seems to be mounted (path '%v')", dev.Path) + } + + if seemsMounted { + return nil, fmt.Errorf("device '%v' is already mounted in the host system", dev.Path) + } + } + + cmdArgs = append(cmdArgs, "-drive", "file="+shellescape.Quote(dev.Path)+",format=raw,aio=native,cache=none") } // We're not using clean `cdromImagePath` here because it is set to "." @@ -235,6 +254,8 @@ func NewVM(logger *slog.Logger, cfg VMConfig) (*VM, error) { osUpTimeout: osUpTimeout, sshUpTimeout: sshUpTimeout, + + originalCfg: cfg, } vm.resetSerialStdout() @@ -252,6 +273,8 @@ func (vm *VM) Run() error { return errors.Wrap(err, "start qemu cmd") } + go vm.runPeriodicHostMountChecker() + var globalErrsMu sync.Mutex var globalErrs []error @@ -499,3 +522,30 @@ func (vm *VM) DialSCP() (*scp.Client, error) { func (vm *VM) SSHUpNotifyChan() chan struct{} { return vm.sshReadyCh } + +// It's always a user's responsibility to ensure that no drives are mounted +// in both host and guest system. This should serve as the last resort. +func (vm *VM) runPeriodicHostMountChecker() { + if len(vm.originalCfg.PassthroughConfig.Block) == 0 { + return + } + + for { + select { + case <-vm.ctx.Done(): + return + case <-time.After(time.Second): + for _, dev := range vm.originalCfg.PassthroughConfig.Block { + seemsMounted, err := checkDeviceSeemsMounted(dev.Path) + if err != nil { + vm.logger.Warn("Failed to check if a passed device seems to be mounted", "dev-path", dev.Path) + continue + } + + if seemsMounted { + panic(fmt.Sprintf("CRITICAL: Passed-through device '%v' appears to have been mounted on the host OS. Forcefully exiting now to prevent data corruption.", dev.Path)) + } + } + } + } +}