This commit is contained in:
Nikos Papadakis 2024-01-25 01:18:30 +02:00
parent 6c72ebbe3e
commit 221b4348d8
Signed by untrusted user who does not match committer: nikos
GPG key ID: 78871F9905ADFF02
18 changed files with 642 additions and 2270 deletions

1386
Cargo.lock generated

File diff suppressed because it is too large Load diff

1
agent/.gitignore vendored
View file

@ -1 +0,0 @@
/target/

View file

@ -4,26 +4,10 @@ version = "0.1.0"
edition = "2021"
[dependencies]
anyhow = "1.0.71"
chrono = "0.4.26"
clap = "4.3.9"
envy = "0.4.2"
itertools = "0.11.0"
once_cell = "1.18.0"
prost = "0.12.1"
regex = "1.10.2"
reqwest = { version = "0.11.18", features = ["blocking", "json"], default-features = false }
rustix = { version = "0.38.28", features = ["fs", "process", "pty", "stdio", "termios"] }
serde = { version = "1.0.173", features = ["derive"] }
serde_json = "1.0.103"
sysinfo = { version = "0.29.2", default-features = false }
tokio = { version = "1.28.2", features = ["full"] }
tokio-stream = { version = "0.1.14", features = ["net", "sync"] }
tokio-util = { version = "0.7.10", features = ["codec"] }
tonic = { version = "0.10.2" }
tower-http = { version = "0.4.3", features = ["trace"] }
tracing = "0.1.37"
tracing-subscriber = "0.3.17"
[build-dependencies]
tonic-build = "0.10.2"
anyhow = "1.0.79"
async-nats = "0.33.0"
serde_json = "1.0.111"
tokio = { version = "1.35.1", features = ["full"] }
tokio-stream = { version = "0.1.14", default-features = false }
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["fmt"] }

View file

@ -1,6 +0,0 @@
fn main() {
tonic_build::configure()
.build_client(false)
.compile(&["../proto/agent.proto"], &["../proto"])
.unwrap();
}

27
agent/server.conf Normal file
View file

@ -0,0 +1,27 @@
authorization: {
users = [
{
user: prymn_admin
password: prymn_admin
permissions: {
publish: ">"
subscribe: ">"
}
}
{
user: demo_agent
password: demo_agent_password
permissions: {
publish: [
"agents.v1.demo_agent.>"
]
subscribe: [
"agents.v1.demo_agent.>"
"_INBOX_demo_agent.>"
]
}
}
]
}
jetstream: {}

View file

@ -1,33 +0,0 @@
use anyhow::Context;
use clap::arg;
use prymn_agent::{self_update, server};
use tracing::Level;
use tracing_subscriber::{filter::Targets, layer::SubscriberExt, util::SubscriberInitExt};
fn main() -> anyhow::Result<()> {
// Debug subscriber, should be configurable in the future
tracing_subscriber::fmt()
.with_max_level(Level::TRACE)
.pretty()
.without_time()
.with_file(false)
.finish()
.with(
Targets::new()
.with_target("prymn_agent", Level::DEBUG)
.with_target("tower_http", Level::DEBUG),
)
.init();
let command = clap::Command::new(env!("CARGO_BIN_NAME"))
.version(env!("CARGO_PKG_VERSION"))
.arg(arg!(--install <TOKEN> "Install this agent binary to the system").exclusive(true))
.try_get_matches()
.unwrap_or_else(|e| e.exit());
if let Some(token) = command.get_one::<String>("install") {
self_update::install(token).context("failed to install the agent to the system")
} else {
server::run()
}
}

View file

@ -1,18 +0,0 @@
use once_cell::sync::Lazy;
use serde::Deserialize;
#[derive(Deserialize, Debug)]
pub struct Config {
#[serde(default = "default_backend_url")]
pub backend_url: String,
}
fn default_backend_url() -> String {
"https://app.prymn.net".to_string()
}
pub static CONFIG: Lazy<Config> =
Lazy::new(|| match envy::prefixed("PRYMN_").from_env::<Config>() {
Ok(config) => config,
Err(_) => todo!("handle this error"),
});

View file

@ -1,186 +0,0 @@
use std::process::{Command, Output};
use regex::Regex;
pub fn update_package_index() -> std::io::Result<Output> {
Command::new("apt-get").arg("-y").arg("update").output()
}
pub fn run_updates(dry_run: bool) -> std::io::Result<Output> {
let mut command = Command::new("apt-get");
if dry_run {
command.arg("-s");
}
command.arg("-y").arg("upgrade").output()
}
pub fn install_packages(packages: &[&str]) -> std::io::Result<Output> {
Command::new("apt-get")
.arg("install")
.arg("-y")
.args(packages)
.output()
}
pub fn get_available_updates() -> std::io::Result<Vec<String>> {
let output = Command::new("apt-get").arg("-sV").arg("upgrade").output()?;
let upgradables = parse_upgrade_output(&String::from_utf8_lossy(&output.stdout));
Ok(upgradables)
}
fn parse_upgrade_output(output: &str) -> Vec<String> {
output
.split_once("The following packages will be upgraded:\n")
.and_then(|(_, rest)| {
// Find the first line with non-whitespace characters (indicating the end of the list)
let re = Regex::new(r"(?m)^\S").unwrap();
re.find(rest).map(|m| rest.split_at(m.start()).0)
})
.map_or_else(Vec::new, |text| {
let lines = text.lines();
lines.map(|line| line.trim().to_owned()).collect()
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_upgrade_output_correctly() {
// `apt-get -sV upgrade`
let test_output = r"
NOTE: This is only a simulation!
apt-get needs root privileges for real execution.
Keep also in mind that locking is deactivated,
so don't depend on the relevance to the real current situation!
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
Calculating upgrade... Done
The following packages have been kept back:
linux-image-amd64 (5.10.191-1 => 5.10.197-1)
The following packages will be upgraded:
adduser (3.118 => 3.118+deb11u1)
base-files (11.1+deb11u7 => 11.1+deb11u8)
cpio (2.13+dfsg-4 => 2.13+dfsg-7.1~deb11u1)
dbus (1.12.24-0+deb11u1 => 1.12.28-0+deb11u1)
distro-info-data (0.51+deb11u3 => 0.51+deb11u4)
dpkg (1.20.12 => 1.20.13)
grub-common (2.06-3~deb11u5 => 2.06-3~deb11u6)
grub-pc (2.06-3~deb11u5 => 2.06-3~deb11u6)
grub-pc-bin (2.06-3~deb11u5 => 2.06-3~deb11u6)
grub2-common (2.06-3~deb11u5 => 2.06-3~deb11u6)
krb5-locales (1.18.3-6+deb11u3 => 1.18.3-6+deb11u4)
libbsd0 (0.11.3-1 => 0.11.3-1+deb11u1)
libcurl3-gnutls (7.74.0-1.3+deb11u7 => 7.74.0-1.3+deb11u10)
libdbus-1-3 (1.12.24-0+deb11u1 => 1.12.28-0+deb11u1)
libgssapi-krb5-2 (1.18.3-6+deb11u3 => 1.18.3-6+deb11u4)
libk5crypto3 (1.18.3-6+deb11u3 => 1.18.3-6+deb11u4)
libkrb5-3 (1.18.3-6+deb11u3 => 1.18.3-6+deb11u4)
libkrb5support0 (1.18.3-6+deb11u3 => 1.18.3-6+deb11u4)
libncurses6 (6.2+20201114-2+deb11u1 => 6.2+20201114-2+deb11u2)
libncursesw6 (6.2+20201114-2+deb11u1 => 6.2+20201114-2+deb11u2)
libnss-systemd (247.3-7+deb11u2 => 247.3-7+deb11u4)
libpam-systemd (247.3-7+deb11u2 => 247.3-7+deb11u4)
libssl1.1 (1.1.1n-0+deb11u5 => 1.1.1w-0+deb11u1)
libsystemd0 (247.3-7+deb11u2 => 247.3-7+deb11u4)
libtinfo6 (6.2+20201114-2+deb11u1 => 6.2+20201114-2+deb11u2)
libudev1 (247.3-7+deb11u2 => 247.3-7+deb11u4)
logrotate (3.18.0-2+deb11u1 => 3.18.0-2+deb11u2)
ncurses-base (6.2+20201114-2+deb11u1 => 6.2+20201114-2+deb11u2)
ncurses-bin (6.2+20201114-2+deb11u1 => 6.2+20201114-2+deb11u2)
ncurses-term (6.2+20201114-2+deb11u1 => 6.2+20201114-2+deb11u2)
openssh-client (1:8.4p1-5+deb11u1 => 1:8.4p1-5+deb11u2)
openssh-server (1:8.4p1-5+deb11u1 => 1:8.4p1-5+deb11u2)
openssh-sftp-server (1:8.4p1-5+deb11u1 => 1:8.4p1-5+deb11u2)
openssl (1.1.1n-0+deb11u5 => 1.1.1w-0+deb11u1)
qemu-utils (1:5.2+dfsg-11+deb11u2 => 1:5.2+dfsg-11+deb11u3)
systemd (247.3-7+deb11u2 => 247.3-7+deb11u4)
systemd-sysv (247.3-7+deb11u2 => 247.3-7+deb11u4)
udev (247.3-7+deb11u2 => 247.3-7+deb11u4)
38 upgraded, 0 newly installed, 0 to remove and 1 not upgraded.
Inst base-files [11.1+deb11u7] (11.1+deb11u8 Debian:11.8/oldstable [amd64])
Conf base-files (11.1+deb11u8 Debian:11.8/oldstable [amd64])
Inst dpkg [1.20.12] (1.20.13 Debian:11.8/oldstable [amd64])
Conf dpkg (1.20.13 Debian:11.8/oldstable [amd64])
Inst ncurses-bin [6.2+20201114-2+deb11u1] (6.2+20201114-2+deb11u2 Debian:11.8/oldstable [amd64])
Conf ncurses-bin (6.2+20201114-2+deb11u2 Debian:11.8/oldstable [amd64])
Inst ncurses-base [6.2+20201114-2+deb11u1] (6.2+20201114-2+deb11u2 Debian:11.8/oldstable [all])
Conf ncurses-base (6.2+20201114-2+deb11u2 Debian:11.8/oldstable [all])
Inst libnss-systemd [247.3-7+deb11u2] (247.3-7+deb11u4 Debian:11.8/oldstable, Debian:11-updates/oldstable-updates [amd64]) []
Inst libsystemd0 [247.3-7+deb11u2] (247.3-7+deb11u4 Debian:11.8/oldstable, Debian:11-updates/oldstable-updates [amd64]) [systemd:amd64 ]
Conf libsystemd0 (247.3-7+deb11u4 Debian:11.8/oldstable, Debian:11-updates/oldstable-updates [amd64]) [systemd:amd64 ]
Inst libpam-systemd [247.3-7+deb11u2] (247.3-7+deb11u4 Debian:11.8/oldstable, Debian:11-updates/oldstable-updates [amd64]) [systemd:amd64 ]
Inst systemd [247.3-7+deb11u2] (247.3-7+deb11u4 Debian:11.8/oldstable, Debian:11-updates/oldstable-updates [amd64])
Inst udev [247.3-7+deb11u2] (247.3-7+deb11u4 Debian:11.8/oldstable, Debian:11-updates/oldstable-updates [amd64]) []
Inst libudev1 [247.3-7+deb11u2] (247.3-7+deb11u4 Debian:11.8/oldstable, Debian:11-updates/oldstable-updates [amd64])
Conf libudev1 (247.3-7+deb11u4 Debian:11.8/oldstable, Debian:11-updates/oldstable-updates [amd64])
Inst adduser [3.118] (3.118+deb11u1 Debian:11.8/oldstable [all])
Conf adduser (3.118+deb11u1 Debian:11.8/oldstable [all])
Conf systemd (247.3-7+deb11u4 Debian:11.8/oldstable, Debian:11-updates/oldstable-updates [amd64])
Inst systemd-sysv [247.3-7+deb11u2] (247.3-7+deb11u4 Debian:11.8/oldstable, Debian:11-updates/oldstable-updates [amd64])
Inst dbus [1.12.24-0+deb11u1] (1.12.28-0+deb11u1 Debian:11.8/oldstable [amd64]) []
Inst libdbus-1-3 [1.12.24-0+deb11u1] (1.12.28-0+deb11u1 Debian:11.8/oldstable [amd64])
Inst libk5crypto3 [1.18.3-6+deb11u3] (1.18.3-6+deb11u4 Debian:11.8/oldstable [amd64])
Conf libk5crypto3 (1.18.3-6+deb11u4 Debian:11.8/oldstable [amd64])
Inst libkrb5support0 [1.18.3-6+deb11u3] (1.18.3-6+deb11u4 Debian:11.8/oldstable [amd64]) [libkrb5-3:amd64 ]
Conf libkrb5support0 (1.18.3-6+deb11u4 Debian:11.8/oldstable [amd64]) [libkrb5-3:amd64 ]
Inst libkrb5-3 [1.18.3-6+deb11u3] (1.18.3-6+deb11u4 Debian:11.8/oldstable [amd64]) [libgssapi-krb5-2:amd64 ]
Conf libkrb5-3 (1.18.3-6+deb11u4 Debian:11.8/oldstable [amd64]) [libgssapi-krb5-2:amd64 ]
Inst libgssapi-krb5-2 [1.18.3-6+deb11u3] (1.18.3-6+deb11u4 Debian:11.8/oldstable [amd64])
Conf libgssapi-krb5-2 (1.18.3-6+deb11u4 Debian:11.8/oldstable [amd64])
Inst libssl1.1 [1.1.1n-0+deb11u5] (1.1.1w-0+deb11u1 Debian:11.8/oldstable [amd64])
Conf libssl1.1 (1.1.1w-0+deb11u1 Debian:11.8/oldstable [amd64])
Inst libncurses6 [6.2+20201114-2+deb11u1] (6.2+20201114-2+deb11u2 Debian:11.8/oldstable [amd64]) []
Inst libncursesw6 [6.2+20201114-2+deb11u1] (6.2+20201114-2+deb11u2 Debian:11.8/oldstable [amd64]) []
Inst libtinfo6 [6.2+20201114-2+deb11u1] (6.2+20201114-2+deb11u2 Debian:11.8/oldstable [amd64])
Conf libtinfo6 (6.2+20201114-2+deb11u2 Debian:11.8/oldstable [amd64])
Inst cpio [2.13+dfsg-4] (2.13+dfsg-7.1~deb11u1 Debian:11.8/oldstable [amd64])
Inst logrotate [3.18.0-2+deb11u1] (3.18.0-2+deb11u2 Debian:11.8/oldstable [amd64])
Inst krb5-locales [1.18.3-6+deb11u3] (1.18.3-6+deb11u4 Debian:11.8/oldstable [all])
Inst ncurses-term [6.2+20201114-2+deb11u1] (6.2+20201114-2+deb11u2 Debian:11.8/oldstable [all])
Inst openssh-sftp-server [1:8.4p1-5+deb11u1] (1:8.4p1-5+deb11u2 Debian:11.8/oldstable [amd64]) []
Inst openssh-server [1:8.4p1-5+deb11u1] (1:8.4p1-5+deb11u2 Debian:11.8/oldstable [amd64]) []
Inst openssh-client [1:8.4p1-5+deb11u1] (1:8.4p1-5+deb11u2 Debian:11.8/oldstable [amd64])
Inst distro-info-data [0.51+deb11u3] (0.51+deb11u4 Debian:11.8/oldstable [all])
Inst grub2-common [2.06-3~deb11u5] (2.06-3~deb11u6 Debian-Security:11/oldstable-security [amd64]) [grub-pc:amd64 ]
Inst grub-pc [2.06-3~deb11u5] (2.06-3~deb11u6 Debian-Security:11/oldstable-security [amd64]) []
Inst grub-pc-bin [2.06-3~deb11u5] (2.06-3~deb11u6 Debian-Security:11/oldstable-security [amd64]) []
Inst grub-common [2.06-3~deb11u5] (2.06-3~deb11u6 Debian-Security:11/oldstable-security [amd64])
Inst libbsd0 [0.11.3-1] (0.11.3-1+deb11u1 Debian:11.8/oldstable [amd64])
Inst libcurl3-gnutls [7.74.0-1.3+deb11u7] (7.74.0-1.3+deb11u10 Debian-Security:11/oldstable-security [amd64])
Inst openssl [1.1.1n-0+deb11u5] (1.1.1w-0+deb11u1 Debian:11.8/oldstable [amd64])
Inst qemu-utils [1:5.2+dfsg-11+deb11u2] (1:5.2+dfsg-11+deb11u3 Debian:11.8/oldstable [amd64])
Conf libnss-systemd (247.3-7+deb11u4 Debian:11.8/oldstable, Debian:11-updates/oldstable-updates [amd64])
Conf libpam-systemd (247.3-7+deb11u4 Debian:11.8/oldstable, Debian:11-updates/oldstable-updates [amd64])
Conf udev (247.3-7+deb11u4 Debian:11.8/oldstable, Debian:11-updates/oldstable-updates [amd64])
Conf systemd-sysv (247.3-7+deb11u4 Debian:11.8/oldstable, Debian:11-updates/oldstable-updates [amd64])
Conf dbus (1.12.28-0+deb11u1 Debian:11.8/oldstable [amd64])
Conf libdbus-1-3 (1.12.28-0+deb11u1 Debian:11.8/oldstable [amd64])
Conf libncurses6 (6.2+20201114-2+deb11u2 Debian:11.8/oldstable [amd64])
Conf libncursesw6 (6.2+20201114-2+deb11u2 Debian:11.8/oldstable [amd64])
Conf cpio (2.13+dfsg-7.1~deb11u1 Debian:11.8/oldstable [amd64])
Conf logrotate (3.18.0-2+deb11u2 Debian:11.8/oldstable [amd64])
Conf krb5-locales (1.18.3-6+deb11u4 Debian:11.8/oldstable [all])
Conf ncurses-term (6.2+20201114-2+deb11u2 Debian:11.8/oldstable [all])
Conf openssh-sftp-server (1:8.4p1-5+deb11u2 Debian:11.8/oldstable [amd64])
Conf openssh-server (1:8.4p1-5+deb11u2 Debian:11.8/oldstable [amd64])
Conf openssh-client (1:8.4p1-5+deb11u2 Debian:11.8/oldstable [amd64])
Conf distro-info-data (0.51+deb11u4 Debian:11.8/oldstable [all])
Conf grub2-common (2.06-3~deb11u6 Debian-Security:11/oldstable-security [amd64])
Conf grub-pc (2.06-3~deb11u6 Debian-Security:11/oldstable-security [amd64])
Conf grub-pc-bin (2.06-3~deb11u6 Debian-Security:11/oldstable-security [amd64])
Conf grub-common (2.06-3~deb11u6 Debian-Security:11/oldstable-security [amd64])
Conf libbsd0 (0.11.3-1+deb11u1 Debian:11.8/oldstable [amd64])
Conf libcurl3-gnutls (7.74.0-1.3+deb11u10 Debian-Security:11/oldstable-security [amd64])
Conf openssl (1.1.1w-0+deb11u1 Debian:11.8/oldstable [amd64])
Conf qemu-utils (1:5.2+dfsg-11+deb11u3 Debian:11.8/oldstable [amd64])
";
let upgradables = parse_upgrade_output(test_output);
assert_eq!(upgradables.len(), 38);
}
}

View file

@ -1,190 +0,0 @@
//! System health module
use std::{collections::HashMap, sync::Arc};
use tokio::sync::watch;
use super::{info::Info, task::TaskStatus};
const MEMORY_USAGE_CRITICAL_THRESHOLD: u64 = 90;
const CPU_USAGE_CRITICAL_THRESHOLD: u64 = 90;
const DISK_USAGE_CRITICAL_THRESHOLD: u64 = 90;
#[derive(Clone, PartialEq)]
pub enum CriticalReason {
HighMemoryUsage,
HighCpuUsage,
HighDiskUsage,
}
#[derive(Clone, Default, PartialEq)]
pub enum SystemStatus {
#[default]
Normal,
OutOfDate,
Updating,
Critical(Vec<CriticalReason>),
}
#[derive(Clone, Default)]
pub struct SystemHealth {
pub status: SystemStatus,
}
#[derive(Default, Clone)]
pub struct Health {
system: SystemHealth,
tasks: HashMap<String, TaskStatus>,
}
impl Health {
pub fn system(&self) -> &SystemHealth {
&self.system
}
pub fn tasks(&self) -> &HashMap<String, TaskStatus> {
&self.tasks
}
}
/// [HealthMonitor] gives access to shared system health state, allowing to watch health and update
/// task health status.
///
/// # Usage
/// Internally it uses [Arc] so it can be cheaply cloned and shared.
/// ```
/// use prymn_agent::health::HealthMonitor;
/// use prymn_agent::info::Info;
///
/// let mut info = Info::new();
/// let health_monitor = HealthMonitor::new();
///
/// // Monitor health changes
/// let _receiver = health_monitor.monitor();
///
/// // Refresh system resources
/// info.refresh_resources();
///
/// // Update the health monitor with the refreshed info
/// health_monitor.check_system_info(&info);
/// ```
#[derive(Clone)]
pub struct HealthMonitor {
sender: Arc<watch::Sender<Health>>,
}
impl HealthMonitor {
pub fn new() -> Self {
let (sender, _) = watch::channel(Health::default());
Self {
sender: Arc::new(sender),
}
}
pub fn check_system_info(&self, info: &Info) {
use sysinfo::{CpuExt, DiskExt, SystemExt};
let sys = info.system();
let mut status = SystemStatus::Normal;
let mut statuses = vec![];
// Check for critical memory usage
let memory_usage = if sys.total_memory() > 0 {
sys.used_memory() * 100 / sys.total_memory()
} else {
0
};
if memory_usage > MEMORY_USAGE_CRITICAL_THRESHOLD {
statuses.push(CriticalReason::HighMemoryUsage);
}
// Check for critical CPU usage
let cpu_usage = sys.global_cpu_info().cpu_usage();
if cpu_usage > CPU_USAGE_CRITICAL_THRESHOLD as f32 {
statuses.push(CriticalReason::HighCpuUsage);
}
// Check for any disk usage that is critical
for disk in sys.disks() {
let available_disk = if disk.total_space() > 0 {
disk.available_space() * 100 / disk.total_space()
} else {
0
};
if available_disk < 100 - DISK_USAGE_CRITICAL_THRESHOLD {
statuses.push(CriticalReason::HighDiskUsage);
}
}
if !statuses.is_empty() {
status = SystemStatus::Critical(statuses);
}
self.sender.send_if_modified(|Health { system, .. }| {
if system.status == status {
return false;
}
system.status = status;
true
});
}
/// Spawns a new tokio task that tracks from the [watch::Receiver] the status of a Prymn task
/// via [TaskStatus]
pub fn track_task(&self, name: String, mut task_recv: watch::Receiver<TaskStatus>) {
let sender = self.sender.clone();
tokio::task::spawn(async move {
while task_recv.changed().await.is_ok() {
sender.send_modify(|health| {
health
.tasks
.insert(String::from(&name), task_recv.borrow().clone());
});
}
// At this point the Sender part of the watch dropped, meaning we can clear the task
// because it is complete.
sender.send_if_modified(|health| health.tasks.remove(&name).is_some());
});
}
pub fn clear_task(&self, task_name: &str) {
self.sender
.send_if_modified(|Health { tasks, .. }| tasks.remove(task_name).is_some());
}
pub fn monitor(&self) -> watch::Receiver<Health> {
self.sender.subscribe()
}
}
impl Default for HealthMonitor {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for SystemStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SystemStatus::Normal => write!(f, "normal"),
SystemStatus::OutOfDate => write!(f, "out of date"),
SystemStatus::Updating => write!(f, "updating"),
SystemStatus::Critical(_) => write!(f, "critical"),
}
}
}
impl std::fmt::Display for CriticalReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CriticalReason::HighMemoryUsage => write!(f, "high memory usage"),
CriticalReason::HighCpuUsage => write!(f, "high cpu usage"),
CriticalReason::HighDiskUsage => write!(f, "high disk usage"),
}
}
}

