Skip to content

[perf] Cache the canonical *instantiation* of param-envs #142316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 169 additions & 20 deletions compiler/rustc_infer/src/infer/canonical/instantiate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
//! [c]: https://rust-lang.github.io/chalk/book/canonical_queries/canonicalization.html

use rustc_macros::extension;
use rustc_middle::bug;
use rustc_middle::ty::{self, FnMutDelegate, GenericArgKind, TyCtxt, TypeFoldable};
use rustc_middle::ty::{
self, DelayedMap, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeSuperVisitable,
TypeVisitableExt, TypeVisitor,
};
use rustc_type_ir::TypeVisitable;

use crate::infer::canonical::{Canonical, CanonicalVarValues};

Expand Down Expand Up @@ -58,23 +61,169 @@ where
T: TypeFoldable<TyCtxt<'tcx>>,
{
if var_values.var_values.is_empty() {
value
} else {
let delegate = FnMutDelegate {
regions: &mut |br: ty::BoundRegion| match var_values[br.var].kind() {
GenericArgKind::Lifetime(l) => l,
r => bug!("{:?} is a region but value is {:?}", br, r),
},
types: &mut |bound_ty: ty::BoundTy| match var_values[bound_ty.var].kind() {
GenericArgKind::Type(ty) => ty,
r => bug!("{:?} is a type but value is {:?}", bound_ty, r),
},
consts: &mut |bound_ct: ty::BoundVar| match var_values[bound_ct].kind() {
GenericArgKind::Const(ct) => ct,
c => bug!("{:?} is a const but value is {:?}", bound_ct, c),
},
};

tcx.replace_escaping_bound_vars_uncached(value, delegate)
return value;
}

value.fold_with(&mut CanonicalInstantiator {
tcx,
current_index: ty::INNERMOST,
var_values: var_values.var_values,
cache: Default::default(),
})
}

/// Replaces the bound vars in a canonical binder with var values.
struct CanonicalInstantiator<'tcx> {
tcx: TyCtxt<'tcx>,

// The values that the bound vars are are being instantiated with.
var_values: ty::GenericArgsRef<'tcx>,

/// As with `BoundVarReplacer`, represents the index of a binder *just outside*
/// the ones we have visited.
current_index: ty::DebruijnIndex,

// Instantiation is a pure function of `DebruijnIndex` and `Ty`.
cache: DelayedMap<(ty::DebruijnIndex, Ty<'tcx>), Ty<'tcx>>,
}

impl<'tcx> TypeFolder<TyCtxt<'tcx>> for CanonicalInstantiator<'tcx> {
fn cx(&self) -> TyCtxt<'tcx> {
self.tcx
}

fn fold_binder<T: TypeFoldable<TyCtxt<'tcx>>>(
&mut self,
t: ty::Binder<'tcx, T>,
) -> ty::Binder<'tcx, T> {
self.current_index.shift_in(1);
let t = t.super_fold_with(self);
self.current_index.shift_out(1);
t
}

fn fold_ty(&mut self, t: Ty<'tcx>) -> Ty<'tcx> {
match *t.kind() {
ty::Bound(debruijn, bound_ty) if debruijn == self.current_index => {
self.var_values[bound_ty.var.as_usize()].expect_ty()
}
_ => {
if !t.has_vars_bound_at_or_above(self.current_index) {
t
} else if let Some(&t) = self.cache.get(&(self.current_index, t)) {
t
} else {
let res = t.super_fold_with(self);
assert!(self.cache.insert((self.current_index, t), res));
res
}
}
}
}

fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> {
match r.kind() {
ty::ReBound(debruijn, br) if debruijn == self.current_index => {
self.var_values[br.var.as_usize()].expect_region()
}
_ => r,
}
}

fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
match ct.kind() {
ty::ConstKind::Bound(debruijn, bound_const) if debruijn == self.current_index => {
self.var_values[bound_const.as_usize()].expect_const()
}
_ => ct.super_fold_with(self),
}
}

fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
if p.has_vars_bound_at_or_above(self.current_index) { p.super_fold_with(self) } else { p }
}

fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
if !c.has_vars_bound_at_or_above(self.current_index) {
return c;
}

// Since instantiation is a function of `DebruijnIndex`, we don't want
// to have to cache more copies of clauses when we're inside of binders.
// Since we currently expect to only have clauses in the outermost
// debruijn index, we just fold if we're inside of a binder.
if self.current_index > ty::INNERMOST {
return c.super_fold_with(self);
}

// Our cache key is `(clauses, var_values)`, but we also don't care about
// var values that aren't named in the clauses, since they can change without
// affecting the output. Since `ParamEnv`s are cached first, we compute the
// last var value that is mentioned in the clauses, and cut off the list so
// that we have more hits in the cache.

// We also cache the computation of "highest var named by clauses" since that
// is both expensive (depending on the size of the clauses) and a pure function.
let index = *self
.tcx
.highest_var_in_clauses_cache
.lock()
.entry(c)
.or_insert_with(|| highest_var_in_clauses(c));
let c_args = &self.var_values[..=index];

if let Some(c) = self.tcx.clauses_cache.lock().get(&(c, c_args)) {
c
} else {
let folded = c.super_fold_with(self);
self.tcx.clauses_cache.lock().insert((c, c_args), folded);
folded
}
}
}

fn highest_var_in_clauses<'tcx>(c: ty::Clauses<'tcx>) -> usize {
struct HighestVarInClauses {
max_var: usize,
current_index: ty::DebruijnIndex,
}
impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for HighestVarInClauses {
fn visit_binder<T: TypeVisitable<TyCtxt<'tcx>>>(
&mut self,
t: &ty::Binder<'tcx, T>,
) -> Self::Result {
self.current_index.shift_in(1);
let t = t.super_visit_with(self);
self.current_index.shift_out(1);
t
}
fn visit_ty(&mut self, t: Ty<'tcx>) {
if let ty::Bound(debruijn, bound_ty) = *t.kind()
&& debruijn == self.current_index
{
self.max_var = self.max_var.max(bound_ty.var.as_usize());
} else if t.has_vars_bound_at_or_above(self.current_index) {
t.super_visit_with(self);
}
}
fn visit_region(&mut self, r: ty::Region<'tcx>) {
if let ty::ReBound(debruijn, bound_region) = r.kind()
&& debruijn == self.current_index
{
self.max_var = self.max_var.max(bound_region.var.as_usize());
}
}
fn visit_const(&mut self, ct: ty::Const<'tcx>) {
if let ty::ConstKind::Bound(debruijn, bound_const) = ct.kind()
&& debruijn == self.current_index
{
self.max_var = self.max_var.max(bound_const.as_usize());
} else if ct.has_vars_bound_at_or_above(self.current_index) {
ct.super_visit_with(self);
}
}
}
let mut visitor = HighestVarInClauses { max_var: 0, current_index: ty::INNERMOST };
c.visit_with(&mut visitor);
visitor.max_var
}
8 changes: 8 additions & 0 deletions compiler/rustc_middle/src/ty/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1479,6 +1479,12 @@ pub struct GlobalCtxt<'tcx> {

pub canonical_param_env_cache: CanonicalParamEnvCache<'tcx>,

/// Caches the index of the highest bound var in clauses in a canonical binder.
pub highest_var_in_clauses_cache: Lock<FxHashMap<ty::Clauses<'tcx>, usize>>,
/// Caches the instantiation of a canonical binder given a set of args.
pub clauses_cache:
Lock<FxHashMap<(ty::Clauses<'tcx>, &'tcx [ty::GenericArg<'tcx>]), ty::Clauses<'tcx>>>,

/// Data layout specification for the current target.
pub data_layout: TargetDataLayout,

Expand Down Expand Up @@ -1727,6 +1733,8 @@ impl<'tcx> TyCtxt<'tcx> {
new_solver_evaluation_cache: Default::default(),
new_solver_canonical_param_env_cache: Default::default(),
canonical_param_env_cache: Default::default(),
highest_var_in_clauses_cache: Default::default(),
clauses_cache: Default::default(),
data_layout,
alloc_map: interpret::AllocMap::new(),
current_gcx,
Expand Down
Loading