Some restructuring
This commit is contained in:
parent
34e66cb01c
commit
ee447087f6
8 changed files with 131 additions and 110 deletions
|
|
@ -18,7 +18,7 @@ var lsCmd = &cobra.Command{
|
|||
// Short: "",
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
os.Exit(runVM(args[0], func(ctx context.Context, i *vm.Instance, fm *vm.FileManager) int {
|
||||
os.Exit(runVM(args[0], func(ctx context.Context, i *vm.VM, fm *vm.FileManager) int {
|
||||
lsblkOut, err := fm.Lsblk()
|
||||
if err != nil {
|
||||
slog.Error("Failed to list block devices in the VM", "error", err)
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ var runCmd = &cobra.Command{
|
|||
|
||||
// TODO: `slog` library prints entire stack traces for errors which makes reading errors challenging.
|
||||
|
||||
os.Exit(runVM(args[0], func(ctx context.Context, i *vm.Instance, fm *vm.FileManager) int {
|
||||
os.Exit(runVM(args[0], func(ctx context.Context, i *vm.VM, fm *vm.FileManager) int {
|
||||
err := fm.Mount(vmMountDevName, vm.MountOptions{
|
||||
FSType: fsType,
|
||||
LUKS: luksFlag,
|
||||
|
|
@ -59,7 +59,7 @@ var runCmd = &cobra.Command{
|
|||
|
||||
<-ctx.Done()
|
||||
return 0
|
||||
}, []vm.PortForwardingConfig{{
|
||||
}, []vm.PortForwardingRule{{
|
||||
HostIP: net.ParseIP("127.0.0.1"), // TODO: Make this changeable.
|
||||
HostPort: networkSharePort,
|
||||
VMPort: 445,
|
||||
|
|
|
|||
10
cmd/shell.go
10
cmd/shell.go
|
|
@ -23,23 +23,23 @@ var shellCmd = &cobra.Command{
|
|||
passthroughArg = args[0]
|
||||
}
|
||||
|
||||
var forwardPortsConfig []vm.PortForwardingConfig
|
||||
var forwardPortRules []vm.PortForwardingRule
|
||||
|
||||
for i, fp := range strings.Split(forwardPortsFlagStr, ",") {
|
||||
if fp == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
fpc, err := vm.ParsePortForwardString(fp)
|
||||
fpr, err := vm.ParsePortForwardString(fp)
|
||||
if err != nil {
|
||||
slog.Error("Failed to parse port forward string", "index", i, "value", fp, "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
forwardPortsConfig = append(forwardPortsConfig, fpc)
|
||||
forwardPortRules = append(forwardPortRules, fpr)
|
||||
}
|
||||
|
||||
os.Exit(runVM(passthroughArg, func(ctx context.Context, i *vm.Instance, fm *vm.FileManager) int {
|
||||
os.Exit(runVM(passthroughArg, func(ctx context.Context, i *vm.VM, fm *vm.FileManager) int {
|
||||
sc, err := i.DialSSH()
|
||||
if err != nil {
|
||||
slog.Error("Failed to dial VM SSH", "error", err)
|
||||
|
|
@ -120,7 +120,7 @@ var shellCmd = &cobra.Command{
|
|||
}
|
||||
|
||||
return 0
|
||||
}, forwardPortsConfig, unrestrictedNetworkingFlag))
|
||||
}, forwardPortRules, unrestrictedNetworkingFlag))
|
||||
|
||||
return nil
|
||||
},
|
||||
|
|
|
|||
14
cmd/utils.go
14
cmd/utils.go
|
|
@ -37,7 +37,7 @@ func doRootCheck() {
|
|||
}
|
||||
}
|
||||
|
||||
func runVM(passthroughArg string, fn func(context.Context, *vm.Instance, *vm.FileManager) int, forwardPorts []vm.PortForwardingConfig, unrestrictedNetworking bool) int {
|
||||
func runVM(passthroughArg string, fn func(context.Context, *vm.VM, *vm.FileManager) int, forwardPortsRules []vm.PortForwardingRule, unrestrictedNetworking bool) int {
|
||||
doRootCheck()
|
||||
|
||||
var passthroughConfig []vm.USBDevicePassthroughConfig
|
||||
|
|
@ -46,8 +46,18 @@ func runVM(passthroughArg string, fn func(context.Context, *vm.Instance, *vm.Fil
|
|||
passthroughConfig = []vm.USBDevicePassthroughConfig{getDevicePassthroughConfig(passthroughArg)}
|
||||
}
|
||||
|
||||
vmCfg := vm.VMConfig{
|
||||
CdromImagePath: "alpine-img/alpine.qcow2",
|
||||
|
||||
USBDevices: passthroughConfig,
|
||||
ExtraPortForwardingRules: forwardPortsRules,
|
||||
|
||||
DebugUnrestrictedNetworking: unrestrictedNetworking,
|
||||
DebugShowDisplay: vmDebugFlag,
|
||||
}
|
||||
|
||||
// TODO: Alpine image should be downloaded from somewhere.
|
||||
vi, err := vm.NewInstance(slog.Default().With("caller", "vm"), "alpine-img/alpine.qcow2", passthroughConfig, vmDebugFlag, forwardPorts, unrestrictedNetworking)
|
||||
vi, err := vm.NewVM(slog.Default().With("caller", "vm"), vmCfg)
|
||||
if err != nil {
|
||||
slog.Error("Failed to create vm instance", "error", err)
|
||||
os.Exit(1)
|
||||
|
|
|
|||
|
|
@ -22,19 +22,19 @@ import (
|
|||
type FileManager struct {
|
||||
logger *slog.Logger
|
||||
|
||||
vi *Instance
|
||||
vm *VM
|
||||
}
|
||||
|
||||
func NewFileManager(logger *slog.Logger, vi *Instance) *FileManager {
|
||||
func NewFileManager(logger *slog.Logger, vm *VM) *FileManager {
|
||||
return &FileManager{
|
||||
logger: logger,
|
||||
|
||||
vi: vi,
|
||||
vm: vm,
|
||||
}
|
||||
}
|
||||
|
||||
func (fm *FileManager) Init() error {
|
||||
sc, err := fm.vi.DialSSH()
|
||||
sc, err := fm.vm.DialSSH()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "dial vm ssh")
|
||||
}
|
||||
|
|
@ -55,7 +55,7 @@ func (fm *FileManager) Init() error {
|
|||
}
|
||||
|
||||
func (fm *FileManager) Lsblk() ([]byte, error) {
|
||||
sc, err := fm.vi.DialSSH()
|
||||
sc, err := fm.vm.DialSSH()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "dial vm ssh")
|
||||
}
|
||||
|
|
@ -180,7 +180,7 @@ func (fm *FileManager) Mount(devName string, mo MountOptions) error {
|
|||
return fmt.Errorf("fs type is empty")
|
||||
}
|
||||
|
||||
sc, err := fm.vi.DialSSH()
|
||||
sc, err := fm.vm.DialSSH()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "dial vm ssh")
|
||||
}
|
||||
|
|
@ -205,7 +205,7 @@ func (fm *FileManager) Mount(devName string, mo MountOptions) error {
|
|||
}
|
||||
|
||||
func (fm *FileManager) StartSMB(pwd []byte) error {
|
||||
scpClient, err := fm.vi.DialSCP()
|
||||
scpClient, err := fm.vm.DialSCP()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "dial scp")
|
||||
}
|
||||
|
|
@ -225,14 +225,14 @@ force user = linsk
|
|||
force group = linsk
|
||||
create mask = 0664`
|
||||
|
||||
err = scpClient.CopyFile(fm.vi.ctx, strings.NewReader(sambaCfg), "/etc/samba/smb.conf", "0400")
|
||||
err = scpClient.CopyFile(fm.vm.ctx, strings.NewReader(sambaCfg), "/etc/samba/smb.conf", "0400")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "copy samba config file")
|
||||
}
|
||||
|
||||
scpClient.Close()
|
||||
|
||||
sc, err := fm.vi.DialSSH()
|
||||
sc, err := fm.vm.DialSSH()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "dial ssh")
|
||||
}
|
||||
|
|
@ -269,11 +269,11 @@ create mask = 0664`
|
|||
go func() {
|
||||
_, err = stdinPipe.Write(pwd)
|
||||
if err != nil {
|
||||
fm.vi.logger.Error("Failed to write SMB password to smbpasswd stdin", "error", err)
|
||||
fm.vm.logger.Error("Failed to write SMB password to smbpasswd stdin", "error", err)
|
||||
}
|
||||
_, err = stdinPipe.Write(pwd)
|
||||
if err != nil {
|
||||
fm.vi.logger.Error("Failed to write repeated SMB password to smbpasswd stdin", "error", err)
|
||||
fm.vm.logger.Error("Failed to write repeated SMB password to smbpasswd stdin", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
|
|
|
|||
24
vm/ssh.go
24
vm/ssh.go
|
|
@ -50,10 +50,10 @@ func ParseSSHKeyScan(knownHosts []byte) (ssh.HostKeyCallback, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (vi *Instance) scanSSHIdentity() ([]byte, error) {
|
||||
vi.resetSerialStdout()
|
||||
func (vm *VM) scanSSHIdentity() ([]byte, error) {
|
||||
vm.resetSerialStdout()
|
||||
|
||||
err := vi.writeSerial([]byte(`ssh-keyscan -H localhost; echo "SERIAL STATUS: $?"; rm /root/.ash_history` + "\n"))
|
||||
err := vm.writeSerial([]byte(`ssh-keyscan -H localhost; echo "SERIAL STATUS: $?"; rm /root/.ash_history` + "\n"))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "write keyscan command to serial")
|
||||
}
|
||||
|
|
@ -64,11 +64,11 @@ func (vi *Instance) scanSSHIdentity() ([]byte, error) {
|
|||
|
||||
for {
|
||||
select {
|
||||
case <-vi.ctx.Done():
|
||||
return nil, vi.ctx.Err()
|
||||
case <-vm.ctx.Done():
|
||||
return nil, vm.ctx.Err()
|
||||
case <-time.After(time.Until(deadline)):
|
||||
return nil, fmt.Errorf("keyscan command timed out")
|
||||
case data := <-vi.serialStdoutCh:
|
||||
case data := <-vm.serialStdoutCh:
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
|
@ -91,8 +91,8 @@ func (vi *Instance) scanSSHIdentity() ([]byte, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (vi *Instance) sshSetup() (ssh.Signer, error) {
|
||||
vi.resetSerialStdout()
|
||||
func (vm *VM) sshSetup() (ssh.Signer, error) {
|
||||
vm.resetSerialStdout()
|
||||
|
||||
sshSigner, sshPublicKey, err := generateSSHKey()
|
||||
if err != nil {
|
||||
|
|
@ -101,7 +101,7 @@ func (vi *Instance) sshSetup() (ssh.Signer, error) {
|
|||
|
||||
cmd := `set -ex; do_setup () { sh -c "set -ex; ifconfig eth0 up; ifconfig lo up; udhcpc; mkdir -p ~/.ssh; echo ` + shellescape.Quote(string(sshPublicKey)) + ` > ~/.ssh/authorized_keys; rc-update add sshd; rc-service sshd start"; echo "SERIAL"" ""STATUS: $?"; }; do_setup` + "\n"
|
||||
|
||||
err = vi.writeSerial([]byte(cmd))
|
||||
err = vm.writeSerial([]byte(cmd))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "write ssh setup serial command")
|
||||
}
|
||||
|
|
@ -112,11 +112,11 @@ func (vi *Instance) sshSetup() (ssh.Signer, error) {
|
|||
|
||||
for {
|
||||
select {
|
||||
case <-vi.ctx.Done():
|
||||
return nil, vi.ctx.Err()
|
||||
case <-vm.ctx.Done():
|
||||
return nil, vm.ctx.Err()
|
||||
case <-time.After(time.Until(deadline)):
|
||||
return nil, fmt.Errorf("setup command timed out %v", getLogErrMsg(stdOutErrBuf.String()))
|
||||
case data := <-vi.serialStdoutCh:
|
||||
case data := <-vm.serialStdoutCh:
|
||||
prefix := []byte("SERIAL STATUS: ")
|
||||
stdOutErrBuf.Write(data)
|
||||
if bytes.HasPrefix(data, prefix) {
|
||||
|
|
|
|||
20
vm/types.go
20
vm/types.go
|
|
@ -14,28 +14,28 @@ type USBDevicePassthroughConfig struct {
|
|||
HostPort uint8
|
||||
}
|
||||
|
||||
type PortForwardingConfig struct {
|
||||
type PortForwardingRule struct {
|
||||
HostIP net.IP
|
||||
HostPort uint16
|
||||
VMPort uint16
|
||||
}
|
||||
|
||||
func ParsePortForwardString(s string) (PortForwardingConfig, error) {
|
||||
func ParsePortForwardString(s string) (PortForwardingRule, error) {
|
||||
split := strings.Split(s, ":")
|
||||
switch len(split) {
|
||||
case 2:
|
||||
// <HOST PORT>:<VM PORT>
|
||||
hostPort, err := strconv.ParseUint(split[0], 10, 16)
|
||||
if err != nil {
|
||||
return PortForwardingConfig{}, errors.Wrap(err, "parse host port")
|
||||
return PortForwardingRule{}, errors.Wrap(err, "parse host port")
|
||||
}
|
||||
|
||||
vmPort, err := strconv.ParseUint(split[1], 10, 16)
|
||||
if err != nil {
|
||||
return PortForwardingConfig{}, errors.Wrap(err, "parse vm port")
|
||||
return PortForwardingRule{}, errors.Wrap(err, "parse vm port")
|
||||
}
|
||||
|
||||
return PortForwardingConfig{
|
||||
return PortForwardingRule{
|
||||
HostPort: uint16(hostPort),
|
||||
VMPort: uint16(vmPort),
|
||||
}, nil
|
||||
|
|
@ -43,25 +43,25 @@ func ParsePortForwardString(s string) (PortForwardingConfig, error) {
|
|||
// <HOST IP>:<HOST PORT>:<VM PORT>
|
||||
hostIP := net.ParseIP(split[0])
|
||||
if hostIP == nil {
|
||||
return PortForwardingConfig{}, fmt.Errorf("bad host ip")
|
||||
return PortForwardingRule{}, fmt.Errorf("bad host ip")
|
||||
}
|
||||
|
||||
hostPort, err := strconv.ParseUint(split[1], 10, 16)
|
||||
if err != nil {
|
||||
return PortForwardingConfig{}, errors.Wrap(err, "parse host port")
|
||||
return PortForwardingRule{}, errors.Wrap(err, "parse host port")
|
||||
}
|
||||
|
||||
vmPort, err := strconv.ParseUint(split[2], 10, 16)
|
||||
if err != nil {
|
||||
return PortForwardingConfig{}, errors.Wrap(err, "parse vm port")
|
||||
return PortForwardingRule{}, errors.Wrap(err, "parse vm port")
|
||||
}
|
||||
|
||||
return PortForwardingConfig{
|
||||
return PortForwardingRule{
|
||||
HostIP: hostIP,
|
||||
HostPort: uint16(hostPort),
|
||||
VMPort: uint16(vmPort),
|
||||
}, nil
|
||||
default:
|
||||
return PortForwardingConfig{}, fmt.Errorf("bad split by ':' length: want 2 or 3, have %v", len(split))
|
||||
return PortForwardingRule{}, fmt.Errorf("bad split by ':' length: want 2 or 3, have %v", len(split))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
145
vm/vm.go
145
vm/vm.go
|
|
@ -26,7 +26,7 @@ import (
|
|||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type Instance struct {
|
||||
type VM struct {
|
||||
logger *slog.Logger
|
||||
|
||||
ctx context.Context
|
||||
|
|
@ -51,11 +51,22 @@ type Instance struct {
|
|||
canceled uint32
|
||||
}
|
||||
|
||||
func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDevicePassthroughConfig, debug bool, extraPortForwardings []PortForwardingConfig, unrestrictedNetworking bool) (*Instance, error) {
|
||||
alpineImagePath = filepath.Clean(alpineImagePath)
|
||||
_, err := os.Stat(alpineImagePath)
|
||||
type VMConfig struct {
|
||||
CdromImagePath string
|
||||
|
||||
USBDevices []USBDevicePassthroughConfig
|
||||
ExtraPortForwardingRules []PortForwardingRule
|
||||
|
||||
// Debug-related options.
|
||||
DebugUnrestrictedNetworking bool
|
||||
DebugShowDisplay bool
|
||||
}
|
||||
|
||||
func NewVM(logger *slog.Logger, cfg VMConfig) (*VM, error) {
|
||||
cdromImagePath := filepath.Clean(cfg.CdromImagePath)
|
||||
_, err := os.Stat(cdromImagePath)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to stat alpine image path")
|
||||
return nil, errors.Wrap(err, "stat cdrom image path")
|
||||
}
|
||||
|
||||
sshPort, err := freeport.GetFreePort()
|
||||
|
|
@ -70,13 +81,13 @@ func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDe
|
|||
|
||||
netdevOpts := "user,id=net0,hostfwd=tcp:127.0.0.1:" + fmt.Sprint(sshPort) + "-:22"
|
||||
|
||||
if !unrestrictedNetworking {
|
||||
if !cfg.DebugUnrestrictedNetworking {
|
||||
netdevOpts += ",restrict=on"
|
||||
} else {
|
||||
logger.Warn("Running with unsafe unrestricted networking")
|
||||
}
|
||||
|
||||
for _, pf := range extraPortForwardings {
|
||||
for _, pf := range cfg.ExtraPortForwardingRules {
|
||||
hostIPStr := ""
|
||||
if pf.HostIP != nil {
|
||||
hostIPStr = pf.HostIP.String()
|
||||
|
|
@ -87,16 +98,16 @@ func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDe
|
|||
|
||||
cmdArgs = append(cmdArgs, "-device", "e1000,netdev=net0", "-netdev", netdevOpts)
|
||||
|
||||
cmdArgs = append(cmdArgs, "-drive", "file="+shellescape.Quote(alpineImagePath)+",format=qcow2,if=virtio", "-snapshot")
|
||||
cmdArgs = append(cmdArgs, "-drive", "file="+shellescape.Quote(cdromImagePath)+",format=qcow2,if=virtio", "-snapshot")
|
||||
|
||||
if !debug {
|
||||
if !cfg.DebugShowDisplay {
|
||||
cmdArgs = append(cmdArgs, "-display", "none")
|
||||
}
|
||||
|
||||
if len(usbDevices) != 0 {
|
||||
if len(cfg.USBDevices) != 0 {
|
||||
cmdArgs = append(cmdArgs, "-usb", "-device", "nec-usb-xhci,id=xhci")
|
||||
|
||||
for _, dev := range usbDevices {
|
||||
for _, dev := range cfg.USBDevices {
|
||||
cmdArgs = append(cmdArgs, "-device", "usb-host,hostbus="+strconv.FormatUint(uint64(dev.HostBus), 10)+",hostport="+strconv.FormatUint(uint64(dev.HostPort), 10))
|
||||
}
|
||||
}
|
||||
|
|
@ -120,7 +131,7 @@ func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDe
|
|||
|
||||
ctx, ctxCancel := context.WithCancel(context.Background())
|
||||
|
||||
vi := &Instance{
|
||||
vm := &VM{
|
||||
logger: logger,
|
||||
|
||||
ctx: ctx,
|
||||
|
|
@ -137,17 +148,17 @@ func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDe
|
|||
stderrBuf: stderrBuf,
|
||||
}
|
||||
|
||||
vi.resetSerialStdout()
|
||||
vm.resetSerialStdout()
|
||||
|
||||
return vi, nil
|
||||
return vm, nil
|
||||
}
|
||||
|
||||
func (vi *Instance) Run() error {
|
||||
if atomic.AddUint32(&vi.disposed, 1) != 1 {
|
||||
func (vm *VM) Run() error {
|
||||
if atomic.AddUint32(&vm.disposed, 1) != 1 {
|
||||
return fmt.Errorf("vm disposed")
|
||||
}
|
||||
|
||||
err := vi.cmd.Start()
|
||||
err := vm.cmd.Start()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "start qemu cmd")
|
||||
}
|
||||
|
|
@ -159,40 +170,40 @@ func (vi *Instance) Run() error {
|
|||
globalErrsMu.Lock()
|
||||
defer globalErrsMu.Unlock()
|
||||
|
||||
globalErrs = append(globalErrs, err, errors.Wrap(vi.Cancel(), "cancel on error"))
|
||||
globalErrs = append(globalErrs, err, errors.Wrap(vm.Cancel(), "cancel on error"))
|
||||
}
|
||||
|
||||
vi.logger.Info("Booting the VM")
|
||||
vm.logger.Info("Booting the VM")
|
||||
|
||||
go func() {
|
||||
_ = vi.runSerialReader()
|
||||
_ = vi.Cancel()
|
||||
_ = vm.runSerialReader()
|
||||
_ = vm.Cancel()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
err = vi.runVMLoginHandler()
|
||||
err = vm.runVMLoginHandler()
|
||||
if err != nil {
|
||||
globalErrFn(errors.Wrap(err, "run vm login handler"))
|
||||
return
|
||||
}
|
||||
|
||||
vi.logger.Info("Setting the VM up")
|
||||
vm.logger.Info("Setting the VM up")
|
||||
|
||||
sshSigner, err := vi.sshSetup()
|
||||
sshSigner, err := vm.sshSetup()
|
||||
if err != nil {
|
||||
globalErrFn(errors.Wrap(err, "set up ssh"))
|
||||
return
|
||||
}
|
||||
|
||||
vi.logger.Debug("Set up SSH server successfully")
|
||||
vm.logger.Debug("Set up SSH server successfully")
|
||||
|
||||
sshKeyScan, err := vi.scanSSHIdentity()
|
||||
sshKeyScan, err := vm.scanSSHIdentity()
|
||||
if err != nil {
|
||||
globalErrFn(errors.Wrap(err, "scan ssh identity"))
|
||||
return
|
||||
}
|
||||
|
||||
vi.logger.Debug("Scanned SSH identity")
|
||||
vm.logger.Debug("Scanned SSH identity")
|
||||
|
||||
knownHosts, err := ParseSSHKeyScan(sshKeyScan)
|
||||
if err != nil {
|
||||
|
|
@ -201,7 +212,7 @@ func (vi *Instance) Run() error {
|
|||
return
|
||||
}
|
||||
|
||||
vi.sshConf = &ssh.ClientConfig{
|
||||
vm.sshConf = &ssh.ClientConfig{
|
||||
User: "root",
|
||||
HostKeyCallback: knownHosts,
|
||||
Auth: []ssh.AuthMethod{
|
||||
|
|
@ -211,111 +222,111 @@ func (vi *Instance) Run() error {
|
|||
}
|
||||
|
||||
// This is to notify everyone waiting for SSH to be up that it's ready to go.
|
||||
close(vi.sshReadyCh)
|
||||
close(vm.sshReadyCh)
|
||||
|
||||
vi.logger.Info("The VM is ready")
|
||||
vm.logger.Info("The VM is ready")
|
||||
}()
|
||||
|
||||
_, err = vi.cmd.Process.Wait()
|
||||
cancelErr := vi.Cancel()
|
||||
_, err = vm.cmd.Process.Wait()
|
||||
cancelErr := vm.Cancel()
|
||||
if err != nil {
|
||||
combinedErr := multierr.Combine(
|
||||
errors.Wrap(err, "wait for cmd to finish execution"),
|
||||
errors.Wrap(cancelErr, "cancel"),
|
||||
)
|
||||
|
||||
return fmt.Errorf("%w %v", combinedErr, getLogErrMsg(vi.stderrBuf.String()))
|
||||
return fmt.Errorf("%w %v", combinedErr, getLogErrMsg(vm.stderrBuf.String()))
|
||||
}
|
||||
|
||||
combinedErr := multierr.Combine(
|
||||
append(globalErrs, errors.Wrap(cancelErr, "cancel on exit"))...,
|
||||
)
|
||||
if combinedErr != nil {
|
||||
return fmt.Errorf("%w %v", combinedErr, getLogErrMsg(vi.stderrBuf.String()))
|
||||
return fmt.Errorf("%w %v", combinedErr, getLogErrMsg(vm.stderrBuf.String()))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (vi *Instance) Cancel() error {
|
||||
if atomic.AddUint32(&vi.canceled, 1) != 1 {
|
||||
func (vm *VM) Cancel() error {
|
||||
if atomic.AddUint32(&vm.canceled, 1) != 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
vi.logger.Warn("Canceling the VM context")
|
||||
vm.logger.Warn("Canceling the VM context")
|
||||
|
||||
var gracefulOK bool
|
||||
|
||||
sc, err := vi.DialSSH()
|
||||
sc, err := vm.DialSSH()
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrSSHUnavailable) {
|
||||
vi.logger.Warn("Failed to dial VM ssh to do graceful shutdown", "error", err)
|
||||
vm.logger.Warn("Failed to dial VM ssh to do graceful shutdown", "error", err)
|
||||
}
|
||||
} else {
|
||||
_, err = runSSHCmd(sc, "poweroff")
|
||||
_ = sc.Close()
|
||||
if err != nil {
|
||||
vi.logger.Warn("Could not power off the VM safely", "error", err)
|
||||
vm.logger.Warn("Could not power off the VM safely", "error", err)
|
||||
} else {
|
||||
vi.logger.Info("Shutting the VM down safely")
|
||||
vm.logger.Info("Shutting the VM down safely")
|
||||
}
|
||||
}
|
||||
|
||||
var interruptErr error
|
||||
|
||||
if !gracefulOK {
|
||||
interruptErr = vi.cmd.Process.Signal(os.Interrupt)
|
||||
interruptErr = vm.cmd.Process.Signal(os.Interrupt)
|
||||
}
|
||||
|
||||
vi.ctxCancel()
|
||||
vm.ctxCancel()
|
||||
return multierr.Combine(
|
||||
errors.Wrap(interruptErr, "interrupt cmd"),
|
||||
errors.Wrap(vi.serialRead.Close(), "close serial read pipe"),
|
||||
errors.Wrap(vi.serialWrite.Close(), "close serial write pipe"),
|
||||
errors.Wrap(vm.serialRead.Close(), "close serial read pipe"),
|
||||
errors.Wrap(vm.serialWrite.Close(), "close serial write pipe"),
|
||||
)
|
||||
}
|
||||
|
||||
func (vi *Instance) runSerialReader() error {
|
||||
func (vm *VM) runSerialReader() error {
|
||||
for {
|
||||
raw, err := vi.serialReader.ReadBytes('\n')
|
||||
raw, err := vm.serialReader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "read from serial reader")
|
||||
}
|
||||
|
||||
select {
|
||||
case vi.serialStdoutCh <- raw:
|
||||
case vm.serialStdoutCh <- raw:
|
||||
default:
|
||||
// Message gets discarded if the buffer is full.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (vi *Instance) writeSerial(b []byte) error {
|
||||
vi.serialWriteMu.Lock()
|
||||
defer vi.serialWriteMu.Unlock()
|
||||
func (vm *VM) writeSerial(b []byte) error {
|
||||
vm.serialWriteMu.Lock()
|
||||
defer vm.serialWriteMu.Unlock()
|
||||
|
||||
_, err := vi.serialWrite.Write(b)
|
||||
_, err := vm.serialWrite.Write(b)
|
||||
return err
|
||||
}
|
||||
|
||||
func (vi *Instance) runVMLoginHandler() error {
|
||||
func (vm *VM) runVMLoginHandler() error {
|
||||
for {
|
||||
select {
|
||||
case <-vi.ctx.Done():
|
||||
case <-vm.ctx.Done():
|
||||
return nil
|
||||
case <-time.After(time.Second):
|
||||
peek, err := vi.serialReader.Peek(vi.serialReader.Buffered())
|
||||
peek, err := vm.serialReader.Peek(vm.serialReader.Buffered())
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "peek stdout")
|
||||
}
|
||||
|
||||
if bytes.Contains(peek, []byte("login:")) {
|
||||
err = vi.writeSerial([]byte("root\n"))
|
||||
err = vm.writeSerial([]byte("root\n"))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to stdio write login")
|
||||
}
|
||||
|
||||
vi.logger.Debug("Logged into the VM serial")
|
||||
vm.logger.Debug("Logged into the VM serial")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -323,24 +334,24 @@ func (vi *Instance) runVMLoginHandler() error {
|
|||
}
|
||||
}
|
||||
|
||||
func (vi *Instance) resetSerialStdout() {
|
||||
vi.serialStdoutCh = make(chan []byte, 32)
|
||||
func (vm *VM) resetSerialStdout() {
|
||||
vm.serialStdoutCh = make(chan []byte, 32)
|
||||
}
|
||||
|
||||
func (vi *Instance) DialSSH() (*ssh.Client, error) {
|
||||
if vi.sshConf == nil {
|
||||
func (vm *VM) DialSSH() (*ssh.Client, error) {
|
||||
if vm.sshConf == nil {
|
||||
return nil, ErrSSHUnavailable
|
||||
}
|
||||
|
||||
return ssh.Dial("tcp", "localhost:"+fmt.Sprint(vi.sshMappedPort), vi.sshConf)
|
||||
return ssh.Dial("tcp", "localhost:"+fmt.Sprint(vm.sshMappedPort), vm.sshConf)
|
||||
}
|
||||
|
||||
func (vi *Instance) DialSCP() (*scp.Client, error) {
|
||||
if vi.sshConf == nil {
|
||||
func (vm *VM) DialSCP() (*scp.Client, error) {
|
||||
if vm.sshConf == nil {
|
||||
return nil, ErrSSHUnavailable
|
||||
}
|
||||
|
||||
sc := scp.NewClient("localhost:"+fmt.Sprint(vi.sshMappedPort), vi.sshConf)
|
||||
sc := scp.NewClient("localhost:"+fmt.Sprint(vm.sshMappedPort), vm.sshConf)
|
||||
err := sc.Connect()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -349,6 +360,6 @@ func (vi *Instance) DialSCP() (*scp.Client, error) {
|
|||
return &sc, nil
|
||||
}
|
||||
|
||||
func (vi *Instance) SSHUpNotifyChan() chan struct{} {
|
||||
return vi.sshReadyCh
|
||||
func (vm *VM) SSHUpNotifyChan() chan struct{} {
|
||||
return vm.sshReadyCh
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue