linsk/vm/vm.go

415 lines
9.3 KiB
Go
Raw Normal View History

2023-08-25 15:12:19 +01:00
package vm
import (
"bufio"
"bytes"
"context"
"encoding/hex"
2023-08-25 15:12:19 +01:00
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"runtime"
"sync"
"sync/atomic"
2023-08-25 19:55:11 +01:00
"syscall"
2023-08-25 15:12:19 +01:00
"time"
2023-08-25 16:54:58 +01:00
"log/slog"
2023-08-27 15:53:44 +01:00
"github.com/AlexSSD7/linsk/utils"
2023-08-25 15:12:19 +01:00
"github.com/alessio/shellescape"
2023-08-26 16:26:35 +01:00
"github.com/bramvdbogaerde/go-scp"
2023-08-25 15:12:19 +01:00
"github.com/phayes/freeport"
"github.com/pkg/errors"
"go.uber.org/multierr"
"golang.org/x/crypto/ssh"
)
2023-08-27 13:44:57 +01:00
type VM struct {
2023-08-25 16:54:58 +01:00
logger *slog.Logger
2023-08-25 15:12:19 +01:00
ctx context.Context
ctxCancel context.CancelFunc
cmd *exec.Cmd
sshMappedPort uint16
sshConf *ssh.ClientConfig
sshReadyCh chan struct{}
2023-08-27 15:30:51 +01:00
installSSH bool
2023-08-25 15:12:19 +01:00
serialRead *io.PipeReader
serialReader *bufio.Reader
serialWrite *io.PipeWriter
serialWriteMu sync.Mutex
stderrBuf *bytes.Buffer
serialStdoutCh chan []byte
// These are to be interacted with using `atomic` package
disposed uint32
canceled uint32
}
2023-08-27 15:30:51 +01:00
type DriveConfig struct {
Path string
SnapshotMode bool
}
2023-08-27 13:44:57 +01:00
type VMConfig struct {
CdromImagePath string
2023-08-27 15:53:44 +01:00
Drives []DriveConfig
2023-08-27 13:44:57 +01:00
USBDevices []USBDevicePassthroughConfig
ExtraPortForwardingRules []PortForwardingRule
2023-08-27 15:30:51 +01:00
// Mostly debug-related options.
UnrestrictedNetworking bool
ShowDisplay bool
InstallBaseUtilities bool
2023-08-27 13:44:57 +01:00
}
func NewVM(logger *slog.Logger, cfg VMConfig) (*VM, error) {
cdromImagePath := filepath.Clean(cfg.CdromImagePath)
_, err := os.Stat(cdromImagePath)
2023-08-25 15:12:19 +01:00
if err != nil {
2023-08-27 13:44:57 +01:00
return nil, errors.Wrap(err, "stat cdrom image path")
2023-08-25 15:12:19 +01:00
}
sshPort, err := freeport.GetFreePort()
if err != nil {
return nil, errors.Wrap(err, "get free port for ssh server")
}
// TODO: Configurable memory allocation
2023-08-27 16:54:35 +01:00
cmdArgs := []string{"-serial", "stdio", "-m", "2048", "-smp", fmt.Sprint(runtime.NumCPU())}
baseCmd := "qemu-system"
switch runtime.GOARCH {
case "amd64":
cmdArgs = append(cmdArgs, "-accel", "kvm")
baseCmd += "-x86_64"
case "arm64":
// TODO: EFI firmware path is temporary, for dev purposes only.
cmdArgs = append(cmdArgs, "-accel", "hvf", "-bios", "/opt/homebrew/Cellar/qemu/8.1.0/share/qemu/edk2-aarch64-code.fd", "-M", "virt,highmem=off", "-cpu", "cortex-a57")
baseCmd += "-aarch64"
default:
return nil, fmt.Errorf("arch '%v' is not supported", runtime.GOARCH)
}
2023-08-26 11:57:12 +01:00
netdevOpts := "user,id=net0,hostfwd=tcp:127.0.0.1:" + fmt.Sprint(sshPort) + "-:22"
2023-08-27 15:30:51 +01:00
if !cfg.UnrestrictedNetworking {
2023-08-26 16:43:04 +01:00
netdevOpts += ",restrict=on"
} else {
logger.Warn("Running with unsafe unrestricted networking")
}
2023-08-27 13:44:57 +01:00
for _, pf := range cfg.ExtraPortForwardingRules {
2023-08-26 11:57:12 +01:00
hostIPStr := ""
if pf.HostIP != nil {
hostIPStr = pf.HostIP.String()
}
netdevOpts += ",hostfwd=tcp:" + hostIPStr + ":" + fmt.Sprint(pf.HostPort) + "-:" + fmt.Sprint(pf.VMPort)
}
cmdArgs = append(cmdArgs, "-device", "e1000,netdev=net0", "-netdev", netdevOpts)
2023-08-25 15:12:19 +01:00
2023-08-27 15:30:51 +01:00
if !cfg.ShowDisplay {
2023-08-25 15:12:19 +01:00
cmdArgs = append(cmdArgs, "-display", "none")
} else if runtime.GOARCH == "arm64" {
// No video is configured by default in ARM. This will enable it.
// TODO: This doesn't really work on arm64. It just shows a blank viewer.
cmdArgs = append(cmdArgs, "-device", "virtio-gpu-device")
2023-08-25 15:12:19 +01:00
}
2023-08-27 13:44:57 +01:00
if len(cfg.USBDevices) != 0 {
cmdArgs = append(cmdArgs, "-device", "nec-usb-xhci,id=xhci")
2023-08-25 15:12:19 +01:00
2023-08-27 13:44:57 +01:00
for _, dev := range cfg.USBDevices {
cmdArgs = append(cmdArgs, "-device", "usb-host,vendorid=0x"+hex.EncodeToString(utils.Uint16ToBytesBE(dev.VendorID))+",productid=0x"+hex.EncodeToString(utils.Uint16ToBytesBE(dev.ProductID)))
2023-08-25 15:12:19 +01:00
}
}
2023-08-27 15:30:51 +01:00
for i, extraDrive := range cfg.Drives {
_, err = os.Stat(extraDrive.Path)
if err != nil {
return nil, errors.Wrapf(err, "stat extra drive #%v path", i)
}
driveArgs := "file=" + shellescape.Quote(extraDrive.Path) + ",format=qcow2,if=virtio"
if extraDrive.SnapshotMode {
2023-08-27 15:53:44 +01:00
driveArgs += ",snapshot=on"
2023-08-27 15:30:51 +01:00
}
cmdArgs = append(cmdArgs, "-drive", driveArgs)
}
2023-08-27 15:53:44 +01:00
// We're not using clean `cdromImagePath` here because it is set to "."
// when the original string is empty.
if cfg.CdromImagePath != "" {
2023-08-27 15:30:51 +01:00
cmdArgs = append(cmdArgs, "-boot", "d", "-cdrom", cdromImagePath)
}
if cfg.InstallBaseUtilities && !cfg.UnrestrictedNetworking {
return nil, fmt.Errorf("cannot install base utilities with unrestricted networking disabled")
}
2023-08-25 15:12:19 +01:00
sysRead, userWrite := io.Pipe()
userRead, sysWrite := io.Pipe()
cmd := exec.Command(baseCmd, cmdArgs...)
cmd.Stdin = sysRead
cmd.Stdout = sysWrite
stderrBuf := bytes.NewBuffer(nil)
cmd.Stderr = stderrBuf
2023-08-25 19:55:11 +01:00
// This is to prevent Ctrl+C propagating to the child process.
cmd.SysProcAttr = &syscall.SysProcAttr{
Setpgid: true,
}
2023-08-25 15:12:19 +01:00
userReader := bufio.NewReader(userRead)
ctx, ctxCancel := context.WithCancel(context.Background())
2023-08-27 13:44:57 +01:00
vm := &VM{
2023-08-25 15:12:19 +01:00
logger: logger,
ctx: ctx,
ctxCancel: ctxCancel,
cmd: cmd,
sshMappedPort: uint16(sshPort),
sshReadyCh: make(chan struct{}),
2023-08-27 15:30:51 +01:00
installSSH: cfg.InstallBaseUtilities,
2023-08-25 15:12:19 +01:00
serialRead: userRead,
serialReader: userReader,
serialWrite: userWrite,
stderrBuf: stderrBuf,
}
2023-08-27 13:44:57 +01:00
vm.resetSerialStdout()
2023-08-25 15:12:19 +01:00
2023-08-27 13:44:57 +01:00
return vm, nil
2023-08-25 15:12:19 +01:00
}
2023-08-27 13:44:57 +01:00
func (vm *VM) Run() error {
if atomic.AddUint32(&vm.disposed, 1) != 1 {
2023-08-25 15:12:19 +01:00
return fmt.Errorf("vm disposed")
}
2023-08-27 13:44:57 +01:00
err := vm.cmd.Start()
2023-08-25 15:12:19 +01:00
if err != nil {
return errors.Wrap(err, "start qemu cmd")
}
var globalErrsMu sync.Mutex
var globalErrs []error
globalErrFn := func(err error) {
globalErrsMu.Lock()
defer globalErrsMu.Unlock()
2023-08-27 13:44:57 +01:00
globalErrs = append(globalErrs, err, errors.Wrap(vm.Cancel(), "cancel on error"))
2023-08-25 15:12:19 +01:00
}
2023-08-27 13:44:57 +01:00
vm.logger.Info("Booting the VM")
2023-08-25 15:12:19 +01:00
go func() {
2023-08-27 13:44:57 +01:00
_ = vm.runSerialReader()
_ = vm.Cancel()
2023-08-25 15:12:19 +01:00
}()
go func() {
2023-08-27 13:44:57 +01:00
err = vm.runVMLoginHandler()
2023-08-25 15:12:19 +01:00
if err != nil {
globalErrFn(errors.Wrap(err, "run vm login handler"))
return
}
2023-08-27 13:44:57 +01:00
vm.logger.Info("Setting the VM up")
2023-08-25 16:54:58 +01:00
2023-08-27 13:44:57 +01:00
sshSigner, err := vm.sshSetup()
2023-08-25 15:12:19 +01:00
if err != nil {
globalErrFn(errors.Wrap(err, "set up ssh"))
return
}
2023-08-27 13:44:57 +01:00
vm.logger.Debug("Set up SSH server successfully")
2023-08-25 15:12:19 +01:00
2023-08-27 13:44:57 +01:00
sshKeyScan, err := vm.scanSSHIdentity()
2023-08-25 15:12:19 +01:00
if err != nil {
globalErrFn(errors.Wrap(err, "scan ssh identity"))
return
}
2023-08-27 13:44:57 +01:00
vm.logger.Debug("Scanned SSH identity")
2023-08-25 15:12:19 +01:00
knownHosts, err := ParseSSHKeyScan(sshKeyScan)
if err != nil {
// TODO: Test what actually happens in inline critical errors like this.
globalErrFn(errors.Wrap(err, "parse ssh key scan"))
return
}
2023-08-27 13:44:57 +01:00
vm.sshConf = &ssh.ClientConfig{
2023-08-25 15:12:19 +01:00
User: "root",
HostKeyCallback: knownHosts,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(sshSigner),
},
Timeout: time.Second * 5,
}
// This is to notify everyone waiting for SSH to be up that it's ready to go.
2023-08-27 13:44:57 +01:00
close(vm.sshReadyCh)
2023-08-25 15:12:19 +01:00
2023-08-27 13:44:57 +01:00
vm.logger.Info("The VM is ready")
2023-08-25 15:12:19 +01:00
}()
2023-08-27 13:44:57 +01:00
_, err = vm.cmd.Process.Wait()
cancelErr := vm.Cancel()
2023-08-25 15:12:19 +01:00
if err != nil {
combinedErr := multierr.Combine(
errors.Wrap(err, "wait for cmd to finish execution"),
errors.Wrap(cancelErr, "cancel"),
)
2023-08-27 15:53:44 +01:00
return fmt.Errorf("%w %v", combinedErr, utils.GetLogErrMsg(vm.stderrBuf.String()))
2023-08-25 15:12:19 +01:00
}
combinedErr := multierr.Combine(
append(globalErrs, errors.Wrap(cancelErr, "cancel on exit"))...,
)
2023-08-25 16:54:58 +01:00
if combinedErr != nil {
2023-08-27 15:53:44 +01:00
return fmt.Errorf("%w %v", combinedErr, utils.GetLogErrMsg(vm.stderrBuf.String()))
2023-08-25 16:54:58 +01:00
}
2023-08-25 15:12:19 +01:00
2023-08-25 16:54:58 +01:00
return nil
2023-08-25 15:12:19 +01:00
}
2023-08-27 13:44:57 +01:00
func (vm *VM) Cancel() error {
if atomic.AddUint32(&vm.canceled, 1) != 1 {
2023-08-25 15:12:19 +01:00
return nil
}
2023-08-27 13:44:57 +01:00
vm.logger.Warn("Canceling the VM context")
2023-08-25 19:55:11 +01:00
var gracefulOK bool
2023-08-27 13:44:57 +01:00
sc, err := vm.DialSSH()
2023-08-25 19:55:11 +01:00
if err != nil {
if !errors.Is(err, ErrSSHUnavailable) {
2023-08-27 13:44:57 +01:00
vm.logger.Warn("Failed to dial VM ssh to do graceful shutdown", "error", err)
2023-08-25 19:55:11 +01:00
}
} else {
_, err = runSSHCmd(sc, "poweroff")
_ = sc.Close()
if err != nil {
2023-08-27 13:44:57 +01:00
vm.logger.Warn("Could not power off the VM safely", "error", err)
2023-08-25 19:55:11 +01:00
} else {
2023-08-27 13:44:57 +01:00
vm.logger.Info("Shutting the VM down safely")
2023-08-25 19:55:11 +01:00
}
}
var interruptErr error
if !gracefulOK {
2023-08-27 13:44:57 +01:00
interruptErr = vm.cmd.Process.Signal(os.Interrupt)
2023-08-25 19:55:11 +01:00
}
2023-08-27 13:44:57 +01:00
vm.ctxCancel()
2023-08-25 15:12:19 +01:00
return multierr.Combine(
2023-08-25 19:55:11 +01:00
errors.Wrap(interruptErr, "interrupt cmd"),
2023-08-27 13:44:57 +01:00
errors.Wrap(vm.serialRead.Close(), "close serial read pipe"),
errors.Wrap(vm.serialWrite.Close(), "close serial write pipe"),
2023-08-25 15:12:19 +01:00
)
}
2023-08-27 13:44:57 +01:00
func (vm *VM) runSerialReader() error {
2023-08-25 15:12:19 +01:00
for {
2023-08-27 13:44:57 +01:00
raw, err := vm.serialReader.ReadBytes('\n')
2023-08-25 15:12:19 +01:00
if err != nil {
return errors.Wrap(err, "read from serial reader")
}
select {
2023-08-27 13:44:57 +01:00
case vm.serialStdoutCh <- raw:
2023-08-25 15:12:19 +01:00
default:
// Message gets discarded if the buffer is full.
}
}
}
2023-08-27 13:44:57 +01:00
func (vm *VM) writeSerial(b []byte) error {
vm.serialWriteMu.Lock()
defer vm.serialWriteMu.Unlock()
2023-08-25 15:12:19 +01:00
2023-08-27 13:44:57 +01:00
_, err := vm.serialWrite.Write(b)
2023-08-25 15:12:19 +01:00
return err
}
2023-08-27 13:44:57 +01:00
func (vm *VM) runVMLoginHandler() error {
2023-08-25 15:12:19 +01:00
for {
select {
2023-08-27 13:44:57 +01:00
case <-vm.ctx.Done():
2023-08-25 15:12:19 +01:00
return nil
case <-time.After(time.Second):
2023-08-27 13:44:57 +01:00
peek, err := vm.serialReader.Peek(vm.serialReader.Buffered())
2023-08-25 15:12:19 +01:00
if err != nil {
return errors.Wrap(err, "peek stdout")
}
if bytes.Contains(peek, []byte("login:")) {
2023-08-27 13:44:57 +01:00
err = vm.writeSerial([]byte("root\n"))
2023-08-25 15:12:19 +01:00
if err != nil {
return errors.Wrap(err, "failed to stdio write login")
}
2023-08-27 13:44:57 +01:00
vm.logger.Debug("Logged into the VM serial")
2023-08-25 15:12:19 +01:00
return nil
}
}
}
}
2023-08-27 13:44:57 +01:00
func (vm *VM) resetSerialStdout() {
vm.serialStdoutCh = make(chan []byte, 32)
2023-08-25 15:12:19 +01:00
}
2023-08-27 13:44:57 +01:00
func (vm *VM) DialSSH() (*ssh.Client, error) {
if vm.sshConf == nil {
2023-08-25 15:12:19 +01:00
return nil, ErrSSHUnavailable
}
2023-08-27 13:44:57 +01:00
return ssh.Dial("tcp", "localhost:"+fmt.Sprint(vm.sshMappedPort), vm.sshConf)
2023-08-25 15:12:19 +01:00
}
2023-08-27 13:44:57 +01:00
func (vm *VM) DialSCP() (*scp.Client, error) {
if vm.sshConf == nil {
2023-08-26 16:26:35 +01:00
return nil, ErrSSHUnavailable
}
2023-08-27 13:44:57 +01:00
sc := scp.NewClient("localhost:"+fmt.Sprint(vm.sshMappedPort), vm.sshConf)
2023-08-26 16:26:35 +01:00
err := sc.Connect()
if err != nil {
return nil, err
}
return &sc, nil
}
2023-08-27 13:44:57 +01:00
func (vm *VM) SSHUpNotifyChan() chan struct{} {
return vm.sshReadyCh
2023-08-25 15:12:19 +01:00
}