From 723107806b5a7231752e70a00aef9c5c820adc81 Mon Sep 17 00:00:00 2001 From: bnewbold Date: Tue, 25 Dec 2012 22:45:13 +0100 Subject: refactor auth checks with a decorator; fail fast --- exmachina.py | 137 +++++++++++++++++++++++++++++++++-------------------------- 1 file changed, 77 insertions(+), 60 deletions(-) diff --git a/exmachina.py b/exmachina.py index 1656c5e..e67c584 100755 --- a/exmachina.py +++ b/exmachina.py @@ -38,6 +38,7 @@ import socket import subprocess import time import base64 +import functools import bjsonrpc import bjsonrpc.handlers @@ -109,6 +110,18 @@ def execute_apt(packagename, action, timeout=120, aptargs=['-q', '-y']): stdout, stderr = proc.communicate() return stdout, stderr, proc.returncode +def authreq(fn): + """ + Decorator to force authentication before allowing calls to a method + """ + @functools.wraps(fn) + def wrappedfunc(self, *args, **kwargs): + if not self.secret_key: + return fn(self, *args, **kwargs) + else: + log.error("Unauthorized function call attempt; bailing") + exit(-1) + return wrappedfunc class ExMachinaHandler(bjsonrpc.handlers.BaseHandler): @@ -131,101 +144,102 @@ class ExMachinaHandler(bjsonrpc.handlers.BaseHandler): sys.exit() self.secret_key = None + # ------------- Augeas API Passthrough ----------------- + @authreq def augeas_save(self): - if not self.secret_key: - log.info("augeas: saving config") - return self.augeas.save() + log.info("augeas: saving config") + return self.augeas.save() + @authreq def augeas_set(self, path, value): - if not self.secret_key: - log.info("augeas: set %s=%s" % (path, value)) - return self.augeas.set(path.encode('utf-8'), - value.encode('utf-8')) + log.info("augeas: set %s=%s" % (path, value)) + return self.augeas.set(path.encode('utf-8'), + value.encode('utf-8')) + @authreq def augeas_setm(self, base, sub, value): - if not self.secret_key: - log.info("augeas: setm %s %s = %s" % (base, sub, value)) - return self.augeas.setm(base.encode('utf-8'), - sub.encode('utf-8'), - value.encode('utf-8')) + log.info("augeas: setm %s %s = %s" % (base, sub, value)) + return self.augeas.setm(base.encode('utf-8'), + sub.encode('utf-8'), + value.encode('utf-8')) + @authreq def augeas_get(self, path): - if not self.secret_key: - # reduce verbosity - log.debug("augeas: get %s" % path) - return self.augeas.get(path.encode('utf-8')) + # reduce verbosity + log.debug("augeas: get %s" % path) + return self.augeas.get(path.encode('utf-8')) + @authreq def augeas_match(self, path): - if not self.secret_key: - # reduce verbosity - log.debug("augeas: match %s" % path) - return self.augeas.match("%s" % path.encode('utf-8')) + # reduce verbosity + log.debug("augeas: match %s" % path) + return self.augeas.match("%s" % path.encode('utf-8')) + @authreq def augeas_insert(self, path, label, before=True): - if not self.secret_key: - log.info("augeas: insert %s=%s" % (path, value)) - return self.augeas.insert(path.encode('utf-8'), - label.encode('utf-8'), - before=before) + log.info("augeas: insert %s=%s" % (path, value)) + return self.augeas.insert(path.encode('utf-8'), + label.encode('utf-8'), + before=before) + @authreq def augeas_move(self, src, dst): - if not self.secret_key: - log.info("augeas: move %s -> %s" % (src, dst)) - return self.augeas.move(src.encode('utf-8'), dst.encode('utf-8')) + log.info("augeas: move %s -> %s" % (src, dst)) + return self.augeas.move(src.encode('utf-8'), dst.encode('utf-8')) + @authreq def augeas_remove(self, path): - if not self.secret_key: - log.info("augeas: remove %s" % path) - return self.augeas.remove(path.encode('utf-8')) + log.info("augeas: remove %s" % path) + return self.augeas.remove(path.encode('utf-8')) # ------------- Misc. non-Augeas Helpers ----------------- + @authreq def set_timezone(self, tzname): - if not self.secret_key: - log.info("reset timezone to %s" % tzname) - tzname = tzname.strip() - tzpath = os.path.join("/usr/share/zoneinfo", tzname) - try: - os.stat(tzpath) - except OSError: - # file not found - raise ValueError("timezone not valid: %s" % tzname) - shutil.copy( - os.path.join("/usr/share/zoneinfo", tzname), - "/etc/localtime") - with open("/etc/timezone", "w") as tzfile: - tzfile.write(tzname + "\n") - return "timezone changed to %s" % tzname + log.info("reset timezone to %s" % tzname) + tzname = tzname.strip() + tzpath = os.path.join("/usr/share/zoneinfo", tzname) + try: + os.stat(tzpath) + except OSError: + # file not found + raise ValueError("timezone not valid: %s" % tzname) + shutil.copy( + os.path.join("/usr/share/zoneinfo", tzname), + "/etc/localtime") + with open("/etc/timezone", "w") as tzfile: + tzfile.write(tzname + "\n") + return "timezone changed to %s" % tzname # ------------- init.d Service Control ----------------- + @authreq def initd_status(self, servicename): - if not self.secret_key: - return execute_service(servicename, "status") + return execute_service(servicename, "status") + @authreq def initd_start(self, servicename): - if not self.secret_key: - return execute_service(servicename, "start") + return execute_service(servicename, "start") + @authreq def initd_stop(self, servicename): - if not self.secret_key: - return execute_service(servicename, "stop") + return execute_service(servicename, "stop") + @authreq def initd_restart(self, servicename): - if not self.secret_key: - return execute_service(servicename, "restart") + return execute_service(servicename, "restart") # ------------- apt-get Package Control ----------------- + @authreq def apt_install(self, packagename): - if not self.secret_key: - return execute_apt(packagename, "install") + return execute_apt(packagename, "install") + @authreq def apt_update(self): - if not self.secret_key: - return execute_apt("", "update") + return execute_apt("", "update") + @authreq def apt_remove(self, packagename): - if not self.secret_key: - return execute_apt(packagename, "remove") + return execute_apt(packagename, "remove") class EmptyClass(): @@ -289,7 +303,9 @@ def run_server(socket_path, secret_key=None, socket_group=None): if not 0 == os.geteuid(): log.warn("Expected to be running as root!") + # if the socket was left open after a previous run, overwrite it if os.path.exists(socket_path): + log.warn("Clobbering pre-existing socket: %s" % socket_path) os.unlink(socket_path) sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock.bind(socket_path) @@ -302,6 +318,7 @@ def run_server(socket_path, secret_key=None, socket_group=None): os.chown(socket_path, socket_uid, socket_gid) else: os.chmod(socket_path, 0666) + if secret_key: ExMachinaHandler.secret_key = secret_key -- cgit v1.2.3