From 0f74f4aa8190848388b139abc0ae015144f02a37 Mon Sep 17 00:00:00 2001 From: AlexSSD7 Date: Wed, 30 Aug 2023 09:19:02 +0100 Subject: [PATCH] Working VM image retrieval --- cmd/root.go | 19 +++++++++++++++++++ cmd/utils.go | 15 ++++++++++++++- storage/storage.go | 29 ++++++++++++++++++++++++----- vm/ssh.go | 15 +++++++++------ 4 files changed, 66 insertions(+), 12 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index 6bc5c84..d20958d 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -2,6 +2,8 @@ package cmd import ( "os" + "path/filepath" + "runtime" "log/slog" @@ -29,6 +31,7 @@ var unrestrictedNetworkingFlag bool var vmMemAllocFlag uint32 var vmSSHSetupTimeoutFlag uint32 var vmOSUpTimeoutFlag uint32 +var dataDirFlag string // TODO: Version command. @@ -44,4 +47,20 @@ func init() { rootCmd.PersistentFlags().Uint32Var(&vmMemAllocFlag, "vm-mem-alloc", 512, "Specifies the VM memory allocation in KiB") rootCmd.PersistentFlags().Uint32Var(&vmOSUpTimeoutFlag, "vm-os-up-timeout", 30, "Specifies the VM OS-up timeout in seconds.") rootCmd.PersistentFlags().Uint32Var(&vmSSHSetupTimeoutFlag, "vm-ssh-setup-timeout", 60, "Specifies the VM SSH server setup timeout in seconds. This cannot be lower than the OS-up timeout.") + + defaultDataDir := "linsk-data-dir" + + homeDir, err := os.UserHomeDir() + if err != nil { + slog.Error("Failed to get user home directory, will use a local directory as a fallback", "error", err.Error(), "dir", defaultDataDir) + } else { + homeDirName := ".linsk" + if runtime.GOOS == "windows" { + homeDirName = "Linsk" + } + + defaultDataDir = filepath.Join(homeDir, homeDirName) + } + + rootCmd.PersistentFlags().StringVar(&dataDirFlag, "data-dir", defaultDataDir, "Specifies the data directory (folder) to use. The VM images will be stored here.") } diff --git a/cmd/utils.go b/cmd/utils.go index 9617dbd..ed51906 100644 --- a/cmd/utils.go +++ b/cmd/utils.go @@ -14,6 +14,7 @@ import ( "log/slog" + "github.com/AlexSSD7/linsk/storage" "github.com/AlexSSD7/linsk/vm" "github.com/pkg/errors" ) @@ -51,6 +52,18 @@ func doUSBRootCheck() { } func runVM(passthroughArg string, fn func(context.Context, *vm.VM, *vm.FileManager) int, forwardPortsRules []vm.PortForwardingRule, unrestrictedNetworking bool) int { + store, err := storage.NewStorage(slog.With("caller", "storage"), dataDirFlag) + if err != nil { + slog.Error("Failed to create Linsk data storage", "error", err.Error(), "data-dir", dataDirFlag) + os.Exit(1) + } + + _, err = store.ValidateImageHashOrDownload() + if err != nil { + slog.Error("Failed to validate image hash or download image", "error", err.Error()) + os.Exit(1) + } + var passthroughConfig vm.PassthroughConfig if passthroughArg != "" { @@ -60,7 +73,7 @@ func runVM(passthroughArg string, fn func(context.Context, *vm.VM, *vm.FileManag vmCfg := vm.VMConfig{ Drives: []vm.DriveConfig{{ - Path: "alpine.qcow2", + Path: store.GetLocalImagePath(), SnapshotMode: true, }}, diff --git a/storage/storage.go b/storage/storage.go index 4a9c7d0..5b53842 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -18,7 +18,7 @@ import ( const imageURL = "http://localhost:8000/linsk-base.qcow2" -var imageHash = [64]byte{96, 134, 26, 122, 43, 140, 212, 78, 44, 123, 103, 209, 21, 36, 81, 152, 9, 177, 47, 114, 225, 117, 64, 198, 50, 151, 71, 100, 1, 92, 106, 24, 224, 254, 157, 125, 188, 118, 84, 200, 47, 11, 215, 252, 100, 173, 64, 202, 132, 110, 15, 240, 234, 223, 56, 125, 94, 94, 179, 39, 193, 215, 41, 109} +var imageHash = [64]byte{70, 23, 243, 131, 146, 197, 41, 223, 67, 223, 41, 243, 128, 147, 82, 238, 34, 24, 123, 246, 251, 117, 120, 72, 72, 64, 96, 146, 227, 199, 49, 169, 164, 33, 205, 217, 98, 255, 109, 18, 130, 203, 126, 83, 34, 4, 229, 108, 173, 22, 107, 37, 181, 17, 84, 13, 129, 110, 25, 126, 158, 50, 135, 9} type Storage struct { logger *slog.Logger @@ -41,12 +41,12 @@ func NewStorage(logger *slog.Logger, dataDir string) (*Storage, error) { }, nil } -func (s *Storage) getLocalImagePath() string { +func (s *Storage) GetLocalImagePath() string { return filepath.Join(s.path, hex.EncodeToString(imageHash[:])+".qcow2") } func (s *Storage) DownloadImage() error { - localImagePath := s.getLocalImagePath() + localImagePath := s.GetLocalImagePath() var created, success bool @@ -77,6 +77,8 @@ func (s *Storage) DownloadImage() error { defer func() { _ = f.Close() }() + s.logger.Info("Starting to download the VM image", "path", localImagePath) + resp, err := http.Get(imageURL) if err != nil { return errors.Wrap(err, "http get image") @@ -85,7 +87,7 @@ func (s *Storage) DownloadImage() error { defer func() { _ = resp.Body.Close() }() _, err = copyWithProgress(f, resp.Body, 1024, resp.ContentLength, func(i int, f float64) { - s.logger.Info("Downloading image", "url", imageURL, "percent", math.Round(f*100*100)/100, "content-length", humanize.Bytes(uint64(resp.ContentLength))) + s.logger.Info("Downloading the VM image", "url", imageURL, "percent", math.Round(f*100*100)/100, "content-length", humanize.Bytes(uint64(resp.ContentLength))) }) if err != nil { return errors.Wrap(err, "copy resp to file") @@ -96,13 +98,15 @@ func (s *Storage) DownloadImage() error { return errors.Wrap(err, "validate image hash") } + s.logger.Info("Successfully downloaded the VM image", "dst", localImagePath) + success = true return nil } func (s *Storage) ValidateImageHash() error { - localImagePath := s.getLocalImagePath() + localImagePath := s.GetLocalImagePath() f, err := os.OpenFile(localImagePath, os.O_RDONLY, 0400) if err != nil { @@ -133,9 +137,24 @@ func (s *Storage) ValidateImageHash() error { return fmt.Errorf("hash mismatch: want '%v', have '%v'", hex.EncodeToString(imageHash[:]), hex.EncodeToString(sum)) } + s.logger.Info("Validated the VM image hash", "path", localImagePath) + return nil } +func (s *Storage) ValidateImageHashOrDownload() (bool, error) { + err := s.ValidateImageHash() + if err == nil { + return false, nil + } + + if errors.Is(err, os.ErrNotExist) { + return true, errors.Wrap(s.DownloadImage(), "download image") + } + + return false, err +} + func copyWithProgress(dst io.Writer, src io.Reader, blockSize int, length int64, report func(int, float64)) (int, error) { block := make([]byte, blockSize) diff --git a/vm/ssh.go b/vm/ssh.go index edb7a3f..d1f78a9 100644 --- a/vm/ssh.go +++ b/vm/ssh.go @@ -53,26 +53,29 @@ func ParseSSHKeyScan(knownHosts []byte) (ssh.HostKeyCallback, error) { func (vm *VM) scanSSHIdentity() ([]byte, error) { vm.resetSerialStdout() - err := vm.writeSerial([]byte(`ssh-keyscan -H localhost; echo "SERIAL STATUS: $?"; rm /root/.ash_history` + "\n")) + err := vm.writeSerial([]byte(`ssh-keyscan -H 127.0.0.1 && echo "SERIAL STATUS: $?" && rm /root/.ash_history` + "\n")) if err != nil { return nil, errors.Wrap(err, "write keyscan command to serial") } deadline := time.Now().Add(time.Second * 5) - var ret bytes.Buffer + stdOutErrBuf := bytes.NewBuffer(nil) + ret := bytes.NewBuffer(nil) for { select { case <-vm.ctx.Done(): return nil, vm.ctx.Err() case <-time.After(time.Until(deadline)): - return nil, fmt.Errorf("keyscan command timed out") + return nil, fmt.Errorf("keyscan command timed out %v", utils.GetLogErrMsg(stdOutErrBuf.String(), "stdout/stderr log")) case data := <-vm.serialStdoutCh: if len(data) == 0 { continue } + stdOutErrBuf.Write(data) + // This isn't clean at all, but there is no better // way to achieve an exit status check like this. prefix := []byte("SERIAL STATUS: ") @@ -82,7 +85,7 @@ func (vm *VM) scanSSHIdentity() ([]byte, error) { } if data[len(prefix)] != '0' { - return nil, fmt.Errorf("non-zero keyscan command status code: '%v'", string(data[len(prefix)])) + return nil, fmt.Errorf("non-zero keyscan command status code: '%v' %v", string(data[len(prefix)]), utils.GetLogErrMsg(stdOutErrBuf.String(), "stdout/stderr log")) } return ret.Bytes(), nil @@ -103,10 +106,10 @@ func (vm *VM) sshSetup() (ssh.Signer, error) { installSSHDCmd := "" if vm.installSSH { - installSSHDCmd = "ifconfig eth0 up && ifconfig lo up && udhcpc; apk add openssh; " + installSSHDCmd = "apk add openssh; " } - cmd := `do_setup () { sh -c "set -ex; ` + installSSHDCmd + `mkdir -p ~/.ssh; echo ` + shellescape.Quote(string(sshPublicKey)) + ` > ~/.ssh/authorized_keys; rc-update add sshd; rc-service sshd start"; echo "SERIAL"" ""STATUS: $?"; }; do_setup` + "\n" + cmd := `do_setup () { sh -c "set -ex; ifconfig eth0 up && ifconfig lo up && udhcpc; ` + installSSHDCmd + `mkdir -p ~/.ssh; echo ` + shellescape.Quote(string(sshPublicKey)) + ` > ~/.ssh/authorized_keys; rc-update add sshd; rc-service sshd start"; echo "SERIAL"" ""STATUS: $?"; }; do_setup` + "\n" err = vm.writeSerial([]byte(cmd)) if err != nil {