Source code for otto.connections.ssh

# -*- coding: utf-8 -*-
"""
Paramiko based ssh module
"""
import os
import logging
import socket
from time import sleep, time
from multiprocessing import Process, Value, Array

import paramiko

from otto.lib.contextmanagers import ignored
from otto.lib.otypes import ReturnCode, ConnectionError, Data, Namespace

instance = os.environ.get('instance') or ''
logger = logging.getLogger('otto' + instance + '.connections')
logger.addHandler(logging.NullHandler())


# pylint: disable=R0903,R0902,R0904
class AllowAnythingPolicy(paramiko.MissingHostKeyPolicy):
    def missing_host_key(self, client, hostname, key):
        return


# noinspection PyMethodOverriding
[docs]class Client(paramiko.SSHClient): """ Connect to a host using paramiko, a Python interface to the SSH2 protocol. Transport compression is enabled by default now. Client.environmentals is a dictionary of environment variables to be set. They can be manipulated on the fly with care or using the env context manager, otto.lib.contextmanager.env(). """ # pylint: disable=R0913,R0921 def __init__(self, host, user, password, port=22, compress=True): self.cwd = str() self.environmentals = dict() self.host = host self.user = user self.port = port self.password = password super(Client, self).__init__() self.compression = compress self.set_log_channel('otto' + instance + '.connections') self.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.connected = False # pylint: disable=W0221
[docs] def connect(self, timeout=10, key_file=None): """ Currently this method will return False if we can't connect. If the script ignores this return and proceeds the traceback might not be obvious as to where the problem was. """ # Policy for automatically adding the hostname and new host key to the local HostKeys try: # Calling the base class connect method. super(Client, self).connect(hostname=self.host, port=self.port, username=self.user, password=self.password, timeout=timeout, key_filename=key_file, compress=self.compression) except paramiko.BadHostKeyException as e: message = "Server's host key could not be verified" logger.critical(message) logger.error(e) return ReturnCode(False, message=message) except paramiko.AuthenticationException as e: message = "Authentication with the server failed" logger.critical(message) logger.error(e) return ReturnCode(False, message=message) except paramiko.SSHException as e: message = "Couldn't complete connection" # when stacking Exceptions superclasses will catch subclasses logger.critical(message) logger.error(e) return ReturnCode(False, message=message) except socket.timeout as e: message = "No response from host" logger.critical(message) logger.error(e) return ReturnCode(False, message=message) except socket.error as e: message = "Connection refused" logger.critical(message) logger.error(e) return ReturnCode(False, message=message) self.connected = True return ReturnCode(True, message=self.connected)
[docs] def run(self, cmd, timeout=None, bufsize=-1): """ :param cmd: command to run on remote host :type cmd: str :param timeout: timeout on blocking read/write operations when exceeded socket error will be raised :type timeout: float :param bufsize: byte size of the buffer for the filehandle returned :type bufsize: int :rtype: ReturnCode """ ret = ReturnCode(False) if not self.connected: raise ConnectionError("Run was called on an unconnected host. Did you check the result of connect()?") try: if self.environmentals: envstring = str() for var, value in self.environmentals.items(): statement = "%s=%s " % (var, value) envstring += statement cmd = "%s%s" % (envstring, cmd) if self.cwd: cmd = "cd %s && %s" % (self.cwd, cmd) self._log(logging.DEBUG, 'running command: "%s"' % cmd) stdin, stdout, stderr = self.exec_command(command=cmd, timeout=timeout, bufsize=bufsize) except paramiko.SSHException as e: err = "Couldn't complete the command: %s" % str(e) logger.critical(err) ret.message = err return ret # we must read stderr _before_ stdout # otherwise paramiko losses the stdout data try: ret.raw = Data(ret.raw.status, ret.raw.stdout, stderr.read()) except socket.timeout: ret.message = "Timeout" return ret status = stdout.channel.recv_exit_status() ret.raw = Data(status, stdout.read(), ret.raw.stderr) if status != 0: ret.message = ret.raw.stderr else: ret.status = True ret.message = ret.raw.stdout stdin.close() return ret
[docs] def disconnect(self): """ Disconnect from the host. """ self.close() self.connected = False return self._transport is None
[docs] def reconnect(self, after=10, timeout=10, key_file=None, conn_attempts=10): """ This method will attempt to reconnect with the host, maybe after a reboot action. The method will have a limit of 10 attempts to connect by default, for a total of 300 seconds before it gives up with the reconnection. """ self.close() sleep(after) if conn_attempts: for i in range(conn_attempts): with ignored(ConnectionError): logger.debug('Attempt %d of %d to re-connect to the host', i, conn_attempts) self.connect(timeout=timeout, key_file=key_file) self.connected = True return True else: down = True i = 1 while down: with ignored(ConnectionError): logger.debug('Attempt %d of inf to re-connect to the host', i) if self.connect(timeout=timeout, key_file=key_file): down = False self.connected = True i += 1 return True logger.error("Giving up, no attempts left to re-connect to host") return False
[docs] def ls(self, path=None, expectation=True): """ return a list of files in path. If path is not specified cwd will be used. If path does not exist an exception will be raised unless expectation is False. Works in conjunction with the cd context manager. """ sftp = self.open_sftp() if not path: path = "" try: if path: if path[0] == '/': return sftp.listdir(path=path) else: return sftp.listdir(path="%s%s" % (self.cwd, path)) else: if self.cwd: return sftp.listdir(path=self.cwd) else: return sftp.listdir() except IOError as e: if not expectation: return [] else: raise ConnectionError("ls %s failed: %s" % (path, e))
def mkdir(self, dirname, mode=511, expectation=True): sftp = self.open_sftp() try: if dirname[0] == '/': return sftp.mkdir(path=dirname, mode=mode) else: return sftp.mkdir(path="%s%s" % (self.cwd, dirname), mode=mode) except IOError as e: if not expectation: return [] else: raise ConnectionError("mkdir %s failed: %s" % (dirname, e))
[docs] def rmdir(self, dirname, expectation=True): """ remove a directory named by a string """ sftp = self.open_sftp() try: if dirname[0] == '/': return sftp.rmdir(path=dirname) else: return sftp.rmdir(path="%s%s" % (self.cwd, dirname)) except IOError as e: if not expectation: return [] else: raise ConnectionError("rmdir %s failed: %s" % (dirname, e))
[docs] def rm(self, path, expectation=True): """ remove a file named by a string """ sftp = self.open_sftp() try: if path[0] == '/': return sftp.remove(path=path) else: return sftp.remove(path="%s%s" % (self.cwd, path)) except IOError as e: if not expectation: return [] else: raise ConnectionError("rm %s failed: %s" % (path, e))
[docs] def open(self, fname, mode='r', bufsize=-1, expectation=True): """ return a file-like object of fname on the remote """ sftp = self.open_sftp() try: ret = sftp.file(fname, mode, bufsize, ) except IOError as e: if not expectation: ret = ReturnCode(False, "file failed: %s" % e) else: raise ConnectionError(str(e)) return ret
[docs] def stat(self, path): """ returns a namespace of:: {'size' : st.st_size, 'uid': st.st_uid, 'gid': st.st_gid, 'mode': st.st_mode, 'atime': st.st_atime, 'mtime': st.st_mtime, } this can accessesed like so:: fstat= init.stat('/etc/passwd') fstat.flags :param path: :type path: str :return: :rtype: Namespace """ sftp = paramiko.SFTPClient.from_transport(self.get_transport()) st = sftp.stat(path) dstat = {'size': st.st_size, 'uid': st.st_uid, 'gid': st.st_gid, 'mode': st.st_mode, 'atime': st.st_atime, 'mtime': st.st_mtime, } ret = Namespace(dstat) return ret
[docs]class parallelCmd(object): """ a non-blocking remote command running object :: c = Cmd(init, params) c.run() # do other things c.wait() # or use an 'if not c.done:' control struct c.result.status c.result.message When the remote job has completed it will set:: .done to True .message with the dict version of the fio output .status with the exit code as a boolean and try to store the json data as a dict in .dict . """ def __init__(self, user, hostname, password, command, port=22): self.client = paramiko.SSHClient() self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) self.command = command self.time = float() self.dict = dict() # fio output in dict format self.started = False # synchonization varaibles self.__status = Value('i') self.__stdout = Array('c', 1000000) # 1 MB self.__stderr = Array('c', 1000000) # 1 MB self.p = None self.user = user self.hostname = hostname self.password = password self.port = port @staticmethod def _read(channelobj): """read until EOF""" buf = channelobj.readline() output = str(buf) while buf: buf = channelobj.readline() output += buf return output def _runcmd(self, status, stdout, stderr): """ this private method is executed as a thread :type status: c_int representing a bool :type stdout: c_array of char :type stderr: c_array of char """ self.started = True start = time() self.client.connect(self.hostname, self.port, self.user, self.password) stdin, sout, serr = self.client.exec_command(self.command) err = serr.read() for i in range(len(err)): stderr[i] = str(err[i]) # we must read stderr _before_ stdout # otherwise paramiko loses the stdout data status = sout.channel.recv_exit_status() out = sout.read() status += int(status) # copy stdout into shared memory for i in range(len(out)): stdout[i] = str(out[i]) self.client.close() self.time = time() - start @property def done(self): """ :return: whether or not the job is complete :rtype: bool """ if self.started and not self.p.is_alive(): return True else: return False @property def result(self): """ Return the result as a NamedTuple this means that result can be sliced or referenced by name: status or 0: exitcode as int stdout or 1: stdout as str stderr or 2: stderr as str so upon completion the following: cmd.result[0] cmd.status are equivilent. *If the process is not complete this will block.* """ if not self.done: self.wait() else: return Data(self.__status.value, str(self.__stdout.value), str(self.__stderr.value))
[docs] def run(self): """ start the job on the remote host """ self.p = Process(target=self._runcmd, args=(self.__status, self.__stdout, self.__stderr)) self.p.start() while not self.p.is_alive(): sleep(.1) logger.debug("slept not started") else: self.started = True
[docs] def wait(self): """ This is basicaly a join. It blocks untill the job is done. :return: a dictionary version of the json output :rtype: dict """ then = time() while not self.done: sleep(.01) else: logger.debug("waited for {:10.4f} sec".format(time() - then)) return self.result
[docs]class TunnelSocketCreator(paramiko.SSHClient): """ a class for opening sockets remotely with ssh """ def __init__(self, host, user, port=22, key_filename=None, compress=True): self.host = host self.user = user self.port = port super(TunnelSocketCreator, self).__init__() self.key_filename = key_filename self.compression = compress self.set_missing_host_key_policy(paramiko.AutoAddPolicy())
[docs] def connect(self, timeout=10): """ connect to proxy host :param timeout: :return: """ return super(TunnelSocketCreator, self).connect(hostname=self.host, port=self.port, username=self.user, timeout=timeout, key_filename=self.key_filename, compress=self.compression)
[docs] def get_sock(self, rhost, rport): """ :param rhost: :param rport: :return: a socket like object connected to rhost:rport """ transport = self.get_transport() return transport.open_channel('direct-tcpip', (rhost, rport), ('127.0.0.1', 0))