init notify

This commit is contained in:
ranfdev
2023-10-08 15:57:09 +02:00
parent c3de2224a8
commit 52ea57057e
40 changed files with 13845 additions and 552 deletions

1
ntfy-daemon/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/target

2312
ntfy-daemon/Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

31
ntfy-daemon/Cargo.toml Normal file
View File

@ -0,0 +1,31 @@
[package]
name = "ntfy-daemon"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[build-dependencies]
capnpc = "0.17.2"
[dependencies]
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
capnp = "0.17.2"
capnp-rpc = "0.17.0"
futures = "0.3.0"
tokio = { version = "1.0.0", features = ["net", "rt", "macros", "parking_lot"]}
tokio-util = { version = "0.7.4", features = ["compat", "io"] }
clap = { version = "4.3.11", features = ["derive"] }
anyhow = "1.0.71"
tokio-stream = { version = "0.1.14", features = ["io-util", "time"] }
rusqlite = "0.29.0"
rand = "0.8.5"
reqwest = { version = "0.11.18", features = ["stream", "rustls-tls"]}
url = "2.4.0"
ashpd = "0.6.0"
generational-arena = "0.2.9"
tracing = "0.1.37"
thiserror = "1.0.49"
regex = "1.9.6"

5
ntfy-daemon/README.md Normal file
View File

@ -0,0 +1,5 @@
# ntfy-daemon
Rust crate providing a capnp-rpc interface to multiple ntfy servers.
Connections to the same server are multiplexed over http2.
Messages are received and stored in a sqlite database for persistance.

6
ntfy-daemon/build.rs Normal file
View File

@ -0,0 +1,6 @@
fn main() {
capnpc::CompilerCommand::new()
.file("src/ntfy.capnp")
.run()
.unwrap();
}

File diff suppressed because it is too large Load Diff

24
ntfy-daemon/src/lib.rs Normal file
View File

