Files
test/framework/ssh/ssh_connection.py
croy 3246c599c5 Setting get_pty as kwarg in ssh_connection
get_pty was getting sent as a positional arg instead of
a kwarg causing problems if the timeout needed to be used.

Change-Id: Iff9ad5d623b48ea65f4b8e9baa158d5d4aad09d1
Signed-off-by: croy <Christian.Roy@windriver.com>
2025-05-01 13:52:38 -04:00

551 lines
21 KiB
Python

import codecs
import re
import time
from typing import List
import paramiko
from paramiko.client import SSHClient
from paramiko.sftp_client import SFTPClient
from config.host.objects.host_configuration import HostConfiguration
from framework.logging.automation_logger import get_logger
from framework.ssh.prompt_response import PromptResponse
from framework.threading.thread_manager import ThreadManager
class SSHConnection:
"""
This class holds information and actions for an ssh connection.
"""
def __init__(
self,
name: str,
host: str,
user: str,
password: str,
timeout: int = 30,
ssh_port: int = 22,
jump_host: HostConfiguration = None,
):
"""
Initialize the SSH connection object.
This sets up the basic configuration used to create an SSH session, optionally through a jump host.
Args:
name (str): The name of the connection.
host (str): The target host to connect to.
user (str): The SSH username.
password (str): The SSH password.
timeout (int): The timeout for establishing a connection, in seconds.
ssh_port (int): The port used for SSH. Defaults to 22.
jump_host (HostConfiguration, optional): Configuration for a jump host, if needed.
"""
self.client = SSHClient()
self.name = name
self.host = host
self.user = user
self.password = password
self.timeout = timeout
self.ssh_port = ssh_port
self.jump_host = jump_host
self.is_connected = False
self.last_return_code = None # The last Return Code
# these are values are used for commands that require ssh pass on remote nodes
self.use_ssh_pass = False
self.ssh_pass_host = None
self.ssh_pass_username = None
self.ssh_pass_password = None
self.output_start_line = -1 # for parsing out lines that come by default when using ssh pass
def _connect_to_jump_host(self, allow_agent: bool = True, look_for_keys: bool = True) -> None:
"""
Connect to the configured jump host using SSH.
Uses paramiko to establish the SSH session with the jump host, based on
credentials provided in the `jump_host` configuration.
Args:
allow_agent (bool): Connect to SSH agent (Paramiko arg). Default is True.
look_for_keys (bool): Re-use saved private keys (Paramiko arg). Default is True.
Returns:
None:
"""
try:
host = self.jump_host.get_host()
user_name = self.jump_host.get_credentials().get_user_name()
password = self.jump_host.get_credentials().get_password()
jump_host_ssh_port = self.jump_host.get_ssh_port()
self.client.connect(
host,
username=user_name,
password=password,
timeout=self.timeout,
allow_agent=allow_agent,
look_for_keys=look_for_keys,
port=jump_host_ssh_port,
)
except BaseException as exception:
get_logger().log_error(f"Failed to Connect to Jump-Host {host} with username/password =" f" {user_name}/{password} with timeout {self.timeout}s")
get_logger().log_error(f"Exception: {exception}")
raise BaseException("Failed to connect to Jump-Host")
def connect(self, allow_agent: bool = True, look_for_keys: bool = False) -> bool:
"""
Create an SSH connection to the target host.
Args:
allow_agent (bool): Use SSH agent forwarding (Paramiko arg). Default is True.
look_for_keys (bool): Search for saved private keys (Paramiko arg). Default is False.
Returns:
bool: True if the connection was successful, False otherwise.
"""
is_connection_success = True
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy)
sock = None
try:
# if a jump host is configured, create that connection first
if self.jump_host:
self._connect_to_jump_host(allow_agent, look_for_keys)
sock = self.client.get_transport().open_channel("direct-tcpip", (self.host, self.ssh_port), ("", 0), timeout=self.timeout)
self.client.connect(
self.host,
username=self.user,
password=self.password,
timeout=self.timeout,
allow_agent=allow_agent,
look_for_keys=look_for_keys,
port=self.ssh_port,
sock=sock,
)
self.is_connected = True
except BaseException as exception:
get_logger().log_error(f"Failed to Connect to host {self.host} with username/password =" f" {self.user}/{self.password} with timeout {self.timeout}s")
get_logger().log_error(f"Exception: {exception}")
is_connection_success = False
# connection failed but if a jump host is used, we may still have that connection. Reset the client object
self.client = SSHClient()
self.is_connected = False
return is_connection_success
def send(self, cmd: str, reconnect_timeout: int = 600, get_pty: bool = False) -> str:
"""
Send a command to the SSH session and return the output.
Retries the connection for up to `reconnect_timeout` seconds
if the session is lost.
Args:
cmd (str): The command to execute.
reconnect_timeout (int): Time in seconds to retry the connection.
get_pty (bool): Defaults to False. Whether to request a terminal when running a 'send' command.
Returns:
str: The output of the command.
"""
return self._execute_command("SEND", cmd, reconnect_timeout=reconnect_timeout, get_pty=get_pty)
def send_as_sudo(self, cmd: str, reconnect_timeout: int = 600) -> str:
"""
Sends a command using sudo and returns the output. Waits for reconnect timeout.
Args:
cmd (str): The command to send.
reconnect_timeout (int): How long to wait for SSH reconnection if needed.
Returns:
str: Output of the executed command.
"""
return self._execute_command("SEND_SUDO", cmd, reconnect_timeout=reconnect_timeout)
def send_expect_prompts(self, cmd: str, prompts: List[PromptResponse], reconnect_timeout: int = 600) -> str:
"""
Sends a command, waits for prompts and returns the output.
Waits for the reconnect timeout in case of SSH disconnects.
Args:
cmd (str): The command to send.
prompts (List[PromptResponse]): The prompts to expect.
reconnect_timeout (int): The amount of time in seconds to wait for SSH connection.
Returns:
str: The output of the command.
"""
return self._execute_command("SEND_EXPECT_PROMPTS", cmd, prompts=prompts, reconnect_timeout=reconnect_timeout)
def _execute_command(
self,
action: str,
cmd: str,
reconnect_timeout: int = 600,
prompts: List[PromptResponse] = None,
get_pty: bool = False,
) -> str:
"""
Executes the given action with the given command.
Waits for reconnect timeout in case of SSH disconnects.
Args:
action (str): The action to execute, e.g., SEND, SEND_SUDO, SEND_EXPECT_PROMPTS.
cmd (str): The command to run.
reconnect_timeout (int): The time in seconds to wait for SSH connection.
prompts (List[PromptResponse], optional): Expected prompts, if any.
get_pty (bool): Defaults to False. Whether to request a terminal when running a 'send' command.
Returns:
str: The output of the command.
"""
timeout = time.time() + reconnect_timeout
refresh_timeout = 5
# if we are using ssh pass, we need to wrap the call
if self.use_ssh_pass:
if action == "SEND_SUDO": # if it a sudo call we need further changes to avoid password prompt
cmd = f'{self.get_ssh_pass_str()} "echo "{self.ssh_pass_password}" | sudo -S {cmd}"'
# since we do not need prompts or to prepend sudo now, change Action to just 'SEND'
action = "SEND"
else:
cmd = f"{self.get_ssh_pass_str()} '{cmd}'"
while time.time() < timeout:
try:
if not self.is_connected:
self.connect()
thread_manager = ThreadManager(timeout=reconnect_timeout / 10)
if action == "SEND":
thread_manager.start_thread("SSH_Command", self._send, cmd, get_pty=get_pty)
elif action == "SEND_SUDO":
thread_manager.start_thread("SSH_Command", self._send_as_sudo, cmd)
elif action == "SEND_EXPECT_PROMPTS":
thread_manager.start_thread("SSH_Command", self._send_expect_prompts, cmd, prompts)
else:
raise ValueError(f"{action} is not a supported command for an SSHConnection.")
thread_manager.join_all_threads()
output = thread_manager.get_thread_object("SSH_Command").get_result()
# if we use ssh pass we want to skip the preamble before sending back ouput
if self.use_ssh_pass and self.output_start_line != -1: # if -1 it's the call to get preamble so return whole output
output = output[self.output_start_line :]
return output
except Exception as e:
get_logger().log_info(f"SSH command failed to execute. Reconnecting and trying again in {refresh_timeout} seconds. " f"Exception: {str(e)}")
time.sleep(refresh_timeout)
self.is_connected = False
def _send(self, cmd: str, timeout: int = 30, get_pty: bool = False) -> str:
"""
Sends the given command with the specified timeout.
Args:
cmd (str): The command to send.
timeout (int): The timeout in seconds for command execution.
get_pty (bool): Defaults to False. Whether to request a terminal when running a 'send' command.
Returns:
str: The output of the command.
"""
get_logger().log_ssh(cmd)
stdin, stdout, stderr = self.client.exec_command(cmd, timeout=timeout, get_pty=get_pty)
stdout.channel.set_combine_stderr(True)
self.last_return_code = stdout.channel.recv_exit_status()
output = stdout.readlines()
for line in output:
clean_line = line.rstrip("\n")
get_logger().log_ssh(clean_line)
return output
def _send_as_sudo(self, cmd: str) -> str:
"""
Sends the specified command using sudo and handles the password prompt.
Args:
cmd (str): The command to execute with sudo.
Returns:
str: The output of the command.
"""
# Deliberately skipping the "P" in the password as some prompts have
# different cases
sudo_password_prompt = PromptResponse("assword", self.password)
sudo_completed = PromptResponse("@")
sudo_prompts = [sudo_password_prompt, sudo_completed]
return self.send_expect_prompts("sudo {}".format(cmd), sudo_prompts)
def _send_expect_prompts(self, cmd: str, prompts: List[PromptResponse], timeout: int = 30) -> str:
"""
Send the command and wait for the specified prompts in order.
Args:
cmd (str): The command to execute.
prompts (List[PromptResponse]): An ordered list of prompts we expect and
the associated responses.
timeout (int): Timeout in seconds to wait for each prompt.
Returns:
str: The SSH output generated before the last prompt. If there are
intermediate prompts, returns the output between the last two prompts.
"""
if not prompts or len(prompts) < 1:
raise ValueError("You must specify a list with at least one prompt to call this " "function. Otherwise, please call 'send' instead.")
code = -1
# Open up a channel to control the SSH connection and send the command.
channel = self.client.invoke_shell()
self.__send_in_channel(channel, cmd)
# Keep going until we have matched every prompt in order
# Or we timeout from receiving output from the ssh connection
for prompt in prompts:
is_prompt_match = False
output_since_last_prompt = ""
while not is_prompt_match:
# Read the response from the server.
code, output_buffer = self.__read_from_channel(channel, timeout)
if code != 0:
get_logger().log_warning(f"Failed to match prompt of {prompt.get_prompt_substring()}")
break
# Log the current console output.
get_logger().log_info(output_buffer.rstrip())
# Add the currently read buffer to the output
output_since_last_prompt += output_buffer
prompt.set_complete_output(output_since_last_prompt)
is_prompt_match = prompt.get_prompt_substring() in output_since_last_prompt
# If we match the prompt, send the associated response if any.
if is_prompt_match and prompt.get_prompt_response():
self.__send_in_channel(channel, prompt.get_prompt_response())
self.last_return_code = code
complete_output = prompts[-1].get_complete_output()
# output is a long string, break into list using line breaks but add back the line break as it's needed
# for table parsing
output_list = [line + "\n" for line in complete_output.split("\n") if line]
return output_list
def __send_in_channel(self, ssh_channel: paramiko.Channel, cmd: str) -> None:
"""
Send a command through the given SSH channel.
This method assumes the channel was opened via `invoke_shell()` and waits
until the channel is ready before sending the command.
Args:
ssh_channel (paramiko.Channel): The SSH channel obtained from
`self.client.invoke_shell()`.
cmd (str): The command to send.
"""
while not ssh_channel.send_ready():
time.sleep(0.009) # Avoid spamming the channel. Value taken from paramiko-expect.
get_logger().log_info(f"Sending command: {cmd}")
ssh_channel.send(cmd)
ssh_channel.send("\n")
def __read_from_channel(self, ssh_channel: paramiko.Channel, timeout: int) -> tuple[int, str]:
"""
Read data from an SSH channel opened via `invoke_shell()`.
Waits for the channel to be ready and reads the output. Times out if no
response is received in the given number of seconds.
Args:
ssh_channel (paramiko.Channel): The SSH channel obtained from
`self.client.invoke_shell()`.
timeout (int): Time in seconds to wait for a response.
Returns:
tuple[int, str]: A tuple of return code and string output.
- Return code: 0 on success, -1 on timeout or connection closed.
- Output: The response string read from the SSH channel.
"""
# Setup Variables
decoder = codecs.getincrementaldecoder("utf-8")()
base_time = time.time()
# Avoids paramiko hang when recv is not ready yet
while not ssh_channel.recv_ready():
time.sleep(0.009) # Avoid spamming the channel. Value taken from paramiko-expect.
if time.time() >= (base_time + timeout):
get_logger().log_warning("SSH output read timed out — buffer may be incomplete or prompt unmatched.")
return -1, "Timeout Exceeded"
# Read some of the output
current_buffer = ssh_channel.recv(1024)
# If we have an empty buffer, then the SSH session has been closed
if len(current_buffer) == 0:
get_logger().log_warning("SSH session closed: received empty buffer from remote channel")
return -1, "Connection has been closed"
# Convert the buffer to our chosen encoding
current_buffer_decoded = decoder.decode(current_buffer)
# Strip ANSI escape sequences added by shell commands like sudo or colored prompts
# These sequences are common in interactive `invoke_shell()` sessions
current_buffer_decoded = self._strip_ansi_sequences(current_buffer_decoded)
# Strip all ugly \r (Ctrl-M making) characters from the current read
current_buffer_decoded = current_buffer_decoded.replace("\r", "")
return 0, current_buffer_decoded
def get_return_code(self) -> str:
"""
Return the last return code captured by this SSH connection.
Returns:
str: The last return code from the most recent SSH command.
"""
return self.last_return_code
def close(self) -> None:
"""
Close the SSH connection.
This shuts down the underlying Paramiko SSH client.
Returns:
None:
"""
self.client.close()
def set_name(self, name: str) -> None:
"""
Sets the name of this SSH connection
Args:
name (str): Name to assign to this SSH connection
"""
self.name = name
def get_name(self) -> str:
"""
Get the name of this SSH connection.
Returns:
str: The name of the connection.
"""
return self.name
def get_sftp_client(self, reconnect_timeout: int = 600) -> SFTPClient:
"""
Get an SFTP client for file operations.
Retries the connection for up to `reconnect_timeout` seconds if disconnected.
Args:
reconnect_timeout (int): The number of seconds to retry connecting.
Returns:
SFTPClient: A Paramiko SFTP client for performing file operations.
"""
timeout = time.time() + reconnect_timeout
refresh_timeout = 5
sftp_client: SFTPClient = None
while time.time() < timeout:
try:
if not self.is_connected:
self.connect()
sftp_client = self.client.open_sftp()
if sftp_client:
return sftp_client
else:
raise "SFTP Client was None" # should be caught in the except block which tries to reconnect
except Exception as e:
get_logger().log_info(f"Failed to get sftp client. Reconnecting and trying again in {refresh_timeout} seconds. " f"Exception: {str(e)}")
time.sleep(refresh_timeout)
self.is_connected = False
return sftp_client
def setup_ssh_pass(self, host_name: str, host_user_name: str, host_password: str) -> None:
"""
Set up the connection to use sshpass for remote authentication.
Stores SSH credentials and calculates the starting line number to strip
connection preamble in subsequent SSH commands.
Args:
host_name (str): The host to use sshpass on.
host_user_name (str): The username for SSH authentication.
host_password (str): The password for SSH authentication.
"""
# setup this ssh connection with ssh pass parameters
self.use_ssh_pass = True
self.ssh_pass_host = host_name
self.ssh_pass_username = host_user_name
self.ssh_pass_password = host_password
# get preamble so we can parse it out
output = self.send("\n")
self.output_start_line = len(output)
def get_ssh_pass_str(self) -> str:
"""
Return the SSH pass command string.
This wraps SSH calls with `sshpass` to support automated password-based login.
Returns:
str: The formatted SSH pass command string.
"""
return f"sshpass -p '{self.ssh_pass_password}' ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null {self.ssh_pass_username}@{self.ssh_pass_host}"
def __str__(self) -> str:
"""
Return the string representation of this connection.
Returns:
str: A string identifying this SSH connection.
"""
return f"ssh_con:{self.name}"
@staticmethod
def _strip_ansi_sequences(text: str) -> str:
"""
Remove ANSI escape sequences from a string.
This is a commonly used regular expression for matching ANSI terminal
control codes (e.g., color codes, cursor movement, etc.). These are
typically found in output from interactive shell commands.
Regex pattern adapted from:
- https://stackoverflow.com/a/14693789
- https://github.com/chalk/ansi-regex/blob/main/index.js
Args:
text (str): A string that may contain ANSI escape sequences.
Returns:
str: Cleaned string without ANSI sequences.
"""
ansi_escape_pattern = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
return ansi_escape_pattern.sub("", text)