From 0870f8113af168bf8a237025007087b985b223ba Mon Sep 17 00:00:00 2001 From: AlexSSD7 Date: Tue, 29 Aug 2023 14:24:18 +0100 Subject: [PATCH] Polish code --- cmd/imgbuilder/builder/build.go | 4 +- cmd/run.go | 2 +- sshutil/ssh.go | 97 ++++++++++++ sshutil/unix.go | 57 +++++++ utils/utils.go | 8 +- vm/filemanager.go | 256 +++++++++----------------------- vm/os_specifics_windows.go | 1 + vm/ssh.go | 50 ++----- vm/types.go | 4 +- vm/vm.go | 5 +- 10 files changed, 246 insertions(+), 238 deletions(-) create mode 100644 sshutil/ssh.go create mode 100644 sshutil/unix.go diff --git a/cmd/imgbuilder/builder/build.go b/cmd/imgbuilder/builder/build.go index 45e15b7..4d8a10d 100644 --- a/cmd/imgbuilder/builder/build.go +++ b/cmd/imgbuilder/builder/build.go @@ -38,7 +38,7 @@ func NewBuildContext(logger *slog.Logger, baseISOPath string, outPath string, sh return nil, errors.Wrap(err, "stat output file") } - // File doesn't exist. Continuing. + // File doesn't exist. Going forward with creating a new .qcow2 image. } else { return nil, fmt.Errorf("output file already exists") } @@ -183,8 +183,6 @@ func runAlpineSetupCmd(sc *ssh.Client, pkgs []string) error { return errors.Wrap(err, "new session") } - // TODO: Timeout for this command. - stderr := bytes.NewBuffer(nil) sess.Stderr = stderr diff --git a/cmd/run.go b/cmd/run.go index 3aaded5..d7c5fbe 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -61,7 +61,7 @@ var runCmd = &cobra.Command{ return 1 } - err = fm.StartFTP([]byte(sharePWD), networkSharePort+1, ftpPassivePortCount) + err = fm.StartFTP(sharePWD, networkSharePort+1, ftpPassivePortCount) if err != nil { slog.Error("Failed to start FTP server", "error", err.Error()) return 1 diff --git a/sshutil/ssh.go b/sshutil/ssh.go new file mode 100644 index 0000000..5f6c8f6 --- /dev/null +++ b/sshutil/ssh.go @@ -0,0 +1,97 @@ +package sshutil + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/rsa" + "fmt" + "time" + + "github.com/AlexSSD7/linsk/utils" + "github.com/pkg/errors" + "golang.org/x/crypto/ssh" +) + +func GenerateSSHKey() (ssh.Signer, []byte, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + return nil, nil, errors.Wrap(err, "generate rsa private key") + } + + signer, err := ssh.NewSignerFromKey(privateKey) + if err != nil { + return nil, nil, errors.Wrap(err, "create signer from key") + } + + return signer, ssh.MarshalAuthorizedKey(signer.PublicKey()), nil +} + +func RunSSHCmd(ctx context.Context, sc *ssh.Client, cmd string) ([]byte, error) { + var ret []byte + err := NewSSHSession(ctx, time.Second*15, sc, func(sess *ssh.Session) error { + stdout := bytes.NewBuffer(nil) + stderr := bytes.NewBuffer(nil) + + sess.Stdout = stdout + sess.Stderr = stderr + + err := sess.Run(cmd) + if err != nil { + return utils.WrapErrWithLog(err, "run cmd", stderr.String()) + } + + ret = stdout.Bytes() + + return nil + }) + + return ret, err +} + +func NewSSHSession(ctx context.Context, timeout time.Duration, sc *ssh.Client, fn func(*ssh.Session) error) error { + return NewSSHSessionWithDelayedTimeout(ctx, timeout, sc, func(sess *ssh.Session, startTimeout func()) error { + startTimeout() + return fn(sess) + }) +} + +func NewSSHSessionWithDelayedTimeout(ctx context.Context, timeout time.Duration, sc *ssh.Client, fn func(sess *ssh.Session, startTimeout func()) error) error { + s, err := sc.NewSession() + if err != nil { + return errors.Wrap(err, "create new ssh session") + } + + done := make(chan struct{}) + defer close(done) + + var timedOut bool + + // Start a thread to handle context cancelation. + go func() { + select { + case <-ctx.Done(): + timedOut = true + _ = sc.Close() + case <-done: + } + }() + + err = fn(s, func() { + // Now start a thread which will close the session + // down when the timeout hits. + go func() { + select { + case <-time.After(timeout): + timedOut = true + _ = sc.Close() + case <-done: + } + }() + }) + if timedOut { + return fmt.Errorf("timed out (%w)", err) + } + + return err +} diff --git a/sshutil/unix.go b/sshutil/unix.go new file mode 100644 index 0000000..a2f1970 --- /dev/null +++ b/sshutil/unix.go @@ -0,0 +1,57 @@ +package sshutil + +import ( + "bytes" + "context" + "fmt" + "time" + + "github.com/AlexSSD7/linsk/utils" + "github.com/alessio/shellescape" + "github.com/pkg/errors" + "go.uber.org/multierr" + "golang.org/x/crypto/ssh" +) + +func ChangeUnixPass(ctx context.Context, sc *ssh.Client, user string, pwd string) error { + if !utils.ValidateUnixUsername(user) { + return fmt.Errorf("invalid unix username") + } + + return NewSSHSession(ctx, time.Second*10, sc, func(sess *ssh.Session) error { + stderr := bytes.NewBuffer(nil) + sess.Stderr = stderr + + stdinPipe, err := sess.StdinPipe() + if err != nil { + return errors.Wrap(err, "stdin pipe") + } + + err = sess.Start("passwd " + shellescape.Quote(user)) + if err != nil { + return errors.Wrap(err, "start change user password cmd") + } + + pwdBytes := []byte(pwd + "\n") + defer func() { + // Clearing the memory up for security. + for i := range pwdBytes { + pwdBytes[i] = 0 + } + }() + + go func() { + // Writing the password. We're doing this two times + // as we need to confirm the password. + _, _ = stdinPipe.Write(pwdBytes) + _, _ = stdinPipe.Write(pwdBytes) + }() + + err = sess.Wait() + if err != nil { + return multierr.Combine(utils.WrapErrWithLog(err, "wait for change user password cmd", stderr.String())) + } + + return nil + }) +} diff --git a/utils/utils.go b/utils/utils.go index 0bc71cd..11a8463 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -21,7 +21,7 @@ func ClearUnprintableChars(s string, allowNewlines bool) string { }, s) } -var devNameRegexp = regexp.MustCompile("^[0-9a-z_-]+$") +var devNameRegexp = regexp.MustCompile(`^[0-9a-z_-]+$`) func ValidateDevName(s string) bool { // Allow mapped devices. @@ -30,6 +30,12 @@ func ValidateDevName(s string) bool { return devNameRegexp.MatchString(s) } +var unixUsernameRegexp = regexp.MustCompile(`^[a-z_]([a-z0-9_-]{0,31}|[a-z0-9_-]{0,30}\$)$`) + +func ValidateUnixUsername(s string) bool { + return unixUsernameRegexp.MatchString(s) +} + func Uint16ToBytesBE(v uint16) []byte { b := make([]byte, 2) binary.BigEndian.PutUint16(b, v) diff --git a/vm/filemanager.go b/vm/filemanager.go index 63bb253..9b99757 100644 --- a/vm/filemanager.go +++ b/vm/filemanager.go @@ -12,6 +12,7 @@ import ( "syscall" "time" + "github.com/AlexSSD7/linsk/sshutil" "github.com/AlexSSD7/linsk/utils" "github.com/alessio/shellescape" "github.com/pkg/errors" @@ -42,7 +43,7 @@ func (fm *FileManager) Init() error { defer func() { _ = sc.Close() }() - _, err = runSSHCmd(sc, "vgchange -ay") + _, err = sshutil.RunSSHCmd(fm.vm.ctx, sc, "vgchange -ay") if err != nil { return errors.Wrap(err, "run vgchange cmd") } @@ -56,25 +57,12 @@ func (fm *FileManager) Lsblk() ([]byte, error) { return nil, errors.Wrap(err, "dial vm ssh") } - defer func() { _ = sc.Close() }() - - sess, err := sc.NewSession() - if err != nil { - return nil, errors.Wrap(err, "create new vm ssh session") - } - - defer func() { _ = sess.Close() }() - - ret := new(bytes.Buffer) - - sess.Stdout = ret - - err = sess.Run("lsblk -o NAME,SIZE,FSTYPE,LABEL -e 7,11,2,253") + ret, err := sshutil.RunSSHCmd(fm.vm.ctx, sc, "lsblk -o NAME,SIZE,FSTYPE,LABEL -e 7,11,2,253") if err != nil { return nil, errors.Wrap(err, "run lsblk") } - return ret.Bytes(), nil + return ret, nil } type MountOptions struct { @@ -87,105 +75,79 @@ const luksDMName = "cryptmnt" func (fm *FileManager) luksOpen(sc *ssh.Client, fullDevPath string) error { lg := fm.logger.With("vm-path", fullDevPath) - sess, err := sc.NewSession() - if err != nil { - return errors.Wrap(err, "create new vm ssh session") - } + return sshutil.NewSSHSessionWithDelayedTimeout(fm.vm.ctx, time.Second*15, sc, func(sess *ssh.Session, startTimeout func()) error { + stdinPipe, err := sess.StdinPipe() + if err != nil { + return errors.Wrap(err, "create vm ssh session stdin pipe") + } - stdinPipe, err := sess.StdinPipe() - if err != nil { - return errors.Wrap(err, "create vm ssh session stdin pipe") - } + stderrBuf := bytes.NewBuffer(nil) + sess.Stderr = stderrBuf - stderrBuf := bytes.NewBuffer(nil) - sess.Stderr = stderrBuf + err = sess.Start("cryptsetup luksOpen " + shellescape.Quote(fullDevPath) + " " + luksDMName) + if err != nil { + return errors.Wrap(err, "start cryptsetup luksopen cmd") + } - err = sess.Start("cryptsetup luksOpen " + shellescape.Quote(fullDevPath) + " " + luksDMName) - if err != nil { - return errors.Wrap(err, "start cryptsetup luksopen cmd") - } + lg.Info("Attempting to open a LUKS device") - lg.Info("Attempting to open a LUKS device") + _, err = os.Stderr.Write([]byte("Enter Password: ")) + if err != nil { + return errors.Wrap(err, "write prompt to stderr") + } - _, err = os.Stderr.Write([]byte("Enter Password: ")) - if err != nil { - return errors.Wrap(err, "write prompt to stderr") - } + pwd, err := term.ReadPassword(int(syscall.Stdin)) + if err != nil { + return errors.Wrap(err, "read luks password") + } - pwd, err := term.ReadPassword(int(syscall.Stdin)) - if err != nil { - return errors.Wrap(err, "read luks password") - } + fmt.Print("\n") - fmt.Print("\n") + // We start the timeout countdown now only to avoid timing out + // while the user is entering the password, or shortly after that. + startTimeout() - var wErr error - var wWG sync.WaitGroup + var wErr error + var wWG sync.WaitGroup - wWG.Add(1) - go func() { - defer wWG.Done() + wWG.Add(1) + go func() { + defer wWG.Done() - _, err := stdinPipe.Write(pwd) - _, err2 := stdinPipe.Write([]byte("\n")) - wErr = errors.Wrap(multierr.Combine(err, err2), "write password to stdin") - }() + _, err := stdinPipe.Write(pwd) + _, err2 := stdinPipe.Write([]byte("\n")) + wErr = errors.Wrap(multierr.Combine(err, err2), "write password to stdin") + }() - defer func() { - // Clear the memory up. - { - for i := 0; i < len(pwd); i++ { - pwd[i] = 0 - } + defer func() { + // Clear the memory up for security. + { + for i := 0; i < len(pwd); i++ { + pwd[i] = 0 + } - for i := 0; i < 4; i++ { + // This is my paranoia. + _, _ = rand.Read(pwd) _, _ = rand.Read(pwd) } - } - }() + }() - done := make(chan struct{}) - defer close(done) + err = sess.Wait() + if err != nil { + if strings.Contains(stderrBuf.String(), "Not enough available memory to open a keyslot.") { + fm.logger.Warn("Detected not enough memory to open a LUKS device, please allocate more memory using --vm-mem-alloc flag.") + } - var timedOut bool - - go func() { - tm := func() { - timedOut = true - _ = sc.Close() - } - select { - case <-fm.vm.ctx.Done(): - tm() - case <-time.After(time.Second * 1): - tm() - case <-done: - } - }() - - checkTimeoutErr := func(err error) error { - if timedOut { - return fmt.Errorf("timed out (%w)", err) + return utils.WrapErrWithLog(err, "wait for cryptsetup luksopen cmd to finish", stderrBuf.String()) } - return err - } + lg.Info("LUKS device opened successfully") - err = sess.Wait() - if err != nil { - if strings.Contains(stderrBuf.String(), "Not enough available memory to open a keyslot.") { - fm.logger.Warn("Detected not enough memory to open a LUKS device, please allocate more memory using --vm-mem-alloc flag.") - } + _ = stdinPipe.Close() + wWG.Wait() - return checkTimeoutErr(utils.WrapErrWithLog(err, "wait for cryptsetup luksopen cmd to finish", stderrBuf.String())) - } - - lg.Info("LUKS device opened successfully") - - _ = stdinPipe.Close() - wWG.Wait() - - return wErr + return wErr + }) } func (fm *FileManager) Mount(devName string, mo MountOptions) error { @@ -199,6 +161,10 @@ func (fm *FileManager) Mount(devName string, mo MountOptions) error { return fmt.Errorf("bad device name") } + // We're intentionally not calling filepath.Clean() as + // this causes unintended consequences when run on Windows. + // (Windows Go standard library treats the path as it's for + // Windows, but we're targeting a Linux VM.) fullDevPath := "/dev/" + devName if mo.FSType == "" { @@ -221,42 +187,16 @@ func (fm *FileManager) Mount(devName string, mo MountOptions) error { fullDevPath = "/dev/mapper/" + luksDMName } - done := make(chan struct{}) - defer close(done) - - var timedOut bool - - go func() { - tm := func() { - timedOut = true - _ = sc.Close() - } - select { - case <-fm.vm.ctx.Done(): - tm() - case <-time.After(time.Second * 10): - tm() - case <-done: - } - }() - - checkTimeoutErr := func(err error) error { - if timedOut { - return fmt.Errorf("timed out (%w)", err) - } - - return err - } - - _, err = runSSHCmd(sc, "mount -t "+shellescape.Quote(mo.FSType)+" "+shellescape.Quote(fullDevPath)+" /mnt") + _, err = sshutil.RunSSHCmd(fm.vm.ctx, sc, "mount -t "+shellescape.Quote(mo.FSType)+" "+shellescape.Quote(fullDevPath)+" /mnt") if err != nil { - return checkTimeoutErr(errors.Wrap(err, "run mount cmd")) + return errors.Wrap(err, "run mount cmd") } return nil } -func (fm *FileManager) StartFTP(pwd []byte, passivePortStart uint16, passivePortCount uint16) error { +func (fm *FileManager) StartFTP(pwd string, passivePortStart uint16, passivePortCount uint16) error { + // This timeout is for the SCP client exclusively. scpCtx, scpCtxCancel := context.WithTimeout(fm.vm.ctx, time.Second*5) defer scpCtxCancel() @@ -292,74 +232,16 @@ pasv_address=127.0.0.1 return errors.Wrap(err, "dial ssh") } - done := make(chan struct{}) - defer close(done) - - var timedOut bool - - go func() { - tm := func() { - timedOut = true - _ = sc.Close() - } - select { - case <-fm.vm.ctx.Done(): - tm() - case <-time.After(time.Second * 15): - tm() - case <-done: - } - }() - - checkTimeoutErr := func(err error) error { - if timedOut { - return fmt.Errorf("timed out (%w)", err) - } - - return err - } - defer func() { _ = sc.Close() }() - _, err = runSSHCmd(sc, "rc-update add vsftpd && rc-service vsftpd start") + _, err = sshutil.RunSSHCmd(fm.vm.ctx, sc, "rc-update add vsftpd && rc-service vsftpd start") if err != nil { - return checkTimeoutErr(errors.Wrap(err, "add and start ftpd service")) + return errors.Wrap(err, "add and start ftpd service") } - sess, err := sc.NewSession() + err = sshutil.ChangeUnixPass(fm.vm.ctx, sc, "linsk", pwd) if err != nil { - return checkTimeoutErr(errors.Wrap(err, "create new ssh session")) - } - - pwd = append(pwd, '\n') - - stderr := bytes.NewBuffer(nil) - sess.Stderr = stderr - - stdinPipe, err := sess.StdinPipe() - if err != nil { - return checkTimeoutErr(errors.Wrap(err, "stdin pipe")) - } - - err = sess.Start("passwd linsk") - if err != nil { - return checkTimeoutErr(errors.Wrap(err, "start change user password cmd")) - } - - go func() { - _, err = stdinPipe.Write(pwd) - if err != nil { - fm.vm.logger.Error("Failed to write FTP password to passwd stdin", "error", err.Error()) - } - _, err = stdinPipe.Write(pwd) - if err != nil { - fm.vm.logger.Error("Failed to write repeated FTP password to passwd stdin", "error", err.Error()) - } - }() - - err = sess.Wait() - if err != nil { - return checkTimeoutErr(utils.WrapErrWithLog(err, "wait for change user password cmd", stderr.String())) + return errors.Wrap(err, "change unix pass") } return nil diff --git a/vm/os_specifics_windows.go b/vm/os_specifics_windows.go index 72ff773..0bef4d0 100644 --- a/vm/os_specifics_windows.go +++ b/vm/os_specifics_windows.go @@ -9,6 +9,7 @@ import ( ) func prepareVMCmd(cmd *exec.Cmd) { + // This is to prevent Ctrl+C propagating to the child process. cmd.SysProcAttr = &syscall.SysProcAttr{ CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP, } diff --git a/vm/ssh.go b/vm/ssh.go index 18c3173..bf66feb 100644 --- a/vm/ssh.go +++ b/vm/ssh.go @@ -2,8 +2,6 @@ package vm import ( "bytes" - "crypto/rand" - "crypto/rsa" "encoding/base64" "fmt" "net" @@ -11,6 +9,7 @@ import ( "strings" "time" + "github.com/AlexSSD7/linsk/sshutil" "github.com/AlexSSD7/linsk/utils" "github.com/alessio/shellescape" "github.com/pkg/errors" @@ -74,6 +73,8 @@ func (vm *VM) scanSSHIdentity() ([]byte, error) { continue } + // This isn't clean at all, but there is no better + // way to achieve an exit status check like this. prefix := []byte("SERIAL STATUS: ") if bytes.HasPrefix(data, prefix) { if len(data) == len(prefix) { @@ -95,7 +96,7 @@ func (vm *VM) scanSSHIdentity() ([]byte, error) { func (vm *VM) sshSetup() (ssh.Signer, error) { vm.resetSerialStdout() - sshSigner, sshPublicKey, err := generateSSHKey() + sshSigner, sshPublicKey, err := sshutil.GenerateSSHKey() if err != nil { return nil, errors.Wrap(err, "generate ssh key") } @@ -123,6 +124,8 @@ func (vm *VM) sshSetup() (ssh.Signer, error) { case <-time.After(time.Until(deadline)): return nil, fmt.Errorf("setup command timed out %v", utils.GetLogErrMsg(stdOutErrBuf.String(), "stdout/stderr log")) case data := <-vm.serialStdoutCh: + // This isn't clean at all, but there is no better + // way to achieve an exit status check like this. prefix := []byte("SERIAL STATUS: ") stdOutErrBuf.WriteString(utils.ClearUnprintableChars(string(data), true)) if bytes.HasPrefix(data, prefix) { @@ -131,7 +134,10 @@ func (vm *VM) sshSetup() (ssh.Signer, error) { } if data[len(prefix)] != '0' { + // A non-pretty yet effective debug print to assist with debugging + // in case something ever goes wrong. fmt.Fprintf(os.Stderr, "SSH SETUP FAILURE:\n%v", stdOutErrBuf.String()) + return nil, fmt.Errorf("non-zero setup command status code: '%v' %v", string(data[len(prefix)]), utils.GetLogErrMsg(stdOutErrBuf.String(), "stdout/stderr log")) } @@ -140,41 +146,3 @@ func (vm *VM) sshSetup() (ssh.Signer, error) { } } } - -func generateSSHKey() (ssh.Signer, []byte, error) { - privateKey, err := rsa.GenerateKey(rand.Reader, 4096) - if err != nil { - return nil, nil, errors.Wrap(err, "generate rsa private key") - } - - signer, err := ssh.NewSignerFromKey(privateKey) - if err != nil { - return nil, nil, errors.Wrap(err, "create signer from key") - } - - return signer, ssh.MarshalAuthorizedKey(signer.PublicKey()), nil -} - -func runSSHCmd(c *ssh.Client, cmd string) ([]byte, error) { - sess, err := c.NewSession() - if err != nil { - return nil, errors.Wrap(err, "create new vm ssh session") - } - - // TODO: Timeouts - - defer func() { _ = sess.Close() }() - - stdout := bytes.NewBuffer(nil) - stderr := bytes.NewBuffer(nil) - - sess.Stdout = stdout - sess.Stderr = stderr - - err = sess.Run(cmd) - if err != nil { - return nil, utils.WrapErrWithLog(err, "run cmd", stderr.String()) - } - - return stdout.Bytes(), nil -} diff --git a/vm/types.go b/vm/types.go index c844f87..b9401b3 100644 --- a/vm/types.go +++ b/vm/types.go @@ -24,7 +24,7 @@ func ParsePortForwardString(s string) (PortForwardingRule, error) { split := strings.Split(s, ":") switch len(split) { case 2: - // : + // Format: : hostPort, err := strconv.ParseUint(split[0], 10, 16) if err != nil { return PortForwardingRule{}, errors.Wrap(err, "parse host port") @@ -40,7 +40,7 @@ func ParsePortForwardString(s string) (PortForwardingRule, error) { VMPort: uint16(vmPort), }, nil case 3: - // :: + // Format: :: hostIP := net.ParseIP(split[0]) if hostIP == nil { return PortForwardingRule{}, fmt.Errorf("bad host ip") diff --git a/vm/vm.go b/vm/vm.go index 48dff19..c019f17 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -17,6 +17,7 @@ import ( "log/slog" + "github.com/AlexSSD7/linsk/sshutil" "github.com/AlexSSD7/linsk/utils" "github.com/alessio/shellescape" "github.com/bramvdbogaerde/go-scp" @@ -91,8 +92,6 @@ func NewVM(logger *slog.Logger, cfg VMConfig) (*VM, error) { return nil, errors.Wrap(err, "get free port for ssh server") } - // TODO: Configurable memory allocation - cmdArgs := []string{"-serial", "stdio", "-m", fmt.Sprint(cfg.MemoryAlloc), "-smp", fmt.Sprint(runtime.NumCPU())} baseCmd := "qemu-system" @@ -373,7 +372,7 @@ func (vm *VM) Cancel() error { } } else { vm.logger.Warn("Sending poweroff command to the VM") - _, err = runSSHCmd(sc, "poweroff") + _, err = sshutil.RunSSHCmd(context.Background(), sc, "poweroff") _ = sc.Close() if err != nil { vm.logger.Warn("Could not power off the VM safely", "error", err.Error())