View file

@ -1,87 +0,0 @@
//! System info
use std::{sync::Mutex, time::Duration};
use anyhow::Context;
use sysinfo::{CpuRefreshKind, SystemExt};
use crate::debian;
pub struct Info {
system: sysinfo::System,
updates: Vec<String>,
}
impl Info {
pub fn new() -> Self {
Self {
system: sysinfo::System::new(),
updates: Vec::new(),
}
}
pub fn refresh_resources(&mut self) {
self.system.refresh_specifics(
sysinfo::RefreshKind::new()
.with_disks_list()
.with_memory()
.with_cpu(CpuRefreshKind::new().with_cpu_usage()),
);
}
pub fn refresh_updates(&mut self) -> anyhow::Result<()> {
debian::update_package_index().context("while fetching the package index")?;
let updates =
debian::get_available_updates().context("while fetching available updates")?;
self.updates = updates;
Ok(())
}
pub fn system(&self) -> &sysinfo::System {
&self.system
}
pub fn updates(&self) -> &Vec<String> {
&self.updates
}
}
impl Default for Info {
fn default() -> Self {
Self::new()
}
}
/// Spawns a new thread that forever gathers system information.
pub fn spawn_info_subsystem() -> &'static Mutex<Info> {
const REFRESH_RESOURCES_INTERVAL: Duration = Duration::from_secs(5);
const REFRESH_UPDATES_INTERVAL: Duration = Duration::from_secs(3600);
let info = Box::new(Mutex::new(Info::new()));
let info = Box::leak(info);
std::thread::spawn(|| loop {
tracing::debug!("refreshing system resources");
#[allow(clippy::mut_mutex_lock)]
info.lock().unwrap().refresh_resources();
std::thread::sleep(REFRESH_RESOURCES_INTERVAL);
});
std::thread::spawn(|| loop {
tracing::debug!("refreshing available system updates");
#[allow(clippy::mut_mutex_lock)]
if let Err(err) = info.lock().unwrap().refresh_updates() {
tracing::warn!(?err, "failed to refresh updates");
}
std::thread::sleep(REFRESH_UPDATES_INTERVAL);
});
info
}

