diff --git a/cmd/run.go b/cmd/run.go index cfa254d..7a084e4 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -135,7 +135,7 @@ func init() { defaultShareType = "ftp" } - runCmd.Flags().StringVar(&shareBackendFlag, "share-backend", defaultShareType, "Specifies the file share backend to use. The default value is OS-specific.") + runCmd.Flags().StringVar(&shareBackendFlag, "share-backend", defaultShareType, `Specifies the file share backend to use. The default value is OS-specific. (available "smb", "afp", "ftp")`) runCmd.Flags().StringVar(&shareListenIPFlag, "share-listen", share.GetDefaultListenIPStr(), "Specifies the IP to bind the network share port to. NOTE: For FTP, changing the bind address is not enough to connect remotely. You should also specify --ftp-extip.") runCmd.Flags().StringVar(&ftpExtIPFlag, "ftp-extip", share.GetDefaultListenIPStr(), "Specifies the external IP the FTP server should advertise.") diff --git a/cmd/utils.go b/cmd/utils.go index a0cf5f8..7ebfcfe 100644 --- a/cmd/utils.go +++ b/cmd/utils.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "os" "path/filepath" @@ -43,7 +44,7 @@ func runVM(passthroughArg string, fn runvm.RunVMFunc, forwardPortsRules []vm.Por return 1 } - biosPath, err := store.CheckDownloadVMBIOS() + biosPath, err := store.CheckDownloadVMBIOS(context.Background()) if err != nil { slog.Error("Failed to check/download VM BIOS", "error", err.Error()) return 1 @@ -171,7 +172,7 @@ func runVM(passthroughArg string, fn runvm.RunVMFunc, forwardPortsRules []vm.Por }} } - vmCfg := vm.VMConfig{ + vmCfg := vm.Config{ Drives: []vm.DriveConfig{{ Path: vmImagePath, SnapshotMode: true, diff --git a/imgbuilder/imgbuilder.go b/imgbuilder/imgbuilder.go index 842b41f..4b3a1b1 100644 --- a/imgbuilder/imgbuilder.go +++ b/imgbuilder/imgbuilder.go @@ -47,7 +47,7 @@ func NewBuildContext(logger *slog.Logger, baseISOPath string, outPath string, sh return nil, errors.Wrap(err, "create temporary qemu image") } - vi, err := vm.NewVM(logger.With("subcaller", "vm"), vm.VMConfig{ + vi, err := vm.NewVM(logger.With("subcaller", "vm"), vm.Config{ CdromImagePath: baseISOPath, BIOSPath: biosPath, Drives: []vm.DriveConfig{{ diff --git a/nettap/alloc.go b/nettap/alloc.go index 711ac48..d995658 100644 --- a/nettap/alloc.go +++ b/nettap/alloc.go @@ -7,12 +7,12 @@ import ( "github.com/shirou/gopsutil/process" ) -type NetTapAlloc struct { +type Alloc struct { TapName string PID int } -func (a *NetTapAlloc) Validate() error { +func (a *Alloc) Validate() error { err := ValidateTapName(a.TapName) if err != nil { return errors.Wrap(err, "validate tap name") @@ -22,7 +22,7 @@ func (a *NetTapAlloc) Validate() error { return fmt.Errorf("pid is zero") } - if a.PID > int(a.PID) { + if a.PID > int(int32(a.PID)) { return fmt.Errorf("pid int32 overflow (%v)", a.PID) } @@ -31,7 +31,7 @@ func (a *NetTapAlloc) Validate() error { // The taps removed slice always returns the taps removed, even after // an error has occurred sometime while deleting non-first interfaces. -func (tm *TapManager) PruneTaps(knownAllocs []NetTapAlloc) ([]string, error) { +func (tm *TapManager) PruneTaps(knownAllocs []Alloc) ([]string, error) { var tapsRemoved []string for i, alloc := range knownAllocs { diff --git a/sshutil/pass.go b/sshutil/pass.go index 28e9bd7..613ab27 100644 --- a/sshutil/pass.go +++ b/sshutil/pass.go @@ -56,6 +56,8 @@ func genericChangePass(ctx context.Context, sc *ssh.Client, user string, pwd str }) } +type ChangePassFunc func(ctx context.Context, sc *ssh.Client, user string, pwd string) error + func ChangeUnixPass(ctx context.Context, sc *ssh.Client, user string, pwd string) error { return genericChangePass(ctx, sc, user, pwd, "passwd") } diff --git a/storage/download.go b/storage/download.go index 31f1bd9..6a5326f 100644 --- a/storage/download.go +++ b/storage/download.go @@ -2,6 +2,7 @@ package storage import ( "bytes" + "context" "crypto/sha256" "encoding/hex" "fmt" @@ -15,7 +16,7 @@ import ( "github.com/pkg/errors" ) -func (s *Storage) download(url string, hash []byte, out string, applyReaderMiddleware func(io.Reader) io.Reader) error { +func (s *Storage) download(ctx context.Context, url string, hash []byte, out string, applyReaderMiddleware func(io.Reader) io.Reader) error { var created, success bool defer func() { @@ -40,9 +41,14 @@ func (s *Storage) download(url string, hash []byte, out string, applyReaderMiddl defer func() { _ = f.Close() }() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return errors.Wrap(err, "create new http get request") + } + s.logger.Info("Starting to download file", "from", url, "to", out) - resp, err := http.Get(url) + resp, err := http.DefaultClient.Do(req) if err != nil { return errors.Wrap(err, "http get") } diff --git a/storage/nettap.go b/storage/nettap.go index 9adae79..a819b80 100644 --- a/storage/nettap.go +++ b/storage/nettap.go @@ -55,13 +55,13 @@ func (s *Storage) ReleaseNetTapAllocation(tapName string) error { return nil } -func (s *Storage) ListNetTapAllocations() ([]nettap.NetTapAlloc, error) { +func (s *Storage) ListNetTapAllocations() ([]nettap.Alloc, error) { dirEntries, err := os.ReadDir(s.path) if err != nil { return nil, errors.Wrap(err, "read data dir") } - var ret []nettap.NetTapAlloc + var ret []nettap.Alloc for _, entry := range dirEntries { if strings.HasPrefix(entry.Name(), tapAllocPrefix) { @@ -84,7 +84,7 @@ func (s *Storage) ListNetTapAllocations() ([]nettap.NetTapAlloc, error) { return nil, errors.Wrapf(err, "parse pid (alloc file '%v')", entryPath) } - ret = append(ret, nettap.NetTapAlloc{ + ret = append(ret, nettap.Alloc{ TapName: tapName, PID: int(pid), }) diff --git a/storage/storage.go b/storage/storage.go index 824a71f..8f0b235 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -2,6 +2,7 @@ package storage import ( "compress/bzip2" + "context" "fmt" "log/slog" "os" @@ -34,7 +35,7 @@ func NewStorage(logger *slog.Logger, dataDir string) (*Storage, error) { }, nil } -func (s *Storage) CheckDownloadBaseImage() (string, error) { +func (s *Storage) CheckDownloadBaseImage(ctx context.Context) (string, error) { baseImagePath := filepath.Join(s.path, constants.GetAlpineBaseImageFileName()) _, err := os.Stat(baseImagePath) if err != nil { @@ -43,7 +44,7 @@ func (s *Storage) CheckDownloadBaseImage() (string, error) { } // Image doesn't exist. Download one. - err := s.download(constants.GetAlpineBaseImageURL(), constants.GetAlpineBaseImageHash(), baseImagePath, nil) + err := s.download(ctx, constants.GetAlpineBaseImageURL(), constants.GetAlpineBaseImageHash(), baseImagePath, nil) if err != nil { return "", errors.Wrap(err, "download base alpine image") } @@ -76,13 +77,16 @@ func (s *Storage) RunCLIImageBuild(showBuilderVMDisplay bool, overwrite bool) in return 1 } - baseImagePath, err := s.CheckDownloadBaseImage() + // We're using context.Background() everywhere because this is intended + // to be executed as a blocking CLI command. + + baseImagePath, err := s.CheckDownloadBaseImage(context.Background()) if err != nil { slog.Error("Failed to check or download base VM image", "error", err.Error()) return 1 } - biosPath, err := s.CheckDownloadVMBIOS() + biosPath, err := s.CheckDownloadVMBIOS(context.Background()) if err != nil { slog.Error("Failed to check or download VM BIOS", "error", err.Error()) return 1 @@ -131,9 +135,9 @@ func (s *Storage) DataDirPath() string { return s.path } -func (s *Storage) CheckDownloadVMBIOS() (string, error) { +func (s *Storage) CheckDownloadVMBIOS(ctx context.Context) (string, error) { if runtime.GOARCH == "arm64" { - p, err := s.CheckDownloadAarch64EFIImage() + p, err := s.CheckDownloadAarch64EFIImage(ctx) if err != nil { return "", errors.Wrap(err, "check/download aarch64 efi image") } @@ -146,7 +150,7 @@ func (s *Storage) CheckDownloadVMBIOS() (string, error) { return "", nil } -func (s *Storage) CheckDownloadAarch64EFIImage() (string, error) { +func (s *Storage) CheckDownloadAarch64EFIImage(ctx context.Context) (string, error) { efiImagePath := s.GetAarch64EFIImagePath() _, err := os.Stat(efiImagePath) if err != nil { @@ -155,7 +159,7 @@ func (s *Storage) CheckDownloadAarch64EFIImage() (string, error) { } // EFI image doesn't exist. Download one. - err := s.download(constants.GetAarch64EFIImageBZ2URL(), constants.GetAarch64EFIImageHash(), efiImagePath, bzip2.NewReader) + err := s.download(ctx, constants.GetAarch64EFIImageBZ2URL(), constants.GetAarch64EFIImageHash(), efiImagePath, bzip2.NewReader) if err != nil { return "", errors.Wrap(err, "download base alpine image") } diff --git a/vm/cfg.go b/vm/cfg.go index 4d3ff3c..59b5dd7 100644 --- a/vm/cfg.go +++ b/vm/cfg.go @@ -37,7 +37,7 @@ func cleanQEMUPath(s string) string { return path } -func configureBaseVMCmd(logger *slog.Logger, cfg VMConfig) (string, []qemucli.Arg, error) { +func configureBaseVMCmd(logger *slog.Logger, cfg Config) (string, []qemucli.Arg, error) { baseCmd := "qemu-system" if osspecifics.IsWindows() { @@ -180,7 +180,7 @@ func configureVMCmdTapNetwork(tapName string) ([]qemucli.Arg, error) { return []qemucli.Arg{netdevArg, deviceArg}, nil } -func configureVMCmdNetworking(logger *slog.Logger, cfg VMConfig, sshPort uint16) ([]qemucli.Arg, error) { +func configureVMCmdNetworking(logger *slog.Logger, cfg Config, sshPort uint16) ([]qemucli.Arg, error) { // SSH port config. ports := []PortForwardingRule{{ HostIP: net.ParseIP("127.0.0.1"), @@ -211,7 +211,7 @@ func configureVMCmdNetworking(logger *slog.Logger, cfg VMConfig, sshPort uint16) return args, nil } -func configureVMCmdDrives(cfg VMConfig) ([]qemucli.Arg, error) { +func configureVMCmdDrives(cfg Config) ([]qemucli.Arg, error) { var args []qemucli.Arg for i, drive := range cfg.Drives { @@ -265,7 +265,7 @@ func configureVMCmdDrives(cfg VMConfig) ([]qemucli.Arg, error) { return args, nil } -func configureVMCmdUSBPassthrough(cfg VMConfig) []qemucli.Arg { +func configureVMCmdUSBPassthrough(cfg Config) []qemucli.Arg { var args []qemucli.Arg if len(cfg.PassthroughConfig.USB) != 0 { @@ -283,7 +283,7 @@ func configureVMCmdUSBPassthrough(cfg VMConfig) []qemucli.Arg { return args } -func configureVMCmdBlockDevicePassthrough(logger *slog.Logger, cfg VMConfig) ([]qemucli.Arg, error) { +func configureVMCmdBlockDevicePassthrough(logger *slog.Logger, cfg Config) ([]qemucli.Arg, error) { var args []qemucli.Arg if len(cfg.PassthroughConfig.Block) != 0 { diff --git a/vm/filemanager.go b/vm/filemanager.go index 16f53c2..fd50dbe 100644 --- a/vm/filemanager.go +++ b/vm/filemanager.go @@ -97,7 +97,7 @@ func (fm *FileManager) luksOpen(sc *ssh.Client, fullDevPath string) error { return errors.Wrap(err, "write prompt to stderr") } - pwd, err := term.ReadPassword(int(syscall.Stdin)) + pwd, err := term.ReadPassword(syscall.Stdin) if err != nil { return errors.Wrap(err, "read luks password") } @@ -199,17 +199,6 @@ func (fm *FileManager) Mount(devName string, mo MountOptions) error { } func (fm *FileManager) StartFTP(pwd string, passivePortStart uint16, passivePortCount uint16, extIP net.IP) error { - // This timeout is for the SCP client exclusively. - scpCtx, scpCtxCancel := context.WithTimeout(fm.vm.ctx, time.Second*5) - defer scpCtxCancel() - - scpClient, err := fm.vm.DialSCP() - if err != nil { - return errors.Wrap(err, "dial scp") - } - - defer scpClient.Close() - ftpdCfg := `anonymous_enable=NO local_enable=YES write_enable=YES @@ -223,45 +212,10 @@ pasv_max_port=` + fmt.Sprint(passivePortStart+passivePortCount) + ` pasv_address=` + extIP.String() + ` ` - err = scpClient.CopyFile(scpCtx, strings.NewReader(ftpdCfg), "/etc/vsftpd/vsftpd.conf", "0400") - if err != nil { - return errors.Wrap(err, "copy ftpd .conf file") - } - - scpClient.Close() - - sc, err := fm.vm.DialSSH() - if err != nil { - return errors.Wrap(err, "dial ssh") - } - - defer func() { _ = sc.Close() }() - - _, err = sshutil.RunSSHCmd(fm.vm.ctx, sc, "rc-update add vsftpd && rc-service vsftpd start") - if err != nil { - return errors.Wrap(err, "add and start ftpd service") - } - - err = sshutil.ChangeUnixPass(fm.vm.ctx, sc, "linsk", pwd) - if err != nil { - return errors.Wrap(err, "change unix pass") - } - - return nil + return fm.startGenericShare(pwd, ftpdCfg, "/etc/vsftpd/vsftpd.conf", "vsftpd", sshutil.ChangeUnixPass) } func (fm *FileManager) StartSMB(pwd string) error { - // This timeout is for the SCP client exclusively. - scpCtx, scpCtxCancel := context.WithTimeout(fm.vm.ctx, time.Second*5) - defer scpCtxCancel() - - scpClient, err := fm.vm.DialSCP() - if err != nil { - return errors.Wrap(err, "dial scp") - } - - defer scpClient.Close() - sambaCfg := `[global] workgroup = WORKGROUP dos charset = cp866 @@ -286,46 +240,10 @@ force user = linsk force group = linsk create mask = 0664 ` - - err = scpClient.CopyFile(scpCtx, strings.NewReader(sambaCfg), "/etc/samba/smb.conf", "0400") - if err != nil { - return errors.Wrap(err, "copy samba config file") - } - - scpClient.Close() - - sc, err := fm.vm.DialSSH() - if err != nil { - return errors.Wrap(err, "dial ssh") - } - - defer func() { _ = sc.Close() }() - - _, err = sshutil.RunSSHCmd(fm.vm.ctx, sc, "rc-update add samba && rc-service samba start") - if err != nil { - return errors.Wrap(err, "add and start samba service") - } - - err = sshutil.ChangeSambaPass(fm.vm.ctx, sc, "linsk", pwd) - if err != nil { - return errors.Wrap(err, "change samba pass") - } - - return nil + return fm.startGenericShare(pwd, sambaCfg, "/etc/samba/smb.conf", "samba", sshutil.ChangeSambaPass) } func (fm *FileManager) StartAFP(pwd string) error { - // This timeout is for the SCP client exclusively. - scpCtx, scpCtxCancel := context.WithTimeout(fm.vm.ctx, time.Second*5) - defer scpCtxCancel() - - scpClient, err := fm.vm.DialSCP() - if err != nil { - return errors.Wrap(err, "dial scp") - } - - defer scpClient.Close() - afpCfg := `[Global] [linsk] @@ -337,9 +255,24 @@ force user = linsk force group = linsk ` - err = scpClient.CopyFile(scpCtx, strings.NewReader(afpCfg), "/etc/afp.conf", "0400") + return fm.startGenericShare(pwd, afpCfg, "/etc/afp.conf", "netatalk", sshutil.ChangeUnixPass) +} + +func (fm *FileManager) startGenericShare(pwd string, cfg string, cfgPath string, rcServiceName string, changePassFunc sshutil.ChangePassFunc) error { + // This timeout is for the SCP client exclusively. + scpCtx, scpCtxCancel := context.WithTimeout(fm.vm.ctx, time.Second*5) + defer scpCtxCancel() + + scpClient, err := fm.vm.DialSCP() if err != nil { - return errors.Wrap(err, "copy netatalk config file") + return errors.Wrap(err, "dial scp") + } + + defer scpClient.Close() + + err = scpClient.CopyFile(scpCtx, strings.NewReader(cfg), cfgPath, "0400") + if err != nil { + return errors.Wrap(err, "copy config file") } scpClient.Close() @@ -351,14 +284,14 @@ force group = linsk defer func() { _ = sc.Close() }() - _, err = sshutil.RunSSHCmd(fm.vm.ctx, sc, "rc-update add netatalk && rc-service netatalk start") + _, err = sshutil.RunSSHCmd(fm.vm.ctx, sc, "rc-update add "+shellescape.Quote(rcServiceName)+" && rc-service "+shellescape.Quote(rcServiceName)+" start") if err != nil { - return errors.Wrap(err, "add and start netatalk service") + return errors.Wrap(err, "add and start rc service") } - err = sshutil.ChangeUnixPass(fm.vm.ctx, sc, "linsk", pwd) + err = changePassFunc(fm.vm.ctx, sc, "linsk", pwd) if err != nil { - return errors.Wrap(err, "change unix pass") + return errors.Wrap(err, "change pass") } return nil diff --git a/vm/vm.go b/vm/vm.go index e15822c..a7e1a7e 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -52,7 +52,7 @@ type VM struct { disposed uint32 canceled uint32 - originalCfg VMConfig + originalCfg Config } type DriveConfig struct { @@ -64,7 +64,7 @@ type TapConfig struct { Name string } -type VMConfig struct { +type Config struct { CdromImagePath string BIOSPath string Drives []DriveConfig @@ -87,7 +87,7 @@ type VMConfig struct { InstallBaseUtilities bool } -func NewVM(logger *slog.Logger, cfg VMConfig) (*VM, error) { +func NewVM(logger *slog.Logger, cfg Config) (*VM, error) { sshPort, err := freeport.GetFreePort() if err != nil { return nil, errors.Wrap(err, "get free port for ssh server")