rewrite subscription complete
This commit is contained in:
@ -31,3 +31,4 @@ regex = "1.9.6"
|
||||
oo7 = "0.2.1"
|
||||
async-trait = "0.1.83"
|
||||
http = "1.1.0"
|
||||
async-channel = "2.3.1"
|
||||
@ -1,6 +1,7 @@
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
@ -135,14 +136,14 @@ pub struct Credential {
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Credentials {
|
||||
keyring: Rc<dyn LightKeyring>,
|
||||
creds: Rc<RefCell<HashMap<String, Credential>>>,
|
||||
keyring: Arc<dyn LightKeyring + Send + Sync>,
|
||||
creds: Arc<RwLock<HashMap<String, Credential>>>,
|
||||
}
|
||||
|
||||
impl Credentials {
|
||||
pub async fn new() -> anyhow::Result<Self> {
|
||||
let mut this = Self {
|
||||
keyring: Rc::new(RealKeyring {
|
||||
keyring: Arc::new(RealKeyring {
|
||||
keyring: oo7::Keyring::new()
|
||||
.await
|
||||
.expect("Failed to start Secret Service"),
|
||||
@ -154,7 +155,7 @@ impl Credentials {
|
||||
}
|
||||
pub async fn new_nullable(credentials: Vec<Credential>) -> anyhow::Result<Self> {
|
||||
let mut this = Self {
|
||||
keyring: Rc::new(NullableKeyring::with_credentials(credentials)),
|
||||
keyring: Arc::new(NullableKeyring::with_credentials(credentials)),
|
||||
creds: Default::default(),
|
||||
};
|
||||
this.load().await?;
|
||||
@ -168,12 +169,13 @@ impl Credentials {
|
||||
.await
|
||||
.map_err(|e| capnp::Error::failed(e.to_string()))?;
|
||||
|
||||
self.creds.borrow_mut().clear();
|
||||
let mut lock = self.creds.write().unwrap();
|
||||
lock.clear();
|
||||
for item in values {
|
||||
let attrs = item
|
||||
.attributes()
|
||||
.await;
|
||||
self.creds.borrow_mut().insert(
|
||||
lock.insert(
|
||||
attrs["server"].to_string(),
|
||||
Credential {
|
||||
username: attrs["username"].to_string(),
|
||||
@ -184,14 +186,14 @@ impl Credentials {
|
||||
Ok(())
|
||||
}
|
||||
pub fn get(&self, server: &str) -> Option<Credential> {
|
||||
self.creds.borrow().get(server).cloned()
|
||||
self.creds.read().unwrap().get(server).cloned()
|
||||
}
|
||||
pub fn list_all(&self) -> HashMap<String, Credential> {
|
||||
self.creds.borrow().clone()
|
||||
self.creds.read().unwrap().clone()
|
||||
}
|
||||
pub async fn insert(&self, server: &str, username: &str, password: &str) -> anyhow::Result<()> {
|
||||
{
|
||||
if let Some(cred) = self.creds.borrow().get(server) {
|
||||
if let Some(cred) = self.creds.read().unwrap().get(server) {
|
||||
if cred.username != username {
|
||||
anyhow::bail!("You can add only one account per server");
|
||||
}
|
||||
@ -207,7 +209,7 @@ impl Credentials {
|
||||
.await
|
||||
.map_err(|e| capnp::Error::failed(e.to_string()))?;
|
||||
|
||||
self.creds.borrow_mut().insert(
|
||||
self.creds.write().unwrap().insert(
|
||||
server.to_string(),
|
||||
Credential {
|
||||
username: username.to_string(),
|
||||
@ -219,7 +221,8 @@ impl Credentials {
|
||||
pub async fn delete(&self, server: &str) -> anyhow::Result<()> {
|
||||
let creds = {
|
||||
self.creds
|
||||
.borrow()
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(server)
|
||||
.ok_or(anyhow::anyhow!("server creds not found"))?
|
||||
.clone()
|
||||
@ -234,7 +237,8 @@ impl Credentials {
|
||||
.await
|
||||
.map_err(|e| capnp::Error::failed(e.to_string()))?;
|
||||
self.creds
|
||||
.borrow_mut()
|
||||
.write()
|
||||
.unwrap()
|
||||
.remove(server)
|
||||
.ok_or(anyhow::anyhow!("server creds not found"))?;
|
||||
Ok(())
|
||||
|
||||
@ -3,7 +3,7 @@ use async_trait::async_trait;
|
||||
use reqwest::{header::HeaderMap, Client, Request, RequestBuilder, Response, ResponseBuilderExt};
|
||||
use serde_json::{json, Value};
|
||||
use tokio::time;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::RwLock;
|
||||
@ -86,26 +86,90 @@ impl HttpClient {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct NullableClient {
|
||||
responses: Arc<RwLock<HashMap<String, Response>>>,
|
||||
responses: Arc<RwLock<HashMap<String, VecDeque<Response>>>>,
|
||||
default_response: Arc<RwLock<Option<Box<dyn Fn() -> Response + Send + Sync + 'static>>>>,
|
||||
}
|
||||
|
||||
impl NullableClient {
|
||||
/// Builder for configuring NullableClient
|
||||
#[derive(Default)]
|
||||
pub struct NullableClientBuilder {
|
||||
responses: HashMap<String, VecDeque<Response>>,
|
||||
default_response: Option<Box<dyn Fn() -> Response + Send + Sync + 'static>>,
|
||||
}
|
||||
|
||||
impl NullableClientBuilder {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub async fn set_response(&self, url: &str, response: Response) {
|
||||
/// Add a single response for a specific URL
|
||||
pub fn response(mut self, url: impl Into<String>, response: Response) -> Self {
|
||||
self.responses
|
||||
.write()
|
||||
.await
|
||||
.insert(url.to_string(), response);
|
||||
.entry(url.into())
|
||||
.or_default()
|
||||
.push_back(response);
|
||||
self
|
||||
}
|
||||
|
||||
pub async fn set_default_response(&self, res: Box<dyn Fn() -> Response + Send + Sync + 'static>) {
|
||||
*self.default_response.write().await = Some(res);
|
||||
/// Add multiple responses for a specific URL that will be returned in sequence
|
||||
pub fn responses(mut self, url: impl Into<String>, responses: Vec<Response>) -> Self {
|
||||
self.responses.insert(url.into(), responses.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set a default response generator for any unmatched URLs
|
||||
pub fn default_response(
|
||||
mut self,
|
||||
response: impl Fn() -> Response + Send + Sync + 'static,
|
||||
) -> Self {
|
||||
self.default_response = Some(Box::new(response));
|
||||
self
|
||||
}
|
||||
|
||||
/// Helper method to quickly add a JSON response
|
||||
pub fn json_response(
|
||||
self,
|
||||
url: impl Into<String>,
|
||||
status: u16,
|
||||
body: impl serde::Serialize,
|
||||
) -> Result<Self> {
|
||||
let response = http::response::Builder::new()
|
||||
.status(status)
|
||||
.body(serde_json::to_string(&body)?)
|
||||
.unwrap()
|
||||
.into();
|
||||
Ok(self.response(url, response))
|
||||
}
|
||||
|
||||
/// Helper method to quickly add a text response
|
||||
pub fn text_response(
|
||||
self,
|
||||
url: impl Into<String>,
|
||||
status: u16,
|
||||
body: impl Into<String>,
|
||||
) -> Self {
|
||||
let response = http::response::Builder::new()
|
||||
.status(status)
|
||||
.body(body.into())
|
||||
.unwrap()
|
||||
.into();
|
||||
self.response(url, response)
|
||||
}
|
||||
|
||||
pub fn build(self) -> NullableClient {
|
||||
NullableClient {
|
||||
responses: Arc::new(RwLock::new(self.responses.into_iter().map(|(k, v)| (k, v.into())).collect())),
|
||||
default_response: Arc::new(RwLock::new(self.default_response)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl NullableClient {
|
||||
pub fn builder() -> NullableClientBuilder {
|
||||
NullableClientBuilder::new()
|
||||
}
|
||||
}
|
||||
|
||||
@ -116,15 +180,28 @@ impl LightHttpClient for NullableClient {
|
||||
}
|
||||
|
||||
async fn execute(&self, request: Request) -> Result<Response> {
|
||||
time::sleep(Duration::from_millis(1)).await; // else we spam the thread with responses
|
||||
// Get the configured response or return a default one
|
||||
time::sleep(Duration::from_millis(1)).await;
|
||||
let url = request.url().to_string();
|
||||
if let Some(response) = self.responses.write().await.remove(&url) {
|
||||
Ok(response)
|
||||
} else if let Some(res) = &*self.default_response.read().await {
|
||||
Ok(res())
|
||||
let mut responses = self.responses.write().await;
|
||||
|
||||
if let Some(url_responses) = responses.get_mut(&url) {
|
||||
if let Some(response) = url_responses.pop_front() {
|
||||
// Remove the URL entry if no more responses
|
||||
if url_responses.is_empty() {
|
||||
responses.remove(&url);
|
||||
}
|
||||
Ok(response)
|
||||
} else {
|
||||
if let Some(default_fn) = &*self.default_response.read().await {
|
||||
Ok(default_fn())
|
||||
} else {
|
||||
Err(anyhow::anyhow!("no response configured for URL: {}", url))
|
||||
}
|
||||
}
|
||||
} else if let Some(default_fn) = &*self.default_response.read().await {
|
||||
Ok(default_fn())
|
||||
} else {
|
||||
Err(anyhow::anyhow!("no response"))
|
||||
Err(anyhow::anyhow!("no response configured for URL: {}", url))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -132,80 +209,93 @@ impl LightHttpClient for NullableClient {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_nullable() -> Result<()> {
|
||||
let client = NullableClient::new();
|
||||
async fn test_nullable_with_builder() -> Result<()> {
|
||||
// Configure client using builder pattern
|
||||
let client = NullableClient::builder()
|
||||
.text_response("https://api.example.com/topic", 200, "ok")
|
||||
.json_response(
|
||||
"https://api.example.com/json",
|
||||
200,
|
||||
json!({ "status": "success" }),
|
||||
)?
|
||||
.default_response(|| {
|
||||
http::response::Builder::new()
|
||||
.status(404)
|
||||
.body("not found")
|
||||
.unwrap()
|
||||
.into()
|
||||
})
|
||||
.build();
|
||||
|
||||
// Configure mock response
|
||||
let mock_response = http::response::Builder::new()
|
||||
.status(200)
|
||||
.body("ok")
|
||||
.unwrap()
|
||||
.into();
|
||||
client
|
||||
.set_response("https://api.example.com/topic", mock_response)
|
||||
.await;
|
||||
|
||||
let client = HttpClient::new_nullable(client);
|
||||
let request_tracker = client.request_tracker().await;
|
||||
|
||||
let req = client
|
||||
.get("https://api.example.com/topic")
|
||||
.header("Content-Type", "application/x-ndjson")
|
||||
.header("Transfer-Encoding", "chunked")
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
// Execute request
|
||||
let response = client.execute(req).await?;
|
||||
let http_client = HttpClient::new_nullable(client);
|
||||
let request_tracker = http_client.request_tracker().await;
|
||||
|
||||
// Test successful text response
|
||||
let request = http_client.get("https://api.example.com/topic").build()?;
|
||||
let response = http_client.execute(request).await?;
|
||||
assert_eq!(response.status(), 200);
|
||||
assert_eq!(response.bytes().await.unwrap(), b"ok"[..]);
|
||||
assert_eq!(response.text().await?, "ok");
|
||||
|
||||
// Test successful JSON response
|
||||
let request = http_client.get("https://api.example.com/json").build()?;
|
||||
let response = http_client.execute(request).await?;
|
||||
assert_eq!(response.status(), 200);
|
||||
assert_eq!(response.text().await?, r#"{"status":"success"}"#);
|
||||
|
||||
// Test default response
|
||||
let request = http_client.get("https://api.example.com/unknown").build()?;
|
||||
let response = http_client.execute(request).await?;
|
||||
assert_eq!(response.status(), 404);
|
||||
assert_eq!(response.text().await?, "not found");
|
||||
|
||||
// Verify recorded requests
|
||||
let requests = request_tracker.items().await;
|
||||
assert_eq!(requests.len(), 1);
|
||||
|
||||
let request = &requests[0];
|
||||
assert_eq!(request.method, "GET");
|
||||
assert_eq!(
|
||||
request.headers.get("Content-Type").unwrap(),
|
||||
"application/x-ndjson"
|
||||
);
|
||||
assert_eq!(request.headers.get("Transfer-Encoding").unwrap(), "chunked");
|
||||
assert_eq!(requests.len(), 3);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_nullable_with_failing_response() -> Result<()> {
|
||||
let client = NullableClient::new();
|
||||
async fn test_sequence_of_responses() -> Result<()> {
|
||||
// Configure client with multiple responses for the same URL
|
||||
let client = NullableClient::builder()
|
||||
.responses(
|
||||
"https://api.example.com/sequence",
|
||||
vec![
|
||||
http::response::Builder::new()
|
||||
.status(200)
|
||||
.body("first")
|
||||
.unwrap()
|
||||
.into(),
|
||||
http::response::Builder::new()
|
||||
.status(200)
|
||||
.body("second")
|
||||
.unwrap()
|
||||
.into(),
|
||||
],
|
||||
)
|
||||
.build();
|
||||
|
||||
// Configure mock response
|
||||
let mock_response = http::response::Builder::new()
|
||||
.status(400)
|
||||
.body("fail")
|
||||
.unwrap()
|
||||
.into();
|
||||
client
|
||||
.set_response("https://api.example.com/topic", mock_response)
|
||||
.await;
|
||||
let http_client = HttpClient::new_nullable(client);
|
||||
|
||||
let req = client
|
||||
.get("https://api.example.com/topic")
|
||||
.header("Content-Type", "application/x-ndjson")
|
||||
.header("Transfer-Encoding", "chunked")
|
||||
.build()
|
||||
.unwrap();
|
||||
// First request gets first response
|
||||
let request = http_client.get("https://api.example.com/sequence").build()?;
|
||||
let response = http_client.execute(request).await?;
|
||||
assert_eq!(response.text().await?, "first");
|
||||
|
||||
// Execute request
|
||||
let response = client.execute(req).await?;
|
||||
let response: Result<_, _> = response.error_for_status();
|
||||
// Second request gets second response
|
||||
let request = http_client.get("https://api.example.com/sequence").build()?;
|
||||
let response = http_client.execute(request).await?;
|
||||
assert_eq!(response.text().await?, "second");
|
||||
|
||||
dbg!(&response);
|
||||
assert!(matches!(response, Err(_)));
|
||||
// Third request fails (no more responses)
|
||||
let request = http_client.get("https://api.example.com/sequence").build()?;
|
||||
let result = http_client.execute(request).await;
|
||||
assert!(result.is_err());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -8,8 +8,12 @@ mod http_client;
|
||||
mod output_tracker;
|
||||
mod listener;
|
||||
mod ntfy;
|
||||
mod subscription;
|
||||
|
||||
pub use ntfy::Ntfy;
|
||||
pub use subscription::SubscriptionHandle;
|
||||
pub use listener::*;
|
||||
pub use ntfy::NtfyHandle;
|
||||
pub use ntfy::start;
|
||||
|
||||
pub mod ntfy_capnp {
|
||||
include!(concat!(env!("OUT_DIR"), "/src/ntfy_capnp.rs"));
|
||||
|
||||
@ -1,15 +1,17 @@
|
||||
use std::cell::RefCell;
|
||||
use std::sync::Arc;
|
||||
use std::thread::JoinHandle;
|
||||
use std::{rc::Rc, time::Duration};
|
||||
use std::{time::Duration};
|
||||
|
||||
use futures::{StreamExt, TryStreamExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::spawn;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::task::{self, spawn_local, AbortHandle, LocalSet};
|
||||
use tokio::{
|
||||
select,
|
||||
sync::{broadcast, mpsc, watch},
|
||||
sync::{mpsc, watch, oneshot},
|
||||
};
|
||||
use tokio_stream::wrappers::LinesStream;
|
||||
use tracing::{debug, error, info};
|
||||
@ -23,7 +25,7 @@ use tokio::time::timeout;
|
||||
const CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(15);
|
||||
const TIMEOUT: std::time::Duration = std::time::Duration::from_secs(240); // 4 minutes
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "event")]
|
||||
pub enum ServerEvent {
|
||||
#[serde(rename = "open")]
|
||||
@ -34,12 +36,9 @@ pub enum ServerEvent {
|
||||
topic: String,
|
||||
},
|
||||
#[serde(rename = "message")]
|
||||
Message {
|
||||
id: String,
|
||||
expires: Option<usize>,
|
||||
#[serde(flatten)]
|
||||
message: models::Message,
|
||||
},
|
||||
Message (
|
||||
models::Message,
|
||||
),
|
||||
#[serde(rename = "keepalive")]
|
||||
KeepAlive {
|
||||
id: String,
|
||||
@ -64,10 +63,11 @@ pub struct ListenerConfig {
|
||||
pub(crate) since: u64,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Debug)]
|
||||
pub enum ListenerCommand {
|
||||
Restart,
|
||||
Shutdown,
|
||||
GetState(oneshot::Sender<ConnectionState>),
|
||||
}
|
||||
|
||||
fn topic_request(
|
||||
@ -108,63 +108,82 @@ pub enum ConnectionState {
|
||||
},
|
||||
}
|
||||
|
||||
pub struct ConnectionHandler {
|
||||
pub event_tx: watch::Sender<ListenerEvent>,
|
||||
pub commands_rx: Option<broadcast::Receiver<ListenerCommand>>,
|
||||
pub struct ListenerActor {
|
||||
pub event_tx: async_channel::Sender<ListenerEvent>,
|
||||
pub commands_rx: Option<mpsc::Receiver<ListenerCommand>>,
|
||||
pub config: ListenerConfig,
|
||||
pub state: Rc<RefCell<ConnectionState>>,
|
||||
pub state: ConnectionState,
|
||||
}
|
||||
|
||||
impl ConnectionHandler {
|
||||
fn new(
|
||||
impl ListenerActor {
|
||||
pub fn new(
|
||||
config: ListenerConfig,
|
||||
event_tx: watch::Sender<ListenerEvent>,
|
||||
commands_rx: broadcast::Receiver<ListenerCommand>,
|
||||
) -> Self {
|
||||
let this = Self {
|
||||
event_tx,
|
||||
commands_rx: Some(commands_rx),
|
||||
) -> ListenerHandle {
|
||||
let (event_tx, event_rx) = async_channel::bounded(64);
|
||||
let (commands_tx, commands_rx) = mpsc::channel(1);
|
||||
|
||||
let config_clone = config.clone();
|
||||
|
||||
// use a new local set to isolate panics
|
||||
let local_set = LocalSet::new();
|
||||
local_set.spawn_local(async move {
|
||||
|
||||
let this = Self {
|
||||
event_tx,
|
||||
commands_rx: Some(commands_rx),
|
||||
config: config_clone,
|
||||
state: ConnectionState::Unitialized,
|
||||
};
|
||||
|
||||
this.run_loop().await;
|
||||
});
|
||||
spawn_local(local_set);
|
||||
|
||||
ListenerHandle {
|
||||
events: event_rx,
|
||||
config,
|
||||
state: Rc::new(RefCell::new(ConnectionState::Unitialized)),
|
||||
};
|
||||
this
|
||||
commands: commands_tx,
|
||||
listener_actor: Arc::new(RwLock::new(None)),
|
||||
join_handle: Arc::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run(mut self) -> task::JoinHandle<()> {
|
||||
spawn_local(async move {
|
||||
pub async fn run_loop(mut self) {
|
||||
let mut commands_rx = self.commands_rx.take().unwrap();
|
||||
loop {
|
||||
select! {
|
||||
_ = self.run_supervised_loop() => {
|
||||
// the supervised loop cannot fail. If it finished, don't restart.
|
||||
break;
|
||||
},
|
||||
cmd = commands_rx.recv() => {
|
||||
match cmd {
|
||||
Ok(ListenerCommand::Restart) => {
|
||||
info!("Received restart command");
|
||||
continue;
|
||||
}
|
||||
Ok(ListenerCommand::Shutdown) => {
|
||||
info!("Received shutdown command");
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Command receive error: {:?}", e);
|
||||
break;
|
||||
}
|
||||
_ = self.run_supervised_loop() => {
|
||||
// the supervised loop cannot fail. If it finished, don't restart.
|
||||
break;
|
||||
},
|
||||
cmd = commands_rx.recv() => {
|
||||
match cmd {
|
||||
Some(ListenerCommand::Restart) => {
|
||||
info!("Received restart command");
|
||||
continue;
|
||||
}
|
||||
Some(ListenerCommand::Shutdown) => {
|
||||
info!("Received shutdown command");
|
||||
break;
|
||||
}
|
||||
Some(ListenerCommand::GetState(tx)) => {
|
||||
info!("Received get state command");
|
||||
let state = self.state.clone();
|
||||
let _ = tx.send(state);
|
||||
}
|
||||
None => {
|
||||
error!("Channel closed for ListenerActor");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn set_state(&mut self, state: ConnectionState) {
|
||||
self.state.replace(state.clone());
|
||||
self.event_tx
|
||||
.send(ListenerEvent::ConnectionStateChanged(state)).unwrap();
|
||||
async fn set_state(&mut self, state: ConnectionState) {
|
||||
self.state = state.clone();
|
||||
self.event_tx.send(ListenerEvent::ConnectionStateChanged(state)).await.unwrap();
|
||||
}
|
||||
async fn run_supervised_loop(&mut self) {
|
||||
dbg!("supervised");
|
||||
@ -189,7 +208,7 @@ impl ConnectionHandler {
|
||||
retry_count: retry.count(),
|
||||
delay: retry.next_delay(),
|
||||
error: Some(Arc::new(e)),
|
||||
});
|
||||
}).await;
|
||||
info!(delay = ?retry.next_delay(), "restarting");
|
||||
retry.wait().await;
|
||||
} else {
|
||||
@ -219,12 +238,11 @@ impl ConnectionHandler {
|
||||
|
||||
self.set_state(
|
||||
ConnectionState::Connected,
|
||||
);
|
||||
).await;
|
||||
|
||||
info!(topic = %&self.config.topic, "listening");
|
||||
while let Some(msg) = stream.next().await {
|
||||
let msg = msg?;
|
||||
dbg!(&msg);
|
||||
|
||||
let min_msg = serde_json::from_str::<models::MinMessage>(&msg)
|
||||
.map_err(|e| Error::InvalidMinMessage(msg.to_string(), e))?;
|
||||
@ -234,9 +252,9 @@ impl ConnectionHandler {
|
||||
.map_err(|e| Error::InvalidMessage(msg.to_string(), e))?;
|
||||
|
||||
match event {
|
||||
ServerEvent::Message { message, .. } => {
|
||||
ServerEvent::Message(msg) => {
|
||||
debug!("message event");
|
||||
self.event_tx.send(ListenerEvent::Message(message))?;
|
||||
self.event_tx.send(ListenerEvent::Message(msg)).await.unwrap();
|
||||
}
|
||||
ServerEvent::KeepAlive { .. } => {
|
||||
debug!("keepalive event");
|
||||
@ -253,61 +271,27 @@ impl ConnectionHandler {
|
||||
|
||||
// Reliable listener implementation
|
||||
#[derive(Clone)]
|
||||
pub struct Listener {
|
||||
pub state: Rc<RefCell<ConnectionState>>,
|
||||
pub events: watch::Receiver<ListenerEvent>,
|
||||
pub struct ListenerHandle {
|
||||
pub events: async_channel::Receiver<ListenerEvent>,
|
||||
pub config: ListenerConfig,
|
||||
pub commands: broadcast::Sender<ListenerCommand>,
|
||||
pub event_tracker: OutputTracker<ListenerEvent>,
|
||||
local_set: Rc<LocalSet>,
|
||||
connection_handler: Rc<RefCell<Option<ConnectionHandler>>>,
|
||||
pub commands: mpsc::Sender<ListenerCommand>,
|
||||
join_handle: Arc<Option<task::JoinHandle<()>>>,
|
||||
listener_actor: Arc<RwLock<Option<ListenerActor>>>,
|
||||
}
|
||||
|
||||
impl Listener {
|
||||
pub fn new(config: ListenerConfig) -> Self {
|
||||
let (tx, rx) = watch::channel(ListenerEvent::ConnectionStateChanged(
|
||||
ConnectionState::Unitialized,
|
||||
));
|
||||
let (commands_tx, commands_rx) = broadcast::channel(1);
|
||||
|
||||
let local_set = Rc::new(LocalSet::new());
|
||||
let connection_handler = ConnectionHandler::new(config.clone(), tx, commands_rx);
|
||||
let state = connection_handler.state.clone();
|
||||
|
||||
let event_tracker = OutputTracker::default();
|
||||
// let event_tracker_clone = event_tracker.clone();
|
||||
// let mut rx_clone = rx.clone();
|
||||
// local_set.spawn_local(async move {
|
||||
// rx_clone.changed().await.unwrap();
|
||||
// event_tracker_clone.push(rx_clone.borrow().clone());
|
||||
// });
|
||||
|
||||
Listener {
|
||||
state,
|
||||
events: rx,
|
||||
config,
|
||||
commands: commands_tx,
|
||||
local_set,
|
||||
event_tracker,
|
||||
connection_handler: Rc::new(RefCell::new(Some(connection_handler))),
|
||||
}
|
||||
}
|
||||
pub async fn run(&mut self) {
|
||||
let connection_handler = self.connection_handler.take().unwrap();
|
||||
|
||||
let _ = self
|
||||
.local_set
|
||||
.run_until(async move {
|
||||
connection_handler.run().await.unwrap();
|
||||
})
|
||||
.await;
|
||||
impl ListenerHandle {
|
||||
// the response will be sent as an event in self.events
|
||||
pub async fn request_state(&self) -> ConnectionState {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.commands.send(ListenerCommand::GetState(tx)).await.unwrap();
|
||||
rx.await.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use models::Subscription;
|
||||
use reqwest::ResponseBuilderExt;
|
||||
use serde_json::json;
|
||||
use task::LocalSet;
|
||||
use tokio_stream::wrappers::WatchStream;
|
||||
|
||||
@ -328,34 +312,16 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_listener_reconnects_on_http_status_400() {
|
||||
async fn test_listener_reconnects_on_http_status_500() {
|
||||
let local_set = LocalSet::new();
|
||||
local_set
|
||||
.run_until(async {
|
||||
.spawn_local(async {
|
||||
let http_client = HttpClient::new_nullable({
|
||||
let nullable = NullableClient::new();
|
||||
let url = Subscription::build_url("http://localhost", "test", 0).unwrap();
|
||||
nullable
|
||||
.set_response(
|
||||
url.as_str(),
|
||||
reqwest::Response::from(
|
||||
http::response::Builder::new()
|
||||
.status(500)
|
||||
.url(url.clone())
|
||||
.body("failed")
|
||||
.unwrap(),
|
||||
),
|
||||
)
|
||||
.await;
|
||||
nullable.set_default_response(Box::new(move || {
|
||||
reqwest::Response::from(
|
||||
http::response::Builder::new()
|
||||
.status(200)
|
||||
.url(url.clone())
|
||||
.body(r#"{"id":"SLiKI64DOt","time":1635528757,"event":"open","topic":"mytopic"}"#)
|
||||
.unwrap(),
|
||||
)})
|
||||
).await;
|
||||
let nullable = NullableClient::builder()
|
||||
.text_response(url.clone(), 500, "failed")
|
||||
.json_response(url, 200, json!({"id":"SLiKI64DOt","time":1635528757,"event":"open","topic":"mytopic"})).unwrap()
|
||||
.build();
|
||||
nullable
|
||||
});
|
||||
let credentials = Credentials::new_nullable(vec![]).await.unwrap();
|
||||
@ -368,11 +334,8 @@ mod tests {
|
||||
since: 0,
|
||||
};
|
||||
|
||||
let mut listener = Listener::new(config.clone());
|
||||
let events = listener.events.clone();
|
||||
let changes = WatchStream::new(events);
|
||||
spawn_local(async move { listener.run().await });
|
||||
let items: Vec<_> = changes.take(3).collect().await;
|
||||
let mut listener = ListenerActor::new(config.clone());
|
||||
let items: Vec<_> = listener.events.take(3).collect().await;
|
||||
|
||||
|
||||
dbg!(&items);
|
||||
@ -391,40 +354,21 @@ mod tests {
|
||||
// ListenerEvent::Disconnected { .. },
|
||||
// ListenerEvent::Connected { .. },
|
||||
// ));
|
||||
})
|
||||
.await;
|
||||
});
|
||||
local_set.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_listener_reconnects_on_invalid_message() {
|
||||
let local_set = LocalSet::new();
|
||||
local_set
|
||||
.run_until(async {
|
||||
.spawn_local(async {
|
||||
let http_client = HttpClient::new_nullable({
|
||||
let nullable = NullableClient::new();
|
||||
let url = Subscription::build_url("http://localhost", "test", 0).unwrap();
|
||||
nullable
|
||||
.set_response(
|
||||
url.as_str(),
|
||||
reqwest::Response::from(
|
||||
http::response::Builder::new()
|
||||
.status(200)
|
||||
.url(url.clone())
|
||||
.body("failed")
|
||||
.unwrap(),
|
||||
),
|
||||
)
|
||||
.await;
|
||||
nullable.set_default_response(Box::new(move || {
|
||||
reqwest::Response::from(
|
||||
http::response::Builder::new()
|
||||
.status(200)
|
||||
.url(url.clone())
|
||||
.body(r#"{"id":"SLiKI64DOt","time":1635528757,"event":"open","topic":"mytopic"}"#)
|
||||
.unwrap(),
|
||||
)
|
||||
})).await;
|
||||
|
||||
let nullable = NullableClient::builder()
|
||||
.text_response(url.clone(), 200, "invalid message")
|
||||
.json_response(url, 200, json!({"id":"SLiKI64DOt","time":1635528757,"event":"open","topic":"mytopic"})).unwrap()
|
||||
.build();
|
||||
nullable
|
||||
});
|
||||
let credentials = Credentials::new_nullable(vec![]).await.unwrap();
|
||||
@ -437,11 +381,8 @@ mod tests {
|
||||
since: 0,
|
||||
};
|
||||
|
||||
let mut listener = Listener::new(config.clone());
|
||||
let events = listener.events.clone();
|
||||
let changes = WatchStream::new(events);
|
||||
spawn_local(async move { listener.run().await });
|
||||
let items: Vec<_> = changes.take(3).collect().await;
|
||||
let mut listener = ListenerActor::new(config.clone());
|
||||
let items: Vec<_> = listener.events.take(3).collect().await;
|
||||
|
||||
dbg!(&items);
|
||||
assert!(matches!(
|
||||
@ -452,15 +393,15 @@ mod tests {
|
||||
ListenerEvent::ConnectionStateChanged(ConnectionState::Connected { .. }),
|
||||
]
|
||||
));
|
||||
})
|
||||
.await;
|
||||
});
|
||||
local_set.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn integration_connects_sends_receives_simple() {
|
||||
let local_set = LocalSet::new();
|
||||
local_set
|
||||
.run_until(async {
|
||||
.spawn_local(async {
|
||||
let http_client = HttpClient::new(reqwest::Client::new());
|
||||
let credentials = Credentials::new_nullable(vec![]).await.unwrap();
|
||||
|
||||
@ -472,10 +413,10 @@ mod tests {
|
||||
since: 0,
|
||||
};
|
||||
|
||||
let mut listener = Listener::new(config.clone());
|
||||
let mut listener = ListenerActor::new(config.clone());
|
||||
|
||||
// assert_event_matches!(listener, ListenerEvent::Connected { .. },);
|
||||
})
|
||||
.await;
|
||||
});
|
||||
local_set.await;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::{cell::RefCell, rc::Rc};
|
||||
|
||||
use rusqlite::{params, Connection, Result};
|
||||
@ -8,16 +9,16 @@ use crate::Error;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Db {
|
||||
conn: Rc<RefCell<Connection>>,
|
||||
conn: Arc<RwLock<Connection>>,
|
||||
}
|
||||
|
||||
impl Db {
|
||||
pub fn connect(path: &str) -> Result<Self> {
|
||||
let mut this = Self {
|
||||
conn: Rc::new(RefCell::new(Connection::open(path)?)),
|
||||
conn: Arc::new(RwLock::new(Connection::open(path)?)),
|
||||
};
|
||||
{
|
||||
this.conn.borrow().execute_batch(
|
||||
this.conn.read().unwrap().execute_batch(
|
||||
"PRAGMA foreign_keys = ON;
|
||||
PRAGMA journal_mode = wal;",
|
||||
)?;
|
||||
@ -27,12 +28,12 @@ impl Db {
|
||||
}
|
||||
fn migrate(&mut self) -> Result<()> {
|
||||
self.conn
|
||||
.borrow()
|
||||
.read().unwrap()
|
||||
.execute_batch(include_str!("./migrations/00.sql"))?;
|
||||
Ok(())
|
||||
}
|
||||
fn get_or_insert_server(&mut self, server: &str) -> Result<i64> {
|
||||
let mut conn = self.conn.borrow_mut();
|
||||
let mut conn = self.conn.write().unwrap();
|
||||
let tx = conn.transaction()?;
|
||||
let mut res = tx.query_row(
|
||||
"SELECT id
|
||||
@ -56,7 +57,7 @@ impl Db {
|
||||
}
|
||||
pub fn insert_message(&mut self, server: &str, json_data: &str) -> Result<(), Error> {
|
||||
let server_id = self.get_or_insert_server(server)?;
|
||||
let res = self.conn.borrow().execute(
|
||||
let res = self.conn.read().unwrap().execute(
|
||||
"INSERT INTO message (server, data) VALUES (?1, ?2)",
|
||||
params![server_id, json_data],
|
||||
);
|
||||
@ -76,7 +77,7 @@ impl Db {
|
||||
topic: &str,
|
||||
since: u64,
|
||||
) -> Result<Vec<String>, rusqlite::Error> {
|
||||
let conn = self.conn.borrow();
|
||||
let conn = self.conn.read().unwrap();
|
||||
let mut stmt = conn.prepare(
|
||||
"
|
||||
SELECT data
|
||||
@ -94,7 +95,7 @@ impl Db {
|
||||
}
|
||||
pub fn insert_subscription(&mut self, sub: models::Subscription) -> Result<(), Error> {
|
||||
let server_id = self.get_or_insert_server(&sub.server)?;
|
||||
self.conn.borrow().execute(
|
||||
self.conn.read().unwrap().execute(
|
||||
"INSERT INTO subscription (server, topic, display_name, reserved, muted, archived) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
|
||||
params![
|
||||
server_id,
|
||||
@ -109,7 +110,7 @@ impl Db {
|
||||
}
|
||||
pub fn remove_subscription(&mut self, server: &str, topic: &str) -> Result<(), Error> {
|
||||
let server_id = self.get_or_insert_server(server)?;
|
||||
let res = self.conn.borrow().execute(
|
||||
let res = self.conn.read().unwrap().execute(
|
||||
"DELETE FROM subscription
|
||||
WHERE server = ?1 AND topic = ?2",
|
||||
params![server_id, topic],
|
||||
@ -120,7 +121,7 @@ impl Db {
|
||||
Ok(())
|
||||
}
|
||||
pub fn list_subscriptions(&mut self) -> Result<Vec<models::Subscription>, Error> {
|
||||
let conn = self.conn.borrow();
|
||||
let conn = self.conn.read().unwrap();
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT server.endpoint, sub.topic, sub.display_name, sub.reserved, sub.muted, sub.archived, sub.symbolic_icon, sub.read_until
|
||||
FROM subscription sub
|
||||
@ -146,7 +147,7 @@ impl Db {
|
||||
|
||||
pub fn update_subscription(&mut self, sub: models::Subscription) -> Result<(), Error> {
|
||||
let server_id = self.get_or_insert_server(&sub.server)?;
|
||||
let res = self.conn.borrow().execute(
|
||||
let res = self.conn.read().unwrap().execute(
|
||||
"UPDATE subscription
|
||||
SET display_name = ?1, reserved = ?2, muted = ?3, archived = ?4, read_until = ?5
|
||||
WHERE server = ?6 AND topic = ?7",
|
||||
@ -174,7 +175,7 @@ impl Db {
|
||||
value: u64,
|
||||
) -> Result<(), Error> {
|
||||
let server_id = self.get_or_insert_server(server).unwrap();
|
||||
let conn = self.conn.borrow();
|
||||
let conn = self.conn.read().unwrap();
|
||||
let res = conn.execute(
|
||||
"UPDATE subscription
|
||||
SET read_until = ?3
|
||||
@ -189,7 +190,7 @@ impl Db {
|
||||
}
|
||||
pub fn delete_messages(&mut self, server: &str, topic: &str) -> Result<(), Error> {
|
||||
let server_id = self.get_or_insert_server(server).unwrap();
|
||||
let conn = self.conn.borrow();
|
||||
let conn = self.conn.read().unwrap();
|
||||
let res = conn.execute(
|
||||
"DELETE FROM message
|
||||
WHERE topic = ?2 AND server = ?1
|
||||
|
||||
@ -28,7 +28,9 @@ pub fn validate_topic(topic: &str) -> Result<&str, Error> {
|
||||
|
||||
#[derive(Default, Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub id: String,
|
||||
pub topic: String,
|
||||
pub expires: Option<u64>,
|
||||
pub message: Option<String>,
|
||||
#[serde(default = "Default::default")]
|
||||
pub time: u64,
|
||||
@ -337,3 +339,35 @@ pub trait NotificationProxy: Sync + Send {
|
||||
pub trait NetworkMonitorProxy: Sync + Send {
|
||||
fn listen(&self) -> Pin<Box<dyn Stream<Item = ()>>>;
|
||||
}
|
||||
|
||||
|
||||
pub struct NullNotifier {
|
||||
|
||||
}
|
||||
|
||||
impl NullNotifier {
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
}
|
||||
impl NotificationProxy for NullNotifier {
|
||||
fn send(&self, n: Notification) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NullNetworkMonitor {
|
||||
|
||||
}
|
||||
|
||||
impl NullNetworkMonitor {
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
}
|
||||
|
||||
impl NetworkMonitorProxy for NullNetworkMonitor {
|
||||
fn listen(&self) -> Pin<Box<dyn Stream<Item = ()>>> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
@ -1,48 +1,89 @@
|
||||
use crate::models::NullNetworkMonitor;
|
||||
use crate::models::NullNotifier;
|
||||
use anyhow::{anyhow, Context};
|
||||
use futures::future::join_all;
|
||||
use std::{collections::HashMap, future::Future, sync::Arc};
|
||||
use tokio::{
|
||||
sync::{broadcast, mpsc, RwLock},
|
||||
task::LocalSet,
|
||||
sync::{broadcast, mpsc, oneshot, RwLock},
|
||||
task::{spawn_local, LocalSet},
|
||||
};
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::{
|
||||
credentials::{self, Credential},
|
||||
http_client::HttpClient,
|
||||
listener::{Listener, ListenerCommand, ListenerConfig, ListenerEvent},
|
||||
message_repo::Db,
|
||||
models::{self, Account},
|
||||
topic_listener::build_client,
|
||||
SharedEnv,
|
||||
ListenerActor, ListenerCommand, ListenerConfig, ListenerHandle, SharedEnv, SubscriptionHandle,
|
||||
};
|
||||
|
||||
// Message types for the actor
|
||||
#[derive()]
|
||||
pub enum NtfyMessage {
|
||||
Subscribe {
|
||||
server: String,
|
||||
topic: String,
|
||||
respond_to: oneshot::Sender<Result<SubscriptionHandle, Vec<anyhow::Error>>>,
|
||||
},
|
||||
Unsubscribe {
|
||||
server: String,
|
||||
topic: String,
|
||||
respond_to: oneshot::Sender<anyhow::Result<()>>,
|
||||
},
|
||||
RefreshAll {
|
||||
respond_to: oneshot::Sender<anyhow::Result<()>>,
|
||||
},
|
||||
ListSubscriptions {
|
||||
respond_to: oneshot::Sender<anyhow::Result<Vec<SubscriptionHandle>>>,
|
||||
},
|
||||
ListAccounts {
|
||||
respond_to: oneshot::Sender<anyhow::Result<Vec<Account>>>,
|
||||
},
|
||||
WatchSubscribed {
|
||||
respond_to: oneshot::Sender<anyhow::Result<()>>,
|
||||
},
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
|
||||
pub struct WatchKey {
|
||||
server: String,
|
||||
topic: String,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Ntfy {
|
||||
listener_handles: Arc<RwLock<HashMap<WatchKey, Listener>>>,
|
||||
pub struct NtfyActor {
|
||||
listener_handles: Arc<RwLock<HashMap<WatchKey, SubscriptionHandle>>>,
|
||||
env: SharedEnv,
|
||||
command_rx: mpsc::Receiver<NtfyMessage>,
|
||||
}
|
||||
|
||||
impl Ntfy {
|
||||
pub fn new(env: SharedEnv) -> Self {
|
||||
Self {
|
||||
#[derive(Clone)]
|
||||
pub struct NtfyHandle {
|
||||
command_tx: mpsc::Sender<NtfyMessage>,
|
||||
}
|
||||
|
||||
impl NtfyActor {
|
||||
pub fn new(env: SharedEnv) -> (Self, NtfyHandle) {
|
||||
let (command_tx, command_rx) = mpsc::channel(32);
|
||||
|
||||
let actor = Self {
|
||||
listener_handles: Default::default(),
|
||||
env,
|
||||
}
|
||||
command_rx,
|
||||
};
|
||||
|
||||
let handle = NtfyHandle { command_tx };
|
||||
|
||||
(actor, handle)
|
||||
}
|
||||
pub async fn subscribe(
|
||||
|
||||
async fn handle_subscribe(
|
||||
&self,
|
||||
server: &str,
|
||||
topic: &str,
|
||||
) -> Result<Listener, Vec<anyhow::Error>> {
|
||||
let subscription = models::Subscription::builder(topic.to_owned())
|
||||
.server(server.to_string())
|
||||
server: String,
|
||||
topic: String,
|
||||
) -> Result<SubscriptionHandle, Vec<anyhow::Error>> {
|
||||
let subscription = models::Subscription::builder(topic.clone())
|
||||
.server(server.clone())
|
||||
.build()
|
||||
.map_err(|e| e.into_iter().map(|e| anyhow!(e)).collect::<Vec<_>>())?;
|
||||
|
||||
@ -50,81 +91,94 @@ impl Ntfy {
|
||||
db.insert_subscription(subscription.clone())
|
||||
.map_err(|e| vec![anyhow!(e)])?;
|
||||
|
||||
let listener = self.listen(subscription).await;
|
||||
listener.map_err(|e| vec![anyhow!(e)])
|
||||
self.listen(subscription)
|
||||
.await
|
||||
.map_err(|e| vec![anyhow!(e)])
|
||||
}
|
||||
|
||||
pub async fn unsubscribe(&mut self, server: &str, topic: &str) -> anyhow::Result<()> {
|
||||
let listener = self.listener_handles.write().await.remove(&WatchKey {
|
||||
server: server.to_string(),
|
||||
topic: topic.to_string(),
|
||||
async fn handle_unsubscribe(&mut self, server: String, topic: String) -> anyhow::Result<()> {
|
||||
let subscription = self.listener_handles.write().await.remove(&WatchKey {
|
||||
server: server.clone(),
|
||||
topic: topic.clone(),
|
||||
});
|
||||
if let Some(listener) = listener {
|
||||
listener.commands.send(ListenerCommand::Shutdown)?;
|
||||
|
||||
if let Some(sub) = subscription {
|
||||
sub.shutdown().await?;
|
||||
}
|
||||
|
||||
self.env.db.remove_subscription(server, topic)?;
|
||||
self.env.db.remove_subscription(&server, &topic)?;
|
||||
info!(server, topic, "Unsubscribed");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// TODO rename reconnect_all
|
||||
pub async fn refresh_all(&mut self) -> anyhow::Result<()> {
|
||||
for listener in self.listener_handles.read().await.values() {
|
||||
listener.commands.send(ListenerCommand::Restart)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
pub async fn run(&mut self) {
|
||||
while let Some(msg) = self.command_rx.recv().await {
|
||||
match msg {
|
||||
NtfyMessage::Subscribe {
|
||||
server,
|
||||
topic,
|
||||
respond_to,
|
||||
} => {
|
||||
let result = self.handle_subscribe(server, topic).await;
|
||||
let _ = respond_to.send(result);
|
||||
}
|
||||
|
||||
pub async fn list_subscriptions(&mut self) -> anyhow::Result<Vec<Listener>> {
|
||||
let values = self
|
||||
.listener_handles
|
||||
.read()
|
||||
.await
|
||||
.values()
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
NtfyMessage::Unsubscribe {
|
||||
server,
|
||||
topic,
|
||||
respond_to,
|
||||
} => {
|
||||
let result = self.handle_unsubscribe(server, topic).await;
|
||||
let _ = respond_to.send(result);
|
||||
}
|
||||
|
||||
Ok(values)
|
||||
}
|
||||
NtfyMessage::RefreshAll { respond_to } => {
|
||||
let mut res = Ok(());
|
||||
for sub in self.listener_handles.read().await.values() {
|
||||
res = sub.restart().await;
|
||||
if res.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let _ = respond_to.send(res);
|
||||
}
|
||||
|
||||
pub async fn list_accounts(&mut self) -> anyhow::Result<Vec<Account>> {
|
||||
let values = self.env.credentials.list_all();
|
||||
let res = values
|
||||
.into_iter()
|
||||
.map(|(server, credential)| Account {
|
||||
server,
|
||||
username: credential.username,
|
||||
})
|
||||
.collect();
|
||||
NtfyMessage::ListSubscriptions { respond_to } => {
|
||||
let subs = self
|
||||
.listener_handles
|
||||
.read()
|
||||
.await
|
||||
.values()
|
||||
.cloned()
|
||||
.collect();
|
||||
let _ = respond_to.send(Ok(subs));
|
||||
}
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
NtfyMessage::ListAccounts { respond_to } => {
|
||||
let accounts = self
|
||||
.env
|
||||
.credentials
|
||||
.list_all()
|
||||
.into_iter()
|
||||
.map(|(server, credential)| Account {
|
||||
server,
|
||||
username: credential.username,
|
||||
})
|
||||
.collect();
|
||||
let _ = respond_to.send(Ok(accounts));
|
||||
}
|
||||
|
||||
pub fn listen(
|
||||
&self,
|
||||
sub: models::Subscription,
|
||||
) -> impl Future<Output = anyhow::Result<Listener>> {
|
||||
let server = sub.server.clone();
|
||||
let topic = sub.topic.clone();
|
||||
let listener = Listener::new(ListenerConfig {
|
||||
http_client: self.env.nullable_http.clone(),
|
||||
credentials: self.env.credentials.clone(),
|
||||
endpoint: server.clone(),
|
||||
topic: topic.clone(),
|
||||
since: sub.read_until,
|
||||
});
|
||||
let listener_handles = self.listener_handles.clone();
|
||||
async move {
|
||||
listener_handles
|
||||
.write()
|
||||
.await
|
||||
.insert(WatchKey { server, topic }, listener.clone());
|
||||
Ok(listener)
|
||||
NtfyMessage::WatchSubscribed { respond_to } => {
|
||||
let result = self.handle_watch_subscribed().await;
|
||||
let _ = respond_to.send(result);
|
||||
}
|
||||
|
||||
NtfyMessage::Shutdown => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn watch_subscribed(&mut self) -> anyhow::Result<()> {
|
||||
async fn handle_watch_subscribed(&mut self) -> anyhow::Result<()> {
|
||||
let f: Vec<_> = self
|
||||
.env
|
||||
.db
|
||||
@ -132,48 +186,227 @@ impl Ntfy {
|
||||
.into_iter()
|
||||
.map(|m| self.listen(m))
|
||||
.collect();
|
||||
|
||||
join_all(f.into_iter().map(|x| async move {
|
||||
if let Err(e) = x.await {
|
||||
error!(error = ?e, "Can't rewatch subscribed topic");
|
||||
}
|
||||
}))
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn add_account(&mut self) {}
|
||||
fn remove_account(&mut self) {}
|
||||
fn listen(
|
||||
&self,
|
||||
sub: models::Subscription,
|
||||
) -> impl Future<Output = anyhow::Result<SubscriptionHandle>> {
|
||||
let server = sub.server.clone();
|
||||
let topic = sub.topic.clone();
|
||||
let listener = ListenerActor::new(ListenerConfig {
|
||||
http_client: self.env.nullable_http.clone(),
|
||||
credentials: self.env.credentials.clone(),
|
||||
endpoint: server.clone(),
|
||||
topic: topic.clone(),
|
||||
since: sub.read_until,
|
||||
});
|
||||
let listener_handles = self.listener_handles.clone();
|
||||
let sub = SubscriptionHandle::new(listener.clone(), sub, &self.env);
|
||||
|
||||
async move {
|
||||
listener_handles
|
||||
.write()
|
||||
.await
|
||||
.insert(WatchKey { server, topic }, sub.clone());
|
||||
Ok(sub)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl NtfyHandle {
|
||||
pub async fn subscribe(
|
||||
&self,
|
||||
server: &str,
|
||||
topic: &str,
|
||||
) -> Result<SubscriptionHandle, Vec<anyhow::Error>> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.command_tx
|
||||
.send(NtfyMessage::Subscribe {
|
||||
server: server.to_string(),
|
||||
topic: topic.to_string(),
|
||||
respond_to: tx,
|
||||
})
|
||||
.await
|
||||
.map_err(|_| vec![anyhow!("Actor mailbox error")])?;
|
||||
|
||||
rx.await
|
||||
.map_err(|_| vec![anyhow!("Actor response error")])?
|
||||
}
|
||||
|
||||
pub async fn unsubscribe(&self, server: &str, topic: &str) -> anyhow::Result<()> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.command_tx
|
||||
.send(NtfyMessage::Unsubscribe {
|
||||
server: server.to_string(),
|
||||
topic: topic.to_string(),
|
||||
respond_to: tx,
|
||||
})
|
||||
.await
|
||||
.map_err(|_| anyhow!("Actor mailbox error"))?;
|
||||
|
||||
rx.await.map_err(|_| anyhow!("Actor response error"))?
|
||||
}
|
||||
|
||||
pub async fn refresh_all(&self) -> anyhow::Result<()> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.command_tx
|
||||
.send(NtfyMessage::RefreshAll { respond_to: tx })
|
||||
.await
|
||||
.map_err(|_| anyhow!("Actor mailbox error"))?;
|
||||
|
||||
rx.await.map_err(|_| anyhow!("Actor response error"))?
|
||||
}
|
||||
|
||||
pub async fn list_subscriptions(&self) -> anyhow::Result<Vec<SubscriptionHandle>> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.command_tx
|
||||
.send(NtfyMessage::ListSubscriptions { respond_to: tx })
|
||||
.await
|
||||
.map_err(|_| anyhow!("Actor mailbox error"))?;
|
||||
|
||||
rx.await.map_err(|_| anyhow!("Actor response error"))?
|
||||
}
|
||||
|
||||
pub async fn list_accounts(&self) -> anyhow::Result<Vec<Account>> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.command_tx
|
||||
.send(NtfyMessage::ListAccounts { respond_to: tx })
|
||||
.await
|
||||
.map_err(|_| anyhow!("Actor mailbox error"))?;
|
||||
|
||||
rx.await.map_err(|_| anyhow!("Actor response error"))?
|
||||
}
|
||||
|
||||
pub async fn watch_subscribed(&self) -> anyhow::Result<()> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.command_tx
|
||||
.send(NtfyMessage::WatchSubscribed { respond_to: tx })
|
||||
.await
|
||||
.map_err(|_| anyhow!("Actor mailbox error"))?;
|
||||
|
||||
rx.await.map_err(|_| anyhow!("Actor response error"))?
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start(
|
||||
socket_path: std::path::PathBuf,
|
||||
dbpath: &str,
|
||||
notification_proxy: Arc<dyn models::NotificationProxy>,
|
||||
network_proxy: Arc<dyn models::NetworkMonitorProxy>,
|
||||
) -> anyhow::Result<Ntfy> {
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()?;
|
||||
|
||||
) -> anyhow::Result<NtfyHandle> {
|
||||
let dbpath = dbpath.to_owned();
|
||||
let credentials = rt.block_on(async { crate::credentials::Credentials::new().await.unwrap() });
|
||||
let local = tokio::task::LocalSet::new();
|
||||
|
||||
let env = SharedEnv {
|
||||
db: Db::connect(&dbpath).unwrap(),
|
||||
proxy: notification_proxy,
|
||||
http: build_client().unwrap(),
|
||||
nullable_http: HttpClient::new(build_client().unwrap()),
|
||||
network: network_proxy,
|
||||
credentials,
|
||||
};
|
||||
let ntfy = Ntfy::new(env);
|
||||
let mut ntfy_clone = ntfy.clone();
|
||||
local.spawn_local(async move {
|
||||
ntfy_clone.watch_subscribed().await.unwrap();
|
||||
// Create a channel to receive the handle from the spawned thread
|
||||
let (handle_tx, handle_rx) = oneshot::channel();
|
||||
|
||||
std::thread::spawn(move || {
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
// Create everything inside the new thread's runtime
|
||||
let credentials =
|
||||
rt.block_on(async move { crate::credentials::Credentials::new().await.unwrap() });
|
||||
|
||||
let env = SharedEnv {
|
||||
db: Db::connect(&dbpath).unwrap(),
|
||||
proxy: notification_proxy,
|
||||
http: build_client().unwrap(),
|
||||
nullable_http: HttpClient::new(build_client().unwrap()),
|
||||
network: network_proxy,
|
||||
credentials,
|
||||
};
|
||||
|
||||
let (mut actor, handle) = NtfyActor::new(env);
|
||||
let handle_clone = handle.clone();
|
||||
|
||||
// Send the handle back to the calling thread
|
||||
handle_tx.send(handle.clone());
|
||||
|
||||
rt.block_on({
|
||||
let local_set = LocalSet::new();
|
||||
// Spawn the watch_subscribed task
|
||||
local_set.spawn_local(async move {
|
||||
if let Err(e) = handle_clone.watch_subscribed().await {
|
||||
error!(error = ?e, "Failed to watch subscribed topics");
|
||||
}
|
||||
});
|
||||
|
||||
// Run the actor
|
||||
local_set.spawn_local(async move {
|
||||
actor.run().await;
|
||||
});
|
||||
local_set
|
||||
})
|
||||
});
|
||||
|
||||
Ok(ntfy)
|
||||
// Wait for the handle from the spawned thread
|
||||
Ok(handle_rx
|
||||
.blocking_recv()
|
||||
.map_err(|_| anyhow!("Failed to receive actor handle"))?)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use models::Message;
|
||||
use tokio::time::sleep;
|
||||
|
||||
use crate::ListenerEvent;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_subscribe_and_publish() {
|
||||
let notification_proxy = Arc::new(NullNotifier::new());
|
||||
let network_proxy = Arc::new(NullNetworkMonitor::new());
|
||||
let dbpath = ":memory:";
|
||||
let socket_path = std::path::PathBuf::from("/tmp/ntfy.sock");
|
||||
|
||||
let handle = start(socket_path, dbpath, notification_proxy, network_proxy).unwrap();
|
||||
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
rt.block_on(async move {
|
||||
let server = "http://localhost:8000";
|
||||
let topic = "test_topic";
|
||||
|
||||
// Subscribe to the topic
|
||||
let subscription_handle = handle.subscribe(server, topic).await.unwrap();
|
||||
|
||||
// Publish a message
|
||||
let message = serde_json::to_string(&Message {
|
||||
topic: topic.to_string(),
|
||||
..Default::default()
|
||||
})
|
||||
.unwrap();
|
||||
let result = subscription_handle.publish(message).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
sleep(Duration::from_millis(250)).await;
|
||||
|
||||
// Attach to the subscription and check if the message is received and stored
|
||||
let (events, receiver) = subscription_handle.attach().await;
|
||||
dbg!(&events);
|
||||
assert!(events.iter().any(|event| match event {
|
||||
ListenerEvent::Message(msg) => msg.topic == topic,
|
||||
_ => false,
|
||||
}));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
257
ntfy-daemon/src/subscription.rs
Normal file
257
ntfy-daemon/src/subscription.rs
Normal file
@ -0,0 +1,257 @@
|
||||
use crate::listener::{ListenerEvent, ListenerHandle};
|
||||
use crate::message_repo::Db;
|
||||
use crate::models::{self, Message, NotificationProxy};
|
||||
use crate::{Error, ServerEvent, SharedEnv};
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
use tokio::select;
|
||||
use tokio::sync::{broadcast, mpsc, oneshot, watch, RwLock};
|
||||
use tokio::task::spawn_local;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SubscriptionHandle {
|
||||
sender: mpsc::Sender<SubscriptionRequest>,
|
||||
listener: ListenerHandle,
|
||||
}
|
||||
|
||||
impl SubscriptionHandle {
|
||||
pub fn new(listener: ListenerHandle, model: models::Subscription, env: &SharedEnv) -> Self {
|
||||
let (sender, receiver) = mpsc::channel(32);
|
||||
let broadcast_tx = broadcast::channel(8).0;
|
||||
let actor = SubscriptionActor {
|
||||
listener: listener.clone(),
|
||||
model,
|
||||
receiver,
|
||||
env: env.clone(),
|
||||
broadcast_tx: broadcast_tx.clone(),
|
||||
};
|
||||
spawn_local(actor.run());
|
||||
Self { sender, listener }
|
||||
}
|
||||
|
||||
pub async fn model(&self) -> models::Subscription {
|
||||
let (resp_tx, resp_rx) = oneshot::channel();
|
||||
self.sender
|
||||
.send(SubscriptionRequest::GetModel { resp_tx })
|
||||
.await
|
||||
.unwrap();
|
||||
resp_rx.await.unwrap()
|
||||
}
|
||||
|
||||
pub async fn update_info(&self, new_model: models::Subscription) -> anyhow::Result<()> {
|
||||
let (resp_tx, resp_rx) = oneshot::channel();
|
||||
self.sender
|
||||
.send(SubscriptionRequest::UpdateInfo { new_model, resp_tx })
|
||||
.await?;
|
||||
resp_rx.await.unwrap()
|
||||
}
|
||||
|
||||
pub async fn restart(&self) -> anyhow::Result<()> {
|
||||
self.listener
|
||||
.commands
|
||||
.send(crate::ListenerCommand::Restart)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn shutdown(&self) -> anyhow::Result<()> {
|
||||
self.listener
|
||||
.commands
|
||||
.send(crate::ListenerCommand::Shutdown)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// returns a vector containing all the past messages stored in the database and the current connection state.
|
||||
// The first vector is useful to get a summary of what happened before.
|
||||
// The `ListenerHandle` is returned to receive new events.
|
||||
pub async fn attach(&self) -> (Vec<ListenerEvent>, broadcast::Receiver<ListenerEvent>) {
|
||||
let (resp_tx, resp_rx) = oneshot::channel();
|
||||
self.sender
|
||||
.send(SubscriptionRequest::Attach { resp_tx })
|
||||
.await
|
||||
.unwrap();
|
||||
resp_rx.await.unwrap()
|
||||
}
|
||||
|
||||
pub async fn publish(&self, msg: String) -> anyhow::Result<()> {
|
||||
let (resp_tx, resp_rx) = oneshot::channel();
|
||||
self.sender
|
||||
.send(SubscriptionRequest::Publish { msg, resp_tx })
|
||||
.await
|
||||
.unwrap();
|
||||
resp_rx.await.unwrap()
|
||||
}
|
||||
|
||||
pub async fn clear_notifications(&self) -> anyhow::Result<()> {
|
||||
let (resp_tx, resp_rx) = oneshot::channel();
|
||||
self.sender
|
||||
.send(SubscriptionRequest::ClearNotifications { resp_tx })
|
||||
.await
|
||||
.unwrap();
|
||||
resp_rx.await.unwrap()
|
||||
}
|
||||
|
||||
pub async fn update_read_until(&self, timestamp: u64) -> anyhow::Result<()> {
|
||||
let (resp_tx, resp_rx) = oneshot::channel();
|
||||
self.sender
|
||||
.send(SubscriptionRequest::UpdateReadUntil { timestamp, resp_tx })
|
||||
.await
|
||||
.unwrap();
|
||||
resp_rx.await.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
struct SubscriptionActor {
|
||||
listener: ListenerHandle,
|
||||
model: models::Subscription,
|
||||
receiver: mpsc::Receiver<SubscriptionRequest>,
|
||||
env: SharedEnv,
|
||||
broadcast_tx: broadcast::Sender<ListenerEvent>,
|
||||
}
|
||||
|
||||
impl SubscriptionActor {
|
||||
async fn run(mut self) {
|
||||
loop {
|
||||
select! {
|
||||
Ok(event) = self.listener.events.recv() => {
|
||||
match event {
|
||||
ListenerEvent::Message(msg) => self.handle_msg_event(msg),
|
||||
other => {
|
||||
let _ = self.broadcast_tx.send(other);
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(request) = self.receiver.recv() => {
|
||||
match request {
|
||||
SubscriptionRequest::GetModel { resp_tx } => {
|
||||
let _ = resp_tx.send(self.model.clone());
|
||||
}
|
||||
SubscriptionRequest::UpdateInfo {
|
||||
mut new_model,
|
||||
resp_tx,
|
||||
} => {
|
||||
new_model.server = self.model.server.clone();
|
||||
new_model.topic = self.model.topic.clone();
|
||||
let res = self.env.db.update_subscription(new_model.clone());
|
||||
if let Ok(_) = res {
|
||||
self.model = new_model;
|
||||
}
|
||||
resp_tx.send(res.map_err(|e| e.into()));
|
||||
}
|
||||
SubscriptionRequest::Publish {msg, resp_tx} => {
|
||||
let _ = resp_tx.send(self.publish(msg).await);
|
||||
}
|
||||
SubscriptionRequest::Attach { resp_tx } => {
|
||||
let messages = self
|
||||
.env
|
||||
.db
|
||||
.list_messages(&self.model.server, &self.model.topic, 0)
|
||||
.unwrap_or_default();
|
||||
let mut previous_events: Vec<ListenerEvent> = messages
|
||||
.into_iter()
|
||||
.filter_map(|msg| {
|
||||
let msg = serde_json::from_str(&msg);
|
||||
match msg {
|
||||
Err(e) => {
|
||||
error!(error = ?e, "error parsing stored message");
|
||||
None
|
||||
}
|
||||
Ok(msg) => Some(msg),
|
||||
}
|
||||
})
|
||||
.map(ListenerEvent::Message)
|
||||
.collect();
|
||||
previous_events.push(ListenerEvent::ConnectionStateChanged(self.listener.request_state().await));
|
||||
let _ = resp_tx.send((previous_events, self.broadcast_tx.subscribe()));
|
||||
}
|
||||
SubscriptionRequest::ClearNotifications {resp_tx} => {
|
||||
let _ = resp_tx.send(self.env.db.delete_messages(&self.model.server, &self.model.topic).map_err(|e| anyhow::anyhow!(e)));
|
||||
}
|
||||
SubscriptionRequest::UpdateReadUntil { timestamp, resp_tx } => {
|
||||
let res = self.env.db.update_read_until(&self.model.server, &self.model.topic, timestamp);
|
||||
let _ = resp_tx.send(res.map_err(|e| anyhow::anyhow!(e)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn publish(&self, msg: String) -> anyhow::Result<()> {
|
||||
let server = &self.model.server;
|
||||
let creds = self.env.credentials.get(server);
|
||||
let mut req = self.env.http.post(server);
|
||||
if let Some(creds) = creds {
|
||||
req = req.basic_auth(creds.username, Some(creds.password));
|
||||
}
|
||||
|
||||
info!("sending message");
|
||||
let res = req.body(msg).send().await?;
|
||||
res.error_for_status()?;
|
||||
Ok(())
|
||||
}
|
||||
fn handle_msg_event(&mut self, msg: Message) {
|
||||
// Store in database
|
||||
let already_stored: bool = {
|
||||
let json_ev = &serde_json::to_string(&msg).unwrap();
|
||||
match self.env.db.insert_message(&self.model.server, json_ev) {
|
||||
Err(Error::DuplicateMessage) => {
|
||||
warn!("Received duplicate message");
|
||||
true
|
||||
}
|
||||
Err(e) => {
|
||||
error!(error = ?e, "Can't store the message");
|
||||
false
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
};
|
||||
|
||||
if !already_stored {
|
||||
// Show notification. If this fails, panic
|
||||
if !{ self.model.muted } {
|
||||
let notifier = self.env.proxy.clone();
|
||||
|
||||
let title = { msg.notification_title(&self.model) };
|
||||
|
||||
let n = models::Notification {
|
||||
title,
|
||||
body: msg.display_message().as_deref().unwrap_or("").to_string(),
|
||||
actions: msg.actions.clone(),
|
||||
};
|
||||
|
||||
info!("Showing notification");
|
||||
notifier.send(n).unwrap();
|
||||
}
|
||||
|
||||
// Forward to app
|
||||
let _ = self.broadcast_tx.send(ListenerEvent::Message(msg));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum SubscriptionRequest {
|
||||
GetModel {
|
||||
resp_tx: oneshot::Sender<models::Subscription>,
|
||||
},
|
||||
UpdateInfo {
|
||||
new_model: models::Subscription,
|
||||
resp_tx: oneshot::Sender<anyhow::Result<()>>,
|
||||
},
|
||||
Attach {
|
||||
resp_tx: oneshot::Sender<(Vec<ListenerEvent>, broadcast::Receiver<ListenerEvent>)>,
|
||||
},
|
||||
Publish {
|
||||
msg: String,
|
||||
resp_tx: oneshot::Sender<anyhow::Result<()>>,
|
||||
},
|
||||
ClearNotifications {
|
||||
resp_tx: oneshot::Sender<anyhow::Result<()>>,
|
||||
},
|
||||
UpdateReadUntil {
|
||||
timestamp: u64,
|
||||
resp_tx: oneshot::Sender<anyhow::Result<()>>,
|
||||
},
|
||||
}
|
||||
Reference in New Issue
Block a user