Polish code
This commit is contained in:
parent
8af3ab3fb6
commit
0870f8113a
10 changed files with 246 additions and 238 deletions
|
|
@ -38,7 +38,7 @@ func NewBuildContext(logger *slog.Logger, baseISOPath string, outPath string, sh
|
||||||
return nil, errors.Wrap(err, "stat output file")
|
return nil, errors.Wrap(err, "stat output file")
|
||||||
}
|
}
|
||||||
|
|
||||||
// File doesn't exist. Continuing.
|
// File doesn't exist. Going forward with creating a new .qcow2 image.
|
||||||
} else {
|
} else {
|
||||||
return nil, fmt.Errorf("output file already exists")
|
return nil, fmt.Errorf("output file already exists")
|
||||||
}
|
}
|
||||||
|
|
@ -183,8 +183,6 @@ func runAlpineSetupCmd(sc *ssh.Client, pkgs []string) error {
|
||||||
return errors.Wrap(err, "new session")
|
return errors.Wrap(err, "new session")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Timeout for this command.
|
|
||||||
|
|
||||||
stderr := bytes.NewBuffer(nil)
|
stderr := bytes.NewBuffer(nil)
|
||||||
sess.Stderr = stderr
|
sess.Stderr = stderr
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -61,7 +61,7 @@ var runCmd = &cobra.Command{
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
err = fm.StartFTP([]byte(sharePWD), networkSharePort+1, ftpPassivePortCount)
|
err = fm.StartFTP(sharePWD, networkSharePort+1, ftpPassivePortCount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("Failed to start FTP server", "error", err.Error())
|
slog.Error("Failed to start FTP server", "error", err.Error())
|
||||||
return 1
|
return 1
|
||||||
|
|
|
||||||
97
sshutil/ssh.go
Normal file
97
sshutil/ssh.go
Normal file
|
|
@ -0,0 +1,97 @@
|
||||||
|
package sshutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/AlexSSD7/linsk/utils"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GenerateSSHKey() (ssh.Signer, []byte, error) {
|
||||||
|
privateKey, err := rsa.GenerateKey(rand.Reader, 4096)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, errors.Wrap(err, "generate rsa private key")
|
||||||
|
}
|
||||||
|
|
||||||
|
signer, err := ssh.NewSignerFromKey(privateKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, errors.Wrap(err, "create signer from key")
|
||||||
|
}
|
||||||
|
|
||||||
|
return signer, ssh.MarshalAuthorizedKey(signer.PublicKey()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func RunSSHCmd(ctx context.Context, sc *ssh.Client, cmd string) ([]byte, error) {
|
||||||
|
var ret []byte
|
||||||
|
err := NewSSHSession(ctx, time.Second*15, sc, func(sess *ssh.Session) error {
|
||||||
|
stdout := bytes.NewBuffer(nil)
|
||||||
|
stderr := bytes.NewBuffer(nil)
|
||||||
|
|
||||||
|
sess.Stdout = stdout
|
||||||
|
sess.Stderr = stderr
|
||||||
|
|
||||||
|
err := sess.Run(cmd)
|
||||||
|
if err != nil {
|
||||||
|
return utils.WrapErrWithLog(err, "run cmd", stderr.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
ret = stdout.Bytes()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return ret, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSSHSession(ctx context.Context, timeout time.Duration, sc *ssh.Client, fn func(*ssh.Session) error) error {
|
||||||
|
return NewSSHSessionWithDelayedTimeout(ctx, timeout, sc, func(sess *ssh.Session, startTimeout func()) error {
|
||||||
|
startTimeout()
|
||||||
|
return fn(sess)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSSHSessionWithDelayedTimeout(ctx context.Context, timeout time.Duration, sc *ssh.Client, fn func(sess *ssh.Session, startTimeout func()) error) error {
|
||||||
|
s, err := sc.NewSession()
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "create new ssh session")
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
defer close(done)
|
||||||
|
|
||||||
|
var timedOut bool
|
||||||
|
|
||||||
|
// Start a thread to handle context cancelation.
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
timedOut = true
|
||||||
|
_ = sc.Close()
|
||||||
|
case <-done:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = fn(s, func() {
|
||||||
|
// Now start a thread which will close the session
|
||||||
|
// down when the timeout hits.
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-time.After(timeout):
|
||||||
|
timedOut = true
|
||||||
|
_ = sc.Close()
|
||||||
|
case <-done:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
})
|
||||||
|
if timedOut {
|
||||||
|
return fmt.Errorf("timed out (%w)", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
57
sshutil/unix.go
Normal file
57
sshutil/unix.go
Normal file
|
|
@ -0,0 +1,57 @@
|
||||||
|
package sshutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/AlexSSD7/linsk/utils"
|
||||||
|
"github.com/alessio/shellescape"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"go.uber.org/multierr"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ChangeUnixPass(ctx context.Context, sc *ssh.Client, user string, pwd string) error {
|
||||||
|
if !utils.ValidateUnixUsername(user) {
|
||||||
|
return fmt.Errorf("invalid unix username")
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewSSHSession(ctx, time.Second*10, sc, func(sess *ssh.Session) error {
|
||||||
|
stderr := bytes.NewBuffer(nil)
|
||||||
|
sess.Stderr = stderr
|
||||||
|
|
||||||
|
stdinPipe, err := sess.StdinPipe()
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "stdin pipe")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = sess.Start("passwd " + shellescape.Quote(user))
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "start change user password cmd")
|
||||||
|
}
|
||||||
|
|
||||||
|
pwdBytes := []byte(pwd + "\n")
|
||||||
|
defer func() {
|
||||||
|
// Clearing the memory up for security.
|
||||||
|
for i := range pwdBytes {
|
||||||
|
pwdBytes[i] = 0
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
// Writing the password. We're doing this two times
|
||||||
|
// as we need to confirm the password.
|
||||||
|
_, _ = stdinPipe.Write(pwdBytes)
|
||||||
|
_, _ = stdinPipe.Write(pwdBytes)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = sess.Wait()
|
||||||
|
if err != nil {
|
||||||
|
return multierr.Combine(utils.WrapErrWithLog(err, "wait for change user password cmd", stderr.String()))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
@ -21,7 +21,7 @@ func ClearUnprintableChars(s string, allowNewlines bool) string {
|
||||||
}, s)
|
}, s)
|
||||||
}
|
}
|
||||||
|
|
||||||
var devNameRegexp = regexp.MustCompile("^[0-9a-z_-]+$")
|
var devNameRegexp = regexp.MustCompile(`^[0-9a-z_-]+$`)
|
||||||
|
|
||||||
func ValidateDevName(s string) bool {
|
func ValidateDevName(s string) bool {
|
||||||
// Allow mapped devices.
|
// Allow mapped devices.
|
||||||
|
|
@ -30,6 +30,12 @@ func ValidateDevName(s string) bool {
|
||||||
return devNameRegexp.MatchString(s)
|
return devNameRegexp.MatchString(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var unixUsernameRegexp = regexp.MustCompile(`^[a-z_]([a-z0-9_-]{0,31}|[a-z0-9_-]{0,30}\$)$`)
|
||||||
|
|
||||||
|
func ValidateUnixUsername(s string) bool {
|
||||||
|
return unixUsernameRegexp.MatchString(s)
|
||||||
|
}
|
||||||
|
|
||||||
func Uint16ToBytesBE(v uint16) []byte {
|
func Uint16ToBytesBE(v uint16) []byte {
|
||||||
b := make([]byte, 2)
|
b := make([]byte, 2)
|
||||||
binary.BigEndian.PutUint16(b, v)
|
binary.BigEndian.PutUint16(b, v)
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/AlexSSD7/linsk/sshutil"
|
||||||
"github.com/AlexSSD7/linsk/utils"
|
"github.com/AlexSSD7/linsk/utils"
|
||||||
"github.com/alessio/shellescape"
|
"github.com/alessio/shellescape"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
@ -42,7 +43,7 @@ func (fm *FileManager) Init() error {
|
||||||
|
|
||||||
defer func() { _ = sc.Close() }()
|
defer func() { _ = sc.Close() }()
|
||||||
|
|
||||||
_, err = runSSHCmd(sc, "vgchange -ay")
|
_, err = sshutil.RunSSHCmd(fm.vm.ctx, sc, "vgchange -ay")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "run vgchange cmd")
|
return errors.Wrap(err, "run vgchange cmd")
|
||||||
}
|
}
|
||||||
|
|
@ -56,25 +57,12 @@ func (fm *FileManager) Lsblk() ([]byte, error) {
|
||||||
return nil, errors.Wrap(err, "dial vm ssh")
|
return nil, errors.Wrap(err, "dial vm ssh")
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() { _ = sc.Close() }()
|
ret, err := sshutil.RunSSHCmd(fm.vm.ctx, sc, "lsblk -o NAME,SIZE,FSTYPE,LABEL -e 7,11,2,253")
|
||||||
|
|
||||||
sess, err := sc.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "create new vm ssh session")
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() { _ = sess.Close() }()
|
|
||||||
|
|
||||||
ret := new(bytes.Buffer)
|
|
||||||
|
|
||||||
sess.Stdout = ret
|
|
||||||
|
|
||||||
err = sess.Run("lsblk -o NAME,SIZE,FSTYPE,LABEL -e 7,11,2,253")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "run lsblk")
|
return nil, errors.Wrap(err, "run lsblk")
|
||||||
}
|
}
|
||||||
|
|
||||||
return ret.Bytes(), nil
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type MountOptions struct {
|
type MountOptions struct {
|
||||||
|
|
@ -87,11 +75,7 @@ const luksDMName = "cryptmnt"
|
||||||
func (fm *FileManager) luksOpen(sc *ssh.Client, fullDevPath string) error {
|
func (fm *FileManager) luksOpen(sc *ssh.Client, fullDevPath string) error {
|
||||||
lg := fm.logger.With("vm-path", fullDevPath)
|
lg := fm.logger.With("vm-path", fullDevPath)
|
||||||
|
|
||||||
sess, err := sc.NewSession()
|
return sshutil.NewSSHSessionWithDelayedTimeout(fm.vm.ctx, time.Second*15, sc, func(sess *ssh.Session, startTimeout func()) error {
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "create new vm ssh session")
|
|
||||||
}
|
|
||||||
|
|
||||||
stdinPipe, err := sess.StdinPipe()
|
stdinPipe, err := sess.StdinPipe()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "create vm ssh session stdin pipe")
|
return errors.Wrap(err, "create vm ssh session stdin pipe")
|
||||||
|
|
@ -119,6 +103,10 @@ func (fm *FileManager) luksOpen(sc *ssh.Client, fullDevPath string) error {
|
||||||
|
|
||||||
fmt.Print("\n")
|
fmt.Print("\n")
|
||||||
|
|
||||||
|
// We start the timeout countdown now only to avoid timing out
|
||||||
|
// while the user is entering the password, or shortly after that.
|
||||||
|
startTimeout()
|
||||||
|
|
||||||
var wErr error
|
var wErr error
|
||||||
var wWG sync.WaitGroup
|
var wWG sync.WaitGroup
|
||||||
|
|
||||||
|
|
@ -132,52 +120,25 @@ func (fm *FileManager) luksOpen(sc *ssh.Client, fullDevPath string) error {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
// Clear the memory up.
|
// Clear the memory up for security.
|
||||||
{
|
{
|
||||||
for i := 0; i < len(pwd); i++ {
|
for i := 0; i < len(pwd); i++ {
|
||||||
pwd[i] = 0
|
pwd[i] = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < 4; i++ {
|
// This is my paranoia.
|
||||||
|
_, _ = rand.Read(pwd)
|
||||||
_, _ = rand.Read(pwd)
|
_, _ = rand.Read(pwd)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
defer close(done)
|
|
||||||
|
|
||||||
var timedOut bool
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
tm := func() {
|
|
||||||
timedOut = true
|
|
||||||
_ = sc.Close()
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-fm.vm.ctx.Done():
|
|
||||||
tm()
|
|
||||||
case <-time.After(time.Second * 1):
|
|
||||||
tm()
|
|
||||||
case <-done:
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
checkTimeoutErr := func(err error) error {
|
|
||||||
if timedOut {
|
|
||||||
return fmt.Errorf("timed out (%w)", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = sess.Wait()
|
err = sess.Wait()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strings.Contains(stderrBuf.String(), "Not enough available memory to open a keyslot.") {
|
if strings.Contains(stderrBuf.String(), "Not enough available memory to open a keyslot.") {
|
||||||
fm.logger.Warn("Detected not enough memory to open a LUKS device, please allocate more memory using --vm-mem-alloc flag.")
|
fm.logger.Warn("Detected not enough memory to open a LUKS device, please allocate more memory using --vm-mem-alloc flag.")
|
||||||
}
|
}
|
||||||
|
|
||||||
return checkTimeoutErr(utils.WrapErrWithLog(err, "wait for cryptsetup luksopen cmd to finish", stderrBuf.String()))
|
return utils.WrapErrWithLog(err, "wait for cryptsetup luksopen cmd to finish", stderrBuf.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
lg.Info("LUKS device opened successfully")
|
lg.Info("LUKS device opened successfully")
|
||||||
|
|
@ -186,6 +147,7 @@ func (fm *FileManager) luksOpen(sc *ssh.Client, fullDevPath string) error {
|
||||||
wWG.Wait()
|
wWG.Wait()
|
||||||
|
|
||||||
return wErr
|
return wErr
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fm *FileManager) Mount(devName string, mo MountOptions) error {
|
func (fm *FileManager) Mount(devName string, mo MountOptions) error {
|
||||||
|
|
@ -199,6 +161,10 @@ func (fm *FileManager) Mount(devName string, mo MountOptions) error {
|
||||||
return fmt.Errorf("bad device name")
|
return fmt.Errorf("bad device name")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We're intentionally not calling filepath.Clean() as
|
||||||
|
// this causes unintended consequences when run on Windows.
|
||||||
|
// (Windows Go standard library treats the path as it's for
|
||||||
|
// Windows, but we're targeting a Linux VM.)
|
||||||
fullDevPath := "/dev/" + devName
|
fullDevPath := "/dev/" + devName
|
||||||
|
|
||||||
if mo.FSType == "" {
|
if mo.FSType == "" {
|
||||||
|
|
@ -221,42 +187,16 @@ func (fm *FileManager) Mount(devName string, mo MountOptions) error {
|
||||||
fullDevPath = "/dev/mapper/" + luksDMName
|
fullDevPath = "/dev/mapper/" + luksDMName
|
||||||
}
|
}
|
||||||
|
|
||||||
done := make(chan struct{})
|
_, err = sshutil.RunSSHCmd(fm.vm.ctx, sc, "mount -t "+shellescape.Quote(mo.FSType)+" "+shellescape.Quote(fullDevPath)+" /mnt")
|
||||||
defer close(done)
|
|
||||||
|
|
||||||
var timedOut bool
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
tm := func() {
|
|
||||||
timedOut = true
|
|
||||||
_ = sc.Close()
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-fm.vm.ctx.Done():
|
|
||||||
tm()
|
|
||||||
case <-time.After(time.Second * 10):
|
|
||||||
tm()
|
|
||||||
case <-done:
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
checkTimeoutErr := func(err error) error {
|
|
||||||
if timedOut {
|
|
||||||
return fmt.Errorf("timed out (%w)", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = runSSHCmd(sc, "mount -t "+shellescape.Quote(mo.FSType)+" "+shellescape.Quote(fullDevPath)+" /mnt")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return checkTimeoutErr(errors.Wrap(err, "run mount cmd"))
|
return errors.Wrap(err, "run mount cmd")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fm *FileManager) StartFTP(pwd []byte, passivePortStart uint16, passivePortCount uint16) error {
|
func (fm *FileManager) StartFTP(pwd string, passivePortStart uint16, passivePortCount uint16) error {
|
||||||
|
// This timeout is for the SCP client exclusively.
|
||||||
scpCtx, scpCtxCancel := context.WithTimeout(fm.vm.ctx, time.Second*5)
|
scpCtx, scpCtxCancel := context.WithTimeout(fm.vm.ctx, time.Second*5)
|
||||||
defer scpCtxCancel()
|
defer scpCtxCancel()
|
||||||
|
|
||||||
|
|
@ -292,74 +232,16 @@ pasv_address=127.0.0.1
|
||||||
return errors.Wrap(err, "dial ssh")
|
return errors.Wrap(err, "dial ssh")
|
||||||
}
|
}
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
defer close(done)
|
|
||||||
|
|
||||||
var timedOut bool
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
tm := func() {
|
|
||||||
timedOut = true
|
|
||||||
_ = sc.Close()
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-fm.vm.ctx.Done():
|
|
||||||
tm()
|
|
||||||
case <-time.After(time.Second * 15):
|
|
||||||
tm()
|
|
||||||
case <-done:
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
checkTimeoutErr := func(err error) error {
|
|
||||||
if timedOut {
|
|
||||||
return fmt.Errorf("timed out (%w)", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() { _ = sc.Close() }()
|
defer func() { _ = sc.Close() }()
|
||||||
|
|
||||||
_, err = runSSHCmd(sc, "rc-update add vsftpd && rc-service vsftpd start")
|
_, err = sshutil.RunSSHCmd(fm.vm.ctx, sc, "rc-update add vsftpd && rc-service vsftpd start")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return checkTimeoutErr(errors.Wrap(err, "add and start ftpd service"))
|
return errors.Wrap(err, "add and start ftpd service")
|
||||||
}
|
}
|
||||||
|
|
||||||
sess, err := sc.NewSession()
|
err = sshutil.ChangeUnixPass(fm.vm.ctx, sc, "linsk", pwd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return checkTimeoutErr(errors.Wrap(err, "create new ssh session"))
|
return errors.Wrap(err, "change unix pass")
|
||||||
}
|
|
||||||
|
|
||||||
pwd = append(pwd, '\n')
|
|
||||||
|
|
||||||
stderr := bytes.NewBuffer(nil)
|
|
||||||
sess.Stderr = stderr
|
|
||||||
|
|
||||||
stdinPipe, err := sess.StdinPipe()
|
|
||||||
if err != nil {
|
|
||||||
return checkTimeoutErr(errors.Wrap(err, "stdin pipe"))
|
|
||||||
}
|
|
||||||
|
|
||||||
err = sess.Start("passwd linsk")
|
|
||||||
if err != nil {
|
|
||||||
return checkTimeoutErr(errors.Wrap(err, "start change user password cmd"))
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
_, err = stdinPipe.Write(pwd)
|
|
||||||
if err != nil {
|
|
||||||
fm.vm.logger.Error("Failed to write FTP password to passwd stdin", "error", err.Error())
|
|
||||||
}
|
|
||||||
_, err = stdinPipe.Write(pwd)
|
|
||||||
if err != nil {
|
|
||||||
fm.vm.logger.Error("Failed to write repeated FTP password to passwd stdin", "error", err.Error())
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
err = sess.Wait()
|
|
||||||
if err != nil {
|
|
||||||
return checkTimeoutErr(utils.WrapErrWithLog(err, "wait for change user password cmd", stderr.String()))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func prepareVMCmd(cmd *exec.Cmd) {
|
func prepareVMCmd(cmd *exec.Cmd) {
|
||||||
|
// This is to prevent Ctrl+C propagating to the child process.
|
||||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||||
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP,
|
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
50
vm/ssh.go
50
vm/ssh.go
|
|
@ -2,8 +2,6 @@ package vm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
|
@ -11,6 +9,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/AlexSSD7/linsk/sshutil"
|
||||||
"github.com/AlexSSD7/linsk/utils"
|
"github.com/AlexSSD7/linsk/utils"
|
||||||
"github.com/alessio/shellescape"
|
"github.com/alessio/shellescape"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
|
@ -74,6 +73,8 @@ func (vm *VM) scanSSHIdentity() ([]byte, error) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This isn't clean at all, but there is no better
|
||||||
|
// way to achieve an exit status check like this.
|
||||||
prefix := []byte("SERIAL STATUS: ")
|
prefix := []byte("SERIAL STATUS: ")
|
||||||
if bytes.HasPrefix(data, prefix) {
|
if bytes.HasPrefix(data, prefix) {
|
||||||
if len(data) == len(prefix) {
|
if len(data) == len(prefix) {
|
||||||
|
|
@ -95,7 +96,7 @@ func (vm *VM) scanSSHIdentity() ([]byte, error) {
|
||||||
func (vm *VM) sshSetup() (ssh.Signer, error) {
|
func (vm *VM) sshSetup() (ssh.Signer, error) {
|
||||||
vm.resetSerialStdout()
|
vm.resetSerialStdout()
|
||||||
|
|
||||||
sshSigner, sshPublicKey, err := generateSSHKey()
|
sshSigner, sshPublicKey, err := sshutil.GenerateSSHKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "generate ssh key")
|
return nil, errors.Wrap(err, "generate ssh key")
|
||||||
}
|
}
|
||||||
|
|
@ -123,6 +124,8 @@ func (vm *VM) sshSetup() (ssh.Signer, error) {
|
||||||
case <-time.After(time.Until(deadline)):
|
case <-time.After(time.Until(deadline)):
|
||||||
return nil, fmt.Errorf("setup command timed out %v", utils.GetLogErrMsg(stdOutErrBuf.String(), "stdout/stderr log"))
|
return nil, fmt.Errorf("setup command timed out %v", utils.GetLogErrMsg(stdOutErrBuf.String(), "stdout/stderr log"))
|
||||||
case data := <-vm.serialStdoutCh:
|
case data := <-vm.serialStdoutCh:
|
||||||
|
// This isn't clean at all, but there is no better
|
||||||
|
// way to achieve an exit status check like this.
|
||||||
prefix := []byte("SERIAL STATUS: ")
|
prefix := []byte("SERIAL STATUS: ")
|
||||||
stdOutErrBuf.WriteString(utils.ClearUnprintableChars(string(data), true))
|
stdOutErrBuf.WriteString(utils.ClearUnprintableChars(string(data), true))
|
||||||
if bytes.HasPrefix(data, prefix) {
|
if bytes.HasPrefix(data, prefix) {
|
||||||
|
|
@ -131,7 +134,10 @@ func (vm *VM) sshSetup() (ssh.Signer, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if data[len(prefix)] != '0' {
|
if data[len(prefix)] != '0' {
|
||||||
|
// A non-pretty yet effective debug print to assist with debugging
|
||||||
|
// in case something ever goes wrong.
|
||||||
fmt.Fprintf(os.Stderr, "SSH SETUP FAILURE:\n%v", stdOutErrBuf.String())
|
fmt.Fprintf(os.Stderr, "SSH SETUP FAILURE:\n%v", stdOutErrBuf.String())
|
||||||
|
|
||||||
return nil, fmt.Errorf("non-zero setup command status code: '%v' %v", string(data[len(prefix)]), utils.GetLogErrMsg(stdOutErrBuf.String(), "stdout/stderr log"))
|
return nil, fmt.Errorf("non-zero setup command status code: '%v' %v", string(data[len(prefix)]), utils.GetLogErrMsg(stdOutErrBuf.String(), "stdout/stderr log"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -140,41 +146,3 @@ func (vm *VM) sshSetup() (ssh.Signer, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateSSHKey() (ssh.Signer, []byte, error) {
|
|
||||||
privateKey, err := rsa.GenerateKey(rand.Reader, 4096)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, errors.Wrap(err, "generate rsa private key")
|
|
||||||
}
|
|
||||||
|
|
||||||
signer, err := ssh.NewSignerFromKey(privateKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, errors.Wrap(err, "create signer from key")
|
|
||||||
}
|
|
||||||
|
|
||||||
return signer, ssh.MarshalAuthorizedKey(signer.PublicKey()), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func runSSHCmd(c *ssh.Client, cmd string) ([]byte, error) {
|
|
||||||
sess, err := c.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "create new vm ssh session")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Timeouts
|
|
||||||
|
|
||||||
defer func() { _ = sess.Close() }()
|
|
||||||
|
|
||||||
stdout := bytes.NewBuffer(nil)
|
|
||||||
stderr := bytes.NewBuffer(nil)
|
|
||||||
|
|
||||||
sess.Stdout = stdout
|
|
||||||
sess.Stderr = stderr
|
|
||||||
|
|
||||||
err = sess.Run(cmd)
|
|
||||||
if err != nil {
|
|
||||||
return nil, utils.WrapErrWithLog(err, "run cmd", stderr.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
return stdout.Bytes(), nil
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ func ParsePortForwardString(s string) (PortForwardingRule, error) {
|
||||||
split := strings.Split(s, ":")
|
split := strings.Split(s, ":")
|
||||||
switch len(split) {
|
switch len(split) {
|
||||||
case 2:
|
case 2:
|
||||||
// <HOST PORT>:<VM PORT>
|
// Format: <HOST PORT>:<VM PORT>
|
||||||
hostPort, err := strconv.ParseUint(split[0], 10, 16)
|
hostPort, err := strconv.ParseUint(split[0], 10, 16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return PortForwardingRule{}, errors.Wrap(err, "parse host port")
|
return PortForwardingRule{}, errors.Wrap(err, "parse host port")
|
||||||
|
|
@ -40,7 +40,7 @@ func ParsePortForwardString(s string) (PortForwardingRule, error) {
|
||||||
VMPort: uint16(vmPort),
|
VMPort: uint16(vmPort),
|
||||||
}, nil
|
}, nil
|
||||||
case 3:
|
case 3:
|
||||||
// <HOST IP>:<HOST PORT>:<VM PORT>
|
// Format: <HOST IP>:<HOST PORT>:<VM PORT>
|
||||||
hostIP := net.ParseIP(split[0])
|
hostIP := net.ParseIP(split[0])
|
||||||
if hostIP == nil {
|
if hostIP == nil {
|
||||||
return PortForwardingRule{}, fmt.Errorf("bad host ip")
|
return PortForwardingRule{}, fmt.Errorf("bad host ip")
|
||||||
|
|
|
||||||
5
vm/vm.go
5
vm/vm.go
|
|
@ -17,6 +17,7 @@ import (
|
||||||
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
|
||||||
|
"github.com/AlexSSD7/linsk/sshutil"
|
||||||
"github.com/AlexSSD7/linsk/utils"
|
"github.com/AlexSSD7/linsk/utils"
|
||||||
"github.com/alessio/shellescape"
|
"github.com/alessio/shellescape"
|
||||||
"github.com/bramvdbogaerde/go-scp"
|
"github.com/bramvdbogaerde/go-scp"
|
||||||
|
|
@ -91,8 +92,6 @@ func NewVM(logger *slog.Logger, cfg VMConfig) (*VM, error) {
|
||||||
return nil, errors.Wrap(err, "get free port for ssh server")
|
return nil, errors.Wrap(err, "get free port for ssh server")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Configurable memory allocation
|
|
||||||
|
|
||||||
cmdArgs := []string{"-serial", "stdio", "-m", fmt.Sprint(cfg.MemoryAlloc), "-smp", fmt.Sprint(runtime.NumCPU())}
|
cmdArgs := []string{"-serial", "stdio", "-m", fmt.Sprint(cfg.MemoryAlloc), "-smp", fmt.Sprint(runtime.NumCPU())}
|
||||||
|
|
||||||
baseCmd := "qemu-system"
|
baseCmd := "qemu-system"
|
||||||
|
|
@ -373,7 +372,7 @@ func (vm *VM) Cancel() error {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
vm.logger.Warn("Sending poweroff command to the VM")
|
vm.logger.Warn("Sending poweroff command to the VM")
|
||||||
_, err = runSSHCmd(sc, "poweroff")
|
_, err = sshutil.RunSSHCmd(context.Background(), sc, "poweroff")
|
||||||
_ = sc.Close()
|
_ = sc.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
vm.logger.Warn("Could not power off the VM safely", "error", err.Error())
|
vm.logger.Warn("Could not power off the VM safely", "error", err.Error())
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue