diff --git a/cmd/run.go b/cmd/run.go index d6dfc51..3aaded5 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -43,8 +43,6 @@ var runCmd = &cobra.Command{ }) } - // TODO: `slog` library prints entire stack traces for errors which makes reading errors challenging. - os.Exit(runVM(args[0], func(ctx context.Context, i *vm.VM, fm *vm.FileManager) int { slog.Info("Mounting the device", "dev", vmMountDevName, "fs", fsType, "luks", luksFlag) @@ -59,14 +57,10 @@ var runCmd = &cobra.Command{ sharePWD, err := password.Generate(16, 10, 0, false, false) if err != nil { - slog.Error("Failed to generate ephemeral password for network file share", "error", err.Error()) + slog.Error("Failed to generate ephemeral password for the network file share", "error", err.Error()) return 1 } - shareURI := "ftp://linsk:" + sharePWD + "@localhost:" + fmt.Sprint(networkSharePort) - - fmt.Fprintf(os.Stderr, "================\n[Network File Share Config]\nThe network file share was started. Please use the credentials below to connect to the file server.\n\nType: FTP\nServer Address: ftp://localhost:%v\nUsername: linsk\nPassword: %v\n\nShare URI: %v\n================\n", networkSharePort, sharePWD, shareURI) - err = fm.StartFTP([]byte(sharePWD), networkSharePort+1, ftpPassivePortCount) if err != nil { slog.Error("Failed to start FTP server", "error", err.Error()) @@ -75,6 +69,10 @@ var runCmd = &cobra.Command{ slog.Info("Started the network share successfully", "type", "ftp") + shareURI := "ftp://linsk:" + sharePWD + "@localhost:" + fmt.Sprint(networkSharePort) + + fmt.Fprintf(os.Stderr, "================\n[Network File Share Config]\nThe network file share was started. Please use the credentials below to connect to the file server.\n\nType: FTP\nServer Address: ftp://localhost:%v\nUsername: linsk\nPassword: %v\n\nShare URI: %v\n================\n", networkSharePort, sharePWD, shareURI) + <-ctx.Done() return 0 }, ports, unrestrictedNetworkingFlag)) diff --git a/cmd/utils.go b/cmd/utils.go index 79192ec..9d40cda 100644 --- a/cmd/utils.go +++ b/cmd/utils.go @@ -42,12 +42,11 @@ func doUSBRootCheck() { ok, err := checkIfRoot() if err != nil { slog.Error("Failed to check whether the command is ran by root", "error", err.Error()) - os.Exit(1) + return } if !ok { - slog.Error("USB passthrough on your OS requires this program to be ran as root") - os.Exit(1) + slog.Warn("USB passthrough on your OS usually requires this program to be ran as root") } } diff --git a/vm/filemanager.go b/vm/filemanager.go index 75bda8b..63bb253 100644 --- a/vm/filemanager.go +++ b/vm/filemanager.go @@ -2,6 +2,7 @@ package vm import ( "bytes" + "context" "crypto/rand" "fmt" "log/slog" @@ -9,6 +10,7 @@ import ( "strings" "sync" "syscall" + "time" "github.com/AlexSSD7/linsk/utils" "github.com/alessio/shellescape" @@ -90,8 +92,6 @@ func (fm *FileManager) luksOpen(sc *ssh.Client, fullDevPath string) error { return errors.Wrap(err, "create new vm ssh session") } - defer func() { sess.Close() }() - stdinPipe, err := sess.StdinPipe() if err != nil { return errors.Wrap(err, "create vm ssh session stdin pipe") @@ -131,7 +131,45 @@ func (fm *FileManager) luksOpen(sc *ssh.Client, fullDevPath string) error { wErr = errors.Wrap(multierr.Combine(err, err2), "write password to stdin") }() - // TODO: Timeout for this command + defer func() { + // Clear the memory up. + { + for i := 0; i < len(pwd); i++ { + pwd[i] = 0 + } + + for i := 0; i < 4; i++ { + _, _ = rand.Read(pwd) + } + } + }() + + done := make(chan struct{}) + defer close(done) + + var timedOut bool + + go func() { + tm := func() { + timedOut = true + _ = sc.Close() + } + select { + case <-fm.vm.ctx.Done(): + tm() + case <-time.After(time.Second * 1): + tm() + case <-done: + } + }() + + checkTimeoutErr := func(err error) error { + if timedOut { + return fmt.Errorf("timed out (%w)", err) + } + + return err + } err = sess.Wait() if err != nil { @@ -139,22 +177,11 @@ func (fm *FileManager) luksOpen(sc *ssh.Client, fullDevPath string) error { fm.logger.Warn("Detected not enough memory to open a LUKS device, please allocate more memory using --vm-mem-alloc flag.") } - return utils.WrapErrWithLog(err, "wait for cryptsetup luksopen cmd to finish", stderrBuf.String()) + return checkTimeoutErr(utils.WrapErrWithLog(err, "wait for cryptsetup luksopen cmd to finish", stderrBuf.String())) } lg.Info("LUKS device opened successfully") - // Clear the memory up - { - for i := 0; i < len(pwd); i++ { - pwd[i] = 0 - } - - for i := 0; i < 16; i++ { - _, _ = rand.Read(pwd) - } - } - _ = stdinPipe.Close() wWG.Wait() @@ -194,15 +221,45 @@ func (fm *FileManager) Mount(devName string, mo MountOptions) error { fullDevPath = "/dev/mapper/" + luksDMName } + done := make(chan struct{}) + defer close(done) + + var timedOut bool + + go func() { + tm := func() { + timedOut = true + _ = sc.Close() + } + select { + case <-fm.vm.ctx.Done(): + tm() + case <-time.After(time.Second * 10): + tm() + case <-done: + } + }() + + checkTimeoutErr := func(err error) error { + if timedOut { + return fmt.Errorf("timed out (%w)", err) + } + + return err + } + _, err = runSSHCmd(sc, "mount -t "+shellescape.Quote(mo.FSType)+" "+shellescape.Quote(fullDevPath)+" /mnt") if err != nil { - return errors.Wrap(err, "run mount cmd") + return checkTimeoutErr(errors.Wrap(err, "run mount cmd")) } return nil } func (fm *FileManager) StartFTP(pwd []byte, passivePortStart uint16, passivePortCount uint16) error { + scpCtx, scpCtxCancel := context.WithTimeout(fm.vm.ctx, time.Second*5) + defer scpCtxCancel() + scpClient, err := fm.vm.DialSCP() if err != nil { return errors.Wrap(err, "dial scp") @@ -223,7 +280,7 @@ pasv_max_port=` + fmt.Sprint(passivePortStart+passivePortCount) + ` pasv_address=127.0.0.1 ` - err = scpClient.CopyFile(fm.vm.ctx, strings.NewReader(ftpdCfg), "/etc/vsftpd/vsftpd.conf", "0400") + err = scpClient.CopyFile(scpCtx, strings.NewReader(ftpdCfg), "/etc/vsftpd/vsftpd.conf", "0400") if err != nil { return errors.Wrap(err, "copy ftpd .conf file") } @@ -235,16 +292,43 @@ pasv_address=127.0.0.1 return errors.Wrap(err, "dial ssh") } + done := make(chan struct{}) + defer close(done) + + var timedOut bool + + go func() { + tm := func() { + timedOut = true + _ = sc.Close() + } + select { + case <-fm.vm.ctx.Done(): + tm() + case <-time.After(time.Second * 15): + tm() + case <-done: + } + }() + + checkTimeoutErr := func(err error) error { + if timedOut { + return fmt.Errorf("timed out (%w)", err) + } + + return err + } + defer func() { _ = sc.Close() }() _, err = runSSHCmd(sc, "rc-update add vsftpd && rc-service vsftpd start") if err != nil { - return errors.Wrap(err, "add and start ftpd service") + return checkTimeoutErr(errors.Wrap(err, "add and start ftpd service")) } sess, err := sc.NewSession() if err != nil { - return errors.Wrap(err, "create new ssh session") + return checkTimeoutErr(errors.Wrap(err, "create new ssh session")) } pwd = append(pwd, '\n') @@ -254,14 +338,12 @@ pasv_address=127.0.0.1 stdinPipe, err := sess.StdinPipe() if err != nil { - return errors.Wrap(err, "stdin pipe") + return checkTimeoutErr(errors.Wrap(err, "stdin pipe")) } - // TODO: Timeout for this command - err = sess.Start("passwd linsk") if err != nil { - return errors.Wrap(err, "start change user password cmd") + return checkTimeoutErr(errors.Wrap(err, "start change user password cmd")) } go func() { @@ -277,7 +359,7 @@ pasv_address=127.0.0.1 err = sess.Wait() if err != nil { - return utils.WrapErrWithLog(err, "wait for change user password cmd", stderr.String()) + return checkTimeoutErr(utils.WrapErrWithLog(err, "wait for change user password cmd", stderr.String())) } return nil diff --git a/vm/vm.go b/vm/vm.go index 11c6426..48dff19 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -424,7 +424,7 @@ func (vm *VM) writeSerial(b []byte) error { // and this will ensure that we do not write faster than the polling rate. for i := range b { _, err := vm.serialWrite.Write([]byte{b[i]}) - time.Sleep(time.Millisecond * 10) + time.Sleep(time.Millisecond) if err != nil { return errors.Wrapf(err, "write char #%v", i) }