#!/usr/bin/python
# Copyright 2012 Google Inc. All Rights Reserved.

"""This class creates SSH session to Device.

This file defines the command/console interaction between the test
instance and the CPE device. This includes ssh tunnel establishment,
shell command passing, and device log collection, etc.

"""

__author__ = 'Lehan Meng (lmeng@google.com)'


import os
import tempfile
import time

import pexpect


class SSH(object):
  """SSH class that create a ssh tunnel to the CPE device.

  ssh tunnel can be established from various locations:
    - from PC that on the same LAN of the device
    - from PC that must ssh to device over authentication server (e.g., jump)
    - from PC that must ssh to device over lab server (e.g., athenasrv3)

  """

  def __init__(self, dev=None, **kwargs):
    """Initiated the SSH class.

    If a device is used to initiate the SSH Class, try to create an ssh tunnel
    using parameters provided by this device, otherwise use default parameter
    values.

    Args:
      dev: a device instance, which has necessary parameters to create a
              tunnel
      kwargs['user']: ssh user name
      kwargs['addr']: device ip address
      kwargs['pwd']: ssh password
      kwargs['bruno_prompt']: device shell command prompt
      etc.
    """

    self.params = {
      'user': 'root',
      'addr': None,
      'addr_ipv6': None,
      'pwd': 'google',
      'bruno_prompt': '\(none\)#',
      'jump_server': 'jmp.googlefiber.net',
      'jump_prompt': None,
      'athena_user': None,
      'athena_pwd': None,
      'athena_prompt': None}

    self.tmp_file = tempfile.TemporaryFile(mode='w+t')

    self.__short_delay = 5  # 5 seconds delay
    self.__delay = 10  # medium delay
    self.__long_delay = 30  #long delay

    for s in kwargs:
      self.params[s] = kwargs[s]

    #populate prompt on jump and athena server:
    if self.params['athena_user'] is not 'None':
      self.params['jump_prompt'] = self.params['athena_user'] + '@jmp101.nuq1:'
      self.params['athena_prompt'] = ('[' + self.params['athena_user']
                                      + '@athenasrv3 ~]$')

    if dev is not None:
      # parameters used to open an ssh tunnel to this device
      for s in ('user', 'addr', 'pwd', 'bruno_prompt', 'addr_ipv6'):
        self.params[s] = kwargs[s]

  def Setlogging(self, logging):
    """Setup the logging file(s)."""
    self.log = logging

  def Ssh(self):
    """Setup an ssh tunnel to device, return immediately upon failure.

    This method create an ssh tunnel from PC that on the same network of device
    Returns:
      return True when succeed, return False otherwise
    """
    # clear the temp file
    self.tmp_file.close()
    self.tmp_file = tempfile.TemporaryFile(mode='w+t')
    ssh_newkey = 'Are you sure you want to continue connecting'

    if self.params['addr_ipv6']:
      addr = self.params['addr_ipv6']
    else:
      addr = self.params['addr']

    p_ssh = pexpect.spawn('ssh '+self.params['user']+'@'+addr)
    i = p_ssh.expect(
        [ssh_newkey, 'password:', self.params['bruno_prompt'], pexpect.EOF,
         '\(none\)#', 'gfibertv#', pexpect.TIMEOUT])

    while 1:
      if i == 0:
        print 'choose Yes'
        p_ssh.sendline('yes')
        i = p_ssh.expect(
            [ssh_newkey, 'password:', self.params['bruno_prompt'], pexpect.EOF])
      if i == 1:
        print 'need password'
        p_ssh.sendline(self.param['pwd'])
        i = p_ssh.expect(
            [ssh_newkey, 'password:', self.params['bruno_prompt'], pexpect.EOF])
      if i == 2 or i == 4 or i == 5:
        print 'Login OK'
        p_ssh.logfile = self.tmp_file
        self.p_ssh = p_ssh
        if i == 4: self.params['bruno_prompt'] = '\(none\)#'
        if i == 5: self.params['bruno_prompt'] = 'gfibertv#'
        return True
      if i == 3 or i == 6:
        print 'Key or connection timeout'
        return False
      if i < 0 or i > 6:
        print 'Error, quit'
        return False

  def SshFromAthena(self):
    """Initiate an ssh tunnel from the athenasrv3 to the device.

    Returns:
      return True when succeed, return False otherwise
    """
    # clear the temp file
    self.tmp_file.close()
    self.tmp_file = tempfile.TemporaryFile(mode='w+t')
    self.p_ssh.logfile = self.tmp_file
    #Const
    ssh_newkey = 'Are you sure you want to continue connecting'

    self.p_ssh.sendline('ssh-agent /bin/bash')
    self.p_ssh.expect(self.params['athena_prompt'])
    #p_ssh.expect(self.params['athena_prompt'])
    #print self.p_ssh.before, self.p_ssh.after
    print '========================'
    print 'add bruno key file:'
    self.p_ssh.sendline('ssh-add /home/' + self.params['athena_user']
                        + '/.ssh/bruno-sshkey')
    self.p_ssh.expect(self.params['athena_prompt'])
    #print self.p_ssh.before

    if self.params['addr_ipv6']:
      addr = self.params['addr_ipv6']
    else:
      addr = self.params['addr']

    self.p_ssh.sendline('ssh '+self.params['user']+'@'+addr)
    i = self.p_ssh.expect(
        [ssh_newkey, 'password:', self.params['bruno_prompt'], pexpect.EOF,
         '\(none\)#', 'gfibertv#', pexpect.TIMEOUT])

    while 1:
      if i == 0:
        print 'choose Yes'
        self.p_ssh.sendline('yes')
        i = self.p_ssh.expect(
            [ssh_newkey, 'password:', self.params['bruno_prompt'], pexpect.EOF])
      if i == 1:
        print 'need password'
        self.p_ssh.sendline(self.params['pwd'])
        i = self.p_ssh.expect(
            [ssh_newkey, 'password:', self.params['bruno_prompt'], pexpect.EOF])
      if i == 2 or i == 4 or i == 5:
        print 'Login OK'
        self.p_ssh.logfile = self.tmp_file
        if i == 4: self.params['bruno_prompt'] = '\(none\)#'
        if i == 5: self.params['bruno_prompt'] = 'gfibertv#'
        return True
      if i == 3 or i == 6:
        print 'Key or connection timeout'
        return False
      if i < 0 or i > 6:
        print 'Error, quit'
        return False

  def ExitToAthena(self):
    """Terminate the ssh tunnel to device, and exit to athenasrv3."""
    self.p_ssh.sendline('exit')
    self.p_ssh.expect(self.params['athena_prompt'])
    self.p_ssh.sendline('exit')
    self.p_ssh.expect(self.params['athena_prompt'])

  def SshToAthena(self):
    """ssh to athenasrv3, over jump server from a Google PC.

    This is for the reason that some test environment can only be accessed via
    corp PC, authentication server (e.g., jump) and lab server
    (e.g., athenasrv3)
    Returns:
      return True upon success
    """
    p_ssh = pexpect.spawn('ssh -a ' + self.params['jump_server'])
    print 'ssh -a jmp.googlefiber.net ...'
    #p_ssh.expect('lmeng@jmp101.nuq1:.*[\\$]')
    i = p_ssh.expect(
        [self.params['jump_prompt'], pexpect.TIMEOUT, pexpect.EOF],
        self.__short_delay)
    if i == 1 or i == 2:
      info = self.log.CreateErrorInfo(
            'critical', 'Cannot establish SSH tunnel to jump server.'
            'Timeout or incorrect prompt!')
      self.log.SendLine(None, info)

    print p_ssh.before, p_ssh.after
    print '========================'
    p_ssh.sendline('ssh -a ' + self.params['athena_user'] + '@10.1.16.250')
    print 'ssh -a ' + self.params['athena_user'] + '@10.1.16.250 ...'
    i = p_ssh.expect([self.params['athena_user'] + '@10.1.16.250\'s password:',
                      pexpect.TIMEOUT, pexpect.EOF], self.__short_delay)
    if i == 1 or i == 2:
      info = self.log.CreateErrorInfo(
            'critical', 'Cannot establish SSH tunnel to athena server.'
            'Timeout or incorrect prompt!')
      self.log.SendLine(None, info)

    print '========================'
    print self.params['athena_user'] + '@10.1.16.250\'s password:'
    p_ssh.sendline(self.params['athena_pwd'])
    p_ssh.expect(self.params['athena_prompt'])
    p_ssh.expect(self.params['athena_prompt'])
    print p_ssh.before
    print '========================'
    print 'login athena3'
    self.p_ssh = p_ssh

  def SshRetry(self, max_retry=20, retry_delay=15):
    """Establish ssh connection to Device, retry upon failure.

    Args:
      max_retry: the total number of retry
      retry_delay: delay between each ssh try
    """
    tunnel = self.Ssh()
    if not tunnel:
      retry = 0
      while not tunnel and retry < max_retry:
        delay = retry_delay
        info = self.log.CreateErrorInfo(
            'Warning', 'Failed to create ssh tunnel, retry after '
            + str(delay) + ' seconds')
        self.log.SendLine(None, info)
        time.sleep(delay)
        tunnel = self.p_ssh.Ssh()
        retry += 1

      if retry >= max_retry:
        info = self.log.CreateErrorInfo(
            'critical', 'Cannot establish ssh tunnel to Device within '
            + str(max_retry*delay)
            + ' seconds! Timeout or incorrect command prompt!')
        self.log.SendLine(None, info)
        os.sys.exit(1)

    info = self.log.CreateProgressInfo(
        '---', 'ssh session to Device successfully established!')
    self.log.SendLine(None, info)

  def SshRetryFromAthena(self, max_retry=5, retry_delay=15):
    """Establish ssh connection to Device from athenasrv3, retry upon failure.

    Args:
      max_retry: the total number of retry
      retry_delay: delay between each ssh try
    """
    tunnel = self.SshFromAthena()
    if not tunnel:
      retry = 0
      while not tunnel and retry < max_retry:
        delay = retry_delay
        info = self.log.CreateErrorInfo(
            'Warning', 'Failed to create ssh tunnel, retry after '
            + str(delay) + ' seconds')
        self.log.SendLine(None, info)
        time.sleep(delay)
        tunnel = self.SshFromAthena()
        retry += 1

      if retry >= max_retry:
        info = self.log.CreateErrorInfo(
            'critical', 'Cannot establish ssh tunnel to Device within '
            + str(max_retry*delay)
            + ' seconds! Timeout or incorrect command prompt!')
        self.log.SendLine(None, info)
        os.sys.exit(1)

    info = self.log.CreateProgressInfo(
        '---', 'ssh session to Device successfully established!')
    self.log.SendLine(None, info)

  def SendCmd(self, cmd):
    """Send a command to the Device over ssh tunnel."""
    self.p_ssh.sendline(cmd)
    i = self.p_ssh.expect(
        [self.params['bruno_prompt'], pexpect.EOF, pexpect.TIMEOUT], 10)

    while 1:
      if i == 0:
        # send command successfully
        return True
      if i == 1 or i == 2:
        # prompt not returned
        info = self.log.CreateErrorInfo(
            'Warning', 'Device not responding to shell command or timeout.'
            ' Connect after 3 seconds ...')
        self.log.SendLine(None, info)
        time.sleep(3)
        self.SshRetryFromAthena()
        self.p_ssh.sendline(cmd)
        i = self.p_ssh.expect(
            [self.params['bruno_prompt'], pexpect.EOF, pexpect.TIMEOUT], 5)

  def GetCmdOutput(self, buff_size=0):
    """Get lines from the command output of the device, in a bottom up manner.

    Args:
      buff_size: if buff_size<=0, return all output of a command
                if buff_size>0, return 'buff_size' from command output
    Returns:
      f_list: The most recent output lines comes at the begging of the f_list
    """
    self.tmp_file.seek(0)
    if buff_size > 0:
      f_list = reversed(self.tmp_file.readlines()[-buff_size:])
    else:
      f_list = reversed(self.tmp_file.readlines())
    f_list = list(f_list)
    return f_list

  def IsClosed(self):
    """Verify if the ssh tunnel is closed.

    Returns:
      return the status of ssh tunnel
    """
    return self.p_ssh.closed

  def ExitStatus(self):
    """Exit status of the ssh process.

    Returns:
      return the exit status of ssh process
    """
    return self.p_ssh.ExitStatus

  def SignalStatus(self):
    """Signal status of the ssh process.

    Returns:
      return the signal status of ssh process
    """
    return self.p_ssh.SignalStatus

  def Close(self):
    # exit the ssh tunnel
    self.p_ssh.sendline('exit')
    i = self.expect([self.p_ssh.params['athena_prompt'],
                     self.p_ssh.params['jump_prompt'],
                     self.p_ssh.params['bruno_prompt'],
                     pexpect.EOF, pexpect.TIMEOUT])
    while 1:
      if i == 0:
        # now at athena server:
        self.p_ssh.sendline('exit')
        i = self.expect([self.p_ssh.params['athena_prompt'],
                         self.p_ssh.params['jump_prompt'],
                         self.p_ssh.params['bruno_prompt'],
                         pexpect.EOF, pexpect.TIMEOUT])
      elif i == 1:
        # now at jump server:
        self.p_ssh.sendline('exit')
        break
      elif i == 2:
        # now at bruno:
        self.p_ssh.sendline('exit')
        i = self.expect([self.p_ssh.params['athena_prompt'],
                         self.p_ssh.params['jump_prompt'],
                         self.p_ssh.params['bruno_prompt'],
                         pexpect.EOF, pexpect.TIMEOUT])
      elif i == 3 or i == 4:
        # time out:
        print 'Timeout when closing the ssh tunnel!'
        break
      else:
        print 'Error'
        break

  def __del__(self):
    self.tempFile.Close()
    self.p_ssh.Close()

# End of SSH class

