diff --git a/Cargo.lock b/Cargo.lock index 8d3f5f9..eb64435 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -709,6 +709,7 @@ dependencies = [ "bytes", "chrono", "futures", + "once_cell", "rustix", "serde", "serde_json", diff --git a/agent/Cargo.toml b/agent/Cargo.toml index 96a6809..a4d9053 100644 --- a/agent/Cargo.toml +++ b/agent/Cargo.toml @@ -9,6 +9,7 @@ async-nats = "0.33.0" bytes = "1.5.0" chrono = { version = "0.4.33", default-features = false, features = ["now", "serde"] } futures = { version = "0.3.30", default-features = false, features = ["std"] } +once_cell = "1.19.0" rustix = { version = "0.38.30", features = ["termios", "stdio", "pty", "process"] } serde = { version = "1.0.195", features = ["derive"] } serde_json = "1.0.111" diff --git a/agent/src/health.rs b/agent/src/health.rs index 58296e9..321af35 100644 --- a/agent/src/health.rs +++ b/agent/src/health.rs @@ -5,7 +5,7 @@ use std::{collections::HashMap, sync::Arc, time::Duration}; use serde::{Deserialize, Serialize}; use tokio::sync::watch; -use crate::messaging::{Client, Message}; +use crate::messaging::Message; const MEMORY_USAGE_CRITICAL_THRESHOLD: f64 = 90.0; const CPU_USAGE_CRITICAL_THRESHOLD: f32 = 90.0; @@ -171,7 +171,7 @@ impl Default for HealthMonitor { } } -pub async fn init_health_subsystem(client: Client) -> HealthMonitor { +pub async fn init_health_subsystem(client: crate::messaging::Client) -> HealthMonitor { let health_monitor = HealthMonitor::new(); let health_monitor_clone = health_monitor.clone(); let health_monitor_ret = health_monitor.clone(); diff --git a/agent/src/messaging.rs b/agent/src/messaging.rs index 3b71335..65f502e 100644 --- a/agent/src/messaging.rs +++ b/agent/src/messaging.rs @@ -1,4 +1,7 @@ -use std::fmt::{Debug, Display}; +use std::{ + fmt::{Debug, Display}, + sync::Arc, +}; use anyhow::{anyhow, Result}; use bytes::Bytes; @@ -6,11 +9,7 @@ use futures::Stream; use tokio::sync::mpsc; use tokio_stream::StreamExt; -#[derive(Clone)] -pub struct Client { - id: String, - nats: async_nats::Client, -} +const PREFIX: &'static str = "agents.v1"; #[derive(Debug)] pub enum Subject { @@ -46,18 +45,19 @@ impl TryFrom<&str> for Subject { pub struct Message { subject: Subject, payload: Bytes, - reply: Option, + // reply: Option, } impl Message { - fn from_transport(msg: async_nats::Message) -> Result { - let suffix = msg.subject.split_terminator('.').last().unwrap_or_default(); - let subject = suffix.try_into()?; + fn from_transport(subject: &str, payload: Bytes) -> Result { + let subject = subject.try_into()?; + + // msg.headers.unwrap().get("x-idempotent-key"); Ok(Message { subject, - payload: msg.payload, - reply: msg.reply, + payload, + // reply: None, }) } @@ -73,20 +73,16 @@ impl Message { Ok(Message { subject: Subject::Health, payload: Bytes::from_iter(serde_json::to_vec(&health)?), - reply: None, + // reply: None, }) } +} - async fn open_subchannels(&self, client: &Client, sender: mpsc::Sender) { - match self.subject { - Subject::OpenTerminal => { - // let subject = format!("agents.v1.{}.terminal.{}.input"); - // client.nats.subscribe(subject).await; - // sender.send(message).await.unwrap(); - } - _ => {} - } - } +#[derive(Clone)] +pub struct Client { + id: Arc, + nats: async_nats::Client, + subj_prefix: Arc, } impl Client { @@ -101,8 +97,9 @@ impl Client { tracing::debug!("connected to NATS"); Ok(Self { - id: id.to_owned(), nats, + id: Arc::new(id.to_owned()), + subj_prefix: Arc::new(format!("{}.{}", PREFIX, id)), }) } @@ -111,26 +108,29 @@ impl Client { self.nats.publish(subject, msg.payload).await.unwrap(); } - pub async fn reply(&self, msg: Message, mut stream: impl Stream + Unpin) { - match msg.reply { - Some(reply) => { - while let Some(data) = stream.next().await { - self.nats.publish(reply.clone(), data.into()).await.unwrap(); - } - } - None => tracing::warn!(?msg, "tried to reply to message without a reply subject"), - } - } + // pub async fn reply(&self, msg: Message, mut stream: impl Stream + Unpin) { + // match msg.reply { + // Some(reply) => { + // while let Some(data) = stream.next().await { + // self.nats.publish(reply.clone(), data.into()).await.unwrap(); + // } + // } + // None => tracing::warn!(?msg, "tried to reply to message without a reply subject"), + // } + // } pub async fn messages_channel(&self) -> Result> { - let subject = format!("agents.v1.{}.*", self.id); let (sender, receiver) = mpsc::channel(100); - let mut stream = self.subscribe(subject).await?; + let mut stream = self.subscribe("*").await?; let self_clone = self.clone(); tokio::spawn(async move { while let Some(msg) = stream.next().await { - self_clone.clone().open_subchannels(&msg).await; + self_clone + .clone() + .open_subchannels(&msg, sender.clone()) + .await; + sender.send(msg).await.unwrap(); } }); @@ -138,16 +138,22 @@ impl Client { Ok(receiver) } - async fn subscribe(&self, subject: String) -> Result> { + async fn subscribe(&self, subject: &str) -> Result> { + let prefix = Arc::clone(&self.subj_prefix); + let stream = self .nats - .subscribe(subject.clone()) + .subscribe(format!("{}.{}", self.subj_prefix, subject)) .await? - .filter_map(|data| match Message::from_transport(data) { - Ok(msg) => Some(msg), - Err(err) => { - tracing::warn!("{}", err); - None + .filter_map(move |data| { + let subject = data.subject[..].trim_start_matches(&*prefix); + + match Message::from_transport(subject, data.payload) { + Ok(msg) => Some(msg), + Err(err) => { + tracing::warn!("{}", err); + None + } } }); @@ -156,29 +162,29 @@ impl Client { Ok(stream) } - async fn open_subchannels(self, message: &Message, sender: mpsc::Sender) { + async fn open_subchannels(&self, message: &Message, sender: mpsc::Sender) { + fn send_messages_from_stream( + mut stream: impl Stream + Send + Unpin + 'static, + sender: mpsc::Sender, + ) { + tokio::spawn(async move { + while let Some(message) = stream.next().await { + sender.send(message).await.unwrap(); + } + }); + } + match message.subject { Subject::OpenTerminal => { - let terminal_id = "test"; - let stream = self - .subscribe(format!( - "agents.v1.{}.terminal.{}.input", - self.id, terminal_id - )) - .await - .unwrap(); + let input_stream = self.subscribe("terminal.key.input").await.unwrap(); + let resize_stream = self.subscribe("terminal.key.resize").await.unwrap(); + + send_messages_from_stream(input_stream, sender.clone()); + send_messages_from_stream(resize_stream, sender); } _ => {} } } - - fn send_messages_from_stream(self, stream: impl Stream + Send + Unpin) { - tokio::spawn(async { - // while let Some(msg) = stream.next().await { - // - // } - }); - } } impl Debug for Client { diff --git a/agent/src/services/mod.rs b/agent/src/services/mod.rs index 663e4da..d6df248 100644 --- a/agent/src/services/mod.rs +++ b/agent/src/services/mod.rs @@ -1,7 +1,7 @@ -use anyhow::Context; +use std::collections::HashMap; + use bytes::Bytes; use thiserror::Error; -use tokio_stream::StreamExt; mod exec; mod terminal; @@ -12,6 +12,20 @@ enum ServiceError { BodyFormatError, } +struct Service { + terminals: HashMap, +} + +impl Service { + fn serve(&mut self, message: crate::messaging::Message) { + match message.subject() { + crate::messaging::Subject::OpenTerminal => {} + crate::messaging::Subject::Exec => todo!(), + crate::messaging::Subject::Health => todo!(), + } + } +} + async fn route_message(message: crate::messaging::Message) -> Result<(), ServiceError> { match message.subject() { crate::messaging::Subject::Health => {} @@ -30,7 +44,7 @@ async fn route_message(message: crate::messaging::Message) -> Result<(), Service struct Ctx { body: T, - // input_streams: + terminals: HashMap, } impl Ctx @@ -38,10 +52,13 @@ where T: TryFrom, { fn with_body(body: Bytes) -> Result { + let body = body + .try_into() + .map_err(|_err| ServiceError::BodyFormatError)?; + Ok(Self { - body: body - .try_into() - .map_err(|_err| ServiceError::BodyFormatError)?, + body, + terminals: HashMap::default(), }) } } diff --git a/agent/src/services/terminal.rs b/agent/src/services/terminal.rs index 2553b4e..6511b11 100644 --- a/agent/src/services/terminal.rs +++ b/agent/src/services/terminal.rs @@ -1,38 +1,44 @@ +use std::convert::Infallible; + use bytes::Bytes; -use futures::Stream; use serde::Deserialize; use tokio::io::AsyncWriteExt; use tokio_stream::StreamExt; use tokio_util::codec::{BytesCodec, FramedRead}; +use crate::pty::open_shell; + use super::Ctx; #[derive(Debug, Deserialize)] -pub struct OpenTerminalMessage { - id: String, -} +pub struct OpenTerminalMessage(Bytes); impl TryFrom for OpenTerminalMessage { - type Error = serde_json::Error; + type Error = Infallible; fn try_from(value: Bytes) -> Result { - serde_json::from_slice(&value[..]) + Ok(Self(value)) } } -pub async fn open_terminal(ctx: Ctx) -> anyhow::Result<()> { +#[derive(Debug, Deserialize)] +pub struct TerminalInput(Bytes); + +impl TryFrom for TerminalInput { + type Error = Infallible; + + fn try_from(value: Bytes) -> Result { + Ok(Self(value)) + } +} + +pub async fn open_terminal( + mut ctx: Ctx, +) -> anyhow::Result> { let pty = crate::pty::Pty::open()?; - let mut pty_clone = pty.try_clone()?; + let shell = open_shell(pty.child()?, "/bin/bash")?; - tokio::spawn(async move { - while let Some(data) = tokio_stream::once(b"foo").next().await { - if let Err(err) = pty_clone.write_all(&data[..]).await { - tracing::warn!(%err, "pseudoterminal write error"); - } - } - }); - - let _out_stream = FramedRead::new(pty, BytesCodec::new()).filter_map(|inner| { + let _out_stream = FramedRead::new(pty.try_clone()?, BytesCodec::new()).filter_map(|inner| { inner .map(|bytes| bytes.freeze()) .map_err(|err| { @@ -41,5 +47,20 @@ pub async fn open_terminal(ctx: Ctx) -> anyhow::Result<()> .ok() }); - Ok(()) + ctx.terminals.insert(String::from("test"), (pty, shell)); + + Ok(ctx) +} + +pub async fn terminal_input( + terminal_id: &str, + mut ctx: Ctx, +) -> anyhow::Result> { + let (pty, _) = ctx.terminals.get_mut(terminal_id).unwrap(); + + if let Err(err) = pty.write_all(&ctx.body.0[..]).await { + tracing::warn!(%err, "pseudoterminal write error"); + } + + Ok(ctx) }