Merge pull request #27 from ranfdev/nullables

Rewrite of ntfy-daemon. Adds basic tests with Nullables and removes any trace of capnp
This commit is contained in:
ranfdev
2025-01-21 17:13:51 +01:00
committed by GitHub
29 changed files with 3053 additions and 1986 deletions

1788
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -23,8 +23,6 @@ tracing-subscriber = "0.3"
adw = { version = "0.7", package = "libadwaita", features = ["v1_6"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
capnp = "0.18.0"
capnp-rpc = "0.18.0"
anyhow = "1.0.71"
chrono = "0.4.26"
rand = "0.8.5"

View File

@ -17,6 +17,7 @@ https://ntfy.sh client application to receive everyday's notifications.
## Architecture
The code is split between the GUI and the underlying ntfy-daemon.
![](./architecture.svg)
## How to run
Use gnome-builder to clone and run the project. Note: after clicking the "run"

12
architecture.svg Normal file

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 221 KiB

View File

@ -37,17 +37,6 @@
]
},
"modules": [
{
"name": "capnp",
"buildsystem": "cmake",
"sources": [
{
"type": "archive",
"url": "https://capnproto.org/capnproto-c++-0.10.4.tar.gz",
"sha256": "981e7ef6dbe3ac745907e55a78870fbb491c5d23abd4ebc04e20ec235af4458c"
}
]
},
{
"name": "blueprint-compiler",
"buildsystem": "meson",
@ -56,7 +45,8 @@
{
"type": "git",
"url": "https://gitlab.gnome.org/jwestman/blueprint-compiler",
"tag": "v0.14.0"
"tag": "v0.14.0",
"commit": "8e10fcf8692108b9d4ab78f41086c5d7773ef864"
}
]
},

View File

@ -5,27 +5,23 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[build-dependencies]
capnpc = "0.17.2"
[dependencies]
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
capnp = "0.18.0"
capnp-rpc = "0.18.0"
futures = "0.3.0"
tokio = { version = "1.0.0", features = ["net", "rt", "macros", "parking_lot"]}
tokio-util = { version = "0.7.4", features = ["compat", "io"] }
clap = { version = "4.3.11", features = ["derive"] }
anyhow = "1.0.71"
tokio-stream = { version = "0.1.14", features = ["io-util", "time"] }
tokio-stream = { version = "0.1.14", features = ["io-util", "time", "sync"] }
rusqlite = "0.29.0"
rand = "0.8.5"
reqwest = { version = "0.11.18", features = ["stream", "rustls-tls-native-roots"]}
url = "2.4.0"
generational-arena = "0.2.9"
reqwest = { version = "0.12.9", features = ["stream", "rustls-tls-native-roots"]}
url = { version = "2.4.0", features = ["serde"] }
tracing = "0.1.37"
thiserror = "1.0.49"
regex = "1.9.6"
oo7 = "0.2.1"
async-trait = "0.1.83"
http = "1.1.0"
async-channel = "2.3.1"

View File

@ -1,5 +0,0 @@
# ntfy-daemon
Rust crate providing a capnp-rpc interface to multiple ntfy servers.
Connections to the same server are multiplexed over http2.
Messages are received and stored in a sqlite database for persistance.

View File

@ -1,6 +0,0 @@
fn main() {
capnpc::CompilerCommand::new()
.file("src/ntfy.capnp")
.run()
.unwrap();
}

View File

@ -0,0 +1,14 @@
macro_rules! send_command {
($self:expr, $command:expr) => {{
let (resp_tx, resp_rx) = oneshot::channel();
use anyhow::Context;
$self
.command_tx
.send($command(resp_tx))
.await
.context("Actor mailbox error")?;
resp_rx.await.context("Actor response error")?
}};
}
pub(crate) use send_command;

View File