View file

@ -1,8 +0,0 @@
pub mod config;
pub mod debian;
pub mod health;
pub mod info;
pub mod pty;
pub mod self_update;
pub mod server;
pub mod task;

107
agent/src/main.rs Normal file
View file

@ -0,0 +1,107 @@
use async_nats as nats;
use nats::{Client, ConnectOptions};
use tracing::Level;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let subscriber = tracing_subscriber::fmt()
.with_max_level(Level::TRACE)
.finish();
tracing::subscriber::set_global_default(subscriber)
.expect("to set a tracing global subscriber");
let client = ConnectOptions::new()
.name("Prymn Agent demo_agent")
.custom_inbox_prefix("_INBOX_demo_agent")
.user_and_password("demo_agent".to_owned(), "demo_agent_password".to_owned())
.connect("localhost")
.await
.map_err(|err| err)?;
tracing::info!("connected to nats server");
wait_for_commands(client).await;
Ok(())
}
async fn wait_for_commands(_client: Client) {
// let mut sub = client
// .subscribe("agents.v1.demo_agent.cmd.*")
// .await
// .unwrap();
//
// let mut command_queue = CommandQueue::default();
//
// while let Some(msg) = sub.next().await {
// let suffix = msg.subject.trim_start_matches("agents.v1.demo_agent.cmd.");
//
// match suffix {
// "end" => {
// let key = std::str::from_utf8(&msg.payload).unwrap();
// command_queue.end_command(key);
// }
// key => {
// if let Some(mut receiver) = command_queue.add_command(key) {
// tokio::spawn(async move {
// while let Ok(()) = receiver.changed().await {
// let queue = receiver.borrow();
// while let Some(cmd) = queue.lock().unwrap().pop_back() {
// handle_command(cmd);
// }
// }
// });
// }
// }
// }
// }
}
//
// fn handle_command(cmd: Command) {
// println!("{cmd:?}");
// }
//
// #[derive(Default)]
// struct CommandQueue(HashMap<String, watch::Sender<Mutex<VecDeque<Command>>>>);
//
// impl CommandQueue {
// pub fn add_command(&mut self, key: &str) -> Option<watch::Receiver<Mutex<VecDeque<Command>>>> {
// match self.0.get_mut(key) {
// Some(sender) => {
// sender.send_modify(|q| q.lock().unwrap().push_back(Command::Foo));
// None
// }
// None => {
// let (sender, receiver) = watch::channel(Mutex::new(VecDeque::new()));
// sender.send_modify(|q| q.lock().unwrap().push_back(Command::Foo));
// Some(receiver)
// }
// }
// }
//
// pub fn end_command(&mut self, key: &str) {
// self.0.remove(key);
// }
// }
//
// mod cmd {
// // use std::borrow::Cow;
//
// #[derive(Debug, Clone)]
// pub enum Command {
// Foo,
// }
//
// #[derive(Debug)]
// pub struct UnknownCommand<'a>(&'a str);
//
// impl<'a> TryFrom<&'a str> for Command {
// type Error = UnknownCommand<'a>;
//
// fn try_from(cmd: &'a str) -> Result<Self, Self::Error> {
// match cmd {
// "foo" => Ok(Command::Foo),
// _ => Err(UnknownCommand(cmd)),
// }
// }
// }
// }

