Skip to content

Commit dab24e2

Browse files
Don't worry about uncaptured contravariant lifetimes if they outlive a captured lifetime
1 parent 009e738 commit dab24e2

File tree

2 files changed

+242
-21
lines changed

2 files changed

+242
-21
lines changed

compiler/rustc_lint/src/impl_trait_overcaptures.rs

Lines changed: 226 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,28 @@
1-
use rustc_data_structures::fx::FxIndexSet;
1+
use std::cell::LazyCell;
2+
3+
use rustc_data_structures::fx::{FxHashMap, FxIndexMap, FxIndexSet};
24
use rustc_data_structures::unord::UnordSet;
35
use rustc_errors::{Applicability, LintDiagnostic};
46
use rustc_hir as hir;
57
use rustc_hir::def::DefKind;
68
use rustc_hir::def_id::{DefId, LocalDefId};
9+
use rustc_infer::infer::outlives::env::OutlivesEnvironment;
10+
use rustc_infer::infer::TyCtxtInferExt;
711
use rustc_macros::LintDiagnostic;
8-
use rustc_middle::bug;
912
use rustc_middle::middle::resolve_bound_vars::ResolvedArg;
13+
use rustc_middle::ty::relate::{
14+
structurally_relate_consts, structurally_relate_tys, Relate, RelateResult, TypeRelation,
15+
};
1016
use rustc_middle::ty::{
1117
self, Ty, TyCtxt, TypeSuperVisitable, TypeVisitable, TypeVisitableExt, TypeVisitor,
1218
};
19+
use rustc_middle::{bug, span_bug};
1320
use rustc_session::lint::FutureIncompatibilityReason;
1421
use rustc_session::{declare_lint, declare_lint_pass};
1522
use rustc_span::edition::Edition;
16-
use rustc_span::Span;
23+
use rustc_span::{Span, Symbol};
24+
use rustc_trait_selection::traits::outlives_bounds::InferCtxtExt;
25+
use rustc_trait_selection::traits::ObligationCtxt;
1726

1827
use crate::{fluent_generated as fluent, LateContext, LateLintPass};
1928

@@ -119,38 +128,86 @@ impl<'tcx> LateLintPass<'tcx> for ImplTraitOvercaptures {
119128
}
120129
}
121130

131+
#[derive(PartialEq, Eq, Hash, Debug, Copy, Clone)]
132+
enum ParamKind {
133+
// Early-bound var.
134+
Early(Symbol, u32),
135+
// Late-bound var on function, not within a binder. We can capture these.
136+
Free(DefId, Symbol),
137+
// Late-bound var in a binder. We can't capture these yet.
138+
Late,
139+
}
140+
122141
fn check_fn(tcx: TyCtxt<'_>, parent_def_id: LocalDefId) {
123142
let sig = tcx.fn_sig(parent_def_id).instantiate_identity();
124143

125-
let mut in_scope_parameters = FxIndexSet::default();
144+
let mut in_scope_parameters = FxIndexMap::default();
126145
// Populate the in_scope_parameters list first with all of the generics in scope
127146
let mut current_def_id = Some(parent_def_id.to_def_id());
128147
while let Some(def_id) = current_def_id {
129148
let generics = tcx.generics_of(def_id);
130149
for param in &generics.own_params {
131-
in_scope_parameters.insert(param.def_id);
150+
in_scope_parameters.insert(param.def_id, ParamKind::Early(param.name, param.index));
132151
}
133152
current_def_id = generics.parent;
134153
}
135154

155+
for bound_var in sig.bound_vars() {
156+
let ty::BoundVariableKind::Region(ty::BoundRegionKind::BrNamed(def_id, name)) = bound_var
157+
else {
158+
span_bug!(tcx.def_span(parent_def_id), "unexpected non-lifetime binder on fn sig");
159+
};
160+
161+
in_scope_parameters.insert(def_id, ParamKind::Free(def_id, name));
162+
}
163+
164+
let sig = tcx.liberate_late_bound_regions(parent_def_id.to_def_id(), sig);
165+
136166
// Then visit the signature to walk through all the binders (incl. the late-bound
137167
// vars on the function itself, which we need to count too).
138168
sig.visit_with(&mut VisitOpaqueTypes {
139169
tcx,
140170
parent_def_id,
141171
in_scope_parameters,
142172
seen: Default::default(),
173+
// Lazily compute these two, since they're likely a bit expensive.
174+
variances: LazyCell::new(|| {
175+
let mut functional_variances = FunctionalVariances {
176+
tcx: tcx,
177+
variances: FxHashMap::default(),
178+
ambient_variance: ty::Covariant,
179+
generics: tcx.generics_of(parent_def_id),
180+
};
181+
let _ = functional_variances.relate(sig, sig);
182+
functional_variances.variances
183+
}),
184+
outlives_env: LazyCell::new(|| {
185+
let param_env = tcx.param_env(parent_def_id);
186+
let infcx = tcx.infer_ctxt().build();
187+
let ocx = ObligationCtxt::new(&infcx);
188+
let assumed_wf_tys = ocx.assumed_wf_types(param_env, parent_def_id).unwrap_or_default();
189+
let implied_bounds =
190+
infcx.implied_bounds_tys_compat(param_env, parent_def_id, &assumed_wf_tys, false);
191+
OutlivesEnvironment::with_bounds(param_env, implied_bounds)
192+
}),
143193
});
144194
}
145195