@ -1,6 +1,135 @@
use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;
use std::sync::{Arc, RwLock};
use async_trait::async_trait;
#[derive(Clone)]
pub struct KeyringItem {
attributes: HashMap<String, String>,
// we could zero-out this region of memory
secret: Vec<u8>,
}
impl KeyringItem {
async fn attributes(&self) -> HashMap<String, String> {
self.attributes.clone()
}
async fn secret(&self) -> &[u8] {
&self.secret[..]
}
}
#[async_trait]
trait LightKeyring {
async fn search_items(
&self,
attributes: HashMap<&str, &str>,
) -> anyhow::Result<Vec<KeyringItem>>;
async fn create_item(
&self,
label: &str,
attributes: HashMap<&str, &str>,
secret: &str,
replace: bool,
) -> anyhow::Result<()>;
async fn delete(&self, attributes: HashMap<&str, &str>) -> anyhow::Result<()>;
}
struct RealKeyring {
keyring: oo7::Keyring,
}
#[async_trait]
impl LightKeyring for RealKeyring {
async fn search_items(
&self,
attributes: HashMap<&str, &str>,
) -> anyhow::Result<Vec<KeyringItem>> {
let items = self.keyring.search_items(attributes).await?;
let mut out_items = vec![];
for item in items {
out_items.push(KeyringItem {
attributes: item.attributes().await?,
secret: item.secret().await?.to_vec(),
});
}
Ok(out_items)
}
async fn create_item(
&self,
label: &str,
attributes: HashMap<&str, &str>,
secret: &str,
replace: bool,
) -> anyhow::Result<()> {
self.keyring
.create_item(label, attributes, secret, replace)
.await?;
Ok(())
}
async fn delete(&self, attributes: HashMap<&str, &str>) -> anyhow::Result<()> {
self.keyring.delete(attributes).await?;
Ok(())
}
}
struct NullableKeyring {
search_response: Vec<KeyringItem>,
}
impl NullableKeyring {
pub fn new(search_response: Vec<KeyringItem>) -> Self {
Self { search_response }
}
}
#[async_trait]
impl LightKeyring for NullableKeyring {
async fn search_items(
&self,
_attributes: HashMap<&str, &str>,
) -> anyhow::Result<Vec<KeyringItem>> {
Ok(self.search_response.clone())
}
async fn create_item(
&self,
_label: &str,
_attributes: HashMap<&str, &str>,
_secret: &str,
_replace: bool,
) -> anyhow::Result<()> {
Ok(())
}
async fn delete(&self, _attributes: HashMap<&str, &str>) -> anyhow::Result<()> {
Ok(())
}
}
impl NullableKeyring {
pub fn with_credentials(credentials: Vec<Credential>) -> Self {
let mut search_response = vec![];
for cred in credentials {
let attributes = HashMap::from([
("type".to_string(), "password".to_string()),
("username".to_string(), cred.username.clone()),
("server".to_string(), cred.password.clone()),
]);
search_response.push(KeyringItem {
attributes,
secret: cred.password.into_bytes(),
});
}
Self { search_response }
}
}
#[derive(Debug, Clone)]
pub struct Credential {
@ -8,20 +137,28 @@ pub struct Credential {
pub password: String,
}
#[derive(Debug, Clone)]
#[derive(Clone)]
pub struct Credentials {
keyring: Rc<oo7::Keyring>,
creds: Rc<RefCell<HashMap<String, Credential>>>,
keyring: Arc<dyn LightKeyring + Send + Sync>,
creds: Arc<RwLock<HashMap<String, Credential>>>,
}
impl Credentials {
pub async fn new() -> anyhow::Result<Self> {
let mut this = Self {
keyring: Rc::new(
oo7::Keyring::new()
keyring: Arc::new(RealKeyring {
keyring: oo7::Keyring::new()
.await
.expect("Failed to start Secret Service"),
),
}),
creds: Default::default(),
};
this.load().await?;
Ok(this)
}
pub async fn new_nullable(credentials: Vec<Credential>) -> anyhow::Result<Self> {
let mut this = Self {
keyring: Arc::new(NullableKeyring::with_credentials(credentials)),
creds: Default::default(),
};
this.load().await?;
@ -29,37 +166,31 @@ impl Credentials {
}
pub async fn load(&mut self) -> anyhow::Result<()> {
let attrs = HashMap::from([("type", "password")]);
let values = self
.keyring
.search_items(attrs)
.await
.map_err(|e| capnp::Error::failed(e.to_string()))?;
let values = self.keyring.search_items(attrs).await?;
self.creds.borrow_mut().clear();
let mut lock = self.creds.write().unwrap();
lock.clear();
for item in values {
let attrs = item
.attributes()
.await
.map_err(|e| capnp::Error::failed(e.to_string()))?;
self.creds.borrow_mut().insert(
let attrs = item.attributes().await;
lock.insert(
attrs["server"].to_string(),
Credential {
username: attrs["username"].to_string(),
password: std::str::from_utf8(&item.secret().await?)?.to_string(),
password: std::str::from_utf8(&item.secret().await)?.to_string(),
},
);
}
Ok(())
}
pub fn get(&self, server: &str) -> Option<Credential> {
self.creds.borrow().get(server).cloned()
self.creds.read().unwrap().get(server).cloned()
}
pub fn list_all(&self) -> HashMap<String, Credential> {
self.creds.borrow().clone()
self.creds.read().unwrap().clone()
}
pub async fn insert(&self, server: &str, username: &str, password: &str) -> anyhow::Result<()> {
{
if let Some(cred) = self.creds.borrow().get(server) {
if let Some(cred) = self.creds.read().unwrap().get(server) {
if cred.username != username {
anyhow::bail!("You can add only one account per server");
}
@ -72,10 +203,9 @@ impl Credentials {
]);
self.keyring
.create_item("Password", attrs, password, true)
.await
.map_err(|e| capnp::Error::failed(e.to_string()))?;
.await?;
self.creds.borrow_mut().insert(
self.creds.write().unwrap().insert(
server.to_string(),
Credential {
username: username.to_string(),
@ -87,7 +217,8 @@ impl Credentials {
pub async fn delete(&self, server: &str) -> anyhow::Result<()> {
let creds = {
self.creds
.borrow()
.read()
.unwrap()
.get(server)
.ok_or(anyhow::anyhow!("server creds not found"))?
.clone()
@ -97,12 +228,10 @@ impl Credentials {
("username", &creds.username),
("server", server),
]);
self.keyring
.delete(attrs)
.await
.map_err(|e| capnp::Error::failed(e.to_string()))?;
self.keyring.delete(attrs).await?;
self.creds
.borrow_mut()
.write()
.unwrap()
.remove(server)
.ok_or(anyhow::anyhow!("server creds not found"))?;
Ok(())

View File

@ -0,0 +1,324 @@
use anyhow::Result;
use async_trait::async_trait;
use reqwest::{header::HeaderMap, Client, Request, RequestBuilder, Response, ResponseBuilderExt};
use serde_json::{json, Value};
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tokio::time;
use crate::models;
use crate::output_tracker::OutputTrackerAsync;
// Structure to store request information for verification
#[derive(Clone, Debug)]
pub struct RequestInfo {
pub url: String,
pub method: String,
pub headers: HeaderMap,
pub body: Option<Vec<u8>>,
}
impl RequestInfo {
fn from_request(request: &Request) -> Self {
RequestInfo {
url: request.url().to_string(),
method: request.method().to_string(),
headers: request.headers().clone(),
body: None, // Note: Request body can't be accessed after it's built
}
}
}
#[async_trait]
trait LightHttpClient: Send + Sync {
fn get(&self, url: &str) -> RequestBuilder;
fn post(&self, url: &str) -> RequestBuilder;
async fn execute(&self, request: Request) -> Result<Response>;
}
#[async_trait]
impl LightHttpClient for Client {
fn get(&self, url: &str) -> RequestBuilder {
self.get(url)
}
fn post(&self, url: &str) -> RequestBuilder {
self.post(url)
}
async fn execute(&self, request: Request) -> Result<Response> {
Ok(self.execute(request).await?)
}
}
#[derive(Clone)]
pub struct HttpClient {
client: Arc<dyn LightHttpClient>,
request_tracker: OutputTrackerAsync<RequestInfo>,
}
impl HttpClient {
pub fn new(client: reqwest::Client) -> Self {
Self {
client: Arc::new(client),
request_tracker: Default::default(),
}
}
pub fn new_nullable(client: NullableClient) -> Self {
Self {
client: Arc::new(client),
request_tracker: Default::default(),
}
}
pub async fn request_tracker(&self) -> OutputTrackerAsync<RequestInfo> {
self.request_tracker.enable().await;
self.request_tracker.clone()
}
pub fn get(&self, url: &str) -> RequestBuilder {
self.client.get(url)
}
pub fn post(&self, url: &str) -> RequestBuilder {
self.client.post(url)
}
pub async fn execute(&self, request: Request) -> Result<Response> {
self.request_tracker
.push(RequestInfo::from_request(&request))
.await;
Ok(self.client.execute(request).await?)
}
}
#[derive(Clone, Default)]
pub struct NullableClient {
responses: Arc<RwLock<HashMap<String, VecDeque<Response>>>>,
default_response: Arc<RwLock<Option<Box<dyn Fn() -> Response + Send + Sync + 'static>>>>,
}
/// Builder for configuring NullableClient
#[derive(Default)]
pub struct NullableClientBuilder {
responses: HashMap<String, VecDeque<Response>>,
default_response: Option<Box<dyn Fn() -> Response + Send + Sync + 'static>>,
}
impl NullableClientBuilder {
pub fn new() -> Self {
Self::default()
}
/// Add a single response for a specific URL
pub fn response(mut self, url: impl Into<String>, response: Response) -> Self {
self.responses
.entry(url.into())
.or_default()
.push_back(response);
self
}
/// Add multiple responses for a specific URL that will be returned in sequence
pub fn responses(mut self, url: impl Into<String>, responses: Vec<Response>) -> Self {
self.responses.insert(url.into(), responses.into());
self
}
/// Set a default response generator for any unmatched URLs
pub fn default_response(
mut self,
response: impl Fn() -> Response + Send + Sync + 'static,
) -> Self {
self.default_response = Some(Box::new(response));
self
}
/// Helper method to quickly add a JSON response
pub fn json_response(
self,
url: impl Into<String>,
status: u16,
body: impl serde::Serialize,
) -> Result<Self> {
let response = http::response::Builder::new()
.status(status)
.body(serde_json::to_string(&body)?)
.unwrap()
.into();
Ok(self.response(url, response))
}
/// Helper method to quickly add a text response
pub fn text_response(
self,
url: impl Into<String>,
status: u16,
body: impl Into<String>,
) -> Self {
let response = http::response::Builder::new()
.status(status)
.body(body.into())
.unwrap()
.into();
self.response(url, response)
}
pub fn build(self) -> NullableClient {
NullableClient {
responses: Arc::new(RwLock::new(
self.responses
.into_iter()
.map(|(k, v)| (k, v.into()))
.collect(),
)),
default_response: Arc::new(RwLock::new(self.default_response)),
}
}
}
impl NullableClient {
pub fn builder() -> NullableClientBuilder {
NullableClientBuilder::new()
}
}
#[async_trait]
impl LightHttpClient for NullableClient {
fn get(&self, url: &str) -> RequestBuilder {
Client::new().get(url)
}
fn post(&self, url: &str) -> RequestBuilder {
Client::new().post(url)
}
async fn execute(&self, request: Request) -> Result<Response> {
time::sleep(Duration::from_millis(1)).await;
let url = request.url().to_string();
let mut responses = self.responses.write().await;
if let Some(url_responses) = responses.get_mut(&url) {
if let Some(response) = url_responses.pop_front() {
// Remove the URL entry if no more responses
if url_responses.is_empty() {
responses.remove(&url);
}
Ok(response)
} else {
if let Some(default_fn) = &*self.default_response.read().await {
Ok(default_fn())
} else {
Err(anyhow::anyhow!("no response configured for URL: {}", url))
}
}
} else if let Some(default_fn) = &*self.default_response.read().await {
Ok(default_fn())
} else {
Err(anyhow::anyhow!("no response configured for URL: {}", url))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[tokio::test]
async fn test_nullable_with_builder() -> Result<()> {
// Configure client using builder pattern
let client = NullableClient::builder()
.text_response("https://api.example.com/topic", 200, "ok")
.json_response(
"https://api.example.com/json",
200,
json!({ "status": "success" }),
)?
.default_response(|| {
http::response::Builder::new()
.status(404)
.body("not found")
.unwrap()
.into()
})
.build();
let http_client = HttpClient::new_nullable(client);
let request_tracker = http_client.request_tracker().await;
// Test successful text response
let request = http_client.get("https://api.example.com/topic").build()?;
let response = http_client.execute(request).await?;
assert_eq!(response.status(), 200);
assert_eq!(response.text().await?, "ok");
// Test successful JSON response
let request = http_client.get("https://api.example.com/json").build()?;
let response = http_client.execute(request).await?;
assert_eq!(response.status(), 200);
assert_eq!(response.text().await?, r#"{"status":"success"}"#);
// Test default response
let request = http_client.get("https://api.example.com/unknown").build()?;
let response = http_client.execute(request).await?;
assert_eq!(response.status(), 404);
assert_eq!(response.text().await?, "not found");
// Verify recorded requests
let requests = request_tracker.items().await;
assert_eq!(requests.len(), 3);
Ok(())
}
#[tokio::test]
async fn test_sequence_of_responses() -> Result<()> {
// Configure client with multiple responses for the same URL
let client = NullableClient::builder()
.responses(
"https://api.example.com/sequence",
vec![
http::response::Builder::new()
.status(200)
.body("first")
.unwrap()
.into(),
http::response::Builder::new()
.status(200)
.body("second")
.unwrap()
.into(),
],
)
.build();
let http_client = HttpClient::new_nullable(client);
// First request gets first response
let request = http_client
.get("https://api.example.com/sequence")
.build()?;
let response = http_client.execute(request).await?;
assert_eq!(response.text().await?, "first");
// Second request gets second response
let request = http_client
.get("https://api.example.com/sequence")
.build()?;
let response = http_client.execute(request).await?;
assert_eq!(response.text().await?, "second");
// Third request fails (no more responses)
let request = http_client
.get("https://api.example.com/sequence")
.build()?;
let result = http_client.execute(request).await;
assert!(result.is_err());
Ok(())
}
}

View File

@ -1,21 +1,28 @@
mod actor_utils;
pub mod credentials;
mod http_client;
mod listener;
pub mod message_repo;
pub mod models;
mod ntfy;
mod output_tracker;
pub mod retry;
pub mod system_client;
pub mod topic_listener;
pub mod ntfy_capnp {
include!(concat!(env!("OUT_DIR"), "/src/ntfy_capnp.rs"));
}
mod subscription;
pub use listener::*;
pub use ntfy::start;
pub use ntfy::NtfyHandle;
use std::sync::Arc;
pub use subscription::SubscriptionHandle;
use http_client::HttpClient;
#[derive(Clone)]
pub struct SharedEnv {
db: message_repo::Db,
proxy: Arc<dyn models::NotificationProxy>,
http: reqwest::Client,
network: Arc<dyn models::NetworkMonitorProxy>,
notifier: Arc<dyn models::NotificationProxy>,
http_client: HttpClient,
network_monitor: Arc<dyn models::NetworkMonitorProxy>,
credentials: credentials::Credentials,
}
@ -25,6 +32,8 @@ pub enum Error {
InvalidTopic(String),
#[error("invalid server base url {0:?}")]
InvalidServer(#[from] url::ParseError),
#[error("multiple errors in subscription model: {0:?}")]
InvalidSubscription(Vec<Error>),
#[error("duplicate message")]
DuplicateMessage,
#[error("can't parse the minimum set of required fields from the message {0}")]
@ -36,9 +45,3 @@ pub enum Error {
#[error("subscription not found while {0}")]
SubscriptionNotFound(String),
}
impl From<Error> for capnp::Error {
fn from(value: Error) -> Self {
capnp::Error::failed(format!("{:?}", value))
}
}

393
ntfy-daemon/src/listener.rs Normal file
View File

@ -0,0 +1,393 @@
use std::sync::Arc;
use std::time::Duration;
use futures::{StreamExt, TryStreamExt};
use serde::{Deserialize, Serialize};
use tokio::io::AsyncBufReadExt;
use tokio::task::{self, spawn_local, LocalSet};
use tokio::{
select,
sync::{mpsc, oneshot},
};
use tokio_stream::wrappers::LinesStream;
use tracing::{debug, error, info, warn, Instrument, Span};
use crate::credentials::Credentials;
use crate::http_client::HttpClient;
use crate::{models, Error};
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "event")]
pub enum ServerEvent {
#[serde(rename = "open")]
Open {
id: String,
time: usize,
expires: Option<usize>,
topic: String,
},
#[serde(rename = "message")]
Message(models::ReceivedMessage),
#[serde(rename = "keepalive")]
KeepAlive {
id: String,
time: usize,
expires: Option<usize>,
topic: String,
},
}
#[derive(Debug, Clone)]
pub enum ListenerEvent {
Message(models::ReceivedMessage),
ConnectionStateChanged(ConnectionState),
}
#[derive(Clone)]
pub struct ListenerConfig {
pub(crate) http_client: HttpClient,
pub(crate) credentials: Credentials,
pub(crate) endpoint: String,
pub(crate) topic: String,
pub(crate) since: u64,
}
#[derive(Debug)]
pub enum ListenerCommand {
Restart,
Shutdown,
GetState(oneshot::Sender<ConnectionState>),
}
fn topic_request(
client: &HttpClient,
endpoint: &str,
topic: &str,
since: u64,
username: Option<&str>,
password: Option<&str>,
) -> anyhow::Result<reqwest::Request> {
let url = models::Subscription::build_url(endpoint, topic, since)?;
let mut req = client
.get(url.as_str())
.header("Content-Type", "application/x-ndjson")
.header("Transfer-Encoding", "chunked");
if let Some(username) = username {
req = req.basic_auth(username, password);
}
Ok(req.build()?)
}
async fn response_lines(
res: impl tokio::io::AsyncBufRead,
) -> Result<impl futures::Stream<Item = Result<String, std::io::Error>>, reqwest::Error> {
let lines = LinesStream::new(res.lines());
Ok(lines)
}
#[derive(Clone, Debug)]
pub enum ConnectionState {
Unitialized,
Connected,
Reconnecting {
retry_count: u64,
delay: Duration,
error: Option<Arc<anyhow::Error>>,
},
}
pub struct ListenerActor {
pub event_tx: async_channel::Sender<ListenerEvent>,
pub commands_rx: Option<mpsc::Receiver<ListenerCommand>>,
pub config: ListenerConfig,
pub state: ConnectionState,
}
impl ListenerActor {
pub async fn run_loop(mut self) {
let span = tracing::info_span!("listener_loop", topic = %self.config.topic);
async {
let mut commands_rx = self.commands_rx.take().unwrap();
loop {
select! {
_ = self.run_supervised_loop() => {
info!("supervised loop ended");
break;
},
cmd = commands_rx.recv() => {
match cmd {
Some(ListenerCommand::Restart) => {
info!("restarting listener");
continue;
}
Some(ListenerCommand::Shutdown) => {
info!("shutting down listener");
break;
}
Some(ListenerCommand::GetState(tx)) => {
debug!("getting listener state");
let state = self.state.clone();
if tx.send(state).is_err() {
warn!("failed to send state - receiver dropped");
}
}
None => {
error!("command channel closed");
break;
}
}
}
}
}
}
.instrument(span)
.await;
}
async fn set_state(&mut self, state: ConnectionState) {
self.state = state.clone();
self.event_tx
.send(ListenerEvent::ConnectionStateChanged(state))
.await
.unwrap();
}
async fn run_supervised_loop(&mut self) {
let span = tracing::info_span!("supervised_loop");
async {
let retrier = || {
crate::retry::WaitExponentialRandom::builder()
.min(Duration::from_secs(1))
.max(Duration::from_secs(5 * 60))
.build()
};
let mut retry = retrier();
loop {
let start_time = std::time::Instant::now();
if let Err(e) = self.recv_and_forward_loop().await {
let uptime = std::time::Instant::now().duration_since(start_time);
// Reset retry delay to minimum if uptime was decent enough
if uptime > Duration::from_secs(60 * 4) {
debug!("resetting retry delay due to sufficient uptime");
retry = retrier();
}
error!(error = ?e, "connection error");
self.set_state(ConnectionState::Reconnecting {
retry_count: retry.count(),
delay: retry.next_delay(),
error: Some(Arc::new(e)),
})
.await;
info!(delay = ?retry.next_delay(), "waiting before reconnect attempt");
retry.wait().await;
} else {
break;
}
}
}
.instrument(span)
.await;
}
async fn recv_and_forward_loop(&mut self) -> anyhow::Result<()> {
let span = tracing::info_span!("receive_loop",
endpoint = %self.config.endpoint,
topic = %self.config.topic,
since = %self.config.since
);
async {
let creds = self.config.credentials.get(&self.config.endpoint);
debug!("creating request");
let req = topic_request(
&self.config.http_client,
&self.config.endpoint,
&self.config.topic,
self.config.since,
creds.as_ref().map(|x| x.username.as_str()),
creds.as_ref().map(|x| x.password.as_str()),
);
debug!("executing request");
let res = self.config.http_client.execute(req?).await?;
let res = res.error_for_status()?;
let reader = tokio_util::io::StreamReader::new(
res.bytes_stream()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string())),
);
let stream = response_lines(reader).await?;
tokio::pin!(stream);
self.set_state(ConnectionState::Connected).await;
info!("connection established");
info!(topic = %&self.config.topic, "listening");
while let Some(msg) = stream.next().await {
let msg = msg?;
let min_msg = serde_json::from_str::<models::MinMessage>(&msg)
.map_err(|e| Error::InvalidMinMessage(msg.to_string(), e))?;
self.config.since = min_msg.time.max(self.config.since);
let event = serde_json::from_str(&msg)
.map_err(|e| Error::InvalidMessage(msg.to_string(), e))?;
match event {
ServerEvent::Message(msg) => {
debug!(id = %msg.id, "forwarding message");
self.event_tx
.send(ListenerEvent::Message(msg))
.await
.unwrap();
}
ServerEvent::KeepAlive { id, .. } => {
debug!(id = %id, "received keepalive");
}
ServerEvent::Open { id, .. } => {
debug!(id = %id, "received open event");
}
}
}
Ok(())
}
.instrument(span)
.await
}
}
// Reliable listener implementation
#[derive(Clone)]
pub struct ListenerHandle {
pub events: async_channel::Receiver<ListenerEvent>,
pub config: ListenerConfig,
pub commands: mpsc::Sender<ListenerCommand>,
}
impl ListenerHandle {
pub fn new(config: ListenerConfig) -> ListenerHandle {
let (event_tx, event_rx) = async_channel::bounded(64);
let (commands_tx, commands_rx) = mpsc::channel(1);
let config_clone = config.clone();
// use a new local set to isolate panics
let local_set = LocalSet::new();
local_set.spawn_local(async move {
let this = ListenerActor {
event_tx,
commands_rx: Some(commands_rx),
config: config_clone,
state: ConnectionState::Unitialized,
};
this.run_loop().await;
});
spawn_local(local_set);
Self {
events: event_rx,
config,
commands: commands_tx,
}
}
// the response will be sent as an event in self.events
pub async fn state(&self) -> ConnectionState {
let (tx, rx) = oneshot::channel();
self.commands
.send(ListenerCommand::GetState(tx))
.await
.unwrap();
rx.await.unwrap()
}
}
#[cfg(test)]
mod tests {
use models::Subscription;
use serde_json::json;
use task::LocalSet;
use crate::http_client::NullableClient;
use super::*;
#[tokio::test]
async fn test_listener_reconnects_on_http_status_500() {
let local_set = LocalSet::new();
local_set
.spawn_local(async {
let http_client = HttpClient::new_nullable({
let url = Subscription::build_url("http://localhost", "test", 0).unwrap();
let nullable = NullableClient::builder()
.text_response(url.clone(), 500, "failed")
.json_response(url, 200, json!({"id":"SLiKI64DOt","time":1635528757,"event":"open","topic":"mytopic"})).unwrap()
.build();
nullable
});
let credentials = Credentials::new_nullable(vec![]).await.unwrap();
let config = ListenerConfig {
http_client,
credentials,
endpoint: "http://localhost".to_string(),
topic: "test".to_string(),
since: 0,
};
let listener = ListenerHandle::new(config.clone());
let items: Vec<_> = listener.events.take(3).collect().await;
dbg!(&items);
assert!(matches!(
&items[..],
&[
ListenerEvent::ConnectionStateChanged(ConnectionState::Unitialized),
ListenerEvent::ConnectionStateChanged(ConnectionState::Reconnecting { .. }),
ListenerEvent::ConnectionStateChanged(ConnectionState::Connected { .. }),
]
));
});
local_set.await;
}
#[tokio::test]
async fn test_listener_reconnects_on_invalid_message() {
let local_set = LocalSet::new();
local_set
.spawn_local(async {
let http_client = HttpClient::new_nullable({
let url = Subscription::build_url("http://localhost", "test", 0).unwrap();
let nullable = NullableClient::builder()
.text_response(url.clone(), 200, "invalid message")
.json_response(url, 200, json!({"id":"SLiKI64DOt","time":1635528757,"event":"open","topic":"mytopic"})).unwrap()
.build();
nullable
});
let credentials = Credentials::new_nullable(vec![]).await.unwrap();
let config = ListenerConfig {
http_client,
credentials,
endpoint: "http://localhost".to_string(),
topic: "test".to_string(),
since: 0,
};
let listener = ListenerHandle::new(config.clone());
let items: Vec<_> = listener.events.take(3).collect().await;
dbg!(&items);
assert!(matches!(
&items[..],
&[
ListenerEvent::ConnectionStateChanged(ConnectionState::Unitialized),
ListenerEvent::ConnectionStateChanged(ConnectionState::Reconnecting { .. }),
ListenerEvent::ConnectionStateChanged(ConnectionState::Connected { .. }),
]
));
});
local_set.await;
}
}

View File

@ -1,3 +1,4 @@
use std::sync::{Arc, RwLock};
use std::{cell::RefCell, rc::Rc};
use rusqlite::{params, Connection, Result};
@ -8,16 +9,16 @@ use crate::Error;
#[derive(Clone, Debug)]
pub struct Db {
conn: Rc<RefCell<Connection>>,
conn: Arc<RwLock<Connection>>,
}
impl Db {
pub fn connect(path: &str) -> Result<Self> {
let mut this = Self {
conn: Rc::new(RefCell::new(Connection::open(path)?)),
conn: Arc::new(RwLock::new(Connection::open(path)?)),
};
{
this.conn.borrow().execute_batch(
this.conn.read().unwrap().execute_batch(
"PRAGMA foreign_keys = ON;
PRAGMA journal_mode = wal;",
)?;
@ -27,12 +28,13 @@ impl Db {
}
fn migrate(&mut self) -> Result<()> {
self.conn
.borrow()
.read()
.unwrap()
.execute_batch(include_str!("./migrations/00.sql"))?;
Ok(())
}
fn get_or_insert_server(&mut self, server: &str) -> Result<i64> {
let mut conn = self.conn.borrow_mut();
let mut conn = self.conn.write().unwrap();
let tx = conn.transaction()?;
let mut res = tx.query_row(
"SELECT id
@ -56,7 +58,7 @@ impl Db {
}
pub fn insert_message(&mut self, server: &str, json_data: &str) -> Result<(), Error> {
let server_id = self.get_or_insert_server(server)?;
let res = self.conn.borrow().execute(
let res = self.conn.read().unwrap().execute(
"INSERT INTO message (server, data) VALUES (?1, ?2)",
params![server_id, json_data],
);
@ -76,7 +78,7 @@ impl Db {
topic: &str,
since: u64,
) -> Result<Vec<String>, rusqlite::Error> {
let conn = self.conn.borrow();
let conn = self.conn.read().unwrap();
let mut stmt = conn.prepare(
"
SELECT data
@ -94,7 +96,7 @@ impl Db {
}
pub fn insert_subscription(&mut self, sub: models::Subscription) -> Result<(), Error> {
let server_id = self.get_or_insert_server(&sub.server)?;
self.conn.borrow().execute(
self.conn.read().unwrap().execute(
"INSERT INTO subscription (server, topic, display_name, reserved, muted, archived) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
params![
server_id,
@ -109,7 +111,7 @@ impl Db {
}
pub fn remove_subscription(&mut self, server: &str, topic: &str) -> Result<(), Error> {
let server_id = self.get_or_insert_server(server)?;
let res = self.conn.borrow().execute(
let res = self.conn.read().unwrap().execute(
"DELETE FROM subscription
WHERE server = ?1 AND topic = ?2",
params![server_id, topic],
@ -120,7 +122,7 @@ impl Db {
Ok(())
}
pub fn list_subscriptions(&mut self) -> Result<Vec<models::Subscription>, Error> {
let conn = self.conn.borrow();
let conn = self.conn.read().unwrap();
let mut stmt = conn.prepare(
"SELECT server.endpoint, sub.topic, sub.display_name, sub.reserved, sub.muted, sub.archived, sub.symbolic_icon, sub.read_until
FROM subscription sub
@ -146,7 +148,7 @@ impl Db {
pub fn update_subscription(&mut self, sub: models::Subscription) -> Result<(), Error> {
let server_id = self.get_or_insert_server(&sub.server)?;
let res = self.conn.borrow().execute(
let res = self.conn.read().unwrap().execute(
"UPDATE subscription
SET display_name = ?1, reserved = ?2, muted = ?3, archived = ?4, read_until = ?5
WHERE server = ?6 AND topic = ?7",
@ -174,7 +176,7 @@ impl Db {
value: u64,
) -> Result<(), Error> {
let server_id = self.get_or_insert_server(server).unwrap();
let conn = self.conn.borrow();
let conn = self.conn.read().unwrap();
let res = conn.execute(
"UPDATE subscription
SET read_until = ?3
@ -189,7 +191,7 @@ impl Db {
}
pub fn delete_messages(&mut self, server: &str, topic: &str) -> Result<(), Error> {
let server_id = self.get_or_insert_server(server).unwrap();
let conn = self.conn.borrow();
let conn = self.conn.read().unwrap();
let res = conn.execute(
"DELETE FROM message
WHERE topic = ?2 AND server = ?1

View File

@ -27,8 +27,10 @@ pub fn validate_topic(topic: &str) -> Result<&str, Error> {
}
#[derive(Default, Clone, Debug, Serialize, Deserialize)]
pub struct Message {
pub struct ReceivedMessage {
pub id: String,
pub topic: String,
pub expires: Option<u64>,
pub message: Option<String>,
#[serde(default = "Default::default")]
pub time: u64,
@ -57,7 +59,7 @@ pub struct Message {
pub actions: Vec<Action>,
}
impl Message {
impl ReceivedMessage {
fn extend_with_emojis(&self, text: &mut String) {
// Add emojis
for t in &self.tags {
@ -105,6 +107,37 @@ impl Message {
}
}
#[derive(Default, Clone, Debug, Serialize, Deserialize)]
pub struct OutgoingMessage {
pub topic: String,
pub message: Option<String>,
#[serde(default = "Default::default")]
pub time: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Vec::is_empty")]
pub tags: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub priority: Option<i8>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(default)]
pub attachment: Option<Attachment>,
#[serde(skip_serializing_if = "Option::is_none")]
pub icon: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub filename: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub delay: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub email: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub call: Option<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Vec::is_empty")]
pub actions: Vec<Action>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct MinMessage {
pub id: String,
@ -165,7 +198,7 @@ impl Subscription {
.push("auth");
Ok(url)
}
pub fn validate(self) -> Result<Self, Vec<crate::Error>> {
pub fn validate(self) -> Result<Self, crate::Error> {
let mut errs = vec![];
if let Err(e) = validate_topic(&self.topic) {
errs.push(e);
@ -174,7 +207,7 @@ impl Subscription {
errs.push(e);
};
if !errs.is_empty() {
return Err(errs);
return Err(Error::InvalidSubscription(errs));
}
Ok(self)
}
@ -237,7 +270,7 @@ impl SubscriptionBuilder {
self
}
pub fn build(self) -> Result<Subscription, Vec<Error>> {
pub fn build(self) -> Result<Subscription, Error> {
let res = Subscription {
server: self.server,
topic: self.topic,
@ -318,6 +351,12 @@ impl From<Status> for u8 {
}
}
#[derive(Clone, Debug)]
pub struct Account {
pub server: String,
pub username: String,
}
pub struct Notification {
pub title: String,
pub body: String,
@ -331,3 +370,30 @@ pub trait NotificationProxy: Sync + Send {
pub trait NetworkMonitorProxy: Sync + Send {
fn listen(&self) -> Pin<Box<dyn Stream<Item = ()>>>;
}
pub struct NullNotifier {}
impl NullNotifier {
pub fn new() -> Self {
Self {}
}
}
impl NotificationProxy for NullNotifier {
fn send(&self, n: Notification) -> anyhow::Result<()> {
Ok(())
}
}
pub struct NullNetworkMonitor {}
impl NullNetworkMonitor {
pub fn new() -> Self {
Self {}
}
}
impl NetworkMonitorProxy for NullNetworkMonitor {
fn listen(&self) -> Pin<Box<dyn Stream<Item = ()>>> {
Box::pin(futures::stream::empty())
}
}

View File

@ -1,49 +0,0 @@
@0x9663f4dd604afa35;
enum Status {
down @0;
degraded @1;
up @2;
}
interface WatchHandle {}
interface OutputChannel {
sendMessage @0 (message: Text);
sendStatus @1 (status: Status);
done @2 ();
}
struct SubscriptionInfo {
server @0 :Text;
topic @1 :Text;
displayName @2 :Text;
muted @3 :Bool;
readUntil @4 :UInt64;
}
interface Subscription {
watch @0 (watcher: OutputChannel, since: UInt64) -> (handle: WatchHandle);
publish @1 (message: Text);
getInfo @2 () -> SubscriptionInfo;
updateInfo @3 (value: SubscriptionInfo);
updateReadUntil @4 (value: UInt64);
clearNotifications @5 ();
refresh @6 ();
}
struct Account {
server @0 :Text;
username @1 :Text;
}
interface SystemNotifier {
subscribe @0 (server: Text, topic: Text) -> (subscription: Subscription);
unsubscribe @1 (server: Text, topic: Text);
listSubscriptions @2 () -> (list: List(Subscription));
addAccount @3 (account: Account, password: Text);
removeAccount @4 (account: Account);
listAccounts @5 () -> (list: List(Account));
}

450
ntfy-daemon/src/ntfy.rs Normal file
View File

@ -0,0 +1,450 @@
use crate::actor_utils::send_command;
use crate::models::NullNetworkMonitor;
use crate::models::NullNotifier;
use anyhow::{anyhow, Context};
use futures::future::join_all;
use futures::StreamExt;
use std::{collections::HashMap, future::Future, sync::Arc};
use tokio::select;
use tokio::{
sync::{broadcast, mpsc, oneshot, RwLock},
task::{spawn_local, LocalSet},
};
use tracing::{error, info};
use crate::{
http_client::HttpClient,
message_repo::Db,
models::{self, Account},
ListenerActor, ListenerCommand, ListenerConfig, ListenerHandle, SharedEnv, SubscriptionHandle,
};
const CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(15);
const TIMEOUT: std::time::Duration = std::time::Duration::from_secs(240); // 4 minutes
pub fn build_client() -> anyhow::Result<reqwest::Client> {
Ok(reqwest::Client::builder()
.connect_timeout(CONNECT_TIMEOUT)
.pool_idle_timeout(TIMEOUT)
// rustls is used because HTTP 2 isn't discovered with native-tls.
// HTTP 2 is required to multiplex multiple requests over a single connection.
// You can check that the app is using a single connection to a server by doing
// ```
// ping ntfy.sh # to get the ip address
// netstat | grep $ip
// ```
.use_rustls_tls()
.build()?)
}
// Message types for the actor
#[derive()]
pub enum NtfyCommand {
Subscribe {
server: String,
topic: String,
resp_tx: oneshot::Sender<Result<SubscriptionHandle, anyhow::Error>>,
},
Unsubscribe {
server: String,
topic: String,
resp_tx: oneshot::Sender<anyhow::Result<()>>,
},
RefreshAll {
resp_tx: oneshot::Sender<anyhow::Result<()>>,
},
ListSubscriptions {
resp_tx: oneshot::Sender<anyhow::Result<Vec<SubscriptionHandle>>>,
},
ListAccounts {
resp_tx: oneshot::Sender<anyhow::Result<Vec<Account>>>,
},
WatchSubscribed {
resp_tx: oneshot::Sender<anyhow::Result<()>>,
},
AddAccount {
server: String,
username: String,
password: String,
resp_tx: oneshot::Sender<anyhow::Result<()>>,
},
RemoveAccount {
server: String,
resp_tx: oneshot::Sender<anyhow::Result<()>>,
},
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct WatchKey {
server: String,
topic: String,
}
pub struct NtfyActor {
listener_handles: Arc<RwLock<HashMap<WatchKey, SubscriptionHandle>>>,
env: SharedEnv,
command_rx: mpsc::Receiver<NtfyCommand>,
}
#[derive(Clone)]
pub struct NtfyHandle {
command_tx: mpsc::Sender<NtfyCommand>,
}
impl NtfyActor {
pub fn new(env: SharedEnv) -> (Self, NtfyHandle) {
let (command_tx, command_rx) = mpsc::channel(32);
let actor = Self {
listener_handles: Default::default(),
env,
command_rx,
};
let handle = NtfyHandle { command_tx };
(actor, handle)
}
async fn handle_subscribe(
&self,
server: String,
topic: String,
) -> Result<SubscriptionHandle, anyhow::Error> {
let subscription = models::Subscription::builder(topic.clone())
.server(server.clone())
.build()?;
let mut db = self.env.db.clone();
db.insert_subscription(subscription.clone())?;
self.listen(subscription).await
}
async fn handle_unsubscribe(&mut self, server: String, topic: String) -> anyhow::Result<()> {
let subscription = self.listener_handles.write().await.remove(&WatchKey {
server: server.clone(),
topic: topic.clone(),
});
if let Some(sub) = subscription {
sub.shutdown().await?;
}
self.env.db.remove_subscription(&server, &topic)?;
info!(server, topic, "Unsubscribed");
Ok(())
}
pub async fn run(&mut self) {
let mut network_change_stream = self.env.network_monitor.listen();
loop {
select! {
Some(_) = network_change_stream.next() => {
let _ = self.refresh_all().await;
},
Some(command) = self.command_rx.recv() => self.handle_command(command).await,
};
}
}
async fn handle_command(&mut self, command: NtfyCommand) {
match command {
NtfyCommand::Subscribe {
server,
topic,
resp_tx,
} => {
let result = self.handle_subscribe(server, topic).await;
let _ = resp_tx.send(result);
}
NtfyCommand::Unsubscribe {
server,
topic,
resp_tx,
} => {
let result = self.handle_unsubscribe(server, topic).await;
let _ = resp_tx.send(result);
}
NtfyCommand::RefreshAll { resp_tx } => {
let res = self.refresh_all().await;
let _ = resp_tx.send(res);
}
NtfyCommand::ListSubscriptions { resp_tx } => {
let subs = self
.listener_handles
.read()
.await
.values()
.cloned()
.collect();
let _ = resp_tx.send(Ok(subs));
}
NtfyCommand::ListAccounts { resp_tx } => {
let accounts = self
.env
.credentials
.list_all()
.into_iter()
.map(|(server, credential)| Account {
server,
username: credential.username,
})
.collect();
let _ = resp_tx.send(Ok(accounts));
}
NtfyCommand::WatchSubscribed { resp_tx } => {
let result = self.handle_watch_subscribed().await;
let _ = resp_tx.send(result);
}
NtfyCommand::AddAccount {
server,
username,
password,
resp_tx,
} => {
let result = self
.env
.credentials
.insert(&server, &username, &password)
.await;
let _ = resp_tx.send(result);
}
NtfyCommand::RemoveAccount { server, resp_tx } => {
let result = self.env.credentials.delete(&server).await;
let _ = resp_tx.send(result);
}
}
}
async fn handle_watch_subscribed(&mut self) -> anyhow::Result<()> {
let f: Vec<_> = self
.env
.db
.list_subscriptions()?
.into_iter()
.map(|m| self.listen(m))
.collect();
join_all(f.into_iter().map(|x| async move {
if let Err(e) = x.await {
error!(error = ?e, "Can't rewatch subscribed topic");
}
}))
.await;
Ok(())
}
fn listen(
&self,
sub: models::Subscription,
) -> impl Future<Output = anyhow::Result<SubscriptionHandle>> {
let server = sub.server.clone();
let topic = sub.topic.clone();
let listener = ListenerHandle::new(ListenerConfig {
http_client: self.env.http_client.clone(),
credentials: self.env.credentials.clone(),
endpoint: server.clone(),
topic: topic.clone(),
since: sub.read_until,
});
let listener_handles = self.listener_handles.clone();
let sub = SubscriptionHandle::new(listener.clone(), sub, &self.env);
async move {
listener_handles
.write()
.await
.insert(WatchKey { server, topic }, sub.clone());
Ok(sub)
}
}
async fn refresh_all(&self) -> anyhow::Result<()> {
let mut res = Ok(());
for sub in self.listener_handles.read().await.values() {
res = sub.restart().await;
if res.is_err() {
break;
}
}
res
}
}
impl NtfyHandle {
pub async fn subscribe(
&self,
server: &str,
topic: &str,
) -> Result<SubscriptionHandle, anyhow::Error> {
send_command!(self, |resp_tx| NtfyCommand::Subscribe {
server: server.to_string(),
topic: topic.to_string(),
resp_tx,
})
}
pub async fn unsubscribe(&self, server: &str, topic: &str) -> anyhow::Result<()> {
send_command!(self, |resp_tx| NtfyCommand::Unsubscribe {
server: server.to_string(),
topic: topic.to_string(),
resp_tx,
})
}
pub async fn refresh_all(&self) -> anyhow::Result<()> {
send_command!(self, |resp_tx| NtfyCommand::RefreshAll { resp_tx })
}
pub async fn list_subscriptions(&self) -> anyhow::Result<Vec<SubscriptionHandle>> {
send_command!(self, |resp_tx| NtfyCommand::ListSubscriptions { resp_tx })
}
pub async fn list_accounts(&self) -> anyhow::Result<Vec<Account>> {
send_command!(self, |resp_tx| NtfyCommand::ListAccounts { resp_tx })
}
pub async fn watch_subscribed(&self) -> anyhow::Result<()> {
send_command!(self, |resp_tx| NtfyCommand::WatchSubscribed { resp_tx })
}
pub async fn add_account(
&self,
server: &str,
username: &str,
password: &str,
) -> anyhow::Result<()> {
send_command!(self, |resp_tx| NtfyCommand::AddAccount {
server: server.to_string(),
username: username.to_string(),
password: password.to_string(),
resp_tx,
})
}
pub async fn remove_account(&self, server: &str) -> anyhow::Result<()> {
send_command!(self, |resp_tx| NtfyCommand::RemoveAccount {
server: server.to_string(),
resp_tx,
})
}
}
pub fn start(
dbpath: &str,
notification_proxy: Arc<dyn models::NotificationProxy>,
network_proxy: Arc<dyn models::NetworkMonitorProxy>,
) -> anyhow::Result<NtfyHandle> {
let dbpath = dbpath.to_owned();
// Create a channel to receive the handle from the spawned thread
let (handle_tx, handle_rx) = oneshot::channel();
std::thread::spawn(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
// Create everything inside the new thread's runtime
let credentials =
rt.block_on(async move { crate::credentials::Credentials::new().await.unwrap() });
let env = SharedEnv {
db: Db::connect(&dbpath).unwrap(),
notifier: notification_proxy,
http_client: HttpClient::new(build_client().unwrap()),
network_monitor: network_proxy,
credentials,
};
let (mut actor, handle) = NtfyActor::new(env);
let handle_clone = handle.clone();
// Send the handle back to the calling thread
handle_tx.send(handle.clone());
rt.block_on({
let local_set = LocalSet::new();
// Spawn the watch_subscribed task
local_set.spawn_local(async move {
if let Err(e) = handle_clone.watch_subscribed().await {
error!(error = ?e, "Failed to watch subscribed topics");
}
});
// Run the actor
local_set.spawn_local(async move {
actor.run().await;
});
local_set
})
});
// Wait for the handle from the spawned thread
Ok(handle_rx
.blocking_recv()
.map_err(|_| anyhow!("Failed to receive actor handle"))?)
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use models::{OutgoingMessage, ReceivedMessage};
use tokio::time::sleep;
use crate::ListenerEvent;
use super::*;
#[test]
fn test_subscribe_and_publish() {
let notification_proxy = Arc::new(NullNotifier::new());
let network_proxy = Arc::new(NullNetworkMonitor::new());
let dbpath = ":memory:";
let handle = start(dbpath, notification_proxy, network_proxy).unwrap();
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async move {
let server = "http://localhost:8000";
let topic = "test_topic";
// Subscribe to the topic
let subscription_handle = handle.subscribe(server, topic).await.unwrap();
// Publish a message
let message = serde_json::to_string(&OutgoingMessage {
topic: topic.to_string(),
..Default::default()
})
.unwrap();
let result = subscription_handle.publish(message).await;
assert!(result.is_ok());
sleep(Duration::from_millis(250)).await;
// Attach to the subscription and check if the message is received and stored
let (events, receiver) = subscription_handle.attach().await;
dbg!(&events);
assert!(events.iter().any(|event| match event {
ListenerEvent::Message(msg) => msg.topic == topic,
_ => false,
}));
});
}
}

View File

@ -0,0 +1,71 @@
use std::{cell::RefCell, rc::Rc, sync::Arc};
use tokio::sync::RwLock;
#[derive(Clone)]
pub struct OutputTracker<T> {
store: Rc<RefCell<Option<Vec<T>>>>,
}
impl<T> Default for OutputTracker<T> {
fn default() -> Self {
Self {
store: Default::default(),
}
}
}
impl<T: Clone> OutputTracker<T> {
pub fn enable(&self) {
let mut inner = self.store.borrow_mut();
if inner.is_none() {
*inner = Some(vec![]);
}
}
pub fn push(&self, item: T) {
if let Some(v) = &mut *self.store.borrow_mut() {
v.push(item);
}
}
pub fn items(&self) -> Vec<T> {
if let Some(v) = &*self.store.borrow() {
v.clone()
} else {
vec![]
}
}
}
#[derive(Clone)]
pub struct OutputTrackerAsync<T> {
store: Arc<RwLock<Option<Vec<T>>>>,
}
impl<T> Default for OutputTrackerAsync<T> {
fn default() -> Self {
Self {
store: Default::default(),
}
}
}
impl<T: Clone> OutputTrackerAsync<T> {
pub async fn enable(&self) {
let mut inner = self.store.write().await;
if inner.is_none() {
*inner = Some(vec![]);
}
}
pub async fn push(&self, item: T) {
if let Some(v) = &mut *self.store.write().await {
v.push(item);
}
}
pub async fn items(&self) -> Vec<T> {
if let Some(v) = &*self.store.read().await {
v.clone()
} else {
vec![]
}
}
}

View File

@ -53,4 +53,8 @@ impl WaitExponentialRandom {
sleep(self.next_delay()).await;
self.i += 1;
}
pub fn count(&self) -> u64 {
self.i
}
}

View File

@ -0,0 +1,276 @@
use crate::listener::{ListenerEvent, ListenerHandle};
use crate::models::{self, ReceivedMessage};
use crate::{Error, SharedEnv};
use tokio::select;
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio::task::spawn_local;
use tracing::{debug, error, info, trace, warn};
#[derive(Debug)]
enum SubscriptionCommand {
GetModel {
resp_tx: oneshot::Sender<models::Subscription>,
},
UpdateInfo {
new_model: models::Subscription,
resp_tx: oneshot::Sender<anyhow::Result<()>>,
},
Attach {
resp_tx: oneshot::Sender<(Vec<ListenerEvent>, broadcast::Receiver<ListenerEvent>)>,
},
Publish {
msg: String,
resp_tx: oneshot::Sender<anyhow::Result<()>>,
},
ClearNotifications {
resp_tx: oneshot::Sender<anyhow::Result<()>>,
},
UpdateReadUntil {
timestamp: u64,
resp_tx: oneshot::Sender<anyhow::Result<()>>,
},
}
#[derive(Clone)]
pub struct SubscriptionHandle {
command_tx: mpsc::Sender<SubscriptionCommand>,
listener: ListenerHandle,
}
impl SubscriptionHandle {
pub fn new(listener: ListenerHandle, model: models::Subscription, env: &SharedEnv) -> Self {
let (command_tx, command_rx) = mpsc::channel(32);
let broadcast_tx = broadcast::channel(8).0;
let actor = SubscriptionActor {
listener: listener.clone(),
model,
command_rx,
env: env.clone(),
broadcast_tx: broadcast_tx.clone(),
};
spawn_local(actor.run());
Self {
command_tx,
listener,
}
}
pub async fn model(&self) -> models::Subscription {
let (resp_tx, resp_rx) = oneshot::channel();
self.command_tx
.send(SubscriptionCommand::GetModel { resp_tx })
.await
.unwrap();
resp_rx.await.unwrap()
}
pub async fn update_info(&self, new_model: models::Subscription) -> anyhow::Result<()> {
let (resp_tx, resp_rx) = oneshot::channel();
self.command_tx
.send(SubscriptionCommand::UpdateInfo { new_model, resp_tx })
.await?;
resp_rx.await.unwrap()
}
pub async fn restart(&self) -> anyhow::Result<()> {
self.listener
.commands
.send(crate::ListenerCommand::Restart)
.await?;
Ok(())
}
pub async fn shutdown(&self) -> anyhow::Result<()> {
self.listener
.commands
.send(crate::ListenerCommand::Shutdown)
.await?;
Ok(())
}
// returns a vector containing all the past messages stored in the database and the current connection state.
// The first vector is useful to get a summary of what happened before.
// The `ListenerHandle` is returned to receive new events.
pub async fn attach(&self) -> (Vec<ListenerEvent>, broadcast::Receiver<ListenerEvent>) {
let (resp_tx, resp_rx) = oneshot::channel();
self.command_tx
.send(SubscriptionCommand::Attach { resp_tx })
.await
.unwrap();
resp_rx.await.unwrap()
}
pub async fn publish(&self, msg: String) -> anyhow::Result<()> {
let (resp_tx, resp_rx) = oneshot::channel();
self.command_tx
.send(SubscriptionCommand::Publish { msg, resp_tx })
.await
.unwrap();
resp_rx.await.unwrap()
}
pub async fn clear_notifications(&self) -> anyhow::Result<()> {
let (resp_tx, resp_rx) = oneshot::channel();
self.command_tx
.send(SubscriptionCommand::ClearNotifications { resp_tx })
.await
.unwrap();
resp_rx.await.unwrap()
}
pub async fn update_read_until(&self, timestamp: u64) -> anyhow::Result<()> {
let (resp_tx, resp_rx) = oneshot::channel();
self.command_tx
.send(SubscriptionCommand::UpdateReadUntil { timestamp, resp_tx })
.await
.unwrap();
resp_rx.await.unwrap()
}
}
struct SubscriptionActor {
listener: ListenerHandle,
model: models::Subscription,
command_rx: mpsc::Receiver<SubscriptionCommand>,
env: SharedEnv,
broadcast_tx: broadcast::Sender<ListenerEvent>,
}
impl SubscriptionActor {
async fn run(mut self) {
loop {
select! {
Ok(event) = self.listener.events.recv() => {
debug!(?event, "received listener event");
match event {
ListenerEvent::Message(msg) => self.handle_msg_event(msg),
other => {
let _ = self.broadcast_tx.send(other);
}
}
}
Some(command) = self.command_rx.recv() => {
trace!(?command, "processing subscription command");
match command {
SubscriptionCommand::GetModel { resp_tx } => {
debug!("getting subscription model");
let _ = resp_tx.send(self.model.clone());
}
SubscriptionCommand::UpdateInfo {
mut new_model,
resp_tx,
} => {
debug!(server=?new_model.server, topic=?new_model.topic, "updating subscription info");
new_model.server = self.model.server.clone();
new_model.topic = self.model.topic.clone();
let res = self.env.db.update_subscription(new_model.clone());
if let Ok(_) = res {
self.model = new_model;
}
let _ = resp_tx.send(res.map_err(|e| e.into()));
}
SubscriptionCommand::Publish {msg, resp_tx} => {
debug!(topic=?self.model.topic, "publishing message");
let _ = resp_tx.send(self.publish(msg).await);
}
SubscriptionCommand::Attach { resp_tx } => {
debug!(topic=?self.model.topic, "attaching new listener");
let messages = self
.env
.db
.list_messages(&self.model.server, &self.model.topic, 0)
.unwrap_or_default();
let mut previous_events: Vec<ListenerEvent> = messages
.into_iter()
.filter_map(|msg| {
let msg = serde_json::from_str(&msg);
match msg {
Err(e) => {
error!(error = ?e, "error parsing stored message");
None
}
Ok(msg) => Some(msg),
}
})
.map(ListenerEvent::Message)
.collect();
previous_events.push(ListenerEvent::ConnectionStateChanged(self.listener.state().await));
let _ = resp_tx.send((previous_events, self.broadcast_tx.subscribe()));
}
SubscriptionCommand::ClearNotifications {resp_tx} => {
debug!(topic=?self.model.topic, "clearing notifications");
let _ = resp_tx.send(self.env.db.delete_messages(&self.model.server, &self.model.topic).map_err(|e| anyhow::anyhow!(e)));
}
SubscriptionCommand::UpdateReadUntil { timestamp, resp_tx } => {
debug!(topic=?self.model.topic, timestamp=timestamp, "updating read until timestamp");
let res = self.env.db.update_read_until(&self.model.server, &self.model.topic, timestamp);
let _ = resp_tx.send(res.map_err(|e| anyhow::anyhow!(e)));
}
}
}
}
}
}
async fn publish(&self, msg: String) -> anyhow::Result<()> {
let server = &self.model.server;
debug!(server=?server, "preparing to publish message");
let creds = self.env.credentials.get(server);
let mut req = self.env.http_client.post(server);
if let Some(creds) = creds {
req = req.basic_auth(creds.username, Some(creds.password));
}
info!(server=?server, "sending message");
let res = req.body(msg).send().await?;
res.error_for_status()?;
debug!(server=?server, "message published successfully");
Ok(())
}
fn handle_msg_event(&mut self, msg: ReceivedMessage) {
debug!(topic=?self.model.topic, "handling new message");
// Store in database
let already_stored: bool = {
let json_ev = &serde_json::to_string(&msg).unwrap();
match self.env.db.insert_message(&self.model.server, json_ev) {
Err(Error::DuplicateMessage) => {
warn!(topic=?self.model.topic, "received duplicate message");
true
}
Err(e) => {
error!(error=?e, topic=?self.model.topic, "can't store the message");
false
}
_ => {
debug!(topic=?self.model.topic, "message stored successfully");
false
}
}
};
if !already_stored {
debug!(topic=?self.model.topic, muted=?self.model.muted, "checking if notification should be shown");
// Show notification. If this fails, panic
if !{ self.model.muted } {
let notifier = self.env.notifier.clone();
let title = { msg.notification_title(&self.model) };
let n = models::Notification {
title,
body: msg.display_message().as_deref().unwrap_or("").to_string(),
actions: msg.actions.clone(),
};
info!(topic=?self.model.topic, "showing notification");
notifier.send(n).unwrap();
} else {
debug!(topic=?self.model.topic, "notification muted, skipping");
}
// Forward to app
debug!(topic=?self.model.topic, "forwarding message to app");
let _ = self.broadcast_tx.send(ListenerEvent::Message(msg));
}
}
}

View File

@ -1,619 +0,0 @@
use std::cell::{Cell, RefCell};
use std::ops::ControlFlow;
use std::rc::{Rc, Weak};
use std::sync::Arc;
use std::time::Duration;
use std::{collections::HashMap, hash::Hash};
use capnp::capability::Promise;
use capnp_rpc::{pry, rpc_twoparty_capnp, twoparty, RpcSystem};
use futures::future::join_all;
use futures::prelude::*;
use generational_arena::Arena;
use tokio::net::UnixListener;
use tokio::sync::mpsc;
use tracing::{error, info, warn};
use crate::models::Message;
use crate::Error;
use crate::SharedEnv;
use crate::{
message_repo::Db,
models::{self, MinMessage},
ntfy_capnp::{output_channel, subscription, system_notifier, watch_handle, Status},
topic_listener::{build_client, TopicListener},
};
const MESSAGE_THROTTLE: Duration = Duration::from_millis(150);
pub struct NotifyForwarder {
model: Rc<RefCell<models::Subscription>>,
env: SharedEnv,
watching: Weak<RefCell<Arena<output_channel::Client>>>,
status: Rc<Cell<Status>>,
}
impl NotifyForwarder {
pub fn new(
model: Rc<RefCell<models::Subscription>>,
env: SharedEnv,
watching: Weak<RefCell<Arena<output_channel::Client>>>,
status: Rc<Cell<Status>>,
) -> Self {
Self {
model,
env,
watching,
status,
}
}
}
impl output_channel::Server for NotifyForwarder {
// Stores the message, sends a system notification, forwards the message to watching clients
fn send_message(
&mut self,
params: output_channel::SendMessageParams,
_results: output_channel::SendMessageResults,
) -> capnp::capability::Promise<(), capnp::Error> {
let request = pry!(params.get());
let message = pry!(pry!(request.get_message()).to_str());
// Store in database
let already_stored: bool = {
// If this fails parsing, the message is not valid at all.
// The server is probably misbehaving.
let min_message: MinMessage = pry!(serde_json::from_str(message)
.map_err(|e| Error::InvalidMinMessage(message.to_string(), e)));
let model = self.model.borrow();
match self.env.db.insert_message(&model.server, message) {
Err(Error::DuplicateMessage) => {
warn!(min_message = ?min_message, "Received duplicate message");
true
}
Err(e) => {
error!(min_message = ?min_message, error = ?e, "Can't store the message");
false
}
_ => false,
}
};
if !already_stored {
// Show notification
// Our priority is to show notifications. If anything fails, panic.
if !{ self.model.borrow().muted } {
let msg: Message = pry!(serde_json::from_str(message)
.map_err(|e| Error::InvalidMessage(message.to_string(), e)));
let np = self.env.proxy.clone();
let title = { msg.notification_title(&self.model.borrow()) };
let n = models::Notification {
title,
body: msg.display_message().as_deref().unwrap_or("").to_string(),
actions: msg.actions,
};
info!("Showing notification");
np.send(n).unwrap();
}
// Forward
if let Some(watching) = self.watching.upgrade() {
let watching = watching.borrow();
let futs = watching.iter().map(|(_id, w)| {
let mut req = w.send_message_request();
req.get().set_message(message.into());
async move {
if let Err(e) = req.send().promise.await {
error!(error = ?e, "Error forwarding");
}
}
});
tokio::task::spawn_local(join_all(futs));
}
}
Promise::from_future(async move {
// some backpressure
tokio::time::sleep(MESSAGE_THROTTLE).await;
Ok(())
})
}
fn send_status(
&mut self,
params: output_channel::SendStatusParams,
_: output_channel::SendStatusResults,
) -> capnp::capability::Promise<(), capnp::Error> {
let status = pry!(pry!(params.get()).get_status());
if let Some(watching) = self.watching.upgrade() {
for (_, w) in watching.borrow().iter() {
let mut req = w.send_status_request();
req.get().set_status(status);
tokio::task::spawn_local(async move {
req.send().promise.await.unwrap();
});
}
}
self.status.set(status);
Promise::ok(())
}
}
struct WatcherImpl {
id: generational_arena::Index,
watchers: Weak<RefCell<Arena<output_channel::Client>>>,
}
impl watch_handle::Server for WatcherImpl {}
impl Drop for WatcherImpl {
fn drop(&mut self) {
if let Some(w) = self.watchers.upgrade() {
w.borrow_mut().remove(self.id);
}
}
}
pub struct SubscriptionImpl {
model: Rc<RefCell<models::Subscription>>,
env: SharedEnv,
watchers: Rc<RefCell<Arena<output_channel::Client>>>,
status: Rc<Cell<Status>>,
topic_listener: mpsc::Sender<ControlFlow<()>>,
}
impl Drop for SubscriptionImpl {
fn drop(&mut self) {
let t = self.topic_listener.clone();
tokio::task::spawn_local(async move {
t.send(ControlFlow::Break(())).await.unwrap();
});
}
}
impl SubscriptionImpl {
fn new(model: models::Subscription, env: SharedEnv) -> Self {
let status = Rc::new(Cell::new(Status::Down));
let watchers = Default::default();
let rc_model = Rc::new(RefCell::new(model.clone()));
let output_channel = NotifyForwarder::new(
rc_model.clone(),
env.clone(),
Rc::downgrade(&watchers),
status.clone(),
);
let topic_listener = TopicListener::new(
env.clone(),
model.server.clone(),
model.topic.clone(),
model.read_until,
capnp_rpc::new_client(output_channel),
);
Self {
model: rc_model,
env,
watchers,
status,
topic_listener,
}
}
fn _publish<'a>(&'a mut self, msg: &'a str) -> impl Future<Output = Result<(), capnp::Error>> {
let msg = msg.to_owned();
let server = &self.model.borrow().server;
let creds = self.env.credentials.get(server);
let mut req = self.env.http.post(server);
if let Some(creds) = creds {
req = req.basic_auth(creds.username, Some(creds.password));
}
async move {
info!("sending message");
let res = req.body(msg).send().await;
match res {
Err(e) => Err(capnp::Error::failed(e.to_string())),
Ok(res) => {
res.error_for_status()
.map_err(|e| capnp::Error::failed(e.to_string()))?;
Ok(())
}
}
}
}
}
impl subscription::Server for SubscriptionImpl {
fn watch(
&mut self,
params: subscription::WatchParams,
mut results: subscription::WatchResults,
) -> capnp::capability::Promise<(), capnp::Error> {
let watcher = pry!(pry!(params.get()).get_watcher());
let since = pry!(params.get()).get_since();
// Send old messages
let msgs = {
let model = self.model.borrow();
pry!(self
.env
.db
.list_messages(&model.server, &model.topic, since)
.map_err(Error::Db))
};
let futs = msgs.into_iter().map(move |msg| {
let mut req = watcher.send_message_request();
req.get().set_message(msg.as_str().into());
req.send().promise
});
let watcher = pry!(pry!(params.get()).get_watcher());
let mut req = watcher.send_status_request();
req.get().set_status(self.status.get());
let id = { self.watchers.borrow_mut().insert(watcher) };
results.get().set_handle(capnp_rpc::new_client(WatcherImpl {
id,
watchers: Rc::downgrade(&self.watchers),
}));
Promise::from_future(async move {
futures::future::try_join_all(futs).await?;
req.send().promise.await?;
Ok(())
})
}
fn publish(
&mut self,
params: subscription::PublishParams,
_results: subscription::PublishResults,
) -> capnp::capability::Promise<(), capnp::Error> {
let msg = pry!(pry!(pry!(params.get()).get_message()).to_str());
let fut = self._publish(msg);
Promise::from_future(async move {
fut.await?;
Ok(())
})
}
fn get_info(
&mut self,
_: subscription::GetInfoParams,
mut results: subscription::GetInfoResults,
) -> capnp::capability::Promise<(), capnp::Error> {
let mut res = results.get();
let model = self.model.borrow();
res.set_server(model.server.as_str().into());
res.set_display_name(model.display_name.as_str().into());
res.set_topic(model.topic.as_str().into());
res.set_muted(model.muted);
res.set_read_until(model.read_until);
Promise::ok(())
}
fn update_info(
&mut self,
params: subscription::UpdateInfoParams,
_results: subscription::UpdateInfoResults,
) -> capnp::capability::Promise<(), capnp::Error> {
let info = pry!(pry!(params.get()).get_value());
let mut model = self.model.borrow_mut();
model.display_name = pry!(pry!(info.get_display_name()).to_string());
model.muted = info.get_muted();
model.read_until = info.get_read_until();
pry!(self.env.db.update_subscription(model.clone()));
Promise::ok(())
}
fn clear_notifications(
&mut self,
_params: subscription::ClearNotificationsParams,
_results: subscription::ClearNotificationsResults,
) -> capnp::capability::Promise<(), capnp::Error> {
let model = self.model.borrow_mut();
pry!(self.env.db.delete_messages(&model.server, &model.topic));
Promise::ok(())
}
fn update_read_until(
&mut self,
params: subscription::UpdateReadUntilParams,
_: subscription::UpdateReadUntilResults,
) -> capnp::capability::Promise<(), capnp::Error> {
let value = pry!(params.get()).get_value();
let mut model = self.model.borrow_mut();
pry!(self
.env
.db
.update_read_until(&model.server, &model.topic, value));
model.read_until = value;
Promise::ok(())
}
fn refresh(
&mut self,
_: subscription::RefreshParams,
_: subscription::RefreshResults,
) -> capnp::capability::Promise<(), capnp::Error> {
let sender = self.topic_listener.clone();
Promise::from_future(async move {
sender
.send(ControlFlow::Continue(()))
.await
.map_err(|e| capnp::Error::failed(format!("{:?}", e)))?;
Ok(())
})
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct WatchKey {
server: String,
topic: String,
}
pub struct SystemNotifier {
watching: Rc<RefCell<HashMap<WatchKey, subscription::Client>>>,
env: SharedEnv,
}
impl SystemNotifier {
pub fn new(
dbpath: &str,
notification_proxy: Arc<dyn models::NotificationProxy>,
network: Arc<dyn models::NetworkMonitorProxy>,
credentials: crate::credentials::Credentials,
) -> Self {
Self {
watching: Rc::new(RefCell::new(HashMap::new())),
env: SharedEnv {
db: Db::connect(dbpath).unwrap(),
proxy: notification_proxy,
http: build_client().unwrap(),
network,
credentials,
},
}
}
fn watch(&mut self, sub: models::Subscription) -> Promise<subscription::Client, capnp::Error> {
let subscription = SubscriptionImpl::new(sub.clone(), self.env.clone());
let watching = self.watching.clone();
let subc: subscription::Client = capnp_rpc::new_client(subscription);
Promise::from_future(async move {
watching.borrow_mut().insert(
WatchKey {
server: sub.server.to_owned(),
topic: sub.topic.to_owned(),
},
subc.clone(),
);
Ok(subc)
})
}
pub fn watch_subscribed(&mut self) -> Promise<(), capnp::Error> {
let f: Vec<_> = pry!(self.env.db.list_subscriptions())
.into_iter()
.map(|m| self.watch(m))
.collect();
Promise::from_future(async move {
join_all(f.into_iter().map(|x| async move {
if let Err(e) = x.await {
error!(error = ?e, "Can't rewatch subscribed topic");
}
}))
.await;
Ok(())
})
}
pub fn refresh_all(&mut self) -> Promise<(), capnp::Error> {
let watching = self.watching.clone();
Promise::from_future(async move {
let reqs: Vec<_> = watching
.borrow()
.values()
.map(|w| w.refresh_request())
.collect();
join_all(reqs.into_iter().map(|x| x.send().promise)).await;
Ok(())
})
}
}
impl system_notifier::Server for SystemNotifier {
fn subscribe(
&mut self,
params: system_notifier::SubscribeParams,
mut results: system_notifier::SubscribeResults,
) -> capnp::capability::Promise<(), capnp::Error> {
let topic = pry!(pry!(pry!(params.get()).get_topic()).to_str());
let server: &str = pry!(pry!(pry!(params.get()).get_server()).to_str());
let subscription = pry!(models::Subscription::builder(topic.to_owned())
.server(server.to_string())
.build()
.map_err(|e| capnp::Error::failed(format!("{:?}", e))));
let sub: Promise<subscription::Client, capnp::Error> = self.watch(subscription.clone());
let mut db = self.env.db.clone();
Promise::from_future(async move {
results.get().set_subscription(sub.await?);
db.insert_subscription(subscription).map_err(|e| {
capnp::Error::failed(format!("could not insert subscription: {}", e))
})?;
Ok(())
})
}
fn unsubscribe(
&mut self,
params: system_notifier::UnsubscribeParams,
_results: system_notifier::UnsubscribeResults,
) -> capnp::capability::Promise<(), capnp::Error> {
let topic = pry!(pry!(pry!(params.get()).get_topic()).to_str());
let server = pry!(pry!(pry!(params.get()).get_server()).to_str());
{
self.watching.borrow_mut().remove(&WatchKey {
server: server.to_string(),
topic: topic.to_string(),
});
pry!(self
.env
.db
.remove_subscription(server, topic)
.map_err(|e| capnp::Error::failed(e.to_string())));
info!(server, topic, "Unsubscribed");
}
Promise::ok(())
}
fn list_subscriptions(
&mut self,
_: system_notifier::ListSubscriptionsParams,
mut results: system_notifier::ListSubscriptionsResults,
) -> capnp::capability::Promise<(), capnp::Error> {
let req = results.get();
let values = self.watching.borrow().values().cloned().collect::<Vec<_>>();
let mut list = req.init_list(values.len() as u32);
for (i, v) in values.iter().enumerate() {
use capnp::capability::FromClientHook;
list.set(i as u32, v.clone().clone().into_client_hook());
}
Promise::ok(())
}
fn list_accounts(
&mut self,
_: system_notifier::ListAccountsParams,
mut results: system_notifier::ListAccountsResults,
) -> capnp::capability::Promise<(), capnp::Error> {
let values = self.env.credentials.list_all();
Promise::from_future(async move {
let mut list = results.get().init_list(values.len() as u32);
for (i, item) in values.into_iter().enumerate() {
let mut acc = list.reborrow().get(i as u32);
acc.set_server(item.0[..].into());
acc.set_username(item.1.username[..].into());
}
Ok(())
})
}
fn add_account(
&mut self,
params: system_notifier::AddAccountParams,
_: system_notifier::AddAccountResults,
) -> capnp::capability::Promise<(), capnp::Error> {
let credentials = self.env.credentials.clone();
let http = self.env.http.clone();
let refresh = self.refresh_all();
Promise::from_future(async move {
let account = params.get()?.get_account()?;
let username = account.get_username()?.to_str()?;
let server = account.get_server()?.to_str()?;
let password = params.get()?.get_password()?.to_str()?;
info!("validating account");
let url = models::Subscription::build_auth_url(server, "stats")?;
http.get(url)
.basic_auth(username, Some(password))
.send()
.await
.map_err(|e| capnp::Error::failed(e.to_string()))?
.error_for_status()
.map_err(|e| capnp::Error::failed(e.to_string()))?;
credentials
.insert(server, username, password)
.await
.map_err(|e| capnp::Error::failed(e.to_string()))?;
refresh.await?;
info!(server = %server, username = %username, "added account");
Ok(())
})
}
fn remove_account(
&mut self,
params: system_notifier::RemoveAccountParams,
_: system_notifier::RemoveAccountResults,
) -> capnp::capability::Promise<(), capnp::Error> {
let credentials = self.env.credentials.clone();
Promise::from_future(async move {
let account = params.get()?.get_account()?;
let username = account.get_username()?.to_str()?;
let server = account.get_server()?.to_str()?;
credentials
.delete(server)
.await
.map_err(|e| capnp::Error::failed(e.to_string()))?;
info!(server = %server, username = %username, "removed account");
Ok(())
})
}
}
pub fn start(
socket_path: std::path::PathBuf,
dbpath: &str,
notification_proxy: Arc<dyn models::NotificationProxy>,
network_proxy: Arc<dyn models::NetworkMonitorProxy>,
) -> anyhow::Result<()> {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
let listener = rt.block_on(async move {
let _ = std::fs::remove_file(&socket_path);
UnixListener::bind(&socket_path).unwrap()
});
let dbpath = dbpath.to_owned();
let f = move || {
let credentials =
rt.block_on(async { crate::credentials::Credentials::new().await.unwrap() });
let local = tokio::task::LocalSet::new();
let mut system_notifier =
SystemNotifier::new(&dbpath, notification_proxy, network_proxy, credentials);
local.spawn_local(async move {
system_notifier.watch_subscribed().await.unwrap();
let system_client: system_notifier::Client = capnp_rpc::new_client(system_notifier);
loop {
match listener.accept().await {
Ok((stream, _addr)) => {
info!("client connected");
let (reader, writer) =
tokio_util::compat::TokioAsyncReadCompatExt::compat(stream).split();
let network = twoparty::VatNetwork::new(
reader,
writer,
rpc_twoparty_capnp::Side::Server,
Default::default(),
);
let rpc_system =
RpcSystem::new(Box::new(network), Some(system_client.clone().client));
tokio::task::spawn_local(rpc_system);
}
Err(e) => {
error!(error=%e);
}
}
}
});
rt.block_on(local);
};
std::thread::spawn(move || {
f();
});
Ok(())
}

View File

@ -1,244 +0,0 @@
use std::ops::ControlFlow;
use std::sync::Arc;
use std::time::Duration;
use futures::prelude::*;
use serde::{Deserialize, Serialize};
use tokio::io::AsyncBufReadExt;
use tokio::sync::mpsc;
use tokio_stream::wrappers::LinesStream;
use tracing::warn;
use tracing::{debug, error, info, instrument, Instrument};
use crate::{
models,
ntfy_capnp::{output_channel, Status},
Error, SharedEnv,
};
const CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(15);
const TIMEOUT: std::time::Duration = std::time::Duration::from_secs(240); // 4 minutes
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "event")]
pub enum Event {
#[serde(rename = "open")]
Open {
id: String,
time: usize,
expires: Option<usize>,
topic: String,
},
#[serde(rename = "message")]
Message {
id: String,
expires: Option<usize>,
#[serde(flatten)]
message: models::Message,
},
#[serde(rename = "keepalive")]
KeepAlive {
id: String,
time: usize,
expires: Option<usize>,
topic: String,
},
}
pub fn build_client() -> anyhow::Result<reqwest::Client> {
Ok(reqwest::Client::builder()
.connect_timeout(CONNECT_TIMEOUT)
.pool_idle_timeout(TIMEOUT)
// rustls is used because HTTP 2 isn't discovered with native-tls.
// HTTP 2 is required to multiplex multiple requests over a single connection.
// You can check that the app is using a single connection to a server by doing
// ```
// ping ntfy.sh # to get the ip address
// netstat | grep $ip
// ```
.use_rustls_tls()
.build()?)
}
fn topic_request(
client: &reqwest::Client,
endpoint: &str,
topic: &str,
since: u64,
username: Option<&str>,
password: Option<&str>,
) -> anyhow::Result<reqwest::Request> {
let url = models::Subscription::build_url(endpoint, topic, since)?;
let mut req = client
.get(url)
.header("Content-Type", "application/x-ndjson")
.header("Transfer-Encoding", "chunked");
if let Some(username) = username {
req = req.basic_auth(username, password);
}
Ok(req.build()?)
}
async fn response_lines(
res: impl tokio::io::AsyncBufRead,
) -> Result<impl futures::Stream<Item = Result<String, std::io::Error>>, reqwest::Error> {
let lines = LinesStream::new(res.lines());
Ok(lines)
}
pub enum BroadcasterEvent {
Stop,
Restart,
}
pub struct TopicListener {
env: crate::SharedEnv,
endpoint: String,
topic: String,
status: Status,
output_channel: output_channel::Client,
since: u64,
}
impl TopicListener {
pub fn new(
env: SharedEnv,
endpoint: String,
topic: String,
since: u64,
output_channel: output_channel::Client,
) -> mpsc::Sender<ControlFlow<()>> {
let (tx, mut rx) = mpsc::channel(8);
let network = env.network.clone();
let mut this = Self {
env,
endpoint,
topic,
status: Status::Down,
output_channel,
since,
};
tokio::task::spawn_local(async move {
loop {
tokio::select! {
_ = this.run_supervised_loop().instrument(tracing::debug_span!("run_supervised_loop")) => {},
res = rx.recv() => match res {
Some(ControlFlow::Continue(_)) => {
info!("Refreshed");
}
None | Some(ControlFlow::Break(_)) => {
break;
}
}
}
}
});
let tx_clone = tx.clone();
tokio::task::spawn_local(async move {
if let Err(e) = Self::reload_on_network_change(network, tx_clone.clone()).await {
warn!(error = %e, "watching network failed")
}
});
tx
}
async fn reload_on_network_change(
monitor: Arc<dyn models::NetworkMonitorProxy>,
tx: mpsc::Sender<ControlFlow<()>>,
) -> anyhow::Result<()> {
let mut m = monitor.listen();
while let Some(_) = m.next().await {
tx.send(ControlFlow::Continue(())).await?;
}
Ok(())
}
fn send_current_status(&mut self) -> impl Future<Output = anyhow::Result<()>> {
let mut req = self.output_channel.send_status_request();
req.get().set_status(self.status);
async move {
req.send().promise.await?;
Ok(())
}
}
#[instrument(skip_all)]
async fn recv_and_forward(&mut self) -> anyhow::Result<()> {
let creds = self.env.credentials.get(&self.endpoint);
let req = topic_request(
&self.env.http,
&self.endpoint,
&self.topic,
self.since,
creds.as_ref().map(|x| x.username.as_str()),
creds.as_ref().map(|x| x.password.as_str()),
);
let res = self.env.http.execute(req?).await?;
let reader = tokio_util::io::StreamReader::new(
res.bytes_stream()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string())),
);
let stream = response_lines(reader).await?;
tokio::pin!(stream);
self.status = Status::Up;
self.send_current_status().await.unwrap();
info!(topic = %&self.topic, "listening");
while let Some(msg) = stream.next().await {
let msg = msg?;
let min_msg = serde_json::from_str::<models::MinMessage>(&msg)
.map_err(|e| Error::InvalidMinMessage(msg.to_string(), e))?;
self.since = min_msg.time.max(self.since);
let event = serde_json::from_str(&msg)
.map_err(|e| Error::InvalidMessage(msg.to_string(), e))?;
match event {
Event::Message { .. } => {
debug!("message event");
let mut req = self.output_channel.send_message_request();
req.get().set_message(msg.as_str().into());
req.send().promise.await?;
}
Event::KeepAlive { .. } => {
debug!("keepalive event");
}
Event::Open { .. } => {
debug!("open event");
}
}
}
Ok(())
}
async fn run_supervised_loop(&mut self) {
let retrier = || {
crate::retry::WaitExponentialRandom::builder()
.min(Duration::from_secs(1))
.max(Duration::from_secs(5 * 60))
.build()
};
let mut retry = retrier();
loop {
let start_time = std::time::Instant::now();
if let Err(e) = self.recv_and_forward().await {
let uptime = std::time::Instant::now().duration_since(start_time);
// Reset retry delay to minimum if uptime was decent enough
if uptime > Duration::from_secs(60 * 4) {
retry = retrier();
}
error!(error = ?e);
self.status = Status::Degraded;
self.send_current_status().await.unwrap();
info!(delay = ?retry.next_delay(), "restarting");
retry.wait().await;
} else {
break;
}
}
}
}

View File

@ -1,19 +1,13 @@
use std::cell::Cell;
use std::path::Path;
use std::path::PathBuf;
use std::pin::Pin;
use std::rc::Rc;
use adw::prelude::*;
use adw::subclass::prelude::*;
use capnp_rpc::{rpc_twoparty_capnp, twoparty, RpcSystem};
use futures::stream::Stream;
use futures::AsyncReadExt;
use gio::SocketClient;
use gio::UnixSocketAddress;
use gtk::{gdk, gio, glib};
use ntfy_daemon::models;
use ntfy_daemon::ntfy_capnp::system_notifier;
use ntfy_daemon::NtfyHandle;
use tracing::{debug, error, info, warn};
use crate::config::{APP_ID, PKGDATADIR, PROFILE, VERSION};
@ -30,8 +24,8 @@ mod imp {
#[derive(Default)]
pub struct NotifyApplication {
pub window: RefCell<WeakRef<NotifyWindow>>,
pub socket_path: RefCell<PathBuf>,
pub hold_guard: OnceCell<gio::ApplicationHoldGuard>,
pub ntfy: OnceCell<NtfyHandle>,
}
#[glib::object_subclass]
@ -58,8 +52,6 @@ mod imp {
// Set icons for shell
gtk::Window::set_default_icon_name(APP_ID);
let socket_path = glib::user_data_dir().join("com.ranfdev.Notify.socket");
self.socket_path.replace(socket_path);
app.setup_css();
app.setup_gactions();
app.setup_accels();
@ -71,7 +63,7 @@ mod imp {
let app = self.obj();
if self.hold_guard.get().is_none() {
app.ensure_rpc_running(&self.socket_path.borrow());
app.ensure_rpc_running();
}
glib::MainContext::default().spawn_local(async move {
@ -108,7 +100,7 @@ impl NotifyApplication {
return;
}
}
self.build_window(&self.imp().socket_path.borrow());
self.build_window();
self.main_window().present();
}
@ -253,7 +245,7 @@ impl NotifyApplication {
Ok(())
}
fn ensure_rpc_running(&self, socket_path: &Path) {
fn ensure_rpc_running(&self) {
let dbpath = glib::user_data_dir().join("com.ranfdev.Notify.sqlite");
info!(database_path = %dbpath.display());
@ -317,42 +309,19 @@ impl NotifyApplication {
}
}
let proxies = std::sync::Arc::new(Proxies { notification: s });
ntfy_daemon::system_client::start(
socket_path.to_owned(),
dbpath.to_str().unwrap(),
proxies.clone(),
proxies,
)
.unwrap();
let ntfy = ntfy_daemon::start(dbpath.to_str().unwrap(), proxies.clone(), proxies).unwrap();
self.imp()
.ntfy
.set(ntfy)
.or(Err(anyhow::anyhow!("failed setting ntfy")))
.unwrap();
self.imp().hold_guard.set(self.hold()).unwrap();
}
fn build_window(&self, socket_path: &Path) {
let address = UnixSocketAddress::new(socket_path);
let client = SocketClient::new();
let connection =
SocketClientExt::connect(&client, &address, gio::Cancellable::NONE).unwrap();
fn build_window(&self) {
let ntfy = self.imp().ntfy.get().unwrap();
let rw = connection.into_async_read_write().unwrap();
let (reader, writer) = rw.split();
let rpc_network = Box::new(twoparty::VatNetwork::new(
reader,
writer,
rpc_twoparty_capnp::Side::Client,
Default::default(),
));
let mut rpc_system = RpcSystem::new(rpc_network, None);
let client: system_notifier::Client =
rpc_system.bootstrap(rpc_twoparty_capnp::Side::Server);
glib::MainContext::default().spawn_local(async move {
debug!("rpc_system started");
rpc_system.await.unwrap();
debug!("rpc_system stopped");
});
let window = NotifyWindow::new(self, client);
let window = NotifyWindow::new(self, ntfy.clone());
*self.imp().window.borrow_mut() = window.downgrade();
}
}

View File

@ -1,56 +1,36 @@
use std::cell::{Cell, OnceCell, RefCell};
use std::future::Future;
use std::rc::Rc;
use adw::prelude::*;
use capnp::capability::Promise;
use capnp_rpc::pry;
use glib::subclass::prelude::*;
use glib::Properties;
use gtk::{gio, glib};
use ntfy_daemon::models;
use ntfy_daemon::ntfy_capnp::{output_channel, subscription, watch_handle, Status};
use tracing::{debug, error, instrument};
use ntfy_daemon::{models, ConnectionState, ListenerEvent};
use tracing::{error, instrument};
struct TopicWatcher {
sub: glib::WeakRef<Subscription>,
#[repr(u16)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Status {
Down = 0,
Degraded = 1,
Up = 2,
}
impl output_channel::Server for TopicWatcher {
fn send_message(
&mut self,
params: output_channel::SendMessageParams,
_results: output_channel::SendMessageResults,
) -> capnp::capability::Promise<(), capnp::Error> {
if let Some(sub) = self.sub.upgrade() {
let request = pry!(params.get());
let message = pry!(pry!(request.get_message()).to_str());
let msg: models::Message = serde_json::from_str(message).unwrap();
sub.imp().messages.append(&glib::BoxedAnyObject::new(msg));
sub.update_unread_count();
Promise::ok(())
} else {
Promise::err(capnp::Error::failed("dead channel".to_string()))
}
}
fn send_status(
&mut self,
params: output_channel::SendStatusParams,
_: output_channel::SendStatusResults,
) -> capnp::capability::Promise<(), capnp::Error> {
if let Some(sub) = self.sub.upgrade() {
let status = pry!(pry!(params.get()).get_status());
sub.imp().status.set(status);
sub.notify_status();
Promise::ok(())
} else {
Promise::err(capnp::Error::failed("dead channel".to_string()))
impl From<u16> for Status {
fn from(value: u16) -> Self {
match value {
0 => Status::Down,
1 => Status::Degraded,
2 => Status::Up,
_ => panic!("Invalid value for Status"),
}
}
}
impl Drop for TopicWatcher {
fn drop(&mut self) {
debug!("Dropped topic watcher");
impl From<Status> for u16 {
fn from(status: Status) -> Self {
status as u16
}
}
@ -76,8 +56,7 @@ mod imp {
pub unread_count: Cell<u32>,
pub read_until: Cell<u64>,
pub messages: gio::ListStore,
pub client: OnceCell<subscription::Client>,
pub remote_handle: RefCell<Option<watch_handle::Client>>,
pub client: OnceCell<ntfy_daemon::SubscriptionHandle>,
}
impl Subscription {
@ -100,7 +79,6 @@ mod imp {
client: Default::default(),
unread_count: Default::default(),
read_until: Default::default(),
remote_handle: Default::default(),
}
}
}
@ -120,7 +98,7 @@ glib::wrapper! {
}
impl Subscription {
pub fn new(client: subscription::Client) -> Self {
pub fn new(client: ntfy_daemon::SubscriptionHandle) -> Self {
let this: Self = glib::Object::builder().build();
let imp = this.imp();
if let Err(_) = imp.client.set(client) {
@ -159,34 +137,54 @@ impl Subscription {
self._set_display_name(display_name.to_string());
}
fn load(&self) -> Promise<(), capnp::Error> {
let imp = self.imp();
let req_info = imp.client.get().unwrap().get_info_request();
let req_messages = {
let mut req = imp.client.get().unwrap().watch_request();
req.get().set_watcher(capnp_rpc::new_client(TopicWatcher {
sub: self.downgrade(),
}));
req
};
fn load(&self) -> impl Future<Output = anyhow::Result<()>> {
let this = self.clone();
Promise::from_future(async move {
let info = req_info.send().promise.await?;
let info = info.get()?;
async move {
let remote_subscription = this.imp().client.get().unwrap();
let model = remote_subscription.model().await;
this.init_info(
info.get_topic()?.to_str()?,
info.get_server()?.to_str()?,
info.get_muted(),
info.get_read_until(),
info.get_display_name()?.to_str()?,
&model.topic,
&model.server,
model.muted,
model.read_until,
&model.display_name,
);
let message_stream = req_messages.send().promise.await?;
let handle = message_stream.get()?.get_handle()?;
this.imp().remote_handle.replace(Some(handle));
let (prev_msgs, mut rx) = remote_subscription.attach().await;
for msg in prev_msgs {
this.handle_event(msg);
}
while let Ok(ev) = rx.recv().await {
this.handle_event(ev);
}
Ok(())
})
}
}
fn handle_event(&self, ev: ListenerEvent) {
match ev {
ListenerEvent::Message(msg) => {
self.imp().messages.append(&glib::BoxedAnyObject::new(msg));
self.update_unread_count();
}
ListenerEvent::ConnectionStateChanged(connection_state) => {
self.set_connection_state(connection_state);
}
}
}
fn set_connection_state(&self, state: ConnectionState) {
let status = match state {
ConnectionState::Unitialized => Status::Degraded,
ConnectionState::Connected => Status::Up,
ConnectionState::Reconnecting { .. } => Status::Degraded,
};
self.imp().status.set(status);
dbg!(status);
self.notify_status();
}
fn _set_display_name(&self, value: String) {
@ -200,34 +198,36 @@ impl Subscription {
self.notify_display_name();
}
#[instrument(skip_all)]
pub fn set_display_name(&self, value: String) -> Promise<(), anyhow::Error> {
pub fn set_display_name(&self, value: String) -> impl Future<Output = anyhow::Result<()>> {
let this = self.clone();
Promise::from_future(async move {
async move {
this._set_display_name(value);
this.send_updated_info().await?;
Ok(())
})
}
}
fn send_updated_info(&self) -> Promise<(), anyhow::Error> {
async fn send_updated_info(&self) -> anyhow::Result<()> {
let imp = self.imp();
let mut req = imp.client.get().unwrap().update_info_request();
let mut val = pry!(req.get().get_value());
val.set_muted(imp.muted.get());
val.set_display_name(imp.display_name.borrow().as_str().into());
val.set_read_until(imp.read_until.get());
Promise::from_future(async move {
debug!("sending update_info");
req.send().promise.await?;
Ok(())
})
imp.client
.get()
.unwrap()
.update_info(
models::Subscription::builder(self.topic())
.display_name((imp.display_name.borrow().to_string()))
.muted(imp.muted.get())
.build()
.map_err(|e| anyhow::anyhow!("invalid subscription data {:?}", e))?,
)
.await?;
Ok(())
}
fn last_message(list: &gio::ListStore) -> Option<models::Message> {
fn last_message(list: &gio::ListStore) -> Option<models::ReceivedMessage> {
let n = list.n_items();
let last = list
.item(n.checked_sub(1)?)
.and_downcast::<glib::BoxedAnyObject>()?;
let last = last.borrow::<models::Message>();
let last = last.borrow::<models::ReceivedMessage>();
Some(last.clone())
}
fn update_unread_count(&self) {
@ -240,60 +240,52 @@ impl Subscription {
self.notify_unread_count();
}
pub fn set_muted(&self, value: bool) -> Promise<(), anyhow::Error> {
pub fn set_muted(&self, value: bool) -> impl Future<Output = anyhow::Result<()>> {
let this = self.clone();
Promise::from_future(async move {
async move {
this.imp().muted.replace(value);
this.notify_muted();
this.send_updated_info().await?;
Ok(())
})
}
}
pub fn flag_all_as_read(&self) -> Promise<(), anyhow::Error> {
pub async fn flag_all_as_read(&self) -> anyhow::Result<()> {
let imp = self.imp();
let Some(value) = Self::last_message(&imp.messages)
.map(|last| last.time)
.filter(|time| *time > self.imp().read_until.get())
else {
return Promise::ok(());
return Ok(());
};
let this = self.clone();
Promise::from_future(async move {
let mut req = this.imp().client.get().unwrap().update_read_until_request();
req.get().set_value(value);
req.send().promise.await?;
this.imp().read_until.set(value);
this.update_unread_count();
Ok(())
})
this.imp()
.client
.get()
.unwrap()
.update_read_until(value)
.await?;
this.imp().read_until.set(value);
this.update_unread_count();
Ok(())
}
pub fn publish_msg(&self, mut msg: models::Message) -> Promise<(), anyhow::Error> {
pub async fn publish_msg(&self, mut msg: models::OutgoingMessage) -> anyhow::Result<()> {
let imp = self.imp();
let json = {
msg.topic = self.topic();
serde_json::to_string(&msg)
serde_json::to_string(&msg)?
};
let mut req = imp.client.get().unwrap().publish_request();
req.get().set_message(pry!(json).as_str().into());
Promise::from_future(async move {
debug!("sending publish");
req.send().promise.await?;
Ok(())
})
imp.client.get().unwrap().publish(json).await?;
Ok(())
}
#[instrument(skip_all)]
pub fn clear_notifications(&self) -> Promise<(), anyhow::Error> {
pub async fn clear_notifications(&self) -> anyhow::Result<()> {
let imp = self.imp();
let req = imp.client.get().unwrap().clear_notifications_request();
let this = self.clone();
Promise::from_future(async move {
debug!("sending clear_notifications");
req.send().promise.await?;
this.imp().messages.remove_all();
Ok(())
})
imp.client.get().unwrap().clear_notifications().await?;
self.imp().messages.remove_all();
Ok(())
}
pub fn nice_status(&self) -> Status {

View File

@ -166,7 +166,7 @@ impl AddSubscriptionDialog {
obj.set_content_width(480);
obj.set_child(Some(&toolbar_view));
}
pub fn subscription(&self) -> Result<models::Subscription, Vec<ntfy_daemon::Error>> {
pub fn subscription(&self) -> Result<models::Subscription, ntfy_daemon::Error> {
let w = { self.imp().widgets.borrow().clone() };
let mut sub = models::Subscription::builder(w.topic_entry.text().to_string());
if w.server_expander.enables_expansion() {
@ -183,7 +183,7 @@ impl AddSubscriptionDialog {
w.topic_entry.remove_css_class("error");
w.sub_btn.set_sensitive(true);
if let Err(errs) = sub {
if let Err(ntfy_daemon::Error::InvalidSubscription(errs)) = sub {
w.sub_btn.set_sensitive(false);
for e in errs {
match e {

View File

@ -182,7 +182,7 @@ impl AdvancedMessageDialog {
&mut buffer.start_iter(),
&mut buffer.end_iter(),
true,
)).map_err(|e| capnp::Error::failed(e.to_string()))?;
))?;
thisc.imp().subscription.get().unwrap()
.publish_msg(msg).await
};

View File

@ -34,12 +34,12 @@ glib::wrapper! {
}
impl MessageRow {
pub fn new(msg: models::Message) -> Self {
pub fn new(msg: models::ReceivedMessage) -> Self {
let this: Self = glib::Object::new();
this.build_ui(msg);
this
}
fn build_ui(&self, msg: models::Message) {
fn build_ui(&self, msg: models::ReceivedMessage) {
self.set_margin_top(8);
self.set_margin_bottom(8);
self.set_margin_start(8);

View File

@ -3,11 +3,12 @@ use std::cell::OnceCell;
use adw::prelude::*;
use adw::subclass::prelude::*;
use gtk::{gio, glib};
use ntfy_daemon::ntfy_capnp::system_notifier;
use crate::error::*;
mod imp {
use ntfy_daemon::NtfyHandle;
use super::*;
#[derive(gtk::CompositeTemplate)]
@ -25,7 +26,7 @@ mod imp {
pub added_accounts: TemplateChild<gtk::ListBox>,
#[template_child]
pub added_accounts_group: TemplateChild<adw::PreferencesGroup>,
pub notifier: OnceCell<system_notifier::Client>,
pub notifier: OnceCell<NtfyHandle>,
}
impl Default for NotifyPreferences {
@ -77,7 +78,7 @@ glib::wrapper! {
}
impl NotifyPreferences {
pub fn new(notifier: system_notifier::Client) -> Self {
pub fn new(notifier: ntfy_daemon::NtfyHandle) -> Self {
let obj: Self = glib::Object::builder().build();
obj.imp()
.notifier
@ -100,21 +101,15 @@ impl NotifyPreferences {
pub async fn show_accounts(&self) -> anyhow::Result<()> {
let imp = self.imp();
let req = imp.notifier.get().unwrap().list_accounts_request();
let res = req.send().promise.await?;
let accounts = res.get()?.get_list()?;
let accounts = imp.notifier.get().unwrap().list_accounts().await?;
imp.added_accounts_group.set_visible(!accounts.is_empty());
imp.added_accounts.remove_all();
for a in accounts {
let server = a.get_server()?.to_string()?;
let username = a.get_username()?.to_string()?;
let row = adw::ActionRow::builder()
.title(&server)
.subtitle(&username)
.title(&a.server)
.subtitle(&a.username)
.build();
row.add_css_class("property");
row.add_suffix(&{
@ -125,10 +120,9 @@ impl NotifyPreferences {
let this = self.clone();
btn.connect_clicked(move |btn| {
let this = this.clone();
let username = username.clone();
let server = server.clone();
let a = a.clone();
btn.error_boundary()
.spawn(async move { this.remove_account(&server, &username).await });
.spawn(async move { this.remove_account(&a.server).await });
});
btn
});
@ -142,29 +136,23 @@ impl NotifyPreferences {
let server = imp.server_entry.text();
let username = imp.username_entry.text();
let mut req = imp.notifier.get().unwrap().add_account_request();
let mut acc = req.get().get_account()?;
acc.set_username(username[..].into());
acc.set_server(server[..].into());
req.get().set_password(password[..].into());
req.send().promise.await?;
imp.notifier
.get()
.unwrap()
.add_account(&server, &username, &password)
.await?;
self.show_accounts().await?;
Ok(())
}
pub async fn remove_account(&self, server: &str, username: &str) -> anyhow::Result<()> {
let mut req = self.imp().notifier.get().unwrap().remove_account_request();
let mut acc = req.get().get_account()?;
acc.set_username(username[..].into());
acc.set_server(server[..].into());
req.send().promise.await?;
pub async fn remove_account(&self, server: &str) -> anyhow::Result<()> {
self.imp()
.notifier
.get()
.unwrap()
.remove_account(server)
.await?;
self.show_accounts().await?;
Ok(())
}
}

View File

@ -3,15 +3,15 @@ use std::cell::OnceCell;
use adw::prelude::*;
use adw::subclass::prelude::*;
use futures::prelude::*;
use gtk::{gio, glib};
use ntfy_daemon::models;
use ntfy_daemon::ntfy_capnp::{system_notifier, Status};
use ntfy_daemon::NtfyHandle;
use tracing::warn;
use crate::application::NotifyApplication;
use crate::config::{APP_ID, PROFILE};
use crate::error::*;
use crate::subscription::Status;
use crate::subscription::Subscription;
use crate::widgets::*;
@ -52,7 +52,7 @@ mod imp {
pub send_btn: TemplateChild<gtk::Button>,
#[template_child]
pub code_btn: TemplateChild<gtk::Button>,
pub notifier: OnceCell<system_notifier::Client>,
pub notifier: OnceCell<NtfyHandle>,
pub conn: OnceCell<gio::SocketConnection>,
pub settings: gio::Settings,
pub banner_binding: Cell<Option<(Subscription, glib::SignalHandlerId)>>,
@ -138,7 +138,8 @@ mod imp {
});
klass.install_action("win.clear-notifications", None, |this, _, _| {
this.selected_subscription().map(|sub| {
this.error_boundary().spawn(sub.clear_notifications());
this.error_boundary()
.spawn(async move { sub.clear_notifications().await });
});
});
//klass.bind_template_instance_callbacks();
@ -190,7 +191,7 @@ glib::wrapper! {
}
impl NotifyWindow {
pub fn new(app: &NotifyApplication, notifier: system_notifier::Client) -> Self {
pub fn new(app: &NotifyApplication, notifier: NtfyHandle) -> Self {
let obj: Self = glib::Object::builder().property("application", app).build();
if let Err(_) = obj.imp().notifier.set(notifier) {
@ -211,24 +212,25 @@ impl NotifyWindow {
fn connect_entry_and_send_btn(&self) {
let imp = self.imp();
let this = self.clone();
let entry = imp.entry.clone();
let publish = move || {
let p = this
.selected_subscription()
.unwrap()
.publish_msg(models::Message {
message: Some(entry.text().as_str().to_string()),
..models::Message::default()
});
entry.error_boundary().spawn(async move {
p.await?;
Ok(())
});
};
let publishc = publish.clone();
imp.entry.connect_activate(move |_| publishc());
imp.send_btn.connect_clicked(move |_| publish());
imp.entry.connect_activate(move |_| this.publish_msg());
let this = self.clone();
imp.send_btn.connect_clicked(move |_| this.publish_msg());
}
fn publish_msg(&self) {
let entry = self.imp().entry.clone();
let this = self.clone();
entry.error_boundary().spawn(async move {
this.selected_subscription()
.unwrap()
.publish_msg(models::OutgoingMessage {
message: Some(entry.text().as_str().to_string()),
..models::OutgoingMessage::default()
})
.await?;
Ok(())
});
}
fn connect_code_btn(&self) {
let imp = self.imp();
@ -260,19 +262,14 @@ impl NotifyWindow {
}
fn add_subscription(&self, sub: models::Subscription) {
let mut req = self.notifier().subscribe_request();
req.get().set_server(sub.server.as_str().into());
req.get().set_topic(sub.topic.as_str().into());
let res = req.send();
let this = self.clone();
self.error_boundary().spawn(async move {
let sub = this.notifier().subscribe(&sub.server, &sub.topic).await?;
let imp = this.imp();
// Subscription::new will use the pipelined client to retrieve info about the subscription
let subscription = Subscription::new(res.pipeline.get_subscription());
let subscription = Subscription::new(sub);
// We want to still check if there were any errors adding the subscription.
res.promise.await?;
imp.subscription_list_model.append(&subscription);
let i = imp.subscription_list_model.n_items() - 1;
@ -283,26 +280,22 @@ impl NotifyWindow {
}
fn unsubscribe(&self) {
let mut req = self.notifier().unsubscribe_request();
let sub = self.selected_subscription().unwrap();
req.get().set_server(sub.server().as_str().into());
req.get().set_topic(sub.topic().as_str().into());
let res = req.send();
let this = self.clone();
self.error_boundary().spawn(async move {
let imp = this.imp();
res.promise.await?;
this.notifier()
.unsubscribe(sub.server().as_str(), sub.topic().as_str())
.await?;
let imp = this.imp();
if let Some(i) = imp.subscription_list_model.find(&sub) {
imp.subscription_list_model.remove(i);
}
Ok(())
});
}
fn notifier(&self) -> &system_notifier::Client {
fn notifier(&self) -> &NtfyHandle {
self.imp().notifier.get().unwrap()
}
fn selected_subscription(&self) -> Option<Subscription> {
@ -328,14 +321,13 @@ impl NotifyWindow {
});
let this = self.clone();
let req = self.notifier().list_subscriptions_request();
let res = req.send();
self.error_boundary().spawn(async move {
let list = res.promise.await?;
let list = list.get()?.get_list()?;
let imp = this.imp();
glib::timeout_future_seconds(1).await;
let list = this.notifier().list_subscriptions().await?;
for sub in list {
imp.subscription_list_model.append(&Subscription::new(sub?));
this.imp()
.subscription_list_model
.append(&Subscription::new(sub));
}
Ok(())
});
@ -371,7 +363,7 @@ impl NotifyWindow {
imp.message_list
.bind_model(Some(&sub.imp().messages), move |obj| {
let b = obj.downcast_ref::<glib::BoxedAnyObject>().unwrap();
let msg = b.borrow::<models::Message>();
let msg = b.borrow::<models::ReceivedMessage>();
MessageRow::new(msg.clone()).upcast()
});
@ -402,7 +394,7 @@ impl NotifyWindow {
{
self.selected_subscription().map(|sub| {
self.error_boundary()
.spawn(sub.flag_all_as_read().map_err(|e| e.into()));
.spawn(async move { sub.flag_all_as_read().await });
});
}
}