Implement shell --forward-ports

This commit is contained in:
AlexSSD7 2023-08-26 11:57:12 +01:00
commit 7789923672
5 changed files with 105 additions and 10 deletions

View file

@ -108,3 +108,9 @@ var shellCmd = &cobra.Command{
return nil return nil
}, },
} }
var forwardPortsFlagStr string
func init() {
shellCmd.Flags().StringVar(&forwardPortsFlagStr, "forward-ports", "", "Extra TCP port forwarding rules. Syntax: '<HOST PORT>:<VM PORT>' OR '<HOST BIND IP>:<HOST PORT>:<VM PORT>'. Multiple rules split by comma are accepted.")
}

View file

@ -5,6 +5,7 @@ import (
"os" "os"
"os/signal" "os/signal"
"os/user" "os/user"
"strings"
"sync" "sync"
"syscall" "syscall"
@ -44,8 +45,20 @@ func runVM(passthroughArg string, fn func(context.Context, *vm.Instance, *vm.Fil
passthroughConfig = []vm.USBDevicePassthroughConfig{getDevicePassthroughConfig(passthroughArg)} passthroughConfig = []vm.USBDevicePassthroughConfig{getDevicePassthroughConfig(passthroughArg)}
} }
var forwardPortsConfig []vm.PortForwardingConfig
for i, fp := range strings.Split(forwardPortsFlagStr, ",") {
fpc, err := vm.ParsePortForwardString(fp)
if err != nil {
slog.Error("Failed to parse port forward string", "index", i, "value", fp, "error", err)
os.Exit(1)
}
forwardPortsConfig = append(forwardPortsConfig, fpc)
}
// TODO: Alpine image should be downloaded from somewhere. // TODO: Alpine image should be downloaded from somewhere.
vi, err := vm.NewInstance(slog.Default().With("caller", "vm"), "alpine-img/alpine.qcow2", passthroughConfig, vmDebugFlag) vi, err := vm.NewInstance(slog.Default().With("caller", "vm"), "alpine-img/alpine.qcow2", passthroughConfig, vmDebugFlag, forwardPortsConfig)
if err != nil { if err != nil {
slog.Error("Failed to create vm instance", "error", err) slog.Error("Failed to create vm instance", "error", err)
os.Exit(1) os.Exit(1)

View file

@ -7,6 +7,7 @@ import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"net" "net"
"os"
"strings" "strings"
"time" "time"
@ -98,7 +99,7 @@ func (vi *Instance) sshSetup() (ssh.Signer, error) {
return nil, errors.Wrap(err, "generate ssh key") return nil, errors.Wrap(err, "generate ssh key")
} }
cmd := `set -ex; do_setup () { sh -c "set -ex; ifconfig eth0 up; ifconfig lo up; udhcpc; mkdir -p ~/.ssh; echo ` + shellescape.Quote(string(sshPublicKey)) + ` > ~/.ssh/authorized_keys; rc-update add sshd; service sshd start"; echo "SERIAL STATUS: $?"; }; do_setup` + "\n" cmd := `set -ex; do_setup () { sh -c "set -ex; ifconfig eth0 up; ifconfig lo up; udhcpc; mkdir -p ~/.ssh; echo ` + shellescape.Quote(string(sshPublicKey)) + ` > ~/.ssh/authorized_keys; rc-update add sshd; rc-service sshd start"; echo "SERIAL"" ""STATUS: $?"; }; do_setup` + "\n"
err = vi.writeSerial([]byte(cmd)) err = vi.writeSerial([]byte(cmd))
if err != nil { if err != nil {
@ -124,6 +125,7 @@ func (vi *Instance) sshSetup() (ssh.Signer, error) {
} }
if data[len(prefix)] != '0' { if data[len(prefix)] != '0' {
fmt.Fprintf(os.Stderr, "SSH SETUP FAILURE:\n%v", stdOutErrBuf.String())
return nil, fmt.Errorf("non-zero setup command status code: '%v' %v", string(data[len(prefix)]), getLogErrMsg(stdOutErrBuf.String())) return nil, fmt.Errorf("non-zero setup command status code: '%v' %v", string(data[len(prefix)]), getLogErrMsg(stdOutErrBuf.String()))
} }

67
vm/types.go Normal file
View file

@ -0,0 +1,67 @@
package vm
import (
"fmt"
"net"
"strconv"
"strings"
"github.com/pkg/errors"
)
type USBDevicePassthroughConfig struct {
HostBus uint8
HostPort uint8
}
type PortForwardingConfig struct {
HostIP net.IP
HostPort uint16
VMPort uint16
}
func ParsePortForwardString(s string) (PortForwardingConfig, error) {
split := strings.Split(s, ":")
switch len(split) {
case 2:
// <HOST PORT>:<VM PORT>
hostPort, err := strconv.ParseUint(split[0], 10, 16)
if err != nil {
return PortForwardingConfig{}, errors.Wrap(err, "parse host port")
}
vmPort, err := strconv.ParseUint(split[1], 10, 16)
if err != nil {
return PortForwardingConfig{}, errors.Wrap(err, "parse vm port")
}
return PortForwardingConfig{
HostPort: uint16(hostPort),
VMPort: uint16(vmPort),
}, nil
case 3:
// <HOST IP>:<HOST PORT>:<VM PORT>
hostIP := net.ParseIP(split[0])
if hostIP == nil {
return PortForwardingConfig{}, fmt.Errorf("bad host ip")
}
hostPort, err := strconv.ParseUint(split[1], 10, 16)
if err != nil {
return PortForwardingConfig{}, errors.Wrap(err, "parse host port")
}
vmPort, err := strconv.ParseUint(split[2], 10, 16)
if err != nil {
return PortForwardingConfig{}, errors.Wrap(err, "parse vm port")
}
return PortForwardingConfig{
HostIP: hostIP,
HostPort: uint16(hostPort),
VMPort: uint16(vmPort),
}, nil
default:
return PortForwardingConfig{}, fmt.Errorf("bad split by ':' length: want 2 or 3, have %v", len(split))
}
}

View file

@ -25,11 +25,6 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
type USBDevicePassthroughConfig struct {
HostBus uint8
HostPort uint8
}
type Instance struct { type Instance struct {
logger *slog.Logger logger *slog.Logger
@ -55,7 +50,7 @@ type Instance struct {
canceled uint32 canceled uint32
} }
func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDevicePassthroughConfig, debug bool) (*Instance, error) { func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDevicePassthroughConfig, debug bool, extraPortForwardings []PortForwardingConfig) (*Instance, error) {
alpineImagePath = filepath.Clean(alpineImagePath) alpineImagePath = filepath.Clean(alpineImagePath)
_, err := os.Stat(alpineImagePath) _, err := os.Stat(alpineImagePath)
if err != nil { if err != nil {
@ -72,8 +67,20 @@ func NewInstance(logger *slog.Logger, alpineImagePath string, usbDevices []USBDe
// TODO: Configurable memory allocation // TODO: Configurable memory allocation
baseCmd := "qemu-system-x86_64" baseCmd := "qemu-system-x86_64"
cmdArgs := []string{"-serial", "stdio", "-enable-kvm", "-m", "2048", "-smp", fmt.Sprint(runtime.NumCPU()), cmdArgs := []string{"-serial", "stdio", "-enable-kvm", "-m", "2048", "-smp", fmt.Sprint(runtime.NumCPU())}
"-device", "e1000,netdev=net0", "-netdev", "user,id=net0,hostfwd=tcp::" + fmt.Sprint(sshPort) + "-:22"}
netdevOpts := "user,id=net0,hostfwd=tcp:127.0.0.1:" + fmt.Sprint(sshPort) + "-:22"
for _, pf := range extraPortForwardings {
hostIPStr := ""
if pf.HostIP != nil {
hostIPStr = pf.HostIP.String()
}
netdevOpts += ",hostfwd=tcp:" + hostIPStr + ":" + fmt.Sprint(pf.HostPort) + "-:" + fmt.Sprint(pf.VMPort)
}
cmdArgs = append(cmdArgs, "-device", "e1000,netdev=net0", "-netdev", netdevOpts)
cmdArgs = append(cmdArgs, "-drive", "file="+shellescape.Quote(alpineImagePath)+",format=qcow2,if=virtio", "-snapshot") cmdArgs = append(cmdArgs, "-drive", "file="+shellescape.Quote(alpineImagePath)+",format=qcow2,if=virtio", "-snapshot")