Source code for slave.vm

#!/usr/bin/env python
# encoding: utf-8

import json
import libvirt
import logging
import netifaces
import os
import random
import re
import sh
from sh import md5sum
from sh import wget
from sh import arp
import socket
import sys
import threading
import time
import uuid
import xmltodict

LIBVIRT_BASE = "/var/lib/libvirt/images"

from slave.models import Image
from slave.vm.comms import VMComms, VM_TYPE_LINUX, VM_TYPE_WINDOWS

logging.getLogger("sh").setLevel(logging.WARN)

[docs]def image_id_to_volume(image): return "{}_vagrant_box_image_0.img".format(image)
[docs]def qemu_img_info(image_path): """Return a dict of info returned by qemu-img info. Assumes the image_path exists and points to a valid VM image :image_path: path to the image :returns: dict of returned information about the image """ output = sh.qemu_img.info(image_path) res = {} for line in output.split("\n"): match = re.match(r'^\s*([^\s].*):\s*(.*)$', line) if match is None: continue res[match.group(1)] = match.group(2) return res
[docs]class ImageManager(object): _instance = None @classmethod
[docs] def instance(cls): if cls._instance is None: cls._instance = cls() return cls._instance
def __init__(self): self._cache = {} self._log = logging.getLogger("ImageMan") self.image_url = None
[docs] def get_md5(self, path): """Get the md5 of the file at ``path``. A cache will be used based on last modified time of the file. If the file does not exist, None will be returned. """ path = os.path.realpath(os.path.abspath(path)) if not os.path.exists(path): self._log.debug("could not get md5, path does not exist: {}".format(path)) return None if path in self._cache and os.path.getmtime(path) == self._cache[path]["modtime"]: self._log.debug("path was in cache ({}), md5: {}".format(path, self._cache[path]["md5"])) return self._cache[path]["md5"] self._log.info("path not in cache ({}), calculating".format(path)) output = md5sum(path).split()[0] self._log.info("calculated md5: " + output) self._cache[path] = { "md5": output, "modtime": os.path.getmtime(path), } return output
[docs] def download_image(self, image_id): """Download the image from the image_url""" image_filename = image_id_to_volume(image_id) dest = os.path.join(LIBVIRT_BASE, image_filename) self._log.debug("downloading image {} to {}".format(image_id, dest)) wget("-q", "-O", dest, self.image_url + "/" + image_filename) self._log.debug("downloaded {}".format(image_id))
[docs] def ensure_image(self, image_id): """Ensure that the image ``image_id`` and its bases exist in LIBVIRT_BASE checking its md5 against the md5 sum stored in the database :returns: True/False on success """ self._log.info("ensuring image {} exists and is valid".format(image_id)) dest = os.path.join(LIBVIRT_BASE, image_id_to_volume(image_id)) if not os.path.exists(dest): self.download_image(image_id) else: images = Image.objects(id=image_id) if len(images) == 0: self._log.warn("image id {} does not reference a valid image".format(image_id)) return False image = images[0] md5 = self.get_md5(dest) # all good, nothing has changed if md5 == image.md5: self._log.debug("image {} is unchanged".format(image_id)) else: self._log.debug("image {} changed (model: {}, disk: {}), redownloading".format(image_id, image.md5, md5)) self.download_image(image_id) info = qemu_img_info(dest) if "backing file" in info: self._log.debug("checking backing files for validity") backing = info["backing file"] # backing will be a (absolute?) path backing_id = os.path.basename(backing).split("_")[0] self.ensure_image(backing_id) else: self._log.debug("no backing file, image looks good!") return True
[docs]class VMHandler(threading.Thread): def __init__(self, job, idx, image, image_username, image_password, tool, params, code_loc, code_username, code_password, timeout=600, network="whitelist", on_finished=None, on_vnc_available=None, startup_timeout=60): """Start up the VM image ``image`` in libvirt, with a timeout of ``timeout``, and params ``params, using network ``network``. :image: The name of the image :params: Params that specify what to run inside of the VM :timeout: The timeout for the vm """ super(VMHandler, self).__init__() self.job = job self.idx = idx self.image = image self.image_username = image_username self.image_password = image_password self.tool = tool self.params = params self.code_loc = code_loc self.code_username = code_username self.code_password = code_password self.timeout = timeout self.startup_timeout = startup_timeout # network can be 'all' or 'whitelist' # whitelist values can also be followed by a semicolon # and a comma-separated list of domain names/ip addresses self.network = network parts = self.network.split(":") self.network = parts[0] self.whitelisted_hosts = [] if self.network == "whitelist": if len(parts) > 1: self.whitelisted_hosts = [x.strip() for x in parts[1].split(",")] self.ram = 1024 self.vnc_port = -1 self.on_vnc_available = on_vnc_available self.on_finished = on_finished self.start_time = time.time() self._log = logging.getLogger("VM-JOB:{}:{}".format(self.job, self.idx)) self._running = threading.Event() self._image_man = ImageManager.instance() self._libvirt_conn = None self._vm_image_loc = None self._domain = None
[docs] def run(self): """Run the VMHandler """ self._running.set() self.start_time = time.time() self._log.debug("starting") if not self._vm_start(): self._log.warn("error, could not start vm, bailing") self._running.clear() return start_time = time.time() total_time = 0 # wait for the VM to startup before waiting for it to be shutdown while self._running.is_set() and not self._vm_is_running() and total_time < self.startup_timeout: time.sleep(0.2) total_time = time.time() - start_time() if self._running.is_set(): if total_time >= self.startup_timeout: self._log.warn("VM took too long to startup, bailing") self._running.clear() # means it started up else: self._log.info("VM started up, waiting to inject bootstrap and run job") self._connect_comms() self._inject_and_run_bootstrap() start_time = time.time() total_time = 0 while self._running.is_set() and self._vm_is_running() and total_time < self.timeout: vnc_port = self._vm_vnc_port() if vnc_port != -1 and self.vnc_port == -1: self.vnc_port = vnc_port if self.on_vnc_available is not None: self.on_vnc_available(self) time.sleep(0.1) total_time = time.time() - start_time self._vm_cleanup() self._log.debug("finished") if self.on_finished is not None: self.on_finished(self)
[docs] def stop(self): self._log.info("stopping") self._running.clear()
[docs] def handle_comms(self, data): """Handle guest communications""" self._log.info("handling comms: {}".format(data)) switch = dict( startup = self.handle_guest_startup ) if "type" not in data: self._log.debug("guest comms does not include a type") return "{}" return switch[data["type"]](data)
[docs] def handle_guest_startup(self, data): res = dict( id = self.job, tool = self.tool, params = self.params, idx = self.idx, code_loc = self.code_loc ) return json.dumps(res) # ----------------------------
def _get_filter_params(self): if self.network == "all": return "" elif self.network == "whitelist": code_loc_host = self.code_loc.replace("http://", "").replace("https://", "").split("/", 1)[0] code_loc_ip = socket.gethostbyname(code_loc_host) this_ip = netifaces.ifaddresses('virbr0')[2][0]['addr'] bcast = this_ip.rsplit(".",1)[0] + ".255" ips = [ "255.255.255.255", # always include the guest comms ip bcast, this_ip, code_loc_ip ] for other_host in self.whitelisted_hosts: ips.append(socket.gethostbyname(other_host)) res = [] for ip in ips: res.append("<parameter name='WHITELIST' value='{}' />".format(ip)) return "\n".join(res) def _connect_comms(self): self._log.debug("waiting for vm to get an ip") ip_addr = self._vm_ip_address() while self._running.is_set() and self._vm_is_running() and ip_addr is None: time.sleep(0.5) ip_addr = self._vm_ip_address() if not self._running.is_set() or not self._vm_is_running(): self._log.debug("stopped waiting for ip, was told to quit (or vm shutdown)") return self._log.info("vm has an ip ({})! connecting comms".format(ip_addr)) # TODO probe ports 22/5569 instead of this? self._comms = VMComms.get_comms(VM_TYPE_WINDOWS) self._comms.connect(ip_addr, self.image_username, self.image_password) def _inject_and_run_bootstrap(self): """Inject and run the bootstrap inside the VM """ self._log.info("injecting bootstrap") with open(os.path.join(os.path.dirname(__file__), "bootstrap.py"), "r") as f: bootstrap_contents = f.read() if not self._running.is_set() or not self._vm_is_running(): return tmp_path = self._comms.sep.join([self._comms.tmp_loc(), "bootstrap.py"]) self._log.debug("saving bootstrap to {}".format(tmp_path)) output = self._comms.put_file(tmp_path, bootstrap_contents) if not self._running.is_set() or not self._vm_is_running(): return config_path = self._comms.sep.join([self._comms.tmp_loc(), "config.json"]) self._log.debug("writing config to {}".format(config_path)) output = self._comms.put_file(config_path, self._make_config()) if not self._running.is_set() or not self._vm_is_running(): return self._log.info("running bootstrap") self._comms.run_script("python \"" + tmp_path + "\"", background=True) if not self._running.is_set() or not self._vm_is_running(): return self._log.debug("started bootstrap") def _make_config(self): res = dict( id = self.job, idx = self.idx, tool = self.tool, params = self.params, code = dict( loc = self.code_loc, username = self.code_username, password = self.code_password, ) ) return json.dumps(res, indent=4, separators=(',', ': ')) # ---------------------------- def _libvirt(self): if self._libvirt_conn is None: self._libvirt_conn = libvirt.open("qemu:///system") return self._libvirt_conn def _libvirt_domain(self): """Return the libvirt domain for the currently-running vagrant box :returns: libvirt.Domain if exists, None if it does not exist """ conn = self._libvirt() try: domain = conn.lookupByName(self._domain) return domain except libvirt.libvirtError as e: return None def _vm_start(self): if not self._image_man.ensure_image(self.image): return False self._vm_create() self._vm_run() return True def _vm_cleanup(self): self._log.info("cleaning up") self._vm_kill() os.remove(self._vm_image_loc) def _vm_kill(self): vm_name = os.path.basename(self._vm_image_loc) # the domains created are transient, don't need to undefine them # sh.virsh.undefine(vm_name) # if the VM has already been shutdown, this will fail try: sh.virsh.destroy(vm_name) except: pass def _vm_is_running(self): """Return True/False if the current image is still running :returns: True/False """ conn = self._libvirt() try: domain = conn.lookupByName(self._domain) except libvirt.libvirtError as e: return False if domain is None: return False state,reason = domain.state() if state == libvirt.VIR_DOMAIN_RUNNING: return True else: return False def _vm_vnc_port(self): """Return the vnc port of the vagrant VM :returns: The vnc port. If the domain is not running, None is returned. If vnc is not (yet?) available, -1 is returned. """ domain = self._libvirt_domain() # VM isn't running (yet?) if domain is None: return None info = xmltodict.parse(domain.XMLDesc()) port = int(info["domain"]["devices"]["graphics"]["@port"]) return port def _vm_ip_address(self): """Return the ip address (on the host) of the VM being handled :returns: IP Address, or None if it does not yet have one """ domain = self._libvirt_domain() if domain is None: return None info = xmltodict.parse(domain.XMLDesc()) mac_addr = info["domain"]["devices"]["interface"]["mac"]["@address"] output = arp("-a", "-n") for line in output.split("\n"): if mac_addr in line: ip_address = re.search(r'(\d{1,3}.\d{1,3}.\d{1,3}.\d{1,3})', line) return ip_address.group(1) return None def _vm_create(self): self._vm_image_loc = "/tmp/{}_{}.img".format(self.job, self.idx) self._domain = os.path.basename(self._vm_image_loc) output = sh.qemu_img.create( self._vm_image_loc, b = os.path.join(LIBVIRT_BASE, image_id_to_volume(self.image)), f = "qcow2", ) def _rand_mac_addr(self): mac = [0x00,0x16,0x3e, random.randint(0x00, 0x7f), random.randint(0x00, 0xff), random.randint(0x00, 0xff) ] return ':'.join(map(lambda x: "%02x" % x, mac)) def _vm_run(self): """Creates the domain and runs it""" vm_name = os.path.basename(self._vm_image_loc) self._log.info("running VM {}".format(vm_name)) domain_xml = """ <domain type='kvm'> <name>{domain_name}</name> <uuid>{domain_uuid}</uuid> <memory>{mem_size}</memory> <currentMemory>{mem_size}</currentMemory> <vcpu>{num_cpus}</vcpu> <os> <type arch='x86_64'>hvm</type> <boot dev='hd'/> </os> <features> <acpi/><apic/><pae/> </features> <clock offset="utc"/> <on_poweroff>destroy</on_poweroff> <on_reboot>restart</on_reboot> <on_crash>destroy</on_crash> <devices> <emulator>/usr/bin/kvm-spice</emulator> <disk type='file' device='disk'> <driver name='qemu' type='qcow2'/> <source file='{image_path}'/> <target dev='vda' bus='sata'/> </disk> <interface type='network'> <source network='default'/> <model type='virtio'/> <filterref filter='{filter_name}'> {filter_params} </filterref> </interface> <input type='mouse' bus='ps2'/> <graphics type='vnc' port='-1'/> <console type='pty'/> <video> <model type='cirrus'/> </video> </devices> </domain> """.format( domain_name = vm_name, domain_uuid = str(uuid.uuid4()), mem_size = self.ram * 1024, # ram is in MB num_cpus = 2, image_path = self._vm_image_loc, filter_name = "talus-" + self.network, filter_params = self._get_filter_params(), #mac_address = self._rand_mac_addr() ) conn = self._libvirt() #domain = conn.defineXML(domain_xml) # should create and start the VM domain = conn.createXML(domain_xml, 0) # #sh.virt_install( #"--import", # stupid python keywords #virt_type = "kvm", #r = self.ram, #accelerate = True, #n = vm_name, #disk = "{},device=disk,bus=sata,format=qcow2".format(self._vm_image_loc), #vnc = True, ##w = "bridge=virbr0,model=virtio", #w = True, #noautoconsole = True, ## network = "filter...." #)