View file

@ -1,166 +0,0 @@
use std::{io, task::ready};
use rustix::{
fd::OwnedFd,
fs::{fcntl_getfl, fcntl_setfl, OFlags},
process::{ioctl_tiocsctty, setsid},
pty::{grantpt, ioctl_tiocgptpeer, openpt, unlockpt, OpenptFlags},
stdio::{dup2_stderr, dup2_stdin, dup2_stdout},
termios::{tcsetwinsize, Winsize},
};
use tokio::{
io::{unix::AsyncFd, AsyncRead, AsyncWrite},
process::Child,
};
#[derive(Debug)]
pub struct Pty {
fd: AsyncFd<OwnedFd>,
}
impl Pty {
pub fn open() -> io::Result<Self> {
let master = openpt(OpenptFlags::RDWR | OpenptFlags::NOCTTY | OpenptFlags::CLOEXEC)?;
grantpt(&master)?;
unlockpt(&master)?;
// Set nonblocking
let flags = fcntl_getfl(&master)?;
fcntl_setfl(&master, flags | OFlags::NONBLOCK)?;
let fd = AsyncFd::new(master)?;
Ok(Self { fd })
}
pub fn child(&self) -> io::Result<PtyChild> {
// NOTE: Linux v4.13 and above
let fd = ioctl_tiocgptpeer(&self.fd, OpenptFlags::RDWR | OpenptFlags::NOCTTY)?;
let child = PtyChild { fd };
Ok(child)
}
pub fn resize_window(&self, rows: u16, cols: u16) -> io::Result<()> {
let winsize = Winsize {
ws_row: rows,
ws_col: cols,
ws_xpixel: 0,
ws_ypixel: 0,
};
tcsetwinsize(&self.fd, winsize)?;
Ok(())
}
pub fn try_clone(&self) -> io::Result<Pty> {
let fd = self.fd.get_ref().try_clone()?;
Ok(Pty {
fd: AsyncFd::new(fd)?,
})
}
}
impl AsyncRead for Pty {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<io::Result<()>> {
loop {
let mut guard = ready!(self.fd.poll_read_ready(cx)?);
match guard.try_io(|inner| {
let fd = inner.get_ref();
let n = rustix::io::read(fd, buf.initialize_unfilled())?;
buf.advance(n);
Ok(())
}) {
Ok(result) => return std::task::Poll::Ready(result),
Err(_would_block) => continue,
}
}
}
}
impl AsyncWrite for Pty {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, io::Error>> {
loop {
let mut guard = ready!(self.fd.poll_write_ready(cx))?;
match guard.try_io(|inner| Ok(rustix::io::write(inner.get_ref(), buf)?)) {
Ok(result) => return std::task::Poll::Ready(result),
Err(_would_block) => continue,
}
}
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), io::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), io::Error>> {
std::task::Poll::Ready(Ok(()))
}
}
#[derive(Debug)]
pub struct PtyChild {
fd: OwnedFd,
}
impl PtyChild {
pub fn login_tty(&self) -> io::Result<()> {
setsid()?;
ioctl_tiocsctty(&self.fd)?;
dup2_stdin(&self.fd)?;
dup2_stdout(&self.fd)?;
dup2_stderr(&self.fd)?;
Ok(())
}
}
pub fn open_shell(pty_child: PtyChild) -> io::Result<Child> {
let mut cmd = tokio::process::Command::new("/bin/bash");
unsafe {
cmd.pre_exec(move || {
pty_child.login_tty()?;
Ok(())
});
}
cmd.spawn()
}
#[cfg(test)]
mod test {
use rustix::fd::AsRawFd;
use super::*;
#[tokio::test]
async fn can_open_pty() {
let pty = Pty::open().unwrap();
let child = pty.child().unwrap();
let master_fd = pty.fd.get_ref().as_raw_fd();
let child_fd = child.fd.as_raw_fd();
assert!(master_fd != child_fd);
}
}

View file

@ -1,143 +0,0 @@
use std::{fs::File, io::Write, os::unix::prelude::PermissionsExt, path::Path, process::Command};
use anyhow::Context;
use reqwest::{blocking::Client, StatusCode};
use serde::Deserialize;
use crate::config;
const PRYMN_PATH: &str = "/usr/local/bin/prymn_agent";
pub fn install(token: &str) -> anyhow::Result<()> {
let this_exe = std::env::current_exe()?;
copy_binary(&this_exe, Path::new(PRYMN_PATH)).with_context(|| {
format!(
"could not copy the file {} to the destination {PRYMN_PATH}",
this_exe.to_str().unwrap(),
)
})?;
install_service_file(Path::new("/etc/systemd/system/prymn.service"))
.context("could not install the agent daemon service")?;
register_to_backend(token).context("could not register the agent")?;
Ok(())
}
fn copy_binary(src: &Path, dest: &Path) -> anyhow::Result<()> {
if dest.exists() {
// unlink the potentially running binary
std::fs::remove_file(dest)?;
}
std::fs::copy(src, dest)?;
let mut perms = dest.metadata()?.permissions();
perms.set_mode(0o755);
std::fs::set_permissions(dest, perms)?;
Ok(())
}
fn install_service_file(dest: &Path) -> anyhow::Result<()> {
let mut file = File::create(dest)?;
write!(
file,
r#"
[Unit]
Description=Prymn Agent Service
After=network.target
[Service]
ExecStart={PRYMN_PATH}
Type=simple
Restart=always
[Install]
WantedBy=default.target
"#
)?;
if !Command::new("systemctl")
.arg("daemon-reload")
.status()?
.success()
{
anyhow::bail!("command exit with non-zero exit code; could not reload systemd daemon");
}
if !Command::new("systemctl")
.arg("enable")
.arg("--now")
.arg("prymn.service")
.status()?
.success()
{
anyhow::bail!("command exit with non-zero exit code; could not enable systemd service");
}
Ok(())
}
fn register_to_backend(token: &str) -> anyhow::Result<()> {
let client = Client::new();
let response = client
.post(format!(
"{}/api/v1/servers/register",
config::CONFIG.backend_url
))
.json(&serde_json::json!({ "token": token }))
.send()?;
// TODO: When the backend API is established more concretely, change this to something better.
#[derive(Deserialize)]
struct ApiError {
errors: serde_json::Value,
}
match response.status() {
StatusCode::UNPROCESSABLE_ENTITY => {
let error = response.json::<ApiError>()?;
anyhow::bail!(
"request was unsuccessful: the backend received invalid data: {}",
error.errors.to_string()
)
}
status if !status.is_success() => {
anyhow::bail!("request was unsuccessful: error {}", status)
}
_ => Ok(()),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn copy_binary_works() {
let temp_dir = std::env::temp_dir();
// let temp_dir = tempdir().unwrap();
let file1_path = temp_dir.join("file1");
let mut file1 = File::create(&file1_path).unwrap();
let file2_path = temp_dir.join("file2");
let mut file2 = File::create(&file2_path).unwrap();
writeln!(file1, "old data").unwrap();
writeln!(file2, "new data").unwrap();
copy_binary(&file2_path, &file1_path).expect("could not copy file");
let perms = file1_path
.metadata()
.expect("could not retrieve metadata")
.permissions();
let new_data = std::fs::read_to_string(file1_path).unwrap();
assert_eq!(new_data, "new data\n");
assert!(perms.mode() & 0o755 == 0o755);
}
}

View file

@ -1,224 +0,0 @@
use std::{pin::Pin, process::Stdio, sync::Mutex};
use tokio::{io::AsyncWriteExt, process::Command};
use tokio_stream::{
wrappers::{ReceiverStream, WatchStream},
Stream, StreamExt,
};
use tokio_util::codec::{BytesCodec, FramedRead};
use tonic::{Request, Response, Status, Streaming};
use crate::{
debian,
health::HealthMonitor,
info::Info,
pty::{open_shell, Pty},
task::TaskBuilder,
};
use super::proto::*;
type AgentResult<T> = std::result::Result<Response<T>, Status>;
pub struct AgentService<'a> {
pub health: HealthMonitor,
pub info: &'a Mutex<Info>, // TODO: Find a way to remove the Mutex dependency here
}
#[tonic::async_trait]
impl agent_server::Agent for AgentService<'static> {
type HealthStream = Pin<Box<dyn Stream<Item = Result<HealthResponse, Status>> + Send>>;
async fn health(&self, _: Request<()>) -> AgentResult<Self::HealthStream> {
let receiver = self.health.monitor();
let version = env!("CARGO_PKG_VERSION");
let output = WatchStream::new(receiver).map(|health| {
Ok(HealthResponse {
version: version.to_owned(),
system: Some(health.system().into()),
tasks: health
.tasks()
.iter()
.map(|(k, v)| (k.clone(), v.into()))
.collect(),
})
});
Ok(Response::new(Box::pin(output)))
}
async fn get_sys_info(&self, _: Request<()>) -> AgentResult<SysInfoResponse> {
Ok(Response::new(SysInfoResponse::from(
&*self.info.lock().unwrap(),
)))
}
type SysUpdateStream = Pin<Box<dyn Stream<Item = Result<SysUpdateResponse, Status>> + Send>>;
async fn sys_update(
&self,
req: Request<SysUpdateRequest>,
) -> AgentResult<Self::SysUpdateStream> {
let dry_run = req.get_ref().dry_run;
let mut receiver =
TaskBuilder::new("system update".to_owned()).health_monitor(self.health.clone());
if dry_run {
receiver = receiver
.add_step(async { Ok("simulating a system update...".to_owned()) })
.add_step(async {
const DUR: std::time::Duration = std::time::Duration::from_secs(5);
tokio::time::sleep(DUR).await;
Ok("completed running an artifical delay...".to_owned())
});
}
let receiver = receiver
.add_step(async move {
tokio::task::spawn_blocking(move || {
let output = debian::run_updates(dry_run).map_err(|err| {
tracing::error!(%err, "failed to run updates");
err
})?;
let out = if !output.status.success() {
tracing::error!(?output, "child process exited unsuccessfuly");
match output.status.code() {
Some(exit_code) => Err(Status::internal(format!(
"operation exited with error (code {exit_code})"
))),
None => Err(Status::cancelled("operation was cancelled by signal")),
}
} else {
Ok(String::from_utf8_lossy(output.stdout.as_slice()).to_string())
};
// TODO: We could split the output by lines and emit those as "steps" so the
// upgrade process is more interactive
out
})
.await
.unwrap()
})
.build()
.into_background();
let stream = ReceiverStream::new(receiver).map(|output| {
output
.map(|output| SysUpdateResponse {
output,
progress: 1,
})
.map_err(|err| Status::internal(err.to_string()))
});
Ok(Response::new(Box::pin(stream)))
}
type ExecStream = Pin<Box<dyn Stream<Item = Result<ExecResponse, Status>> + Send>>;
async fn exec(&self, req: Request<ExecRequest>) -> AgentResult<Self::ExecStream> {
use exec_response::Out;
let ExecRequest {
user,
program,
args,
} = req.get_ref();
if user.is_empty() {
return Err(Status::invalid_argument("you must specify a user"));
}
if program.is_empty() {
return Err(Status::invalid_argument("you must specify a program"));
}
let mut command = if user != "root" {
let mut cmd = Command::new("sudo");
cmd.arg("-iu").arg(user).arg("--").arg(program);
cmd
} else {
Command::new(program)
};
let mut io = command
.args(args)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()?;
let stdout = FramedRead::new(io.stdout.take().unwrap(), BytesCodec::new()).map(|stdout| {
let stdout = stdout.unwrap();
Out::Stdout(String::from_utf8_lossy(&stdout[..]).to_string())
});
let stderr = FramedRead::new(io.stderr.take().unwrap(), BytesCodec::new()).map(|stderr| {
let stderr = stderr.unwrap();
Out::Stderr(String::from_utf8_lossy(&stderr[..]).to_string())
});
let exit = TaskBuilder::new(format!("exec {program}"))
.health_monitor(self.health.clone())
.add_step(async move { io.wait().await.unwrap() })
.build()
.into_stream();
let stream = stdout
.merge(stderr)
.chain(exit.map(|code| Out::ExitCode(code.code().unwrap_or_default())))
.map(|out| Ok(ExecResponse { out: Some(out) }));
Ok(Response::new(Box::pin(stream)))
}
type TerminalStream = Pin<Box<dyn Stream<Item = Result<TerminalResponse, Status>> + Send>>;
async fn terminal(
&self,
req: Request<Streaming<TerminalRequest>>,
) -> AgentResult<Self::TerminalStream> {
let mut in_stream = req.into_inner();
let mut pty = Pty::open()?;
let pty_clone = pty.try_clone()?;
let pty_child = pty.child()?;
let mut child = open_shell(pty_child)?;
tokio::spawn(async move {
// TODO: Handle errors inside here
while let Some(result) = in_stream.next().await {
match result {
Ok(req) => {
if let Some(resize) = req.resize {
pty.resize_window(resize.rows as u16, resize.cols as u16)
.unwrap();
}
pty.write_all(&req.input[..]).await.unwrap();
}
Err(err) => {
// Log and ignore the error...
tracing::warn!(%err, "received an incoming stream error");
}
}
}
// TODO: Maybe there's a more graceful way to stop the process?
child.kill().await.unwrap();
});
let out_stream = FramedRead::new(pty_clone, BytesCodec::new()).map(|inner| {
inner
.map(|b| TerminalResponse { output: b.to_vec() })
.map_err(|err| {
tracing::error!(%err, "read error on pseudoterminal");
Status::internal("terminal read error")
})
});
Ok(Response::new(Box::pin(out_stream)))
}
}

