Some restructuring
This commit is contained in:
parent
34e66cb01c
commit
ee447087f6
8 changed files with 131 additions and 110 deletions
|
|
@ -18,7 +18,7 @@ var lsCmd = &cobra.Command{
|
||||||
// Short: "",
|
// Short: "",
|
||||||
Args: cobra.ExactArgs(1),
|
Args: cobra.ExactArgs(1),
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
os.Exit(runVM(args[0], func(ctx context.Context, i *vm.Instance, fm *vm.FileManager) int {
|
os.Exit(runVM(args[0], func(ctx context.Context, i *vm.VM, fm *vm.FileManager) int {
|
||||||
lsblkOut, err := fm.Lsblk()
|
lsblkOut, err := fm.Lsblk()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("Failed to list block devices in the VM", "error", err)
|
slog.Error("Failed to list block devices in the VM", "error", err)
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ var runCmd = &cobra.Command{
|
||||||
|
|
||||||
// TODO: `slog` library prints entire stack traces for errors which makes reading errors challenging.
|
// TODO: `slog` library prints entire stack traces for errors which makes reading errors challenging.
|
||||||
|
|
||||||
os.Exit(runVM(args[0], func(ctx context.Context, i *vm.Instance, fm *vm.FileManager) int {
|
os.Exit(runVM(args[0], func(ctx context.Context, i *vm.VM, fm *vm.FileManager) int {
|
||||||
err := fm.Mount(vmMountDevName, vm.MountOptions{
|
err := fm.Mount(vmMountDevName, vm.MountOptions{
|
||||||
FSType: fsType,
|
FSType: fsType,
|
||||||
LUKS: luksFlag,
|
LUKS: luksFlag,
|
||||||
|
|
@ -59,7 +59,7 @@ var runCmd = &cobra.Command{
|
||||||
|
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
return 0
|
return 0
|
||||||
}, []vm.PortForwardingConfig{{
|
}, []vm.PortForwardingRule{{
|
||||||
HostIP: net.ParseIP("127.0.0.1"), // TODO: Make this changeable.
|
HostIP: net.ParseIP("127.0.0.1"), // TODO: Make this changeable.
|
||||||
HostPort: networkSharePort,
|
HostPort: networkSharePort,
|
||||||
VMPort: 445,
|
VMPort: 445,
|
||||||
|
|
|
||||||
10
cmd/shell.go
10
cmd/shell.go
|
|
@ -23,23 +23,23 @@ var shellCmd = &cobra.Command{
|
||||||
passthroughArg = args[0]
|
passthroughArg = args[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
var forwardPortsConfig []vm.PortForwardingConfig
|
var forwardPortRules []vm.PortForwardingRule
|
||||||
|
|
||||||
for i, fp := range strings.Split(forwardPortsFlagStr, ",") {
|
for i, fp := range strings.Split(forwardPortsFlagStr, ",") {
|
||||||
if fp == "" {
|
if fp == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
fpc, err := vm.ParsePortForwardString(fp)
|
fpr, err := vm.ParsePortForwardString(fp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("Failed to parse port forward string", "index", i, "value", fp, "error", err)
|
slog.Error("Failed to parse port forward string", "index", i, "value", fp, "error", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
forwardPortsConfig = append(forwardPortsConfig, fpc)
|
forwardPortRules = append(forwardPortRules, fpr)
|
||||||
}
|
}
|
||||||
|
|
||||||
os.Exit(runVM(passthroughArg, func(ctx context.Context, i *vm.Instance, fm *vm.FileManager) int {
|
os.Exit(runVM(passthroughArg, func(ctx context.Context, i *vm.VM, fm *vm.FileManager) int {
|
||||||
sc, err := i.DialSSH()
|
sc, err := i.DialSSH()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("Failed to dial VM SSH", "error", err)
|
slog.Error("Failed to dial VM SSH", "error", err)
|
||||||
|
|
@ -120,7 +120,7 @@ var shellCmd = &cobra.Command{
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
}, forwardPortsConfig, unrestrictedNetworkingFlag))
|
}, forwardPortRules, unrestrictedNetworkingFlag))
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
|
|
||||||
14
cmd/utils.go
14
cmd/utils.go
|
|
@ -37,7 +37,7 @@ func doRootCheck() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func runVM(passthroughArg string, fn func(context.Context, *vm.Instance, *vm.FileManager) int, forwardPorts []vm.PortForwardingConfig, unrestrictedNetworking bool) int {
|
func runVM(passthroughArg string, fn func(context.Context, *vm.VM, *vm.FileManager) int, forwardPortsRules []vm.PortForwardingRule, unrestrictedNetworking bool) int {
|
||||||
doRootCheck()
|
doRootCheck()
|
||||||
|
|
||||||
var passthroughConfig []vm.USBDevicePassthroughConfig
|
var passthroughConfig []vm.USBDevicePassthroughConfig
|
||||||
|
|
@ -46,8 +46,18 @@ func runVM(passthroughArg string, fn func(context.Context, *vm.Instance, *vm.Fil
|
||||||
passthroughConfig = []vm.USBDevicePassthroughConfig{getDevicePassthroughConfig(passthroughArg)}
|
passthroughConfig = []vm.USBDevicePassthroughConfig{getDevicePassthroughConfig(passthroughArg)}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
vmCfg := vm.VMConfig{
|
||||||
|
CdromImagePath: "alpine-img/alpine.qcow2",
|
||||||
|
|
||||||
|
USBDevices: passthroughConfig,
|
||||||
|
ExtraPortForwardingRules: forwardPortsRules,
|
||||||
|
|
||||||
|
DebugUnrestrictedNetworking: unrestrictedNetworking,
|
||||||
|
DebugShowDisplay: vmDebugFlag,
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Alpine image should be downloaded from somewhere.
|
// TODO: Alpine image should be downloaded from somewhere.
|
||||||
vi, err := vm.NewInstance(slog.Default().With("caller", "vm"), "alpine-img/alpine.qcow2", passthroughConfig, vmDebugFlag, forwardPorts, unrestrictedNetworking)
|
vi, err := vm.NewVM(slog.Default().With("caller", "vm"), vmCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("Failed to create vm instance", "error", err)
|
slog.Error("Failed to create vm instance", "error", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
|
|
|
||||||
|
|
@ -22,19 +22,19 @@ import (
|
||||||
type FileManager struct {
|
type FileManager struct {
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
|
|
||||||
vi *Instance
|
vm *VM
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFileManager(logger *slog.Logger, vi *Instance) *FileManager {
|
func NewFileManager(logger *slog.Logger, vm *VM) *FileManager {
|
||||||
return &FileManager{
|
return &FileManager{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
|
||||||
vi: vi,
|
vm: vm,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fm *FileManager) Init() error {
|
func (fm *FileManager) Init() error {
|
||||||
sc, err := fm.vi.DialSSH()
|
sc, err := fm.vm.DialSSH()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "dial vm ssh")
|
return errors.Wrap(err, "dial vm ssh")
|
||||||
}
|
}
|
||||||
|
|
@ -55,7 +55,7 @@ func (fm *FileManager) Init() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fm *FileManager) Lsblk() ([]byte, error) {
|
func (fm *FileManager) Lsblk() ([]byte, error) {
|
||||||
sc, err := fm.vi.DialSSH()
|
sc, err := fm.vm.DialSSH()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "dial vm ssh")
|
return nil, errors.Wrap(err, "dial vm ssh")
|
||||||
}
|
}
|
||||||
|
|
@ -180,7 +180,7 @@ func (fm *FileManager) Mount(devName string, mo MountOptions) error {
|
||||||
return fmt.Errorf("fs type is empty")
|
return fmt.Errorf("fs type is empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
sc, err := fm.vi.DialSSH()
|
sc, err := fm.vm.DialSSH()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "dial vm ssh")
|
return errors.Wrap(err, "dial vm ssh")
|
||||||
}
|
}
|
||||||
|
|
@ -205,7 +205,7 @@ func (fm *FileManager) Mount(devName string, mo MountOptions) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fm *FileManager) StartSMB(pwd []byte) error {
|
func (fm *FileManager) StartSMB(pwd []byte) error {
|
||||||
scpClient, err := fm.vi.DialSCP()
|
scpClient, err := fm.vm.DialSCP()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "dial scp")
|
return errors.Wrap(err, "dial scp")
|
||||||
}
|
}
|
||||||
|
|
@ -225,14 +225,14 @@ force user = linsk
|
||||||
force group = linsk
|
force group = linsk
|
||||||
create mask = 0664`
|
create mask = 0664`
|
||||||
|
|
||||||
err = scpClient.CopyFile(fm.vi.ctx, strings.NewReader(sambaCfg), "/etc/samba/smb.conf", "0400")
|
err = scpClient.CopyFile(fm.vm.ctx, strings.NewReader(sambaCfg), "/etc/samba/smb.conf", "0400")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "copy samba config file")
|
return errors.Wrap(err, "copy samba config file")
|
||||||
}
|
}
|
||||||
|
|
||||||
scpClient.Close()
|
scpClient.Close()
|
||||||
|
|
||||||
sc, err := fm.vi.DialSSH()
|
sc, err := fm.vm.DialSSH()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "dial ssh")
|
return errors.Wrap(err, "dial ssh")
|
||||||
}
|
}
|
||||||
|
|
@ -269,11 +269,11 @@ create mask = 0664`
|
||||||
go func() {
|
go func() {
|
||||||
_, err = stdinPipe.Write(pwd)
|
_, err = stdinPipe.Write(pwd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fm.vi.logger.Error("Failed to write SMB password to smbpasswd stdin", "error", err)
|
fm.vm.logger.Error("Failed to write SMB password to smbpasswd stdin", "error", err)
|
||||||
}
|
}
|
||||||
_, err = stdinPipe.Write(pwd)
|
_, err = stdinPipe.Write(pwd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fm.vi.logger.Error("Failed to write repeated SMB password to smbpasswd stdin", "error", err)
|
fm.vm.logger.Error("Failed to write repeated SMB password to smbpasswd stdin", "error", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|
|
||||||
24
vm/ssh.go
24
vm/ssh.go
|
|
@ -50,10 +50,10 @@ func ParseSSHKeyScan(knownHosts []byte) (ssh.HostKeyCallback, error) {
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vi *Instance) scanSSHIdentity() ([]byte, error) {
|
func (vm *VM) scanSSHIdentity() ([]byte, error) {
|
||||||
vi.resetSerialStdout()
|
vm.resetSerialStdout()
|
||||||
|
|
||||||
err := vi.writeSerial([]byte(`ssh-keyscan -H localhost; echo "SERIAL STATUS: $?"; rm /root/.ash_history` + "\n"))
|
err := vm.writeSerial([]byte(`ssh-keyscan -H localhost; echo "SERIAL STATUS: $?"; rm /root/.ash_history` + "\n"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "write keyscan command to serial")
|
return nil, errors.Wrap(err, "write keyscan command to serial")
|
||||||
}
|
}
|
||||||
|
|
@ -64,11 +64,11 @@ func (vi *Instance) scanSSHIdentity() ([]byte, error) {
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-vi.ctx.Done():
|
case <-vm.ctx.Done():
|
||||||
return nil, vi.ctx.Err()
|
return nil, vm.ctx.Err()
|
||||||
case <-time.After(time.Until(deadline)):
|
case <-time.After(time.Until(deadline)):
|
||||||
return nil, fmt.Errorf("keyscan command timed out")
|
return nil, fmt.Errorf("keyscan command timed out")
|
||||||
case data := <-vi.serialStdoutCh:
|
case data := <-vm.serialStdoutCh:
|
||||||
if len(data) == 0 {
|
if len(data) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
@ -91,8 +91,8 @@ func (vi *Instance) scanSSHIdentity() ([]byte, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vi *Instance) sshSetup() (ssh.Signer, error) {
|
func (vm *VM) sshSetup() (ssh.Signer, error) {
|
||||||
vi.resetSerialStdout()
|
vm.resetSerialStdout()
|
||||||
|
|
||||||
sshSigner, sshPublicKey, err := generateSSHKey()
|
sshSigner, sshPublicKey, err := generateSSHKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -101,7 +101,7 @@ func (vi *Instance) sshSetup() (ssh.Signer, error) {
|
||||||
|
|
||||||
cmd := `set -ex; do_setup () { sh -c "set -ex; ifconfig eth0 up; ifconfig lo up; udhcpc; mkdir -p ~/.ssh; echo ` + shellescape.Quote(string(sshPublicKey)) + ` > ~/.ssh/authorized_keys; rc-update add sshd; rc-service sshd start"; echo "SERIAL"" ""STATUS: $?"; }; do_setup` + "\n"
|
cmd := `set -ex; do_setup () { sh -c "set -ex; ifconfig eth0 up; ifconfig lo up; udhcpc; mkdir -p ~/.ssh; echo ` + shellescape.Quote(string(sshPublicKey)) + ` > ~/.ssh/authorized_keys; rc-update add sshd; rc-service sshd start"; echo "SERIAL"" ""STATUS: $?"; }; do_setup` + "\n"
|
||||||
|
|
||||||
err = vi.writeSerial([]byte(cmd))
|
err = vm.writeSerial([]byte(cmd))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "write ssh setup serial command")
|
return nil, errors.Wrap(err, "write ssh setup serial command")
|
||||||
}
|
}
|
||||||
|
|
@ -112,11 +112,11 @@ func (vi *Instance) sshSetup() (ssh.Signer, error) {
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-vi.ctx.Done():
|
case <-vm.ctx.Done():
|
||||||
return nil, vi.ctx.Err()
|
return nil, vm.ctx.Err()
|
||||||
case <-time.After(time.Until(deadline)):
|
case <-time.After(time.Until(deadline)):
|
||||||
return nil, fmt.Errorf("setup command timed out %v", getLogErrMsg(stdOutErrBuf.String()))
|
return nil, fmt.Errorf("setup command timed out %v", getLogErrMsg(stdOutErrBuf.String()))
|
||||||
case data := <-vi.serialStdoutCh:
|
case data := <-vm.serialStdoutCh:
|
||||||
prefix := []byte("SERIAL STATUS: ")
|
prefix := []byte("SERIAL STATUS: ")
|
||||||
stdOutErrBuf.Write(data)
|
stdOutErrBuf.Write(data)
|
||||||
if bytes.HasPrefix(data, prefix) {
|
if bytes.HasPrefix(data, prefix) {
|
||||||
|
|
|
||||||
20
vm/types.go
20
vm/types.go
|
|
@ -14,28 +14,28 @@ type USBDevicePassthroughConfig struct {
|
||||||
HostPort uint8
|
HostPort uint8
|
||||||
}
|
}
|
||||||
|
|
||||||
type PortForwardingConfig struct {
|
type PortForwardingRule struct {
|
||||||
HostIP net.IP
|
HostIP net.IP
|
||||||
HostPort uint16
|
HostPort uint16
|
||||||
VMPort uint16
|
VMPort uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParsePortForwardString(s string) (PortForwardingConfig, error) {
|
func ParsePortForwardString(s string) (PortForwardingRule, error) {
|
||||||
split := strings.Split(s, ":")
|
split := strings.Split(s, ":")
|
||||||
switch len(split) {
|
switch len(split) {
|
||||||
case 2:
|
case 2:
|
||||||
// <HOST PORT>:<VM PORT>
|
// <HOST PORT>:<VM PORT>
|
||||||
hostPort, err := strconv.ParseUint(split[0], 10, 16)
|
hostPort, err := strconv.ParseUint(split[0], 10, 16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return PortForwardingConfig{}, errors.Wrap(err, "parse host port")
|
return PortForwardingRule{}, errors.Wrap(err, "parse host port")
|
||||||
}
|
}
|
||||||
|
|
||||||
vmPort, err := strconv.ParseUint(split[1], 10, 16)
|
vmPort, err := strconv.ParseUint(split[1], 10, 16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return PortForwardingConfig{}, errors.Wrap(err, "parse vm port")
|
return PortForwardingRule{}, errors.Wrap(err, "parse vm port")
|
||||||
}
|
}
|
||||||
|
|
||||||
return PortForwardingConfig{
|
return PortForwardingRule{
|
||||||
HostPort: uint16(hostPort),
|
HostPort: uint16(hostPort),
|
||||||
VMPort: uint16(vmPort),
|
VMPort: uint16(vmPort),
|
||||||
}, nil
|
}, nil
|
||||||
|
|
@ -43,25 +43,25 @@ func ParsePortForwardString(s string) (PortForwardingConfig, error) {
|
||||||
// <HOST IP>:<HOST PORT>:<VM PORT>
|
// <HOST IP>:<HOST PORT>:<VM PORT>
|
||||||
hostIP := net.ParseIP(split[0])
|
hostIP := net.ParseIP(split[0])
|
||||||
if hostIP == nil {
|
if hostIP == nil {
|
||||||
return PortForwardingConfig{}, fmt.Errorf("bad host ip")
|
return PortForwardingRule{}, fmt.Errorf("bad host ip")
|
||||||
}
|
}
|
||||||
|
|
||||||
hostPort, err := strconv.ParseUint(split[1], 10, 16)
|
hostPort, err := strconv.ParseUint(split[1], 10, 16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return PortForwardingConfig{}, errors.Wrap(err, "parse host port")
|
return PortForwardingRule{}, errors.Wrap(err, "parse host port")
|
||||||
}
|
}
|
||||||
|
|
||||||
vmPort, err := strconv.ParseUint(split[2], 10, 16)
|
vmPort, err := strconv.ParseUint(split[2], 10, 16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return PortForwardingConfig{}, errors.Wrap(err, "parse vm port")
|
return PortForwardingRule{}, errors.Wrap(err, "parse vm port")
|
||||||
}
|
}
|
||||||
|
|
||||||
return PortForwardingConfig{
|
return PortForwardingRule{
|
||||||
HostIP: hostIP,
|
HostIP: hostIP,
|
||||||
HostPort: uint16(hostPort),
|
HostPort: uint16(hostPort),
|
||||||
VMPort: uint16(vmPort),
|
VMPort: uint16(vmPort),
|
||||||
}, nil
|
}, nil
|
||||||
default:
|
default:
|
||||||
return PortForwardingConfig{}, fmt.Errorf("bad split by ':' length: want 2 or 3, have %v", len(split))
|
return PortForwardingRule{}, fmt.Errorf("bad split by ':' length: want 2 or 3, have %v", len(split))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
145
vm/vm.go
145
vm/vm.go
|
|
@ -26,7 +26,7 @@ import (
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Instance struct {
|
type VM struct {
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
|
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
|
@ -51,11 +51,22 @@ type Instance struct {
|
||||||
canceled uint32
|
canceled uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDevicePassthroughConfig, debug bool, extraPortForwardings []PortForwardingConfig, unrestrictedNetworking bool) (*Instance, error) {
|
type VMConfig struct {
|
||||||
alpineImagePath = filepath.Clean(alpineImagePath)
|
CdromImagePath string
|
||||||
_, err := os.Stat(alpineImagePath)
|
|
||||||
|
USBDevices []USBDevicePassthroughConfig
|
||||||
|
ExtraPortForwardingRules []PortForwardingRule
|
||||||
|
|
||||||
|
// Debug-related options.
|
||||||
|
DebugUnrestrictedNetworking bool
|
||||||
|
DebugShowDisplay bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewVM(logger *slog.Logger, cfg VMConfig) (*VM, error) {
|
||||||
|
cdromImagePath := filepath.Clean(cfg.CdromImagePath)
|
||||||
|
_, err := os.Stat(cdromImagePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "failed to stat alpine image path")
|
return nil, errors.Wrap(err, "stat cdrom image path")
|
||||||
}
|
}
|
||||||
|
|
||||||
sshPort, err := freeport.GetFreePort()
|
sshPort, err := freeport.GetFreePort()
|
||||||
|
|
@ -70,13 +81,13 @@ func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDe
|
||||||
|
|
||||||
netdevOpts := "user,id=net0,hostfwd=tcp:127.0.0.1:" + fmt.Sprint(sshPort) + "-:22"
|
netdevOpts := "user,id=net0,hostfwd=tcp:127.0.0.1:" + fmt.Sprint(sshPort) + "-:22"
|
||||||
|
|
||||||
if !unrestrictedNetworking {
|
if !cfg.DebugUnrestrictedNetworking {
|
||||||
netdevOpts += ",restrict=on"
|
netdevOpts += ",restrict=on"
|
||||||
} else {
|
} else {
|
||||||
logger.Warn("Running with unsafe unrestricted networking")
|
logger.Warn("Running with unsafe unrestricted networking")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, pf := range extraPortForwardings {
|
for _, pf := range cfg.ExtraPortForwardingRules {
|
||||||
hostIPStr := ""
|
hostIPStr := ""
|
||||||
if pf.HostIP != nil {
|
if pf.HostIP != nil {
|
||||||
hostIPStr = pf.HostIP.String()
|
hostIPStr = pf.HostIP.String()
|
||||||
|
|
@ -87,16 +98,16 @@ func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDe
|
||||||
|
|
||||||
cmdArgs = append(cmdArgs, "-device", "e1000,netdev=net0", "-netdev", netdevOpts)
|
cmdArgs = append(cmdArgs, "-device", "e1000,netdev=net0", "-netdev", netdevOpts)
|
||||||
|
|
||||||
cmdArgs = append(cmdArgs, "-drive", "file="+shellescape.Quote(alpineImagePath)+",format=qcow2,if=virtio", "-snapshot")
|
cmdArgs = append(cmdArgs, "-drive", "file="+shellescape.Quote(cdromImagePath)+",format=qcow2,if=virtio", "-snapshot")
|
||||||
|
|
||||||
if !debug {
|
if !cfg.DebugShowDisplay {
|
||||||
cmdArgs = append(cmdArgs, "-display", "none")
|
cmdArgs = append(cmdArgs, "-display", "none")
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(usbDevices) != 0 {
|
if len(cfg.USBDevices) != 0 {
|
||||||
cmdArgs = append(cmdArgs, "-usb", "-device", "nec-usb-xhci,id=xhci")
|
cmdArgs = append(cmdArgs, "-usb", "-device", "nec-usb-xhci,id=xhci")
|
||||||
|
|
||||||
for _, dev := range usbDevices {
|
for _, dev := range cfg.USBDevices {
|
||||||
cmdArgs = append(cmdArgs, "-device", "usb-host,hostbus="+strconv.FormatUint(uint64(dev.HostBus), 10)+",hostport="+strconv.FormatUint(uint64(dev.HostPort), 10))
|
cmdArgs = append(cmdArgs, "-device", "usb-host,hostbus="+strconv.FormatUint(uint64(dev.HostBus), 10)+",hostport="+strconv.FormatUint(uint64(dev.HostPort), 10))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -120,7 +131,7 @@ func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDe
|
||||||
|
|
||||||
ctx, ctxCancel := context.WithCancel(context.Background())
|
ctx, ctxCancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
vi := &Instance{
|
vm := &VM{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
|
@ -137,17 +148,17 @@ func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDe
|
||||||
stderrBuf: stderrBuf,
|
stderrBuf: stderrBuf,
|
||||||
}
|
}
|
||||||
|
|
||||||
vi.resetSerialStdout()
|
vm.resetSerialStdout()
|
||||||
|
|
||||||
return vi, nil
|
return vm, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vi *Instance) Run() error {
|
func (vm *VM) Run() error {
|
||||||
if atomic.AddUint32(&vi.disposed, 1) != 1 {
|
if atomic.AddUint32(&vm.disposed, 1) != 1 {
|
||||||
return fmt.Errorf("vm disposed")
|
return fmt.Errorf("vm disposed")
|
||||||
}
|
}
|
||||||
|
|
||||||
err := vi.cmd.Start()
|
err := vm.cmd.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "start qemu cmd")
|
return errors.Wrap(err, "start qemu cmd")
|
||||||
}
|
}
|
||||||
|
|
@ -159,40 +170,40 @@ func (vi *Instance) Run() error {
|
||||||
globalErrsMu.Lock()
|
globalErrsMu.Lock()
|
||||||
defer globalErrsMu.Unlock()
|
defer globalErrsMu.Unlock()
|
||||||
|
|
||||||
globalErrs = append(globalErrs, err, errors.Wrap(vi.Cancel(), "cancel on error"))
|
globalErrs = append(globalErrs, err, errors.Wrap(vm.Cancel(), "cancel on error"))
|
||||||
}
|
}
|
||||||
|
|
||||||
vi.logger.Info("Booting the VM")
|
vm.logger.Info("Booting the VM")
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
_ = vi.runSerialReader()
|
_ = vm.runSerialReader()
|
||||||
_ = vi.Cancel()
|
_ = vm.Cancel()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
err = vi.runVMLoginHandler()
|
err = vm.runVMLoginHandler()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
globalErrFn(errors.Wrap(err, "run vm login handler"))
|
globalErrFn(errors.Wrap(err, "run vm login handler"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
vi.logger.Info("Setting the VM up")
|
vm.logger.Info("Setting the VM up")
|
||||||
|
|
||||||
sshSigner, err := vi.sshSetup()
|
sshSigner, err := vm.sshSetup()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
globalErrFn(errors.Wrap(err, "set up ssh"))
|
globalErrFn(errors.Wrap(err, "set up ssh"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
vi.logger.Debug("Set up SSH server successfully")
|
vm.logger.Debug("Set up SSH server successfully")
|
||||||
|
|
||||||
sshKeyScan, err := vi.scanSSHIdentity()
|
sshKeyScan, err := vm.scanSSHIdentity()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
globalErrFn(errors.Wrap(err, "scan ssh identity"))
|
globalErrFn(errors.Wrap(err, "scan ssh identity"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
vi.logger.Debug("Scanned SSH identity")
|
vm.logger.Debug("Scanned SSH identity")
|
||||||
|
|
||||||
knownHosts, err := ParseSSHKeyScan(sshKeyScan)
|
knownHosts, err := ParseSSHKeyScan(sshKeyScan)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -201,7 +212,7 @@ func (vi *Instance) Run() error {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
vi.sshConf = &ssh.ClientConfig{
|
vm.sshConf = &ssh.ClientConfig{
|
||||||
User: "root",
|
User: "root",
|
||||||
HostKeyCallback: knownHosts,
|
HostKeyCallback: knownHosts,
|
||||||
Auth: []ssh.AuthMethod{
|
Auth: []ssh.AuthMethod{
|
||||||
|
|
@ -211,111 +222,111 @@ func (vi *Instance) Run() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// This is to notify everyone waiting for SSH to be up that it's ready to go.
|
// This is to notify everyone waiting for SSH to be up that it's ready to go.
|
||||||
close(vi.sshReadyCh)
|
close(vm.sshReadyCh)
|
||||||
|
|
||||||
vi.logger.Info("The VM is ready")
|
vm.logger.Info("The VM is ready")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
_, err = vi.cmd.Process.Wait()
|
_, err = vm.cmd.Process.Wait()
|
||||||
cancelErr := vi.Cancel()
|
cancelErr := vm.Cancel()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
combinedErr := multierr.Combine(
|
combinedErr := multierr.Combine(
|
||||||
errors.Wrap(err, "wait for cmd to finish execution"),
|
errors.Wrap(err, "wait for cmd to finish execution"),
|
||||||
errors.Wrap(cancelErr, "cancel"),
|
errors.Wrap(cancelErr, "cancel"),
|
||||||
)
|
)
|
||||||
|
|
||||||
return fmt.Errorf("%w %v", combinedErr, getLogErrMsg(vi.stderrBuf.String()))
|
return fmt.Errorf("%w %v", combinedErr, getLogErrMsg(vm.stderrBuf.String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
combinedErr := multierr.Combine(
|
combinedErr := multierr.Combine(
|
||||||
append(globalErrs, errors.Wrap(cancelErr, "cancel on exit"))...,
|
append(globalErrs, errors.Wrap(cancelErr, "cancel on exit"))...,
|
||||||
)
|
)
|
||||||
if combinedErr != nil {
|
if combinedErr != nil {
|
||||||
return fmt.Errorf("%w %v", combinedErr, getLogErrMsg(vi.stderrBuf.String()))
|
return fmt.Errorf("%w %v", combinedErr, getLogErrMsg(vm.stderrBuf.String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vi *Instance) Cancel() error {
|
func (vm *VM) Cancel() error {
|
||||||
if atomic.AddUint32(&vi.canceled, 1) != 1 {
|
if atomic.AddUint32(&vm.canceled, 1) != 1 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
vi.logger.Warn("Canceling the VM context")
|
vm.logger.Warn("Canceling the VM context")
|
||||||
|
|
||||||
var gracefulOK bool
|
var gracefulOK bool
|
||||||
|
|
||||||
sc, err := vi.DialSSH()
|
sc, err := vm.DialSSH()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !errors.Is(err, ErrSSHUnavailable) {
|
if !errors.Is(err, ErrSSHUnavailable) {
|
||||||
vi.logger.Warn("Failed to dial VM ssh to do graceful shutdown", "error", err)
|
vm.logger.Warn("Failed to dial VM ssh to do graceful shutdown", "error", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
_, err = runSSHCmd(sc, "poweroff")
|
_, err = runSSHCmd(sc, "poweroff")
|
||||||
_ = sc.Close()
|
_ = sc.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
vi.logger.Warn("Could not power off the VM safely", "error", err)
|
vm.logger.Warn("Could not power off the VM safely", "error", err)
|
||||||
} else {
|
} else {
|
||||||
vi.logger.Info("Shutting the VM down safely")
|
vm.logger.Info("Shutting the VM down safely")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var interruptErr error
|
var interruptErr error
|
||||||
|
|
||||||
if !gracefulOK {
|
if !gracefulOK {
|
||||||
interruptErr = vi.cmd.Process.Signal(os.Interrupt)
|
interruptErr = vm.cmd.Process.Signal(os.Interrupt)
|
||||||
}
|
}
|
||||||
|
|
||||||
vi.ctxCancel()
|
vm.ctxCancel()
|
||||||
return multierr.Combine(
|
return multierr.Combine(
|
||||||
errors.Wrap(interruptErr, "interrupt cmd"),
|
errors.Wrap(interruptErr, "interrupt cmd"),
|
||||||
errors.Wrap(vi.serialRead.Close(), "close serial read pipe"),
|
errors.Wrap(vm.serialRead.Close(), "close serial read pipe"),
|
||||||
errors.Wrap(vi.serialWrite.Close(), "close serial write pipe"),
|
errors.Wrap(vm.serialWrite.Close(), "close serial write pipe"),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vi *Instance) runSerialReader() error {
|
func (vm *VM) runSerialReader() error {
|
||||||
for {
|
for {
|
||||||
raw, err := vi.serialReader.ReadBytes('\n')
|
raw, err := vm.serialReader.ReadBytes('\n')
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "read from serial reader")
|
return errors.Wrap(err, "read from serial reader")
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case vi.serialStdoutCh <- raw:
|
case vm.serialStdoutCh <- raw:
|
||||||
default:
|
default:
|
||||||
// Message gets discarded if the buffer is full.
|
// Message gets discarded if the buffer is full.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vi *Instance) writeSerial(b []byte) error {
|
func (vm *VM) writeSerial(b []byte) error {
|
||||||
vi.serialWriteMu.Lock()
|
vm.serialWriteMu.Lock()
|
||||||
defer vi.serialWriteMu.Unlock()
|
defer vm.serialWriteMu.Unlock()
|
||||||
|
|
||||||
_, err := vi.serialWrite.Write(b)
|
_, err := vm.serialWrite.Write(b)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vi *Instance) runVMLoginHandler() error {
|
func (vm *VM) runVMLoginHandler() error {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-vi.ctx.Done():
|
case <-vm.ctx.Done():
|
||||||
return nil
|
return nil
|
||||||
case <-time.After(time.Second):
|
case <-time.After(time.Second):
|
||||||
peek, err := vi.serialReader.Peek(vi.serialReader.Buffered())
|
peek, err := vm.serialReader.Peek(vm.serialReader.Buffered())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "peek stdout")
|
return errors.Wrap(err, "peek stdout")
|
||||||
}
|
}
|
||||||
|
|
||||||
if bytes.Contains(peek, []byte("login:")) {
|
if bytes.Contains(peek, []byte("login:")) {
|
||||||
err = vi.writeSerial([]byte("root\n"))
|
err = vm.writeSerial([]byte("root\n"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "failed to stdio write login")
|
return errors.Wrap(err, "failed to stdio write login")
|
||||||
}
|
}
|
||||||
|
|
||||||
vi.logger.Debug("Logged into the VM serial")
|
vm.logger.Debug("Logged into the VM serial")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -323,24 +334,24 @@ func (vi *Instance) runVMLoginHandler() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vi *Instance) resetSerialStdout() {
|
func (vm *VM) resetSerialStdout() {
|
||||||
vi.serialStdoutCh = make(chan []byte, 32)
|
vm.serialStdoutCh = make(chan []byte, 32)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vi *Instance) DialSSH() (*ssh.Client, error) {
|
func (vm *VM) DialSSH() (*ssh.Client, error) {
|
||||||
if vi.sshConf == nil {
|
if vm.sshConf == nil {
|
||||||
return nil, ErrSSHUnavailable
|
return nil, ErrSSHUnavailable
|
||||||
}
|
}
|
||||||
|
|
||||||
return ssh.Dial("tcp", "localhost:"+fmt.Sprint(vi.sshMappedPort), vi.sshConf)
|
return ssh.Dial("tcp", "localhost:"+fmt.Sprint(vm.sshMappedPort), vm.sshConf)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vi *Instance) DialSCP() (*scp.Client, error) {
|
func (vm *VM) DialSCP() (*scp.Client, error) {
|
||||||
if vi.sshConf == nil {
|
if vm.sshConf == nil {
|
||||||
return nil, ErrSSHUnavailable
|
return nil, ErrSSHUnavailable
|
||||||
}
|
}
|
||||||
|
|
||||||
sc := scp.NewClient("localhost:"+fmt.Sprint(vi.sshMappedPort), vi.sshConf)
|
sc := scp.NewClient("localhost:"+fmt.Sprint(vm.sshMappedPort), vm.sshConf)
|
||||||
err := sc.Connect()
|
err := sc.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -349,6 +360,6 @@ func (vi *Instance) DialSCP() (*scp.Client, error) {
|
||||||
return &sc, nil
|
return &sc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (vi *Instance) SSHUpNotifyChan() chan struct{} {
|
func (vm *VM) SSHUpNotifyChan() chan struct{} {
|
||||||
return vi.sshReadyCh
|
return vm.sshReadyCh
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue