mas_storage_pg/compat/
session.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{
12    BrowserSession, CompatSession, CompatSessionState, CompatSsoLogin, CompatSsoLoginState, Device,
13    User,
14};
15use mas_storage::{
16    Clock, Page, Pagination,
17    compat::{CompatSessionFilter, CompatSessionRepository},
18};
19use rand::RngCore;
20use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
21use sea_query_binder::SqlxBinder;
22use sqlx::PgConnection;
23use ulid::Ulid;
24use url::Url;
25use uuid::Uuid;
26
27use crate::{
28    DatabaseError, DatabaseInconsistencyError,
29    filter::{Filter, StatementExt, StatementWithJoinsExt},
30    iden::{CompatSessions, CompatSsoLogins, UserSessions},
31    pagination::QueryBuilderExt,
32    tracing::ExecuteExt,
33};
34
35/// An implementation of [`CompatSessionRepository`] for a PostgreSQL connection
36pub struct PgCompatSessionRepository<'c> {
37    conn: &'c mut PgConnection,
38}
39
40impl<'c> PgCompatSessionRepository<'c> {
41    /// Create a new [`PgCompatSessionRepository`] from an active PostgreSQL
42    /// connection
43    pub fn new(conn: &'c mut PgConnection) -> Self {
44        Self { conn }
45    }
46}
47
48struct CompatSessionLookup {
49    compat_session_id: Uuid,
50    device_id: Option<String>,
51    human_name: Option<String>,
52    user_id: Uuid,
53    user_session_id: Option<Uuid>,
54    created_at: DateTime<Utc>,
55    finished_at: Option<DateTime<Utc>>,
56    is_synapse_admin: bool,
57    user_agent: Option<String>,
58    last_active_at: Option<DateTime<Utc>>,
59    last_active_ip: Option<IpAddr>,
60}
61
62impl From<CompatSessionLookup> for CompatSession {
63    fn from(value: CompatSessionLookup) -> Self {
64        let id = value.compat_session_id.into();
65
66        let state = match value.finished_at {
67            None => CompatSessionState::Valid,
68            Some(finished_at) => CompatSessionState::Finished { finished_at },
69        };
70
71        CompatSession {
72            id,
73            state,
74            user_id: value.user_id.into(),
75            user_session_id: value.user_session_id.map(Ulid::from),
76            device: value.device_id.map(Device::from),
77            human_name: value.human_name,
78            created_at: value.created_at,
79            is_synapse_admin: value.is_synapse_admin,
80            user_agent: value.user_agent,
81            last_active_at: value.last_active_at,
82            last_active_ip: value.last_active_ip,
83        }
84    }
85}
86
87#[derive(sqlx::FromRow)]
88#[enum_def]
89struct CompatSessionAndSsoLoginLookup {
90    compat_session_id: Uuid,
91    device_id: Option<String>,
92    human_name: Option<String>,
93    user_id: Uuid,
94    user_session_id: Option<Uuid>,
95    created_at: DateTime<Utc>,
96    finished_at: Option<DateTime<Utc>>,
97    is_synapse_admin: bool,
98    user_agent: Option<String>,
99    last_active_at: Option<DateTime<Utc>>,
100    last_active_ip: Option<IpAddr>,
101    compat_sso_login_id: Option<Uuid>,
102    compat_sso_login_token: Option<String>,
103    compat_sso_login_redirect_uri: Option<String>,
104    compat_sso_login_created_at: Option<DateTime<Utc>>,
105    compat_sso_login_fulfilled_at: Option<DateTime<Utc>>,
106    compat_sso_login_exchanged_at: Option<DateTime<Utc>>,
107}
108
109impl TryFrom<CompatSessionAndSsoLoginLookup> for (CompatSession, Option<CompatSsoLogin>) {
110    type Error = DatabaseInconsistencyError;
111
112    fn try_from(value: CompatSessionAndSsoLoginLookup) -> Result<Self, Self::Error> {
113        let id = value.compat_session_id.into();
114
115        let state = match value.finished_at {
116            None => CompatSessionState::Valid,
117            Some(finished_at) => CompatSessionState::Finished { finished_at },
118        };
119
120        let session = CompatSession {
121            id,
122            state,
123            user_id: value.user_id.into(),
124            device: value.device_id.map(Device::from),
125            human_name: value.human_name,
126            user_session_id: value.user_session_id.map(Ulid::from),
127            created_at: value.created_at,
128            is_synapse_admin: value.is_synapse_admin,
129            user_agent: value.user_agent,
130            last_active_at: value.last_active_at,
131            last_active_ip: value.last_active_ip,
132        };
133
134        match (
135            value.compat_sso_login_id,
136            value.compat_sso_login_token,
137            value.compat_sso_login_redirect_uri,
138            value.compat_sso_login_created_at,
139            value.compat_sso_login_fulfilled_at,
140            value.compat_sso_login_exchanged_at,
141        ) {
142            (None, None, None, None, None, None) => Ok((session, None)),
143            (
144                Some(id),
145                Some(login_token),
146                Some(redirect_uri),
147                Some(created_at),
148                fulfilled_at,
149                exchanged_at,
150            ) => {
151                let id = id.into();
152                let redirect_uri = Url::parse(&redirect_uri).map_err(|e| {
153                    DatabaseInconsistencyError::on("compat_sso_logins")
154                        .column("redirect_uri")
155                        .row(id)
156                        .source(e)
157                })?;
158
159                let state = match (fulfilled_at, exchanged_at) {
160                    (Some(fulfilled_at), Some(exchanged_at)) => CompatSsoLoginState::Exchanged {
161                        fulfilled_at,
162                        exchanged_at,
163                        compat_session_id: session.id,
164                    },
165                    _ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
166                };
167
168                let login = CompatSsoLogin {
169                    id,
170                    redirect_uri,
171                    login_token,
172                    created_at,
173                    state,
174                };
175
176                Ok((session, Some(login)))
177            }
178            _ => Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
179        }
180    }
181}
182
183impl Filter for CompatSessionFilter<'_> {
184    fn generate_condition(&self, has_joins: bool) -> impl sea_query::IntoCondition {
185        sea_query::Condition::all()
186            .add_option(self.user().map(|user| {
187                Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id))
188            }))
189            .add_option(self.browser_session().map(|browser_session| {
190                Expr::col((CompatSessions::Table, CompatSessions::UserSessionId))
191                    .eq(Uuid::from(browser_session.id))
192            }))
193            .add_option(self.browser_session_filter().map(|browser_session_filter| {
194                Expr::col((CompatSessions::Table, CompatSessions::UserSessionId)).in_subquery(
195                    Query::select()
196                        .expr(Expr::col((
197                            UserSessions::Table,
198                            UserSessions::UserSessionId,
199                        )))
200                        .apply_filter(browser_session_filter)
201                        .from(UserSessions::Table)
202                        .take(),
203                )
204            }))
205            .add_option(self.state().map(|state| {
206                if state.is_active() {
207                    Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
208                } else {
209                    Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null()
210                }
211            }))
212            .add_option(self.auth_type().map(|auth_type| {
213                // In in the SELECT to list sessions, we can rely on the JOINed table, whereas
214                // in other queries we need to do a subquery
215                if has_joins {
216                    if auth_type.is_sso_login() {
217                        Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId))
218                            .is_not_null()
219                    } else {
220                        Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId))
221                            .is_null()
222                    }
223                } else {
224                    // This builds either a:
225                    // `WHERE compat_session_id = ANY(...)`
226                    // or a `WHERE compat_session_id <> ALL(...)`
227                    let compat_sso_logins = Query::select()
228                        .expr(Expr::col((
229                            CompatSsoLogins::Table,
230                            CompatSsoLogins::CompatSessionId,
231                        )))
232                        .from(CompatSsoLogins::Table)
233                        .take();
234
235                    if auth_type.is_sso_login() {
236                        Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
237                            .eq(Expr::any(compat_sso_logins))
238                    } else {
239                        Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
240                            .ne(Expr::all(compat_sso_logins))
241                    }
242                }
243            }))
244            .add_option(self.last_active_after().map(|last_active_after| {
245                Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt))
246                    .gt(last_active_after)
247            }))
248            .add_option(self.last_active_before().map(|last_active_before| {
249                Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt))
250                    .lt(last_active_before)
251            }))
252            .add_option(self.device().map(|device| {
253                Expr::col((CompatSessions::Table, CompatSessions::DeviceId)).eq(device.as_str())
254            }))
255    }
256}
257
258#[async_trait]
259impl CompatSessionRepository for PgCompatSessionRepository<'_> {
260    type Error = DatabaseError;
261
262    #[tracing::instrument(
263        name = "db.compat_session.lookup",
264        skip_all,
265        fields(
266            db.query.text,
267            compat_session.id = %id,
268        ),
269        err,
270    )]
271    async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSession>, Self::Error> {
272        let res = sqlx::query_as!(
273            CompatSessionLookup,
274            r#"
275                SELECT compat_session_id
276                     , device_id
277                     , human_name
278                     , user_id
279                     , user_session_id
280                     , created_at
281                     , finished_at
282                     , is_synapse_admin
283                     , user_agent
284                     , last_active_at
285                     , last_active_ip as "last_active_ip: IpAddr"
286                FROM compat_sessions
287                WHERE compat_session_id = $1
288            "#,
289            Uuid::from(id),
290        )
291        .traced()
292        .fetch_optional(&mut *self.conn)
293        .await?;
294
295        let Some(res) = res else { return Ok(None) };
296
297        Ok(Some(res.into()))
298    }
299
300    #[tracing::instrument(
301        name = "db.compat_session.add",
302        skip_all,
303        fields(
304            db.query.text,
305            compat_session.id,
306            %user.id,
307            %user.username,
308            compat_session.device.id = device.as_str(),
309        ),
310        err,
311    )]
312    async fn add(
313        &mut self,
314        rng: &mut (dyn RngCore + Send),
315        clock: &dyn Clock,
316        user: &User,
317        device: Device,
318        browser_session: Option<&BrowserSession>,
319        is_synapse_admin: bool,
320        human_name: Option<String>,
321    ) -> Result<CompatSession, Self::Error> {
322        let created_at = clock.now();
323        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
324        tracing::Span::current().record("compat_session.id", tracing::field::display(id));
325
326        sqlx::query!(
327            r#"
328                INSERT INTO compat_sessions
329                    (compat_session_id, user_id, device_id,
330                     user_session_id, created_at, is_synapse_admin,
331                     human_name)
332                VALUES ($1, $2, $3, $4, $5, $6, $7)
333            "#,
334            Uuid::from(id),
335            Uuid::from(user.id),
336            device.as_str(),
337            browser_session.map(|s| Uuid::from(s.id)),
338            created_at,
339            is_synapse_admin,
340            human_name.as_deref(),
341        )
342        .traced()
343        .execute(&mut *self.conn)
344        .await?;
345
346        Ok(CompatSession {
347            id,
348            state: CompatSessionState::default(),
349            user_id: user.id,
350            device: Some(device),
351            human_name,
352            user_session_id: browser_session.map(|s| s.id),
353            created_at,
354            is_synapse_admin,
355            user_agent: None,
356            last_active_at: None,
357            last_active_ip: None,
358        })
359    }
360
361    #[tracing::instrument(
362        name = "db.compat_session.finish",
363        skip_all,
364        fields(
365            db.query.text,
366            %compat_session.id,
367            user.id = %compat_session.user_id,
368            compat_session.device.id = compat_session.device.as_ref().map(mas_data_model::Device::as_str),
369        ),
370        err,
371    )]
372    async fn finish(
373        &mut self,
374        clock: &dyn Clock,
375        compat_session: CompatSession,
376    ) -> Result<CompatSession, Self::Error> {
377        let finished_at = clock.now();
378
379        let res = sqlx::query!(
380            r#"
381                UPDATE compat_sessions cs
382                SET finished_at = $2
383                WHERE compat_session_id = $1
384            "#,
385            Uuid::from(compat_session.id),
386            finished_at,
387        )
388        .traced()
389        .execute(&mut *self.conn)
390        .await?;
391
392        DatabaseError::ensure_affected_rows(&res, 1)?;
393
394        let compat_session = compat_session
395            .finish(finished_at)
396            .map_err(DatabaseError::to_invalid_operation)?;
397
398        Ok(compat_session)
399    }
400
401    #[tracing::instrument(
402        name = "db.compat_session.finish_bulk",
403        skip_all,
404        fields(db.query.text),
405        err,
406    )]
407    async fn finish_bulk(
408        &mut self,
409        clock: &dyn Clock,
410        filter: CompatSessionFilter<'_>,
411    ) -> Result<usize, Self::Error> {
412        let finished_at = clock.now();
413        let (sql, arguments) = Query::update()
414            .table(CompatSessions::Table)
415            .value(CompatSessions::FinishedAt, finished_at)
416            .apply_filter(filter)
417            .build_sqlx(PostgresQueryBuilder);
418
419        let res = sqlx::query_with(&sql, arguments)
420            .traced()
421            .execute(&mut *self.conn)
422            .await?;
423
424        Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
425    }
426
427    #[tracing::instrument(
428        name = "db.compat_session.list",
429        skip_all,
430        fields(
431            db.query.text,
432        ),
433        err,
434    )]
435    async fn list(
436        &mut self,
437        filter: CompatSessionFilter<'_>,
438        pagination: Pagination,
439    ) -> Result<Page<(CompatSession, Option<CompatSsoLogin>)>, Self::Error> {
440        let (sql, arguments) = Query::select()
441            .expr_as(
442                Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)),
443                CompatSessionAndSsoLoginLookupIden::CompatSessionId,
444            )
445            .expr_as(
446                Expr::col((CompatSessions::Table, CompatSessions::DeviceId)),
447                CompatSessionAndSsoLoginLookupIden::DeviceId,
448            )
449            .expr_as(
450                Expr::col((CompatSessions::Table, CompatSessions::HumanName)),
451                CompatSessionAndSsoLoginLookupIden::HumanName,
452            )
453            .expr_as(
454                Expr::col((CompatSessions::Table, CompatSessions::UserId)),
455                CompatSessionAndSsoLoginLookupIden::UserId,
456            )
457            .expr_as(
458                Expr::col((CompatSessions::Table, CompatSessions::UserSessionId)),
459                CompatSessionAndSsoLoginLookupIden::UserSessionId,
460            )
461            .expr_as(
462                Expr::col((CompatSessions::Table, CompatSessions::CreatedAt)),
463                CompatSessionAndSsoLoginLookupIden::CreatedAt,
464            )
465            .expr_as(
466                Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)),
467                CompatSessionAndSsoLoginLookupIden::FinishedAt,
468            )
469            .expr_as(
470                Expr::col((CompatSessions::Table, CompatSessions::IsSynapseAdmin)),
471                CompatSessionAndSsoLoginLookupIden::IsSynapseAdmin,
472            )
473            .expr_as(
474                Expr::col((CompatSessions::Table, CompatSessions::UserAgent)),
475                CompatSessionAndSsoLoginLookupIden::UserAgent,
476            )
477            .expr_as(
478                Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt)),
479                CompatSessionAndSsoLoginLookupIden::LastActiveAt,
480            )
481            .expr_as(
482                Expr::col((CompatSessions::Table, CompatSessions::LastActiveIp)),
483                CompatSessionAndSsoLoginLookupIden::LastActiveIp,
484            )
485            .expr_as(
486                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)),
487                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginId,
488            )
489            .expr_as(
490                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::LoginToken)),
491                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginToken,
492            )
493            .expr_as(
494                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::RedirectUri)),
495                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginRedirectUri,
496            )
497            .expr_as(
498                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CreatedAt)),
499                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginCreatedAt,
500            )
501            .expr_as(
502                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)),
503                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginFulfilledAt,
504            )
505            .expr_as(
506                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)),
507                CompatSessionAndSsoLoginLookupIden::CompatSsoLoginExchangedAt,
508            )
509            .from(CompatSessions::Table)
510            .left_join(
511                CompatSsoLogins::Table,
512                Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
513                    .equals((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId)),
514            )
515            .apply_filter_with_joins(filter)
516            .generate_pagination(
517                (CompatSessions::Table, CompatSessions::CompatSessionId),
518                pagination,
519            )
520            .build_sqlx(PostgresQueryBuilder);
521
522        let edges: Vec<CompatSessionAndSsoLoginLookup> = sqlx::query_as_with(&sql, arguments)
523            .traced()
524            .fetch_all(&mut *self.conn)
525            .await?;
526
527        let page = pagination.process(edges).try_map(TryFrom::try_from)?;
528
529        Ok(page)
530    }
531
532    #[tracing::instrument(
533        name = "db.compat_session.count",
534        skip_all,
535        fields(
536            db.query.text,
537        ),
538        err,
539    )]
540    async fn count(&mut self, filter: CompatSessionFilter<'_>) -> Result<usize, Self::Error> {
541        let (sql, arguments) = sea_query::Query::select()
542            .expr(Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)).count())
543            .from(CompatSessions::Table)
544            .apply_filter(filter)
545            .build_sqlx(PostgresQueryBuilder);
546
547        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
548            .traced()
549            .fetch_one(&mut *self.conn)
550            .await?;
551
552        count
553            .try_into()
554            .map_err(DatabaseError::to_invalid_operation)
555    }
556
557    #[tracing::instrument(
558        name = "db.compat_session.record_batch_activity",
559        skip_all,
560        fields(
561            db.query.text,
562        ),
563        err,
564    )]
565    async fn record_batch_activity(
566        &mut self,
567        mut activities: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
568    ) -> Result<(), Self::Error> {
569        // Sort the activity by ID, so that when batching the updates, Postgres
570        // locks the rows in a stable order, preventing deadlocks
571        activities.sort_unstable();
572        let mut ids = Vec::with_capacity(activities.len());
573        let mut last_activities = Vec::with_capacity(activities.len());
574        let mut ips = Vec::with_capacity(activities.len());
575
576        for (id, last_activity, ip) in activities {
577            ids.push(Uuid::from(id));
578            last_activities.push(last_activity);
579            ips.push(ip);
580        }
581
582        let res = sqlx::query!(
583            r#"
584                UPDATE compat_sessions
585                SET last_active_at = GREATEST(t.last_active_at, compat_sessions.last_active_at)
586                  , last_active_ip = COALESCE(t.last_active_ip, compat_sessions.last_active_ip)
587                FROM (
588                    SELECT *
589                    FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
590                        AS t(compat_session_id, last_active_at, last_active_ip)
591                ) AS t
592                WHERE compat_sessions.compat_session_id = t.compat_session_id
593            "#,
594            &ids,
595            &last_activities,
596            &ips as &[Option<IpAddr>],
597        )
598        .traced()
599        .execute(&mut *self.conn)
600        .await?;
601
602        DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
603
604        Ok(())
605    }
606
607    #[tracing::instrument(
608        name = "db.compat_session.record_user_agent",
609        skip_all,
610        fields(
611            db.query.text,
612            %compat_session.id,
613        ),
614        err,
615    )]
616    async fn record_user_agent(
617        &mut self,
618        mut compat_session: CompatSession,
619        user_agent: String,
620    ) -> Result<CompatSession, Self::Error> {
621        let res = sqlx::query!(
622            r#"
623            UPDATE compat_sessions
624            SET user_agent = $2
625            WHERE compat_session_id = $1
626        "#,
627            Uuid::from(compat_session.id),
628            &*user_agent,
629        )
630        .traced()
631        .execute(&mut *self.conn)
632        .await?;
633
634        compat_session.user_agent = Some(user_agent);
635
636        DatabaseError::ensure_affected_rows(&res, 1)?;
637
638        Ok(compat_session)
639    }
640
641    #[tracing::instrument(
642        name = "repository.compat_session.set_human_name",
643        skip(self),
644        fields(
645            compat_session.id = %compat_session.id,
646            compat_session.human_name = ?human_name,
647        ),
648        err,
649    )]
650    async fn set_human_name(
651        &mut self,
652        mut compat_session: CompatSession,
653        human_name: Option<String>,
654    ) -> Result<CompatSession, Self::Error> {
655        let res = sqlx::query!(
656            r#"
657            UPDATE compat_sessions
658            SET human_name = $2
659            WHERE compat_session_id = $1
660        "#,
661            Uuid::from(compat_session.id),
662            human_name.as_deref(),
663        )
664        .traced()
665        .execute(&mut *self.conn)
666        .await?;
667
668        compat_session.human_name = human_name;
669
670        DatabaseError::ensure_affected_rows(&res, 1)?;
671
672        Ok(compat_session)
673    }
674}