View file

@ -1,128 +0,0 @@
use std::time::Duration;
use tokio::{signal, sync::oneshot};
use tower_http::trace::TraceLayer;
use crate::{
health::HealthMonitor,
info,
server::{agent::AgentService, proto::agent_server},
};
mod agent;
mod proto {
tonic::include_proto!("prymn");
impl From<&crate::health::SystemHealth> for SystemHealth {
fn from(val: &crate::health::SystemHealth) -> Self {
if let crate::health::SystemStatus::Critical(ref reasons) = val.status {
SystemHealth {
status: itertools::join(reasons.iter().map(ToString::to_string), ","),
}
} else {
SystemHealth {
status: val.status.to_string(),
}
}
}
}
impl From<&crate::task::TaskStatus> for TaskHealth {
fn from(value: &crate::task::TaskStatus) -> Self {
Self {
started_on: value.started_on().to_string(),
progress: value.progress(),
}
}
}
impl From<&crate::info::Info> for SysInfoResponse {
fn from(info: &crate::info::Info) -> Self {
use sysinfo::{CpuExt, DiskExt, SystemExt};
let system = info.system();
let cpus = system
.cpus()
.iter()
.map(|cpu| sys_info_response::Cpu {
freq_mhz: cpu.frequency(),
usage: cpu.cpu_usage(),
})
.collect();
let disks = system
.disks()
.iter()
.map(|disk| sys_info_response::Disk {
name: disk.name().to_string_lossy().into_owned(),
total_bytes: disk.total_space(),
avail_bytes: disk.available_space(),
mount_point: disk.mount_point().to_string_lossy().into_owned(),
})
.collect();
Self {
uptime: system.uptime(),
hostname: system.host_name().unwrap_or_default(),
os: system.long_os_version().unwrap_or_default(),
mem_total_bytes: system.total_memory(),
mem_avail_bytes: system.available_memory(),
swap_total_bytes: system.total_swap(),
swap_free_bytes: system.free_swap(),
updates_available: info.updates().len() as u32,
cpus,
disks,
}
}
}
}
/// Run the server. This is the main entry point of the application.
#[tokio::main]
pub async fn run() -> anyhow::Result<()> {
let (shutdown_tx, shutdown_rx) = oneshot::channel();
// Listen for shutdown signals
tokio::spawn(async {
signal::ctrl_c()
.await
.expect("failed to listen to a ctrl-c signal");
let _ = shutdown_tx.send(());
});
let info = info::spawn_info_subsystem();
let health_monitor = HealthMonitor::new();
// Monitor system info forever
// TODO: Maybe we can move it inside the server response function?
// We could spawn a new loop whenever we need it, but the problem is when does it get
// destroyed?
{
let health_monitor = health_monitor.clone();
tokio::spawn(async move {
loop {
health_monitor.check_system_info(&info.lock().unwrap());
tokio::time::sleep(Duration::from_secs(1)).await;
}
});
}
let agent_service = agent_server::AgentServer::new(AgentService {
health: health_monitor.clone(),
info,
});
let addr = "[::]:50012".parse()?;
tracing::info!("listening on {}", addr);
tonic::transport::Server::builder()
.layer(TraceLayer::new_for_grpc())
.add_service(agent_service)
.serve_with_shutdown(addr, async {
let _ = shutdown_rx.await;
})
.await?;
Ok(())
}

