
get_pty was getting sent as a positional arg instead of a kwarg causing problems if the timeout needed to be used. Change-Id: Iff9ad5d623b48ea65f4b8e9baa158d5d4aad09d1 Signed-off-by: croy <Christian.Roy@windriver.com>
551 lines
21 KiB
Python
551 lines
21 KiB
Python
import codecs
|
|
import re
|
|
import time
|
|
from typing import List
|
|
|
|
import paramiko
|
|
from paramiko.client import SSHClient
|
|
from paramiko.sftp_client import SFTPClient
|
|
|
|
from config.host.objects.host_configuration import HostConfiguration
|
|
from framework.logging.automation_logger import get_logger
|
|
from framework.ssh.prompt_response import PromptResponse
|
|
from framework.threading.thread_manager import ThreadManager
|
|
|
|
|
|
class SSHConnection:
|
|
"""
|
|
This class holds information and actions for an ssh connection.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
host: str,
|
|
user: str,
|
|
password: str,
|
|
timeout: int = 30,
|
|
ssh_port: int = 22,
|
|
jump_host: HostConfiguration = None,
|
|
):
|
|
"""
|
|
Initialize the SSH connection object.
|
|
|
|
This sets up the basic configuration used to create an SSH session, optionally through a jump host.
|
|
|
|
Args:
|
|
name (str): The name of the connection.
|
|
host (str): The target host to connect to.
|
|
user (str): The SSH username.
|
|
password (str): The SSH password.
|
|
timeout (int): The timeout for establishing a connection, in seconds.
|
|
ssh_port (int): The port used for SSH. Defaults to 22.
|
|
jump_host (HostConfiguration, optional): Configuration for a jump host, if needed.
|
|
"""
|
|
self.client = SSHClient()
|
|
self.name = name
|
|
self.host = host
|
|
self.user = user
|
|
self.password = password
|
|
self.timeout = timeout
|
|
self.ssh_port = ssh_port
|
|
self.jump_host = jump_host
|
|
self.is_connected = False
|
|
|
|
self.last_return_code = None # The last Return Code
|
|
|
|
# these are values are used for commands that require ssh pass on remote nodes
|
|
self.use_ssh_pass = False
|
|
self.ssh_pass_host = None
|
|
self.ssh_pass_username = None
|
|
self.ssh_pass_password = None
|
|
self.output_start_line = -1 # for parsing out lines that come by default when using ssh pass
|
|
|
|
def _connect_to_jump_host(self, allow_agent: bool = True, look_for_keys: bool = True) -> None:
|
|
"""
|
|
Connect to the configured jump host using SSH.
|
|
|
|
Uses paramiko to establish the SSH session with the jump host, based on
|
|
credentials provided in the `jump_host` configuration.
|
|
|
|
Args:
|
|
allow_agent (bool): Connect to SSH agent (Paramiko arg). Default is True.
|
|
look_for_keys (bool): Re-use saved private keys (Paramiko arg). Default is True.
|
|
|
|
Returns:
|
|
None:
|
|
"""
|
|
try:
|
|
host = self.jump_host.get_host()
|
|
user_name = self.jump_host.get_credentials().get_user_name()
|
|
password = self.jump_host.get_credentials().get_password()
|
|
jump_host_ssh_port = self.jump_host.get_ssh_port()
|
|
self.client.connect(
|
|
host,
|
|
username=user_name,
|
|
password=password,
|
|
timeout=self.timeout,
|
|
allow_agent=allow_agent,
|
|
look_for_keys=look_for_keys,
|
|
port=jump_host_ssh_port,
|
|
)
|
|
|
|
except BaseException as exception:
|
|
get_logger().log_error(f"Failed to Connect to Jump-Host {host} with username/password =" f" {user_name}/{password} with timeout {self.timeout}s")
|
|
get_logger().log_error(f"Exception: {exception}")
|
|
raise BaseException("Failed to connect to Jump-Host")
|
|
|
|
def connect(self, allow_agent: bool = True, look_for_keys: bool = False) -> bool:
|
|
"""
|
|
Create an SSH connection to the target host.
|
|
|
|
Args:
|
|
allow_agent (bool): Use SSH agent forwarding (Paramiko arg). Default is True.
|
|
look_for_keys (bool): Search for saved private keys (Paramiko arg). Default is False.
|
|
|
|
Returns:
|
|
bool: True if the connection was successful, False otherwise.
|
|
"""
|
|
is_connection_success = True
|
|
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy)
|
|
sock = None
|
|
try:
|
|
# if a jump host is configured, create that connection first
|
|
if self.jump_host:
|
|
self._connect_to_jump_host(allow_agent, look_for_keys)
|
|
sock = self.client.get_transport().open_channel("direct-tcpip", (self.host, self.ssh_port), ("", 0), timeout=self.timeout)
|
|
|
|
self.client.connect(
|
|
self.host,
|
|
username=self.user,
|
|
password=self.password,
|
|
timeout=self.timeout,
|
|
allow_agent=allow_agent,
|
|
look_for_keys=look_for_keys,
|
|
port=self.ssh_port,
|
|
sock=sock,
|
|
)
|
|
self.is_connected = True
|
|
except BaseException as exception:
|
|
get_logger().log_error(f"Failed to Connect to host {self.host} with username/password =" f" {self.user}/{self.password} with timeout {self.timeout}s")
|
|
get_logger().log_error(f"Exception: {exception}")
|
|
is_connection_success = False
|
|
# connection failed but if a jump host is used, we may still have that connection. Reset the client object
|
|
self.client = SSHClient()
|
|
self.is_connected = False
|
|
|
|
return is_connection_success
|
|
|
|
def send(self, cmd: str, reconnect_timeout: int = 600, get_pty: bool = False) -> str:
|
|
"""
|
|
Send a command to the SSH session and return the output.
|
|
|
|
Retries the connection for up to `reconnect_timeout` seconds
|
|
if the session is lost.
|
|
|
|
Args:
|
|
cmd (str): The command to execute.
|
|
reconnect_timeout (int): Time in seconds to retry the connection.
|
|
get_pty (bool): Defaults to False. Whether to request a terminal when running a 'send' command.
|
|
|
|
Returns:
|
|
str: The output of the command.
|
|
"""
|
|
return self._execute_command("SEND", cmd, reconnect_timeout=reconnect_timeout, get_pty=get_pty)
|
|
|
|
def send_as_sudo(self, cmd: str, reconnect_timeout: int = 600) -> str:
|
|
"""
|
|
Sends a command using sudo and returns the output. Waits for reconnect timeout.
|
|
|
|
Args:
|
|
cmd (str): The command to send.
|
|
reconnect_timeout (int): How long to wait for SSH reconnection if needed.
|
|
|
|
Returns:
|
|
str: Output of the executed command.
|
|
"""
|
|
return self._execute_command("SEND_SUDO", cmd, reconnect_timeout=reconnect_timeout)
|
|
|
|
def send_expect_prompts(self, cmd: str, prompts: List[PromptResponse], reconnect_timeout: int = 600) -> str:
|
|
"""
|
|
Sends a command, waits for prompts and returns the output.
|
|
|
|
Waits for the reconnect timeout in case of SSH disconnects.
|
|
|
|
Args:
|
|
cmd (str): The command to send.
|
|
prompts (List[PromptResponse]): The prompts to expect.
|
|
reconnect_timeout (int): The amount of time in seconds to wait for SSH connection.
|
|
|
|
Returns:
|
|
str: The output of the command.
|
|
"""
|
|
return self._execute_command("SEND_EXPECT_PROMPTS", cmd, prompts=prompts, reconnect_timeout=reconnect_timeout)
|
|
|
|
def _execute_command(
|
|
self,
|
|
action: str,
|
|
cmd: str,
|
|
reconnect_timeout: int = 600,
|
|
prompts: List[PromptResponse] = None,
|
|
get_pty: bool = False,
|
|
) -> str:
|
|
"""
|
|
Executes the given action with the given command.
|
|
|
|
Waits for reconnect timeout in case of SSH disconnects.
|
|
|
|
Args:
|
|
action (str): The action to execute, e.g., SEND, SEND_SUDO, SEND_EXPECT_PROMPTS.
|
|
cmd (str): The command to run.
|
|
reconnect_timeout (int): The time in seconds to wait for SSH connection.
|
|
prompts (List[PromptResponse], optional): Expected prompts, if any.
|
|
get_pty (bool): Defaults to False. Whether to request a terminal when running a 'send' command.
|
|
|
|
Returns:
|
|
str: The output of the command.
|
|
"""
|
|
timeout = time.time() + reconnect_timeout
|
|
refresh_timeout = 5
|
|
|
|
# if we are using ssh pass, we need to wrap the call
|
|
if self.use_ssh_pass:
|
|
if action == "SEND_SUDO": # if it a sudo call we need further changes to avoid password prompt
|
|
cmd = f'{self.get_ssh_pass_str()} "echo "{self.ssh_pass_password}" | sudo -S {cmd}"'
|
|
# since we do not need prompts or to prepend sudo now, change Action to just 'SEND'
|
|
action = "SEND"
|
|
else:
|
|
cmd = f"{self.get_ssh_pass_str()} '{cmd}'"
|
|
|
|
while time.time() < timeout:
|
|
try:
|
|
|
|
if not self.is_connected:
|
|
self.connect()
|
|
|
|
thread_manager = ThreadManager(timeout=reconnect_timeout / 10)
|
|
|
|
if action == "SEND":
|
|
thread_manager.start_thread("SSH_Command", self._send, cmd, get_pty=get_pty)
|
|
elif action == "SEND_SUDO":
|
|
thread_manager.start_thread("SSH_Command", self._send_as_sudo, cmd)
|
|
elif action == "SEND_EXPECT_PROMPTS":
|
|
thread_manager.start_thread("SSH_Command", self._send_expect_prompts, cmd, prompts)
|
|
else:
|
|
raise ValueError(f"{action} is not a supported command for an SSHConnection.")
|
|
|
|
thread_manager.join_all_threads()
|
|
output = thread_manager.get_thread_object("SSH_Command").get_result()
|
|
|
|
# if we use ssh pass we want to skip the preamble before sending back ouput
|
|
if self.use_ssh_pass and self.output_start_line != -1: # if -1 it's the call to get preamble so return whole output
|
|
output = output[self.output_start_line :]
|
|
return output
|
|
|
|
except Exception as e:
|
|
get_logger().log_info(f"SSH command failed to execute. Reconnecting and trying again in {refresh_timeout} seconds. " f"Exception: {str(e)}")
|
|
time.sleep(refresh_timeout)
|
|
self.is_connected = False
|
|
|
|
def _send(self, cmd: str, timeout: int = 30, get_pty: bool = False) -> str:
|
|
"""
|
|
Sends the given command with the specified timeout.
|
|
|
|
Args:
|
|
cmd (str): The command to send.
|
|
timeout (int): The timeout in seconds for command execution.
|
|
get_pty (bool): Defaults to False. Whether to request a terminal when running a 'send' command.
|
|
|
|
Returns:
|
|
str: The output of the command.
|
|
"""
|
|
get_logger().log_ssh(cmd)
|
|
|
|
stdin, stdout, stderr = self.client.exec_command(cmd, timeout=timeout, get_pty=get_pty)
|
|
stdout.channel.set_combine_stderr(True)
|
|
self.last_return_code = stdout.channel.recv_exit_status()
|
|
output = stdout.readlines()
|
|
|
|
for line in output:
|
|
clean_line = line.rstrip("\n")
|
|
get_logger().log_ssh(clean_line)
|
|
|
|
return output
|
|
|
|
def _send_as_sudo(self, cmd: str) -> str:
|
|
"""
|
|
Sends the specified command using sudo and handles the password prompt.
|
|
|
|
Args:
|
|
cmd (str): The command to execute with sudo.
|
|
|
|
Returns:
|
|
str: The output of the command.
|
|
"""
|
|
# Deliberately skipping the "P" in the password as some prompts have
|
|
# different cases
|
|
sudo_password_prompt = PromptResponse("assword", self.password)
|
|
sudo_completed = PromptResponse("@")
|
|
sudo_prompts = [sudo_password_prompt, sudo_completed]
|
|
return self.send_expect_prompts("sudo {}".format(cmd), sudo_prompts)
|
|
|
|
def _send_expect_prompts(self, cmd: str, prompts: List[PromptResponse], timeout: int = 30) -> str:
|
|
"""
|
|
Send the command and wait for the specified prompts in order.
|
|
|
|
Args:
|
|
cmd (str): The command to execute.
|
|
prompts (List[PromptResponse]): An ordered list of prompts we expect and
|
|
the associated responses.
|
|
timeout (int): Timeout in seconds to wait for each prompt.
|
|
|
|
Returns:
|
|
str: The SSH output generated before the last prompt. If there are
|
|
intermediate prompts, returns the output between the last two prompts.
|
|
"""
|
|
if not prompts or len(prompts) < 1:
|
|
raise ValueError("You must specify a list with at least one prompt to call this " "function. Otherwise, please call 'send' instead.")
|
|
|
|
code = -1
|
|
|
|
# Open up a channel to control the SSH connection and send the command.
|
|
channel = self.client.invoke_shell()
|
|
self.__send_in_channel(channel, cmd)
|
|
|
|
# Keep going until we have matched every prompt in order
|
|
# Or we timeout from receiving output from the ssh connection
|
|
for prompt in prompts:
|
|
is_prompt_match = False
|
|
output_since_last_prompt = ""
|
|
|
|
while not is_prompt_match:
|
|
# Read the response from the server.
|
|
code, output_buffer = self.__read_from_channel(channel, timeout)
|
|
|
|
if code != 0:
|
|
get_logger().log_warning(f"Failed to match prompt of {prompt.get_prompt_substring()}")
|
|
break
|
|
|
|
# Log the current console output.
|
|
get_logger().log_info(output_buffer.rstrip())
|
|
|
|
# Add the currently read buffer to the output
|
|
output_since_last_prompt += output_buffer
|
|
prompt.set_complete_output(output_since_last_prompt)
|
|
is_prompt_match = prompt.get_prompt_substring() in output_since_last_prompt
|
|
|
|
# If we match the prompt, send the associated response if any.
|
|
if is_prompt_match and prompt.get_prompt_response():
|
|
self.__send_in_channel(channel, prompt.get_prompt_response())
|
|
|
|
self.last_return_code = code
|
|
|
|
complete_output = prompts[-1].get_complete_output()
|
|
|
|
# output is a long string, break into list using line breaks but add back the line break as it's needed
|
|
# for table parsing
|
|
output_list = [line + "\n" for line in complete_output.split("\n") if line]
|
|
|
|
return output_list
|
|
|
|
def __send_in_channel(self, ssh_channel: paramiko.Channel, cmd: str) -> None:
|
|
"""
|
|
Send a command through the given SSH channel.
|
|
|
|
This method assumes the channel was opened via `invoke_shell()` and waits
|
|
until the channel is ready before sending the command.
|
|
|
|
Args:
|
|
ssh_channel (paramiko.Channel): The SSH channel obtained from
|
|
`self.client.invoke_shell()`.
|
|
cmd (str): The command to send.
|
|
"""
|
|
while not ssh_channel.send_ready():
|
|
time.sleep(0.009) # Avoid spamming the channel. Value taken from paramiko-expect.
|
|
get_logger().log_info(f"Sending command: {cmd}")
|
|
ssh_channel.send(cmd)
|
|
ssh_channel.send("\n")
|
|
|
|
def __read_from_channel(self, ssh_channel: paramiko.Channel, timeout: int) -> tuple[int, str]:
|
|
"""
|
|
Read data from an SSH channel opened via `invoke_shell()`.
|
|
|
|
Waits for the channel to be ready and reads the output. Times out if no
|
|
response is received in the given number of seconds.
|
|
|
|
Args:
|
|
ssh_channel (paramiko.Channel): The SSH channel obtained from
|
|
`self.client.invoke_shell()`.
|
|
timeout (int): Time in seconds to wait for a response.
|
|
|
|
Returns:
|
|
tuple[int, str]: A tuple of return code and string output.
|
|
- Return code: 0 on success, -1 on timeout or connection closed.
|
|
- Output: The response string read from the SSH channel.
|
|
"""
|
|
# Setup Variables
|
|
decoder = codecs.getincrementaldecoder("utf-8")()
|
|
base_time = time.time()
|
|
|
|
# Avoids paramiko hang when recv is not ready yet
|
|
while not ssh_channel.recv_ready():
|
|
time.sleep(0.009) # Avoid spamming the channel. Value taken from paramiko-expect.
|
|
if time.time() >= (base_time + timeout):
|
|
get_logger().log_warning("SSH output read timed out — buffer may be incomplete or prompt unmatched.")
|
|
return -1, "Timeout Exceeded"
|
|
|
|
# Read some of the output
|
|
current_buffer = ssh_channel.recv(1024)
|
|
|
|
# If we have an empty buffer, then the SSH session has been closed
|
|
if len(current_buffer) == 0:
|
|
get_logger().log_warning("SSH session closed: received empty buffer from remote channel")
|
|
return -1, "Connection has been closed"
|
|
|
|
# Convert the buffer to our chosen encoding
|
|
current_buffer_decoded = decoder.decode(current_buffer)
|
|
|
|
# Strip ANSI escape sequences added by shell commands like sudo or colored prompts
|
|
# These sequences are common in interactive `invoke_shell()` sessions
|
|
current_buffer_decoded = self._strip_ansi_sequences(current_buffer_decoded)
|
|
|
|
# Strip all ugly \r (Ctrl-M making) characters from the current read
|
|
current_buffer_decoded = current_buffer_decoded.replace("\r", "")
|
|
|
|
return 0, current_buffer_decoded
|
|
|
|
def get_return_code(self) -> str:
|
|
"""
|
|
Return the last return code captured by this SSH connection.
|
|
|
|
Returns:
|
|
str: The last return code from the most recent SSH command.
|
|
"""
|
|
return self.last_return_code
|
|
|
|
def close(self) -> None:
|
|
"""
|
|
Close the SSH connection.
|
|
|
|
This shuts down the underlying Paramiko SSH client.
|
|
|
|
Returns:
|
|
None:
|
|
"""
|
|
self.client.close()
|
|
|
|
def set_name(self, name: str) -> None:
|
|
"""
|
|
Sets the name of this SSH connection
|
|
|
|
Args:
|
|
name (str): Name to assign to this SSH connection
|
|
|
|
"""
|
|
self.name = name
|
|
|
|
def get_name(self) -> str:
|
|
"""
|
|
Get the name of this SSH connection.
|
|
|
|
Returns:
|
|
str: The name of the connection.
|
|
"""
|
|
return self.name
|
|
|
|
def get_sftp_client(self, reconnect_timeout: int = 600) -> SFTPClient:
|
|
"""
|
|
Get an SFTP client for file operations.
|
|
|
|
Retries the connection for up to `reconnect_timeout` seconds if disconnected.
|
|
|
|
Args:
|
|
reconnect_timeout (int): The number of seconds to retry connecting.
|
|
|
|
Returns:
|
|
SFTPClient: A Paramiko SFTP client for performing file operations.
|
|
"""
|
|
timeout = time.time() + reconnect_timeout
|
|
refresh_timeout = 5
|
|
|
|
sftp_client: SFTPClient = None
|
|
while time.time() < timeout:
|
|
try:
|
|
if not self.is_connected:
|
|
self.connect()
|
|
sftp_client = self.client.open_sftp()
|
|
if sftp_client:
|
|
return sftp_client
|
|
else:
|
|
raise "SFTP Client was None" # should be caught in the except block which tries to reconnect
|
|
except Exception as e:
|
|
get_logger().log_info(f"Failed to get sftp client. Reconnecting and trying again in {refresh_timeout} seconds. " f"Exception: {str(e)}")
|
|
time.sleep(refresh_timeout)
|
|
self.is_connected = False
|
|
|
|
return sftp_client
|
|
|
|
def setup_ssh_pass(self, host_name: str, host_user_name: str, host_password: str) -> None:
|
|
"""
|
|
Set up the connection to use sshpass for remote authentication.
|
|
|
|
Stores SSH credentials and calculates the starting line number to strip
|
|
connection preamble in subsequent SSH commands.
|
|
|
|
Args:
|
|
host_name (str): The host to use sshpass on.
|
|
host_user_name (str): The username for SSH authentication.
|
|
host_password (str): The password for SSH authentication.
|
|
"""
|
|
# setup this ssh connection with ssh pass parameters
|
|
self.use_ssh_pass = True
|
|
self.ssh_pass_host = host_name
|
|
self.ssh_pass_username = host_user_name
|
|
self.ssh_pass_password = host_password
|
|
|
|
# get preamble so we can parse it out
|
|
output = self.send("\n")
|
|
self.output_start_line = len(output)
|
|
|
|
def get_ssh_pass_str(self) -> str:
|
|
"""
|
|
Return the SSH pass command string.
|
|
|
|
This wraps SSH calls with `sshpass` to support automated password-based login.
|
|
|
|
Returns:
|
|
str: The formatted SSH pass command string.
|
|
"""
|
|
return f"sshpass -p '{self.ssh_pass_password}' ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null {self.ssh_pass_username}@{self.ssh_pass_host}"
|
|
|
|
def __str__(self) -> str:
|
|
"""
|
|
Return the string representation of this connection.
|
|
|
|
Returns:
|
|
str: A string identifying this SSH connection.
|
|
"""
|
|
return f"ssh_con:{self.name}"
|
|
|
|
@staticmethod
|
|
def _strip_ansi_sequences(text: str) -> str:
|
|
"""
|
|
Remove ANSI escape sequences from a string.
|
|
|
|
This is a commonly used regular expression for matching ANSI terminal
|
|
control codes (e.g., color codes, cursor movement, etc.). These are
|
|
typically found in output from interactive shell commands.
|
|
|
|
Regex pattern adapted from:
|
|
- https://stackoverflow.com/a/14693789
|
|
- https://github.com/chalk/ansi-regex/blob/main/index.js
|
|
|
|
Args:
|
|
text (str): A string that may contain ANSI escape sequences.
|
|
|
|
Returns:
|
|
str: Cleaned string without ANSI sequences.
|
|
"""
|
|
ansi_escape_pattern = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
|
|
return ansi_escape_pattern.sub("", text)
|