Some restructuring

This commit is contained in:
AlexSSD7 2023-08-27 13:44:57 +01:00
commit ee447087f6
8 changed files with 131 additions and 110 deletions

View file

@ -18,7 +18,7 @@ var lsCmd = &cobra.Command{
// Short: "", // Short: "",
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
os.Exit(runVM(args[0], func(ctx context.Context, i *vm.Instance, fm *vm.FileManager) int { os.Exit(runVM(args[0], func(ctx context.Context, i *vm.VM, fm *vm.FileManager) int {
lsblkOut, err := fm.Lsblk() lsblkOut, err := fm.Lsblk()
if err != nil { if err != nil {
slog.Error("Failed to list block devices in the VM", "error", err) slog.Error("Failed to list block devices in the VM", "error", err)

View file

@ -29,7 +29,7 @@ var runCmd = &cobra.Command{
// TODO: `slog` library prints entire stack traces for errors which makes reading errors challenging. // TODO: `slog` library prints entire stack traces for errors which makes reading errors challenging.
os.Exit(runVM(args[0], func(ctx context.Context, i *vm.Instance, fm *vm.FileManager) int { os.Exit(runVM(args[0], func(ctx context.Context, i *vm.VM, fm *vm.FileManager) int {
err := fm.Mount(vmMountDevName, vm.MountOptions{ err := fm.Mount(vmMountDevName, vm.MountOptions{
FSType: fsType, FSType: fsType,
LUKS: luksFlag, LUKS: luksFlag,
@ -59,7 +59,7 @@ var runCmd = &cobra.Command{
<-ctx.Done() <-ctx.Done()
return 0 return 0
}, []vm.PortForwardingConfig{{ }, []vm.PortForwardingRule{{
HostIP: net.ParseIP("127.0.0.1"), // TODO: Make this changeable. HostIP: net.ParseIP("127.0.0.1"), // TODO: Make this changeable.
HostPort: networkSharePort, HostPort: networkSharePort,
VMPort: 445, VMPort: 445,

View file

@ -23,23 +23,23 @@ var shellCmd = &cobra.Command{
passthroughArg = args[0] passthroughArg = args[0]
} }
var forwardPortsConfig []vm.PortForwardingConfig var forwardPortRules []vm.PortForwardingRule
for i, fp := range strings.Split(forwardPortsFlagStr, ",") { for i, fp := range strings.Split(forwardPortsFlagStr, ",") {
if fp == "" { if fp == "" {
continue continue
} }
fpc, err := vm.ParsePortForwardString(fp) fpr, err := vm.ParsePortForwardString(fp)
if err != nil { if err != nil {
slog.Error("Failed to parse port forward string", "index", i, "value", fp, "error", err) slog.Error("Failed to parse port forward string", "index", i, "value", fp, "error", err)
os.Exit(1) os.Exit(1)
} }
forwardPortsConfig = append(forwardPortsConfig, fpc) forwardPortRules = append(forwardPortRules, fpr)
} }
os.Exit(runVM(passthroughArg, func(ctx context.Context, i *vm.Instance, fm *vm.FileManager) int { os.Exit(runVM(passthroughArg, func(ctx context.Context, i *vm.VM, fm *vm.FileManager) int {
sc, err := i.DialSSH() sc, err := i.DialSSH()
if err != nil { if err != nil {
slog.Error("Failed to dial VM SSH", "error", err) slog.Error("Failed to dial VM SSH", "error", err)
@ -120,7 +120,7 @@ var shellCmd = &cobra.Command{
} }
return 0 return 0
}, forwardPortsConfig, unrestrictedNetworkingFlag)) }, forwardPortRules, unrestrictedNetworkingFlag))
return nil return nil
}, },

View file

@ -37,7 +37,7 @@ func doRootCheck() {
} }
} }
func runVM(passthroughArg string, fn func(context.Context, *vm.Instance, *vm.FileManager) int, forwardPorts []vm.PortForwardingConfig, unrestrictedNetworking bool) int { func runVM(passthroughArg string, fn func(context.Context, *vm.VM, *vm.FileManager) int, forwardPortsRules []vm.PortForwardingRule, unrestrictedNetworking bool) int {
doRootCheck() doRootCheck()
var passthroughConfig []vm.USBDevicePassthroughConfig var passthroughConfig []vm.USBDevicePassthroughConfig
@ -46,8 +46,18 @@ func runVM(passthroughArg string, fn func(context.Context, *vm.Instance, *vm.Fil
passthroughConfig = []vm.USBDevicePassthroughConfig{getDevicePassthroughConfig(passthroughArg)} passthroughConfig = []vm.USBDevicePassthroughConfig{getDevicePassthroughConfig(passthroughArg)}
} }
vmCfg := vm.VMConfig{
CdromImagePath: "alpine-img/alpine.qcow2",
USBDevices: passthroughConfig,
ExtraPortForwardingRules: forwardPortsRules,
DebugUnrestrictedNetworking: unrestrictedNetworking,
DebugShowDisplay: vmDebugFlag,
}
// TODO: Alpine image should be downloaded from somewhere. // TODO: Alpine image should be downloaded from somewhere.
vi, err := vm.NewInstance(slog.Default().With("caller", "vm"), "alpine-img/alpine.qcow2", passthroughConfig, vmDebugFlag, forwardPorts, unrestrictedNetworking) vi, err := vm.NewVM(slog.Default().With("caller", "vm"), vmCfg)
if err != nil { if err != nil {
slog.Error("Failed to create vm instance", "error", err) slog.Error("Failed to create vm instance", "error", err)
os.Exit(1) os.Exit(1)

View file

@ -22,19 +22,19 @@ import (
type FileManager struct { type FileManager struct {
logger *slog.Logger logger *slog.Logger
vi *Instance vm *VM
} }
func NewFileManager(logger *slog.Logger, vi *Instance) *FileManager { func NewFileManager(logger *slog.Logger, vm *VM) *FileManager {
return &FileManager{ return &FileManager{
logger: logger, logger: logger,
vi: vi, vm: vm,
} }
} }
func (fm *FileManager) Init() error { func (fm *FileManager) Init() error {
sc, err := fm.vi.DialSSH() sc, err := fm.vm.DialSSH()
if err != nil { if err != nil {
return errors.Wrap(err, "dial vm ssh") return errors.Wrap(err, "dial vm ssh")
} }
@ -55,7 +55,7 @@ func (fm *FileManager) Init() error {
} }
func (fm *FileManager) Lsblk() ([]byte, error) { func (fm *FileManager) Lsblk() ([]byte, error) {
sc, err := fm.vi.DialSSH() sc, err := fm.vm.DialSSH()
if err != nil { if err != nil {
return nil, errors.Wrap(err, "dial vm ssh") return nil, errors.Wrap(err, "dial vm ssh")
} }
@ -180,7 +180,7 @@ func (fm *FileManager) Mount(devName string, mo MountOptions) error {
return fmt.Errorf("fs type is empty") return fmt.Errorf("fs type is empty")
} }
sc, err := fm.vi.DialSSH() sc, err := fm.vm.DialSSH()
if err != nil { if err != nil {
return errors.Wrap(err, "dial vm ssh") return errors.Wrap(err, "dial vm ssh")
} }
@ -205,7 +205,7 @@ func (fm *FileManager) Mount(devName string, mo MountOptions) error {
} }
func (fm *FileManager) StartSMB(pwd []byte) error { func (fm *FileManager) StartSMB(pwd []byte) error {
scpClient, err := fm.vi.DialSCP() scpClient, err := fm.vm.DialSCP()
if err != nil { if err != nil {
return errors.Wrap(err, "dial scp") return errors.Wrap(err, "dial scp")
} }
@ -225,14 +225,14 @@ force user = linsk
force group = linsk force group = linsk
create mask = 0664` create mask = 0664`
err = scpClient.CopyFile(fm.vi.ctx, strings.NewReader(sambaCfg), "/etc/samba/smb.conf", "0400") err = scpClient.CopyFile(fm.vm.ctx, strings.NewReader(sambaCfg), "/etc/samba/smb.conf", "0400")
if err != nil { if err != nil {
return errors.Wrap(err, "copy samba config file") return errors.Wrap(err, "copy samba config file")
} }
scpClient.Close() scpClient.Close()
sc, err := fm.vi.DialSSH() sc, err := fm.vm.DialSSH()
if err != nil { if err != nil {
return errors.Wrap(err, "dial ssh") return errors.Wrap(err, "dial ssh")
} }
@ -269,11 +269,11 @@ create mask = 0664`
go func() { go func() {
_, err = stdinPipe.Write(pwd) _, err = stdinPipe.Write(pwd)
if err != nil { if err != nil {
fm.vi.logger.Error("Failed to write SMB password to smbpasswd stdin", "error", err) fm.vm.logger.Error("Failed to write SMB password to smbpasswd stdin", "error", err)
} }
_, err = stdinPipe.Write(pwd) _, err = stdinPipe.Write(pwd)
if err != nil { if err != nil {
fm.vi.logger.Error("Failed to write repeated SMB password to smbpasswd stdin", "error", err) fm.vm.logger.Error("Failed to write repeated SMB password to smbpasswd stdin", "error", err)
} }
}() }()

View file

@ -50,10 +50,10 @@ func ParseSSHKeyScan(knownHosts []byte) (ssh.HostKeyCallback, error) {
}, nil }, nil
} }
func (vi *Instance) scanSSHIdentity() ([]byte, error) { func (vm *VM) scanSSHIdentity() ([]byte, error) {
vi.resetSerialStdout() vm.resetSerialStdout()
err := vi.writeSerial([]byte(`ssh-keyscan -H localhost; echo "SERIAL STATUS: $?"; rm /root/.ash_history` + "\n")) err := vm.writeSerial([]byte(`ssh-keyscan -H localhost; echo "SERIAL STATUS: $?"; rm /root/.ash_history` + "\n"))
if err != nil { if err != nil {
return nil, errors.Wrap(err, "write keyscan command to serial") return nil, errors.Wrap(err, "write keyscan command to serial")
} }
@ -64,11 +64,11 @@ func (vi *Instance) scanSSHIdentity() ([]byte, error) {
for { for {
select { select {
case <-vi.ctx.Done(): case <-vm.ctx.Done():
return nil, vi.ctx.Err() return nil, vm.ctx.Err()
case <-time.After(time.Until(deadline)): case <-time.After(time.Until(deadline)):
return nil, fmt.Errorf("keyscan command timed out") return nil, fmt.Errorf("keyscan command timed out")
case data := <-vi.serialStdoutCh: case data := <-vm.serialStdoutCh:
if len(data) == 0 { if len(data) == 0 {
continue continue
} }
@ -91,8 +91,8 @@ func (vi *Instance) scanSSHIdentity() ([]byte, error) {
} }
} }
func (vi *Instance) sshSetup() (ssh.Signer, error) { func (vm *VM) sshSetup() (ssh.Signer, error) {
vi.resetSerialStdout() vm.resetSerialStdout()
sshSigner, sshPublicKey, err := generateSSHKey() sshSigner, sshPublicKey, err := generateSSHKey()
if err != nil { if err != nil {
@ -101,7 +101,7 @@ func (vi *Instance) sshSetup() (ssh.Signer, error) {
cmd := `set -ex; do_setup () { sh -c "set -ex; ifconfig eth0 up; ifconfig lo up; udhcpc; mkdir -p ~/.ssh; echo ` + shellescape.Quote(string(sshPublicKey)) + ` > ~/.ssh/authorized_keys; rc-update add sshd; rc-service sshd start"; echo "SERIAL"" ""STATUS: $?"; }; do_setup` + "\n" cmd := `set -ex; do_setup () { sh -c "set -ex; ifconfig eth0 up; ifconfig lo up; udhcpc; mkdir -p ~/.ssh; echo ` + shellescape.Quote(string(sshPublicKey)) + ` > ~/.ssh/authorized_keys; rc-update add sshd; rc-service sshd start"; echo "SERIAL"" ""STATUS: $?"; }; do_setup` + "\n"
err = vi.writeSerial([]byte(cmd)) err = vm.writeSerial([]byte(cmd))
if err != nil { if err != nil {
return nil, errors.Wrap(err, "write ssh setup serial command") return nil, errors.Wrap(err, "write ssh setup serial command")
} }
@ -112,11 +112,11 @@ func (vi *Instance) sshSetup() (ssh.Signer, error) {
for { for {
select { select {
case <-vi.ctx.Done(): case <-vm.ctx.Done():
return nil, vi.ctx.Err() return nil, vm.ctx.Err()
case <-time.After(time.Until(deadline)): case <-time.After(time.Until(deadline)):
return nil, fmt.Errorf("setup command timed out %v", getLogErrMsg(stdOutErrBuf.String())) return nil, fmt.Errorf("setup command timed out %v", getLogErrMsg(stdOutErrBuf.String()))
case data := <-vi.serialStdoutCh: case data := <-vm.serialStdoutCh:
prefix := []byte("SERIAL STATUS: ") prefix := []byte("SERIAL STATUS: ")
stdOutErrBuf.Write(data) stdOutErrBuf.Write(data)
if bytes.HasPrefix(data, prefix) { if bytes.HasPrefix(data, prefix) {

View file

@ -14,28 +14,28 @@ type USBDevicePassthroughConfig struct {
HostPort uint8 HostPort uint8
} }
type PortForwardingConfig struct { type PortForwardingRule struct {
HostIP net.IP HostIP net.IP
HostPort uint16 HostPort uint16
VMPort uint16 VMPort uint16
} }
func ParsePortForwardString(s string) (PortForwardingConfig, error) { func ParsePortForwardString(s string) (PortForwardingRule, error) {
split := strings.Split(s, ":") split := strings.Split(s, ":")
switch len(split) { switch len(split) {
case 2: case 2:
// <HOST PORT>:<VM PORT> // <HOST PORT>:<VM PORT>
hostPort, err := strconv.ParseUint(split[0], 10, 16) hostPort, err := strconv.ParseUint(split[0], 10, 16)
if err != nil { if err != nil {
return PortForwardingConfig{}, errors.Wrap(err, "parse host port") return PortForwardingRule{}, errors.Wrap(err, "parse host port")
} }
vmPort, err := strconv.ParseUint(split[1], 10, 16) vmPort, err := strconv.ParseUint(split[1], 10, 16)
if err != nil { if err != nil {
return PortForwardingConfig{}, errors.Wrap(err, "parse vm port") return PortForwardingRule{}, errors.Wrap(err, "parse vm port")
} }
return PortForwardingConfig{ return PortForwardingRule{
HostPort: uint16(hostPort), HostPort: uint16(hostPort),
VMPort: uint16(vmPort), VMPort: uint16(vmPort),
}, nil }, nil
@ -43,25 +43,25 @@ func ParsePortForwardString(s string) (PortForwardingConfig, error) {
// <HOST IP>:<HOST PORT>:<VM PORT> // <HOST IP>:<HOST PORT>:<VM PORT>
hostIP := net.ParseIP(split[0]) hostIP := net.ParseIP(split[0])
if hostIP == nil { if hostIP == nil {
return PortForwardingConfig{}, fmt.Errorf("bad host ip") return PortForwardingRule{}, fmt.Errorf("bad host ip")
} }
hostPort, err := strconv.ParseUint(split[1], 10, 16) hostPort, err := strconv.ParseUint(split[1], 10, 16)
if err != nil { if err != nil {
return PortForwardingConfig{}, errors.Wrap(err, "parse host port") return PortForwardingRule{}, errors.Wrap(err, "parse host port")
} }
vmPort, err := strconv.ParseUint(split[2], 10, 16) vmPort, err := strconv.ParseUint(split[2], 10, 16)
if err != nil { if err != nil {
return PortForwardingConfig{}, errors.Wrap(err, "parse vm port") return PortForwardingRule{}, errors.Wrap(err, "parse vm port")
} }
return PortForwardingConfig{ return PortForwardingRule{
HostIP: hostIP, HostIP: hostIP,
HostPort: uint16(hostPort), HostPort: uint16(hostPort),
VMPort: uint16(vmPort), VMPort: uint16(vmPort),
}, nil }, nil
default: default:
return PortForwardingConfig{}, fmt.Errorf("bad split by ':' length: want 2 or 3, have %v", len(split)) return PortForwardingRule{}, fmt.Errorf("bad split by ':' length: want 2 or 3, have %v", len(split))
} }
} }

145
vm/vm.go
View file

@ -26,7 +26,7 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
type Instance struct { type VM struct {
logger *slog.Logger logger *slog.Logger
ctx context.Context ctx context.Context
@ -51,11 +51,22 @@ type Instance struct {
canceled uint32 canceled uint32
} }
func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDevicePassthroughConfig, debug bool, extraPortForwardings []PortForwardingConfig, unrestrictedNetworking bool) (*Instance, error) { type VMConfig struct {
alpineImagePath = filepath.Clean(alpineImagePath) CdromImagePath string
_, err := os.Stat(alpineImagePath)
USBDevices []USBDevicePassthroughConfig
ExtraPortForwardingRules []PortForwardingRule
// Debug-related options.
DebugUnrestrictedNetworking bool
DebugShowDisplay bool
}
func NewVM(logger *slog.Logger, cfg VMConfig) (*VM, error) {
cdromImagePath := filepath.Clean(cfg.CdromImagePath)
_, err := os.Stat(cdromImagePath)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to stat alpine image path") return nil, errors.Wrap(err, "stat cdrom image path")
} }
sshPort, err := freeport.GetFreePort() sshPort, err := freeport.GetFreePort()
@ -70,13 +81,13 @@ func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDe
netdevOpts := "user,id=net0,hostfwd=tcp:127.0.0.1:" + fmt.Sprint(sshPort) + "-:22" netdevOpts := "user,id=net0,hostfwd=tcp:127.0.0.1:" + fmt.Sprint(sshPort) + "-:22"
if !unrestrictedNetworking { if !cfg.DebugUnrestrictedNetworking {
netdevOpts += ",restrict=on" netdevOpts += ",restrict=on"
} else { } else {
logger.Warn("Running with unsafe unrestricted networking") logger.Warn("Running with unsafe unrestricted networking")
} }
for _, pf := range extraPortForwardings { for _, pf := range cfg.ExtraPortForwardingRules {
hostIPStr := "" hostIPStr := ""
if pf.HostIP != nil { if pf.HostIP != nil {
hostIPStr = pf.HostIP.String() hostIPStr = pf.HostIP.String()
@ -87,16 +98,16 @@ func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDe
cmdArgs = append(cmdArgs, "-device", "e1000,netdev=net0", "-netdev", netdevOpts) cmdArgs = append(cmdArgs, "-device", "e1000,netdev=net0", "-netdev", netdevOpts)
cmdArgs = append(cmdArgs, "-drive", "file="+shellescape.Quote(alpineImagePath)+",format=qcow2,if=virtio", "-snapshot") cmdArgs = append(cmdArgs, "-drive", "file="+shellescape.Quote(cdromImagePath)+",format=qcow2,if=virtio", "-snapshot")
if !debug { if !cfg.DebugShowDisplay {
cmdArgs = append(cmdArgs, "-display", "none") cmdArgs = append(cmdArgs, "-display", "none")
} }
if len(usbDevices) != 0 { if len(cfg.USBDevices) != 0 {
cmdArgs = append(cmdArgs, "-usb", "-device", "nec-usb-xhci,id=xhci") cmdArgs = append(cmdArgs, "-usb", "-device", "nec-usb-xhci,id=xhci")
for _, dev := range usbDevices { for _, dev := range cfg.USBDevices {
cmdArgs = append(cmdArgs, "-device", "usb-host,hostbus="+strconv.FormatUint(uint64(dev.HostBus), 10)+",hostport="+strconv.FormatUint(uint64(dev.HostPort), 10)) cmdArgs = append(cmdArgs, "-device", "usb-host,hostbus="+strconv.FormatUint(uint64(dev.HostBus), 10)+",hostport="+strconv.FormatUint(uint64(dev.HostPort), 10))
} }
} }
@ -120,7 +131,7 @@ func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDe
ctx, ctxCancel := context.WithCancel(context.Background()) ctx, ctxCancel := context.WithCancel(context.Background())
vi := &Instance{ vm := &VM{
logger: logger, logger: logger,
ctx: ctx, ctx: ctx,
@ -137,17 +148,17 @@ func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDe
stderrBuf: stderrBuf, stderrBuf: stderrBuf,
} }
vi.resetSerialStdout() vm.resetSerialStdout()
return vi, nil return vm, nil
} }
func (vi *Instance) Run() error { func (vm *VM) Run() error {
if atomic.AddUint32(&vi.disposed, 1) != 1 { if atomic.AddUint32(&vm.disposed, 1) != 1 {
return fmt.Errorf("vm disposed") return fmt.Errorf("vm disposed")
} }
err := vi.cmd.Start() err := vm.cmd.Start()
if err != nil { if err != nil {
return errors.Wrap(err, "start qemu cmd") return errors.Wrap(err, "start qemu cmd")
} }
@ -159,40 +170,40 @@ func (vi *Instance) Run() error {
globalErrsMu.Lock() globalErrsMu.Lock()
defer globalErrsMu.Unlock() defer globalErrsMu.Unlock()
globalErrs = append(globalErrs, err, errors.Wrap(vi.Cancel(), "cancel on error")) globalErrs = append(globalErrs, err, errors.Wrap(vm.Cancel(), "cancel on error"))
} }
vi.logger.Info("Booting the VM") vm.logger.Info("Booting the VM")
go func() { go func() {
_ = vi.runSerialReader() _ = vm.runSerialReader()
_ = vi.Cancel() _ = vm.Cancel()
}() }()
go func() { go func() {
err = vi.runVMLoginHandler() err = vm.runVMLoginHandler()
if err != nil { if err != nil {
globalErrFn(errors.Wrap(err, "run vm login handler")) globalErrFn(errors.Wrap(err, "run vm login handler"))
return return
} }
vi.logger.Info("Setting the VM up") vm.logger.Info("Setting the VM up")
sshSigner, err := vi.sshSetup() sshSigner, err := vm.sshSetup()
if err != nil { if err != nil {
globalErrFn(errors.Wrap(err, "set up ssh")) globalErrFn(errors.Wrap(err, "set up ssh"))
return return
} }
vi.logger.Debug("Set up SSH server successfully") vm.logger.Debug("Set up SSH server successfully")
sshKeyScan, err := vi.scanSSHIdentity() sshKeyScan, err := vm.scanSSHIdentity()
if err != nil { if err != nil {
globalErrFn(errors.Wrap(err, "scan ssh identity")) globalErrFn(errors.Wrap(err, "scan ssh identity"))
return return
} }
vi.logger.Debug("Scanned SSH identity") vm.logger.Debug("Scanned SSH identity")
knownHosts, err := ParseSSHKeyScan(sshKeyScan) knownHosts, err := ParseSSHKeyScan(sshKeyScan)
if err != nil { if err != nil {
@ -201,7 +212,7 @@ func (vi *Instance) Run() error {
return return
} }
vi.sshConf = &ssh.ClientConfig{ vm.sshConf = &ssh.ClientConfig{
User: "root", User: "root",
HostKeyCallback: knownHosts, HostKeyCallback: knownHosts,
Auth: []ssh.AuthMethod{ Auth: []ssh.AuthMethod{
@ -211,111 +222,111 @@ func (vi *Instance) Run() error {
} }
// This is to notify everyone waiting for SSH to be up that it's ready to go. // This is to notify everyone waiting for SSH to be up that it's ready to go.
close(vi.sshReadyCh) close(vm.sshReadyCh)
vi.logger.Info("The VM is ready") vm.logger.Info("The VM is ready")
}() }()
_, err = vi.cmd.Process.Wait() _, err = vm.cmd.Process.Wait()
cancelErr := vi.Cancel() cancelErr := vm.Cancel()
if err != nil { if err != nil {
combinedErr := multierr.Combine( combinedErr := multierr.Combine(
errors.Wrap(err, "wait for cmd to finish execution"), errors.Wrap(err, "wait for cmd to finish execution"),
errors.Wrap(cancelErr, "cancel"), errors.Wrap(cancelErr, "cancel"),
) )
return fmt.Errorf("%w %v", combinedErr, getLogErrMsg(vi.stderrBuf.String())) return fmt.Errorf("%w %v", combinedErr, getLogErrMsg(vm.stderrBuf.String()))
} }
combinedErr := multierr.Combine( combinedErr := multierr.Combine(
append(globalErrs, errors.Wrap(cancelErr, "cancel on exit"))..., append(globalErrs, errors.Wrap(cancelErr, "cancel on exit"))...,
) )
if combinedErr != nil { if combinedErr != nil {
return fmt.Errorf("%w %v", combinedErr, getLogErrMsg(vi.stderrBuf.String())) return fmt.Errorf("%w %v", combinedErr, getLogErrMsg(vm.stderrBuf.String()))
} }
return nil return nil
} }
func (vi *Instance) Cancel() error { func (vm *VM) Cancel() error {
if atomic.AddUint32(&vi.canceled, 1) != 1 { if atomic.AddUint32(&vm.canceled, 1) != 1 {
return nil return nil
} }
vi.logger.Warn("Canceling the VM context") vm.logger.Warn("Canceling the VM context")
var gracefulOK bool var gracefulOK bool
sc, err := vi.DialSSH() sc, err := vm.DialSSH()
if err != nil { if err != nil {
if !errors.Is(err, ErrSSHUnavailable) { if !errors.Is(err, ErrSSHUnavailable) {
vi.logger.Warn("Failed to dial VM ssh to do graceful shutdown", "error", err) vm.logger.Warn("Failed to dial VM ssh to do graceful shutdown", "error", err)
} }
} else { } else {
_, err = runSSHCmd(sc, "poweroff") _, err = runSSHCmd(sc, "poweroff")
_ = sc.Close() _ = sc.Close()
if err != nil { if err != nil {
vi.logger.Warn("Could not power off the VM safely", "error", err) vm.logger.Warn("Could not power off the VM safely", "error", err)
} else { } else {
vi.logger.Info("Shutting the VM down safely") vm.logger.Info("Shutting the VM down safely")
} }
} }
var interruptErr error var interruptErr error
if !gracefulOK { if !gracefulOK {
interruptErr = vi.cmd.Process.Signal(os.Interrupt) interruptErr = vm.cmd.Process.Signal(os.Interrupt)
} }
vi.ctxCancel() vm.ctxCancel()
return multierr.Combine( return multierr.Combine(
errors.Wrap(interruptErr, "interrupt cmd"), errors.Wrap(interruptErr, "interrupt cmd"),
errors.Wrap(vi.serialRead.Close(), "close serial read pipe"), errors.Wrap(vm.serialRead.Close(), "close serial read pipe"),
errors.Wrap(vi.serialWrite.Close(), "close serial write pipe"), errors.Wrap(vm.serialWrite.Close(), "close serial write pipe"),
) )
} }
func (vi *Instance) runSerialReader() error { func (vm *VM) runSerialReader() error {
for { for {
raw, err := vi.serialReader.ReadBytes('\n') raw, err := vm.serialReader.ReadBytes('\n')
if err != nil { if err != nil {
return errors.Wrap(err, "read from serial reader") return errors.Wrap(err, "read from serial reader")
} }
select { select {
case vi.serialStdoutCh <- raw: case vm.serialStdoutCh <- raw:
default: default:
// Message gets discarded if the buffer is full. // Message gets discarded if the buffer is full.
} }
} }
} }
func (vi *Instance) writeSerial(b []byte) error { func (vm *VM) writeSerial(b []byte) error {
vi.serialWriteMu.Lock() vm.serialWriteMu.Lock()
defer vi.serialWriteMu.Unlock() defer vm.serialWriteMu.Unlock()
_, err := vi.serialWrite.Write(b) _, err := vm.serialWrite.Write(b)
return err return err
} }
func (vi *Instance) runVMLoginHandler() error { func (vm *VM) runVMLoginHandler() error {
for { for {
select { select {
case <-vi.ctx.Done(): case <-vm.ctx.Done():
return nil return nil
case <-time.After(time.Second): case <-time.After(time.Second):
peek, err := vi.serialReader.Peek(vi.serialReader.Buffered()) peek, err := vm.serialReader.Peek(vm.serialReader.Buffered())
if err != nil { if err != nil {
return errors.Wrap(err, "peek stdout") return errors.Wrap(err, "peek stdout")
} }
if bytes.Contains(peek, []byte("login:")) { if bytes.Contains(peek, []byte("login:")) {
err = vi.writeSerial([]byte("root\n")) err = vm.writeSerial([]byte("root\n"))
if err != nil { if err != nil {
return errors.Wrap(err, "failed to stdio write login") return errors.Wrap(err, "failed to stdio write login")
} }
vi.logger.Debug("Logged into the VM serial") vm.logger.Debug("Logged into the VM serial")
return nil return nil
} }
@ -323,24 +334,24 @@ func (vi *Instance) runVMLoginHandler() error {
} }
} }
func (vi *Instance) resetSerialStdout() { func (vm *VM) resetSerialStdout() {
vi.serialStdoutCh = make(chan []byte, 32) vm.serialStdoutCh = make(chan []byte, 32)
} }
func (vi *Instance) DialSSH() (*ssh.Client, error) { func (vm *VM) DialSSH() (*ssh.Client, error) {
if vi.sshConf == nil { if vm.sshConf == nil {
return nil, ErrSSHUnavailable return nil, ErrSSHUnavailable
} }
return ssh.Dial("tcp", "localhost:"+fmt.Sprint(vi.sshMappedPort), vi.sshConf) return ssh.Dial("tcp", "localhost:"+fmt.Sprint(vm.sshMappedPort), vm.sshConf)
} }
func (vi *Instance) DialSCP() (*scp.Client, error) { func (vm *VM) DialSCP() (*scp.Client, error) {
if vi.sshConf == nil { if vm.sshConf == nil {
return nil, ErrSSHUnavailable return nil, ErrSSHUnavailable
} }
sc := scp.NewClient("localhost:"+fmt.Sprint(vi.sshMappedPort), vi.sshConf) sc := scp.NewClient("localhost:"+fmt.Sprint(vm.sshMappedPort), vm.sshConf)
err := sc.Connect() err := sc.Connect()
if err != nil { if err != nil {
return nil, err return nil, err
@ -349,6 +360,6 @@ func (vi *Instance) DialSCP() (*scp.Client, error) {
return &sc, nil return &sc, nil
} }
func (vi *Instance) SSHUpNotifyChan() chan struct{} { func (vm *VM) SSHUpNotifyChan() chan struct{} {
return vi.sshReadyCh return vm.sshReadyCh
} }