Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 34 additions & 15 deletions src/ssh/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@

use super::tokio_client::{AuthMethod, Client};
use crate::jump::{parse_jump_hosts, JumpHostChain};
use crate::ssh::tokio_client::client::{CommandOutput, CommandOutputBuffer};
use anyhow::{Context, Result};
use std::path::Path;
use std::time::Duration;
use tokio::sync::mpsc::Sender;
use zeroize::Zeroizing;

/// Configuration for SSH connection and command execution
Expand Down Expand Up @@ -84,6 +86,32 @@ impl SshClient {
command: &str,
config: &ConnectionConfig<'_>,
) -> Result<CommandResult> {
let CommandOutputBuffer {
sender,
receiver_task,
} = CommandOutputBuffer::new();

let exit_status = self
.connect_and_execute_with_output_streaming(command, config, sender)
.await?;

let (output, stderr) = receiver_task.await?;

// Convert result to our format
Ok(CommandResult {
host: self.host.clone(),
output,
stderr,
exit_status,
})
}

pub async fn connect_and_execute_with_output_streaming(
&mut self,
command: &str,
config: &ConnectionConfig<'_>,
output_sender: Sender<CommandOutput>,
) -> Result<u32> {
tracing::debug!("Connecting to {}:{}", self.host, self.port);

// Determine authentication method based on parameters
Expand Down Expand Up @@ -137,11 +165,11 @@ impl SshClient {
tracing::debug!("Executing command: {}", command);

// Execute command with timeout
let result = if let Some(timeout_secs) = config.timeout_seconds {
let exit_status = if let Some(timeout_secs) = config.timeout_seconds {
if timeout_secs == 0 {
// No timeout (unlimited)
tracing::debug!("Executing command with no timeout (unlimited)");
client.execute(command)
client.execute_streaming(command, output_sender)
.await
.with_context(|| format!("Failed to execute command '{}' on {}:{}. The SSH connection was successful but the command could not be executed.", command, self.host, self.port))?
} else {
Expand All @@ -150,7 +178,7 @@ impl SshClient {
tracing::debug!("Executing command with timeout of {} seconds", timeout_secs);
tokio::time::timeout(
command_timeout,
client.execute(command)
client.execute_streaming(command, output_sender)
)
.await
.with_context(|| format!("Command execution timeout: The command '{}' did not complete within {} seconds on {}:{}", command, timeout_secs, self.host, self.port))?
Expand All @@ -168,25 +196,16 @@ impl SshClient {
tracing::debug!("Executing command with default timeout of 300 seconds");
tokio::time::timeout(
command_timeout,
client.execute(command)
client.execute_streaming(command, output_sender)
)
.await
.with_context(|| format!("Command execution timeout: The command '{}' did not complete within 5 minutes on {}:{}", command, self.host, self.port))?
.with_context(|| format!("Failed to execute command '{}' on {}:{}. The SSH connection was successful but the command could not be executed.", command, self.host, self.port))?
};

tracing::debug!(
"Command execution completed with status: {}",
result.exit_status
);
tracing::debug!("Command execution completed with status: {exit_status}",);

// Convert result to our format
Ok(CommandResult {
host: self.host.clone(),
output: result.stdout.into_bytes(),
stderr: result.stderr.into_bytes(),
exit_status: result.exit_status,
})
Ok(exit_status)
}

/// Create a direct SSH connection (no jump hosts)
Expand Down
121 changes: 99 additions & 22 deletions src/ssh/tokio_client/client.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use russh::client::KeyboardInteractiveAuthResponse;
use russh::{
client::{Config, Handle, Handler, Msg},
Channel,
Channel, CryptoVec,
};
use russh_sftp::{client::SftpSession, protocol::OpenFlags};
use std::net::SocketAddr;
use std::sync::Arc;
use std::{fmt::Debug, path::Path};
use std::{io, path::PathBuf};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::mpsc::Sender;
use tokio::task::JoinHandle;
use zeroize::Zeroizing;

use super::ToSocketAddrsWithHostname;
Expand Down Expand Up @@ -813,38 +815,81 @@ impl Client {
///
/// Returns stdout, stderr and the exit code of the command,
/// packaged in a [`CommandExecutedResult`] struct.
///
/// If you need the stderr output interleaved within stdout, you should postfix the command with a redirection,
/// e.g. `echo foo 2>&1`.
/// If you dont want any output at all, use something like `echo foo >/dev/null 2>&1`.
/// If you don't want any output at all, use something like `echo foo >/dev/null 2>&1`.
///
/// Make sure your commands don't read from stdin and exit after bounded time.
///
/// Can be called multiple times, but every invocation is a new shell context.
/// Thus `cd`, setting variables and alike have no effect on future invocations.
/// Thus, `cd` and setting variables and alike have no effect on future invocations.
pub async fn execute(&self, command: &str) -> Result<CommandExecutedResult, super::Error> {
let CommandOutputBuffer {
sender,
receiver_task,
} = CommandOutputBuffer::new();

let exit_status = self.execute_streaming(command, sender).await?;

let (stdout, stderr) = receiver_task.await?;

let result = CommandExecutedResult {
stdout: String::from_utf8_lossy(&stdout).into(),
stderr: String::from_utf8_lossy(&stderr).into(),
exit_status,
};

Ok(result)
}

/// The same as [`Self:: execute`] except that output from stdout and stderr is
/// provided as it is received via callback functions. Once the command has
/// finished, returns its exit code.
pub async fn execute_streaming(
&self,
command: &str,
sender: Sender<CommandOutput>,
) -> Result<u32, super::Error> {
// Sanitize command to prevent injection attacks
let sanitized_command = crate::utils::sanitize_command(command)
.map_err(|e| super::Error::CommandValidationFailed(e.to_string()))?;

// Pre-allocate buffers with capacity to avoid frequent reallocations
let mut stdout_buffer = Vec::with_capacity(SSH_CMD_BUFFER_SIZE);
let mut stderr_buffer = Vec::with_capacity(SSH_RESPONSE_BUFFER_SIZE);
let mut channel = self.connection_handle.channel_open_session().await?;
channel.exec(true, sanitized_command.as_str()).await?;

let mut result: Option<u32> = None;

let mut receiver_dropped = false;

// While the channel has messages...
while let Some(msg) = channel.wait().await {
//dbg!(&msg);
match msg {
// If we get data, add it to the buffer
russh::ChannelMsg::Data { ref data } => {
stdout_buffer.write_all(data).await.unwrap()
russh::ChannelMsg::Data { data } => {
if let Err(_send_error) = sender.send(CommandOutput::StdOut(data)).await {
// only log the warning once per command
if !receiver_dropped {
receiver_dropped = true;

tracing::warn!(
"receiver dropped; cannot send command output to receiver"
);
}
}
}
russh::ChannelMsg::ExtendedData { ref data, ext } => {
russh::ChannelMsg::ExtendedData { data, ext } => {
if ext == 1 {
stderr_buffer.write_all(data).await.unwrap()
if let Err(_send_error) = sender.send(CommandOutput::StdErr(data)).await {
// only log the warning once per command
if !receiver_dropped {
receiver_dropped = true;

tracing::warn!(
"receiver dropped; cannot send command output to receiver"
);
}
}
}
}

Expand All @@ -862,17 +907,7 @@ impl Client {
}

// If we received an exit code, report it back
if let Some(result) = result {
Ok(CommandExecutedResult {
stdout: String::from_utf8_lossy(&stdout_buffer).to_string(),
stderr: String::from_utf8_lossy(&stderr_buffer).to_string(),
exit_status: result,
})

// Otherwise, report an error
} else {
Err(super::Error::CommandDidntExit)
}
result.ok_or(super::Error::CommandDidntExit)
}

/// Request an interactive shell with PTY support.
Expand Down Expand Up @@ -1008,6 +1043,48 @@ impl Debug for Client {
}
}

/// Partial output of a command
pub enum CommandOutput {
/// Partial stdout output of a command
StdOut(CryptoVec),
/// Partial stderr output of a command
StdErr(CryptoVec),
}

pub(crate) struct CommandOutputBuffer {
pub(crate) sender: Sender<CommandOutput>,
pub(crate) receiver_task: JoinHandle<(Vec<u8>, Vec<u8>)>,
}

impl CommandOutputBuffer {
pub(crate) fn new() -> Self {
// The output collection task should easily keep up with output received from ssh server
const OUTPUT_EVENTS_CHANNEL_SIZE: usize = 100;

let (sender, mut receiver) = tokio::sync::mpsc::channel(OUTPUT_EVENTS_CHANNEL_SIZE);

let receiver_task = tokio::task::spawn(async move {
// Pre-allocate buffers with capacity to avoid frequent reallocations
let mut stdout = Vec::with_capacity(SSH_CMD_BUFFER_SIZE);
let mut stderr = Vec::with_capacity(SSH_RESPONSE_BUFFER_SIZE);

while let Some(output) = receiver.recv().await {
match output {
CommandOutput::StdOut(buffer) => stdout.extend_from_slice(&buffer),
CommandOutput::StdErr(buffer) => stderr.extend_from_slice(&buffer),
}
}

(stdout, stderr)
});

Self {
sender,
receiver_task,
}
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CommandExecutedResult {
/// The stdout output of the command.
Expand Down
2 changes: 2 additions & 0 deletions src/ssh/tokio_client/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ pub enum Error {
SftpError(#[from] russh_sftp::client::error::Error),
#[error("I/O error")]
IoError(#[from] io::Error),
#[error("Task join error: {0}")]
JoinError(#[from] tokio::task::JoinError),
#[error("Command validation failed: {0}")]
CommandValidationFailed(String),
#[error("Port forwarding request failed: {0}")]
Expand Down
Loading