146-
struct VisitOpaqueTypes<'tcx> {
196+
struct VisitOpaqueTypes<'tcx, VarFn, OutlivesFn> {
147197
tcx: TyCtxt<'tcx>,
148198
parent_def_id: LocalDefId,
149-
in_scope_parameters: FxIndexSet<DefId>,
199+
in_scope_parameters: FxIndexMap<DefId, ParamKind>,
200+
variances: LazyCell<FxHashMap<DefId, ty::Variance>, VarFn>,
201+
outlives_env: LazyCell<OutlivesEnvironment<'tcx>, OutlivesFn>,
150202
seen: FxIndexSet<LocalDefId>,
151203
}
152204

153-
impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for VisitOpaqueTypes<'tcx> {
205+
impl<'tcx, VarFn, OutlivesFn> TypeVisitor<TyCtxt<'tcx>>
206+
for VisitOpaqueTypes<'tcx, VarFn, OutlivesFn>
207+
where
208+
VarFn: FnOnce() -> FxHashMap<DefId, ty::Variance>,
209+
OutlivesFn: FnOnce() -> OutlivesEnvironment<'tcx>,
210+
{
154211
fn visit_binder<T: TypeVisitable<TyCtxt<'tcx>>>(
155212
&mut self,
156213
t: &ty::Binder<'tcx, T>,
@@ -163,8 +220,8 @@ impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for VisitOpaqueTypes<'tcx> {
163220
ty::BoundVariableKind::Region(ty::BoundRegionKind::BrNamed(def_id, ..))
164221
| ty::BoundVariableKind::Ty(ty::BoundTyKind::Param(def_id, _)) => {
165222
added.push(def_id);
166-
let unique = self.in_scope_parameters.insert(def_id);
167-
assert!(unique);
223+
let unique = self.in_scope_parameters.insert(def_id, ParamKind::Late);
224+
assert_eq!(unique, None);
168225
}
169226
_ => {
170227
self.tcx.dcx().span_delayed_bug(
@@ -209,6 +266,7 @@ impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for VisitOpaqueTypes<'tcx> {
209266
{
210267
// Compute the set of args that are captured by the opaque...
211268
let mut captured = FxIndexSet::default();
269+
let mut captured_regions = FxIndexSet::default();
212270
let variances = self.tcx.variances_of(opaque_def_id);
213271
let mut current_def_id = Some(opaque_def_id.to_def_id());
214272
while let Some(def_id) = current_def_id {
@@ -218,25 +276,60 @@ impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for VisitOpaqueTypes<'tcx> {
218276
if variances[param.index as usize] != ty::Invariant {
219277
continue;
220278
}
279+
280+
let arg = opaque_ty.args[param.index as usize];
221281
// We need to turn all `ty::Param`/`ConstKind::Param` and
222282
// `ReEarlyParam`/`ReBound` into def ids.
223-
captured.insert(extract_def_id_from_arg(
224-
self.tcx,
225-
generics,
226-
opaque_ty.args[param.index as usize],
227-
));
283+
captured.insert(extract_def_id_from_arg(self.tcx, generics, arg));
284+
285+
captured_regions.extend(arg.as_region());
228286
}
229287
current_def_id = generics.parent;
230288
}
231289

232290
// Compute the set of in scope params that are not captured. Get their spans,
233291
// since that's all we really care about them for emitting the diagnostic.
234-
let uncaptured_spans: Vec<_> = self
292+
let mut uncaptured_args: FxIndexSet<_> = self
235293
.in_scope_parameters
236294
.iter()
237-
.filter(|def_id| !captured.contains(*def_id))
238-
.map(|def_id| self.tcx.def_span(def_id))
295+
.filter(|&(def_id, _)| !captured.contains(def_id))
296+
.collect();
297+
298+
// These are args that we know are likely fine to "overcapture", since they can be
299+
// contravariantly shortened to one of the already-captured lifetimes that they
300+
// outlive.
301+
let covariant_long_args: FxIndexSet<_> = uncaptured_args
302+
.iter()
303+
.copied()
304+
.filter(|&(def_id, kind)| {
305+
let Some(ty::Bivariant | ty::Contravariant) = self.variances.get(def_id) else {
306+
return false;
307+
};
308+
let DefKind::LifetimeParam = self.tcx.def_kind(def_id) else {
309+
return false;
310+
};
311+
let uncaptured = match *kind {
312+
ParamKind::Early(name, index) => ty::Region::new_early_param(
313+
self.tcx,
314+
ty::EarlyParamRegion { name, index },
315+
),
316+
ParamKind::Free(def_id, name) => ty::Region::new_late_param(
317+
self.tcx,
318+
self.parent_def_id.to_def_id(),
319+
ty::BoundRegionKind::BrNamed(def_id, name),
320+
),
321+
ParamKind::Late => return false,
322+
};
323+
// Does this region outlive any captured region?
324+
captured_regions.iter().any(|r| {
325+
self.outlives_env
326+
.free_region_map()
327+
.sub_free_regions(self.tcx, *r, uncaptured)
328+
})
329+
})
239330
.collect();
331+
// We don't care to warn on these args.
332+
uncaptured_args.retain(|arg| !covariant_long_args.contains(arg));
240333

241334
let opaque_span = self.tcx.def_span(opaque_def_id);
242335
let new_capture_rules =
@@ -246,7 +339,7 @@ impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for VisitOpaqueTypes<'tcx> {
246339
// `use<>` syntax on it, and we're < edition 2024, then warn the user.
247340
if !new_capture_rules
248341
&& !opaque.bounds.iter().any(|bound| matches!(bound, hir::GenericBound::Use(..)))
249-
&& !uncaptured_spans.is_empty()
342+
&& !uncaptured_args.is_empty()
250343
{
251344
let suggestion = if let Ok(snippet) =
252345
self.tcx.sess.source_map().span_to_snippet(opaque_span)
@@ -274,6 +367,11 @@ impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for VisitOpaqueTypes<'tcx> {
274367
None
275368
};
276369

370+
let uncaptured_spans: Vec<_> = uncaptured_args
371+
.into_iter()
372+
.map(|(def_id, _)| self.tcx.def_span(def_id))
373+
.collect();
374+
277375
self.tcx.emit_node_span_lint(
278376
IMPL_TRAIT_OVERCAPTURES,
279377
self.tcx.local_def_id_to_hir_id(opaque_def_id),
@@ -327,7 +425,7 @@ impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for VisitOpaqueTypes<'tcx> {
327425
if self
328426
.in_scope_parameters
329427
.iter()
330-
.all(|def_id| explicitly_captured.contains(def_id))
428+
.all(|(def_id, _)| explicitly_captured.contains(def_id))
331429
{
332430
self.tcx.emit_node_span_lint(
333431
IMPL_TRAIT_REDUNDANT_CAPTURES,
@@ -396,7 +494,11 @@ fn extract_def_id_from_arg<'tcx>(
396494
ty::ReBound(
397495
_,
398496
ty::BoundRegion { kind: ty::BoundRegionKind::BrNamed(def_id, ..), .. },
399-
) => def_id,
497+
)
498+
| ty::ReLateParam(ty::LateParamRegion {
499+
scope: _,
500+
bound_region: ty::BoundRegionKind::BrNamed(def_id, ..),
501+
}) => def_id,
400502
_ => unreachable!(),
401503
},
402504
ty::GenericArgKind::Type(ty) => {
@@ -413,3 +515,106 @@ fn extract_def_id_from_arg<'tcx>(
413515
}
414516
}
415517
}
518+
519+
/// Computes the variances of regions that appear in the type, but considering
520+
/// late-bound regions too, which don't have their variance computed usually.
521+
///
522+
/// Like generalization, this is a unary operation implemented on top of the binary
523+
/// relation infrastructure, mostly because it's much easier to have the relation
524+
/// track the variance for you, rather than having to do it yourself.
525+
struct FunctionalVariances<'tcx> {
526+
tcx: TyCtxt<'tcx>,
527+
variances: FxHashMap<DefId, ty::Variance>,
528+
ambient_variance: ty::Variance,
529+
generics: &'tcx ty::Generics,
530+
}
531+
532+
impl<'tcx> TypeRelation<TyCtxt<'tcx>> for FunctionalVariances<'tcx> {
533+
fn cx(&self) -> TyCtxt<'tcx> {
534+
self.tcx
535+
}
536+
537+
fn relate_with_variance<T: ty::relate::Relate<TyCtxt<'tcx>>>(
538+
&mut self,
539+
variance: rustc_type_ir::Variance,
540+
_: ty::VarianceDiagInfo<TyCtxt<'tcx>>,
541+
a: T,
542+
b: T,
543+
) -> RelateResult<'tcx, T> {
544+
let old_variance = self.ambient_variance;
545+
self.ambient_variance = self.ambient_variance.xform(variance);
546+
self.relate(a, b)?;
547+
self.ambient_variance = old_variance;
548+
Ok(a)
549+
}
550+
551+
fn tys(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, Ty<'tcx>> {
552+
structurally_relate_tys(self, a, b)?;
553+
Ok(a)
554+
}
555+
556+
fn regions(
557+
&mut self,
558+
a: ty::Region<'tcx>,
559+
_: ty::Region<'tcx>,
560+
) -> RelateResult<'tcx, ty::Region<'tcx>> {
561+
let def_id = match *a {
562+
ty::ReEarlyParam(ebr) => self.generics.region_param(ebr, self.tcx).def_id,
563+
ty::ReBound(
564+
_,
565+
ty::BoundRegion { kind: ty::BoundRegionKind::BrNamed(def_id, ..), .. },
566+
)
567+
| ty::ReLateParam(ty::LateParamRegion {
568+
scope: _,
569+
bound_region: ty::BoundRegionKind::BrNamed(def_id, ..),
570+
}) => def_id,
571+
_ => {
572+
return Ok(a);
573+
}
574+
};
575+
576+
if let Some(variance) = self.variances.get_mut(&def_id) {
577+
*variance = unify(*variance, self.ambient_variance);
578+
} else {
579+
self.variances.insert(def_id, self.ambient_variance);
580+
}
581+
582+
Ok(a)
583+
}
584+
585+
fn consts(
586+
&mut self,
587+
a: ty::Const<'tcx>,
588+
b: ty::Const<'tcx>,
589+
) -> RelateResult<'tcx, ty::Const<'tcx>> {
590+
structurally_relate_consts(self, a, b)?;
591+
Ok(a)
592+
}
593+
594+
fn binders<T>(
595+
&mut self,
596+
a: ty::Binder<'tcx, T>,
597+
b: ty::Binder<'tcx, T>,
598+
) -> RelateResult<'tcx, ty::Binder<'tcx, T>>
599+
where
600+
T: Relate<TyCtxt<'tcx>>,
601+
{
602+
self.relate(a.skip_binder(), b.skip_binder())?;
603+
Ok(a)
604+
}
605+
}
606+
607+
/// What is the variance that satisfies the two variances?
608+
fn unify(a: ty::Variance, b: ty::Variance) -> ty::Variance {
609+
match (a, b) {
610+
// Bivariance is lattice bottom.
611+
(ty::Bivariant, other) | (other, ty::Bivariant) => other,
612+
// Invariant is lattice top.
613+
(ty::Invariant, _) | (_, ty::Invariant) => ty::Invariant,
614+
// If type is required to be covariant and contravariant, then it's invariant.
615+
(ty::Contravariant, ty::Covariant) | (ty::Covariant, ty::Contravariant) => ty::Invariant,
616+
// Otherwise, co + co = co, contra + contra = contra.
617+
(ty::Contravariant, ty::Contravariant) => ty::Contravariant,
618+
(ty::Covariant, ty::Covariant) => ty::Covariant,
619+
}
620+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
//@ check-pass
2+
3+
#![feature(precise_capturing)]
4+
#![deny(impl_trait_overcaptures)]
5+
6+
struct Ctxt<'tcx>(&'tcx ());
7+
8+
// In `compute`, we don't care that we're "overcapturing" `'tcx`
9+
// in edition 2024, because it can be shortened at the call site
10+
// and we know it outlives `'_`.
11+
12+
impl<'tcx> Ctxt<'tcx> {
13+
fn compute(&self) -> impl Sized + '_ {}
14+
}
15+
16+
fn main() {}

0 commit comments

Comments
 (0)