mas_matrix/
lib.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7mod mock;
8mod readonly;
9
10use std::{collections::HashSet, sync::Arc};
11
12use ruma_common::UserId;
13
14pub use self::{
15    mock::HomeserverConnection as MockHomeserverConnection, readonly::ReadOnlyHomeserverConnection,
16};
17
18#[derive(Debug)]
19pub struct MatrixUser {
20    pub displayname: Option<String>,
21    pub avatar_url: Option<String>,
22    pub deactivated: bool,
23}
24
25#[derive(Debug, Default)]
26enum FieldAction<T> {
27    #[default]
28    DoNothing,
29    Set(T),
30    Unset,
31}
32
33pub struct ProvisionRequest {
34    localpart: String,
35    sub: String,
36    locked: bool,
37    displayname: FieldAction<String>,
38    avatar_url: FieldAction<String>,
39    emails: FieldAction<Vec<String>>,
40}
41
42impl ProvisionRequest {
43    /// Create a new [`ProvisionRequest`].
44    ///
45    /// # Parameters
46    ///
47    /// * `localpart` - The localpart of the user to provision.
48    /// * `sub` - The `sub` of the user, aka the internal ID.
49    /// * `locked` - Whether the user is locked.
50    #[must_use]
51    pub fn new(localpart: impl Into<String>, sub: impl Into<String>, locked: bool) -> Self {
52        Self {
53            localpart: localpart.into(),
54            sub: sub.into(),
55            locked,
56            displayname: FieldAction::DoNothing,
57            avatar_url: FieldAction::DoNothing,
58            emails: FieldAction::DoNothing,
59        }
60    }
61
62    /// Get the `sub` of the user to provision, aka the internal ID.
63    #[must_use]
64    pub fn sub(&self) -> &str {
65        &self.sub
66    }
67
68    /// Get the localpart of the user to provision.
69    #[must_use]
70    pub fn localpart(&self) -> &str {
71        &self.localpart
72    }
73
74    /// Get the locked flag of the user to provision
75    pub fn locked(&self) -> bool {
76        self.locked
77    }
78
79    /// Ask to set the displayname of the user.
80    ///
81    /// # Parameters
82    ///
83    /// * `displayname` - The displayname to set.
84    #[must_use]
85    pub fn set_displayname(mut self, displayname: String) -> Self {
86        self.displayname = FieldAction::Set(displayname);
87        self
88    }
89
90    /// Ask to unset the displayname of the user.
91    #[must_use]
92    pub fn unset_displayname(mut self) -> Self {
93        self.displayname = FieldAction::Unset;
94        self
95    }
96
97    /// Call the given callback if the displayname should be set or unset.
98    ///
99    /// # Parameters
100    ///
101    /// * `callback` - The callback to call.
102    pub fn on_displayname<F>(&self, callback: F) -> &Self
103    where
104        F: FnOnce(Option<&str>),
105    {
106        match &self.displayname {
107            FieldAction::Unset => callback(None),
108            FieldAction::Set(displayname) => callback(Some(displayname)),
109            FieldAction::DoNothing => {}
110        }
111
112        self
113    }
114
115    /// Ask to set the avatar URL of the user.
116    ///
117    /// # Parameters
118    ///
119    /// * `avatar_url` - The avatar URL to set.
120    #[must_use]
121    pub fn set_avatar_url(mut self, avatar_url: String) -> Self {
122        self.avatar_url = FieldAction::Set(avatar_url);
123        self
124    }
125
126    /// Ask to unset the avatar URL of the user.
127    #[must_use]
128    pub fn unset_avatar_url(mut self) -> Self {
129        self.avatar_url = FieldAction::Unset;
130        self
131    }
132
133    /// Call the given callback if the avatar URL should be set or unset.
134    ///
135    /// # Parameters
136    ///
137    /// * `callback` - The callback to call.
138    pub fn on_avatar_url<F>(&self, callback: F) -> &Self
139    where
140        F: FnOnce(Option<&str>),
141    {
142        match &self.avatar_url {
143            FieldAction::Unset => callback(None),
144            FieldAction::Set(avatar_url) => callback(Some(avatar_url)),
145            FieldAction::DoNothing => {}
146        }
147
148        self
149    }
150
151    /// Ask to set the emails of the user.
152    ///
153    /// # Parameters
154    ///
155    /// * `emails` - The list of emails to set.
156    #[must_use]
157    pub fn set_emails(mut self, emails: Vec<String>) -> Self {
158        self.emails = FieldAction::Set(emails);
159        self
160    }
161
162    /// Ask to unset the emails of the user.
163    #[must_use]
164    pub fn unset_emails(mut self) -> Self {
165        self.emails = FieldAction::Unset;
166        self
167    }
168
169    /// Call the given callback if the emails should be set or unset.
170    ///
171    /// # Parameters
172    ///
173    /// * `callback` - The callback to call.
174    pub fn on_emails<F>(&self, callback: F) -> &Self
175    where
176        F: FnOnce(Option<&[String]>),
177    {
178        match &self.emails {
179            FieldAction::Unset => callback(None),
180            FieldAction::Set(emails) => callback(Some(emails)),
181            FieldAction::DoNothing => {}
182        }
183
184        self
185    }
186}
187
188#[async_trait::async_trait]
189pub trait HomeserverConnection: Send + Sync {
190    /// Get the homeserver URL.
191    fn homeserver(&self) -> &str;
192
193    /// Get the Matrix ID of the user with the given localpart.
194    ///
195    /// # Parameters
196    ///
197    /// * `localpart` - The localpart of the user.
198    fn mxid(&self, localpart: &str) -> String {
199        format!("@{}:{}", localpart, self.homeserver())
200    }
201
202    /// Get the localpart of a Matrix ID if it has the right server name
203    ///
204    /// Returns [`None`] if the input isn't a valid MXID, or if the server name
205    /// doesn't match
206    ///
207    /// # Parameters
208    ///
209    /// * `mxid` - The MXID of the user
210    fn localpart<'a>(&self, mxid: &'a str) -> Option<&'a str> {
211        let mxid = <&UserId>::try_from(mxid).ok()?;
212        if mxid.server_name() != self.homeserver() {
213            return None;
214        }
215        Some(mxid.localpart())
216    }
217
218    /// Verify a bearer token coming from the homeserver for homeserver to MAS
219    /// interactions
220    ///
221    /// Returns `true` if the token is valid, `false` otherwise.
222    ///
223    /// # Parameters
224    ///
225    /// * `token` - The token to verify.
226    ///
227    /// # Errors
228    ///
229    /// Returns an error if the token failed to verify.
230    async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error>;
231
232    /// Query the state of a user on the homeserver.
233    ///
234    /// # Parameters
235    ///
236    /// * `localpart` - The localpart of the user to query.
237    ///
238    /// # Errors
239    ///
240    /// Returns an error if the homeserver is unreachable or the user does not
241    /// exist.
242    async fn query_user(&self, localpart: &str) -> Result<MatrixUser, anyhow::Error>;
243
244    /// Provision a user on the homeserver.
245    ///
246    /// # Parameters
247    ///
248    /// * `request` - a [`ProvisionRequest`] containing the details of the user
249    ///   to provision.
250    ///
251    /// # Errors
252    ///
253    /// Returns an error if the homeserver is unreachable or the user could not
254    /// be provisioned.
255    async fn provision_user(&self, request: &ProvisionRequest) -> Result<bool, anyhow::Error>;
256
257    /// Check whether a given username is available on the homeserver.
258    ///
259    /// # Parameters
260    ///
261    /// * `localpart` - The localpart to check.
262    ///
263    /// # Errors
264    ///
265    /// Returns an error if the homeserver is unreachable.
266    async fn is_localpart_available(&self, localpart: &str) -> Result<bool, anyhow::Error>;
267
268    /// Create a device for a user on the homeserver.
269    ///
270    /// # Parameters
271    ///
272    /// * `localpart` - The localpart of the user to create a device for.
273    /// * `device_id` - The device ID to create.
274    ///
275    /// # Errors
276    ///
277    /// Returns an error if the homeserver is unreachable or the device could
278    /// not be created.
279    async fn upsert_device(
280        &self,
281        localpart: &str,
282        device_id: &str,
283        initial_display_name: Option<&str>,
284    ) -> Result<(), anyhow::Error>;
285
286    /// Update the display name of a device for a user on the homeserver.
287    ///
288    /// # Parameters
289    ///
290    /// * `localpart` - The localpart of the user to update a device for.
291    /// * `device_id` - The device ID to update.
292    /// * `display_name` - The new display name to set
293    ///
294    /// # Errors
295    ///
296    /// Returns an error if the homeserver is unreachable or the device could
297    /// not be updated.
298    async fn update_device_display_name(
299        &self,
300        localpart: &str,
301        device_id: &str,
302        display_name: &str,
303    ) -> Result<(), anyhow::Error>;
304
305    /// Delete a device for a user on the homeserver.
306    ///
307    /// # Parameters
308    ///
309    /// * `localpart` - The localpart of the user to delete a device for.
310    /// * `device_id` - The device ID to delete.
311    ///
312    /// # Errors
313    ///
314    /// Returns an error if the homeserver is unreachable or the device could
315    /// not be deleted.
316    async fn delete_device(&self, localpart: &str, device_id: &str) -> Result<(), anyhow::Error>;
317
318    /// Sync the list of devices of a user with the homeserver.
319    ///
320    /// # Parameters
321    ///
322    /// * `localpart` - The localpart of the user to sync the devices for.
323    /// * `devices` - The list of devices to sync.
324    ///
325    /// # Errors
326    ///
327    /// Returns an error if the homeserver is unreachable or the devices could
328    /// not be synced.
329    async fn sync_devices(
330        &self,
331        localpart: &str,
332        devices: HashSet<String>,
333    ) -> Result<(), anyhow::Error>;
334
335    /// Delete a user on the homeserver.
336    ///
337    /// # Parameters
338    ///
339    /// * `localpart` - The localpart of the user to delete.
340    /// * `erase` - Whether to ask the homeserver to erase the user's data.
341    ///
342    /// # Errors
343    ///
344    /// Returns an error if the homeserver is unreachable or the user could not
345    /// be deleted.
346    async fn delete_user(&self, localpart: &str, erase: bool) -> Result<(), anyhow::Error>;
347
348    /// Reactivate a user on the homeserver.
349    ///
350    /// # Parameters
351    ///
352    /// * `localpart` - The localpart of the user to reactivate.
353    ///
354    /// # Errors
355    ///
356    /// Returns an error if the homeserver is unreachable or the user could not
357    /// be reactivated.
358    async fn reactivate_user(&self, localpart: &str) -> Result<(), anyhow::Error>;
359
360    /// Set the displayname of a user on the homeserver.
361    ///
362    /// # Parameters
363    ///
364    /// * `localpart` - The localpart of the user to set the displayname for.
365    /// * `displayname` - The displayname to set.
366    ///
367    /// # Errors
368    ///
369    /// Returns an error if the homeserver is unreachable or the displayname
370    /// could not be set.
371    async fn set_displayname(
372        &self,
373        localpart: &str,
374        displayname: &str,
375    ) -> Result<(), anyhow::Error>;
376
377    /// Unset the displayname of a user on the homeserver.
378    ///
379    /// # Parameters
380    ///
381    /// * `localpart` - The localpart of the user to unset the displayname for.
382    ///
383    /// # Errors
384    ///
385    /// Returns an error if the homeserver is unreachable or the displayname
386    /// could not be unset.
387    async fn unset_displayname(&self, localpart: &str) -> Result<(), anyhow::Error>;
388
389    /// Temporarily allow a user to reset their cross-signing keys.
390    ///
391    /// # Parameters
392    ///
393    /// * `localpart` - The localpart of the user to allow cross-signing key
394    ///   reset
395    ///
396    /// # Errors
397    ///
398    /// Returns an error if the homeserver is unreachable or the cross-signing
399    /// reset could not be allowed.
400    async fn allow_cross_signing_reset(&self, localpart: &str) -> Result<(), anyhow::Error>;
401}
402
403#[async_trait::async_trait]
404impl<T: HomeserverConnection + Send + Sync + ?Sized> HomeserverConnection for &T {
405    fn homeserver(&self) -> &str {
406        (**self).homeserver()
407    }
408
409    async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error> {
410        (**self).verify_token(token).await
411    }
412
413    async fn query_user(&self, localpart: &str) -> Result<MatrixUser, anyhow::Error> {
414        (**self).query_user(localpart).await
415    }
416
417    async fn provision_user(&self, request: &ProvisionRequest) -> Result<bool, anyhow::Error> {
418        (**self).provision_user(request).await
419    }
420
421    async fn is_localpart_available(&self, localpart: &str) -> Result<bool, anyhow::Error> {
422        (**self).is_localpart_available(localpart).await
423    }
424
425    async fn upsert_device(
426        &self,
427        localpart: &str,
428        device_id: &str,
429        initial_display_name: Option<&str>,
430    ) -> Result<(), anyhow::Error> {
431        (**self)
432            .upsert_device(localpart, device_id, initial_display_name)
433            .await
434    }
435
436    async fn update_device_display_name(
437        &self,
438        localpart: &str,
439        device_id: &str,
440        display_name: &str,
441    ) -> Result<(), anyhow::Error> {
442        (**self)
443            .update_device_display_name(localpart, device_id, display_name)
444            .await
445    }
446
447    async fn delete_device(&self, localpart: &str, device_id: &str) -> Result<(), anyhow::Error> {
448        (**self).delete_device(localpart, device_id).await
449    }
450
451    async fn sync_devices(
452        &self,
453        localpart: &str,
454        devices: HashSet<String>,
455    ) -> Result<(), anyhow::Error> {
456        (**self).sync_devices(localpart, devices).await
457    }
458
459    async fn delete_user(&self, localpart: &str, erase: bool) -> Result<(), anyhow::Error> {
460        (**self).delete_user(localpart, erase).await
461    }
462
463    async fn reactivate_user(&self, localpart: &str) -> Result<(), anyhow::Error> {
464        (**self).reactivate_user(localpart).await
465    }
466
467    async fn set_displayname(
468        &self,
469        localpart: &str,
470        displayname: &str,
471    ) -> Result<(), anyhow::Error> {
472        (**self).set_displayname(localpart, displayname).await
473    }
474
475    async fn unset_displayname(&self, localpart: &str) -> Result<(), anyhow::Error> {
476        (**self).unset_displayname(localpart).await
477    }
478
479    async fn allow_cross_signing_reset(&self, localpart: &str) -> Result<(), anyhow::Error> {
480        (**self).allow_cross_signing_reset(localpart).await
481    }
482}
483
484// Implement for Arc<T> where T: HomeserverConnection
485#[async_trait::async_trait]
486impl<T: HomeserverConnection + ?Sized> HomeserverConnection for Arc<T> {
487    fn homeserver(&self) -> &str {
488        (**self).homeserver()
489    }
490
491    async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error> {
492        (**self).verify_token(token).await
493    }
494
495    async fn query_user(&self, localpart: &str) -> Result<MatrixUser, anyhow::Error> {
496        (**self).query_user(localpart).await
497    }
498
499    async fn provision_user(&self, request: &ProvisionRequest) -> Result<bool, anyhow::Error> {
500        (**self).provision_user(request).await
501    }
502
503    async fn is_localpart_available(&self, localpart: &str) -> Result<bool, anyhow::Error> {
504        (**self).is_localpart_available(localpart).await
505    }
506
507    async fn upsert_device(
508        &self,
509        localpart: &str,
510        device_id: &str,
511        initial_display_name: Option<&str>,
512    ) -> Result<(), anyhow::Error> {
513        (**self)
514            .upsert_device(localpart, device_id, initial_display_name)
515            .await
516    }
517
518    async fn update_device_display_name(
519        &self,
520        localpart: &str,
521        device_id: &str,
522        display_name: &str,
523    ) -> Result<(), anyhow::Error> {
524        (**self)
525            .update_device_display_name(localpart, device_id, display_name)
526            .await
527    }
528
529    async fn delete_device(&self, localpart: &str, device_id: &str) -> Result<(), anyhow::Error> {
530        (**self).delete_device(localpart, device_id).await
531    }
532
533    async fn sync_devices(
534        &self,
535        localpart: &str,
536        devices: HashSet<String>,
537    ) -> Result<(), anyhow::Error> {
538        (**self).sync_devices(localpart, devices).await
539    }
540
541    async fn delete_user(&self, localpart: &str, erase: bool) -> Result<(), anyhow::Error> {
542        (**self).delete_user(localpart, erase).await
543    }
544
545    async fn reactivate_user(&self, localpart: &str) -> Result<(), anyhow::Error> {
546        (**self).reactivate_user(localpart).await
547    }
548
549    async fn set_displayname(
550        &self,
551        localpart: &str,
552        displayname: &str,
553    ) -> Result<(), anyhow::Error> {
554        (**self).set_displayname(localpart, displayname).await
555    }
556
557    async fn unset_displayname(&self, localpart: &str) -> Result<(), anyhow::Error> {
558        (**self).unset_displayname(localpart).await
559    }
560
561    async fn allow_cross_signing_reset(&self, localpart: &str) -> Result<(), anyhow::Error> {
562        (**self).allow_cross_signing_reset(localpart).await
563    }
564}