From b5209d2df69128011993328b63b2fe5a2fc289cb Mon Sep 17 00:00:00 2001 From: Nick Alexander Date: Thu, 19 Sep 2024 15:52:20 -0700 Subject: [PATCH] WIP: feat: add Windows Push Notification Services (WNS) router. (#775) The relevant settings are like: ``` [wns] credentials = """{ "nightly": { "project_id": "3e776f1c-927b-4918-9924-c7d292ddf732", "credential": "{\\"client_id\\": \\"1f39a2f4-4621-49e7-bde8-fc4c90b9e783\\", \\"client_secret\\": \\"REDACTED\\"}" } }""" ``` Here `project_id` is my personal account's Microsoft Entra "tenant ID" and `client_id` is my personal account's Microsoft Entra "application ID". Different Firefox channels would share a single `project_id` value while having varying `client_id` values. --- .../src/extractors/router_data_input.rs | 2 +- autoendpoint/src/extractors/routers.rs | 7 + autoendpoint/src/routers/mod.rs | 11 + autoendpoint/src/routers/wns/client.rs | 493 ++++++++++++++++++ autoendpoint/src/routers/wns/error.rs | 112 ++++ autoendpoint/src/routers/wns/mod.rs | 6 + autoendpoint/src/routers/wns/router.rs | 414 +++++++++++++++ autoendpoint/src/routers/wns/settings.rs | 65 +++ autoendpoint/src/routes/health.rs | 1 + autoendpoint/src/server.rs | 14 +- autoendpoint/src/settings.rs | 3 + 11 files changed, 1126 insertions(+), 2 deletions(-) create mode 100644 autoendpoint/src/routers/wns/client.rs create mode 100644 autoendpoint/src/routers/wns/error.rs create mode 100644 autoendpoint/src/routers/wns/mod.rs create mode 100644 autoendpoint/src/routers/wns/router.rs create mode 100644 autoendpoint/src/routers/wns/settings.rs diff --git a/autoendpoint/src/extractors/router_data_input.rs b/autoendpoint/src/extractors/router_data_input.rs index dd83975ef..ef93f65f2 100644 --- a/autoendpoint/src/extractors/router_data_input.rs +++ b/autoendpoint/src/extractors/router_data_input.rs @@ -38,7 +38,7 @@ impl FromRequest for RouterDataInput { // Validate the token according to each router's token schema let is_valid = match path_args.router_type { RouterType::WebPush => true, - RouterType::FCM | RouterType::GCM | RouterType::APNS => { + RouterType::FCM | RouterType::GCM | RouterType::APNS | RouterType::WNS => { VALID_TOKEN.is_match(&data.token) } #[cfg(feature = "stub")] diff --git a/autoendpoint/src/extractors/routers.rs b/autoendpoint/src/extractors/routers.rs index c4c012a29..b3a907d3e 100644 --- a/autoendpoint/src/extractors/routers.rs +++ b/autoendpoint/src/extractors/routers.rs @@ -1,6 +1,7 @@ use crate::error::{ApiError, ApiResult}; use crate::routers::apns::router::ApnsRouter; use crate::routers::fcm::router::FcmRouter; +use crate::routers::wns::router::WnsRouter; #[cfg(feature = "stub")] use crate::routers::stub::router::StubRouter; use crate::routers::webpush::WebPushRouter; @@ -21,6 +22,7 @@ pub enum RouterType { FCM, GCM, APNS, + WNS, #[cfg(feature = "stub")] STUB, } @@ -34,6 +36,7 @@ impl FromStr for RouterType { "fcm" => Ok(RouterType::FCM), "gcm" => Ok(RouterType::GCM), "apns" => Ok(RouterType::APNS), + "wns" => Ok(RouterType::WNS), #[cfg(feature = "stub")] "stub" => Ok(RouterType::STUB), _ => Err(()), @@ -48,6 +51,7 @@ impl Display for RouterType { RouterType::FCM => "fcm", RouterType::GCM => "gcm", RouterType::APNS => "apns", + RouterType::WNS => "wns", #[cfg(feature = "stub")] RouterType::STUB => "stub", }) @@ -60,6 +64,7 @@ pub struct Routers { webpush: WebPushRouter, fcm: Arc, apns: Arc, + wns: Arc, #[cfg(feature = "stub")] stub: Arc, } @@ -82,6 +87,7 @@ impl FromRequest for Routers { }, fcm: app_state.fcm_router.clone(), apns: app_state.apns_router.clone(), + wns: app_state.wns_router.clone(), #[cfg(feature = "stub")] stub: app_state.stub_router.clone(), }) @@ -95,6 +101,7 @@ impl Routers { RouterType::WebPush => &self.webpush, RouterType::FCM | RouterType::GCM => self.fcm.as_ref(), RouterType::APNS => self.apns.as_ref(), + RouterType::WNS => self.wns.as_ref(), #[cfg(feature = "stub")] RouterType::STUB => self.stub.as_ref(), } diff --git a/autoendpoint/src/routers/mod.rs b/autoendpoint/src/routers/mod.rs index 766bd8dc9..9c4f3121e 100644 --- a/autoendpoint/src/routers/mod.rs +++ b/autoendpoint/src/routers/mod.rs @@ -5,6 +5,7 @@ use crate::extractors::notification::Notification; use crate::extractors::router_data_input::RouterDataInput; use crate::routers::apns::error::ApnsError; use crate::routers::fcm::error::FcmError; +use crate::routers::wns::error::WnsError; use autopush_common::db::error::DbError; @@ -23,6 +24,7 @@ pub mod fcm; #[cfg(feature = "stub")] pub mod stub; pub mod webpush; +pub mod wns; #[async_trait(?Send)] pub trait Router { @@ -82,6 +84,9 @@ pub enum RouterError { #[error(transparent)] Fcm(#[from] FcmError), + #[error(transparent)] + Wns(#[from] WnsError), + #[cfg(feature = "stub")] #[error(transparent)] Stub(#[from] StubError), @@ -123,6 +128,7 @@ impl RouterError { match self { RouterError::Apns(e) => e.status(), RouterError::Fcm(e) => StatusCode::from_u16(e.status().as_u16()).unwrap_or_default(), + RouterError::Wns(e) => StatusCode::from_u16(e.status().as_u16()).unwrap_or_default(), RouterError::SaveDb(e, _) => e.status(), #[cfg(feature = "stub")] @@ -145,6 +151,7 @@ impl RouterError { match self { RouterError::Apns(e) => e.errno(), RouterError::Fcm(e) => e.errno(), + RouterError::Wns(e) => e.errno(), #[cfg(feature = "stub")] RouterError::Stub(e) => e.errno(), @@ -175,6 +182,7 @@ impl ReportableError for RouterError { match &self { RouterError::Apns(e) => Some(e), RouterError::Fcm(e) => Some(e), + RouterError::Wns(e) => Some(e), RouterError::SaveDb(e, _) => Some(e), _ => None, } @@ -185,6 +193,7 @@ impl ReportableError for RouterError { // apns handle_error emits a metric for ApnsError::Unregistered RouterError::Apns(e) => e.is_sentry_event(), RouterError::Fcm(e) => e.is_sentry_event(), + RouterError::Wns(e) => e.is_sentry_event(), // common handle_error emits metrics for these RouterError::Authentication | RouterError::GCMAuthentication @@ -205,6 +214,7 @@ impl ReportableError for RouterError { match self { RouterError::Apns(e) => e.metric_label(), RouterError::Fcm(e) => e.metric_label(), + RouterError::Wns(e) => e.metric_label(), RouterError::TooMuchData(_) => Some("notification.bridge.error.too_much_data"), _ => None, } @@ -214,6 +224,7 @@ impl ReportableError for RouterError { match &self { RouterError::Apns(e) => e.extras(), RouterError::Fcm(e) => e.extras(), + RouterError::Wns(e) => e.extras(), RouterError::SaveDb(e, sub) => { let mut extras = e.extras(); if let Some(sub) = sub { diff --git a/autoendpoint/src/routers/wns/client.rs b/autoendpoint/src/routers/wns/client.rs new file mode 100644 index 000000000..f50bfc6ff --- /dev/null +++ b/autoendpoint/src/routers/wns/client.rs @@ -0,0 +1,493 @@ +use crate::routers::common::message_size_check; +use crate::routers::wns::error::WnsError; +use crate::routers::wns::settings::{WnsServerCredential, WnsSettings}; +use crate::routers::RouterError; +use reqwest::StatusCode; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::path::Path; +use std::time::Duration; +use url::Url; + +use url::form_urlencoded; + +const SCOPE: &str = "https://wns.windows.com/.default"; + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct AuthorizedUserSecret { + /// client_id + pub client_id: String, + /// client_secret + pub client_secret: String, +} + +pub struct Authenticator { + pub(crate) tenant_id: String, + pub(crate) secret: AuthorizedUserSecret, + pub(crate) timeout: Duration, + pub(crate) http_client: reqwest::Client, +} + +impl Authenticator { + pub async fn token(&self, _scopes: &[&str]) -> Result { + let req = form_urlencoded::Serializer::new(String::new()) + .extend_pairs(&[ + ("client_id", self.secret.client_id.as_str()), + ("client_secret", self.secret.client_secret.as_str()), + ("scope", SCOPE), + ("grant_type", "client_credentials"), + ]) + .finish(); + + let url = format!("https://login.microsoftonline.com/{}/oauth2/v2.0/token", self.tenant_id); + trace!("Contacting url: {}", url); + + // Make the request + let response = self + .http_client + .post(url) + // .header("Authorization", format!("Bearer {}", token)) + .header("Content-Type", "application/x-www-form-urlencoded") + .body(req) + .timeout(self.timeout) + .send() + .await + .map_err(|e| { + if e.is_timeout() { + RouterError::RequestTimeout + } else { + RouterError::Connect(e) + } + })?; + + // Handle error + let status = response.status(); + let raw_data = response + .bytes() + .await + .map_err(WnsError::DeserializeResponse)?; + + if raw_data.is_empty() { + warn!("Empty WNS auth response [{status}]"); + return Err(WnsError::EmptyResponse(status).into()); + } + + let data: WnsAuthResponse = serde_json::from_slice(&raw_data).map_err(|e| { + let s = String::from_utf8(raw_data.to_vec()).unwrap_or_else(|e| e.to_string()); + warn!("Invalid WNS auth response [{status}] \"{s}\""); + WnsError::InvalidResponse(e, s, status) + })?; + + if status.is_client_error() || status.is_server_error() { + // we only ever send one. + return Err(match (status, data.error) { + (StatusCode::UNAUTHORIZED, _) => RouterError::Authentication, + (StatusCode::NOT_FOUND, _) => RouterError::NotFound, + (_, Some(error)) => RouterError::Upstream { + status: error.status, + message: error.message, + }, + (status, None) => RouterError::Upstream { + status: status.to_string(), + message: "Unknown reason".to_string(), + }, + }); + } + + // XXX Probably not the right error. + trace!("Got token '{:?}'", data.access_token); + data.access_token.ok_or(RouterError::Wns(WnsError::NoOAuthToken)) + } +} + +const OAUTH_SCOPES: &[&str] = &["https://wns.windows.com/.default"]; + +/// Holds application-specific Firebase data and authentication. This client +/// handles sending notifications to Firebase. +pub struct WnsClient { + endpoint: Url, + timeout: Duration, + max_data: usize, + authenticator: Option, + http_client: reqwest::Client, +} + +impl WnsClient { + /// Create an `WnsClient` using the provided credential + pub async fn new( + settings: &WnsSettings, + server_credential: WnsServerCredential, + http: reqwest::Client, + ) -> std::io::Result { + // `map`ping off of `serde_json::from_str` gets hairy and weird, requiring + // async blocks and a number of other specialty items. Doing a very stupid + // json detection does not. WNS keys are serialized JSON constructs. + // These are both set in the settings and come from the `credentials` value. + // let key_data = if server_credential.server_access_token.contains('{') { + + trace!( + "Reading credential for {} from string {}...", + &server_credential.project_id, + &server_credential.server_access_token, + ); + let key_data = serde_json::from_str::(&server_credential.server_access_token)?; + // // Some( + // // ServiceAccountAuthenticator::builder(key_data) + // // .build() + // // .await?, + // // ) + // } else { + // // check to see if this is a path to a file, and read in the credentials. + // warn!( + // "Reading credential for {} from file...", + // &server_credential.project_id + // ); + // let content = std::fs::read_to_string(&server_credential.server_access_token)?; + // serde_json::from_str::(&content) + // } + // }; + + let auth = Authenticator { + tenant_id: server_credential.project_id.clone(), + secret: key_data, + timeout: Duration::from_secs(settings.timeout as u64), + http_client: http.clone(), + }; + + Ok(WnsClient { + endpoint: settings + .base_url + .join(&format!( + "v1/projects/{}/messages:send", + server_credential.project_id + )) + .expect("Project ID is not URL-safe"), + timeout: Duration::from_secs(settings.timeout as u64), + max_data: settings.max_data, + authenticator: Some(auth), + http_client: http, + }) + } + + /// Send the message data to WNS + pub async fn send( + &self, + data: HashMap<&'static str, String>, + routing_token: String, + ttl: usize, + ) -> Result<(), RouterError> { + // Check the payload size. WNS only cares about the `data` field when + // checking size. + let data_json = serde_json::to_string(&data).unwrap(); + message_size_check(data_json.as_bytes(), self.max_data)?; + + let server_access_token = self + .authenticator + .as_ref() + .unwrap() + .token(OAUTH_SCOPES) + .await?; + + // Make the request + let response = self + .http_client + .post(routing_token.clone()) // Routing token is the WNS endpoint. TODO: verify endpoint origin! + .header("Content-Type", "application/octet-stream") + .header("Authorization", format!("Bearer {}", server_access_token)) + .header("X-WNS-Type", "wns/raw") + .json(&data) + .timeout(self.timeout) + .send() + .await + .map_err(|e| { + if e.is_timeout() { + RouterError::RequestTimeout + } else { + RouterError::Connect(e) + } + })?; + + // Handle error + let status = response.status(); + if status.is_client_error() || status.is_server_error() { + let raw_data = response + .bytes() + .await + .map_err(WnsError::DeserializeResponse)?; + if raw_data.is_empty() { + warn!("Empty WNS response [{status}]"); + return Err(WnsError::EmptyResponse(status).into()); + } + let data: WnsResponse = serde_json::from_slice(&raw_data).map_err(|e| { + let s = String::from_utf8(raw_data.to_vec()).unwrap_or_else(|e| e.to_string()); + warn!("Invalid WNS response [{status}] \"{s}\""); + WnsError::InvalidResponse(e, s, status) + })?; + + // we only ever send one. + return Err(match (status, data.error) { + (StatusCode::UNAUTHORIZED, _) => RouterError::Authentication, + (StatusCode::NOT_FOUND, _) => RouterError::NotFound, + (_, Some(error)) => RouterError::Upstream { + status: error.status, + message: error.message, + }, + (status, None) => RouterError::Upstream { + status: status.to_string(), + message: "Unknown reason".to_string(), + }, + }); + } + + Ok(()) + } +} + + +#[derive(Deserialize)] +struct WnsAuthResponse { + access_token: Option, + error: Option, +} + +#[derive(Deserialize)] +struct WnsResponse { + error: Option, +} + +#[derive(Deserialize)] +struct WnsErrorResponse { + status: String, + message: String, +} + +// #[cfg(test)] +// pub mod tests { +// use crate::routers::fcm::client::FcmClient; +// use crate::routers::fcm::settings::{FcmServerCredential, FcmSettings}; +// use crate::routers::RouterError; +// use std::collections::HashMap; +// use url::Url; + +// pub const PROJECT_ID: &str = "yup-test-243420"; +// const ACCESS_TOKEN: &str = "ya29.c.ElouBywiys0LyNaZoLPJcp1Fdi2KjFMxzvYKLXkTdvM-rDfqKlvEq6PiMhGoGHx97t5FAvz3eb_ahdwlBjSStxHtDVQB4ZPRJQ_EOi-iS7PnayahU2S9Jp8S6rk"; +// pub const GCM_PROJECT_ID: &str = "valid_gcm_access_token"; + +// /// Write service data to a temporary file +// pub fn make_service_key(server: &mockito::ServerGuard) -> String { +// // Taken from the yup-oauth2 tests +// serde_json::json!({ +// "type": "service_account", +// "project_id": PROJECT_ID, +// "private_key_id": "26de294916614a5ebdf7a065307ed3ea9941902b", +// "private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDemmylrvp1KcOn\n9yTAVVKPpnpYznvBvcAU8Qjwr2fSKylpn7FQI54wCk5VJVom0jHpAmhxDmNiP8yv\nHaqsef+87Oc0n1yZ71/IbeRcHZc2OBB33/LCFqf272kThyJo3qspEqhuAw0e8neg\nLQb4jpm9PsqR8IjOoAtXQSu3j0zkXemMYFy93PWHjVpPEUX16NGfsWH7oxspBHOk\n9JPGJL8VJdbiAoDSDgF0y9RjJY5I52UeHNhMsAkTYs6mIG4kKXt2+T9tAyHw8aho\nwmuytQAfydTflTfTG8abRtliF3nil2taAc5VB07dP1b4dVYy/9r6M8Z0z4XM7aP+\nNdn2TKm3AgMBAAECggEAWi54nqTlXcr2M5l535uRb5Xz0f+Q/pv3ceR2iT+ekXQf\n+mUSShOr9e1u76rKu5iDVNE/a7H3DGopa7ZamzZvp2PYhSacttZV2RbAIZtxU6th\n7JajPAM+t9klGh6wj4jKEcE30B3XVnbHhPJI9TCcUyFZoscuPXt0LLy/z8Uz0v4B\nd5JARwyxDMb53VXwukQ8nNY2jP7WtUig6zwE5lWBPFMbi8GwGkeGZOruAK5sPPwY\nGBAlfofKANI7xKx9UXhRwisB4+/XI1L0Q6xJySv9P+IAhDUI6z6kxR+WkyT/YpG3\nX9gSZJc7qEaxTIuDjtep9GTaoEqiGntjaFBRKoe+VQKBgQDzM1+Ii+REQqrGlUJo\nx7KiVNAIY/zggu866VyziU6h5wjpsoW+2Npv6Dv7nWvsvFodrwe50Y3IzKtquIal\nVd8aa50E72JNImtK/o5Nx6xK0VySjHX6cyKENxHRDnBmNfbALRM+vbD9zMD0lz2q\nmns/RwRGq3/98EqxP+nHgHSr9QKBgQDqUYsFAAfvfT4I75Glc9svRv8IsaemOm07\nW1LCwPnj1MWOhsTxpNF23YmCBupZGZPSBFQobgmHVjQ3AIo6I2ioV6A+G2Xq/JCF\nmzfbvZfqtbbd+nVgF9Jr1Ic5T4thQhAvDHGUN77BpjEqZCQLAnUWJx9x7e2xvuBl\n1A6XDwH/ewKBgQDv4hVyNyIR3nxaYjFd7tQZYHTOQenVffEAd9wzTtVbxuo4sRlR\nNM7JIRXBSvaATQzKSLHjLHqgvJi8LITLIlds1QbNLl4U3UVddJbiy3f7WGTqPFfG\nkLhUF4mgXpCpkMLxrcRU14Bz5vnQiDmQRM4ajS7/kfwue00BZpxuZxst3QKBgQCI\nRI3FhaQXyc0m4zPfdYYVc4NjqfVmfXoC1/REYHey4I1XetbT9Nb/+ow6ew0UbgSC\nUZQjwwJ1m1NYXU8FyovVwsfk9ogJ5YGiwYb1msfbbnv/keVq0c/Ed9+AG9th30qM\nIf93hAfClITpMz2mzXIMRQpLdmQSR4A2l+E4RjkSOwKBgQCB78AyIdIHSkDAnCxz\nupJjhxEhtQ88uoADxRoEga7H/2OFmmPsqfytU4+TWIdal4K+nBCBWRvAX1cU47vH\nJOlSOZI0gRKe0O4bRBQc8GXJn/ubhYSxI02IgkdGrIKpOb5GG10m85ZvqsXw3bKn\nRVHMD0ObF5iORjZUqD0yRitAdg==\n-----END PRIVATE KEY-----\n", +// "client_email": "yup-test-sa-1@yup-test-243420.iam.gserviceaccount.com", +// "client_id": "102851967901799660408", +// "auth_uri": "https://accounts.google.com/o/oauth2/auth", +// "token_uri": server.url() + "/token", +// "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", +// "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/yup-test-sa-1%40yup-test-243420.iam.gserviceaccount.com" +// }).to_string() +// } + +// /// Mock the OAuth token endpoint to provide the access token +// pub async fn mock_token_endpoint(server: &mut mockito::ServerGuard) -> mockito::Mock { +// server +// .mock("POST", "/token") +// .with_body( +// serde_json::json!({ +// "access_token": ACCESS_TOKEN, +// "expires_in": 3600, +// "token_type": "Bearer" +// }) +// .to_string(), +// ) +// .create_async() +// .await +// } + +// /// Start building a mock for the FCM endpoint +// pub fn mock_fcm_endpoint_builder(server: &mut mockito::ServerGuard, id: &str) -> mockito::Mock { +// server.mock("POST", format!("/v1/projects/{id}/messages:send").as_str()) +// } + +// /// Make a FcmClient from the service auth data +// async fn make_client( +// server: &mockito::ServerGuard, +// credential: FcmServerCredential, +// ) -> FcmClient { +// FcmClient::new( +// &FcmSettings { +// base_url: Url::parse(&server.url()).unwrap(), +// server_credentials: serde_json::json!(credential).to_string(), +// ..Default::default() +// }, +// credential, +// reqwest::Client::new(), +// ) +// .await +// .unwrap() +// } + +// /// The FCM client uses the access token and parameters to build the +// /// expected FCM request. +// #[tokio::test] +// async fn sends_correct_fcm_request() { +// let mut server = mockito::Server::new_async().await; + +// let client = make_client( +// &server, +// FcmServerCredential { +// project_id: PROJECT_ID.to_owned(), +// is_gcm: None, +// server_access_token: make_service_key(&server), +// }, +// ) +// .await; +// let _token_mock = mock_token_endpoint(&mut server).await; +// let fcm_mock = mock_fcm_endpoint_builder(&mut server, PROJECT_ID) +// .match_header("Authorization", format!("Bearer {ACCESS_TOKEN}").as_str()) +// .match_header("Content-Type", "application/json") +// .match_body(r#"{"message":{"android":{"data":{"is_test":"true"},"ttl":"42s"},"token":"test-token"}}"#) +// .create(); + +// let mut data = HashMap::new(); +// data.insert("is_test", "true".to_string()); + +// let result = client.send(data, "test-token".to_string(), 42).await; +// assert!(result.is_ok(), "result = {result:?}"); +// fcm_mock.assert(); +// } + +// /// Authorization errors are handled +// #[tokio::test] +// async fn unauthorized() { +// let mut server = mockito::Server::new_async().await; + +// let client = make_client( +// &server, +// FcmServerCredential { +// project_id: PROJECT_ID.to_owned(), +// is_gcm: None, +// server_access_token: make_service_key(&server), +// }, +// ) +// .await; +// let _token_mock = mock_token_endpoint(&mut server).await; +// let _fcm_mock = mock_fcm_endpoint_builder(&mut server, PROJECT_ID) +// .with_status(401) +// .with_body(r#"{"error":{"status":"UNAUTHENTICATED","message":"test-message"}}"#) +// .create_async() +// .await; + +// let result = client +// .send(HashMap::new(), "test-token".to_string(), 42) +// .await; +// assert!(result.is_err()); +// assert!( +// matches!(result.as_ref().unwrap_err(), RouterError::Authentication), +// "result = {result:?}" +// ); +// } + +// /// 404 errors are handled +// #[tokio::test] +// async fn not_found() { +// let mut server = mockito::Server::new_async().await; + +// let client = make_client( +// &server, +// FcmServerCredential { +// project_id: PROJECT_ID.to_owned(), +// is_gcm: None, +// server_access_token: make_service_key(&server), +// }, +// ) +// .await; +// let _token_mock = mock_token_endpoint(&mut server).await; +// let _fcm_mock = mock_fcm_endpoint_builder(&mut server, PROJECT_ID) +// .with_status(404) +// .with_body(r#"{"error":{"status":"NOT_FOUND","message":"test-message"}}"#) +// .create_async() +// .await; + +// let result = client +// .send(HashMap::new(), "test-token".to_string(), 42) +// .await; +// assert!(result.is_err()); +// assert!( +// matches!(result.as_ref().unwrap_err(), RouterError::NotFound), +// "result = {result:?}" +// ); +// } + +// /// Unhandled errors (where an error object is returned) are wrapped and returned +// #[tokio::test] +// async fn other_fcm_error() { +// let mut server = mockito::Server::new_async().await; + +// let client = make_client( +// &server, +// FcmServerCredential { +// project_id: PROJECT_ID.to_owned(), +// is_gcm: Some(false), +// server_access_token: make_service_key(&server), +// }, +// ) +// .await; +// let _token_mock = mock_token_endpoint(&mut server).await; +// let _fcm_mock = mock_fcm_endpoint_builder(&mut server, PROJECT_ID) +// .with_status(400) +// .with_body(r#"{"error":{"status":"TEST_ERROR","message":"test-message"}}"#) +// .create_async() +// .await; + +// let result = client +// .send(HashMap::new(), "test-token".to_string(), 42) +// .await; +// assert!(result.is_err()); +// assert!( +// matches!( +// result.as_ref().unwrap_err(), +// RouterError::Upstream { status, message } +// if status == "TEST_ERROR" && message == "test-message" +// ), +// "result = {result:?}" +// ); +// } + +// /// Unknown errors (where an error object is NOT returned) is handled +// #[tokio::test] +// async fn unknown_fcm_error() { +// let mut server = mockito::Server::new_async().await; + +// let client = make_client( +// &server, +// FcmServerCredential { +// project_id: PROJECT_ID.to_owned(), +// is_gcm: Some(true), +// server_access_token: make_service_key(&server), +// }, +// ) +// .await; +// let _token_mock = mock_token_endpoint(&mut server).await; +// let _fcm_mock = mock_fcm_endpoint_builder(&mut server, PROJECT_ID) +// .with_status(400) +// .with_body("{}") +// .create_async() +// .await; + +// let result = client +// .send(HashMap::new(), "test-token".to_string(), 42) +// .await; +// assert!(result.is_err()); +// assert!( +// matches!( +// result.as_ref().unwrap_err(), +// RouterError::Upstream { status, message } +// if status == "400 Bad Request" && message == "Unknown reason" +// ), +// "result = {result:?}" +// ); +// } +// } diff --git a/autoendpoint/src/routers/wns/error.rs b/autoendpoint/src/routers/wns/error.rs new file mode 100644 index 000000000..504788911 --- /dev/null +++ b/autoendpoint/src/routers/wns/error.rs @@ -0,0 +1,112 @@ +use crate::error::ApiErrorKind; +use crate::routers::RouterError; + +use autopush_common::errors::ReportableError; +use reqwest::StatusCode; + +/// Errors that may occur in the Firebase Cloud Messaging router +#[derive(thiserror::Error, Debug)] +pub enum WnsError { + #[error("Failed to decode the credential settings")] + CredentialDecode(#[from] serde_json::Error), + + #[error("Error while building the OAuth client")] + OAuthClientBuild(#[source] std::io::Error), + + #[error("Error while retrieving an OAuth token")] + OAuthToken(#[from] yup_oauth2::Error), + + #[error("Unable to deserialize WNS response")] + DeserializeResponse(#[source] reqwest::Error), + + #[error("Invalid JSON response from WNS")] + InvalidResponse(#[source] serde_json::Error, String, StatusCode), + + #[error("Empty response from WNS")] + EmptyResponse(StatusCode), + + #[error("No OAuth token was present")] + NoOAuthToken, + + #[error("No registration token found for user")] + NoRegistrationToken, + + #[error("No app ID found for user")] + NoAppId, + + #[error("User has invalid app ID {0}")] + InvalidAppId(String), +} + +impl WnsError { + /// Get the associated HTTP status code + pub fn status(&self) -> StatusCode { + match self { + WnsError::NoRegistrationToken | WnsError::NoAppId | WnsError::InvalidAppId(_) => { + StatusCode::GONE + } + + WnsError::CredentialDecode(_) + | WnsError::OAuthClientBuild(_) + | WnsError::OAuthToken(_) + | WnsError::NoOAuthToken => StatusCode::INTERNAL_SERVER_ERROR, + + WnsError::DeserializeResponse(_) + | WnsError::EmptyResponse(_) + | WnsError::InvalidResponse(_, _, _) => StatusCode::BAD_GATEWAY, + } + } + + /// Get the associated error number + pub fn errno(&self) -> Option { + match self { + WnsError::NoRegistrationToken | WnsError::NoAppId | WnsError::InvalidAppId(_) => { + Some(106) + } + + WnsError::CredentialDecode(_) + | WnsError::OAuthClientBuild(_) + | WnsError::OAuthToken(_) + | WnsError::DeserializeResponse(_) + | WnsError::EmptyResponse(_) + | WnsError::InvalidResponse(_, _, _) + | WnsError::NoOAuthToken => None, + } + } +} + +impl From for ApiErrorKind { + fn from(e: WnsError) -> Self { + ApiErrorKind::Router(RouterError::Wns(e)) + } +} + +impl ReportableError for WnsError { + fn is_sentry_event(&self) -> bool { + matches!(&self, WnsError::InvalidAppId(_) | WnsError::NoAppId) + } + + fn metric_label(&self) -> Option<&'static str> { + match &self { + WnsError::InvalidAppId(_) | WnsError::NoAppId => { + Some("notification.bridge.error.wns.badappid") + } + _ => None, + } + } + + fn extras(&self) -> Vec<(&str, String)> { + match self { + WnsError::InvalidAppId(appid) => { + vec![("app_id", appid.to_string())] + } + WnsError::EmptyResponse(status) => { + vec![("status", status.to_string())] + } + WnsError::InvalidResponse(_, body, status) => { + vec![("status", status.to_string()), ("body", body.to_owned())] + } + _ => vec![], + } + } +} diff --git a/autoendpoint/src/routers/wns/mod.rs b/autoendpoint/src/routers/wns/mod.rs new file mode 100644 index 000000000..a185136d2 --- /dev/null +++ b/autoendpoint/src/routers/wns/mod.rs @@ -0,0 +1,6 @@ +//! A notification router for Windows devices, using Windows Push Notification Services. + +mod client; +pub mod error; +pub mod router; +pub mod settings; diff --git a/autoendpoint/src/routers/wns/router.rs b/autoendpoint/src/routers/wns/router.rs new file mode 100644 index 000000000..9fa4f0adc --- /dev/null +++ b/autoendpoint/src/routers/wns/router.rs @@ -0,0 +1,414 @@ +use autopush_common::db::client::DbClient; + +use crate::error::ApiResult; +use crate::extractors::notification::Notification; +use crate::extractors::router_data_input::RouterDataInput; +use crate::routers::common::{build_message_data, handle_error, incr_success_metrics}; +use crate::routers::wns::client::WnsClient; +use crate::routers::wns::error::WnsError; +use crate::routers::wns::settings::{WnsServerCredential, WnsSettings}; +use crate::routers::{Router, RouterError, RouterResponse}; +use async_trait::async_trait; +use cadence::StatsdClient; +use serde_json::Value; +use std::collections::HashMap; +use std::sync::Arc; +use url::Url; +use uuid::Uuid; + +/// 28 days +const MAX_TTL: usize = 28 * 24 * 60 * 60; + +/// Firebase Cloud Messaging router +pub struct WnsRouter { + settings: WnsSettings, + endpoint_url: Url, + metrics: Arc, + db: Box, + /// A map from application ID to an authenticated WNS client + clients: HashMap, +} + +impl WnsRouter { + /// Create a new `WnsRouter` + pub async fn new( + settings: WnsSettings, + endpoint_url: Url, + http: reqwest::Client, + metrics: Arc, + db: Box, + ) -> Result { + let server_credentials = settings.credentials()?; + let clients = Self::create_clients(&settings, server_credentials, http.clone()) + .await + .map_err(WnsError::OAuthClientBuild)?; + Ok(Self { + settings, + endpoint_url, + metrics, + db, + clients, + }) + } + + /// Create WNS clients for each application + async fn create_clients( + settings: &WnsSettings, + server_credentials: HashMap, + http: reqwest::Client, + ) -> std::io::Result> { + let mut clients = HashMap::new(); + + for (profile, server_credential) in server_credentials { + trace!("Inserting client {}: {:?}", profile, server_credential); + clients.insert( + profile, + WnsClient::new(settings, server_credential, http.clone()).await?, + ); + } + trace!("Initialized {} WNS clients", clients.len()); + Ok(clients) + } + + /// if we have any clients defined, this connection is "active" + pub fn active(&self) -> bool { + !self.clients.is_empty() + } + + /// Do the gauntlet check to get the routing credentials, these are the + /// sender/project ID, and the subscription specific user routing token. + /// WNS stores the values in the top hash as `token` & `app_id`. + /// If any of these error out, it's probably because of a corrupted key. + fn routing_info( + &self, + router_data: &HashMap, + uaid: &Uuid, + ) -> ApiResult<(String, String)> { + // let creds = router_data.get("creds").and_then(Value::as_object); + // // GCM and WNS both should store the client registration_token as token in the router_data. + // // There was some confusion about router table records that may store the client + // // routing token in `creds.auth`, but it's believed that this a duplicate of the + // // server authentication token and can be ignored since we use the value specified + // // in the settings. + let routing_token = match router_data.get("token").and_then(Value::as_str) { + Some(v) => v.to_owned(), + None => { + warn!("No Registration token found for user {}", uaid.to_string()); + return Err(WnsError::NoRegistrationToken.into()); + } + }; + let app_id = match router_data.get("app_id").and_then(Value::as_str) { + Some(v) => v.to_owned(), + None => { + warn!("No App_id found for user {}", uaid.to_string()); + return Err(WnsError::NoAppId.into()); + } + }; + Ok((routing_token, app_id)) + } +} + +#[async_trait(?Send)] +impl Router for WnsRouter { + fn register( + &self, + router_data_input: &RouterDataInput, + app_id: &str, + ) -> Result, RouterError> { + trace!("{} in {:?}", app_id, self.clients.keys()); + if !self.clients.contains_key(app_id) { + return Err(WnsError::InvalidAppId(app_id.to_owned()).into()); + } + + let mut router_data = HashMap::new(); + router_data.insert( + "token".to_string(), + serde_json::to_value(&router_data_input.token).unwrap(), + ); + router_data.insert("app_id".to_string(), serde_json::to_value(app_id).unwrap()); + + // TODO: round trip some profile identifier here? Or maybe + // map the "chid" provided? + + Ok(router_data) + } + + async fn route_notification(&self, notification: &Notification) -> ApiResult { + debug!( + "Sending WNS notification to UAID {}", + notification.subscription.user.uaid + ); + trace!("Notification = {:?}", notification); + + let router_data = notification + .subscription + .user + .router_data + .as_ref() + .ok_or(WnsError::NoRegistrationToken)?; + + let (routing_token, app_id) = + self.routing_info(router_data, ¬ification.subscription.user.uaid)?; + let ttl = MAX_TTL.min(self.settings.min_ttl.max(notification.headers.ttl as usize)); + + // Send the notification to WNS + let client = self + .clients + .get(&app_id) + .ok_or_else(|| WnsError::InvalidAppId(app_id.clone()))?; + + let message_data = build_message_data(notification)?; + let platform = "wnsv1"; + trace!("Sending message to {platform}: [{:?}]", &app_id); + if let Err(e) = client.send(message_data, routing_token, ttl).await { + trace!("Sending message to {platform}: [{:?}] error {:?}", &app_id, e); + return Err(handle_error( + e, + &self.metrics, + self.db.as_ref(), + platform, + &app_id, + notification.subscription.user.uaid, + notification.subscription.vapid.clone(), + ) + .await); + }; + incr_success_metrics(&self.metrics, platform, &app_id, notification); + // Sent successfully, update metrics and make response + trace!("Send request was successful"); + + Ok(RouterResponse::success( + self.endpoint_url + .join(&format!("/m/{}", notification.message_id)) + .expect("Message ID is not URL-safe") + .to_string(), + notification.headers.ttl as usize, + )) + } +} + +// #[cfg(test)] +// mod tests { +// use crate::error::ApiErrorKind; +// use crate::extractors::routers::RouterType; +// use crate::routers::common::tests::{make_notification, CHANNEL_ID}; +// use crate::routers::wns::client::tests::{ +// make_service_key, mock_wns_endpoint_builder, mock_token_endpoint, GCM_PROJECT_ID, +// PROJECT_ID, +// }; +// use crate::routers::wns::error::WnsError; +// use crate::routers::wns::router::WnsRouter; +// use crate::routers::wns::settings::WnsSettings; +// use crate::routers::RouterError; +// use crate::routers::{Router, RouterResponse}; +// use autopush_common::db::client::DbClient; +// use autopush_common::db::mock::MockDbClient; +// use std::sync::Arc; + +// use cadence::StatsdClient; +// use mockall::predicate; +// use std::collections::HashMap; +// use url::Url; + +// const WNS_TOKEN: &str = "test-token"; + +// /// Create a router for testing, using the given service auth file +// async fn make_router( +// server: &mut mockito::ServerGuard, +// wns_credential: String, +// gcm_credential: String, +// db: Box, +// ) -> WnsRouter { +// let url = &server.url(); +// WnsRouter::new( +// WnsSettings { +// base_url: Url::parse(url).unwrap(), +// server_credentials: serde_json::json!({ +// "dev": { +// "project_id": PROJECT_ID, +// "credential": wns_credential +// }, +// GCM_PROJECT_ID: { +// "project_id": GCM_PROJECT_ID, +// "credential": gcm_credential, +// "is_gcm": true, +// } +// }) +// .to_string(), +// ..Default::default() +// }, +// Url::parse("http://localhost:8080/").unwrap(), +// reqwest::Client::new(), +// Arc::new(StatsdClient::from_sink("autopush", cadence::NopMetricSink)), +// db, +// ) +// .await +// .unwrap() +// } + +// /// Create default user router data +// fn default_router_data() -> HashMap { +// let mut map = HashMap::new(); +// map.insert( +// "token".to_string(), +// serde_json::to_value(WNS_TOKEN).unwrap(), +// ); +// map.insert("app_id".to_string(), serde_json::to_value("dev").unwrap()); +// map +// } + +// /// A notification with no data is sent to WNS +// #[tokio::test] +// async fn successful_routing_no_data() { +// let mut server = mockito::Server::new_async().await; + +// let mdb = MockDbClient::new(); +// let db = mdb.into_boxed_arc(); +// let service_key = make_service_key(&server); +// let router = make_router(&mut server, service_key, "whatever".to_string(), db).await; +// assert!(router.active()); +// let _token_mock = mock_token_endpoint(&mut server).await; +// let wns_mock = mock_wns_endpoint_builder(&mut server, PROJECT_ID) +// .match_body( +// serde_json::json!({ +// "message": { +// "android": { +// "data": { +// "chid": CHANNEL_ID +// }, +// "ttl": "60s" +// }, +// "token": "test-token" +// } +// }) +// .to_string() +// .as_str(), +// ) +// .create(); +// let notification = make_notification(default_router_data(), None, RouterType::WNS); + +// let result = router.route_notification(¬ification).await; +// assert!(result.is_ok(), "result = {result:?}"); +// assert_eq!( +// result.unwrap(), +// RouterResponse::success("http://localhost:8080/m/test-message-id".to_string(), 0) +// ); +// wns_mock.assert(); +// } + +// /// A notification with data is sent to WNS +// #[tokio::test] +// async fn successful_routing_with_data() { +// let mut server = mockito::Server::new_async().await; + +// let mdb = MockDbClient::new(); +// let db = mdb.into_boxed_arc(); +// let service_key = make_service_key(&server); +// let router = make_router(&mut server, service_key, "whatever".to_string(), db).await; +// let _token_mock = mock_token_endpoint(&mut server).await; +// let wns_mock = mock_wns_endpoint_builder(&mut server, PROJECT_ID) +// .match_body( +// serde_json::json!({ +// "message": { +// "android": { +// "data": { +// "chid": CHANNEL_ID, +// "body": "test-data", +// "con": "test-encoding", +// "enc": "test-encryption", +// "cryptokey": "test-crypto-key", +// "enckey": "test-encryption-key" +// }, +// "ttl": "60s" +// }, +// "token": "test-token" +// } +// }) +// .to_string() +// .as_str(), +// ) +// .create(); +// let data = "test-data".to_string(); +// let notification = make_notification(default_router_data(), Some(data), RouterType::WNS); + +// let result = router.route_notification(¬ification).await; +// assert!(result.is_ok(), "result = {result:?}"); +// assert_eq!( +// result.unwrap(), +// RouterResponse::success("http://localhost:8080/m/test-message-id".to_string(), 0) +// ); +// wns_mock.assert(); +// } + +// /// If there is no client for the user's app ID, an error is returned and +// /// the WNS request is not sent. +// #[tokio::test] +// async fn missing_client() { +// let mut server = mockito::Server::new_async().await; + +// let db = MockDbClient::new().into_boxed_arc(); +// let service_key = make_service_key(&server); +// let router = make_router(&mut server, service_key, "whatever".to_string(), db).await; +// let _token_mock = mock_token_endpoint(&mut server).await; +// let wns_mock = mock_wns_endpoint_builder(&mut server, PROJECT_ID) +// .expect(0) +// .create_async() +// .await; +// let mut router_data = default_router_data(); +// let app_id = "app_id".to_string(); +// router_data.insert( +// app_id.clone(), +// serde_json::to_value("unknown-app-id").unwrap(), +// ); +// let notification = make_notification(router_data, None, RouterType::WNS); + +// let result = router.route_notification(¬ification).await; +// assert!(result.is_err()); +// assert!( +// matches!( +// &result.as_ref().unwrap_err().kind, +// ApiErrorKind::Router(RouterError::Wns(WnsError::InvalidAppId(_app_id))) +// ), +// "result = {result:?}" +// ); +// wns_mock.assert(); +// } + +// /// If the WNS user no longer exists (404), we drop the user from our database +// #[tokio::test] +// async fn no_wns_user() { +// let mut server = mockito::Server::new_async().await; + +// let notification = make_notification(default_router_data(), None, RouterType::WNS); +// let mut db = MockDbClient::new(); +// db.expect_remove_user() +// .with(predicate::eq(notification.subscription.user.uaid)) +// .times(1) +// .return_once(|_| Ok(())); + +// let service_key = make_service_key(&server); +// let router = make_router( +// &mut server, +// service_key, +// "whatever".to_string(), +// db.into_boxed_arc(), +// ) +// .await; +// let _token_mock = mock_token_endpoint(&mut server).await; +// let _wns_mock = mock_wns_endpoint_builder(&mut server, PROJECT_ID) +// .with_status(404) +// .with_body(r#"{"error":{"status":"NOT_FOUND","message":"test-message"}}"#) +// .create_async() +// .await; + +// let result = router.route_notification(¬ification).await; +// assert!(result.is_err()); +// assert!( +// matches!( +// result.as_ref().unwrap_err().kind, +// ApiErrorKind::Router(RouterError::NotFound) +// ), +// "result = {result:?}" +// ); +// } +// } diff --git a/autoendpoint/src/routers/wns/settings.rs b/autoendpoint/src/routers/wns/settings.rs new file mode 100644 index 000000000..11ba2becb --- /dev/null +++ b/autoendpoint/src/routers/wns/settings.rs @@ -0,0 +1,65 @@ +use std::collections::HashMap; + +use url::Url; + +/// Settings for `WnsRouter` +#[derive(Clone, Debug, serde::Deserialize)] +#[serde(default)] +#[serde(deny_unknown_fields)] +pub struct WnsSettings { + /// The minimum TTL to use for WNS notifications + pub min_ttl: usize, + /// A JSON dict of `WnsCredential`s. This must be a `String` because + /// environment variables cannot encode a `HashMap` + /// WNS is specified as + /// + /// ```json + /// {"_project_id_":{"project_id": "_project_id_", "credential": "_key_"}, ...} + /// ``` + /// For WNS, `credential` keys can be either a serialized JSON string, or the + /// path to the JSON key file. + /// + /// ```json + /// {"bar-project":{"project_id": "bar-project-1234", "credential": "{\"type\": ...}"}, + /// "gorp-project":{"project_id": "gorp-project-abcd", "credential": "keys/gorp-project.json"}, + /// "f00": {"project_id": "f00", "credential": "abcd0123457"}, + /// ... + /// } + /// ``` + #[serde(rename = "credentials")] + pub server_credentials: String, + /// The max size of notification data in bytes + pub max_data: usize, + /// The base URL to use for WNS requests + pub base_url: Url, + /// The number of seconds to wait for WNS requests to complete + pub timeout: usize, +} + +/// Credential information for each application +#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] +pub struct WnsServerCredential { + pub project_id: String, + #[serde(rename = "credential")] + pub server_access_token: String, +} + +impl Default for WnsSettings { + fn default() -> Self { + Self { + min_ttl: 60, + server_credentials: "{}".to_string(), + max_data: 4096, + base_url: Url::parse("https://login.microsoftonline.com").unwrap(), + timeout: 3, + } + } +} + +impl WnsSettings { + /// Read the credentials from the provided JSON + pub fn credentials(&self) -> serde_json::Result> { + trace!("credentials: {}", self.server_credentials); + serde_json::from_str(&self.server_credentials) + } +} diff --git a/autoendpoint/src/routes/health.rs b/autoendpoint/src/routes/health.rs index 249c5dcd6..173bc4d75 100644 --- a/autoendpoint/src/routes/health.rs +++ b/autoendpoint/src/routes/health.rs @@ -21,6 +21,7 @@ pub async fn health_route(state: Data) -> Json { let mut routers: HashMap<&str, bool> = HashMap::new(); routers.insert("apns", state.apns_router.active()); routers.insert("fcm", state.fcm_router.active()); + routers.insert("wns", state.wns_router.active()); let health = json!({ "status": "OK", diff --git a/autoendpoint/src/server.rs b/autoendpoint/src/server.rs index bebbc9a8e..1b11cc22d 100644 --- a/autoendpoint/src/server.rs +++ b/autoendpoint/src/server.rs @@ -21,7 +21,7 @@ use autopush_common::{ use crate::metrics; #[cfg(feature = "stub")] use crate::routers::stub::router::StubRouter; -use crate::routers::{apns::router::ApnsRouter, fcm::router::FcmRouter}; +use crate::routers::{apns::router::ApnsRouter, fcm::router::FcmRouter, wns::router::WnsRouter}; use crate::routes::{ health::{health_route, lb_heartbeat_route, log_check, status_route, version_route}, registration::{ @@ -46,6 +46,7 @@ pub struct AppState { pub http: reqwest::Client, pub fcm_router: Arc, pub apns_router: Arc, + pub wns_router: Arc, #[cfg(feature = "stub")] pub stub_router: Arc, pub reliability: Arc, @@ -110,6 +111,16 @@ impl Server { .await?, ); let reliability = Arc::new(VapidTracker(settings.tracking_keys())); + let wns_router = Arc::new( + WnsRouter::new( + settings.wns.clone(), + endpoint_url.clone(), + http.clone(), + metrics.clone(), + db.clone(), + ) + .await?, + ); #[cfg(feature = "stub")] let stub_router = Arc::new(StubRouter::new(settings.stub.clone())?); let app_state = AppState { @@ -120,6 +131,7 @@ impl Server { http, fcm_router, apns_router, + wns_router, #[cfg(feature = "stub")] stub_router, reliability, diff --git a/autoendpoint/src/settings.rs b/autoendpoint/src/settings.rs index e519ac3ef..82f418d33 100644 --- a/autoendpoint/src/settings.rs +++ b/autoendpoint/src/settings.rs @@ -9,6 +9,7 @@ use url::Url; use crate::headers::vapid::VapidHeaderWithKey; use crate::routers::apns::settings::ApnsSettings; use crate::routers::fcm::settings::FcmSettings; +use crate::routers::wns::settings::WnsSettings; #[cfg(feature = "stub")] use crate::routers::stub::settings::StubSettings; @@ -55,6 +56,7 @@ pub struct Settings { pub fcm: FcmSettings, pub apns: ApnsSettings, + pub wns: WnsSettings, #[cfg(feature = "stub")] pub stub: StubSettings, } @@ -90,6 +92,7 @@ impl Default for Settings { statsd_label: "autoendpoint".to_string(), fcm: FcmSettings::default(), apns: ApnsSettings::default(), + wns: WnsSettings::default(), #[cfg(feature = "stub")] stub: StubSettings::default(), }