diff --git a/cmd/ls.go b/cmd/ls.go index dbb498b..17e7aa9 100644 --- a/cmd/ls.go +++ b/cmd/ls.go @@ -18,15 +18,16 @@ var lsCmd = &cobra.Command{ // Short: "", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - runVM(args[0], func(ctx context.Context, i *vm.Instance, fm *vm.FileManager) { + os.Exit(runVM(args[0], func(ctx context.Context, i *vm.Instance, fm *vm.FileManager) int { lsblkOut, err := fm.Lsblk() if err != nil { slog.Error("Failed to list block devices in the VM", "error", err) - os.Exit(1) + return 1 } fmt.Print(string(lsblkOut)) - }) + return 0 + })) return nil }, diff --git a/cmd/root.go b/cmd/root.go index f88f0ce..a671ccf 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -22,9 +22,14 @@ func Execute() { } } +var vmDebugFlag bool + func init() { slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, nil))) rootCmd.AddCommand(lsCmd) rootCmd.AddCommand(runCmd) + rootCmd.AddCommand(shellCmd) + + rootCmd.PersistentFlags().BoolVar(&vmDebugFlag, "vmdebug", false, "Enable VM debug mode. This will open an accessible VM monitor. You can log in with root user and no password.") } diff --git a/cmd/run.go b/cmd/run.go index 47723c5..28feea5 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log/slog" + "os" "github.com/AlexSSD7/linsk/vm" "github.com/spf13/cobra" @@ -20,19 +21,20 @@ var runCmd = &cobra.Command{ // TODO: `slog` library prints entire stack traces for errors which makes reading errors challenging. - runVM(args[0], func(ctx context.Context, i *vm.Instance, fm *vm.FileManager) { + os.Exit(runVM(args[0], func(ctx context.Context, i *vm.Instance, fm *vm.FileManager) int { err := fm.Mount(vmMountDevName, vm.MountOptions{ FSType: fsType, LUKS: luksFlag, }) if err != nil { slog.Error("Failed to mount the disk inside the VM", "error", err) - return + return 1 } fmt.Println("Mounted! Now sleeping") <-ctx.Done() - }) + return 0 + })) return nil }, diff --git a/cmd/shell.go b/cmd/shell.go new file mode 100644 index 0000000..c573f78 --- /dev/null +++ b/cmd/shell.go @@ -0,0 +1,110 @@ +package cmd + +import ( + "context" + "log/slog" + "os" + + "github.com/AlexSSD7/linsk/vm" + "github.com/spf13/cobra" + "golang.org/x/crypto/ssh" + "golang.org/x/term" +) + +var shellCmd = &cobra.Command{ + Use: "shell", + // TODO: Fill this + // Short: "", + Args: cobra.RangeArgs(0, 1), + RunE: func(cmd *cobra.Command, args []string) error { + var passthroughArg string + if len(args) > 0 { + passthroughArg = args[0] + } + + os.Exit(runVM(passthroughArg, func(ctx context.Context, i *vm.Instance, fm *vm.FileManager) int { + sc, err := i.DialSSH() + if err != nil { + slog.Error("Failed to dial VM SSH", "error", err) + return 1 + } + + defer func() { _ = sc.Close() }() + + sess, err := sc.NewSession() + if err != nil { + slog.Error("Failed to create new VM SSH session", "error", err) + return 1 + } + + defer func() { _ = sess.Close() }() + + termFD := int(os.Stdin.Fd()) + termState, err := term.MakeRaw(termFD) + if err != nil { + slog.Error("Failed to make raw terminal", "error", err) + return 1 + } + + defer func() { + err := term.Restore(termFD, termState) + if err != nil { + slog.Error("Failed to restore terminal", "error", err) + } + }() + + termWidth, termHeight, err := term.GetSize(termFD) + if err != nil { + slog.Error("Failed to get terminal size", "error", err) + return 1 + } + + termModes := ssh.TerminalModes{ + ssh.ECHO: 1, + ssh.TTY_OP_ISPEED: 14400, + ssh.TTY_OP_OSPEED: 14400, + } + + term := os.Getenv("TERM") + if term == "" { + term = "xterm-256color" + } + + err = sess.RequestPty(term, termHeight, termWidth, termModes) + if err != nil { + slog.Error("Failed to request VM SSH pty", "error", err) + return 1 + } + + sess.Stdin = os.Stdin + sess.Stdout = os.Stdout + sess.Stderr = os.Stderr + + err = sess.Shell() + if err != nil { + slog.Error("Start VM SSH shell", "error", err) + return 1 + } + + doneCh := make(chan struct{}, 1) + + go func() { + err = sess.Wait() + if err != nil { + slog.Error("Failed to wait for VM SSH session to finish", "error", err) + } + + doneCh <- struct{}{} + }() + + select { + case <-ctx.Done(): + case <-doneCh: + } + + return 0 + })) + + return nil + }, +} diff --git a/cmd/utils.go b/cmd/utils.go index 8e5c0a8..f7efc3f 100644 --- a/cmd/utils.go +++ b/cmd/utils.go @@ -35,13 +35,17 @@ func doRootCheck() { } } -func runVM(passthroughArg string, fn func(context.Context, *vm.Instance, *vm.FileManager)) *vm.Instance { +func runVM(passthroughArg string, fn func(context.Context, *vm.Instance, *vm.FileManager) int) int { doRootCheck() - passthroughConfig := getDevicePassthroughConfig(passthroughArg) + var passthroughConfig []vm.USBDevicePassthroughConfig + + if passthroughArg != "" { + passthroughConfig = []vm.USBDevicePassthroughConfig{getDevicePassthroughConfig(passthroughArg)} + } // TODO: Alpine image should be downloaded from somewhere. - vi, err := vm.NewInstance(slog.Default().With("caller", "vm"), "alpine-img/alpine.qcow2", []vm.USBDevicePassthroughConfig{passthroughConfig}, true) + vi, err := vm.NewInstance(slog.Default().With("caller", "vm"), "alpine-img/alpine.qcow2", passthroughConfig, vmDebugFlag) if err != nil { slog.Error("Failed to create vm instance", "error", err) os.Exit(1) @@ -104,7 +108,7 @@ func runVM(passthroughArg string, fn func(context.Context, *vm.Instance, *vm.Fil os.Exit(1) } - fn(ctx, vi, fm) + exitCode := fn(ctx, vi, fm) err = vi.Cancel() if err != nil { @@ -123,7 +127,7 @@ func runVM(passthroughArg string, fn func(context.Context, *vm.Instance, *vm.Fil default: } - return nil + return exitCode } } } diff --git a/vm/ssh.go b/vm/ssh.go index df1b552..9be6298 100644 --- a/vm/ssh.go +++ b/vm/ssh.go @@ -52,7 +52,7 @@ func ParseSSHKeyScan(knownHosts []byte) (ssh.HostKeyCallback, error) { func (vi *Instance) scanSSHIdentity() ([]byte, error) { vi.resetSerialStdout() - err := vi.writeSerial([]byte(`ssh-keyscan -H localhost; echo "SERIAL STATUS: $?"` + "\n")) + err := vi.writeSerial([]byte(`ssh-keyscan -H localhost; echo "SERIAL STATUS: $?"; rm /root/.ash_history` + "\n")) if err != nil { return nil, errors.Wrap(err, "write keyscan command to serial") }