@ -0,0 +1,24 @@
pub mod message_repo;
pub mod models;
pub mod ntfy_proxy;
pub mod retry;
pub mod system_client;
pub mod ntfy_capnp {
include!(concat!(env!("OUT_DIR"), "/src/ntfy_capnp.rs"));
}
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("topic {0} must not be empty and must contain only alphanumeric characters and _ (underscore)")]
InvalidTopic(String),
#[error("duplicate message")]
DuplicateMessage,
#[error("can't parse the minimum set of required fields from the message {0}")]
InvalidMinMessage(String, #[source] serde_json::Error),
#[error("can't parse the complete message {0}")]
InvalidMessage(String, #[source] serde_json::Error),
#[error("database error")]
Db(#[from] rusqlite::Error),
#[error("subscription not found while {0}")]
SubscriptionNotFound(String),
}

View File

@ -0,0 +1,29 @@
CREATE TABLE IF NOT EXISTS server (
id INTEGER PRIMARY KEY,
endpoint TEXT NOT NULL UNIQUE,
timeout INTEGER
);
CREATE TABLE IF NOT EXISTS subscription (
topic TEXT,
display_name TEXT,
muted INTEGER NOT NULL DEFAULT 0,
server INTEGER REFERENCES server(id),
archived INTEGER NOT NULL DEFAULT 0,
reserved INTEGER NOT NULL DEFAULT 0,
read_until INTEGER NOT NULL DEFAULT 0,
symbolic_icon TEXT,
PRIMARY KEY (server, topic)
);
CREATE TABLE IF NOT EXISTS message (
server INTEGER,
data TEXT NOT NULL,
topic TEXT AS (data ->> '$.topic'), -- For the FOREIGN KEY constraint
FOREIGN KEY (server, topic) REFERENCES subscription(server, topic) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS message_by_time ON message (data ->> '$.time');
-- I can't put a JSON expression inside a UNIQUE constraint,
-- but I can do it on a UNIQUE INDEX
CREATE UNIQUE INDEX IF NOT EXISTS server_and_message_id ON message (server, data ->> '$.id');

View File

@ -0,0 +1,206 @@
use std::{cell::RefCell, rc::Rc};
use rusqlite::{params, Connection, Result};
use tracing::info;
use crate::models;
use crate::Error;
#[derive(Clone, Debug)]
pub struct Db {
conn: Rc<RefCell<Connection>>,
}
impl Db {
pub fn connect(path: &str) -> Result<Self> {
let mut this = Self {
conn: Rc::new(RefCell::new(Connection::open(path)?)),
};
{
this.conn.borrow().execute_batch(
"PRAGMA foreign_keys = ON;
PRAGMA journal_mode = wal;",
)?;
}
this.migrate()?;
Ok(this)
}
fn migrate(&mut self) -> Result<()> {
{
self.conn
.borrow()
.execute_batch(include_str!("./migrations/00.sql"))?
};
Ok(())
}
fn get_or_insert_server(&mut self, server: &str) -> Result<i64> {
let mut conn = self.conn.borrow_mut();
let tx = conn.transaction()?;
let mut res = tx.query_row(
"SELECT id
FROM server
WHERE endpoint = ?1",
params![server,],
|row| {
let id: i64 = row.get(0)?;
Ok(id)
},
);
if let Err(rusqlite::Error::QueryReturnedNoRows) = res {
tx.execute(
"INSERT INTO server (id, endpoint) VALUES (NULL, ?1)",
params![server,],
)?;
res = Ok(tx.last_insert_rowid());
}
tx.commit()?;
res
}
pub fn insert_message(&mut self, server: &str, json_data: &str) -> Result<(), Error> {
let server_id = self.get_or_insert_server(server)?;
let res = self.conn.borrow().execute(
"INSERT INTO message (server, data) VALUES (?1, ?2)",
params![server_id, json_data],
);
match res {
Err(rusqlite::Error::SqliteFailure(_, Some(text)))
if text.starts_with("UNIQUE constraint failed") =>
{
Err(Error::DuplicateMessage)
}
Err(e) => Err(Error::Db(e)),
Ok(_) => Ok(()),
}
}
pub fn list_messages(
&self,
server: &str,
topic: &str,
since: u64,
) -> Result<Vec<String>, rusqlite::Error> {
let conn = self.conn.borrow();
let mut stmt = conn.prepare(
"
SELECT data
FROM subscription sub
JOIN server s ON sub.server = s.id
JOIN message m ON m.server = sub.server AND m.topic = sub.topic
WHERE s.endpoint = ?1 AND m.topic = ?2 AND m.data ->> 'time' >= ?3
ORDER BY m.data ->> 'time'
",
)?;
let msgs: Result<Vec<String>, _> = stmt
.query_map(params![server, topic, since], |row| Ok(row.get(0)?))?
.collect();
Ok(msgs?)
}
pub fn insert_subscription(&mut self, sub: models::Subscription) -> Result<(), Error> {
let server_id = self.get_or_insert_server(&sub.server)?;
self.conn.borrow().execute(
"INSERT INTO subscription (server, topic, display_name, reserved, muted, archived) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
params![
server_id,
sub.topic,
sub.display_name,
sub.reserved,
sub.muted,
sub.archived
],
)?;
Ok(())
}
pub fn remove_subscription(&mut self, server: &str, topic: &str) -> Result<(), Error> {
let server_id = self.get_or_insert_server(server)?;
let res = self.conn.borrow().execute(
"DELETE FROM subscription
WHERE server = ?1 AND topic = ?2",
params![server_id, topic],
)?;
if res <= 0 {
return Err(Error::SubscriptionNotFound("removing subscription".into()));
}
Ok(())
}
pub fn list_subscriptions(&mut self) -> Result<Vec<models::Subscription>, Error> {
let conn = self.conn.borrow();
let mut stmt = conn.prepare(
"SELECT server.endpoint, sub.topic, sub.display_name, sub.reserved, sub.muted, sub.archived, sub.symbolic_icon, sub.read_until
FROM subscription sub
JOIN server ON server.id = sub.server
ORDER BY server.endpoint, sub.display_name, sub.topic
",
)?;
let rows = stmt.query_map(params![], |row| {
Ok(models::Subscription {
server: row.get(0)?,
topic: row.get(1)?,
display_name: row.get(2)?,
reserved: row.get(3)?,
muted: row.get(4)?,
archived: row.get(5)?,
symbolic_icon: row.get(6)?,
read_until: row.get(7)?,
})
})?;
let subs: Result<Vec<_>, rusqlite::Error> = rows.collect();
Ok(subs?)
}
pub fn update_subscription(&mut self, sub: models::Subscription) -> Result<(), Error> {
let server_id = self.get_or_insert_server(&sub.server)?;
let res = self.conn.borrow().execute(
"UPDATE subscription
SET display_name = ?1, reserved = ?2, muted = ?3, archived = ?4, read_until = ?5
WHERE server = ?6 AND topic = ?7",
params![
sub.display_name,
sub.reserved,
sub.muted,
sub.archived,
sub.read_until,
server_id,
sub.topic,
],
)?;
if res <= 0 {
return Err(Error::SubscriptionNotFound("updating subscription".into()));
}
info!(info = ?sub, "stored subscription info");
Ok(())
}
pub fn update_read_until(
&mut self,
server: &str,
topic: &str,
value: u64,
) -> Result<(), Error> {
let server_id = self.get_or_insert_server(server).unwrap();
let conn = self.conn.borrow();
let res = conn.execute(
"UPDATE subscription
SET read_until = ?3
WHERE topic = ?2 AND server = ?1
",
params![server_id, topic, value],
)?;
if res <= 0 {
return Err(Error::SubscriptionNotFound("updating read_until".into()));
}
Ok(())
}
pub fn delete_messages(&mut self, server: &str, topic: &str) -> Result<(), Error> {
let server_id = self.get_or_insert_server(server).unwrap();
let conn = self.conn.borrow();
let res = conn.execute(
"DELETE FROM message
WHERE topic = ?2 AND server = ?1
",
params![server_id, topic],
)?;
if res <= 0 {
return Err(Error::SubscriptionNotFound("deleting messages".into()));
}
Ok(())
}
}

263
ntfy-daemon/src/models.rs Normal file
View File

@ -0,0 +1,263 @@
use std::collections::HashMap;
use std::sync::OnceLock;
use regex::Regex;
use serde::{Deserialize, Serialize};
use crate::Error;
static EMOJI_MAP: OnceLock<HashMap<String, String>> = OnceLock::new();
fn emoji_map() -> &'static HashMap<String, String> {
EMOJI_MAP.get_or_init(move || {
serde_json::from_str(include_str!("../data/mailer_emoji_map.json")).unwrap()
})
}
fn validate_topic(topic: &str) -> Result<&str, Error> {
let re = Regex::new(r"^([A-z]|[0-9]|_)+$").unwrap();
if re.is_match(topic) {
Ok(topic)
} else {
Err(Error::InvalidTopic(topic.to_string()))
}
}
#[derive(Default, Clone, Debug, Serialize, Deserialize)]
pub struct Message {
pub topic: String,
pub message: Option<String>,
pub time: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Vec::is_empty")]
pub tags: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub priority: Option<i8>,
#[serde(skip_serializing_if = "Vec::is_empty")]
#[serde(default)]
pub attach: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub icon: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub filename: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub delay: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub email: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub call: Option<String>,
#[serde(default)]
#[serde(skip_serializing_if = "Vec::is_empty")]
pub actions: Vec<Action>,
}
impl Message {
fn extend_with_emojis(&self, text: &mut String) {
// Add emojis
for t in &self.tags {
if let Some(emoji) = emoji_map().get(t) {
text.push_str(emoji);
}
}
}
pub fn display_title(&self) -> Option<String> {
self.title.as_ref().map(|title| {
let mut title_text = String::new();
self.extend_with_emojis(&mut title_text);
if !title_text.is_empty() {
title_text.push(' ');
}
title_text.push_str(title);
title_text
})
}
pub fn display_message(&self) -> Option<String> {
self.message.as_ref().map(|message| {
let mut out = String::new();
if self.title.is_none() {
self.extend_with_emojis(&mut out);
}
if !out.is_empty() {
out.push(' ');
}
out.push_str(message);
out
})
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct MinMessage {
pub id: String,
pub topic: String,
pub time: u64,
}
#[derive(Clone, Debug)]
pub struct Subscription {
pub server: String,
pub topic: String,
pub display_name: String,
pub muted: bool,
pub archived: bool,
pub reserved: bool,
pub symbolic_icon: Option<String>,
pub read_until: u64,
}
impl Subscription {
pub fn build_url(server: &str, topic: &str, since: u64) -> anyhow::Result<url::Url> {
let mut url = url::Url::parse(server)?;
url.path_segments_mut()
.map_err(|_| anyhow::anyhow!("url can't be base"))?
.push(&topic)
.push("json");
url.query_pairs_mut()
.append_pair("since", &since.to_string());
Ok(url)
}
pub fn validate(self) -> anyhow::Result<Self> {
validate_topic(&self.topic)?;
Self::build_url(&self.server, &self.topic, 0)?;
Ok(self)
}
pub fn builder(server: String, topic: String) -> SubscriptionBuilder {
SubscriptionBuilder::new(server, topic)
}
}
#[derive(Clone)]
pub struct SubscriptionBuilder {
server: String,
topic: String,
muted: bool,
archived: bool,
reserved: bool,
symbolic_icon: Option<String>,
display_name: String,
}
impl SubscriptionBuilder {
pub fn new(server: String, topic: String) -> Self {
Self {
server,
topic,
muted: false,
archived: false,
reserved: false,
symbolic_icon: None,
display_name: String::new(),
}
}
pub fn muted(mut self, muted: bool) -> Self {
self.muted = muted;
self
}
pub fn archived(mut self, archived: bool) -> Self {
self.archived = archived;
self
}
pub fn reserved(mut self, reserved: bool) -> Self {
self.reserved = reserved;
self
}
pub fn symbolic_icon(mut self, symbolic_icon: Option<String>) -> Self {
self.symbolic_icon = symbolic_icon;
self
}
pub fn display_name(mut self, display_name: String) -> Self {
self.display_name = display_name;
self
}
pub fn build(self) -> anyhow::Result<Subscription> {
let res = Subscription {
server: self.server,
topic: self.topic,
muted: self.muted,
archived: self.archived,
reserved: self.reserved,
symbolic_icon: self.symbolic_icon,
display_name: self.display_name,
read_until: 0,
};
res.validate()
}
}
fn default_method() -> String {
"POST".to_string()
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "action")]
pub enum Action {
#[serde(rename = "view")]
View {
label: String,
url: String,
#[serde(default)]
clear: bool,
},
#[serde(rename = "http")]
Http {
label: String,
url: String,
#[serde(default = "default_method")]
method: String,
#[serde(default)]
headers: HashMap<String, String>,
#[serde(default)]
body: String,
#[serde(default)]
clear: bool,
},
#[serde(rename = "broadcast")]
Broadcast {
label: String,
intent: Option<String>,
#[serde(default)]
extras: HashMap<String, String>,
#[serde(default)]
clear: bool,
},
}
#[derive(Debug, PartialEq, Copy, Clone, Default)]
pub enum Status {
#[default]
Down,
Degraded,
Up,
}
impl From<u8> for Status {
fn from(item: u8) -> Self {
match item {
0 => Status::Down,
1 => Status::Degraded,
2 => Status::Up,
_ => Status::Down,
}
}
}
impl From<Status> for u8 {
fn from(item: Status) -> Self {
match item {
Status::Down => 0,
Status::Degraded => 1,
Status::Up => 2,
}
}
}

View File

@ -0,0 +1,47 @@
@0x9663f4dd604afa35;
enum Status {
down @0;
degraded @1;
up @2;
}
interface WatchHandle {}
interface OutputChannel {
sendMessage @0 (message: Text);
sendStatus @1 (status: Status);
done @2 ();
}
interface NtfyProxy {
getServer @0 () -> (server: Text);
watch @1 (topic: Text, watcher: OutputChannel, since: UInt64) -> (handle: WatchHandle);
publish @2 (message: Text);
}
struct SubscriptionInfo {
server @0 :Text;
topic @1 :Text;
displayName @2 :Text;
muted @3 :Bool;
readUntil @4 :UInt64;
}
interface Subscription {
watch @0 (watcher: OutputChannel, since: UInt64) -> (handle: WatchHandle);
publish @1 (message: Text);
getInfo @2 () -> SubscriptionInfo;
updateInfo @3 (value: SubscriptionInfo);
updateReadUntil @4 (value: UInt64);
clearNotifications @5 ();
}
interface SystemNotifier {
subscribe @0 (server: Text, topic: Text) -> (subscription: Subscription);
unsubscribe @1 (server: Text, topic: Text);
listSubscriptions @2 () -> (list: List(Subscription));
}

View File

@ -0,0 +1,355 @@
use std::cell::RefCell;
use std::collections::HashMap;
use std::ops::ControlFlow;
use std::rc::{Rc, Weak};
use std::time::Duration;
use ashpd::desktop::network_monitor::NetworkMonitor;
use capnp::capability::Promise;
use capnp_rpc::pry;
use futures::future::RemoteHandle;
use futures::prelude::*;
use reqwest::header::HeaderValue;
use serde::{Deserialize, Serialize};
use tokio::io::AsyncBufReadExt;
use tokio::sync::mpsc;
use tokio_stream::wrappers::LinesStream;
use tracing::{debug, error, info, instrument, Instrument};
use crate::{
models,
ntfy_capnp::{ntfy_proxy, output_channel, watch_handle, Status},
Error,
};
const CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(15);
const TIMEOUT: std::time::Duration = std::time::Duration::from_secs(240); // 4 minutes
static GLOBAL_MONITOR: tokio::sync::OnceCell<NetworkMonitor> = tokio::sync::OnceCell::const_new();
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "event")]
pub enum Event {
#[serde(rename = "open")]
Open {
id: String,
time: usize,
expires: Option<usize>,
topic: String,
},
#[serde(rename = "message")]
Message {
id: String,
expires: Option<usize>,
#[serde(flatten)]
message: models::Message,
},
#[serde(rename = "keepalive")]
KeepAlive {
id: String,
time: usize,
expires: Option<usize>,
topic: String,
},
}
fn build_client() -> anyhow::Result<reqwest::Client> {
Ok(reqwest::Client::builder()
.connect_timeout(CONNECT_TIMEOUT)
.pool_idle_timeout(TIMEOUT)
// rustls is used because HTTP 2 isn't discovered with native-tls.
// HTTP 2 is required to multiplex multiple requests over a single connection.
// You can check that the app is using a single connection to a server by doing
// ```
// ping ntfy.sh # to get the ip address
// netstat | grep $ip
// ```
.use_rustls_tls()
.build()?)
}
fn topic_request(endpoint: &str, topic: &str, since: u64) -> anyhow::Result<reqwest::Request> {
let url = models::Subscription::build_url(endpoint, topic, since)?;
let mut req = reqwest::Request::new(reqwest::Method::GET, url);
let headers = req.headers_mut();
headers.append(
"Content-Type",
HeaderValue::from_static("application/x-ndjson"),
);
headers.append("Transfer-Encoding", HeaderValue::from_static("chunked"));
Ok(req)
}
async fn response_lines(
res: impl tokio::io::AsyncBufRead,
) -> Result<impl futures::Stream<Item = Result<String, std::io::Error>>, reqwest::Error> {
let lines = LinesStream::new(res.lines());
Ok(lines)
}
pub enum BroadcasterEvent {
Stop,
Restart,
}
struct TopicListener {
endpoint: String,
topic: String,
status: Status,
output_channel: output_channel::Client,
since: u64,
client: reqwest::Client,
}
impl TopicListener {
fn new(
client: reqwest::Client,
endpoint: String,
topic: String,
since: u64,
output_channel: output_channel::Client,
) -> anyhow::Result<mpsc::Sender<ControlFlow<()>>> {
let (tx, mut rx) = mpsc::channel(8);
let mut this = Self {
endpoint,
topic,
status: Status::Down,
output_channel,
since,
client,
};
tokio::task::spawn_local(async move {
loop {
tokio::select! {
_ = this.run_supervised_loop().instrument(tracing::debug_span!("run_supervised_loop")) => {},
res = rx.recv() => match res {
Some(ControlFlow::Continue(_)) => {}
None | Some(ControlFlow::Break(_)) => {
break;
}
}
}
}
});
Ok(tx)
}
fn send_current_status(&mut self) -> impl Future<Output = anyhow::Result<()>> {
let mut req = self.output_channel.send_status_request();
req.get().set_status(self.status);
async move {
req.send().promise.await?;
Ok(())
}
}
#[instrument(skip_all)]
async fn recv_and_forward(&mut self) -> anyhow::Result<()> {
let req = topic_request(&self.endpoint, &self.topic, self.since)?;
let res = self.client.execute(req).await?;
let reader = tokio_util::io::StreamReader::new(
res.bytes_stream()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string())),
);
let stream = response_lines(reader).await?;
tokio::pin!(stream);
self.status = Status::Up;
self.send_current_status().await.unwrap();
info!(topic = %&self.topic, "listening");
while let Some(msg) = stream.next().await {
let msg = msg?;
let min_msg = serde_json::from_str::<models::MinMessage>(&msg)
.map_err(|e| Error::InvalidMinMessage(msg.to_string(), e))?;
self.since = min_msg.time.max(self.since);
let event = serde_json::from_str(&msg)
.map_err(|e| Error::InvalidMessage(msg.to_string(), e))?;
match event {
Event::Message { .. } => {
debug!("message event");
let mut req = self.output_channel.send_message_request();
req.get().set_message(&msg);
req.send().promise.await?;
}
Event::KeepAlive { .. } => {
debug!("keepalive event");
}
Event::Open { .. } => {
debug!("open event");
}
}
}
Ok(())
}
async fn run_supervised_loop(&mut self) {
let retrier = || {
crate::retry::WaitExponentialRandom::builder()
.min(Duration::from_secs(1))
.max(Duration::from_secs(60 * 10))
.build()
};
let mut retry = retrier();
loop {
let start_time = std::time::Instant::now();
if let Err(e) = self.recv_and_forward().await {
let uptime = std::time::Instant::now().duration_since(start_time);
// Reset retry delay to minimum if uptime was decent enough
if uptime > Duration::from_secs(60 * 4) {
retry = retrier();
}
error!(error = ?e);
self.status = Status::Degraded;
self.send_current_status().await.unwrap();
info!(delay = ?retry.next_delay(), "restarting");
retry.wait().await;
} else {
break;
}
}
}
}
struct WatcherImpl {
topic: String,
all_topics: Weak<RefCell<HashMap<String, mpsc::Sender<ControlFlow<()>>>>>,
}
impl Drop for WatcherImpl {
fn drop(&mut self) {
if let Some(m) = self.all_topics.upgrade() {
debug!("Dropped WatcherImpl");
let mut m = m.borrow_mut();
let tx = m[&self.topic].clone();
tokio::task::spawn_local(async move {
tx.send(ControlFlow::Break(())).await.unwrap();
});
m.remove(&self.topic);
}
}
}
impl watch_handle::Server for WatcherImpl {}
// This is a proxy to the actual ntfy server. After a network issue, this will reconnect to the
// server and re-establish all watches.
pub struct NtfyProxyImpl {
endpoint: String,
watching: Rc<RefCell<HashMap<String, mpsc::Sender<ControlFlow<()>>>>>,
client: reqwest::Client,
_monitor_task: RemoteHandle<()>,
}
impl NtfyProxyImpl {
pub fn new(endpoint: String) -> NtfyProxyImpl {
let watching = Rc::new(RefCell::new(
HashMap::<String, mpsc::Sender<ControlFlow<()>>>::new(),
));
let watching_clone = Rc::downgrade(&watching);
let (f, handle) = async move {
let mut prev_available = false;
let monitor = GLOBAL_MONITOR
.get_or_init(|| async move { NetworkMonitor::new().await.unwrap() })
.await;
while let Ok(_) = monitor.receive_changed().await {
let available = monitor.is_available().await.unwrap();
if available && !prev_available {
info!("Refreshed");
if let Some(ws) = watching_clone.upgrade() {
for (_, w) in ws.borrow().iter() {
w.send(ControlFlow::Continue(())).await.unwrap();
}
}
}
prev_available = available;
}
}
.remote_handle();
tokio::task::spawn_local(f);
NtfyProxyImpl {
endpoint,
watching: watching.clone(),
client: build_client().unwrap(),
_monitor_task: handle,
}
}
fn _watch(
&mut self,
topic: String,
watcher: output_channel::Client,
since: u64,
) -> anyhow::Result<watch_handle::Client> {
if !{ self.watching.borrow().contains_key(&topic) } {
self.watching.borrow_mut().insert(
topic.clone(),
TopicListener::new(
self.client.clone(),
self.endpoint.clone(),
topic.clone(),
since,
watcher,
)?,
);
}
Ok(capnp_rpc::new_client(WatcherImpl {
topic,
all_topics: Rc::downgrade(&self.watching),
}))
}
fn _send_msg<'a>(
&'a mut self,
msg: &'a models::Message,
) -> impl Future<Output = Result<(), capnp::Error>> {
let client = reqwest::Client::new();
let json = serde_json::to_string(&msg).unwrap();
let req = client.post(&self.endpoint).body(json.clone());
async move {
info!(json = ?json, "sending message");
let res = req.send().await;
match res {
Err(e) => Err(capnp::Error::failed(e.to_string())),
Ok(res) => {
res.error_for_status()
.map_err(|e| capnp::Error::failed(e.to_string()))?;
Ok(())
}
}
}
}
}
impl ntfy_proxy::Server for NtfyProxyImpl {
fn publish(
&mut self,
params: ntfy_proxy::PublishParams,
_results: ntfy_proxy::PublishResults,
) -> capnp::capability::Promise<(), capnp::Error> {
let params = params.get();
let message = pry!(pry!(params).get_message());
let message: models::Message = serde_json::from_str(message).unwrap();
let res = self._send_msg(&message);
Promise::from_future(async move {
res.await.map_err(|e| capnp::Error::failed(e.to_string()))?;
Ok(())
})
}
fn watch(
&mut self,
params: ntfy_proxy::WatchParams,
mut results: ntfy_proxy::WatchResults,
) -> capnp::capability::Promise<(), capnp::Error> {
let topic = pry!(pry!(params.get()).get_topic());
let watcher = pry!(pry!(params.get()).get_watcher());
let since = pry!(params.get()).get_since();
let handle = pry!(self
._watch(topic.to_owned(), watcher, since.to_owned())
.map_err(|e| capnp::Error::failed(e.to_string())));
results.get().set_handle(handle);
Promise::ok(())
}
}

56
ntfy-daemon/src/retry.rs Normal file
View File

@ -0,0 +1,56 @@
use std::cmp;
use std::time::Duration;
use rand::prelude::*;
use tokio::time::sleep;
pub struct WaitExponentialRandom {
min: Duration,
max: Duration,
i: u64,
multiplier: u64,
}
pub struct WaitExponentialRandomBuilder {
inner: WaitExponentialRandom,
}
impl WaitExponentialRandomBuilder {
pub fn build(self) -> WaitExponentialRandom {
self.inner
}
pub fn min(mut self, duration: Duration) -> Self {
self.inner.min = duration;
self
}
pub fn max(mut self, duration: Duration) -> Self {
self.inner.max = duration;
self
}
pub fn multiplier(mut self, mul: u64) -> Self {
self.inner.multiplier = mul;
self
}
}
impl WaitExponentialRandom {
pub fn builder() -> WaitExponentialRandomBuilder {
WaitExponentialRandomBuilder {
inner: WaitExponentialRandom {
min: Duration::ZERO,
max: Duration::MAX,
i: 0,
multiplier: 1,
},
}
}
pub fn next_delay(&self) -> Duration {
let secs = (1 << self.i) * self.multiplier;
let secs = rand::thread_rng().gen_range(self.min.as_secs()..=secs);
let dur = Duration::from_secs(secs);
cmp::min(cmp::max(dur, self.min), self.max)
}
pub async fn wait(&mut self) {
sleep(self.next_delay()).await;
self.i += 1;
}
}

View File

@ -0,0 +1,504 @@
use std::cell::OnceCell;
use std::cell::{Cell, RefCell};
use std::rc::{Rc, Weak};
use std::time::Duration;
use std::{collections::HashMap, hash::Hash};
use ashpd::desktop::notification::{Notification, NotificationProxy};
use capnp::capability::Promise;
use capnp_rpc::{pry, rpc_twoparty_capnp, twoparty, RpcSystem};
use futures::future::join_all;
use futures::prelude::*;
use generational_arena::Arena;
use tokio::net::UnixListener;
use tracing::{error, info, warn};
use crate::models::Message;
use crate::Error;
use crate::{
message_repo::Db,
models::{self, MinMessage},
ntfy_capnp::ntfy_proxy,
ntfy_capnp::{output_channel, subscription, system_notifier, watch_handle, Status},
ntfy_proxy::NtfyProxyImpl,
};
const MESSAGE_THROTTLE: Duration = Duration::from_millis(150);
impl From<Error> for capnp::Error {
fn from(value: Error) -> Self {
capnp::Error::failed(format!("{:?}", value))
}
}
pub struct NotifyForwarder {
model: Rc<RefCell<models::Subscription>>,
db: Db,
watching: Weak<RefCell<Arena<output_channel::Client>>>,
status: Rc<Cell<Status>>,
}
impl NotifyForwarder {
pub fn new(
model: Rc<RefCell<models::Subscription>>,
db: Db,
watching: Weak<RefCell<Arena<output_channel::Client>>>,
status: Rc<Cell<Status>>,
) -> Self {
Self {
model,
db,
watching,
status,
}
}
}
impl output_channel::Server for NotifyForwarder {
// Stores the message, sends a system notification, forwards the message to watching clients
fn send_message(
&mut self,
params: output_channel::SendMessageParams,
_results: output_channel::SendMessageResults,
) -> capnp::capability::Promise<(), capnp::Error> {
let request = pry!(params.get());
let message = pry!(request.get_message());
// Store in database
let already_stored: bool = {
// If this fails parsing, the message is not valid at all.
// The server is probably misbehaving.
let min_message: MinMessage = pry!(serde_json::from_str(&message)
.map_err(|e| Error::InvalidMinMessage(message.to_string(), e)));
let model = self.model.borrow();
match self.db.insert_message(&model.server, message) {
Err(Error::DuplicateMessage) => {
warn!(min_message = ?min_message, "Received duplicate message");
true
}
Err(e) => {
error!(min_message = ?min_message, error = ?e, "Can't store the message");
false
}
_ => false,
}
};
if !already_stored {
// Show notification
// Our priority is to show notifications. If anything fails, panic.
if !{ self.model.borrow().muted } {
let msg: Message = pry!(serde_json::from_str(&message)
.map_err(|e| Error::InvalidMessage(message.to_string(), e)));
tokio::task::spawn_local(async move {
let proxy = match NotificationProxy::new().await {
Ok(p) => p,
Err(e) => {
panic!("Can't show notification: {:?}", e);
}
};
let title = msg.display_title();
let title = title.as_ref().map(|x| x.as_str()).unwrap_or(&msg.topic);
let n = Notification::new(&title).body(
msg.display_message()
.as_ref()
.map(|x| x.as_str())
.unwrap_or(""),
);
let notification_id = "com.ranfdev.Notify";
info!("Showing notification");
proxy.add_notification(notification_id, n).await.unwrap();
});
}
// Forward
if let Some(watching) = self.watching.upgrade() {
let watching = watching.borrow();
let futs = watching.iter().map(|(_id, w)| {
let mut req = w.send_message_request();
req.get().set_message(message);
async move {
if let Err(e) = req.send().promise.await {
error!(error = ?e, "Error forwarding");
}
}
});
tokio::task::spawn_local(join_all(futs));
}
}
Promise::from_future(async move {
// some backpressure
tokio::time::sleep(MESSAGE_THROTTLE).await;
Ok(())
})
}
fn send_status(
&mut self,
params: output_channel::SendStatusParams,
_: output_channel::SendStatusResults,
) -> capnp::capability::Promise<(), capnp::Error> {
let status = pry!(pry!(params.get()).get_status());
if let Some(watching) = self.watching.upgrade() {
for (_, w) in watching.borrow().iter() {
let mut req = w.send_status_request();
req.get().set_status(status);
tokio::task::spawn_local(async move {
req.send().promise.await.unwrap();
});
}
}
self.status.set(status);
Promise::ok(())
}
}
struct WatcherImpl {
id: generational_arena::Index,
watchers: Weak<RefCell<Arena<output_channel::Client>>>,
}
impl watch_handle::Server for WatcherImpl {}
impl Drop for WatcherImpl {
fn drop(&mut self) {
if let Some(w) = self.watchers.upgrade() {
w.borrow_mut().remove(self.id);
}
}
}
pub struct SubscriptionImpl {
model: Rc<RefCell<models::Subscription>>,
db: Db,
server: ntfy_proxy::Client,
server_watch_handle: OnceCell<watch_handle::Client>,
watchers: Rc<RefCell<Arena<output_channel::Client>>>,
status: Rc<Cell<Status>>,
}
impl SubscriptionImpl {
fn new(model: models::Subscription, server: ntfy_proxy::Client, db: Db) -> Self {
Self {
model: Rc::new(RefCell::new(model)),
server,
db,
watchers: Default::default(),
server_watch_handle: Default::default(),
status: Rc::new(Cell::new(Status::Down)),
}
}
fn output_channel(&self) -> NotifyForwarder {
NotifyForwarder::new(
self.model.clone(),
self.db.clone(),
Rc::downgrade(&self.watchers),
self.status.clone(),
)
}
}
impl subscription::Server for SubscriptionImpl {
fn watch(
&mut self,
params: subscription::WatchParams,
mut results: subscription::WatchResults,
) -> capnp::capability::Promise<(), capnp::Error> {
let watcher = pry!(pry!(params.get()).get_watcher());
let since = pry!(params.get()).get_since();
// Send old messages
let msgs = {
let model = self.model.borrow();
pry!(self
.db
.list_messages(&model.server, &model.topic, since)
.map_err(Error::Db))
};
let futs = msgs.into_iter().map(move |msg| {
let mut req = watcher.send_message_request();
req.get().set_message(&msg);
req.send().promise
});
let watcher = pry!(pry!(params.get()).get_watcher());
let mut req = watcher.send_status_request();
req.get().set_status(self.status.get());
let id = { self.watchers.borrow_mut().insert(watcher) };
results.get().set_handle(capnp_rpc::new_client(WatcherImpl {
id,
watchers: Rc::downgrade(&self.watchers),
}));
Promise::from_future(async move {
futures::future::try_join_all(futs).await?;
req.send().promise.await?;
Ok(())
})
}
fn publish(
&mut self,
params: subscription::PublishParams,
_results: subscription::PublishResults,
) -> capnp::capability::Promise<(), capnp::Error> {
let msg = pry!(pry!(params.get()).get_message());
let mut req = self.server.publish_request();
req.get().set_message(msg);
Promise::from_future(async move {
req.send().promise.await?;
Ok(())
})
}
fn get_info(
&mut self,
_: subscription::GetInfoParams,
mut results: subscription::GetInfoResults,
) -> capnp::capability::Promise<(), capnp::Error> {
let mut res = results.get();
let model = self.model.borrow();
res.set_server(&model.server);
res.set_display_name(&model.display_name);
res.set_topic(&model.topic);
res.set_muted(model.muted);
res.set_read_until(model.read_until);
Promise::ok(())
}
fn update_info(
&mut self,
params: subscription::UpdateInfoParams,
_results: subscription::UpdateInfoResults,
) -> capnp::capability::Promise<(), capnp::Error> {
let info = pry!(pry!(params.get()).get_value());
let mut model = self.model.borrow_mut();
model.display_name = pry!(info.get_display_name()).to_string();
model.muted = info.get_muted();
model.read_until = info.get_read_until();
pry!(self.db.update_subscription(model.clone()));
Promise::ok(())
}
fn clear_notifications(
&mut self,
_params: subscription::ClearNotificationsParams,
_results: subscription::ClearNotificationsResults,
) -> capnp::capability::Promise<(), capnp::Error> {
let model = self.model.borrow_mut();
pry!(self.db.delete_messages(&model.server, &model.topic));
Promise::ok(())
}
fn update_read_until(
&mut self,
params: subscription::UpdateReadUntilParams,
_: subscription::UpdateReadUntilResults,
) -> capnp::capability::Promise<(), capnp::Error> {
let value = pry!(params.get()).get_value();
let mut model = self.model.borrow_mut();
pry!(self
.db
.update_read_until(&model.server, &model.topic, value));
model.read_until = value;
Promise::ok(())
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct WatchKey {
server: String,
topic: String,
}
pub struct SystemNotifier {
servers: HashMap<String, ntfy_proxy::Client>,
watching: Rc<RefCell<HashMap<WatchKey, subscription::Client>>>,
db: Db,
}
impl SystemNotifier {
pub fn new(dbpath: &str) -> Self {
Self {
servers: HashMap::new(),
watching: Rc::new(RefCell::new(HashMap::new())),
db: Db::connect(dbpath).unwrap(),
}
}
fn watch(&mut self, sub: models::Subscription) -> Promise<subscription::Client, capnp::Error> {
let ntfy = self
.servers
.entry(sub.server.to_owned())
.or_insert_with(|| capnp_rpc::new_client(NtfyProxyImpl::new(sub.server.to_owned())));
let subscription = SubscriptionImpl::new(sub.clone(), ntfy.clone(), self.db.clone());
let mut req = ntfy.watch_request();
req.get().set_topic(&sub.topic);
req.get()
.set_watcher(capnp_rpc::new_client(subscription.output_channel()));
let res = req.send();
let handle = res.pipeline.get_handle();
subscription
.server_watch_handle
.set(handle)
.map_err(|_| "already set")
.unwrap();
let watching = self.watching.clone();
let subc: subscription::Client = capnp_rpc::new_client(subscription);
Promise::from_future(async move {
res.promise
.await
.map_err(|e| capnp::Error::failed(e.to_string()))?;
watching.borrow_mut().insert(
WatchKey {
server: sub.server.to_owned(),
topic: sub.topic.to_owned(),
},
subc.clone(),
);
Ok(subc)
})
}
pub fn watch_subscribed(&mut self) -> Promise<(), capnp::Error> {
let f: Vec<_> = pry!(self.db.list_subscriptions())
.into_iter()
.map(|m| self.watch(m.clone()))
.collect();
Promise::from_future(async move {
join_all(f.into_iter().map(|x| async move {
if let Err(e) = x.await {
error!(error = ?e, "Can't rewatch subscribed topic");
}
}))
.await;
Ok(())
})
}
}
impl system_notifier::Server for SystemNotifier {
fn subscribe(
&mut self,
params: system_notifier::SubscribeParams,
mut results: system_notifier::SubscribeResults,
) -> capnp::capability::Promise<(), capnp::Error> {
let topic = pry!(pry!(params.get()).get_topic());
let server: &str = pry!(pry!(params.get()).get_server());
let server = if server.is_empty() {
"https://ntfy.sh"
} else {
""
};
let subscription = pry!(
models::Subscription::builder(server.to_owned(), topic.to_owned())
.build()
.map_err(|e| capnp::Error::failed(e.to_string()))
);
let sub: Promise<subscription::Client, capnp::Error> = self.watch(subscription.clone());
let mut db = self.db.clone();
Promise::from_future(async move {
results.get().set_subscription(sub.await?);
db.insert_subscription(subscription).map_err(|e| {
capnp::Error::failed(format!("could not insert subscription: {}", e))
})?;
Ok(())
})
}
fn unsubscribe(
&mut self,
params: system_notifier::UnsubscribeParams,
_results: system_notifier::UnsubscribeResults,
) -> capnp::capability::Promise<(), capnp::Error> {
let topic = pry!(pry!(params.get()).get_topic());
let server = pry!(pry!(params.get()).get_server());
{
self.watching.borrow_mut().remove(&WatchKey {
server: server.to_string(),
topic: topic.to_string(),
});
pry!(self
.db
.remove_subscription(&server, &topic)
.map_err(|e| capnp::Error::failed(e.to_string())));
info!(server, topic, "Unsubscribed");
}
Promise::ok(())
}
fn list_subscriptions(
&mut self,
_: system_notifier::ListSubscriptionsParams,
mut results: system_notifier::ListSubscriptionsResults,
) -> capnp::capability::Promise<(), capnp::Error> {
let req = results.get();
let values = self.watching.borrow().values().cloned().collect::<Vec<_>>();
let mut list = req.init_list(values.len() as u32);
for (i, v) in values.iter().enumerate() {
use capnp::capability::FromClientHook;
list.set(i as u32, v.clone().clone().into_client_hook());
}
Promise::ok(())
}
}
pub fn start(socket_path: std::path::PathBuf, dbpath: &str) -> anyhow::Result<()> {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
let listener = rt.block_on(async move {
let _ = std::fs::remove_file(&socket_path);
UnixListener::bind(&socket_path).unwrap()
});
let dbpath = dbpath.to_owned();
let f = move || {
let local = tokio::task::LocalSet::new();
let mut system_notifier = SystemNotifier::new(&dbpath);
local.spawn_local(async move {
system_notifier.watch_subscribed().await.unwrap();
let system_client: system_notifier::Client = capnp_rpc::new_client(system_notifier);
loop {
match listener.accept().await {
Ok((stream, _addr)) => {
info!("client connected");
let (reader, writer) =
tokio_util::compat::TokioAsyncReadCompatExt::compat(stream).split();
let network = twoparty::VatNetwork::new(
reader,
writer,
rpc_twoparty_capnp::Side::Server,
Default::default(),
);
let rpc_system =
RpcSystem::new(Box::new(network), Some(system_client.clone().client));
tokio::task::spawn_local(rpc_system);
}
Err(e) => {
error!(error=%e);
}
}
}
});
rt.block_on(local);
};
std::thread::spawn(move || {
f();
});
Ok(())
}