diff --git a/src/controllers/helpers/pagination.rs b/src/controllers/helpers/pagination.rs index 5e27806a103..22948afb55c 100644 --- a/src/controllers/helpers/pagination.rs +++ b/src/controllers/helpers/pagination.rs @@ -7,12 +7,14 @@ use crate::models::helpers::with_count::*; use crate::util::errors::{bad_request, AppResult}; use crate::util::{HeaderMapExt, RequestUtils}; -use crate::util::diesel::prelude::*; use base64::{engine::general_purpose, Engine}; use diesel::pg::Pg; +use diesel::prelude::*; use diesel::query_builder::{AstPass, Query, QueryFragment, QueryId}; -use diesel::query_dsl::LoadQuery; use diesel::sql_types::BigInt; +use diesel_async::AsyncPgConnection; +use futures_util::future::BoxFuture; +use futures_util::{FutureExt, TryStreamExt}; use http::header; use indexmap::IndexMap; use serde::{Deserialize, Serialize}; @@ -250,16 +252,29 @@ pub(crate) struct PaginatedQuery { } impl PaginatedQuery { - pub(crate) fn load<'a, U, Conn>(self, conn: &mut Conn) -> QueryResult> + pub fn load<'a, U>( + self, + conn: &'a mut AsyncPgConnection, + ) -> BoxFuture<'a, QueryResult>> where - Self: LoadQuery<'a, Conn, WithCount>, + Self: diesel_async::methods::LoadQuery<'a, AsyncPgConnection, WithCount>, + T: 'a, + U: Send + 'a, { + use diesel_async::methods::LoadQuery; + let options = self.options.clone(); - let records_and_total = self.internal_load(conn)?.collect::>()?; - Ok(Paginated { - records_and_total, - options, - }) + let future = self.internal_load(conn); + + async move { + let records_and_total = future.await?.try_collect().await?; + + Ok(Paginated { + records_and_total, + options, + }) + } + .boxed() } } @@ -272,8 +287,6 @@ impl Query for PaginatedQuery { type SqlType = (T::SqlType, BigInt); } -impl diesel::RunQueryDsl for PaginatedQuery {} - impl QueryFragment for PaginatedQuery where T: QueryFragment, @@ -366,8 +379,6 @@ impl< type SqlType = (T::SqlType, BigInt); } -impl diesel::RunQueryDsl for PaginatedQueryWithCountSubq {} - impl QueryFragment for PaginatedQueryWithCountSubq where T: QueryFragment, @@ -390,16 +401,30 @@ where } impl PaginatedQueryWithCountSubq { - pub(crate) fn load<'a, U, Conn>(self, conn: &mut Conn) -> QueryResult> + pub fn load<'a, U>( + self, + conn: &'a mut AsyncPgConnection, + ) -> BoxFuture<'a, QueryResult>> where - Self: LoadQuery<'a, Conn, WithCount>, + Self: diesel_async::methods::LoadQuery<'a, AsyncPgConnection, WithCount> + Send, + C: 'a, + T: 'a, + U: Send + 'a, { + use diesel_async::methods::LoadQuery; + let options = self.options.clone(); - let records_and_total = self.internal_load(conn)?.collect::>()?; - Ok(Paginated { - records_and_total, - options, - }) + let future = self.internal_load(conn); + + async move { + let records_and_total = future.await?.try_collect().await?; + + Ok(Paginated { + records_and_total, + options, + }) + } + .boxed() } } diff --git a/src/controllers/keyword.rs b/src/controllers/keyword.rs index a3598462e91..6843b72461a 100644 --- a/src/controllers/keyword.rs +++ b/src/controllers/keyword.rs @@ -2,14 +2,12 @@ use crate::app::AppState; use crate::controllers::helpers::pagination::PaginationOptions; use crate::controllers::helpers::{pagination::Paginated, Paginate}; use crate::models::Keyword; -use crate::tasks::spawn_blocking; use crate::util::errors::AppResult; use crate::views::EncodableKeyword; use axum::extract::{Path, Query}; use axum_extra::json; use axum_extra::response::ErasedJson; use diesel::prelude::*; -use diesel_async::async_connection_wrapper::AsyncConnectionWrapper; use http::request::Parts; #[derive(Deserialize)] @@ -30,23 +28,18 @@ pub async fn index(state: AppState, qp: Query, req: Parts) -> AppRes let query = query.pages_pagination(PaginationOptions::builder().gather(&req)?); - let conn = state.db_read().await?; - spawn_blocking(move || { - let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into(); - - let data: Paginated = query.load(conn)?; - let total = data.total(); - let kws = data - .into_iter() - .map(Keyword::into) - .collect::>(); - - Ok(json!({ - "keywords": kws, - "meta": { "total": total }, - })) - }) - .await? + let mut conn = state.db_read().await?; + let data: Paginated = query.load(&mut conn).await?; + let total = data.total(); + let kws = data + .into_iter() + .map(Keyword::into) + .collect::>(); + + Ok(json!({ + "keywords": kws, + "meta": { "total": total }, + })) } /// Handles the `GET /keywords/:keyword_id` route. diff --git a/src/controllers/krate/search.rs b/src/controllers/krate/search.rs index a2c09be6be6..990e91d89d0 100644 --- a/src/controllers/krate/search.rs +++ b/src/controllers/krate/search.rs @@ -10,8 +10,8 @@ use diesel_async::async_connection_wrapper::AsyncConnectionWrapper; use diesel_async::AsyncPgConnection; use diesel_full_text_search::*; use http::request::Parts; -use std::cell::OnceCell; -use tokio::runtime::Handle; +use std::sync::OnceLock; +use tracing::Instrument; use crate::app::AppState; use crate::controllers::helpers::Paginate; @@ -20,10 +20,9 @@ use crate::schema::*; use crate::util::errors::{bad_request, AppResult}; use crate::views::EncodableCrate; -use crate::controllers::helpers::pagination::{Page, Paginated, PaginationOptions}; +use crate::controllers::helpers::pagination::{Page, PaginationOptions}; use crate::models::krate::ALL_COLUMNS; use crate::sql::{array_agg, canon_crate_name, lower}; -use crate::tasks::spawn_blocking; use crate::util::RequestUtils; /// Handles the `GET /crates` route. @@ -48,226 +47,215 @@ use crate::util::RequestUtils; /// function out to cover the different use cases, and create unit tests /// for them. pub async fn search(app: AppState, req: Parts) -> AppResult { + use diesel_async::RunQueryDsl; + let conn = app.db_read().await?; - spawn_blocking(move || { - use diesel::RunQueryDsl; - let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into(); + let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into(); + + use diesel::sql_types::Float; + use seek::*; + + let params = req.query(); + let option_param = |s| match params.get(s).map(|v| v.as_str()) { + Some(v) if v.contains('\0') => Err(bad_request(format!( + "parameter {s} cannot contain a null byte" + ))), + Some(v) => Ok(Some(v)), + None => Ok(None), + }; + let sort = option_param("sort")?; + let include_yanked = option_param("include_yanked")? + .map(|s| s == "yes") + .unwrap_or(true); + + let filter_params = FilterParams { + q_string: option_param("q")?, + include_yanked, + category: option_param("category")?, + all_keywords: option_param("all_keywords")?, + keyword: option_param("keyword")?, + letter: option_param("letter")?, + user_id: option_param("user_id")?.and_then(|s| s.parse::().ok()), + team_id: option_param("team_id")?.and_then(|s| s.parse::().ok()), + following: option_param("following")?.is_some(), + has_ids: option_param("ids[]")?.is_some(), + ..Default::default() + }; + + let selection = ( + ALL_COLUMNS, + false.into_sql::(), + crate_downloads::downloads, + recent_crate_downloads::downloads.nullable(), + 0_f32.into_sql::(), + versions::num.nullable(), + versions::yanked.nullable(), + ); - use diesel::sql_types::Float; - use seek::*; + let mut seek: Option = None; + let mut query = filter_params + .make_query(&req, conn) + .await? + .inner_join(crate_downloads::table) + .left_join(recent_crate_downloads::table) + .left_join(default_versions::table) + .left_join(versions::table.on(default_versions::version_id.eq(versions::id))) + .select(selection); - let params = req.query(); - let option_param = |s| match params.get(s).map(|v| v.as_str()) { - Some(v) if v.contains('\0') => Err(bad_request(format!( - "parameter {s} cannot contain a null byte" - ))), - Some(v) => Ok(Some(v)), - None => Ok(None), - }; - let sort = option_param("sort")?; - let include_yanked = option_param("include_yanked")? - .map(|s| s == "yes") - .unwrap_or(true); - - let filter_params = FilterParams { - q_string: option_param("q")?, - include_yanked, - category: option_param("category")?, - all_keywords: option_param("all_keywords")?, - keyword: option_param("keyword")?, - letter: option_param("letter")?, - user_id: option_param("user_id")?.and_then(|s| s.parse::().ok()), - team_id: option_param("team_id")?.and_then(|s| s.parse::().ok()), - following: option_param("following")?.is_some(), - has_ids: option_param("ids[]")?.is_some(), - ..Default::default() - }; + if let Some(q_string) = &filter_params.q_string { + if !q_string.is_empty() { + let sort = sort.unwrap_or("relevance"); - let selection = ( - ALL_COLUMNS, - false.into_sql::(), - crate_downloads::downloads, - recent_crate_downloads::downloads.nullable(), - 0_f32.into_sql::(), - versions::num.nullable(), - versions::yanked.nullable(), - ); - - let mut seek: Option = None; - let mut query = filter_params - .make_query(&req, conn)? - .inner_join(crate_downloads::table) - .left_join(recent_crate_downloads::table) - .left_join(default_versions::table) - .left_join(versions::table.on(default_versions::version_id.eq(versions::id))) - .select(selection); - - if let Some(q_string) = &filter_params.q_string { - if !q_string.is_empty() { - let sort = sort.unwrap_or("relevance"); - - query = query.order(Crate::with_name(q_string).desc()); - - if sort == "relevance" { - let q = sql::("plainto_tsquery('english', ") - .bind::(q_string) - .sql(")"); - let rank = ts_rank_cd(crates::textsearchable_index_col, q); - query = query.select(( - ALL_COLUMNS, - Crate::with_name(q_string), - crate_downloads::downloads, - recent_crate_downloads::downloads.nullable(), - rank.clone(), - versions::num.nullable(), - versions::yanked.nullable(), - )); - seek = Some(Seek::Relevance); - query = query.then_order_by(rank.desc()) - } else { - query = query.select(( - ALL_COLUMNS, - Crate::with_name(q_string), - crate_downloads::downloads, - recent_crate_downloads::downloads.nullable(), - 0_f32.into_sql::(), - versions::num.nullable(), - versions::yanked.nullable(), - )); - seek = Some(Seek::Query); - } - } - } - - // Any sort other than 'relevance' (default) would ignore exact crate name matches - // Seek-based pagination requires a unique ordering to avoid unexpected row skipping - // during pagination. - // Therefore, when the ordering isn't unique an auxiliary ordering column should be added - // to ensure predictable pagination behavior. - if sort == Some("downloads") { - seek = Some(Seek::Downloads); - query = query.order((crate_downloads::downloads.desc(), crates::id.desc())) - } else if sort == Some("recent-downloads") { - seek = Some(Seek::RecentDownloads); - query = query.order(( - recent_crate_downloads::downloads.desc().nulls_last(), - crates::id.desc(), - )) - } else if sort == Some("recent-updates") { - seek = Some(Seek::RecentUpdates); - query = query.order((crates::updated_at.desc(), crates::id.desc())); - } else if sort == Some("new") { - seek = Some(Seek::New); - query = query.order((crates::created_at.desc(), crates::id.desc())); - } else { - seek = seek.or(Some(Seek::Name)); - // Since the name is unique value, the inherent ordering becomes naturally unique. - // Therefore, an additional auxiliary ordering column is unnecessary in this case. - query = query.then_order_by(crates::name.asc()) - } + query = query.order(Crate::with_name(q_string).desc()); - let pagination: PaginationOptions = PaginationOptions::builder() - .limit_page_numbers() - .enable_seek(true) - .gather(&req)?; - - let explicit_page = matches!(pagination.page, Page::Numeric(_)); - - // To avoid breaking existing users, seek-based pagination is only used if an explicit page has - // not been provided. This way clients relying on meta.next_page will use the faster seek-based - // paginations, while client hardcoding pages handling will use the slower offset-based code. - let (total, next_page, prev_page, data, conn) = if !explicit_page && seek.is_some() { - let seek = seek.unwrap(); - if let Some(condition) = seek - .after(&pagination.page)? - .map(|s| filter_params.seek_after(&s)) - { - query = query.filter(condition); + if sort == "relevance" { + let q = sql::("plainto_tsquery('english', ") + .bind::(q_string) + .sql(")"); + let rank = ts_rank_cd(crates::textsearchable_index_col, q); + query = query.select(( + ALL_COLUMNS, + Crate::with_name(q_string), + crate_downloads::downloads, + recent_crate_downloads::downloads.nullable(), + rank.clone(), + versions::num.nullable(), + versions::yanked.nullable(), + )); + seek = Some(Seek::Relevance); + query = query.then_order_by(rank.desc()) + } else { + query = query.select(( + ALL_COLUMNS, + Crate::with_name(q_string), + crate_downloads::downloads, + recent_crate_downloads::downloads.nullable(), + 0_f32.into_sql::(), + versions::num.nullable(), + versions::yanked.nullable(), + )); + seek = Some(Seek::Query); } + } + } - // This does a full index-only scan over the crates table to gather how many crates were - // published. Unfortunately on PostgreSQL counting the rows in a table requires scanning - // the table, and the `total` field is part of the stable registries API. - // - // If this becomes a problem in the future the crates count could be denormalized, at least - // for the filterless happy path. - let query = query.pages_pagination_with_count_query( - pagination, - filter_params.make_query(&req, conn)?.count(), - ); - let data: Paginated = - info_span!("db.query", message = "SELECT ..., COUNT(*) FROM crates") - .in_scope(|| query.load(conn))?; - - ( - data.total(), - data.next_seek_params(|last| seek.to_payload(last))? - .map(|p| req.query_with_params(p)), - None, - data.into_iter().collect::>(), - conn, - ) - } else { - let query = query.pages_pagination_with_count_query( - pagination, - filter_params.make_query(&req, conn)?.count(), - ); - let data: Paginated = - info_span!("db.query", message = "SELECT ..., COUNT(*) FROM crates") - .in_scope(|| query.load(conn))?; - ( - data.total(), - data.next_page_params().map(|p| req.query_with_params(p)), - data.prev_page_params().map(|p| req.query_with_params(p)), - data.into_iter().collect::>(), - conn, - ) - }; + // Any sort other than 'relevance' (default) would ignore exact crate name matches + // Seek-based pagination requires a unique ordering to avoid unexpected row skipping + // during pagination. + // Therefore, when the ordering isn't unique an auxiliary ordering column should be added + // to ensure predictable pagination behavior. + if sort == Some("downloads") { + seek = Some(Seek::Downloads); + query = query.order((crate_downloads::downloads.desc(), crates::id.desc())) + } else if sort == Some("recent-downloads") { + seek = Some(Seek::RecentDownloads); + query = query.order(( + recent_crate_downloads::downloads.desc().nulls_last(), + crates::id.desc(), + )) + } else if sort == Some("recent-updates") { + seek = Some(Seek::RecentUpdates); + query = query.order((crates::updated_at.desc(), crates::id.desc())); + } else if sort == Some("new") { + seek = Some(Seek::New); + query = query.order((crates::created_at.desc(), crates::id.desc())); + } else { + seek = seek.or(Some(Seek::Name)); + // Since the name is unique value, the inherent ordering becomes naturally unique. + // Therefore, an additional auxiliary ordering column is unnecessary in this case. + query = query.then_order_by(crates::name.asc()) + } - let crates = data.iter().map(|(c, ..)| c).collect::>(); + let pagination: PaginationOptions = PaginationOptions::builder() + .limit_page_numbers() + .enable_seek(true) + .gather(&req)?; + + let explicit_page = matches!(pagination.page, Page::Numeric(_)); + + // To avoid breaking existing users, seek-based pagination is only used if an explicit page has + // not been provided. This way clients relying on meta.next_page will use the faster seek-based + // paginations, while client hardcoding pages handling will use the slower offset-based code. + let (total, next_page, prev_page, data, conn) = if !explicit_page && seek.is_some() { + let seek = seek.unwrap(); + if let Some(condition) = seek + .after(&pagination.page)? + .map(|s| filter_params.seek_after(&s)) + { + query = query.filter(condition); + } - let versions: Vec = info_span!("db.query", message = "SELECT ... FROM versions") - .in_scope(|| { - Version::belonging_to(&crates) - .filter(versions::yanked.eq(false)) - .load(conn) - })?; - let versions = versions - .grouped_by(&crates) - .into_iter() - .map(TopVersions::from_versions); - - let crates = versions - .zip(data) - .map( - |( - max_version, - (krate, perfect_match, total, recent, _, default_version, yanked), - )| { - EncodableCrate::from_minimal( - krate, - default_version.as_deref(), - yanked, - Some(&max_version), - perfect_match, - total, - Some(recent.unwrap_or(0)), - ) - }, - ) - .collect::>(); - - Ok(json!({ - "crates": crates, - "meta": { - "total": total, - "next_page": next_page, - "prev_page": prev_page, + // This does a full index-only scan over the crates table to gather how many crates were + // published. Unfortunately on PostgreSQL counting the rows in a table requires scanning + // the table, and the `total` field is part of the stable registries API. + // + // If this becomes a problem in the future the crates count could be denormalized, at least + // for the filterless happy path. + let count_query = filter_params.make_query(&req, conn).await?.count(); + let query = query.pages_pagination_with_count_query(pagination, count_query); + let span = info_span!("db.query", message = "SELECT ..., COUNT(*) FROM crates"); + let data = query.load::(conn).instrument(span).await?; + ( + data.total(), + data.next_seek_params(|last| seek.to_payload(last))? + .map(|p| req.query_with_params(p)), + None, + data.into_iter().collect::>(), + conn, + ) + } else { + let count_query = filter_params.make_query(&req, conn).await?.count(); + let query = query.pages_pagination_with_count_query(pagination, count_query); + let span = info_span!("db.query", message = "SELECT ..., COUNT(*) FROM crates"); + let data = query.load::(conn).instrument(span).await?; + ( + data.total(), + data.next_page_params().map(|p| req.query_with_params(p)), + data.prev_page_params().map(|p| req.query_with_params(p)), + data.into_iter().collect::>(), + conn, + ) + }; + + let crates = data.iter().map(|(c, ..)| c).collect::>(); + + let span = info_span!("db.query", message = "SELECT ... FROM versions"); + let versions: Vec = Version::belonging_to(&crates) + .filter(versions::yanked.eq(false)) + .load(conn) + .instrument(span) + .await?; + let versions = versions + .grouped_by(&crates) + .into_iter() + .map(TopVersions::from_versions); + + let crates = versions + .zip(data) + .map( + |(max_version, (krate, perfect_match, total, recent, _, default_version, yanked))| { + EncodableCrate::from_minimal( + krate, + default_version.as_deref(), + yanked, + Some(&max_version), + perfect_match, + total, + Some(recent.unwrap_or(0)), + ) }, - })) - }) - .await? + ) + .collect::>(); + + Ok(json!({ + "crates": crates, + "meta": { + "total": total, + "next_page": next_page, + "prev_page": prev_page, + }, + })) } #[derive(Default)] @@ -282,8 +270,8 @@ struct FilterParams<'a> { team_id: Option, following: bool, has_ids: bool, - _auth_user_id: OnceCell, - _ids: OnceCell>>, + _auth_user_id: OnceLock, + _ids: OnceLock>>, } impl<'a> FilterParams<'a> { @@ -304,14 +292,12 @@ impl<'a> FilterParams<'a> { .as_deref() } - fn authed_user_id(&self, req: &Parts, conn: &mut AsyncPgConnection) -> AppResult { + async fn authed_user_id(&self, req: &Parts, conn: &mut AsyncPgConnection) -> AppResult { if let Some(val) = self._auth_user_id.get() { return Ok(*val); } - let user_id = Handle::current() - .block_on(AuthCheck::default().check(req, conn))? - .user_id(); + let user_id = AuthCheck::default().check(req, conn).await?.user_id(); // This should not fail, because of the `get()` check above let _ = self._auth_user_id.set(user_id); @@ -319,7 +305,7 @@ impl<'a> FilterParams<'a> { Ok(user_id) } - fn make_query( + async fn make_query( &'a self, req: &Parts, conn: &mut AsyncPgConnection, @@ -408,7 +394,7 @@ impl<'a> FilterParams<'a> { ), ); } else if self.following { - let user_id = self.authed_user_id(req, conn)?; + let user_id = self.authed_user_id(req, conn).await?; query = query.filter( crates::id.eq_any( follows::table diff --git a/src/controllers/user/me.rs b/src/controllers/user/me.rs index be4cb7fcd67..401794d09cb 100644 --- a/src/controllers/user/me.rs +++ b/src/controllers/user/me.rs @@ -5,7 +5,6 @@ use axum::response::Response; use axum::Json; use axum_extra::json; use axum_extra::response::ErasedJson; -use diesel_async::async_connection_wrapper::AsyncConnectionWrapper; use http::request::Parts; use std::collections::HashMap; @@ -15,7 +14,6 @@ use crate::controllers::helpers::{ok_true, Paginate}; use crate::models::krate::CrateName; use crate::models::{CrateOwner, Follow, OwnerKind, User, Version, VersionOwnerAction}; use crate::schema::{crate_owners, crates, emails, follows, users, versions}; -use crate::tasks::spawn_blocking; use crate::util::errors::{bad_request, AppResult}; use crate::util::BytesRequest; use crate::views::{EncodableMe, EncodablePrivateUser, EncodableVersion, OwnedCrate}; @@ -69,41 +67,39 @@ pub async fn me(app: AppState, req: Parts) -> AppResult> { pub async fn updates(app: AppState, req: Parts) -> AppResult { let mut conn = app.db_read_prefer_primary().await?; let auth = AuthCheck::only_cookie().check(&req, &mut conn).await?; - spawn_blocking(move || { - let conn: &mut AsyncConnectionWrapper<_> = &mut conn.into(); - - let user = auth.user(); - - let followed_crates = Follow::belonging_to(user).select(follows::crate_id); - let query = versions::table - .inner_join(crates::table) - .left_outer_join(users::table) - .filter(crates::id.eq_any(followed_crates)) - .order(versions::created_at.desc()) - .select(<(Version, CrateName, Option)>::as_select()) - .pages_pagination(PaginationOptions::builder().gather(&req)?); - let data: Paginated<(Version, CrateName, Option)> = query.load(conn)?; - let more = data.next_page_params().is_some(); - let versions = data.iter().map(|(v, ..)| v).collect::>(); - let actions = VersionOwnerAction::for_versions(conn, &versions)?; - let data = data - .into_iter() - .zip(actions) - .map(|((v, cn, pb), voas)| (v, cn, pb, voas)); - - let versions = data - .into_iter() - .map(|(version, crate_name, published_by, actions)| { - EncodableVersion::from(version, &crate_name.name, published_by, actions) - }) - .collect::>(); - - Ok(json!({ - "versions": versions, - "meta": { "more": more }, - })) - }) - .await? + + let user = auth.user(); + + let followed_crates = Follow::belonging_to(user).select(follows::crate_id); + let query = versions::table + .inner_join(crates::table) + .left_outer_join(users::table) + .filter(crates::id.eq_any(followed_crates)) + .order(versions::created_at.desc()) + .select(<(Version, CrateName, Option)>::as_select()) + .pages_pagination(PaginationOptions::builder().gather(&req)?); + + let data: Paginated<(Version, CrateName, Option)> = query.load(&mut conn).await?; + + let more = data.next_page_params().is_some(); + let versions = data.iter().map(|(v, ..)| v).collect::>(); + let actions = VersionOwnerAction::async_for_versions(&mut conn, &versions).await?; + let data = data + .into_iter() + .zip(actions) + .map(|((v, cn, pb), voas)| (v, cn, pb, voas)); + + let versions = data + .into_iter() + .map(|(version, crate_name, published_by, actions)| { + EncodableVersion::from(version, &crate_name.name, published_by, actions) + }) + .collect::>(); + + Ok(json!({ + "versions": versions, + "meta": { "more": more }, + })) } /// Handles the `PUT /confirm/:email_token` route