Some restructuring

This commit is contained in:
AlexSSD7 2023-08-27 13:44:57 +01:00
commit ee447087f6
8 changed files with 131 additions and 110 deletions

145
vm/vm.go
View file

@ -26,7 +26,7 @@ import (
"golang.org/x/crypto/ssh"
)
type Instance struct {
type VM struct {
logger *slog.Logger
ctx context.Context
@ -51,11 +51,22 @@ type Instance struct {
canceled uint32
}
func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDevicePassthroughConfig, debug bool, extraPortForwardings []PortForwardingConfig, unrestrictedNetworking bool) (*Instance, error) {
alpineImagePath = filepath.Clean(alpineImagePath)
_, err := os.Stat(alpineImagePath)
type VMConfig struct {
CdromImagePath string
USBDevices []USBDevicePassthroughConfig
ExtraPortForwardingRules []PortForwardingRule
// Debug-related options.
DebugUnrestrictedNetworking bool
DebugShowDisplay bool
}
func NewVM(logger *slog.Logger, cfg VMConfig) (*VM, error) {
cdromImagePath := filepath.Clean(cfg.CdromImagePath)
_, err := os.Stat(cdromImagePath)
if err != nil {
return nil, errors.Wrap(err, "failed to stat alpine image path")
return nil, errors.Wrap(err, "stat cdrom image path")
}
sshPort, err := freeport.GetFreePort()
@ -70,13 +81,13 @@ func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDe
netdevOpts := "user,id=net0,hostfwd=tcp:127.0.0.1:" + fmt.Sprint(sshPort) + "-:22"
if !unrestrictedNetworking {
if !cfg.DebugUnrestrictedNetworking {
netdevOpts += ",restrict=on"
} else {
logger.Warn("Running with unsafe unrestricted networking")
}
for _, pf := range extraPortForwardings {
for _, pf := range cfg.ExtraPortForwardingRules {
hostIPStr := ""
if pf.HostIP != nil {
hostIPStr = pf.HostIP.String()
@ -87,16 +98,16 @@ func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDe
cmdArgs = append(cmdArgs, "-device", "e1000,netdev=net0", "-netdev", netdevOpts)
cmdArgs = append(cmdArgs, "-drive", "file="+shellescape.Quote(alpineImagePath)+",format=qcow2,if=virtio", "-snapshot")
cmdArgs = append(cmdArgs, "-drive", "file="+shellescape.Quote(cdromImagePath)+",format=qcow2,if=virtio", "-snapshot")
if !debug {
if !cfg.DebugShowDisplay {
cmdArgs = append(cmdArgs, "-display", "none")
}
if len(usbDevices) != 0 {
if len(cfg.USBDevices) != 0 {
cmdArgs = append(cmdArgs, "-usb", "-device", "nec-usb-xhci,id=xhci")
for _, dev := range usbDevices {
for _, dev := range cfg.USBDevices {
cmdArgs = append(cmdArgs, "-device", "usb-host,hostbus="+strconv.FormatUint(uint64(dev.HostBus), 10)+",hostport="+strconv.FormatUint(uint64(dev.HostPort), 10))
}
}
@ -120,7 +131,7 @@ func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDe
ctx, ctxCancel := context.WithCancel(context.Background())
vi := &Instance{
vm := &VM{
logger: logger,
ctx: ctx,
@ -137,17 +148,17 @@ func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDe
stderrBuf: stderrBuf,
}
vi.resetSerialStdout()
vm.resetSerialStdout()
return vi, nil
return vm, nil
}
func (vi *Instance) Run() error {
if atomic.AddUint32(&vi.disposed, 1) != 1 {
func (vm *VM) Run() error {
if atomic.AddUint32(&vm.disposed, 1) != 1 {
return fmt.Errorf("vm disposed")
}
err := vi.cmd.Start()
err := vm.cmd.Start()
if err != nil {
return errors.Wrap(err, "start qemu cmd")
}
@ -159,40 +170,40 @@ func (vi *Instance) Run() error {
globalErrsMu.Lock()
defer globalErrsMu.Unlock()
globalErrs = append(globalErrs, err, errors.Wrap(vi.Cancel(), "cancel on error"))
globalErrs = append(globalErrs, err, errors.Wrap(vm.Cancel(), "cancel on error"))
}
vi.logger.Info("Booting the VM")
vm.logger.Info("Booting the VM")
go func() {
_ = vi.runSerialReader()
_ = vi.Cancel()
_ = vm.runSerialReader()
_ = vm.Cancel()
}()
go func() {
err = vi.runVMLoginHandler()
err = vm.runVMLoginHandler()
if err != nil {
globalErrFn(errors.Wrap(err, "run vm login handler"))
return
}
vi.logger.Info("Setting the VM up")
vm.logger.Info("Setting the VM up")
sshSigner, err := vi.sshSetup()
sshSigner, err := vm.sshSetup()
if err != nil {
globalErrFn(errors.Wrap(err, "set up ssh"))
return
}
vi.logger.Debug("Set up SSH server successfully")
vm.logger.Debug("Set up SSH server successfully")
sshKeyScan, err := vi.scanSSHIdentity()
sshKeyScan, err := vm.scanSSHIdentity()
if err != nil {
globalErrFn(errors.Wrap(err, "scan ssh identity"))
return
}
vi.logger.Debug("Scanned SSH identity")
vm.logger.Debug("Scanned SSH identity")
knownHosts, err := ParseSSHKeyScan(sshKeyScan)
if err != nil {
@ -201,7 +212,7 @@ func (vi *Instance) Run() error {
return
}
vi.sshConf = &ssh.ClientConfig{
vm.sshConf = &ssh.ClientConfig{
User: "root",
HostKeyCallback: knownHosts,
Auth: []ssh.AuthMethod{
@ -211,111 +222,111 @@ func (vi *Instance) Run() error {
}
// This is to notify everyone waiting for SSH to be up that it's ready to go.
close(vi.sshReadyCh)
close(vm.sshReadyCh)
vi.logger.Info("The VM is ready")
vm.logger.Info("The VM is ready")
}()
_, err = vi.cmd.Process.Wait()
cancelErr := vi.Cancel()
_, err = vm.cmd.Process.Wait()
cancelErr := vm.Cancel()
if err != nil {
combinedErr := multierr.Combine(
errors.Wrap(err, "wait for cmd to finish execution"),
errors.Wrap(cancelErr, "cancel"),
)
return fmt.Errorf("%w %v", combinedErr, getLogErrMsg(vi.stderrBuf.String()))
return fmt.Errorf("%w %v", combinedErr, getLogErrMsg(vm.stderrBuf.String()))
}
combinedErr := multierr.Combine(
append(globalErrs, errors.Wrap(cancelErr, "cancel on exit"))...,
)
if combinedErr != nil {
return fmt.Errorf("%w %v", combinedErr, getLogErrMsg(vi.stderrBuf.String()))
return fmt.Errorf("%w %v", combinedErr, getLogErrMsg(vm.stderrBuf.String()))
}
return nil
}
func (vi *Instance) Cancel() error {
if atomic.AddUint32(&vi.canceled, 1) != 1 {
func (vm *VM) Cancel() error {
if atomic.AddUint32(&vm.canceled, 1) != 1 {
return nil
}
vi.logger.Warn("Canceling the VM context")
vm.logger.Warn("Canceling the VM context")
var gracefulOK bool
sc, err := vi.DialSSH()
sc, err := vm.DialSSH()
if err != nil {
if !errors.Is(err, ErrSSHUnavailable) {
vi.logger.Warn("Failed to dial VM ssh to do graceful shutdown", "error", err)
vm.logger.Warn("Failed to dial VM ssh to do graceful shutdown", "error", err)
}
} else {
_, err = runSSHCmd(sc, "poweroff")
_ = sc.Close()
if err != nil {
vi.logger.Warn("Could not power off the VM safely", "error", err)
vm.logger.Warn("Could not power off the VM safely", "error", err)
} else {
vi.logger.Info("Shutting the VM down safely")
vm.logger.Info("Shutting the VM down safely")
}
}
var interruptErr error
if !gracefulOK {
interruptErr = vi.cmd.Process.Signal(os.Interrupt)
interruptErr = vm.cmd.Process.Signal(os.Interrupt)
}
vi.ctxCancel()
vm.ctxCancel()
return multierr.Combine(
errors.Wrap(interruptErr, "interrupt cmd"),
errors.Wrap(vi.serialRead.Close(), "close serial read pipe"),
errors.Wrap(vi.serialWrite.Close(), "close serial write pipe"),
errors.Wrap(vm.serialRead.Close(), "close serial read pipe"),
errors.Wrap(vm.serialWrite.Close(), "close serial write pipe"),
)
}
func (vi *Instance) runSerialReader() error {
func (vm *VM) runSerialReader() error {
for {
raw, err := vi.serialReader.ReadBytes('\n')
raw, err := vm.serialReader.ReadBytes('\n')
if err != nil {
return errors.Wrap(err, "read from serial reader")
}
select {
case vi.serialStdoutCh <- raw:
case vm.serialStdoutCh <- raw:
default:
// Message gets discarded if the buffer is full.
}
}
}
func (vi *Instance) writeSerial(b []byte) error {
vi.serialWriteMu.Lock()
defer vi.serialWriteMu.Unlock()
func (vm *VM) writeSerial(b []byte) error {
vm.serialWriteMu.Lock()
defer vm.serialWriteMu.Unlock()
_, err := vi.serialWrite.Write(b)
_, err := vm.serialWrite.Write(b)
return err
}
func (vi *Instance) runVMLoginHandler() error {
func (vm *VM) runVMLoginHandler() error {
for {
select {
case <-vi.ctx.Done():
case <-vm.ctx.Done():
return nil
case <-time.After(time.Second):
peek, err := vi.serialReader.Peek(vi.serialReader.Buffered())
peek, err := vm.serialReader.Peek(vm.serialReader.Buffered())
if err != nil {
return errors.Wrap(err, "peek stdout")
}
if bytes.Contains(peek, []byte("login:")) {
err = vi.writeSerial([]byte("root\n"))
err = vm.writeSerial([]byte("root\n"))
if err != nil {
return errors.Wrap(err, "failed to stdio write login")
}
vi.logger.Debug("Logged into the VM serial")
vm.logger.Debug("Logged into the VM serial")
return nil
}
@ -323,24 +334,24 @@ func (vi *Instance) runVMLoginHandler() error {
}
}
func (vi *Instance) resetSerialStdout() {
vi.serialStdoutCh = make(chan []byte, 32)
func (vm *VM) resetSerialStdout() {
vm.serialStdoutCh = make(chan []byte, 32)
}
func (vi *Instance) DialSSH() (*ssh.Client, error) {
if vi.sshConf == nil {
func (vm *VM) DialSSH() (*ssh.Client, error) {
if vm.sshConf == nil {
return nil, ErrSSHUnavailable
}
return ssh.Dial("tcp", "localhost:"+fmt.Sprint(vi.sshMappedPort), vi.sshConf)
return ssh.Dial("tcp", "localhost:"+fmt.Sprint(vm.sshMappedPort), vm.sshConf)
}
func (vi *Instance) DialSCP() (*scp.Client, error) {
if vi.sshConf == nil {
func (vm *VM) DialSCP() (*scp.Client, error) {
if vm.sshConf == nil {
return nil, ErrSSHUnavailable
}
sc := scp.NewClient("localhost:"+fmt.Sprint(vi.sshMappedPort), vi.sshConf)
sc := scp.NewClient("localhost:"+fmt.Sprint(vm.sshMappedPort), vm.sshConf)
err := sc.Connect()
if err != nil {
return nil, err
@ -349,6 +360,6 @@ func (vi *Instance) DialSCP() (*scp.Client, error) {
return &sc, nil
}
func (vi *Instance) SSHUpNotifyChan() chan struct{} {
return vi.sshReadyCh
func (vm *VM) SSHUpNotifyChan() chan struct{} {
return vm.sshReadyCh
}