LUKS support + other things

This commit is contained in:
AlexSSD7 2023-08-25 19:55:11 +01:00
commit 4369fb82dd
8 changed files with 201 additions and 22 deletions

View file

@ -1,3 +1,4 @@
openssh
lvm2
util-linux
util-linux
cryptsetup

View file

@ -21,7 +21,7 @@ var lsCmd = &cobra.Command{
runVM(args[0], func(ctx context.Context, i *vm.Instance, fm *vm.FileManager) {
lsblkOut, err := fm.Lsblk()
if err != nil {
slog.Error("Failed to list block devices in the VM", "error", err.Error())
slog.Error("Failed to list block devices in the VM", "error", err)
os.Exit(1)
}
@ -35,7 +35,7 @@ var lsCmd = &cobra.Command{
func getDevicePassthroughConfig(val string) vm.USBDevicePassthroughConfig {
valSplit := strings.Split(val, ":")
if want, have := 2, len(valSplit); want != have {
slog.Error("Bad device passthrough syntax", "error", fmt.Errorf("wrong items split by ':' count: want %v, have %v", want, have).Error())
slog.Error("Bad device passthrough syntax", "error", fmt.Errorf("wrong items split by ':' count: want %v, have %v", want, have))
os.Exit(1)
}
@ -43,7 +43,7 @@ func getDevicePassthroughConfig(val string) vm.USBDevicePassthroughConfig {
case "usb":
usbValsSplit := strings.Split(valSplit[1], ",")
if want, have := 2, len(usbValsSplit); want != have {
slog.Error("Bad USB device passthrough syntax", "error", fmt.Errorf("wrong args split by ',' count: want %v, have %v", want, have).Error())
slog.Error("Bad USB device passthrough syntax", "error", fmt.Errorf("wrong args split by ',' count: want %v, have %v", want, have))
os.Exit(1)
}

View file

@ -18,8 +18,13 @@ var runCmd = &cobra.Command{
vmMountDevName := args[1]
fsType := args[2]
// TODO: `slog` library prints entire stack traces for errors which makes reading errors challenging.
runVM(args[0], func(ctx context.Context, i *vm.Instance, fm *vm.FileManager) {
err := fm.Mount(vmMountDevName, vm.MountOptions{FSType: fsType})
err := fm.Mount(vmMountDevName, vm.MountOptions{
FSType: fsType,
LUKS: luksFlag,
})
if err != nil {
slog.Error("Failed to mount the disk inside the VM", "error", err)
return
@ -32,3 +37,9 @@ var runCmd = &cobra.Command{
return nil
},
}
var luksFlag bool
func init() {
runCmd.Flags().BoolVarP(&luksFlag, "luks", "l", false, "Use cryptsetup to open a LUKS volume (password will be prompted)")
}

View file

@ -3,8 +3,10 @@ package cmd
import (
"context"
"os"
"os/signal"
"os/user"
"sync"
"syscall"
"log/slog"
@ -23,7 +25,7 @@ func checkIfRoot() (bool, error) {
func doRootCheck() {
ok, err := checkIfRoot()
if err != nil {
slog.Error("Failed to check whether the command is ran by root", "error", err.Error())
slog.Error("Failed to check whether the command is ran by root", "error", err)
os.Exit(1)
}
@ -39,9 +41,9 @@ func runVM(passthroughArg string, fn func(context.Context, *vm.Instance, *vm.Fil
passthroughConfig := getDevicePassthroughConfig(passthroughArg)
// TODO: Alpine image should be downloaded from somewhere.
vi, err := vm.NewInstance(slog.Default(), "alpine-img/alpine.qcow2", []vm.USBDevicePassthroughConfig{passthroughConfig}, true)
vi, err := vm.NewInstance(slog.Default().With("caller", "vm"), "alpine-img/alpine.qcow2", []vm.USBDevicePassthroughConfig{passthroughConfig}, true)
if err != nil {
slog.Error("Failed to create vm instance", "error", err.Error())
slog.Error("Failed to create vm instance", "error", err)
os.Exit(1)
}
@ -49,6 +51,10 @@ func runVM(passthroughArg string, fn func(context.Context, *vm.Instance, *vm.Fil
var wg sync.WaitGroup
ctx, ctxCancel := context.WithCancel(context.Background())
defer ctxCancel()
interrupt := make(chan os.Signal, 2)
signal.Notify(interrupt, syscall.SIGTERM, syscall.SIGINT)
wg.Add(1)
go func() {
@ -59,17 +65,42 @@ func runVM(passthroughArg string, fn func(context.Context, *vm.Instance, *vm.Fil
runErrCh <- err
}()
fm := vm.NewFileManager(vi)
go func() {
for i := 0; ; i++ {
select {
case <-ctx.Done():
signal.Reset()
return
case sig := <-interrupt:
lg := slog.With("signal", sig)
if i == 0 {
lg.Warn("Caught interrupt, safely shutting down")
} else if i < 10 {
lg.Warn("Caught subsequent interrupt, please interrupt n more times to panic", "n", 10-i)
} else {
panic("force interrupt")
}
err := vi.Cancel()
if err != nil {
lg.Warn("Failed to cancel VM context", "error", err)
}
}
}
}()
fm := vm.NewFileManager(slog.Default().With("caller", "file-manager"), vi)
for {
select {
case err := <-runErrCh:
slog.Error("Failed to start the VM", "error", err.Error())
slog.Error("Failed to start the VM", "error", err)
os.Exit(1)
case <-vi.SSHUpNotifyChan():
err := fm.Init()
if err != nil {
slog.Error("Failed to initialize File Manager", "error", err.Error())
slog.Error("Failed to initialize File Manager", "error", err)
os.Exit(1)
}
@ -77,7 +108,7 @@ func runVM(passthroughArg string, fn func(context.Context, *vm.Instance, *vm.Fil
err = vi.Cancel()
if err != nil {
slog.Error("Failed to cancel VM context", "error", err.Error())
slog.Error("Failed to cancel VM context", "error", err)
os.Exit(1)
}
@ -86,7 +117,7 @@ func runVM(passthroughArg string, fn func(context.Context, *vm.Instance, *vm.Fil
select {
case err := <-runErrCh:
if err != nil {
slog.Error("Failed to run the VM", "error", err.Error())
slog.Error("Failed to run the VM", "error", err)
os.Exit(1)
}
default:

1
go.mod
View file

@ -9,6 +9,7 @@ require (
github.com/spf13/cobra v1.7.0
go.uber.org/multierr v1.11.0
golang.org/x/crypto v0.12.0
golang.org/x/term v0.11.0
)
require (

View file

@ -2,36 +2,50 @@ package vm
import (
"bytes"
"crypto/rand"
"fmt"
"log/slog"
"os"
"path/filepath"
"sync"
"syscall"
"github.com/AlexSSD7/vldisk/utils"
"github.com/alessio/shellescape"
"github.com/pkg/errors"
"go.uber.org/multierr"
"golang.org/x/crypto/ssh"
"golang.org/x/term"
)
type FileManager struct {
logger *slog.Logger
vi *Instance
}
func NewFileManager(vi *Instance) *FileManager {
func NewFileManager(logger *slog.Logger, vi *Instance) *FileManager {
return &FileManager{
logger: logger,
vi: vi,
}
}
func (fm *FileManager) Init() error {
c, err := fm.vi.DialSSH()
sc, err := fm.vi.DialSSH()
if err != nil {
return errors.Wrap(err, "dial vm ssh")
}
_, err = runSSHCmd(c, "apk add util-linux lvm2")
defer func() { _ = sc.Close() }()
_, err = runSSHCmd(sc, "apk add util-linux lvm2")
if err != nil {
return errors.Wrap(err, "install utilities")
}
_, err = runSSHCmd(c, "vgchange -ay")
_, err = runSSHCmd(sc, "vgchange -ay")
if err != nil {
return errors.Wrap(err, "run vgchange cmd")
}
@ -45,11 +59,15 @@ func (fm *FileManager) Lsblk() ([]byte, error) {
return nil, errors.Wrap(err, "dial vm ssh")
}
defer func() { _ = sc.Close() }()
sess, err := sc.NewSession()
if err != nil {
return nil, errors.Wrap(err, "create new vm ssh session")
}
defer func() { _ = sess.Close() }()
ret := new(bytes.Buffer)
sess.Stdout = ret
@ -64,6 +82,84 @@ func (fm *FileManager) Lsblk() ([]byte, error) {
type MountOptions struct {
FSType string
LUKS bool
}
const luksDMName = "cryptmnt"
func (fm *FileManager) luksOpen(sc *ssh.Client, fullDevPath string) error {
lg := fm.logger.With("vm-path", fullDevPath)
sess, err := sc.NewSession()
if err != nil {
return errors.Wrap(err, "create new vm ssh session")
}
defer func() { sess.Close() }()
stdinPipe, err := sess.StdinPipe()
if err != nil {
return errors.Wrap(err, "create vm ssh session stdin pipe")
}
stderrBuf := bytes.NewBuffer(nil)
sess.Stderr = stderrBuf
err = sess.Start("cryptsetup luksOpen " + shellescape.Quote(fullDevPath) + " " + luksDMName)
if err != nil {
return errors.Wrap(err, "start cryptsetup luksopen cmd")
}
lg.Info("Attempting to open LUKS device")
_, err = os.Stderr.Write([]byte("Enter Password: "))
if err != nil {
return errors.Wrap(err, "write prompt to stderr")
}
pwd, err := term.ReadPassword(int(syscall.Stdin))
if err != nil {
return errors.Wrap(err, "read luks password")
}
fmt.Print("\n")
var wErr error
var wWG sync.WaitGroup
wWG.Add(1)
go func() {
defer wWG.Done()
_, err := stdinPipe.Write(pwd)
_, err2 := stdinPipe.Write([]byte("\n"))
wErr = errors.Wrap(multierr.Combine(err, err2), "write password to stdin")
}()
// TODO: Timeout for this command
err = sess.Wait()
if err != nil {
return wrapErrWithLog(err, "wait for cryptsetup luksopen cmd to finish", stderrBuf.String())
}
lg.Info("LUKS device opened successfully")
// Clear the memory up
{
for i := 0; i < len(pwd); i++ {
pwd[i] = 0
}
for i := 0; i < 16; i++ {
_, _ = rand.Read(pwd)
}
}
_ = stdinPipe.Close()
wWG.Wait()
return wErr
}
func (fm *FileManager) Mount(devName string, mo MountOptions) error {
@ -88,12 +184,18 @@ func (fm *FileManager) Mount(devName string, mo MountOptions) error {
return errors.Wrap(err, "dial vm ssh")
}
sess, err := sc.NewSession()
if err != nil {
return errors.Wrap(err, "create new vm ssh session")
defer func() { _ = sc.Close() }()
if mo.LUKS {
err = fm.luksOpen(sc, fullDevPath)
if err != nil {
return errors.Wrap(err, "luks open")
}
fullDevPath = "/dev/mapper/" + luksDMName
}
err = sess.Run("mount -t " + shellescape.Quote(mo.FSType) + " " + shellescape.Quote(fullDevPath) + " /mnt")
_, err = runSSHCmd(sc, "mount -t "+shellescape.Quote(mo.FSType)+" "+shellescape.Quote(fullDevPath)+" /mnt")
if err != nil {
return errors.Wrap(err, "run mount cmd")
}

View file

@ -153,6 +153,8 @@ func runSSHCmd(c *ssh.Client, cmd string) ([]byte, error) {
return nil, errors.Wrap(err, "create new vm ssh session")
}
defer func() { _ = sess.Close() }()
stdout := bytes.NewBuffer(nil)
stderr := bytes.NewBuffer(nil)

View file

@ -13,6 +13,7 @@ import (
"strconv"
"sync"
"sync/atomic"
"syscall"
"time"
"log/slog"
@ -98,6 +99,11 @@ func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDe
stderrBuf := bytes.NewBuffer(nil)
cmd.Stderr = stderrBuf
// This is to prevent Ctrl+C propagating to the child process.
cmd.SysProcAttr = &syscall.SysProcAttr{
Setpgid: true,
}
userReader := bufio.NewReader(userRead)
ctx, ctxCancel := context.WithCancel(context.Background())
@ -224,9 +230,34 @@ func (vi *Instance) Cancel() error {
return nil
}
vi.logger.Warn("Canceling the VM context")
var gracefulOK bool
sc, err := vi.DialSSH()
if err != nil {
if !errors.Is(err, ErrSSHUnavailable) {
vi.logger.Warn("Failed to dial VM ssh to do graceful shutdown", "error", err)
}
} else {
_, err = runSSHCmd(sc, "poweroff")
_ = sc.Close()
if err != nil {
vi.logger.Warn("Could not power off the VM safely", "error", err)
} else {
vi.logger.Info("Shutting the VM down safely")
}
}
var interruptErr error
if !gracefulOK {
interruptErr = vi.cmd.Process.Signal(os.Interrupt)
}
vi.ctxCancel()
return multierr.Combine(
errors.Wrap(vi.cmd.Process.Signal(os.Interrupt), "cancel cmd"),
errors.Wrap(interruptErr, "interrupt cmd"),
errors.Wrap(vi.serialRead.Close(), "close serial read pipe"),
errors.Wrap(vi.serialWrite.Close(), "close serial write pipe"),
)