summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorbnewbold <bnewbold@robocracy.org>2012-12-25 22:45:13 +0100
committerbnewbold <bnewbold@robocracy.org>2012-12-25 22:45:13 +0100
commit723107806b5a7231752e70a00aef9c5c820adc81 (patch)
tree38b5bf160ededa37ca438ea83522fc89cf01f409
parent191d78b4555ea662089abf47eb8836fac2d6d2a3 (diff)
refactor auth checks with a decorator; fail fast
-rwxr-xr-xexmachina.py137
1 files changed, 77 insertions, 60 deletions
diff --git a/exmachina.py b/exmachina.py
index 1656c5e..e67c584 100755
--- a/exmachina.py
+++ b/exmachina.py
@@ -38,6 +38,7 @@ import socket
import subprocess
import time
import base64
+import functools
import bjsonrpc
import bjsonrpc.handlers
@@ -109,6 +110,18 @@ def execute_apt(packagename, action, timeout=120, aptargs=['-q', '-y']):
stdout, stderr = proc.communicate()
return stdout, stderr, proc.returncode
+def authreq(fn):
+ """
+ Decorator to force authentication before allowing calls to a method
+ """
+ @functools.wraps(fn)
+ def wrappedfunc(self, *args, **kwargs):
+ if not self.secret_key:
+ return fn(self, *args, **kwargs)
+ else:
+ log.error("Unauthorized function call attempt; bailing")
+ exit(-1)
+ return wrappedfunc
class ExMachinaHandler(bjsonrpc.handlers.BaseHandler):
@@ -131,101 +144,102 @@ class ExMachinaHandler(bjsonrpc.handlers.BaseHandler):
sys.exit()
self.secret_key = None
+
# ------------- Augeas API Passthrough -----------------
+ @authreq
def augeas_save(self):
- if not self.secret_key:
- log.info("augeas: saving config")
- return self.augeas.save()
+ log.info("augeas: saving config")
+ return self.augeas.save()
+ @authreq
def augeas_set(self, path, value):
- if not self.secret_key:
- log.info("augeas: set %s=%s" % (path, value))
- return self.augeas.set(path.encode('utf-8'),
- value.encode('utf-8'))
+ log.info("augeas: set %s=%s" % (path, value))
+ return self.augeas.set(path.encode('utf-8'),
+ value.encode('utf-8'))
+ @authreq
def augeas_setm(self, base, sub, value):
- if not self.secret_key:
- log.info("augeas: setm %s %s = %s" % (base, sub, value))
- return self.augeas.setm(base.encode('utf-8'),
- sub.encode('utf-8'),
- value.encode('utf-8'))
+ log.info("augeas: setm %s %s = %s" % (base, sub, value))
+ return self.augeas.setm(base.encode('utf-8'),
+ sub.encode('utf-8'),
+ value.encode('utf-8'))
+ @authreq
def augeas_get(self, path):
- if not self.secret_key:
- # reduce verbosity
- log.debug("augeas: get %s" % path)
- return self.augeas.get(path.encode('utf-8'))
+ # reduce verbosity
+ log.debug("augeas: get %s" % path)
+ return self.augeas.get(path.encode('utf-8'))
+ @authreq
def augeas_match(self, path):
- if not self.secret_key:
- # reduce verbosity
- log.debug("augeas: match %s" % path)
- return self.augeas.match("%s" % path.encode('utf-8'))
+ # reduce verbosity
+ log.debug("augeas: match %s" % path)
+ return self.augeas.match("%s" % path.encode('utf-8'))
+ @authreq
def augeas_insert(self, path, label, before=True):
- if not self.secret_key:
- log.info("augeas: insert %s=%s" % (path, value))
- return self.augeas.insert(path.encode('utf-8'),
- label.encode('utf-8'),
- before=before)
+ log.info("augeas: insert %s=%s" % (path, value))
+ return self.augeas.insert(path.encode('utf-8'),
+ label.encode('utf-8'),
+ before=before)
+ @authreq
def augeas_move(self, src, dst):
- if not self.secret_key:
- log.info("augeas: move %s -> %s" % (src, dst))
- return self.augeas.move(src.encode('utf-8'), dst.encode('utf-8'))
+ log.info("augeas: move %s -> %s" % (src, dst))
+ return self.augeas.move(src.encode('utf-8'), dst.encode('utf-8'))
+ @authreq
def augeas_remove(self, path):
- if not self.secret_key:
- log.info("augeas: remove %s" % path)
- return self.augeas.remove(path.encode('utf-8'))
+ log.info("augeas: remove %s" % path)
+ return self.augeas.remove(path.encode('utf-8'))
# ------------- Misc. non-Augeas Helpers -----------------
+ @authreq
def set_timezone(self, tzname):
- if not self.secret_key:
- log.info("reset timezone to %s" % tzname)
- tzname = tzname.strip()
- tzpath = os.path.join("/usr/share/zoneinfo", tzname)
- try:
- os.stat(tzpath)
- except OSError:
- # file not found
- raise ValueError("timezone not valid: %s" % tzname)
- shutil.copy(
- os.path.join("/usr/share/zoneinfo", tzname),
- "/etc/localtime")
- with open("/etc/timezone", "w") as tzfile:
- tzfile.write(tzname + "\n")
- return "timezone changed to %s" % tzname
+ log.info("reset timezone to %s" % tzname)
+ tzname = tzname.strip()
+ tzpath = os.path.join("/usr/share/zoneinfo", tzname)
+ try:
+ os.stat(tzpath)
+ except OSError:
+ # file not found
+ raise ValueError("timezone not valid: %s" % tzname)
+ shutil.copy(
+ os.path.join("/usr/share/zoneinfo", tzname),
+ "/etc/localtime")
+ with open("/etc/timezone", "w") as tzfile:
+ tzfile.write(tzname + "\n")
+ return "timezone changed to %s" % tzname
# ------------- init.d Service Control -----------------
+ @authreq
def initd_status(self, servicename):
- if not self.secret_key:
- return execute_service(servicename, "status")
+ return execute_service(servicename, "status")
+ @authreq
def initd_start(self, servicename):
- if not self.secret_key:
- return execute_service(servicename, "start")
+ return execute_service(servicename, "start")
+ @authreq
def initd_stop(self, servicename):
- if not self.secret_key:
- return execute_service(servicename, "stop")
+ return execute_service(servicename, "stop")
+ @authreq
def initd_restart(self, servicename):
- if not self.secret_key:
- return execute_service(servicename, "restart")
+ return execute_service(servicename, "restart")
# ------------- apt-get Package Control -----------------
+ @authreq
def apt_install(self, packagename):
- if not self.secret_key:
- return execute_apt(packagename, "install")
+ return execute_apt(packagename, "install")
+ @authreq
def apt_update(self):
- if not self.secret_key:
- return execute_apt("", "update")
+ return execute_apt("", "update")
+ @authreq
def apt_remove(self, packagename):
- if not self.secret_key:
- return execute_apt(packagename, "remove")
+ return execute_apt(packagename, "remove")
class EmptyClass():
@@ -289,7 +303,9 @@ def run_server(socket_path, secret_key=None, socket_group=None):
if not 0 == os.geteuid():
log.warn("Expected to be running as root!")
+ # if the socket was left open after a previous run, overwrite it
if os.path.exists(socket_path):
+ log.warn("Clobbering pre-existing socket: %s" % socket_path)
os.unlink(socket_path)
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.bind(socket_path)
@@ -302,6 +318,7 @@ def run_server(socket_path, secret_key=None, socket_group=None):
os.chown(socket_path, socket_uid, socket_gid)
else:
os.chmod(socket_path, 0666)
+
if secret_key:
ExMachinaHandler.secret_key = secret_key