From 4369fb82dd56bbddf55069605a0444111eda09c8 Mon Sep 17 00:00:00 2001 From: AlexSSD7 Date: Fri, 25 Aug 2023 19:55:11 +0100 Subject: [PATCH] LUKS support + other things --- alpine-img/img/packages | 3 +- cmd/ls.go | 6 +- cmd/run.go | 13 ++++- cmd/utils.go | 47 +++++++++++++--- go.mod | 1 + vm/filemanager.go | 118 +++++++++++++++++++++++++++++++++++++--- vm/ssh.go | 2 + vm/vm.go | 33 ++++++++++- 8 files changed, 201 insertions(+), 22 deletions(-) diff --git a/alpine-img/img/packages b/alpine-img/img/packages index f42cfd1..afbc8f3 100644 --- a/alpine-img/img/packages +++ b/alpine-img/img/packages @@ -1,3 +1,4 @@ openssh lvm2 -util-linux \ No newline at end of file +util-linux +cryptsetup \ No newline at end of file diff --git a/cmd/ls.go b/cmd/ls.go index 8f147bc..b6046ab 100644 --- a/cmd/ls.go +++ b/cmd/ls.go @@ -21,7 +21,7 @@ var lsCmd = &cobra.Command{ runVM(args[0], func(ctx context.Context, i *vm.Instance, fm *vm.FileManager) { lsblkOut, err := fm.Lsblk() if err != nil { - slog.Error("Failed to list block devices in the VM", "error", err.Error()) + slog.Error("Failed to list block devices in the VM", "error", err) os.Exit(1) } @@ -35,7 +35,7 @@ var lsCmd = &cobra.Command{ func getDevicePassthroughConfig(val string) vm.USBDevicePassthroughConfig { valSplit := strings.Split(val, ":") if want, have := 2, len(valSplit); want != have { - slog.Error("Bad device passthrough syntax", "error", fmt.Errorf("wrong items split by ':' count: want %v, have %v", want, have).Error()) + slog.Error("Bad device passthrough syntax", "error", fmt.Errorf("wrong items split by ':' count: want %v, have %v", want, have)) os.Exit(1) } @@ -43,7 +43,7 @@ func getDevicePassthroughConfig(val string) vm.USBDevicePassthroughConfig { case "usb": usbValsSplit := strings.Split(valSplit[1], ",") if want, have := 2, len(usbValsSplit); want != have { - slog.Error("Bad USB device passthrough syntax", "error", fmt.Errorf("wrong args split by ',' count: want %v, have %v", want, have).Error()) + slog.Error("Bad USB device passthrough syntax", "error", fmt.Errorf("wrong args split by ',' count: want %v, have %v", want, have)) os.Exit(1) } diff --git a/cmd/run.go b/cmd/run.go index 6981d5d..e581e8b 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -18,8 +18,13 @@ var runCmd = &cobra.Command{ vmMountDevName := args[1] fsType := args[2] + // TODO: `slog` library prints entire stack traces for errors which makes reading errors challenging. + runVM(args[0], func(ctx context.Context, i *vm.Instance, fm *vm.FileManager) { - err := fm.Mount(vmMountDevName, vm.MountOptions{FSType: fsType}) + err := fm.Mount(vmMountDevName, vm.MountOptions{ + FSType: fsType, + LUKS: luksFlag, + }) if err != nil { slog.Error("Failed to mount the disk inside the VM", "error", err) return @@ -32,3 +37,9 @@ var runCmd = &cobra.Command{ return nil }, } + +var luksFlag bool + +func init() { + runCmd.Flags().BoolVarP(&luksFlag, "luks", "l", false, "Use cryptsetup to open a LUKS volume (password will be prompted)") +} diff --git a/cmd/utils.go b/cmd/utils.go index 2845be4..0fe8996 100644 --- a/cmd/utils.go +++ b/cmd/utils.go @@ -3,8 +3,10 @@ package cmd import ( "context" "os" + "os/signal" "os/user" "sync" + "syscall" "log/slog" @@ -23,7 +25,7 @@ func checkIfRoot() (bool, error) { func doRootCheck() { ok, err := checkIfRoot() if err != nil { - slog.Error("Failed to check whether the command is ran by root", "error", err.Error()) + slog.Error("Failed to check whether the command is ran by root", "error", err) os.Exit(1) } @@ -39,9 +41,9 @@ func runVM(passthroughArg string, fn func(context.Context, *vm.Instance, *vm.Fil passthroughConfig := getDevicePassthroughConfig(passthroughArg) // TODO: Alpine image should be downloaded from somewhere. - vi, err := vm.NewInstance(slog.Default(), "alpine-img/alpine.qcow2", []vm.USBDevicePassthroughConfig{passthroughConfig}, true) + vi, err := vm.NewInstance(slog.Default().With("caller", "vm"), "alpine-img/alpine.qcow2", []vm.USBDevicePassthroughConfig{passthroughConfig}, true) if err != nil { - slog.Error("Failed to create vm instance", "error", err.Error()) + slog.Error("Failed to create vm instance", "error", err) os.Exit(1) } @@ -49,6 +51,10 @@ func runVM(passthroughArg string, fn func(context.Context, *vm.Instance, *vm.Fil var wg sync.WaitGroup ctx, ctxCancel := context.WithCancel(context.Background()) + defer ctxCancel() + + interrupt := make(chan os.Signal, 2) + signal.Notify(interrupt, syscall.SIGTERM, syscall.SIGINT) wg.Add(1) go func() { @@ -59,17 +65,42 @@ func runVM(passthroughArg string, fn func(context.Context, *vm.Instance, *vm.Fil runErrCh <- err }() - fm := vm.NewFileManager(vi) + go func() { + for i := 0; ; i++ { + select { + case <-ctx.Done(): + signal.Reset() + return + case sig := <-interrupt: + lg := slog.With("signal", sig) + + if i == 0 { + lg.Warn("Caught interrupt, safely shutting down") + } else if i < 10 { + lg.Warn("Caught subsequent interrupt, please interrupt n more times to panic", "n", 10-i) + } else { + panic("force interrupt") + } + + err := vi.Cancel() + if err != nil { + lg.Warn("Failed to cancel VM context", "error", err) + } + } + } + }() + + fm := vm.NewFileManager(slog.Default().With("caller", "file-manager"), vi) for { select { case err := <-runErrCh: - slog.Error("Failed to start the VM", "error", err.Error()) + slog.Error("Failed to start the VM", "error", err) os.Exit(1) case <-vi.SSHUpNotifyChan(): err := fm.Init() if err != nil { - slog.Error("Failed to initialize File Manager", "error", err.Error()) + slog.Error("Failed to initialize File Manager", "error", err) os.Exit(1) } @@ -77,7 +108,7 @@ func runVM(passthroughArg string, fn func(context.Context, *vm.Instance, *vm.Fil err = vi.Cancel() if err != nil { - slog.Error("Failed to cancel VM context", "error", err.Error()) + slog.Error("Failed to cancel VM context", "error", err) os.Exit(1) } @@ -86,7 +117,7 @@ func runVM(passthroughArg string, fn func(context.Context, *vm.Instance, *vm.Fil select { case err := <-runErrCh: if err != nil { - slog.Error("Failed to run the VM", "error", err.Error()) + slog.Error("Failed to run the VM", "error", err) os.Exit(1) } default: diff --git a/go.mod b/go.mod index 3608c05..d0e6332 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/spf13/cobra v1.7.0 go.uber.org/multierr v1.11.0 golang.org/x/crypto v0.12.0 + golang.org/x/term v0.11.0 ) require ( diff --git a/vm/filemanager.go b/vm/filemanager.go index 3a6f395..8f23dca 100644 --- a/vm/filemanager.go +++ b/vm/filemanager.go @@ -2,36 +2,50 @@ package vm import ( "bytes" + "crypto/rand" "fmt" + "log/slog" + "os" "path/filepath" + "sync" + "syscall" "github.com/AlexSSD7/vldisk/utils" "github.com/alessio/shellescape" "github.com/pkg/errors" + "go.uber.org/multierr" + "golang.org/x/crypto/ssh" + "golang.org/x/term" ) type FileManager struct { + logger *slog.Logger + vi *Instance } -func NewFileManager(vi *Instance) *FileManager { +func NewFileManager(logger *slog.Logger, vi *Instance) *FileManager { return &FileManager{ + logger: logger, + vi: vi, } } func (fm *FileManager) Init() error { - c, err := fm.vi.DialSSH() + sc, err := fm.vi.DialSSH() if err != nil { return errors.Wrap(err, "dial vm ssh") } - _, err = runSSHCmd(c, "apk add util-linux lvm2") + defer func() { _ = sc.Close() }() + + _, err = runSSHCmd(sc, "apk add util-linux lvm2") if err != nil { return errors.Wrap(err, "install utilities") } - _, err = runSSHCmd(c, "vgchange -ay") + _, err = runSSHCmd(sc, "vgchange -ay") if err != nil { return errors.Wrap(err, "run vgchange cmd") } @@ -45,11 +59,15 @@ func (fm *FileManager) Lsblk() ([]byte, error) { return nil, errors.Wrap(err, "dial vm ssh") } + defer func() { _ = sc.Close() }() + sess, err := sc.NewSession() if err != nil { return nil, errors.Wrap(err, "create new vm ssh session") } + defer func() { _ = sess.Close() }() + ret := new(bytes.Buffer) sess.Stdout = ret @@ -64,6 +82,84 @@ func (fm *FileManager) Lsblk() ([]byte, error) { type MountOptions struct { FSType string + LUKS bool +} + +const luksDMName = "cryptmnt" + +func (fm *FileManager) luksOpen(sc *ssh.Client, fullDevPath string) error { + lg := fm.logger.With("vm-path", fullDevPath) + + sess, err := sc.NewSession() + if err != nil { + return errors.Wrap(err, "create new vm ssh session") + } + + defer func() { sess.Close() }() + + stdinPipe, err := sess.StdinPipe() + if err != nil { + return errors.Wrap(err, "create vm ssh session stdin pipe") + } + + stderrBuf := bytes.NewBuffer(nil) + sess.Stderr = stderrBuf + + err = sess.Start("cryptsetup luksOpen " + shellescape.Quote(fullDevPath) + " " + luksDMName) + if err != nil { + return errors.Wrap(err, "start cryptsetup luksopen cmd") + } + + lg.Info("Attempting to open LUKS device") + + _, err = os.Stderr.Write([]byte("Enter Password: ")) + if err != nil { + return errors.Wrap(err, "write prompt to stderr") + } + + pwd, err := term.ReadPassword(int(syscall.Stdin)) + if err != nil { + return errors.Wrap(err, "read luks password") + } + + fmt.Print("\n") + + var wErr error + var wWG sync.WaitGroup + + wWG.Add(1) + go func() { + defer wWG.Done() + + _, err := stdinPipe.Write(pwd) + _, err2 := stdinPipe.Write([]byte("\n")) + wErr = errors.Wrap(multierr.Combine(err, err2), "write password to stdin") + }() + + // TODO: Timeout for this command + + err = sess.Wait() + if err != nil { + return wrapErrWithLog(err, "wait for cryptsetup luksopen cmd to finish", stderrBuf.String()) + } + + lg.Info("LUKS device opened successfully") + + // Clear the memory up + { + for i := 0; i < len(pwd); i++ { + pwd[i] = 0 + } + + for i := 0; i < 16; i++ { + _, _ = rand.Read(pwd) + } + } + + _ = stdinPipe.Close() + wWG.Wait() + + return wErr } func (fm *FileManager) Mount(devName string, mo MountOptions) error { @@ -88,12 +184,18 @@ func (fm *FileManager) Mount(devName string, mo MountOptions) error { return errors.Wrap(err, "dial vm ssh") } - sess, err := sc.NewSession() - if err != nil { - return errors.Wrap(err, "create new vm ssh session") + defer func() { _ = sc.Close() }() + + if mo.LUKS { + err = fm.luksOpen(sc, fullDevPath) + if err != nil { + return errors.Wrap(err, "luks open") + } + + fullDevPath = "/dev/mapper/" + luksDMName } - err = sess.Run("mount -t " + shellescape.Quote(mo.FSType) + " " + shellescape.Quote(fullDevPath) + " /mnt") + _, err = runSSHCmd(sc, "mount -t "+shellescape.Quote(mo.FSType)+" "+shellescape.Quote(fullDevPath)+" /mnt") if err != nil { return errors.Wrap(err, "run mount cmd") } diff --git a/vm/ssh.go b/vm/ssh.go index d116728..df1b552 100644 --- a/vm/ssh.go +++ b/vm/ssh.go @@ -153,6 +153,8 @@ func runSSHCmd(c *ssh.Client, cmd string) ([]byte, error) { return nil, errors.Wrap(err, "create new vm ssh session") } + defer func() { _ = sess.Close() }() + stdout := bytes.NewBuffer(nil) stderr := bytes.NewBuffer(nil) diff --git a/vm/vm.go b/vm/vm.go index 09d3290..56cb8af 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -13,6 +13,7 @@ import ( "strconv" "sync" "sync/atomic" + "syscall" "time" "log/slog" @@ -98,6 +99,11 @@ func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDe stderrBuf := bytes.NewBuffer(nil) cmd.Stderr = stderrBuf + // This is to prevent Ctrl+C propagating to the child process. + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: true, + } + userReader := bufio.NewReader(userRead) ctx, ctxCancel := context.WithCancel(context.Background()) @@ -224,9 +230,34 @@ func (vi *Instance) Cancel() error { return nil } + vi.logger.Warn("Canceling the VM context") + + var gracefulOK bool + + sc, err := vi.DialSSH() + if err != nil { + if !errors.Is(err, ErrSSHUnavailable) { + vi.logger.Warn("Failed to dial VM ssh to do graceful shutdown", "error", err) + } + } else { + _, err = runSSHCmd(sc, "poweroff") + _ = sc.Close() + if err != nil { + vi.logger.Warn("Could not power off the VM safely", "error", err) + } else { + vi.logger.Info("Shutting the VM down safely") + } + } + + var interruptErr error + + if !gracefulOK { + interruptErr = vi.cmd.Process.Signal(os.Interrupt) + } + vi.ctxCancel() return multierr.Combine( - errors.Wrap(vi.cmd.Process.Signal(os.Interrupt), "cancel cmd"), + errors.Wrap(interruptErr, "interrupt cmd"), errors.Wrap(vi.serialRead.Close(), "close serial read pipe"), errors.Wrap(vi.serialWrite.Close(), "close serial write pipe"), )