1use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{
12 BrowserSession, CompatSession, CompatSessionState, CompatSsoLogin, CompatSsoLoginState, Device,
13 User,
14};
15use mas_storage::{
16 Clock, Page, Pagination,
17 compat::{CompatSessionFilter, CompatSessionRepository},
18};
19use rand::RngCore;
20use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
21use sea_query_binder::SqlxBinder;
22use sqlx::PgConnection;
23use ulid::Ulid;
24use url::Url;
25use uuid::Uuid;
26
27use crate::{
28 DatabaseError, DatabaseInconsistencyError,
29 filter::{Filter, StatementExt, StatementWithJoinsExt},
30 iden::{CompatSessions, CompatSsoLogins, UserSessions},
31 pagination::QueryBuilderExt,
32 tracing::ExecuteExt,
33};
34
35pub struct PgCompatSessionRepository<'c> {
37 conn: &'c mut PgConnection,
38}
39
40impl<'c> PgCompatSessionRepository<'c> {
41 pub fn new(conn: &'c mut PgConnection) -> Self {
44 Self { conn }
45 }
46}
47
48struct CompatSessionLookup {
49 compat_session_id: Uuid,
50 device_id: Option<String>,
51 human_name: Option<String>,
52 user_id: Uuid,
53 user_session_id: Option<Uuid>,
54 created_at: DateTime<Utc>,
55 finished_at: Option<DateTime<Utc>>,
56 is_synapse_admin: bool,
57 user_agent: Option<String>,
58 last_active_at: Option<DateTime<Utc>>,
59 last_active_ip: Option<IpAddr>,
60}
61
62impl From<CompatSessionLookup> for CompatSession {
63 fn from(value: CompatSessionLookup) -> Self {
64 let id = value.compat_session_id.into();
65
66 let state = match value.finished_at {
67 None => CompatSessionState::Valid,
68 Some(finished_at) => CompatSessionState::Finished { finished_at },
69 };
70
71 CompatSession {
72 id,
73 state,
74 user_id: value.user_id.into(),
75 user_session_id: value.user_session_id.map(Ulid::from),
76 device: value.device_id.map(Device::from),
77 human_name: value.human_name,
78 created_at: value.created_at,
79 is_synapse_admin: value.is_synapse_admin,
80 user_agent: value.user_agent,
81 last_active_at: value.last_active_at,
82 last_active_ip: value.last_active_ip,
83 }
84 }
85}
86
87#[derive(sqlx::FromRow)]
88#[enum_def]
89struct CompatSessionAndSsoLoginLookup {
90 compat_session_id: Uuid,
91 device_id: Option<String>,
92 human_name: Option<String>,
93 user_id: Uuid,
94 user_session_id: Option<Uuid>,
95 created_at: DateTime<Utc>,
96 finished_at: Option<DateTime<Utc>>,
97 is_synapse_admin: bool,
98 user_agent: Option<String>,
99 last_active_at: Option<DateTime<Utc>>,
100 last_active_ip: Option<IpAddr>,
101 compat_sso_login_id: Option<Uuid>,
102 compat_sso_login_token: Option<String>,
103 compat_sso_login_redirect_uri: Option<String>,
104 compat_sso_login_created_at: Option<DateTime<Utc>>,
105 compat_sso_login_fulfilled_at: Option<DateTime<Utc>>,
106 compat_sso_login_exchanged_at: Option<DateTime<Utc>>,
107}
108
109impl TryFrom<CompatSessionAndSsoLoginLookup> for (CompatSession, Option<CompatSsoLogin>) {
110 type Error = DatabaseInconsistencyError;
111
112 fn try_from(value: CompatSessionAndSsoLoginLookup) -> Result<Self, Self::Error> {
113 let id = value.compat_session_id.into();
114
115 let state = match value.finished_at {
116 None => CompatSessionState::Valid,
117 Some(finished_at) => CompatSessionState::Finished { finished_at },
118 };
119
120 let session = CompatSession {
121 id,
122 state,
123 user_id: value.user_id.into(),
124 device: value.device_id.map(Device::from),
125 human_name: value.human_name,
126 user_session_id: value.user_session_id.map(Ulid::from),
127 created_at: value.created_at,
128 is_synapse_admin: value.is_synapse_admin,
129 user_agent: value.user_agent,
130 last_active_at: value.last_active_at,
131 last_active_ip: value.last_active_ip,
132 };
133
134 match (
135 value.compat_sso_login_id,
136 value.compat_sso_login_token,
137 value.compat_sso_login_redirect_uri,
138 value.compat_sso_login_created_at,
139 value.compat_sso_login_fulfilled_at,
140 value.compat_sso_login_exchanged_at,
141 ) {
142 (None, None, None, None, None, None) => Ok((session, None)),
143 (
144 Some(id),
145 Some(login_token),
146 Some(redirect_uri),
147 Some(created_at),
148 fulfilled_at,
149 exchanged_at,
150 ) => {
151 let id = id.into();
152 let redirect_uri = Url::parse(&redirect_uri).map_err(|e| {
153 DatabaseInconsistencyError::on("compat_sso_logins")
154 .column("redirect_uri")
155 .row(id)
156 .source(e)
157 })?;
158
159 let state = match (fulfilled_at, exchanged_at) {
160 (Some(fulfilled_at), Some(exchanged_at)) => CompatSsoLoginState::Exchanged {
161 fulfilled_at,
162 exchanged_at,
163 compat_session_id: session.id,
164 },
165 _ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
166 };
167
168 let login = CompatSsoLogin {
169 id,
170 redirect_uri,
171 login_token,
172 created_at,
173 state,
174 };
175
176 Ok((session, Some(login)))
177 }
178 _ => Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
179 }
180 }
181}
182
183impl Filter for CompatSessionFilter<'_> {
184 fn generate_condition(&self, has_joins: bool) -> impl sea_query::IntoCondition {
185 sea_query::Condition::all()
186 .add_option(self.user().map(|user| {
187 Expr::col((CompatSessions::Table, CompatSessions::UserId)).eq(Uuid::from(user.id))
188 }))
189 .add_option(self.browser_session().map(|browser_session| {
190 Expr::col((CompatSessions::Table, CompatSessions::UserSessionId))
191 .eq(Uuid::from(browser_session.id))
192 }))
193 .add_option(self.browser_session_filter().map(|browser_session_filter| {
194 Expr::col((CompatSessions::Table, CompatSessions::UserSessionId)).in_subquery(
195 Query::select()
196 .expr(Expr::col((
197 UserSessions::Table,
198 UserSessions::UserSessionId,
199 )))
200 .apply_filter(browser_session_filter)
201 .from(UserSessions::Table)
202 .take(),
203 )
204 }))
205 .add_option(self.state().map(|state| {
206 if state.is_active() {
207 Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_null()
208 } else {
209 Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)).is_not_null()
210 }
211 }))
212 .add_option(self.auth_type().map(|auth_type| {
213 if has_joins {
216 if auth_type.is_sso_login() {
217 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId))
218 .is_not_null()
219 } else {
220 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId))
221 .is_null()
222 }
223 } else {
224 let compat_sso_logins = Query::select()
228 .expr(Expr::col((
229 CompatSsoLogins::Table,
230 CompatSsoLogins::CompatSessionId,
231 )))
232 .from(CompatSsoLogins::Table)
233 .take();
234
235 if auth_type.is_sso_login() {
236 Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
237 .eq(Expr::any(compat_sso_logins))
238 } else {
239 Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
240 .ne(Expr::all(compat_sso_logins))
241 }
242 }
243 }))
244 .add_option(self.last_active_after().map(|last_active_after| {
245 Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt))
246 .gt(last_active_after)
247 }))
248 .add_option(self.last_active_before().map(|last_active_before| {
249 Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt))
250 .lt(last_active_before)
251 }))
252 .add_option(self.device().map(|device| {
253 Expr::col((CompatSessions::Table, CompatSessions::DeviceId)).eq(device.as_str())
254 }))
255 }
256}
257
258#[async_trait]
259impl CompatSessionRepository for PgCompatSessionRepository<'_> {
260 type Error = DatabaseError;
261
262 #[tracing::instrument(
263 name = "db.compat_session.lookup",
264 skip_all,
265 fields(
266 db.query.text,
267 compat_session.id = %id,
268 ),
269 err,
270 )]
271 async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSession>, Self::Error> {
272 let res = sqlx::query_as!(
273 CompatSessionLookup,
274 r#"
275 SELECT compat_session_id
276 , device_id
277 , human_name
278 , user_id
279 , user_session_id
280 , created_at
281 , finished_at
282 , is_synapse_admin
283 , user_agent
284 , last_active_at
285 , last_active_ip as "last_active_ip: IpAddr"
286 FROM compat_sessions
287 WHERE compat_session_id = $1
288 "#,
289 Uuid::from(id),
290 )
291 .traced()
292 .fetch_optional(&mut *self.conn)
293 .await?;
294
295 let Some(res) = res else { return Ok(None) };
296
297 Ok(Some(res.into()))
298 }
299
300 #[tracing::instrument(
301 name = "db.compat_session.add",
302 skip_all,
303 fields(
304 db.query.text,
305 compat_session.id,
306 %user.id,
307 %user.username,
308 compat_session.device.id = device.as_str(),
309 ),
310 err,
311 )]
312 async fn add(
313 &mut self,
314 rng: &mut (dyn RngCore + Send),
315 clock: &dyn Clock,
316 user: &User,
317 device: Device,
318 browser_session: Option<&BrowserSession>,
319 is_synapse_admin: bool,
320 human_name: Option<String>,
321 ) -> Result<CompatSession, Self::Error> {
322 let created_at = clock.now();
323 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
324 tracing::Span::current().record("compat_session.id", tracing::field::display(id));
325
326 sqlx::query!(
327 r#"
328 INSERT INTO compat_sessions
329 (compat_session_id, user_id, device_id,
330 user_session_id, created_at, is_synapse_admin,
331 human_name)
332 VALUES ($1, $2, $3, $4, $5, $6, $7)
333 "#,
334 Uuid::from(id),
335 Uuid::from(user.id),
336 device.as_str(),
337 browser_session.map(|s| Uuid::from(s.id)),
338 created_at,
339 is_synapse_admin,
340 human_name.as_deref(),
341 )
342 .traced()
343 .execute(&mut *self.conn)
344 .await?;
345
346 Ok(CompatSession {
347 id,
348 state: CompatSessionState::default(),
349 user_id: user.id,
350 device: Some(device),
351 human_name,
352 user_session_id: browser_session.map(|s| s.id),
353 created_at,
354 is_synapse_admin,
355 user_agent: None,
356 last_active_at: None,
357 last_active_ip: None,
358 })
359 }
360
361 #[tracing::instrument(
362 name = "db.compat_session.finish",
363 skip_all,
364 fields(
365 db.query.text,
366 %compat_session.id,
367 user.id = %compat_session.user_id,
368 compat_session.device.id = compat_session.device.as_ref().map(mas_data_model::Device::as_str),
369 ),
370 err,
371 )]
372 async fn finish(
373 &mut self,
374 clock: &dyn Clock,
375 compat_session: CompatSession,
376 ) -> Result<CompatSession, Self::Error> {
377 let finished_at = clock.now();
378
379 let res = sqlx::query!(
380 r#"
381 UPDATE compat_sessions cs
382 SET finished_at = $2
383 WHERE compat_session_id = $1
384 "#,
385 Uuid::from(compat_session.id),
386 finished_at,
387 )
388 .traced()
389 .execute(&mut *self.conn)
390 .await?;
391
392 DatabaseError::ensure_affected_rows(&res, 1)?;
393
394 let compat_session = compat_session
395 .finish(finished_at)
396 .map_err(DatabaseError::to_invalid_operation)?;
397
398 Ok(compat_session)
399 }
400
401 #[tracing::instrument(
402 name = "db.compat_session.finish_bulk",
403 skip_all,
404 fields(db.query.text),
405 err,
406 )]
407 async fn finish_bulk(
408 &mut self,
409 clock: &dyn Clock,
410 filter: CompatSessionFilter<'_>,
411 ) -> Result<usize, Self::Error> {
412 let finished_at = clock.now();
413 let (sql, arguments) = Query::update()
414 .table(CompatSessions::Table)
415 .value(CompatSessions::FinishedAt, finished_at)
416 .apply_filter(filter)
417 .build_sqlx(PostgresQueryBuilder);
418
419 let res = sqlx::query_with(&sql, arguments)
420 .traced()
421 .execute(&mut *self.conn)
422 .await?;
423
424 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
425 }
426
427 #[tracing::instrument(
428 name = "db.compat_session.list",
429 skip_all,
430 fields(
431 db.query.text,
432 ),
433 err,
434 )]
435 async fn list(
436 &mut self,
437 filter: CompatSessionFilter<'_>,
438 pagination: Pagination,
439 ) -> Result<Page<(CompatSession, Option<CompatSsoLogin>)>, Self::Error> {
440 let (sql, arguments) = Query::select()
441 .expr_as(
442 Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)),
443 CompatSessionAndSsoLoginLookupIden::CompatSessionId,
444 )
445 .expr_as(
446 Expr::col((CompatSessions::Table, CompatSessions::DeviceId)),
447 CompatSessionAndSsoLoginLookupIden::DeviceId,
448 )
449 .expr_as(
450 Expr::col((CompatSessions::Table, CompatSessions::HumanName)),
451 CompatSessionAndSsoLoginLookupIden::HumanName,
452 )
453 .expr_as(
454 Expr::col((CompatSessions::Table, CompatSessions::UserId)),
455 CompatSessionAndSsoLoginLookupIden::UserId,
456 )
457 .expr_as(
458 Expr::col((CompatSessions::Table, CompatSessions::UserSessionId)),
459 CompatSessionAndSsoLoginLookupIden::UserSessionId,
460 )
461 .expr_as(
462 Expr::col((CompatSessions::Table, CompatSessions::CreatedAt)),
463 CompatSessionAndSsoLoginLookupIden::CreatedAt,
464 )
465 .expr_as(
466 Expr::col((CompatSessions::Table, CompatSessions::FinishedAt)),
467 CompatSessionAndSsoLoginLookupIden::FinishedAt,
468 )
469 .expr_as(
470 Expr::col((CompatSessions::Table, CompatSessions::IsSynapseAdmin)),
471 CompatSessionAndSsoLoginLookupIden::IsSynapseAdmin,
472 )
473 .expr_as(
474 Expr::col((CompatSessions::Table, CompatSessions::UserAgent)),
475 CompatSessionAndSsoLoginLookupIden::UserAgent,
476 )
477 .expr_as(
478 Expr::col((CompatSessions::Table, CompatSessions::LastActiveAt)),
479 CompatSessionAndSsoLoginLookupIden::LastActiveAt,
480 )
481 .expr_as(
482 Expr::col((CompatSessions::Table, CompatSessions::LastActiveIp)),
483 CompatSessionAndSsoLoginLookupIden::LastActiveIp,
484 )
485 .expr_as(
486 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)),
487 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginId,
488 )
489 .expr_as(
490 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::LoginToken)),
491 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginToken,
492 )
493 .expr_as(
494 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::RedirectUri)),
495 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginRedirectUri,
496 )
497 .expr_as(
498 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CreatedAt)),
499 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginCreatedAt,
500 )
501 .expr_as(
502 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)),
503 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginFulfilledAt,
504 )
505 .expr_as(
506 Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)),
507 CompatSessionAndSsoLoginLookupIden::CompatSsoLoginExchangedAt,
508 )
509 .from(CompatSessions::Table)
510 .left_join(
511 CompatSsoLogins::Table,
512 Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId))
513 .equals((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId)),
514 )
515 .apply_filter_with_joins(filter)
516 .generate_pagination(
517 (CompatSessions::Table, CompatSessions::CompatSessionId),
518 pagination,
519 )
520 .build_sqlx(PostgresQueryBuilder);
521
522 let edges: Vec<CompatSessionAndSsoLoginLookup> = sqlx::query_as_with(&sql, arguments)
523 .traced()
524 .fetch_all(&mut *self.conn)
525 .await?;
526
527 let page = pagination.process(edges).try_map(TryFrom::try_from)?;
528
529 Ok(page)
530 }
531
532 #[tracing::instrument(
533 name = "db.compat_session.count",
534 skip_all,
535 fields(
536 db.query.text,
537 ),
538 err,
539 )]
540 async fn count(&mut self, filter: CompatSessionFilter<'_>) -> Result<usize, Self::Error> {
541 let (sql, arguments) = sea_query::Query::select()
542 .expr(Expr::col((CompatSessions::Table, CompatSessions::CompatSessionId)).count())
543 .from(CompatSessions::Table)
544 .apply_filter(filter)
545 .build_sqlx(PostgresQueryBuilder);
546
547 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
548 .traced()
549 .fetch_one(&mut *self.conn)
550 .await?;
551
552 count
553 .try_into()
554 .map_err(DatabaseError::to_invalid_operation)
555 }
556
557 #[tracing::instrument(
558 name = "db.compat_session.record_batch_activity",
559 skip_all,
560 fields(
561 db.query.text,
562 ),
563 err,
564 )]
565 async fn record_batch_activity(
566 &mut self,
567 mut activities: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
568 ) -> Result<(), Self::Error> {
569 activities.sort_unstable();
572 let mut ids = Vec::with_capacity(activities.len());
573 let mut last_activities = Vec::with_capacity(activities.len());
574 let mut ips = Vec::with_capacity(activities.len());
575
576 for (id, last_activity, ip) in activities {
577 ids.push(Uuid::from(id));
578 last_activities.push(last_activity);
579 ips.push(ip);
580 }
581
582 let res = sqlx::query!(
583 r#"
584 UPDATE compat_sessions
585 SET last_active_at = GREATEST(t.last_active_at, compat_sessions.last_active_at)
586 , last_active_ip = COALESCE(t.last_active_ip, compat_sessions.last_active_ip)
587 FROM (
588 SELECT *
589 FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
590 AS t(compat_session_id, last_active_at, last_active_ip)
591 ) AS t
592 WHERE compat_sessions.compat_session_id = t.compat_session_id
593 "#,
594 &ids,
595 &last_activities,
596 &ips as &[Option<IpAddr>],
597 )
598 .traced()
599 .execute(&mut *self.conn)
600 .await?;
601
602 DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
603
604 Ok(())
605 }
606
607 #[tracing::instrument(
608 name = "db.compat_session.record_user_agent",
609 skip_all,
610 fields(
611 db.query.text,
612 %compat_session.id,
613 ),
614 err,
615 )]
616 async fn record_user_agent(
617 &mut self,
618 mut compat_session: CompatSession,
619 user_agent: String,
620 ) -> Result<CompatSession, Self::Error> {
621 let res = sqlx::query!(
622 r#"
623 UPDATE compat_sessions
624 SET user_agent = $2
625 WHERE compat_session_id = $1
626 "#,
627 Uuid::from(compat_session.id),
628 &*user_agent,
629 )
630 .traced()
631 .execute(&mut *self.conn)
632 .await?;
633
634 compat_session.user_agent = Some(user_agent);
635
636 DatabaseError::ensure_affected_rows(&res, 1)?;
637
638 Ok(compat_session)
639 }
640
641 #[tracing::instrument(
642 name = "repository.compat_session.set_human_name",
643 skip(self),
644 fields(
645 compat_session.id = %compat_session.id,
646 compat_session.human_name = ?human_name,
647 ),
648 err,
649 )]
650 async fn set_human_name(
651 &mut self,
652 mut compat_session: CompatSession,
653 human_name: Option<String>,
654 ) -> Result<CompatSession, Self::Error> {
655 let res = sqlx::query!(
656 r#"
657 UPDATE compat_sessions
658 SET human_name = $2
659 WHERE compat_session_id = $1
660 "#,
661 Uuid::from(compat_session.id),
662 human_name.as_deref(),
663 )
664 .traced()
665 .execute(&mut *self.conn)
666 .await?;
667
668 compat_session.human_name = human_name;
669
670 DatabaseError::ensure_affected_rows(&res, 1)?;
671
672 Ok(compat_session)
673 }
674}