Some restructuring
This commit is contained in:
parent
34e66cb01c
commit
ee447087f6
8 changed files with 131 additions and 110 deletions
145
vm/vm.go
145
vm/vm.go
|
|
@ -26,7 +26,7 @@ import (
|
|||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type Instance struct {
|
||||
type VM struct {
|
||||
logger *slog.Logger
|
||||
|
||||
ctx context.Context
|
||||
|
|
@ -51,11 +51,22 @@ type Instance struct {
|
|||
canceled uint32
|
||||
}
|
||||
|
||||
func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDevicePassthroughConfig, debug bool, extraPortForwardings []PortForwardingConfig, unrestrictedNetworking bool) (*Instance, error) {
|
||||
alpineImagePath = filepath.Clean(alpineImagePath)
|
||||
_, err := os.Stat(alpineImagePath)
|
||||
type VMConfig struct {
|
||||
CdromImagePath string
|
||||
|
||||
USBDevices []USBDevicePassthroughConfig
|
||||
ExtraPortForwardingRules []PortForwardingRule
|
||||
|
||||
// Debug-related options.
|
||||
DebugUnrestrictedNetworking bool
|
||||
DebugShowDisplay bool
|
||||
}
|
||||
|
||||
func NewVM(logger *slog.Logger, cfg VMConfig) (*VM, error) {
|
||||
cdromImagePath := filepath.Clean(cfg.CdromImagePath)
|
||||
_, err := os.Stat(cdromImagePath)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to stat alpine image path")
|
||||
return nil, errors.Wrap(err, "stat cdrom image path")
|
||||
}
|
||||
|
||||
sshPort, err := freeport.GetFreePort()
|
||||
|
|
@ -70,13 +81,13 @@ func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDe
|
|||
|
||||
netdevOpts := "user,id=net0,hostfwd=tcp:127.0.0.1:" + fmt.Sprint(sshPort) + "-:22"
|
||||
|
||||
if !unrestrictedNetworking {
|
||||
if !cfg.DebugUnrestrictedNetworking {
|
||||
netdevOpts += ",restrict=on"
|
||||
} else {
|
||||
logger.Warn("Running with unsafe unrestricted networking")
|
||||
}
|
||||
|
||||
for _, pf := range extraPortForwardings {
|
||||
for _, pf := range cfg.ExtraPortForwardingRules {
|
||||
hostIPStr := ""
|
||||
if pf.HostIP != nil {
|
||||
hostIPStr = pf.HostIP.String()
|
||||
|
|
@ -87,16 +98,16 @@ func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDe
|
|||
|
||||
cmdArgs = append(cmdArgs, "-device", "e1000,netdev=net0", "-netdev", netdevOpts)
|
||||
|
||||
cmdArgs = append(cmdArgs, "-drive", "file="+shellescape.Quote(alpineImagePath)+",format=qcow2,if=virtio", "-snapshot")
|
||||
cmdArgs = append(cmdArgs, "-drive", "file="+shellescape.Quote(cdromImagePath)+",format=qcow2,if=virtio", "-snapshot")
|
||||
|
||||
if !debug {
|
||||
if !cfg.DebugShowDisplay {
|
||||
cmdArgs = append(cmdArgs, "-display", "none")
|
||||
}
|
||||
|
||||
if len(usbDevices) != 0 {
|
||||
if len(cfg.USBDevices) != 0 {
|
||||
cmdArgs = append(cmdArgs, "-usb", "-device", "nec-usb-xhci,id=xhci")
|
||||
|
||||
for _, dev := range usbDevices {
|
||||
for _, dev := range cfg.USBDevices {
|
||||
cmdArgs = append(cmdArgs, "-device", "usb-host,hostbus="+strconv.FormatUint(uint64(dev.HostBus), 10)+",hostport="+strconv.FormatUint(uint64(dev.HostPort), 10))
|
||||
}
|
||||
}
|
||||
|
|
@ -120,7 +131,7 @@ func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDe
|
|||
|
||||
ctx, ctxCancel := context.WithCancel(context.Background())
|
||||
|
||||
vi := &Instance{
|
||||
vm := &VM{
|
||||
logger: logger,
|
||||
|
||||
ctx: ctx,
|
||||
|
|
@ -137,17 +148,17 @@ func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDe
|
|||
stderrBuf: stderrBuf,
|
||||
}
|
||||
|
||||
vi.resetSerialStdout()
|
||||
vm.resetSerialStdout()
|
||||
|
||||
return vi, nil
|
||||
return vm, nil
|
||||
}
|
||||
|
||||
func (vi *Instance) Run() error {
|
||||
if atomic.AddUint32(&vi.disposed, 1) != 1 {
|
||||
func (vm *VM) Run() error {
|
||||
if atomic.AddUint32(&vm.disposed, 1) != 1 {
|
||||
return fmt.Errorf("vm disposed")
|
||||
}
|
||||
|
||||
err := vi.cmd.Start()
|
||||
err := vm.cmd.Start()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "start qemu cmd")
|
||||
}
|
||||
|
|
@ -159,40 +170,40 @@ func (vi *Instance) Run() error {
|
|||
globalErrsMu.Lock()
|
||||
defer globalErrsMu.Unlock()
|
||||
|
||||
globalErrs = append(globalErrs, err, errors.Wrap(vi.Cancel(), "cancel on error"))
|
||||
globalErrs = append(globalErrs, err, errors.Wrap(vm.Cancel(), "cancel on error"))
|
||||
}
|
||||
|
||||
vi.logger.Info("Booting the VM")
|
||||
vm.logger.Info("Booting the VM")
|
||||
|
||||
go func() {
|
||||
_ = vi.runSerialReader()
|
||||
_ = vi.Cancel()
|
||||
_ = vm.runSerialReader()
|
||||
_ = vm.Cancel()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
err = vi.runVMLoginHandler()
|
||||
err = vm.runVMLoginHandler()
|
||||
if err != nil {
|
||||
globalErrFn(errors.Wrap(err, "run vm login handler"))
|
||||
return
|
||||
}
|
||||
|
||||
vi.logger.Info("Setting the VM up")
|
||||
vm.logger.Info("Setting the VM up")
|
||||
|
||||
sshSigner, err := vi.sshSetup()
|
||||
sshSigner, err := vm.sshSetup()
|
||||
if err != nil {
|
||||
globalErrFn(errors.Wrap(err, "set up ssh"))
|
||||
return
|
||||
}
|
||||
|
||||
vi.logger.Debug("Set up SSH server successfully")
|
||||
vm.logger.Debug("Set up SSH server successfully")
|
||||
|
||||
sshKeyScan, err := vi.scanSSHIdentity()
|
||||
sshKeyScan, err := vm.scanSSHIdentity()
|
||||
if err != nil {
|
||||
globalErrFn(errors.Wrap(err, "scan ssh identity"))
|
||||
return
|
||||
}
|
||||
|
||||
vi.logger.Debug("Scanned SSH identity")
|
||||
vm.logger.Debug("Scanned SSH identity")
|
||||
|
||||
knownHosts, err := ParseSSHKeyScan(sshKeyScan)
|
||||
if err != nil {
|
||||
|
|
@ -201,7 +212,7 @@ func (vi *Instance) Run() error {
|
|||
return
|
||||
}
|
||||
|
||||
vi.sshConf = &ssh.ClientConfig{
|
||||
vm.sshConf = &ssh.ClientConfig{
|
||||
User: "root",
|
||||
HostKeyCallback: knownHosts,
|
||||
Auth: []ssh.AuthMethod{
|
||||
|
|
@ -211,111 +222,111 @@ func (vi *Instance) Run() error {
|
|||
}
|
||||
|
||||
// This is to notify everyone waiting for SSH to be up that it's ready to go.
|
||||
close(vi.sshReadyCh)
|
||||
close(vm.sshReadyCh)
|
||||
|
||||
vi.logger.Info("The VM is ready")
|
||||
vm.logger.Info("The VM is ready")
|
||||
}()
|
||||
|
||||
_, err = vi.cmd.Process.Wait()
|
||||
cancelErr := vi.Cancel()
|
||||
_, err = vm.cmd.Process.Wait()
|
||||
cancelErr := vm.Cancel()
|
||||
if err != nil {
|
||||
combinedErr := multierr.Combine(
|
||||
errors.Wrap(err, "wait for cmd to finish execution"),
|
||||
errors.Wrap(cancelErr, "cancel"),
|
||||
)
|
||||
|
||||
return fmt.Errorf("%w %v", combinedErr, getLogErrMsg(vi.stderrBuf.String()))
|
||||
return fmt.Errorf("%w %v", combinedErr, getLogErrMsg(vm.stderrBuf.String()))
|
||||
}
|
||||
|
||||
combinedErr := multierr.Combine(
|
||||
append(globalErrs, errors.Wrap(cancelErr, "cancel on exit"))...,
|
||||
)
|
||||
if combinedErr != nil {
|
||||
return fmt.Errorf("%w %v", combinedErr, getLogErrMsg(vi.stderrBuf.String()))
|
||||
return fmt.Errorf("%w %v", combinedErr, getLogErrMsg(vm.stderrBuf.String()))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (vi *Instance) Cancel() error {
|
||||
if atomic.AddUint32(&vi.canceled, 1) != 1 {
|
||||
func (vm *VM) Cancel() error {
|
||||
if atomic.AddUint32(&vm.canceled, 1) != 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
vi.logger.Warn("Canceling the VM context")
|
||||
vm.logger.Warn("Canceling the VM context")
|
||||
|
||||
var gracefulOK bool
|
||||
|
||||
sc, err := vi.DialSSH()
|
||||
sc, err := vm.DialSSH()
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrSSHUnavailable) {
|
||||
vi.logger.Warn("Failed to dial VM ssh to do graceful shutdown", "error", err)
|
||||
vm.logger.Warn("Failed to dial VM ssh to do graceful shutdown", "error", err)
|
||||
}
|
||||
} else {
|
||||
_, err = runSSHCmd(sc, "poweroff")
|
||||
_ = sc.Close()
|
||||
if err != nil {
|
||||
vi.logger.Warn("Could not power off the VM safely", "error", err)
|
||||
vm.logger.Warn("Could not power off the VM safely", "error", err)
|
||||
} else {
|
||||
vi.logger.Info("Shutting the VM down safely")
|
||||
vm.logger.Info("Shutting the VM down safely")
|
||||
}
|
||||
}
|
||||
|
||||
var interruptErr error
|
||||
|
||||
if !gracefulOK {
|
||||
interruptErr = vi.cmd.Process.Signal(os.Interrupt)
|
||||
interruptErr = vm.cmd.Process.Signal(os.Interrupt)
|
||||
}
|
||||
|
||||
vi.ctxCancel()
|
||||
vm.ctxCancel()
|
||||
return multierr.Combine(
|
||||
errors.Wrap(interruptErr, "interrupt cmd"),
|
||||
errors.Wrap(vi.serialRead.Close(), "close serial read pipe"),
|
||||
errors.Wrap(vi.serialWrite.Close(), "close serial write pipe"),
|
||||
errors.Wrap(vm.serialRead.Close(), "close serial read pipe"),
|
||||
errors.Wrap(vm.serialWrite.Close(), "close serial write pipe"),
|
||||
)
|
||||
}
|
||||
|
||||
func (vi *Instance) runSerialReader() error {
|
||||
func (vm *VM) runSerialReader() error {
|
||||
for {
|
||||
raw, err := vi.serialReader.ReadBytes('\n')
|
||||
raw, err := vm.serialReader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "read from serial reader")
|
||||
}
|
||||
|
||||
select {
|
||||
case vi.serialStdoutCh <- raw:
|
||||
case vm.serialStdoutCh <- raw:
|
||||
default:
|
||||
// Message gets discarded if the buffer is full.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (vi *Instance) writeSerial(b []byte) error {
|
||||
vi.serialWriteMu.Lock()
|
||||
defer vi.serialWriteMu.Unlock()
|
||||
func (vm *VM) writeSerial(b []byte) error {
|
||||
vm.serialWriteMu.Lock()
|
||||
defer vm.serialWriteMu.Unlock()
|
||||
|
||||
_, err := vi.serialWrite.Write(b)
|
||||
_, err := vm.serialWrite.Write(b)
|
||||
return err
|
||||
}
|
||||
|
||||
func (vi *Instance) runVMLoginHandler() error {
|
||||
func (vm *VM) runVMLoginHandler() error {
|
||||
for {
|
||||
select {
|
||||
case <-vi.ctx.Done():
|
||||
case <-vm.ctx.Done():
|
||||
return nil
|
||||
case <-time.After(time.Second):
|
||||
peek, err := vi.serialReader.Peek(vi.serialReader.Buffered())
|
||||
peek, err := vm.serialReader.Peek(vm.serialReader.Buffered())
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "peek stdout")
|
||||
}
|
||||
|
||||
if bytes.Contains(peek, []byte("login:")) {
|
||||
err = vi.writeSerial([]byte("root\n"))
|
||||
err = vm.writeSerial([]byte("root\n"))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to stdio write login")
|
||||
}
|
||||
|
||||
vi.logger.Debug("Logged into the VM serial")
|
||||
vm.logger.Debug("Logged into the VM serial")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -323,24 +334,24 @@ func (vi *Instance) runVMLoginHandler() error {
|
|||
}
|
||||
}
|
||||
|
||||
func (vi *Instance) resetSerialStdout() {
|
||||
vi.serialStdoutCh = make(chan []byte, 32)
|
||||
func (vm *VM) resetSerialStdout() {
|
||||
vm.serialStdoutCh = make(chan []byte, 32)
|
||||
}
|
||||
|
||||
func (vi *Instance) DialSSH() (*ssh.Client, error) {
|
||||
if vi.sshConf == nil {
|
||||
func (vm *VM) DialSSH() (*ssh.Client, error) {
|
||||
if vm.sshConf == nil {
|
||||
return nil, ErrSSHUnavailable
|
||||
}
|
||||
|
||||
return ssh.Dial("tcp", "localhost:"+fmt.Sprint(vi.sshMappedPort), vi.sshConf)
|
||||
return ssh.Dial("tcp", "localhost:"+fmt.Sprint(vm.sshMappedPort), vm.sshConf)
|
||||
}
|
||||
|
||||
func (vi *Instance) DialSCP() (*scp.Client, error) {
|
||||
if vi.sshConf == nil {
|
||||
func (vm *VM) DialSCP() (*scp.Client, error) {
|
||||
if vm.sshConf == nil {
|
||||
return nil, ErrSSHUnavailable
|
||||
}
|
||||
|
||||
sc := scp.NewClient("localhost:"+fmt.Sprint(vi.sshMappedPort), vi.sshConf)
|
||||
sc := scp.NewClient("localhost:"+fmt.Sprint(vm.sshMappedPort), vm.sshConf)
|
||||
err := sc.Connect()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -349,6 +360,6 @@ func (vi *Instance) DialSCP() (*scp.Client, error) {
|
|||
return &sc, nil
|
||||
}
|
||||
|
||||
func (vi *Instance) SSHUpNotifyChan() chan struct{} {
|
||||
return vi.sshReadyCh
|
||||
func (vm *VM) SSHUpNotifyChan() chan struct{} {
|
||||
return vm.sshReadyCh
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue