1use std::collections::BTreeMap;
8
9use camino::Utf8PathBuf;
10use mas_iana::jose::JsonWebSignatureAlg;
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize, de::Error};
13use serde_with::{serde_as, skip_serializing_none};
14use ulid::Ulid;
15use url::Url;
16
17use crate::{ClientSecret, ClientSecretRaw, ConfigurationSection};
18
19#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
21pub struct UpstreamOAuth2Config {
22 pub providers: Vec<Provider>,
24}
25
26impl UpstreamOAuth2Config {
27 pub(crate) fn is_default(&self) -> bool {
29 self.providers.is_empty()
30 }
31}
32
33impl ConfigurationSection for UpstreamOAuth2Config {
34 const PATH: Option<&'static str> = Some("upstream_oauth2");
35
36 fn validate(
37 &self,
38 figment: &figment::Figment,
39 ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
40 for (index, provider) in self.providers.iter().enumerate() {
41 let annotate = |mut error: figment::Error| {
42 error.metadata = figment
43 .find_metadata(&format!("{root}.providers", root = Self::PATH.unwrap()))
44 .cloned();
45 error.profile = Some(figment::Profile::Default);
46 error.path = vec![
47 Self::PATH.unwrap().to_owned(),
48 "providers".to_owned(),
49 index.to_string(),
50 ];
51 error
52 };
53
54 if !matches!(provider.discovery_mode, DiscoveryMode::Disabled)
55 && provider.issuer.is_none()
56 {
57 return Err(annotate(figment::Error::custom(
58 "The `issuer` field is required when discovery is enabled",
59 ))
60 .into());
61 }
62
63 match provider.token_endpoint_auth_method {
64 TokenAuthMethod::None
65 | TokenAuthMethod::PrivateKeyJwt
66 | TokenAuthMethod::SignInWithApple => {
67 if provider.client_secret.is_some() {
68 return Err(annotate(figment::Error::custom(
69 "Unexpected field `client_secret` for the selected authentication method",
70 )).into());
71 }
72 }
73 TokenAuthMethod::ClientSecretBasic
74 | TokenAuthMethod::ClientSecretPost
75 | TokenAuthMethod::ClientSecretJwt => {
76 if provider.client_secret.is_none() {
77 return Err(annotate(figment::Error::missing_field("client_secret")).into());
78 }
79 }
80 }
81
82 match provider.token_endpoint_auth_method {
83 TokenAuthMethod::None
84 | TokenAuthMethod::ClientSecretBasic
85 | TokenAuthMethod::ClientSecretPost
86 | TokenAuthMethod::SignInWithApple => {
87 if provider.token_endpoint_auth_signing_alg.is_some() {
88 return Err(annotate(figment::Error::custom(
89 "Unexpected field `token_endpoint_auth_signing_alg` for the selected authentication method",
90 )).into());
91 }
92 }
93 TokenAuthMethod::ClientSecretJwt | TokenAuthMethod::PrivateKeyJwt => {
94 if provider.token_endpoint_auth_signing_alg.is_none() {
95 return Err(annotate(figment::Error::missing_field(
96 "token_endpoint_auth_signing_alg",
97 ))
98 .into());
99 }
100 }
101 }
102
103 match provider.token_endpoint_auth_method {
104 TokenAuthMethod::SignInWithApple => {
105 if provider.sign_in_with_apple.is_none() {
106 return Err(
107 annotate(figment::Error::missing_field("sign_in_with_apple")).into(),
108 );
109 }
110 }
111
112 _ => {
113 if provider.sign_in_with_apple.is_some() {
114 return Err(annotate(figment::Error::custom(
115 "Unexpected field `sign_in_with_apple` for the selected authentication method",
116 )).into());
117 }
118 }
119 }
120
121 if matches!(
122 provider.claims_imports.localpart.on_conflict,
123 OnConflict::Add | OnConflict::Replace | OnConflict::Set
124 ) && !matches!(
125 provider.claims_imports.localpart.action,
126 ImportAction::Force | ImportAction::Require
127 ) {
128 return Err(annotate(figment::Error::custom(
129 "The field `action` must be either `force` or `require` when `on_conflict` is set to `add`, `replace` or `set`",
130 )).with_path("claims_imports.localpart").into());
131 }
132 }
133
134 Ok(())
135 }
136}
137
138#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
140#[serde(rename_all = "snake_case")]
141pub enum ResponseMode {
142 Query,
145
146 FormPost,
151}
152
153#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
155#[serde(rename_all = "snake_case")]
156pub enum TokenAuthMethod {
157 None,
159
160 ClientSecretBasic,
163
164 ClientSecretPost,
167
168 ClientSecretJwt,
171
172 PrivateKeyJwt,
175
176 SignInWithApple,
178}
179
180#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
182#[serde(rename_all = "lowercase")]
183pub enum ImportAction {
184 #[default]
186 Ignore,
187
188 Suggest,
190
191 Force,
193
194 Require,
196}
197
198impl ImportAction {
199 #[allow(clippy::trivially_copy_pass_by_ref)]
200 const fn is_default(&self) -> bool {
201 matches!(self, ImportAction::Ignore)
202 }
203}
204
205#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
207#[serde(rename_all = "lowercase")]
208pub enum OnConflict {
209 #[default]
211 Fail,
212
213 Add,
216
217 Replace,
219
220 Set,
223}
224
225impl OnConflict {
226 #[allow(clippy::trivially_copy_pass_by_ref)]
227 const fn is_default(&self) -> bool {
228 matches!(self, OnConflict::Fail)
229 }
230}
231
232#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
234pub struct SubjectImportPreference {
235 #[serde(default, skip_serializing_if = "Option::is_none")]
239 pub template: Option<String>,
240}
241
242impl SubjectImportPreference {
243 const fn is_default(&self) -> bool {
244 self.template.is_none()
245 }
246}
247
248#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
250pub struct LocalpartImportPreference {
251 #[serde(default, skip_serializing_if = "ImportAction::is_default")]
253 pub action: ImportAction,
254
255 #[serde(default, skip_serializing_if = "Option::is_none")]
259 pub template: Option<String>,
260
261 #[serde(default, skip_serializing_if = "OnConflict::is_default")]
263 pub on_conflict: OnConflict,
264}
265
266impl LocalpartImportPreference {
267 const fn is_default(&self) -> bool {
268 self.action.is_default() && self.template.is_none()
269 }
270}
271
272#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
274pub struct DisplaynameImportPreference {
275 #[serde(default, skip_serializing_if = "ImportAction::is_default")]
277 pub action: ImportAction,
278
279 #[serde(default, skip_serializing_if = "Option::is_none")]
283 pub template: Option<String>,
284}
285
286impl DisplaynameImportPreference {
287 const fn is_default(&self) -> bool {
288 self.action.is_default() && self.template.is_none()
289 }
290}
291
292#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
294pub struct EmailImportPreference {
295 #[serde(default, skip_serializing_if = "ImportAction::is_default")]
297 pub action: ImportAction,
298
299 #[serde(default, skip_serializing_if = "Option::is_none")]
303 pub template: Option<String>,
304}
305
306impl EmailImportPreference {
307 const fn is_default(&self) -> bool {
308 self.action.is_default() && self.template.is_none()
309 }
310}
311
312#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
314pub struct AccountNameImportPreference {
315 #[serde(default, skip_serializing_if = "Option::is_none")]
320 pub template: Option<String>,
321}
322
323impl AccountNameImportPreference {
324 const fn is_default(&self) -> bool {
325 self.template.is_none()
326 }
327}
328
329#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
331pub struct ClaimsImports {
332 #[serde(default, skip_serializing_if = "SubjectImportPreference::is_default")]
334 pub subject: SubjectImportPreference,
335
336 #[serde(default, skip_serializing_if = "LocalpartImportPreference::is_default")]
338 pub localpart: LocalpartImportPreference,
339
340 #[serde(
342 default,
343 skip_serializing_if = "DisplaynameImportPreference::is_default"
344 )]
345 pub displayname: DisplaynameImportPreference,
346
347 #[serde(default, skip_serializing_if = "EmailImportPreference::is_default")]
350 pub email: EmailImportPreference,
351
352 #[serde(
354 default,
355 skip_serializing_if = "AccountNameImportPreference::is_default"
356 )]
357 pub account_name: AccountNameImportPreference,
358}
359
360impl ClaimsImports {
361 const fn is_default(&self) -> bool {
362 self.subject.is_default()
363 && self.localpart.is_default()
364 && self.displayname.is_default()
365 && self.email.is_default()
366 }
367}
368
369#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)]
371#[serde(rename_all = "snake_case")]
372pub enum DiscoveryMode {
373 #[default]
375 Oidc,
376
377 Insecure,
379
380 Disabled,
382}
383
384impl DiscoveryMode {
385 #[allow(clippy::trivially_copy_pass_by_ref)]
386 const fn is_default(&self) -> bool {
387 matches!(self, DiscoveryMode::Oidc)
388 }
389}
390
391#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)]
394#[serde(rename_all = "snake_case")]
395pub enum PkceMethod {
396 #[default]
400 Auto,
401
402 Always,
404
405 Never,
407}
408
409impl PkceMethod {
410 #[allow(clippy::trivially_copy_pass_by_ref)]
411 const fn is_default(&self) -> bool {
412 matches!(self, PkceMethod::Auto)
413 }
414}
415
416fn default_true() -> bool {
417 true
418}
419
420#[allow(clippy::trivially_copy_pass_by_ref)]
421fn is_default_true(value: &bool) -> bool {
422 *value
423}
424
425#[allow(clippy::ref_option)]
426fn is_signed_response_alg_default(signed_response_alg: &JsonWebSignatureAlg) -> bool {
427 *signed_response_alg == signed_response_alg_default()
428}
429
430#[allow(clippy::unnecessary_wraps)]
431fn signed_response_alg_default() -> JsonWebSignatureAlg {
432 JsonWebSignatureAlg::Rs256
433}
434
435#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
436pub struct SignInWithApple {
437 #[serde(skip_serializing_if = "Option::is_none")]
439 #[schemars(with = "Option<String>")]
440 pub private_key_file: Option<Utf8PathBuf>,
441
442 #[serde(skip_serializing_if = "Option::is_none")]
444 pub private_key: Option<String>,
445
446 pub team_id: String,
448
449 pub key_id: String,
451}
452
453fn default_scope() -> String {
454 "openid".to_owned()
455}
456
457fn is_default_scope(scope: &str) -> bool {
458 scope == default_scope()
459}
460
461#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)]
463#[serde(rename_all = "snake_case")]
464pub enum OnBackchannelLogout {
465 #[default]
467 DoNothing,
468
469 LogoutBrowserOnly,
471
472 LogoutAll,
475}
476
477impl OnBackchannelLogout {
478 #[allow(clippy::trivially_copy_pass_by_ref)]
479 const fn is_default(&self) -> bool {
480 matches!(self, OnBackchannelLogout::DoNothing)
481 }
482}
483
484#[serde_as]
486#[skip_serializing_none]
487#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
488pub struct Provider {
489 #[serde(default = "default_true", skip_serializing_if = "is_default_true")]
493 pub enabled: bool,
494
495 #[schemars(
497 with = "String",
498 regex(pattern = r"^[0123456789ABCDEFGHJKMNPQRSTVWXYZ]{26}$"),
499 description = "A ULID as per https://github.com/ulid/spec"
500 )]
501 pub id: Ulid,
502
503 #[serde(skip_serializing_if = "Option::is_none")]
518 pub synapse_idp_id: Option<String>,
519
520 #[serde(skip_serializing_if = "Option::is_none")]
524 pub issuer: Option<String>,
525
526 #[serde(skip_serializing_if = "Option::is_none")]
528 pub human_name: Option<String>,
529
530 #[serde(skip_serializing_if = "Option::is_none")]
543 pub brand_name: Option<String>,
544
545 pub client_id: String,
547
548 #[schemars(with = "ClientSecretRaw")]
553 #[serde_as(as = "serde_with::TryFromInto<ClientSecretRaw>")]
554 #[serde(flatten)]
555 pub client_secret: Option<ClientSecret>,
556
557 pub token_endpoint_auth_method: TokenAuthMethod,
559
560 #[serde(skip_serializing_if = "Option::is_none")]
562 pub sign_in_with_apple: Option<SignInWithApple>,
563
564 #[serde(skip_serializing_if = "Option::is_none")]
569 pub token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
570
571 #[serde(
576 default = "signed_response_alg_default",
577 skip_serializing_if = "is_signed_response_alg_default"
578 )]
579 pub id_token_signed_response_alg: JsonWebSignatureAlg,
580
581 #[serde(default = "default_scope", skip_serializing_if = "is_default_scope")]
585 pub scope: String,
586
587 #[serde(default, skip_serializing_if = "DiscoveryMode::is_default")]
592 pub discovery_mode: DiscoveryMode,
593
594 #[serde(default, skip_serializing_if = "PkceMethod::is_default")]
599 pub pkce_method: PkceMethod,
600
601 #[serde(default)]
607 pub fetch_userinfo: bool,
608
609 #[serde(skip_serializing_if = "Option::is_none")]
615 pub userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
616
617 #[serde(skip_serializing_if = "Option::is_none")]
621 pub authorization_endpoint: Option<Url>,
622
623 #[serde(skip_serializing_if = "Option::is_none")]
627 pub userinfo_endpoint: Option<Url>,
628
629 #[serde(skip_serializing_if = "Option::is_none")]
633 pub token_endpoint: Option<Url>,
634
635 #[serde(skip_serializing_if = "Option::is_none")]
639 pub jwks_uri: Option<Url>,
640
641 #[serde(skip_serializing_if = "Option::is_none")]
643 pub response_mode: Option<ResponseMode>,
644
645 #[serde(default, skip_serializing_if = "ClaimsImports::is_default")]
648 pub claims_imports: ClaimsImports,
649
650 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
654 pub additional_authorization_parameters: BTreeMap<String, String>,
655
656 #[serde(default)]
661 pub forward_login_hint: bool,
662
663 #[serde(default, skip_serializing_if = "OnBackchannelLogout::is_default")]
667 pub on_backchannel_logout: OnBackchannelLogout,
668}
669
670impl Provider {
671 pub async fn client_secret(&self) -> anyhow::Result<Option<String>> {
679 Ok(match &self.client_secret {
680 Some(client_secret) => Some(client_secret.value().await?),
681 None => None,
682 })
683 }
684}
685
686#[cfg(test)]
687mod tests {
688 use std::str::FromStr;
689
690 use figment::{
691 Figment, Jail,
692 providers::{Format, Yaml},
693 };
694 use tokio::{runtime::Handle, task};
695
696 use super::*;
697
698 #[tokio::test]
699 async fn load_config() {
700 task::spawn_blocking(|| {
701 Jail::expect_with(|jail| {
702 jail.create_file(
703 "config.yaml",
704 r#"
705 upstream_oauth2:
706 providers:
707 - id: 01GFWR28C4KNE04WG3HKXB7C9R
708 client_id: upstream-oauth2
709 token_endpoint_auth_method: none
710
711 - id: 01GFWR32NCQ12B8Z0J8CPXRRB6
712 client_id: upstream-oauth2
713 client_secret_file: secret
714 token_endpoint_auth_method: client_secret_basic
715
716 - id: 01GFWR3WHR93Y5HK389H28VHZ9
717 client_id: upstream-oauth2
718 client_secret: c1!3n753c237
719 token_endpoint_auth_method: client_secret_post
720
721 - id: 01GFWR43R2ZZ8HX9CVBNW9TJWG
722 client_id: upstream-oauth2
723 client_secret_file: secret
724 token_endpoint_auth_method: client_secret_jwt
725
726 - id: 01GFWR4BNFDCC4QDG6AMSP1VRR
727 client_id: upstream-oauth2
728 token_endpoint_auth_method: private_key_jwt
729 jwks:
730 keys:
731 - kid: "03e84aed4ef4431014e8617567864c4efaaaede9"
732 kty: "RSA"
733 alg: "RS256"
734 use: "sig"
735 e: "AQAB"
736 n: "ma2uRyBeSEOatGuDpCiV9oIxlDWix_KypDYuhQfEzqi_BiF4fV266OWfyjcABbam59aJMNvOnKW3u_eZM-PhMCBij5MZ-vcBJ4GfxDJeKSn-GP_dJ09rpDcILh8HaWAnPmMoi4DC0nrfE241wPISvZaaZnGHkOrfN_EnA5DligLgVUbrA5rJhQ1aSEQO_gf1raEOW3DZ_ACU3qhtgO0ZBG3a5h7BPiRs2sXqb2UCmBBgwyvYLDebnpE7AotF6_xBIlR-Cykdap3GHVMXhrIpvU195HF30ZoBU4dMd-AeG6HgRt4Cqy1moGoDgMQfbmQ48Hlunv9_Vi2e2CLvYECcBw"
737
738 - kid: "d01c1abe249269f72ef7ca2613a86c9f05e59567"
739 kty: "RSA"
740 alg: "RS256"
741 use: "sig"
742 e: "AQAB"
743 n: "0hukqytPwrj1RbMYhYoepCi3CN5k7DwYkTe_Cmb7cP9_qv4ok78KdvFXt5AnQxCRwBD7-qTNkkfMWO2RxUMBdQD0ED6tsSb1n5dp0XY8dSWiBDCX8f6Hr-KolOpvMLZKRy01HdAWcM6RoL9ikbjYHUEW1C8IJnw3MzVHkpKFDL354aptdNLaAdTCBvKzU9WpXo10g-5ctzSlWWjQuecLMQ4G1mNdsR1LHhUENEnOvgT8cDkX0fJzLbEbyBYkdMgKggyVPEB1bg6evG4fTKawgnf0IDSPxIU-wdS9wdSP9ZCJJPLi5CEp-6t6rE_sb2dGcnzjCGlembC57VwpkUvyMw"
744 "#,
745 )?;
746 jail.create_file("secret", r"c1!3n753c237")?;
747
748 let config = Figment::new()
749 .merge(Yaml::file("config.yaml"))
750 .extract_inner::<UpstreamOAuth2Config>("upstream_oauth2")?;
751
752 assert_eq!(config.providers.len(), 5);
753
754 assert_eq!(
755 config.providers[1].id,
756 Ulid::from_str("01GFWR32NCQ12B8Z0J8CPXRRB6").unwrap()
757 );
758
759 assert!(config.providers[0].client_secret.is_none());
760 assert!(matches!(config.providers[1].client_secret, Some(ClientSecret::File(ref p)) if p == "secret"));
761 assert!(matches!(config.providers[2].client_secret, Some(ClientSecret::Value(ref v)) if v == "c1!3n753c237"));
762 assert!(matches!(config.providers[3].client_secret, Some(ClientSecret::File(ref p)) if p == "secret"));
763 assert!(config.providers[4].client_secret.is_none());
764
765 Handle::current().block_on(async move {
766 assert_eq!(config.providers[1].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
767 assert_eq!(config.providers[2].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
768 assert_eq!(config.providers[3].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
769 });
770
771 Ok(())
772 });
773 }).await.unwrap();
774 }
775}