diff --git a/cmd/root.go b/cmd/root.go index ddb50a0..f5f9f74 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,6 +1,7 @@ package cmd import ( + "fmt" "os" "path/filepath" "runtime" @@ -26,12 +27,19 @@ func Execute() { } } -var vmDebugFlag bool -var unrestrictedNetworkingFlag bool -var vmMemAllocFlag uint32 -var vmSSHSetupTimeoutFlag uint32 -var vmOSUpTimeoutFlag uint32 -var dataDirFlag string +var ( + vmDebugFlag bool + unrestrictedNetworkingFlag bool + vmMemAllocFlag uint32 + vmSSHSetupTimeoutFlag uint32 + vmOSUpTimeoutFlag uint32 + dataDirFlag string +) + +const ( + defaultMemAlloc = 512 + defaultMemAllocLUKS = 2048 +) func init() { slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, nil))) @@ -45,7 +53,7 @@ func init() { rootCmd.PersistentFlags().BoolVar(&vmDebugFlag, "vm-debug", false, "Enables the VM debug mode. This will open an accessible VM monitor. You can log in with root user and no password.") rootCmd.PersistentFlags().BoolVar(&unrestrictedNetworkingFlag, "vm-unrestricted-networking", false, "Enables unrestricted networking. This will allow the VM to connect to the internet.") - rootCmd.PersistentFlags().Uint32Var(&vmMemAllocFlag, "vm-mem-alloc", 512, "Specifies the VM memory allocation in KiB") + rootCmd.PersistentFlags().Uint32Var(&vmMemAllocFlag, "vm-mem-alloc", defaultMemAlloc, fmt.Sprintf("Specifies the VM memory allocation in KiB (the default is %v in LUKS mode)", defaultMemAllocLUKS)) rootCmd.PersistentFlags().Uint32Var(&vmOSUpTimeoutFlag, "vm-os-up-timeout", 30, "Specifies the VM OS-up timeout in seconds.") rootCmd.PersistentFlags().Uint32Var(&vmSSHSetupTimeoutFlag, "vm-ssh-setup-timeout", 60, "Specifies the VM SSH server setup timeout in seconds. This cannot be lower than the OS-up timeout.") diff --git a/cmd/run.go b/cmd/run.go index 617e5b2..747cac8 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -45,6 +45,15 @@ var runCmd = &cobra.Command{ os.Exit(1) } + if luksFlag && !allowLUKSLowMemoryFlag { + if vmMemAllocFlag < 2048 { + if vmMemAllocFlag != defaultMemAlloc { + slog.Warn("Enforcing minimum LUKS memory allocation. Please add --allow-luks-low-memory to disable this.", "min", vmMemAllocFlag, "specified", vmMemAllocFlag) + } + vmMemAllocFlag = defaultMemAllocLUKS + } + } + os.Exit(runVM(args[0], func(ctx context.Context, i *vm.VM, fm *vm.FileManager, tapCtx *share.NetTapRuntimeContext) int { slog.Info("Mounting the device", "dev", vmMountDevName, "fs", fsType, "luks", luksFlag) @@ -83,14 +92,18 @@ var runCmd = &cobra.Command{ }, } -var luksFlag bool -var shareListenIPFlag string -var ftpExtIPFlag string -var shareBackendFlag string -var smbUseExternAddrFlag bool +var ( + luksFlag bool + allowLUKSLowMemoryFlag bool + shareListenIPFlag string + ftpExtIPFlag string + shareBackendFlag string + smbUseExternAddrFlag bool +) func init() { runCmd.Flags().BoolVarP(&luksFlag, "luks", "l", false, "Use cryptsetup to open a LUKS volume (password will be prompted).") + runCmd.Flags().BoolVar(&allowLUKSLowMemoryFlag, "allow-luks-low-memory", false, "Allow VM memory allocation lower than 2048 MiB when LUKS is enabled.") var defaultShareType string switch runtime.GOOS { diff --git a/share/defaults.go b/share/defaults.go index cdecd28..0feb327 100644 --- a/share/defaults.go +++ b/share/defaults.go @@ -6,7 +6,7 @@ import ( ) func IsSMBExtModeDefault() bool { - return runtime.GOOS == "windows1" + return runtime.GOOS == "windows" } var defaultListenIP = net.ParseIP("127.0.0.1") diff --git a/sshutil/ssh.go b/sshutil/ssh.go index 5f6c8f6..b79d485 100644 --- a/sshutil/ssh.go +++ b/sshutil/ssh.go @@ -50,13 +50,13 @@ func RunSSHCmd(ctx context.Context, sc *ssh.Client, cmd string) ([]byte, error) } func NewSSHSession(ctx context.Context, timeout time.Duration, sc *ssh.Client, fn func(*ssh.Session) error) error { - return NewSSHSessionWithDelayedTimeout(ctx, timeout, sc, func(sess *ssh.Session, startTimeout func()) error { - startTimeout() + return NewSSHSessionWithDelayedTimeout(ctx, timeout, sc, func(sess *ssh.Session, startTimeout func(preTimeout func())) error { + startTimeout(nil) return fn(sess) }) } -func NewSSHSessionWithDelayedTimeout(ctx context.Context, timeout time.Duration, sc *ssh.Client, fn func(sess *ssh.Session, startTimeout func()) error) error { +func NewSSHSessionWithDelayedTimeout(ctx context.Context, timeout time.Duration, sc *ssh.Client, fn func(sess *ssh.Session, startTimeout func(preTimeout func())) error) error { s, err := sc.NewSession() if err != nil { return errors.Wrap(err, "create new ssh session") @@ -77,12 +77,13 @@ func NewSSHSessionWithDelayedTimeout(ctx context.Context, timeout time.Duration, } }() - err = fn(s, func() { + err = fn(s, func(preTimeout func()) { // Now start a thread which will close the session // down when the timeout hits. go func() { select { case <-time.After(timeout): + preTimeout() timedOut = true _ = sc.Close() case <-done: diff --git a/vm/filemanager.go b/vm/filemanager.go index 8243d87..2f187c0 100644 --- a/vm/filemanager.go +++ b/vm/filemanager.go @@ -76,7 +76,7 @@ const luksDMName = "cryptmnt" func (fm *FileManager) luksOpen(sc *ssh.Client, fullDevPath string) error { lg := fm.logger.With("vm-path", fullDevPath) - return sshutil.NewSSHSessionWithDelayedTimeout(fm.vm.ctx, time.Second*15, sc, func(sess *ssh.Session, startTimeout func()) error { + return sshutil.NewSSHSessionWithDelayedTimeout(fm.vm.ctx, time.Second*15, sc, func(sess *ssh.Session, startTimeout func(preTimeout func())) error { stdinPipe, err := sess.StdinPipe() if err != nil { return errors.Wrap(err, "create vm ssh session stdin pipe") @@ -106,7 +106,9 @@ func (fm *FileManager) luksOpen(sc *ssh.Client, fullDevPath string) error { // We start the timeout countdown now only to avoid timing out // while the user is entering the password, or shortly after that. - startTimeout() + startTimeout(func() { + lg.Warn("LUKS open command timed out. If you are using large-memory key derivation function, try increasing the VM memory allocation using --vm-mem-alloc flag.") + }) var wErr error var wWG sync.WaitGroup @@ -136,7 +138,7 @@ func (fm *FileManager) luksOpen(sc *ssh.Client, fullDevPath string) error { err = sess.Wait() if err != nil { if strings.Contains(stderrBuf.String(), "Not enough available memory to open a keyslot.") { - fm.logger.Warn("Detected not enough memory to open a LUKS device, please allocate more memory using --vm-mem-alloc flag") + fm.logger.Warn("Detected not enough memory to open a LUKS device, please allocate more memory using --vm-mem-alloc flag.") } return utils.WrapErrWithLog(err, "wait for cryptsetup luksopen cmd to finish", stderrBuf.String())