Initial commit
This commit is contained in:
commit
b905244626
17 changed files with 1462 additions and 0 deletions
31
vm/errors.go
Normal file
31
vm/errors.go
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
package vm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/AlexSSD7/vldisk/utils"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrSSHUnavailable = errors.New("ssh unavailable")
|
||||
)
|
||||
|
||||
func wrapErrWithLog(err error, msg, log string) error {
|
||||
return errors.Wrapf(err, "%v %v", msg, getLogErrMsg(log))
|
||||
}
|
||||
|
||||
func getLogErrMsg(s string) string {
|
||||
logToInclude := strings.ReplaceAll(s, "\n", "\\n")
|
||||
logToInclude = strings.TrimSuffix(logToInclude, "\\n")
|
||||
logToInclude = utils.ClearUnprintableChars(logToInclude)
|
||||
|
||||
origLogLen := len(logToInclude)
|
||||
const maxLogLen = 256
|
||||
if origLogLen > maxLogLen {
|
||||
logToInclude = fmt.Sprintf("[%v chars trimmed]", origLogLen) + logToInclude[len(logToInclude)-maxLogLen:]
|
||||
}
|
||||
|
||||
return fmt.Sprintf("(log: '%v')", logToInclude)
|
||||
}
|
||||
59
vm/filemanager.go
Normal file
59
vm/filemanager.go
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
package vm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type FileManager struct {
|
||||
vi *Instance
|
||||
}
|
||||
|
||||
func NewFileManager(vi *Instance) *FileManager {
|
||||
return &FileManager{
|
||||
vi: vi,
|
||||
}
|
||||
}
|
||||
|
||||
func (fm *FileManager) Init() error {
|
||||
c, err := fm.vi.DialSSH()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "dial vm ssh")
|
||||
}
|
||||
|
||||
_, err = runSSHCmd(c, "apk add util-linux lvm2")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "install utilities")
|
||||
}
|
||||
|
||||
_, err = runSSHCmd(c, "vgchange -ay")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "run vgchange cmd")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fm *FileManager) Lsblk() ([]byte, error) {
|
||||
c, err := fm.vi.DialSSH()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "dial vm ssh")
|
||||
}
|
||||
|
||||
sess, err := c.NewSession()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "create new vm ssh session")
|
||||
}
|
||||
|
||||
ret := new(bytes.Buffer)
|
||||
|
||||
sess.Stdout = ret
|
||||
|
||||
err = sess.Run("lsblk -o NAME,SIZE,FSTYPE,LABEL -e 7,11,2,253")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "run lsblk")
|
||||
}
|
||||
|
||||
return ret.Bytes(), nil
|
||||
}
|
||||
168
vm/ssh.go
Normal file
168
vm/ssh.go
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
package vm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/alessio/shellescape"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func ParseSSHKeyScan(knownHosts []byte) (ssh.HostKeyCallback, error) {
|
||||
knownKeysMap := make(map[string][]byte)
|
||||
for _, line := range strings.Split(string(knownHosts), "\n") {
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
lineSplit := strings.Split(line, " ")
|
||||
if want, have := 3, len(lineSplit); want != have {
|
||||
return nil, fmt.Errorf("bad split ssh identity string length: want %v, have %v ('%v')", want, have, line)
|
||||
}
|
||||
|
||||
b, err := base64.StdEncoding.DecodeString(lineSplit[2])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "decode base64 public key")
|
||||
}
|
||||
|
||||
knownKeysMap[lineSplit[1]] = b
|
||||
}
|
||||
|
||||
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
knownKey, ok := knownKeysMap[key.Type()]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown key type '%v'", key.Type())
|
||||
}
|
||||
|
||||
if !bytes.Equal(key.Marshal(), knownKey) {
|
||||
return fmt.Errorf("public key mismatch")
|
||||
}
|
||||
|
||||
return nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (vi *Instance) scanSSHIdentity() ([]byte, error) {
|
||||
vi.resetSerialStdout()
|
||||
|
||||
err := vi.writeSerial([]byte(`ssh-keyscan -H localhost; echo "SERIAL STATUS: $?"` + "\n"))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "write keyscan command to serial")
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(time.Second * 5)
|
||||
|
||||
var ret bytes.Buffer
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-vi.ctx.Done():
|
||||
return nil, vi.ctx.Err()
|
||||
case <-time.After(time.Until(deadline)):
|
||||
return nil, fmt.Errorf("keyscan command timed out")
|
||||
case data := <-vi.serialStdoutCh:
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
prefix := []byte("SERIAL STATUS: ")
|
||||
if bytes.HasPrefix(data, prefix) {
|
||||
if len(data) == len(prefix) {
|
||||
return nil, fmt.Errorf("keyscan command status code did not show up")
|
||||
}
|
||||
|
||||
if data[len(prefix)] != '0' {
|
||||
return nil, fmt.Errorf("non-zero keyscan command status code: '%v'", string(data[len(prefix)]))
|
||||
}
|
||||
|
||||
return ret.Bytes(), nil
|
||||
} else if data[0] == '|' {
|
||||
ret.Write(data)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (vi *Instance) sshSetup() (ssh.Signer, error) {
|
||||
vi.resetSerialStdout()
|
||||
|
||||
sshSigner, sshPublicKey, err := generateSSHKey()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "generate ssh key")
|
||||
}
|
||||
|
||||
cmd := `set -ex; do_setup () { sh -c "set -ex; ifconfig eth0 up; ifconfig lo up; udhcpc; mkdir -p ~/.ssh; echo ` + shellescape.Quote(string(sshPublicKey)) + ` > ~/.ssh/authorized_keys; rc-update add sshd; service sshd start"; echo "SERIAL STATUS: $?"; }; do_setup` + "\n"
|
||||
|
||||
err = vi.writeSerial([]byte(cmd))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "write ssh setup serial command")
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(time.Second * 5)
|
||||
|
||||
stdOutErrBuf := bytes.NewBuffer(nil)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-vi.ctx.Done():
|
||||
return nil, vi.ctx.Err()
|
||||
case <-time.After(time.Until(deadline)):
|
||||
return nil, fmt.Errorf("setup command timed out %v", getLogErrMsg(stdOutErrBuf.String()))
|
||||
case data := <-vi.serialStdoutCh:
|
||||
prefix := []byte("SERIAL STATUS: ")
|
||||
stdOutErrBuf.Write(data)
|
||||
if bytes.HasPrefix(data, prefix) {
|
||||
if len(data) == len(prefix) {
|
||||
return nil, fmt.Errorf("setup command status code did not show up")
|
||||
}
|
||||
|
||||
if data[len(prefix)] != '0' {
|
||||
return nil, fmt.Errorf("non-zero setup command status code: '%v' %v", string(data[len(prefix)]), getLogErrMsg(stdOutErrBuf.String()))
|
||||
}
|
||||
|
||||
return sshSigner, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func generateSSHKey() (ssh.Signer, []byte, error) {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 4096)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "generate rsa private key")
|
||||
}
|
||||
|
||||
signer, err := ssh.NewSignerFromKey(privateKey)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "create signer from key")
|
||||
}
|
||||
|
||||
return signer, ssh.MarshalAuthorizedKey(signer.PublicKey()), nil
|
||||
}
|
||||
|
||||
func runSSHCmd(c *ssh.Client, cmd string) ([]byte, error) {
|
||||
sess, err := c.NewSession()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "create new vm ssh session")
|
||||
}
|
||||
|
||||
stdout := bytes.NewBuffer(nil)
|
||||
stderr := bytes.NewBuffer(nil)
|
||||
|
||||
sess.Stdout = stdout
|
||||
sess.Stderr = stderr
|
||||
|
||||
err = sess.Run(cmd)
|
||||
if err != nil {
|
||||
return nil, wrapErrWithLog(err, "run cmd", stderr.String())
|
||||
}
|
||||
|
||||
return stdout.Bytes(), nil
|
||||
}
|
||||
291
vm/vm.go
Normal file
291
vm/vm.go
Normal file
|
|
@ -0,0 +1,291 @@
|
|||
package vm
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/alessio/shellescape"
|
||||
"github.com/inconshreveable/log15"
|
||||
"github.com/phayes/freeport"
|
||||
"github.com/pkg/errors"
|
||||
"go.uber.org/multierr"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type USBDevicePassthroughConfig struct {
|
||||
HostBus uint8
|
||||
HostPort uint8
|
||||
}
|
||||
|
||||
type Instance struct {
|
||||
logger log15.Logger
|
||||
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
|
||||
cmd *exec.Cmd
|
||||
|
||||
sshMappedPort uint16
|
||||
sshConf *ssh.ClientConfig
|
||||
sshReadyCh chan struct{}
|
||||
|
||||
serialRead *io.PipeReader
|
||||
serialReader *bufio.Reader
|
||||
serialWrite *io.PipeWriter
|
||||
serialWriteMu sync.Mutex
|
||||
stderrBuf *bytes.Buffer
|
||||
|
||||
serialStdoutCh chan []byte
|
||||
|
||||
// These are to be interacted with using `atomic` package
|
||||
disposed uint32
|
||||
canceled uint32
|
||||
}
|
||||
|
||||
func NewInstance(logger log15.Logger, alpineImagePath string, usbDevices []USBDevicePassthroughConfig, debug bool) (*Instance, error) {
|
||||
alpineImagePath = filepath.Clean(alpineImagePath)
|
||||
_, err := os.Stat(alpineImagePath)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to stat alpine image path")
|
||||
}
|
||||
|
||||
sshPort, err := freeport.GetFreePort()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get free port for ssh server")
|
||||
}
|
||||
|
||||
// TODO: Disable internet access
|
||||
|
||||
// TODO: Configurable memory allocation
|
||||
|
||||
baseCmd := "qemu-system-x86_64"
|
||||
cmdArgs := []string{"-serial", "stdio", "-enable-kvm", "-m", "2048", "-smp", fmt.Sprint(runtime.NumCPU()),
|
||||
"-device", "e1000,netdev=net0", "-netdev", "user,id=net0,hostfwd=tcp::" + fmt.Sprint(sshPort) + "-:22"}
|
||||
|
||||
cmdArgs = append(cmdArgs, "-drive", "file="+shellescape.Quote(alpineImagePath)+",format=qcow2,if=virtio", "-snapshot")
|
||||
|
||||
if !debug {
|
||||
cmdArgs = append(cmdArgs, "-display", "none")
|
||||
}
|
||||
|
||||
if len(usbDevices) != 0 {
|
||||
cmdArgs = append(cmdArgs, "-usb", "-device", "nec-usb-xhci,id=xhci")
|
||||
|
||||
for _, dev := range usbDevices {
|
||||
cmdArgs = append(cmdArgs, "-device", "usb-host,hostbus="+strconv.FormatUint(uint64(dev.HostBus), 10)+",hostport="+strconv.FormatUint(uint64(dev.HostPort), 10))
|
||||
}
|
||||
}
|
||||
|
||||
sysRead, userWrite := io.Pipe()
|
||||
userRead, sysWrite := io.Pipe()
|
||||
|
||||
cmd := exec.Command(baseCmd, cmdArgs...)
|
||||
|
||||
cmd.Stdin = sysRead
|
||||
cmd.Stdout = sysWrite
|
||||
stderrBuf := bytes.NewBuffer(nil)
|
||||
cmd.Stderr = stderrBuf
|
||||
|
||||
userReader := bufio.NewReader(userRead)
|
||||
|
||||
ctx, ctxCancel := context.WithCancel(context.Background())
|
||||
|
||||
vi := &Instance{
|
||||
logger: logger,
|
||||
|
||||
ctx: ctx,
|
||||
ctxCancel: ctxCancel,
|
||||
|
||||
cmd: cmd,
|
||||
|
||||
sshMappedPort: uint16(sshPort),
|
||||
sshReadyCh: make(chan struct{}),
|
||||
|
||||
serialRead: userRead,
|
||||
serialReader: userReader,
|
||||
serialWrite: userWrite,
|
||||
stderrBuf: stderrBuf,
|
||||
}
|
||||
|
||||
vi.resetSerialStdout()
|
||||
|
||||
return vi, nil
|
||||
}
|
||||
|
||||
func (vi *Instance) Run() error {
|
||||
if atomic.AddUint32(&vi.disposed, 1) != 1 {
|
||||
return fmt.Errorf("vm disposed")
|
||||
}
|
||||
|
||||
err := vi.cmd.Start()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "start qemu cmd")
|
||||
}
|
||||
|
||||
var globalErrsMu sync.Mutex
|
||||
var globalErrs []error
|
||||
|
||||
globalErrFn := func(err error) {
|
||||
globalErrsMu.Lock()
|
||||
defer globalErrsMu.Unlock()
|
||||
|
||||
globalErrs = append(globalErrs, err, errors.Wrap(vi.Cancel(), "cancel on error"))
|
||||
}
|
||||
|
||||
vi.logger.Info("Booting the VM")
|
||||
|
||||
go func() {
|
||||
_ = vi.runSerialReader()
|
||||
_ = vi.Cancel()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
err = vi.runVMLoginHandler()
|
||||
if err != nil {
|
||||
globalErrFn(errors.Wrap(err, "run vm login handler"))
|
||||
return
|
||||
}
|
||||
|
||||
sshSigner, err := vi.sshSetup()
|
||||
if err != nil {
|
||||
globalErrFn(errors.Wrap(err, "set up ssh"))
|
||||
return
|
||||
}
|
||||
|
||||
vi.logger.Debug("Set up SSH server successfully")
|
||||
|
||||
sshKeyScan, err := vi.scanSSHIdentity()
|
||||
if err != nil {
|
||||
globalErrFn(errors.Wrap(err, "scan ssh identity"))
|
||||
return
|
||||
}
|
||||
|
||||
vi.logger.Debug("Scanned SSH identity")
|
||||
|
||||
knownHosts, err := ParseSSHKeyScan(sshKeyScan)
|
||||
if err != nil {
|
||||
// TODO: Test what actually happens in inline critical errors like this.
|
||||
globalErrFn(errors.Wrap(err, "parse ssh key scan"))
|
||||
return
|
||||
}
|
||||
|
||||
vi.sshConf = &ssh.ClientConfig{
|
||||
User: "root",
|
||||
HostKeyCallback: knownHosts,
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.PublicKeys(sshSigner),
|
||||
},
|
||||
Timeout: time.Second * 5,
|
||||
}
|
||||
|
||||
// This is to notify everyone waiting for SSH to be up that it's ready to go.
|
||||
close(vi.sshReadyCh)
|
||||
|
||||
vi.logger.Info("SSH up, the VM ready for work")
|
||||
}()
|
||||
|
||||
_, err = vi.cmd.Process.Wait()
|
||||
cancelErr := vi.Cancel()
|
||||
if err != nil {
|
||||
combinedErr := multierr.Combine(
|
||||
errors.Wrap(err, "wait for cmd to finish execution"),
|
||||
errors.Wrap(cancelErr, "cancel"),
|
||||
)
|
||||
|
||||
return fmt.Errorf("%w %v", combinedErr, getLogErrMsg(vi.stderrBuf.String()))
|
||||
}
|
||||
|
||||
combinedErr := multierr.Combine(
|
||||
append(globalErrs, errors.Wrap(cancelErr, "cancel on exit"))...,
|
||||
)
|
||||
|
||||
return fmt.Errorf("%w %v", combinedErr, getLogErrMsg(vi.stderrBuf.String()))
|
||||
}
|
||||
|
||||
func (vi *Instance) Cancel() error {
|
||||
if atomic.AddUint32(&vi.canceled, 1) != 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
vi.ctxCancel()
|
||||
return multierr.Combine(
|
||||
errors.Wrap(vi.cmd.Process.Signal(os.Interrupt), "cancel cmd"),
|
||||
errors.Wrap(vi.serialRead.Close(), "close serial read pipe"),
|
||||
errors.Wrap(vi.serialWrite.Close(), "close serial write pipe"),
|
||||
)
|
||||
}
|
||||
|
||||
func (vi *Instance) runSerialReader() error {
|
||||
for {
|
||||
raw, err := vi.serialReader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "read from serial reader")
|
||||
}
|
||||
|
||||
select {
|
||||
case vi.serialStdoutCh <- raw:
|
||||
default:
|
||||
// Message gets discarded if the buffer is full.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (vi *Instance) writeSerial(b []byte) error {
|
||||
vi.serialWriteMu.Lock()
|
||||
defer vi.serialWriteMu.Unlock()
|
||||
|
||||
_, err := vi.serialWrite.Write(b)
|
||||
return err
|
||||
}
|
||||
|
||||
func (vi *Instance) runVMLoginHandler() error {
|
||||
for {
|
||||
select {
|
||||
case <-vi.ctx.Done():
|
||||
return nil
|
||||
case <-time.After(time.Second):
|
||||
peek, err := vi.serialReader.Peek(vi.serialReader.Buffered())
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "peek stdout")
|
||||
}
|
||||
|
||||
if bytes.Contains(peek, []byte("login:")) {
|
||||
err = vi.writeSerial([]byte("root\n"))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to stdio write login")
|
||||
}
|
||||
|
||||
vi.logger.Debug("Logged into the VM serial")
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (vi *Instance) resetSerialStdout() {
|
||||
vi.serialStdoutCh = make(chan []byte, 32)
|
||||
}
|
||||
|
||||
func (vi *Instance) DialSSH() (*ssh.Client, error) {
|
||||
if vi.sshConf == nil {
|
||||
return nil, ErrSSHUnavailable
|
||||
}
|
||||
|
||||
return ssh.Dial("tcp", "localhost:"+fmt.Sprint(vi.sshMappedPort), vi.sshConf)
|
||||
}
|
||||
|
||||
func (vi *Instance) SSHUpNotifyChan() chan struct{} {
|
||||
return vi.sshReadyCh
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue