From 2d582b251f53ca28fe2a144c61271cc4252bb9e4 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Tue, 29 Apr 2025 18:55:15 -0700 Subject: [PATCH 1/2] [naga] Process overrides selectively for the active entry point This adds an argument to `process_overrides` to specify the desired entry point, and then modifies override processing to tolerate missing overrides if (and hopefully only if) they are not used by that entry point. During that processing, lists are constructed of the functions, globals, and entry points that cannot be used due to missing overrides. The MSL backend is changed to skip over those items while writing. - [ ] update the other backends - [ ] make save_overrides_resolved a bit nicer - [ ] add some more tests --- naga-cli/src/bin/naga.rs | 55 +- naga/src/back/msl/mod.rs | 15 +- naga/src/back/msl/writer.rs | 60 ++- naga/src/back/pipeline_constants.rs | 476 ++++++++++++++---- naga/src/front/glsl/context.rs | 9 +- naga/src/front/glsl/parser.rs | 2 +- naga/src/front/wgsl/lower/mod.rs | 26 +- naga/src/proc/constant_evaluator.rs | 31 +- naga/src/proc/mod.rs | 37 +- naga/src/valid/expression.rs | 19 +- naga/src/valid/function.rs | 7 +- naga/src/valid/handles.rs | 2 +- naga/src/valid/interface.rs | 2 +- naga/src/valid/mod.rs | 86 +++- naga/src/valid/type.rs | 2 +- .../in/wgsl/missing-unused-overrides.toml | 13 + .../in/wgsl/missing-unused-overrides.wgsl | 45 ++ naga/tests/naga/snapshots.rs | 64 ++- naga/tests/naga/validation.rs | 167 +++++- naga/tests/naga/wgsl_errors.rs | 2 +- .../out/msl/wgsl-missing-unused-overrides.msl | 35 ++ wgpu-core/src/device/mod.rs | 4 +- wgpu-hal/src/dx12/device.rs | 1 + wgpu-hal/src/gles/device.rs | 7 +- wgpu-hal/src/metal/device.rs | 8 +- wgpu-hal/src/vulkan/device.rs | 7 +- 26 files changed, 960 insertions(+), 222 deletions(-) create mode 100644 naga/tests/in/wgsl/missing-unused-overrides.toml create mode 100644 naga/tests/in/wgsl/missing-unused-overrides.wgsl create mode 100644 naga/tests/out/msl/wgsl-missing-unused-overrides.msl diff --git a/naga-cli/src/bin/naga.rs b/naga-cli/src/bin/naga.rs index 6f95e429f68..cae745f97e9 100644 --- a/naga-cli/src/bin/naga.rs +++ b/naga-cli/src/bin/naga.rs @@ -686,6 +686,8 @@ fn write_output( params: &Parameters, output_path: &str, ) -> anyhow::Result<()> { + use naga::back::pipeline_constants::ProcessOverridesOutput; + match Path::new(&output_path) .extension() .ok_or(CliError("Output filename has no extension"))? @@ -717,9 +719,14 @@ fn write_output( succeed, and it failed in a previous step", ))?; - let (module, info) = - naga::back::pipeline_constants::process_overrides(module, info, ¶ms.overrides) - .unwrap_pretty(); + let ProcessOverridesOutput { module, info, .. } = + naga::back::pipeline_constants::process_overrides( + module, + info, + None, + ¶ms.overrides, + ) + .unwrap_pretty(); let pipeline_options = msl::PipelineOptions::default(); let (msl, _) = @@ -751,9 +758,17 @@ fn write_output( succeed, and it failed in a previous step", ))?; - let (module, info) = - naga::back::pipeline_constants::process_overrides(module, info, ¶ms.overrides) - .unwrap_pretty(); + let naga::back::pipeline_constants::ProcessOverridesOutput { + module, + info, + unresolved: _, + } = naga::back::pipeline_constants::process_overrides( + module, + info, + None, + ¶ms.overrides, + ) + .unwrap_pretty(); let spv = spv::write_vec(&module, &info, ¶ms.spv_out, pipeline_options).unwrap_pretty(); @@ -788,9 +803,17 @@ fn write_output( succeed, and it failed in a previous step", ))?; - let (module, info) = - naga::back::pipeline_constants::process_overrides(module, info, ¶ms.overrides) - .unwrap_pretty(); + let naga::back::pipeline_constants::ProcessOverridesOutput { + module, + info, + unresolved: _, + } = naga::back::pipeline_constants::process_overrides( + module, + info, + None, + ¶ms.overrides, + ) + .unwrap_pretty(); let mut buffer = String::new(); let mut writer = glsl::Writer::new( @@ -819,9 +842,17 @@ fn write_output( succeed, and it failed in a previous step", ))?; - let (module, info) = - naga::back::pipeline_constants::process_overrides(module, info, ¶ms.overrides) - .unwrap_pretty(); + let naga::back::pipeline_constants::ProcessOverridesOutput { + module, + info, + unresolved: _, + } = naga::back::pipeline_constants::process_overrides( + module, + info, + None, + ¶ms.overrides, + ) + .unwrap_pretty(); let mut buffer = String::new(); let pipeline_options = Default::default(); diff --git a/naga/src/back/msl/mod.rs b/naga/src/back/msl/mod.rs index 7bc8289b9b8..f34d2b23a81 100644 --- a/naga/src/back/msl/mod.rs +++ b/naga/src/back/msl/mod.rs @@ -52,7 +52,12 @@ use alloc::{ }; use core::fmt::{Error as FmtError, Write}; -use crate::{arena::Handle, ir, proc::index, valid::ModuleInfo}; +use crate::{ + arena::Handle, + ir, + proc::index, + valid::{ModuleInfo, UnresolvedOverrides}, +}; mod keywords; pub mod sampler; @@ -431,6 +436,14 @@ pub struct PipelineOptions { /// point is not found, an error will be thrown while writing. pub entry_point: Option<(ir::ShaderStage, String)>, + /// Information about unresolved overrides. + /// + /// This struct is returned by `process_overrides`. It tells the writer + /// which items to omit from the output because they are not used and refer + /// to overrides that were not resolved to a concrete value. + #[cfg_attr(feature = "serialize", serde(skip))] + pub unresolved_overrides: UnresolvedOverrides, + /// Allow `BuiltIn::PointSize` and inject it if doesn't exist. /// /// Metal doesn't like this for non-point primitive topologies and requires it for diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index f05e5c233aa..c7232faf1af 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -21,9 +21,10 @@ use crate::{ proc::{ self, index::{self, BoundsCheck}, - NameKey, TypeResolution, + NameKey, ResolveArraySizeError, TypeResolution, }, - valid, FastHashMap, FastHashSet, + valid::{self, UnresolvedOverrides}, + FastHashMap, FastHashSet, }; #[cfg(test)] @@ -436,6 +437,7 @@ pub struct Writer { /// Set of (struct type, struct field index) denoting which fields require /// padding inserted **before** them (i.e. between fields at index - 1 and index) struct_member_pads: FastHashSet<(Handle, u32)>, + unresolved_overrides: Option, } impl crate::Scalar { @@ -775,6 +777,7 @@ impl Writer { #[cfg(test)] put_block_stack_pointers: Default::default(), struct_member_pads: FastHashSet::default(), + unresolved_overrides: None, } } @@ -4032,6 +4035,7 @@ impl Writer { ); self.wrapped_functions.clear(); self.struct_member_pads.clear(); + self.unresolved_overrides = Some(pipeline_options.unresolved_overrides.clone()); writeln!( self.out, @@ -4216,8 +4220,20 @@ impl Writer { first_time: false, }; - match size.resolve(module.to_ctx())? { - proc::IndexableLength::Known(size) => { + match size.resolve(module.to_ctx()) { + Err(ResolveArraySizeError::NonConstArrayLength) => { + // The array size was never resolved. This _should_ + // be because it is an override expression and the + // type is not needed for the entry point being + // written. + // TODO: do we want to assemble `UnresolvedOverrides.types` to make this safer? + // (And if so, do we also want to validate that those types are truly unused?) + continue; + } + Err(err @ ResolveArraySizeError::ExpectedPositiveArrayLength) => { + return Err(err.into()); + } + Ok(proc::IndexableLength::Known(size)) => { writeln!(self.out, "struct {name} {{")?; writeln!( self.out, @@ -4229,7 +4245,7 @@ impl Writer { )?; writeln!(self.out, "}};")?; } - proc::IndexableLength::Dynamic => { + Ok(proc::IndexableLength::Dynamic) => { writeln!(self.out, "typedef {base_name} {name}[1];")?; } } @@ -5757,6 +5773,17 @@ template fun_handle ); + if self + .unresolved_overrides + .as_ref() + .unwrap() + .functions + .contains_key(&fun_handle) + { + log::trace!("skipping due to unresolved overrides"); + continue; + } + let ctx = back::FunctionCtx { ty: back::FunctionType::Function(fun_handle), info: &mod_info[fun_handle], @@ -5880,6 +5907,19 @@ template }; for ep_index in ep_range { + if self + .unresolved_overrides + .as_ref() + .unwrap() + .entry_points + .contains_key(&ep_index) + { + log::error!( + "must write the same entry point that was passed to `process_overrides`" + ); + return Err(Error::Override); + } + let ep = &module.entry_points[ep_index]; let fun = &ep.function; let fun_info = mod_info.get_entry_point(ep_index); @@ -6288,7 +6328,15 @@ template // within the entry point. for (handle, var) in module.global_variables.iter() { let usage = fun_info[handle]; - if usage.is_empty() || var.space == crate::AddressSpace::Private { + if usage.is_empty() + || var.space == crate::AddressSpace::Private + || self + .unresolved_overrides + .as_ref() + .unwrap() + .global_variables + .contains_key(&handle) + { continue; } diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 1cf1c805249..b92e3cfd816 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -10,10 +10,13 @@ use thiserror::Error; use super::PipelineConstants; use crate::{ arena::HandleVec, - proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter}, - valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator}, - Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar, - Span, Statement, TypeInner, WithSpan, + ir, + proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter, U32EvalError}, + valid::{ + Capabilities, ModuleInfo, UnresolvedOverrides, ValidationError, ValidationFlags, Validator, + }, + Arena, Block, Constant, Expression, FastHashMap, Function, Handle, Literal, Module, Override, + Range, Scalar, Span, Statement, TypeInner, WithSpan, }; #[cfg(no_std)] @@ -37,28 +40,69 @@ pub enum PipelineConstantError { ValidationError(#[from] WithSpan), #[error("workgroup_size override isn't strictly positive")] NegativeWorkgroupSize, + #[error("unable to evaluate workgroup_size override")] + WorkgroupSizeOverrideEvaluationError, } -/// Replace all overrides in `module` with constants. +// Returns the key to use for an override in `pipeline_constants`. +fn override_key(ov: &Override) -> Cow<'_, str> { + if let Some(id) = ov.id { + Cow::Owned(id.to_string()) + } else if let Some(ref name) = ov.name { + Cow::Borrowed(name) + } else { + unreachable!() + } +} + +#[derive(Debug)] +pub struct ProcessOverridesOutput<'a> { + pub module: Cow<'a, Module>, + pub info: Cow<'a, ModuleInfo>, + pub unresolved: UnresolvedOverrides, +} + +/// Replace overrides in `module` with constants. /// /// If no changes are needed, this just returns `Cow::Borrowed` /// references to `module` and `module_info`. Otherwise, it clones -/// `module`, edits its [`global_expressions`] arena to contain only -/// fully-evaluated expressions, and returns `Cow::Owned` values -/// holding the simplified module and its validation results. +/// `module`, updates it with evaluated override expressions, and returns +/// `Cow::Owned` values holding the simplified module and its validation +/// results. +/// +/// If `entry_point` is specified, then any override referenced by +/// that entry point must be supplied, and other overrides are +/// optional. The returned module may still have override expressions, +/// but they should not be reachable from the specified entry point. /// -/// In either case, the module returned has an empty `overrides` -/// arena, and the `global_expressions` arena contains only -/// fully-evaluated expressions. +/// If `entry_point` is not specified, then all overrides must be specified. /// -/// [`global_expressions`]: Module::global_expressions +/// This function completely rewrites both the [`global`] and function-local +/// arenas, replacing [`Expression::Override`] with [`Expression::Constant`]. +/// It then updates expressions, statements, and initializers that refer to a +/// an updated expression. +/// +/// The types arena is not updated. This means that the size of an array (in the +/// workgroup space, because this is the only place override-sized arrays are +/// permitted) may still require indirection through an override handle to the +/// initializer expression, which will be an evaluated constant. See +/// [#6787](https://github.com/gfx-rs/wgpu/pull/6787). +/// +/// [`global`]: Module::global_expressions pub fn process_overrides<'a>( module: &'a Module, module_info: &'a ModuleInfo, + entry_point: Option<(ir::ShaderStage, &str)>, pipeline_constants: &PipelineConstants, -) -> Result<(Cow<'a, Module>, Cow<'a, ModuleInfo>), PipelineConstantError> { +) -> Result, PipelineConstantError> { + let mut unresolved = UnresolvedOverrides::default(); + if module.overrides.is_empty() { - return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info))); + return Ok(ProcessOverridesOutput { + module: Cow::Borrowed(module), + info: Cow::Borrowed(module_info), + unresolved, + }); } let mut module = module.clone(); @@ -84,6 +128,7 @@ pub fn process_overrides<'a>( let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len()); let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new(); + let mut global_expressions_missing_overrides = FastHashMap::default(); let mut layouter = crate::proc::Layouter::default(); // An iterator through the original overrides table, consumed in @@ -93,10 +138,9 @@ pub fn process_overrides<'a>( // Do two things in tandem: // - // - Rebuild the global expression arena from scratch, fully - // evaluating all expressions, and replacing each `Override` - // expression in `module.global_expressions` with a `Constant` - // expression. + // - Rebuild the global expression arena from scratch, replacing + // `Override` expressions in `module.global_expressions` that can + // now be evaluated with `Constant` expressions. // // - Build a new `Constant` in `module.constants` to take the // place of each `Override`. @@ -123,28 +167,43 @@ pub fn process_overrides<'a>( for (old_h, expr, span) in module.global_expressions.drain() { let mut expr = match expr { Expression::Override(h) => { - let c_h = if let Some(new_h) = override_map.get(h) { - *new_h - } else { - let mut new_h = None; - for entry in override_iter.by_ref() { - let stop = entry.0 == h; - new_h = Some(process_override( - entry, - pipeline_constants, - &mut module, - &mut override_map, - &adjusted_global_expressions, - &mut adjusted_constant_initializers, - &mut global_expression_kind_tracker, - )?); - if stop { - break; + match override_map.get(h) { + Some(&Some(new_h)) => { + // Already evaluated. + Expression::Constant(new_h) + } + Some(&None) => { + // Already processed and could not evaluate. Leave + // expression unchanged. + expr + } + None => { + let mut result = None; + for entry in override_iter.by_ref() { + let stop = entry.0 == h; + result = process_override( + entry, + pipeline_constants, + &mut module, + &mut override_map, + &adjusted_global_expressions, + &mut adjusted_constant_initializers, + &mut global_expression_kind_tracker, + )?; + if stop { + break; + } + } + match result { + None => { + // Could not evaluate. Leave expression + // unchanged. + expr + } + Some(new_h) => Expression::Constant(new_h), } } - new_h.unwrap() - }; - Expression::Constant(c_h) + } } Expression::Constant(c_h) => { if adjusted_constant_initializers.insert(c_h) { @@ -155,6 +214,9 @@ pub fn process_overrides<'a>( } expr => expr, }; + // Attempt constant evaluation now that overrides referenced by this + // expression may have been resolved. If they have not been resolved, + // the expression will remain with `ExpressionKind::Override`. let mut evaluator = ConstantEvaluator::for_wgsl_module( &mut module, &mut global_expression_kind_tracker, @@ -162,8 +224,17 @@ pub fn process_overrides<'a>( false, ); adjust_expr(&adjusted_global_expressions, &mut expr); - let h = evaluator.try_eval_and_append(expr, span)?; - adjusted_global_expressions.insert(old_h, h); + match evaluator.try_eval_and_append(expr, span) { + Err((expr, ConstantEvaluatorError::Override(ov_h))) => { + let h = module.global_expressions.append(expr, span); + global_expression_kind_tracker.insert(h, crate::proc::ExpressionKind::Override); + adjusted_global_expressions.insert(old_h, h); + global_expressions_missing_overrides.insert(h, ov_h); + log::trace!("global {:?} initializer missing override {:?}", h, ov_h); + } + Err((_, e)) => return Err(e.into()), + Ok(h) => adjusted_global_expressions.insert(old_h, h), + } } // Finish processing any overrides we didn't visit in the loop above. @@ -184,6 +255,9 @@ pub fn process_overrides<'a>( init: Some(ref mut init), .. } => { + // Anonymous override representing by an array size expression. + // These are not handled through `process_override`, are not + // replaced by a constant, and are not added to `override_map`. *init = adjusted_global_expressions[*init]; } _ => {} @@ -201,23 +275,137 @@ pub fn process_overrides<'a>( c.init = adjusted_global_expressions[c.init]; } - for (_, v) in module.global_variables.iter_mut() { + // Identify which global variables are still unusable due to missing + // overrides. Overrides can appear in the initializer, and in the + // case of workgroup space arrays, in the array size. + for (v_handle, v) in module.global_variables.iter_mut() { if let Some(ref mut init) = v.init { *init = adjusted_global_expressions[*init]; + if let Some(&o_handle) = global_expressions_missing_overrides.get(init) { + log::trace!( + "global {:?} initializer missing override {:?}", + v.name, + overrides[o_handle].name + ); + unresolved.global_variables.insert(v_handle, o_handle); + } + } else if let TypeInner::Array { + size: crate::ArraySize::Pending(o_handle), + .. + } = module.types[v.ty].inner + { + let resolved = match override_map.get(o_handle) { + Some(&Some(_)) => { + // Override was processed successfully. + true + } + Some(&None) => { + // Override could not be processed. + false + } + None => { + // Anonymous override for array size expression + // These are not added to override_map. + match overrides[o_handle].init { + Some(init) => global_expression_kind_tracker.is_const(init), + None => { + // This should not happen. + log::error!("anonymous override with no initializer?"); + true + } + } + } + }; + if !resolved { + log::trace!( + "array size of global {:?} missing override {:?}", + v.name, + overrides[o_handle].name + ); + unresolved.global_variables.insert(v_handle, o_handle); + } } } + // Process functions, taking note of which ones require overrides that were + // not specified. Like expressions, callees are guarenteed to appear before + // their callers. let mut functions = mem::take(&mut module.functions); - for (_, function) in functions.iter_mut() { - process_function(&mut module, &override_map, &mut layouter, function)?; + for (f_handle, function) in functions.iter_mut() { + if let Some(o_handle) = process_function( + &mut module, + &override_map, + &unresolved.functions, + &mut layouter, + function, + )? { + log::trace!( + "function {:?} missing override {:?}", + function.name, + overrides[o_handle].name + ); + unresolved.functions.insert(f_handle, o_handle); + } } module.functions = functions; + // Process entry points let mut entry_points = mem::take(&mut module.entry_points); - for ep in entry_points.iter_mut() { - process_function(&mut module, &override_map, &mut layouter, &mut ep.function)?; - process_workgroup_size_override(&mut module, &adjusted_global_expressions, ep)?; + for (ep_index, ep) in entry_points.iter_mut().enumerate() { + let result = if let Some(o_handle) = process_function( + &mut module, + &override_map, + &unresolved.functions, + &mut layouter, + &mut ep.function, + )? { + log::trace!( + "entry point {} missing override {:?}", + ep.name, + overrides[o_handle].name + ); + Some(o_handle) + } else if let Some(o_handle) = + process_workgroup_size_override(&mut module, &adjusted_global_expressions, ep)? + { + Some(o_handle) + } else { + // See if we use any global variables that are missing overrides. + let mut missing = None; + for (var_handle, _) in module.global_variables.iter() { + let global_use = module_info.get_entry_point(ep_index)[var_handle]; + match unresolved.global_variables.get(&var_handle) { + Some(&o_handle) if !global_use.is_empty() => { + missing = Some(o_handle); + break; + } + _ => {} + } + } + missing + }; + if let Some(o_handle) = result { + // We found a missing override that is required by this entry point. + // Decide whether that is an error. + match entry_point { + Some((tgt_stage, tgt_name)) if ep.stage != tgt_stage || ep.name != tgt_name => { + // An entry point was specified, and we are not currently + // processing that one, so it is okay not to have this + // override. + unresolved.entry_points.insert(ep_index, o_handle); + } + _ => { + // Either we are missing an override for the active entry point, + // or no entry point was specified. Either way, this override + // is required. + return Err(PipelineConstantError::MissingValue( + override_key(&overrides[o_handle]).to_string(), + )); + } + } + } } + module.entry_points = entry_points; module.overrides = overrides; @@ -225,66 +413,84 @@ pub fn process_overrides<'a>( // recompute their types and other metadata. For the time being, // do a full re-validation. let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all()); - let module_info = validator.validate_resolved_overrides(&module)?; + let module_info = validator.validate_with_resolved_overrides(&module, &unresolved)?; - Ok((Cow::Owned(module), Cow::Owned(module_info))) + Ok(ProcessOverridesOutput { + module: Cow::Owned(module), + info: Cow::Owned(module_info), + unresolved, + }) } +/// Process override expressions in the WGSL `@workgroup_size` attribute. +/// +/// If all expressions are resolved, returns `Ok(None)`. If any expression could +/// not be resolved due to missing override values, returns `Ok(Some(handle))`, +/// with the handle of the first identified missing override. The caller is +/// responsible for determining whether translation can proceed despite the +/// missing override. fn process_workgroup_size_override( module: &mut Module, adjusted_global_expressions: &HandleVec>, ep: &mut crate::EntryPoint, -) -> Result<(), PipelineConstantError> { +) -> Result>, PipelineConstantError> { match ep.workgroup_size_overrides { None => {} Some(overrides) => { - overrides.iter().enumerate().try_for_each( - |(i, overridden)| -> Result<(), PipelineConstantError> { - match *overridden { - None => Ok(()), - Some(h) => { - ep.workgroup_size[i] = module - .to_ctx() - .eval_expr_to_u32(adjusted_global_expressions[h]) - .map(|n| { - if n == 0 { - Err(PipelineConstantError::NegativeWorkgroupSize) - } else { - Ok(n) - } - }) - .map_err(|_| PipelineConstantError::NegativeWorkgroupSize)??; - Ok(()) + for (ov_index, ov) in overrides.iter().enumerate() { + match *ov { + None => continue, + Some(h) => { + match module + .to_ctx() + .eval_expr_to_u32(adjusted_global_expressions[h]) + { + Ok(n) => { + if n == 0 { + return Err(PipelineConstantError::NegativeWorkgroupSize); + } else { + ep.workgroup_size[ov_index] = n; + } + } + Err(U32EvalError::Override(handle)) => { + return Ok(Some(handle)); + } + Err(U32EvalError::Runtime) => { + return Err( + PipelineConstantError::WorkgroupSizeOverrideEvaluationError, + ); + } + Err(U32EvalError::Negative) => { + return Err(PipelineConstantError::NegativeWorkgroupSize); + } } } - }, - )?; + } + } ep.workgroup_size_overrides = None; } } - Ok(()) + Ok(None) } -/// Add a [`Constant`] to `module` for the override `old_h`. +/// If a value for the override `old_h` is given in `self.pipeline_constants`, +/// add a [`Constant`] for that override to `module`. +/// +/// If a value is found, adds the new `Constant` to `override_map` and +/// `adjusted_constant_initializers`, and returns it. /// -/// Add the new `Constant` to `override_map` and `adjusted_constant_initializers`. +/// If no value is found, returns `Ok(None)`. The caller is responsible for +/// determining whether translation can proceed despite the missing override. fn process_override( (old_h, r#override, span): (Handle, &mut Override, &Span), pipeline_constants: &PipelineConstants, module: &mut Module, - override_map: &mut HandleVec>, + override_map: &mut HandleVec>>, adjusted_global_expressions: &HandleVec>, adjusted_constant_initializers: &mut HashSet>, global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker, -) -> Result, PipelineConstantError> { - // Determine which key to use for `r#override` in `pipeline_constants`. - let key = if let Some(id) = r#override.id { - Cow::Owned(id.to_string()) - } else if let Some(ref name) = r#override.name { - Cow::Borrowed(name) - } else { - unreachable!(); - }; +) -> Result>, PipelineConstantError> { + let key = override_key(r#override); // Generate a global expression for `r#override`'s value, either // from the provided `pipeline_constants` table or its initializer @@ -299,10 +505,16 @@ fn process_override( .append(Expression::Literal(literal), Span::UNDEFINED); global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const); expr - } else if let Some(init) = r#override.init { - adjusted_global_expressions[init] } else { - return Err(PipelineConstantError::MissingValue(key.to_string())); + match r#override.init { + Some(init) if global_expression_kind_tracker.is_const(init) => { + adjusted_global_expressions[init] + } + _ => { + override_map.insert(old_h, None); + return Ok(None); + } + } }; // Generate a new `Constant` to represent the override's value. @@ -312,27 +524,35 @@ fn process_override( init, }; let h = module.constants.append(constant, *span); - override_map.insert(old_h, h); + override_map.insert(old_h, Some(h)); adjusted_constant_initializers.insert(h); r#override.init = Some(init); - Ok(h) + Ok(Some(h)) } -/// Replace all override expressions in `function` with fully-evaluated constants. +/// Replace override expressions in `function` with fully-evaluated constants. /// -/// Replace all `Expression::Override`s in `function`'s expression arena with +/// Replace `Expression::Override`s in `function`'s expression arena with /// the corresponding `Expression::Constant`s, as given in `override_map`. /// Replace any expressions whose values are now known with their fully /// evaluated form. /// /// If `h` is a `Handle`, then `override_map[h]` is the /// `Handle` for the override's final value. +/// +/// If all override expressions are replaced, returns `Ok(None)`. If any +/// expression could not be replaced due to missing override values, or if +/// the function calls another function that is present in +/// `functions_missing_overrides`, returns `Ok(Some(handle))`, with the handle +/// of the first identified missing override. The caller is responsible for +/// determining whether translation can proceed despite the missing override. fn process_function( module: &mut Module, - override_map: &HandleVec>, + override_map: &HandleVec>>, + functions_missing_overrides: &FastHashMap, Handle>, layouter: &mut crate::proc::Layouter, function: &mut Function, -) -> Result<(), ConstantEvaluatorError> { +) -> Result>, ConstantEvaluatorError> { // A map from original local expression handles to // handles in the new, local expression arena. let mut adjusted_local_expressions = HandleVec::with_capacity(function.expressions.len()); @@ -341,6 +561,8 @@ fn process_function( let mut expressions = mem::take(&mut function.expressions); + let mut missing_override = None; + // Dummy `emitter` and `block` for the constant evaluator. // We can ignore the concept of emitting expressions here since // expressions have already been covered by a `Statement::Emit` @@ -363,14 +585,29 @@ fn process_function( for (old_h, mut expr, span) in expressions.drain() { if let Expression::Override(h) = expr { - expr = Expression::Constant(override_map[h]); + if let Some(&Some(const_h)) = override_map.get(h) { + expr = Expression::Constant(const_h); + } else if missing_override.is_none() { + missing_override = Some(h); + } } adjust_expr(&adjusted_local_expressions, &mut expr); - let h = evaluator.try_eval_and_append(expr, span)?; + let h = evaluator + .try_eval_and_append(expr, span) + .map_err(|(_expr, err)| err)?; adjusted_local_expressions.insert(old_h, h); } - adjust_block(&adjusted_local_expressions, &mut function.body); + match adjust_block( + &adjusted_local_expressions, + functions_missing_overrides, + &mut function.body, + ) { + missing @ Some(_) if missing_override.is_none() => { + missing_override = missing; + } + _ => {} + } filter_emits_in_block(&mut function.body, &function.expressions); @@ -390,7 +627,7 @@ fn process_function( .insert(adjusted_local_expressions[expr_h], name); } - Ok(()) + Ok(missing_override) } /// Replace every expression handle in `expr` with its counterpart @@ -606,15 +843,39 @@ fn adjust_expr(new_pos: &HandleVec>, expr: &mut E /// Replace every expression handle in `block` with its counterpart /// given by `new_pos`. -fn adjust_block(new_pos: &HandleVec>, block: &mut Block) { +/// +/// On success, returns `Ok(None)`. If `block` calls a function that is present +/// in `functions_missing_overrides`, returns `Ok(Some(handle))`, with the +/// handle of the first identified missing override. +fn adjust_block( + new_pos: &HandleVec>, + functions_missing_overrides: &FastHashMap, Handle>, + block: &mut Block, +) -> Option> { + let mut missing_override = None; for stmt in block.iter_mut() { - adjust_stmt(new_pos, stmt); + match adjust_stmt(new_pos, functions_missing_overrides, stmt) { + missing @ Some(_) if missing_override.is_none() => { + missing_override = missing; + } + _ => {} + } } + missing_override } /// Replace every expression handle in `stmt` with its counterpart /// given by `new_pos`. -fn adjust_stmt(new_pos: &HandleVec>, stmt: &mut Statement) { +/// +/// On success, returns `Ok(None)`. If `stmt` calls a function that is present +/// in `functions_missing_overrides`, returns `Ok(Some(handle))`, with the +/// handle of the first identified missing override. +fn adjust_stmt( + new_pos: &HandleVec>, + functions_missing_overrides: &FastHashMap, Handle>, + stmt: &mut Statement, +) -> Option> { + let mut missing_override = None; let adjust = |expr: &mut Handle| { *expr = new_pos[*expr]; }; @@ -627,7 +888,7 @@ fn adjust_stmt(new_pos: &HandleVec>, stmt: &mut S } } Statement::Block(ref mut block) => { - adjust_block(new_pos, block); + adjust_block(new_pos, functions_missing_overrides, block); } Statement::If { ref mut condition, @@ -635,8 +896,8 @@ fn adjust_stmt(new_pos: &HandleVec>, stmt: &mut S ref mut reject, } => { adjust(condition); - adjust_block(new_pos, accept); - adjust_block(new_pos, reject); + adjust_block(new_pos, functions_missing_overrides, accept); + adjust_block(new_pos, functions_missing_overrides, reject); } Statement::Switch { ref mut selector, @@ -644,7 +905,7 @@ fn adjust_stmt(new_pos: &HandleVec>, stmt: &mut S } => { adjust(selector); for case in cases.iter_mut() { - adjust_block(new_pos, &mut case.body); + adjust_block(new_pos, functions_missing_overrides, &mut case.body); } } Statement::Loop { @@ -652,8 +913,8 @@ fn adjust_stmt(new_pos: &HandleVec>, stmt: &mut S ref mut continuing, ref mut break_if, } => { - adjust_block(new_pos, body); - adjust_block(new_pos, continuing); + adjust_block(new_pos, functions_missing_overrides, body); + adjust_block(new_pos, functions_missing_overrides, continuing); if let Some(e) = break_if.as_mut() { adjust(e); } @@ -769,8 +1030,14 @@ fn adjust_stmt(new_pos: &HandleVec>, stmt: &mut S Statement::Call { ref mut arguments, ref mut result, - function: _, + function, } => { + match functions_missing_overrides.get(&function).copied() { + missing @ Some(_) if missing_override.is_none() => { + missing_override = missing; + } + _ => {} + } for argument in arguments.iter_mut() { adjust(argument); } @@ -803,6 +1070,7 @@ fn adjust_stmt(new_pos: &HandleVec>, stmt: &mut S } Statement::Break | Statement::Continue | Statement::Kill | Statement::Barrier(_) => {} } + missing_override } /// Adjust [`Emit`] statements in `block` to skip [`needs_pre_emit`] expressions we have introduced. diff --git a/naga/src/front/glsl/context.rs b/naga/src/front/glsl/context.rs index e6c5546fb95..b151d7af8eb 100644 --- a/naga/src/front/glsl/context.rs +++ b/naga/src/front/glsl/context.rs @@ -277,10 +277,11 @@ impl<'a> Context<'a> { ) }; - eval.try_eval_and_append(expr, meta).map_err(|e| Error { - kind: e.into(), - meta, - }) + eval.try_eval_and_append(expr, meta) + .map_err(|(_expr, err)| Error { + kind: err.into(), + meta, + }) } /// Add variable to current scope diff --git a/naga/src/front/glsl/parser.rs b/naga/src/front/glsl/parser.rs index 2eb3ec4b009..290bcc54431 100644 --- a/naga/src/front/glsl/parser.rs +++ b/naga/src/front/glsl/parser.rs @@ -219,7 +219,7 @@ impl<'source> ParsingContext<'source> { kind: ErrorKind::SemanticError("int constant overflows".into()), meta, }), - Err(U32EvalError::NonConst) => Err(Error { + Err(U32EvalError::Runtime | U32EvalError::Override(_)) => Err(Error { kind: ErrorKind::SemanticError("Expected a uint constant".into()), meta, }), diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 93ccb7143ca..35c2e8a69c2 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -520,7 +520,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { ) -> Result<'source, Handle> { let mut eval = self.as_const_evaluator(); eval.try_eval_and_append(expr, span) - .map_err(|e| Box::new(Error::ConstantEvaluatorError(e.into(), span))) + .map_err(|(_expr, err)| Box::new(Error::ConstantEvaluatorError(err.into(), span))) } fn const_eval_expr_to_u32( @@ -530,7 +530,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { match self.expr_type { ExpressionContextType::Runtime(ref ctx) => { if !ctx.local_expression_kind_tracker.is_const(handle) { - return Err(proc::U32EvalError::NonConst); + return Err(proc::U32EvalError::Runtime); } self.module @@ -544,7 +544,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { .eval_expr_to_u32_from(handle, &ctx.function.expressions) } ExpressionContextType::Constant(None) => self.module.to_ctx().eval_expr_to_u32(handle), - ExpressionContextType::Override => Err(proc::U32EvalError::NonConst), + ExpressionContextType::Override => Err(proc::U32EvalError::Runtime), } } @@ -628,7 +628,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { .to_ctx() .eval_expr_to_u32_from(expr, &rctx.function.expressions) .map_err(|err| match err { - proc::U32EvalError::NonConst => { + proc::U32EvalError::Runtime | proc::U32EvalError::Override(_) => { Error::ExpectedConstExprConcreteIntegerScalar(component_span) } proc::U32EvalError::Negative => Error::ExpectedNonNegative(component_span), @@ -1431,7 +1431,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { Err(err) => { if let Error::ConstantEvaluatorError(ref ty, _) = *err { match **ty { - proc::ConstantEvaluatorError::OverrideExpr => { + proc::ConstantEvaluatorError::Override(_) => { workgroup_size_overrides_out[i] = Some(self.workgroup_size_override( size_expr, @@ -1739,12 +1739,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .to_ctx() .eval_expr_to_literal_from(expr, &ctx.function.expressions) { - Some(ir::Literal::I32(value)) => { - ir::SwitchValue::I32(value) - } - Some(ir::Literal::U32(value)) => { - ir::SwitchValue::U32(value) - } + Ok(ir::Literal::I32(value)) => ir::SwitchValue::I32(value), + Ok(ir::Literal::U32(value)) => ir::SwitchValue::U32(value), _ => { return Err(Box::new(Error::InvalidSwitchCase { span, @@ -3587,7 +3583,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .to_ctx() .eval_expr_to_u32(expr) .map_err(|err| match err { - proc::U32EvalError::NonConst => Error::ExpectedConstExprConcreteIntegerScalar(span), + proc::U32EvalError::Runtime | proc::U32EvalError::Override(_) => { + Error::ExpectedConstExprConcreteIntegerScalar(span) + } proc::U32EvalError::Negative => Error::ExpectedNonNegative(span), })?; Ok((value, span)) @@ -3606,7 +3604,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { Ok(value) => { let len = ctx.const_eval_expr_to_u32(value).map_err(|err| { Box::new(match err { - proc::U32EvalError::NonConst => { + proc::U32EvalError::Runtime | proc::U32EvalError::Override(_) => { Error::ExpectedConstExprConcreteIntegerScalar(span) } proc::U32EvalError::Negative => { @@ -3621,7 +3619,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { Err(err) => { if let Error::ConstantEvaluatorError(ref ty, _) = *err { match **ty { - proc::ConstantEvaluatorError::OverrideExpr => { + proc::ConstantEvaluatorError::Override(_) => { ir::ArraySize::Pending(self.array_size_override( expr, &mut ctx.as_global().as_override(), diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index 27d6addc826..573cc12413f 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -558,11 +558,11 @@ pub enum ConstantEvaluatorError { #[error(transparent)] Literal(#[from] crate::valid::LiteralError), #[error("Can't use pipeline-overridable constants in const-expressions")] - Override, + Override(Handle), #[error("Unexpected runtime-expression")] RuntimeExpr, - #[error("Unexpected override-expression")] - OverrideExpr, + #[error("Unexpectedly able to evaluate an override expression")] + EvaluatedOverrideExpr, } impl<'a> ConstantEvaluator<'a> { @@ -740,7 +740,8 @@ impl<'a> ConstantEvaluator<'a> { /// contributing to some function's expression arena, then append `expr` to /// that arena unchanged (and thus unevaluated). Otherwise, `self` must be /// contributing to the module's constant expression arena; since `expr`'s - /// value is not a constant, return an error. + /// value is not a constant, return an error (along with the original + /// expression, in case the caller needs it). /// /// We only consider `expr` itself, without recursing into its operands. Its /// operands must all have been produced by prior calls to @@ -755,7 +756,7 @@ impl<'a> ConstantEvaluator<'a> { &mut self, expr: Expression, span: Span, - ) -> Result, ConstantEvaluatorError> { + ) -> Result, (Expression, ConstantEvaluatorError)> { match self.expression_kind_tracker.type_of_with_expr(&expr) { ExpressionKind::Const => { let eval_result = self.try_eval_and_append_impl(&expr, span); @@ -772,7 +773,7 @@ impl<'a> ConstantEvaluator<'a> { { Ok(self.append_expr(expr, span, ExpressionKind::Runtime)) } else { - eval_result + eval_result.map_err(|err| (expr, err)) } } ExpressionKind::Override => match self.behavior { @@ -780,7 +781,19 @@ impl<'a> ConstantEvaluator<'a> { Ok(self.append_expr(expr, span, ExpressionKind::Override)) } Behavior::Wgsl(WgslRestrictions::Const(_)) => { - Err(ConstantEvaluatorError::OverrideExpr) + // We should always get `ConstantEvaluatorError::Override` + // here. If we get something else, then it's probably a bug + // in the expression kind determination. We attempt evaluation + // here in order to identify the overrides that would be + // required to evaluate this expression, for use in diagnostics. + match self.try_eval_and_append_impl(&expr, span) { + Err(ov_err @ ConstantEvaluatorError::Override(_)) => Err((expr, ov_err)), + Err(err) => { + log::debug!("expected an override error, but got {:?}", err); + Err((expr, err)) + } + Ok(_) => Err((expr, ConstantEvaluatorError::EvaluatedOverrideExpr)), + } } Behavior::Glsl(_) => { unreachable!() @@ -790,7 +803,7 @@ impl<'a> ConstantEvaluator<'a> { if self.behavior.has_runtime_restrictions() { Ok(self.append_expr(expr, span, ExpressionKind::Runtime)) } else { - Err(ConstantEvaluatorError::RuntimeExpr) + Err((expr, ConstantEvaluatorError::RuntimeExpr)) } } } @@ -830,7 +843,7 @@ impl<'a> ConstantEvaluator<'a> { // This is mainly done to avoid having constants pointing to other constants. Ok(self.constants[c].init) } - Expression::Override(_) => Err(ConstantEvaluatorError::Override), + Expression::Override(ov) => Err(ConstantEvaluatorError::Override(ov)), Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => { self.register_evaluated_expr(expr.clone(), span) } diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index 0843e709b5d..e1bc6b4758c 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -415,7 +415,12 @@ impl crate::Module { #[derive(Debug)] pub(super) enum U32EvalError { - NonConst, + /// Expression is not constant. + Runtime, + + /// Expression is not constant because the indicated override value is not supplied. + Override(crate::Handle), + Negative, } @@ -444,11 +449,10 @@ impl GlobalCtx<'_> { arena: &crate::Arena, ) -> Result { match self.eval_expr_to_literal_from(handle, arena) { - Some(crate::Literal::U32(value)) => Ok(value), - Some(crate::Literal::I32(value)) => { - value.try_into().map_err(|_| U32EvalError::Negative) - } - _ => Err(U32EvalError::NonConst), + Ok(crate::Literal::U32(value)) => Ok(value), + Ok(crate::Literal::I32(value)) => value.try_into().map_err(|_| U32EvalError::Negative), + Err(Some(ov_handle)) => Err(U32EvalError::Override(ov_handle)), + _ => Err(U32EvalError::Runtime), } } @@ -460,7 +464,7 @@ impl GlobalCtx<'_> { arena: &crate::Arena, ) -> Option { match self.eval_expr_to_literal_from(handle, arena) { - Some(crate::Literal::Bool(value)) => Some(value), + Ok(crate::Literal::Bool(value)) => Some(value), _ => None, } } @@ -469,7 +473,7 @@ impl GlobalCtx<'_> { pub(crate) fn eval_expr_to_literal( &self, handle: crate::Handle, - ) -> Option { + ) -> Result>> { self.eval_expr_to_literal_from(handle, self.global_expressions) } @@ -477,25 +481,26 @@ impl GlobalCtx<'_> { &self, handle: crate::Handle, arena: &crate::Arena, - ) -> Option { + ) -> Result>> { fn get( gctx: GlobalCtx, handle: crate::Handle, arena: &crate::Arena, - ) -> Option { + ) -> Result>> { match arena[handle] { - crate::Expression::Literal(literal) => Some(literal), + crate::Expression::Literal(literal) => Ok(literal), crate::Expression::ZeroValue(ty) => match gctx.types[ty].inner { - crate::TypeInner::Scalar(scalar) => crate::Literal::zero(scalar), - _ => None, + crate::TypeInner::Scalar(scalar) => crate::Literal::zero(scalar).ok_or(None), + _ => Err(None), }, - _ => None, + _ => Err(None), } } match arena[handle] { crate::Expression::Constant(c) => { get(*self, self.constants[c].init, self.global_expressions) } + crate::Expression::Override(handle) => Err(Some(handle)), _ => get(*self, handle, arena), } } @@ -531,7 +536,9 @@ impl crate::ArraySize { return Err(ResolveArraySizeError::NonConstArrayLength); }; let length = gctx.eval_expr_to_u32(expr).map_err(|err| match err { - U32EvalError::NonConst => ResolveArraySizeError::NonConstArrayLength, + U32EvalError::Runtime | U32EvalError::Override(_) => { + ResolveArraySizeError::NonConstArrayLength + } U32EvalError::Negative => ResolveArraySizeError::ExpectedPositiveArrayLength, })?; diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 63de450372a..6466ca45d91 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -139,6 +139,8 @@ pub enum ExpressionError { Literal(#[from] LiteralError), #[error("{0:?} is not supported for Width {2} {1:?} arguments yet, see https://github.com/gfx-rs/wgpu/issues/5276")] UnsupportedWidth(crate::MathFunction, crate::ScalarKind, crate::Bytes), + #[error("Missing value for pipeline-overridable constant {0:?}")] + UnresolvedOverride(Handle), } #[derive(Clone, Debug, thiserror::Error)] @@ -194,7 +196,7 @@ impl core::ops::Index> for ExpressionTypeResolver<'_> } } -impl super::Validator { +impl super::Validator<'_> { pub(super) fn validate_const_expression( &self, handle: Handle, @@ -224,7 +226,7 @@ impl super::Validator { crate::TypeInner::Scalar { .. } => {} _ => return Err(ConstExpressionError::InvalidSplatType(value)), }, - _ if global_expr_kind.is_const(handle) || self.overrides_resolved => { + _ if global_expr_kind.is_const(handle) => { return Err(ConstExpressionError::NonFullyEvaluatedConst) } // the constant evaluator will report errors about override-expressions @@ -302,7 +304,9 @@ impl super::Validator { Err(crate::proc::U32EvalError::Negative) => { return Err(ExpressionError::NegativeIndex(base)) } - Err(crate::proc::U32EvalError::NonConst) => {} + Err( + crate::proc::U32EvalError::Runtime | crate::proc::U32EvalError::Override(_), + ) => {} } ShaderStages::all() @@ -373,7 +377,14 @@ impl super::Validator { self.validate_literal(literal)?; ShaderStages::all() } - E::Constant(_) | E::Override(_) | E::ZeroValue(_) => ShaderStages::all(), + E::Constant(_) | E::ZeroValue(_) => ShaderStages::all(), + E::Override(handle) => { + if self.overrides_resolved { + return Err(ExpressionError::UnresolvedOverride(handle)); + } else { + ShaderStages::all() + } + } E::Compose { ref components, ty } => { validate_compose( ty, diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index 7865f1fc42e..5b71b7289f1 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -217,6 +217,8 @@ pub enum FunctionError { EmitResult(Handle), #[error("Expression not visited by the appropriate statement")] UnvisitedExpression(Handle), + #[error("Missing value for pipeline-overridable constant {0:?}")] + UnresolvedOverride(Handle), } bitflags::bitflags! { @@ -318,7 +320,7 @@ impl<'a> BlockContext<'a> { } } -impl super::Validator { +impl super::Validator<'_> { fn validate_call( &mut self, function: Handle, @@ -1760,6 +1762,9 @@ impl super::Validator { &local_expr_kind, ) { Ok(stages) => info.available_stages &= stages, + Err(ExpressionError::UnresolvedOverride(handle)) => { + return Err(FunctionError::UnresolvedOverride(handle).with_span()) + } Err(source) => { return Err(FunctionError::Expression { handle, source } .with_span_handle(handle, &fun.expressions)) diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index 86285c2818b..38787789a08 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -14,7 +14,7 @@ use crate::{Arena, UniqueArena}; #[cfg(test)] use alloc::string::ToString; -impl super::Validator { +impl super::Validator<'_> { /// Validates that all handles within `module` are: /// /// * Valid, in the sense that they contain indices within each arena structure inside the diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index 3792c71abc5..b50199abcb6 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -487,7 +487,7 @@ impl VaryingContext<'_> { } } -impl super::Validator { +impl super::Validator<'_> { pub(super) fn validate_global_var( &self, var: &crate::GlobalVariable, diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index c8a02db1afa..0a08661cb67 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -18,7 +18,7 @@ use bit_set::BitSet; use crate::{ arena::{Handle, HandleSet}, proc::{ExpressionKindTracker, LayoutError, Layouter, TypeResolution}, - FastHashSet, + FastHashMap, FastHashSet, }; //TODO: analyze the model at the same time as we validate it, @@ -268,8 +268,22 @@ impl ops::Index> for ModuleInfo { } } +/// Information about overrides that remain unresolved after [`process_overrides`]. +/// +/// This struct may be passed to the various backend writers. +/// +/// [`process_overrides`]: crate::back::pipeline_constants::process_overrides +#[cfg(any(hlsl_out, msl_out, spv_out, glsl_out))] +#[derive(Clone, Debug, Default)] +pub struct UnresolvedOverrides { + pub(crate) global_variables: + FastHashMap, Handle>, + pub(crate) functions: FastHashMap, Handle>, + pub(crate) entry_points: FastHashMap>, +} + #[derive(Debug)] -pub struct Validator { +pub struct Validator<'a> { flags: ValidationFlags, capabilities: Capabilities, subgroup_stages: ShaderStages, @@ -289,6 +303,8 @@ pub struct Validator { /// constant expressions as errors. overrides_resolved: bool, + unresolved_overrides: Option<&'a UnresolvedOverrides>, + /// A checklist of expressions that must be visited by a specific kind of /// statement. /// @@ -452,7 +468,7 @@ impl crate::TypeInner { } } -impl Validator { +impl<'a> Validator<'a> { /// Construct a new validator instance. pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self { let subgroup_operations = if capabilities.contains(Capabilities::SUBGROUP) { @@ -487,6 +503,7 @@ impl Validator { valid_expression_set: HandleSet::new(), override_ids: FastHashSet::default(), overrides_resolved: false, + unresolved_overrides: None, needs_visit: HandleSet::new(), } } @@ -574,8 +591,6 @@ impl Validator { if !gctx.compare_types(&TypeResolution::Handle(o.ty), &mod_info[init]) { return Err(OverrideError::InvalidType); } - } else if self.overrides_resolved { - return Err(OverrideError::UninitializedOverride); } Ok(()) @@ -590,18 +605,19 @@ impl Validator { self.validate_impl(module) } - /// Check the given module to be valid, requiring overrides to be resolved. + /// Check the given module to be valid, after resolving overrides. /// - /// This is the same as [`validate`], except that any override - /// whose value is not a fully-evaluated constant expression is - /// treated as an error. + /// This is the same as [`validate`], but override expressions are allowed + /// in items that appear in one of the maps in `unresolved`. /// /// [`validate`]: Validator::validate - pub fn validate_resolved_overrides( + pub(crate) fn validate_with_resolved_overrides( &mut self, module: &crate::Module, + unresolved: &'a UnresolvedOverrides, ) -> Result> { self.overrides_resolved = true; + self.unresolved_overrides = Some(unresolved); self.validate_impl(module) } @@ -703,19 +719,36 @@ impl Validator { } for (var_handle, var) in module.global_variables.iter() { - self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind) - .map_err(|source| { - ValidationError::GlobalVariable { - handle: var_handle, - name: var.name.clone().unwrap_or_default(), - source, - } - .with_span_handle(var_handle, &module.global_variables) - })?; + let save_overrides_resolved = self.overrides_resolved; + match self.unresolved_overrides { + Some(unres) if unres.global_variables.contains_key(&var_handle) => { + self.overrides_resolved = false; + } + _ => {} + } + let res = self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind); + self.overrides_resolved = save_overrides_resolved; + res.map_err(|source| { + ValidationError::GlobalVariable { + handle: var_handle, + name: var.name.clone().unwrap_or_default(), + source, + } + .with_span_handle(var_handle, &module.global_variables) + })?; } for (handle, fun) in module.functions.iter() { - match self.validate_function(fun, module, &mod_info, false) { + let save_overrides_resolved = self.overrides_resolved; + match self.unresolved_overrides { + Some(unres) if unres.functions.contains_key(&handle) => { + self.overrides_resolved = false; + } + _ => {} + } + let res = self.validate_function(fun, module, &mod_info, false); + self.overrides_resolved = save_overrides_resolved; + match res { Ok(info) => mod_info.functions.push(info), Err(error) => { return Err(error.and_then(|source| { @@ -731,7 +764,7 @@ impl Validator { } let mut ep_map = FastHashSet::default(); - for ep in module.entry_points.iter() { + for (ep_index, ep) in module.entry_points.iter().enumerate() { if !ep_map.insert((ep.stage, &ep.name)) { return Err(ValidationError::EntryPoint { stage: ep.stage, @@ -741,7 +774,16 @@ impl Validator { .with_span()); // TODO: keep some EP span information? } - match self.validate_entry_point(ep, module, &mod_info) { + let save_overrides_resolved = self.overrides_resolved; + match self.unresolved_overrides { + Some(unres) if unres.entry_points.contains_key(&ep_index) => { + self.overrides_resolved = false; + } + _ => {} + } + let res = self.validate_entry_point(ep, module, &mod_info); + self.overrides_resolved = save_overrides_resolved; + match res { Ok(info) => mod_info.entry_points.push(info), Err(error) => { return Err(error.and_then(|source| { diff --git a/naga/src/valid/type.rs b/naga/src/valid/type.rs index b3ae13b7d4a..67b91a84357 100644 --- a/naga/src/valid/type.rs +++ b/naga/src/valid/type.rs @@ -252,7 +252,7 @@ impl TypeInfo { } } -impl super::Validator { +impl super::Validator<'_> { const fn require_type_capability(&self, capability: Capabilities) -> Result<(), TypeError> { if self.capabilities.contains(capability) { Ok(()) diff --git a/naga/tests/in/wgsl/missing-unused-overrides.toml b/naga/tests/in/wgsl/missing-unused-overrides.toml new file mode 100644 index 00000000000..3eb9ae5c9ca --- /dev/null +++ b/naga/tests/in/wgsl/missing-unused-overrides.toml @@ -0,0 +1,13 @@ +pipeline_constants = { ov_for_vertex = 1.5 } +#targets = "IR | ANALYSIS | SPIRV | METAL | HLSL | GLSL" +targets = "METAL" + +[msl] +lang_version = [2, 1] + +[msl_pipeline] +entry_point = ["Vertex", "vert_main"] + +[spv] +separate_entry_points = true +version = [1, 0] diff --git a/naga/tests/in/wgsl/missing-unused-overrides.wgsl b/naga/tests/in/wgsl/missing-unused-overrides.wgsl new file mode 100644 index 00000000000..2243a17c060 --- /dev/null +++ b/naga/tests/in/wgsl/missing-unused-overrides.wgsl @@ -0,0 +1,45 @@ +override ov_for_vertex: f32; + +@vertex +fn vert_main( + @location(0) pos : vec2, + @builtin(instance_index) ii: u32, + @builtin(vertex_index) vi: u32, +) -> @builtin(position) vec4 { + return vec4(pos.x * ov_for_vertex, pos.y, 0.0, 1.0); +} + +struct FragmentIn { + @location(0) color: vec4 +} + +override ov_for_fragment: f32; + +fn frag_helper(color: vec4) -> vec4 { + return color * ov_for_fragment; +} + +@fragment +fn frag_main(in: FragmentIn) -> @location(0) vec4 { + return frag_helper(in.color); +} + +override ov_global_init: u32; +var foo: u32 = ov_global_init; + +override ov_array_size: u32; +var arr: array; + +override ov_for_compute: u32; + +fn compute_helper() { + _ = foo; + _ = arr[0]; +} + +override ov_workgroup_size: u32; +@compute @workgroup_size(ov_workgroup_size) +fn compute_main() { + _ = ov_for_compute; + compute_helper(); +} diff --git a/naga/tests/naga/snapshots.rs b/naga/tests/naga/snapshots.rs index 931136ed8d2..3b39c81ee1d 100644 --- a/naga/tests/naga/snapshots.rs +++ b/naga/tests/naga/snapshots.rs @@ -145,6 +145,8 @@ struct Parameters { // -- HLSL options -- #[cfg(all(feature = "deserialize", hlsl_out))] hlsl: naga::back::hlsl::Options, + #[serde(default)] + hlsl_pipeline: naga::back::hlsl::PipelineOptions, // -- WGSL options -- wgsl: WgslOutParameters, @@ -548,6 +550,7 @@ fn check_targets(input: &Input, module: &mut naga::Module, source_code: Option<& module, &info, ¶ms.hlsl, + ¶ms.hlsl_pipeline, ¶ms.pipeline_constants, frag_ep, ); @@ -598,9 +601,12 @@ fn write_output_spv( debug_info, }; - let (module, info) = - naga::back::pipeline_constants::process_overrides(module, info, pipeline_constants) - .expect("override evaluation failed"); + let naga::back::pipeline_constants::ProcessOverridesOutput { + module, + info, + unresolved: _, + } = naga::back::pipeline_constants::process_overrides(module, info, None, pipeline_constants) + .expect("override evaluation failed"); if params.separate_entry_points { for ep in module.entry_points.iter() { @@ -660,15 +666,28 @@ fn write_output_msl( ) { use naga::back::msl; - println!("generating MSL"); - - let (module, info) = - naga::back::pipeline_constants::process_overrides(module, info, pipeline_constants) - .expect("override evaluation failed"); + println!("generating MSL for {:?}", pipeline_options.entry_point); + + let naga::back::pipeline_constants::ProcessOverridesOutput { + module, + info, + unresolved, + } = naga::back::pipeline_constants::process_overrides( + module, + info, + pipeline_options + .entry_point + .as_ref() + .map(|&(stage, ref name)| (stage, name.as_str())), + pipeline_constants, + ) + .expect("override evaluation failed"); let mut options = options.clone(); options.bounds_check_policies = bounds_check_policies; - let (string, tr_info) = msl::write_string(&module, &info, &options, pipeline_options) + let mut pipeline_options = pipeline_options.clone(); + pipeline_options.unresolved_overrides = unresolved; + let (string, tr_info) = msl::write_string(&module, &info, &options, &pipeline_options) .unwrap_or_else(|err| panic!("Metal write failed: {err}")); for (ep, result) in module.entry_points.iter().zip(tr_info.entry_point_names) { @@ -704,9 +723,12 @@ fn write_output_glsl( }; let mut buffer = String::new(); - let (module, info) = - naga::back::pipeline_constants::process_overrides(module, info, pipeline_constants) - .expect("override evaluation failed"); + let naga::back::pipeline_constants::ProcessOverridesOutput { + module, + info, + unresolved: _, + } = naga::back::pipeline_constants::process_overrides(module, info, None, pipeline_constants) + .expect("override evaluation failed"); let mut writer = glsl::Writer::new( &mut buffer, &module, @@ -728,6 +750,7 @@ fn write_output_hlsl( module: &naga::Module, info: &naga::valid::ModuleInfo, options: &naga::back::hlsl::Options, + pipeline_options: &naga::back::hlsl::PipelineOptions, pipeline_constants: &naga::back::PipelineConstants, frag_ep: Option, ) { @@ -736,9 +759,20 @@ fn write_output_hlsl( println!("generating HLSL"); - let (module, info) = - naga::back::pipeline_constants::process_overrides(module, info, pipeline_constants) - .expect("override evaluation failed"); + let naga::back::pipeline_constants::ProcessOverridesOutput { + module, + info, + unresolved: _, + } = naga::back::pipeline_constants::process_overrides( + module, + info, + pipeline_options + .entry_point + .as_ref() + .map(|&(stage, ref name)| (stage, name.as_str())), + pipeline_constants, + ) + .expect("override evaluation failed"); let mut buffer = String::new(); let pipeline_options = Default::default(); diff --git a/naga/tests/naga/validation.rs b/naga/tests/naga/validation.rs index 4b813d384e9..d832cc0659a 100644 --- a/naga/tests/naga/validation.rs +++ b/naga/tests/naga/validation.rs @@ -1,4 +1,8 @@ -use naga::{valid, Expression, Function, Scalar}; +use naga::{ + ir, + valid::{self, ModuleInfo}, + Expression, Function, Module, Scalar, +}; /// Validation should fail if `AtomicResult` expressions are not /// populated by `Atomic` statements. @@ -536,16 +540,16 @@ fn main(input: VertexOutput) {{ } #[allow(dead_code)] -struct BindingArrayFixture { +struct BindingArrayFixture<'a> { module: naga::Module, span: naga::Span, ty_u32: naga::Handle, ty_array: naga::Handle, ty_struct: naga::Handle, - validator: naga::valid::Validator, + validator: naga::valid::Validator<'a>, } -impl BindingArrayFixture { +impl BindingArrayFixture<'_> { fn new() -> Self { let mut module = naga::Module::default(); let span = naga::Span::default(); @@ -770,7 +774,6 @@ fn bad_texture_dimensions_level() { fn arity_check() { use ir::MathFunction as Mf; use naga::Span; - use naga::{ir, valid}; let _ = env_logger::builder().is_test(true).try_init(); type Result = core::result::Result; @@ -923,3 +926,157 @@ fn main() { naga::valid::GlobalUse::QUERY ); } + +fn parse_validate(source: &str) -> (Module, ModuleInfo) { + let module = naga::front::wgsl::parse_str(source).expect("module should parse"); + let info = valid::Validator::new(Default::default(), valid::Capabilities::all()) + .validate(&module) + .unwrap(); + (module, info) +} + +/// Helper for `process_overrides` tests. +/// +/// The goal of these tests is to verify that `process_overrides` accepts cases +/// where all necessary overrides are specified (even if some unnecessary ones +/// are not), and does not accept cases where necessary overrides are missing. +/// "Necessary" means that the entry point is referenced in some way by some +/// function reachable from the specified entry point. +/// +/// Each test passes a source snippet containing a compute entry point `used` +/// that makes use of the override `ov` in some way. We augment that with (1) +/// the definition of `ov` and (2) a dummy entrypoint that does not use the +/// override, and then test the matrix of (specified or not) x (used or not). +/// +/// Since `process_overrides` leaves unresolved overrides in the output module, +/// there could be bugs where a backend to reaches one of the remaining overrides +/// and fails. That is not exercised here, but is covered by the +/// `missing-unused-overrides` snapshot test. +fn override_test(test_case: &str) { + use hashbrown::HashMap; + use naga::back::pipeline_constants::PipelineConstantError; + + let source = [ + "override ov: u32;\n", + test_case, + "@compute @workgroup_size(64) +fn unused() { +} +", + ] + .concat(); + let (module, info) = parse_validate(&source); + + let overrides = HashMap::from([(String::from("ov"), 1.)]); + + // Can translate `unused` with or without the override + naga::back::pipeline_constants::process_overrides( + &module, + &info, + Some((ir::ShaderStage::Compute, "unused")), + &HashMap::new(), + ) + .unwrap(); + naga::back::pipeline_constants::process_overrides( + &module, + &info, + Some((ir::ShaderStage::Compute, "unused")), + &overrides, + ) + .unwrap(); + + // Cannot translate `used` without the override + let err = naga::back::pipeline_constants::process_overrides( + &module, + &info, + Some((ir::ShaderStage::Compute, "used")), + &HashMap::new(), + ) + .unwrap_err(); + assert!(matches!(err, PipelineConstantError::MissingValue(name) if name == "ov")); + + // Can translate `used` if the override is specified + naga::back::pipeline_constants::process_overrides( + &module, + &info, + Some((ir::ShaderStage::Compute, "used")), + &overrides, + ) + .unwrap(); +} + +#[cfg(feature = "wgsl-in")] +#[test] +fn override_in_workgroup_size() { + override_test( + " +@compute @workgroup_size(ov) +fn used() { +} +", + ); +} + +#[cfg(feature = "wgsl-in")] +#[test] +fn override_in_function() { + override_test( + " +fn foo() -> u32 { + return ov; +} + +@compute @workgroup_size(64) +fn used() { + foo(); +} +", + ); +} + +#[cfg(feature = "wgsl-in")] +#[test] +fn override_in_entrypoint() { + override_test( + " +fn foo() -> u32 { + return ov; +} + +@compute @workgroup_size(64) +fn used() { + foo(); +} +", + ); +} + +#[cfg(feature = "wgsl-in")] +#[test] +fn override_in_array_size() { + override_test( + " +var arr: array; + +@compute @workgroup_size(64) +fn used() { + _ = arr[5]; +} +", + ); +} + +#[cfg(feature = "wgsl-in")] +#[test] +fn override_in_global_init() { + override_test( + " +var foo: u32 = ov; + +@compute @workgroup_size(64) +fn used() { + _ = foo; +} +", + ); +} diff --git a/naga/tests/naga/wgsl_errors.rs b/naga/tests/naga/wgsl_errors.rs index 71e5b871715..ad3818b7f5e 100644 --- a/naga/tests/naga/wgsl_errors.rs +++ b/naga/tests/naga/wgsl_errors.rs @@ -3138,7 +3138,7 @@ fn local_const_from_override() { const c = o; } ", - r###"error: Unexpected override-expression + r###"error: Can't use pipeline-overridable constants in const-expressions ┌─ wgsl:4:23 │ 4 │ const c = o; diff --git a/naga/tests/out/msl/wgsl-missing-unused-overrides.msl b/naga/tests/out/msl/wgsl-missing-unused-overrides.msl new file mode 100644 index 00000000000..06583a1aeb5 --- /dev/null +++ b/naga/tests/out/msl/wgsl-missing-unused-overrides.msl @@ -0,0 +1,35 @@ +// language: metal2.1 +#include +#include + +using metal::uint; + +struct FragmentIn { + metal::float4 color; +}; +constant float ov_for_vertex = 1.5; + +void compute_helper( + thread uint& foo, + threadgroup type_4& arr +) { + uint phony = foo; + uint phony_1 = arr.inner[0]; + return; +} + +struct vert_mainInput { + metal::float2 pos [[attribute(0)]]; +}; +struct vert_mainOutput { + metal::float4 member [[position]]; +}; +vertex vert_mainOutput vert_main( + vert_mainInput varyings [[stage_in]] +, uint ii [[instance_id]] +, uint vi [[vertex_id]] +) { + const auto pos = varyings.pos; + return vert_mainOutput { metal::float4(pos.x * ov_for_vertex, pos.y, 0.0, 1.0) }; +} + diff --git a/wgpu-core/src/device/mod.rs b/wgpu-core/src/device/mod.rs index a8d1329fac4..421a1a682e8 100644 --- a/wgpu-core/src/device/mod.rs +++ b/wgpu-core/src/device/mod.rs @@ -371,11 +371,11 @@ impl ImplicitPipelineIds<'_> { } /// Create a validator with the given validation flags. -pub fn create_validator( +pub fn create_validator<'a>( features: wgt::Features, downlevel: wgt::DownlevelFlags, flags: naga::valid::ValidationFlags, -) -> naga::valid::Validator { +) -> naga::valid::Validator<'a> { use naga::valid::Capabilities as Caps; let mut caps = Caps::empty(); caps.set( diff --git a/wgpu-hal/src/dx12/device.rs b/wgpu-hal/src/dx12/device.rs index 3d8100b9c0b..69f9267c3e1 100644 --- a/wgpu-hal/src/dx12/device.rs +++ b/wgpu-hal/src/dx12/device.rs @@ -279,6 +279,7 @@ impl super::Device { let (module, info) = naga::back::pipeline_constants::process_overrides( &stage.module.naga.module, &stage.module.naga.info, + None, stage.constants, ) .map_err(|e| crate::PipelineError::PipelineConstants(stage_bit, format!("HLSL: {e:?}")))?; diff --git a/wgpu-hal/src/gles/device.rs b/wgpu-hal/src/gles/device.rs index c5539eae351..75369290d0b 100644 --- a/wgpu-hal/src/gles/device.rs +++ b/wgpu-hal/src/gles/device.rs @@ -216,9 +216,14 @@ impl super::Device { multiview: context.multiview, }; - let (module, info) = naga::back::pipeline_constants::process_overrides( + let naga::back::pipeline_constants::ProcessOverridesOutput { + module, + info, + unresolved: _, + } = naga::back::pipeline_constants::process_overrides( &stage.module.naga.module, &stage.module.naga.info, + None, stage.constants, ) .map_err(|e| { diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index 6ab22b0c3e3..08b1e241fed 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -134,9 +134,14 @@ impl super::Device { panic!("load_shader required a naga shader"); }; let stage_bit = map_naga_stage(naga_stage); - let (module, module_info) = naga::back::pipeline_constants::process_overrides( + let naga::back::pipeline_constants::ProcessOverridesOutput { + module, + info: module_info, + unresolved: unresolved_overrides, + } = naga::back::pipeline_constants::process_overrides( &naga_shader.module, &naga_shader.info, + Some((naga_stage, stage.entry_point)), stage.constants, ) .map_err(|e| crate::PipelineError::PipelineConstants(stage_bit, format!("MSL: {:?}", e)))?; @@ -182,6 +187,7 @@ impl super::Device { let pipeline_options = naga::back::msl::PipelineOptions { entry_point: Some((naga_stage, stage.entry_point.to_owned())), + unresolved_overrides, allow_and_force_point_size: match primitive_class { MTLPrimitiveTopologyClass::Point => true, _ => false, diff --git a/wgpu-hal/src/vulkan/device.rs b/wgpu-hal/src/vulkan/device.rs index 4ae76565120..262a0820cc3 100644 --- a/wgpu-hal/src/vulkan/device.rs +++ b/wgpu-hal/src/vulkan/device.rs @@ -889,9 +889,14 @@ impl super::Device { &self.naga_options }; - let (module, info) = naga::back::pipeline_constants::process_overrides( + let naga::back::pipeline_constants::ProcessOverridesOutput { + module, + info, + unresolved: _, + } = naga::back::pipeline_constants::process_overrides( &naga_shader.module, &naga_shader.info, + None, stage.constants, ) .map_err(|e| { From ecef54a77332d8b147c30084d72a0391ee16f9d4 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Wed, 30 Apr 2025 15:58:22 -0700 Subject: [PATCH 2/2] CI fixes --- naga/src/back/msl/writer.rs | 17 ++++-- naga/src/back/pipeline_constants.rs | 54 +++++++++++++------ naga/src/proc/mod.rs | 1 + naga/src/valid/mod.rs | 3 +- naga/tests/naga/snapshots.rs | 6 +++ .../out/msl/wgsl-missing-unused-overrides.msl | 9 ---- wgpu-hal/src/dx12/device.rs | 6 ++- 7 files changed, 65 insertions(+), 31 deletions(-) diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index c7232faf1af..a3b5cfb7161 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -6990,9 +6990,20 @@ mod workgroup_mem_init { let mut access_stack = AccessStack::new(); - let vars = module.global_variables.iter().filter(|&(handle, var)| { - !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup - }); + let vars = module + .global_variables + .iter() + .filter(|&(handle, var)| { + !fun_info[handle].is_empty() + && var.space == crate::AddressSpace::WorkGroup + && !self + .unresolved_overrides + .as_ref() + .unwrap() + .global_variables + .contains_key(&handle) + }) + .collect::>(); for (handle, var) in vars { access_stack.enter(Access::GlobalVariable(handle), |access_stack| { diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index b92e3cfd816..790e783ab75 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -13,7 +13,8 @@ use crate::{ ir, proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter, U32EvalError}, valid::{ - Capabilities, ModuleInfo, UnresolvedOverrides, ValidationError, ValidationFlags, Validator, + Capabilities, FunctionInfo, ModuleInfo, UnresolvedOverrides, ValidationError, + ValidationFlags, Validator, }, Arena, Block, Constant, Expression, FastHashMap, Function, Handle, Literal, Module, Override, Range, Scalar, Span, Statement, TypeInner, WithSpan, @@ -62,6 +63,26 @@ pub struct ProcessOverridesOutput<'a> { pub unresolved: UnresolvedOverrides, } +/// Check the global usage in `fun_info` for any globals affected by unresolved +/// overrides. +/// +/// If any is found, returns `Some`, otherwise returns `None`. +fn check_for_unresolved_global_use<'a>( + globals: impl Iterator, &'a ir::GlobalVariable)>, + unresolved: &UnresolvedOverrides, + fun_info: &FunctionInfo, +) -> Option> { + for (var_handle, _) in globals { + match unresolved.global_variables.get(&var_handle) { + Some(&o_handle) if !fun_info[var_handle].is_empty() => { + return Some(o_handle); + } + _ => {} + } + } + None +} + /// Replace overrides in `module` with constants. /// /// If no changes are needed, this just returns `Cow::Borrowed` @@ -328,11 +349,11 @@ pub fn process_overrides<'a>( } // Process functions, taking note of which ones require overrides that were - // not specified. Like expressions, callees are guarenteed to appear before + // not specified. Like expressions, callees are guaranteed to appear before // their callers. let mut functions = mem::take(&mut module.functions); for (f_handle, function) in functions.iter_mut() { - if let Some(o_handle) = process_function( + let result = if let Some(o_handle) = process_function( &mut module, &override_map, &unresolved.functions, @@ -344,6 +365,15 @@ pub fn process_overrides<'a>( function.name, overrides[o_handle].name ); + Some(o_handle) + } else { + check_for_unresolved_global_use( + module.global_variables.iter(), + &unresolved, + &module_info[f_handle], + ) + }; + if let Some(o_handle) = result { unresolved.functions.insert(f_handle, o_handle); } } @@ -370,19 +400,11 @@ pub fn process_overrides<'a>( { Some(o_handle) } else { - // See if we use any global variables that are missing overrides. - let mut missing = None; - for (var_handle, _) in module.global_variables.iter() { - let global_use = module_info.get_entry_point(ep_index)[var_handle]; - match unresolved.global_variables.get(&var_handle) { - Some(&o_handle) if !global_use.is_empty() => { - missing = Some(o_handle); - break; - } - _ => {} - } - } - missing + check_for_unresolved_global_use( + module.global_variables.iter(), + &unresolved, + module_info.get_entry_point(ep_index), + ) }; if let Some(o_handle) = result { // We found a missing override that is required by this entry point. diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index e1bc6b4758c..2b03f18dcb8 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -414,6 +414,7 @@ impl crate::Module { } #[derive(Debug)] +#[cfg_attr(not(any(hlsl_out, msl_out, spv_out, glsl_out)), allow(dead_code))] pub(super) enum U32EvalError { /// Expression is not constant. Runtime, diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index 0a08661cb67..11872ada1ae 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -273,7 +273,6 @@ impl ops::Index> for ModuleInfo { /// This struct may be passed to the various backend writers. /// /// [`process_overrides`]: crate::back::pipeline_constants::process_overrides -#[cfg(any(hlsl_out, msl_out, spv_out, glsl_out))] #[derive(Clone, Debug, Default)] pub struct UnresolvedOverrides { pub(crate) global_variables: @@ -611,7 +610,7 @@ impl<'a> Validator<'a> { /// in items that appear in one of the maps in `unresolved`. /// /// [`validate`]: Validator::validate - pub(crate) fn validate_with_resolved_overrides( + pub fn validate_with_resolved_overrides( &mut self, module: &crate::Module, unresolved: &'a UnresolvedOverrides, diff --git a/naga/tests/naga/snapshots.rs b/naga/tests/naga/snapshots.rs index 3b39c81ee1d..2cfd7b40482 100644 --- a/naga/tests/naga/snapshots.rs +++ b/naga/tests/naga/snapshots.rs @@ -235,6 +235,12 @@ impl Input { return None; } + if let Ok(pat) = std::env::var("NAGA_SNAPSHOT") { + if !file_name.to_string_lossy().contains(&pat) { + return None; + } + } + let input = Input::new( subdirectory, file_name.file_stem().unwrap().to_str().unwrap(), diff --git a/naga/tests/out/msl/wgsl-missing-unused-overrides.msl b/naga/tests/out/msl/wgsl-missing-unused-overrides.msl index 06583a1aeb5..e2a7262f4d9 100644 --- a/naga/tests/out/msl/wgsl-missing-unused-overrides.msl +++ b/naga/tests/out/msl/wgsl-missing-unused-overrides.msl @@ -9,15 +9,6 @@ struct FragmentIn { }; constant float ov_for_vertex = 1.5; -void compute_helper( - thread uint& foo, - threadgroup type_4& arr -) { - uint phony = foo; - uint phony_1 = arr.inner[0]; - return; -} - struct vert_mainInput { metal::float2 pos [[attribute(0)]]; }; diff --git a/wgpu-hal/src/dx12/device.rs b/wgpu-hal/src/dx12/device.rs index 69f9267c3e1..1db2468fe74 100644 --- a/wgpu-hal/src/dx12/device.rs +++ b/wgpu-hal/src/dx12/device.rs @@ -276,7 +276,11 @@ impl super::Device { let stage_bit = auxil::map_naga_stage(naga_stage); - let (module, info) = naga::back::pipeline_constants::process_overrides( + let naga::back::pipeline_constants::ProcessOverridesOutput { + module, + info, + unresolved: _, + } = naga::back::pipeline_constants::process_overrides( &stage.module.naga.module, &stage.module.naga.info, None,