View file

@ -1,154 +0,0 @@
//! A task is an atomic executing routine that the agent is running, potentially in the background.
//! The task is tracked by the system monitor.
// TODO: Take a look at futures::stream::FuturesOrdered
// It is used to store futures in an ordered fashion, and it also implements Stream
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use chrono::{DateTime, Utc};
use tokio::sync::{mpsc, watch};
use tokio_stream::{Stream, StreamExt};
use super::health::HealthMonitor;
#[derive(Clone, Default)]
pub struct TaskStatus {
started_on: DateTime<Utc>,
curr_step: usize,
max_steps: usize,
}
impl TaskStatus {
/// Returns the task progress as a percentage value
pub fn progress(&self) -> f32 {
100.0 * (self.curr_step as f32 / self.max_steps as f32)
}
/// Returns the datetime when this task began executing
pub fn started_on(&self) -> &DateTime<Utc> {
&self.started_on
}
fn next_step(&mut self) {
if self.curr_step < self.max_steps {
self.curr_step += 1;
}
}
}
type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
pub struct TaskBuilder<Step> {
task: Task<Step>,
}
impl<T> TaskBuilder<T> {
pub fn new(name: String) -> Self {
let (sender, _) = watch::channel(TaskStatus::default());
Self {
task: Task {
name,
health_monitor: None,
status_channel: sender,
steps: Vec::new(),
},
}
}
/// Attaches a health monitor to notify the health system on progress made.
pub fn health_monitor(mut self, health_monitor: HealthMonitor) -> Self {
self.task.health_monitor = Some(health_monitor);
self
}
pub fn build(self) -> Task<T> {
self.task
}
}
impl<T: Send + 'static> TaskBuilder<BoxFuture<T>> {
pub fn add_step(mut self, step: impl Future<Output = T> + Send + 'static) -> Self {
self.task.add_step(step);
self
}
}
pub struct Task<T> {
name: String,
health_monitor: Option<HealthMonitor>,
status_channel: watch::Sender<TaskStatus>,
steps: Vec<T>,
}
impl<T: Send + 'static> Task<BoxFuture<T>> {
fn add_step(&mut self, step: impl Future<Output = T> + Send + 'static) {
self.steps.push(Box::pin(step))
}
/// Turn this Task into an object that implements [Stream].
///
/// The new stream will output each step's future output.
pub fn into_stream(self) -> TaskStream<T> {
if let Some(health) = &self.health_monitor {
health.track_task(self.name.clone(), self.status_channel.subscribe());
}
// Immediately notify the initial status (step 0)
self.status_channel.send_replace(TaskStatus {
started_on: Utc::now(),
curr_step: 0,
max_steps: self.steps.len(),
});
TaskStream { inner: self }
}
/// Run this task concurrently in the background.
///
/// Returns a [mpsc::Receiver<T>] which receives the returned values of each step's future
/// output.
pub fn into_background(self) -> mpsc::Receiver<T> {
let (sender, receiver) = mpsc::channel(10);
tokio::spawn(async move {
let mut stream = self.into_stream();
while let Some(value) = stream.next().await {
let _ = sender.send(value).await;
}
});
receiver
}
}
pub struct TaskStream<T> {
inner: Task<BoxFuture<T>>,
}
impl<T> Stream for TaskStream<T> {
type Item = T;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.inner.steps.get_mut(0) {
Some(fut) => match fut.as_mut().poll(cx) {
Poll::Ready(value) => {
self.inner.steps.remove(0);
self.inner
.status_channel
.send_modify(|task| task.next_step());
Poll::Ready(Some(value))
}
Poll::Pending => Poll::Pending,
},
None => Poll::Ready(None),
}
}
}

View file

@ -1,18 +0,0 @@
use prymn_agent::{health::HealthMonitor, task::TaskBuilder};
#[tokio::test]
async fn task_is_gone_from_health_monitor_when_complete() {
let health_monitor = HealthMonitor::new();
let health_recv = health_monitor.monitor();
let mut task_recv = TaskBuilder::new("test task".to_owned())
.health_monitor(health_monitor)
.add_step(async { "foo" })
.add_step(async { "bar" })
.build()
.into_background();
assert_eq!(task_recv.recv().await.unwrap(), "foo");
assert_eq!(task_recv.recv().await.unwrap(), "bar");
assert!(health_recv.borrow().tasks().is_empty());
}