More tiny fixes/enhancements

This commit is contained in:
AlexSSD7 2023-08-29 13:29:46 +01:00
commit 8af3ab3fb6
4 changed files with 114 additions and 35 deletions

View file

@ -2,6 +2,7 @@ package vm
import (
"bytes"
"context"
"crypto/rand"
"fmt"
"log/slog"
@ -9,6 +10,7 @@ import (
"strings"
"sync"
"syscall"
"time"
"github.com/AlexSSD7/linsk/utils"
"github.com/alessio/shellescape"
@ -90,8 +92,6 @@ func (fm *FileManager) luksOpen(sc *ssh.Client, fullDevPath string) error {
return errors.Wrap(err, "create new vm ssh session")
}
defer func() { sess.Close() }()
stdinPipe, err := sess.StdinPipe()
if err != nil {
return errors.Wrap(err, "create vm ssh session stdin pipe")
@ -131,7 +131,45 @@ func (fm *FileManager) luksOpen(sc *ssh.Client, fullDevPath string) error {
wErr = errors.Wrap(multierr.Combine(err, err2), "write password to stdin")
}()
// TODO: Timeout for this command
defer func() {
// Clear the memory up.
{
for i := 0; i < len(pwd); i++ {
pwd[i] = 0
}
for i := 0; i < 4; i++ {
_, _ = rand.Read(pwd)
}
}
}()
done := make(chan struct{})
defer close(done)
var timedOut bool
go func() {
tm := func() {
timedOut = true
_ = sc.Close()
}
select {
case <-fm.vm.ctx.Done():
tm()
case <-time.After(time.Second * 1):
tm()
case <-done:
}
}()
checkTimeoutErr := func(err error) error {
if timedOut {
return fmt.Errorf("timed out (%w)", err)
}
return err
}
err = sess.Wait()
if err != nil {
@ -139,22 +177,11 @@ func (fm *FileManager) luksOpen(sc *ssh.Client, fullDevPath string) error {
fm.logger.Warn("Detected not enough memory to open a LUKS device, please allocate more memory using --vm-mem-alloc flag.")
}
return utils.WrapErrWithLog(err, "wait for cryptsetup luksopen cmd to finish", stderrBuf.String())
return checkTimeoutErr(utils.WrapErrWithLog(err, "wait for cryptsetup luksopen cmd to finish", stderrBuf.String()))
}
lg.Info("LUKS device opened successfully")
// Clear the memory up
{
for i := 0; i < len(pwd); i++ {
pwd[i] = 0
}
for i := 0; i < 16; i++ {
_, _ = rand.Read(pwd)
}
}
_ = stdinPipe.Close()
wWG.Wait()
@ -194,15 +221,45 @@ func (fm *FileManager) Mount(devName string, mo MountOptions) error {
fullDevPath = "/dev/mapper/" + luksDMName
}
done := make(chan struct{})
defer close(done)
var timedOut bool
go func() {
tm := func() {
timedOut = true
_ = sc.Close()
}
select {
case <-fm.vm.ctx.Done():
tm()
case <-time.After(time.Second * 10):
tm()
case <-done:
}
}()
checkTimeoutErr := func(err error) error {
if timedOut {
return fmt.Errorf("timed out (%w)", err)
}
return err
}
_, err = runSSHCmd(sc, "mount -t "+shellescape.Quote(mo.FSType)+" "+shellescape.Quote(fullDevPath)+" /mnt")
if err != nil {
return errors.Wrap(err, "run mount cmd")
return checkTimeoutErr(errors.Wrap(err, "run mount cmd"))
}
return nil
}
func (fm *FileManager) StartFTP(pwd []byte, passivePortStart uint16, passivePortCount uint16) error {
scpCtx, scpCtxCancel := context.WithTimeout(fm.vm.ctx, time.Second*5)
defer scpCtxCancel()
scpClient, err := fm.vm.DialSCP()
if err != nil {
return errors.Wrap(err, "dial scp")
@ -223,7 +280,7 @@ pasv_max_port=` + fmt.Sprint(passivePortStart+passivePortCount) + `
pasv_address=127.0.0.1
`
err = scpClient.CopyFile(fm.vm.ctx, strings.NewReader(ftpdCfg), "/etc/vsftpd/vsftpd.conf", "0400")
err = scpClient.CopyFile(scpCtx, strings.NewReader(ftpdCfg), "/etc/vsftpd/vsftpd.conf", "0400")
if err != nil {
return errors.Wrap(err, "copy ftpd .conf file")
}
@ -235,16 +292,43 @@ pasv_address=127.0.0.1
return errors.Wrap(err, "dial ssh")
}
done := make(chan struct{})
defer close(done)
var timedOut bool
go func() {
tm := func() {
timedOut = true
_ = sc.Close()
}
select {
case <-fm.vm.ctx.Done():
tm()
case <-time.After(time.Second * 15):
tm()
case <-done:
}
}()
checkTimeoutErr := func(err error) error {
if timedOut {
return fmt.Errorf("timed out (%w)", err)
}
return err
}
defer func() { _ = sc.Close() }()
_, err = runSSHCmd(sc, "rc-update add vsftpd && rc-service vsftpd start")
if err != nil {
return errors.Wrap(err, "add and start ftpd service")
return checkTimeoutErr(errors.Wrap(err, "add and start ftpd service"))
}
sess, err := sc.NewSession()
if err != nil {
return errors.Wrap(err, "create new ssh session")
return checkTimeoutErr(errors.Wrap(err, "create new ssh session"))
}
pwd = append(pwd, '\n')
@ -254,14 +338,12 @@ pasv_address=127.0.0.1
stdinPipe, err := sess.StdinPipe()
if err != nil {
return errors.Wrap(err, "stdin pipe")
return checkTimeoutErr(errors.Wrap(err, "stdin pipe"))
}
// TODO: Timeout for this command
err = sess.Start("passwd linsk")
if err != nil {
return errors.Wrap(err, "start change user password cmd")
return checkTimeoutErr(errors.Wrap(err, "start change user password cmd"))
}
go func() {
@ -277,7 +359,7 @@ pasv_address=127.0.0.1
err = sess.Wait()
if err != nil {
return utils.WrapErrWithLog(err, "wait for change user password cmd", stderr.String())
return checkTimeoutErr(utils.WrapErrWithLog(err, "wait for change user password cmd", stderr.String()))
}
return nil