Skip to content

[naga] Process overrides selectively for the active entry point #7652

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: trunk
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 43 additions & 12 deletions naga-cli/src/bin/naga.rs
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,8 @@ fn write_output(
params: &Parameters,
output_path: &str,
) -> anyhow::Result<()> {
use naga::back::pipeline_constants::ProcessOverridesOutput;

match Path::new(&output_path)
.extension()
.ok_or(CliError("Output filename has no extension"))?
Expand Down Expand Up @@ -717,9 +719,14 @@ fn write_output(
succeed, and it failed in a previous step",
))?;

let (module, info) =
naga::back::pipeline_constants::process_overrides(module, info, &params.overrides)
.unwrap_pretty();
let ProcessOverridesOutput { module, info, .. } =
naga::back::pipeline_constants::process_overrides(
module,
info,
None,
&params.overrides,
)
.unwrap_pretty();

let pipeline_options = msl::PipelineOptions::default();
let (msl, _) =
Expand Down Expand Up @@ -751,9 +758,17 @@ fn write_output(
succeed, and it failed in a previous step",
))?;

let (module, info) =
naga::back::pipeline_constants::process_overrides(module, info, &params.overrides)
.unwrap_pretty();
let naga::back::pipeline_constants::ProcessOverridesOutput {
module,
info,
unresolved: _,
} = naga::back::pipeline_constants::process_overrides(
module,
info,
None,
&params.overrides,
)
.unwrap_pretty();

let spv =
spv::write_vec(&module, &info, &params.spv_out, pipeline_options).unwrap_pretty();
Expand Down Expand Up @@ -788,9 +803,17 @@ fn write_output(
succeed, and it failed in a previous step",
))?;

let (module, info) =
naga::back::pipeline_constants::process_overrides(module, info, &params.overrides)
.unwrap_pretty();
let naga::back::pipeline_constants::ProcessOverridesOutput {
module,
info,
unresolved: _,
} = naga::back::pipeline_constants::process_overrides(
module,
info,
None,
&params.overrides,
)
.unwrap_pretty();

let mut buffer = String::new();
let mut writer = glsl::Writer::new(
Expand Down Expand Up @@ -819,9 +842,17 @@ fn write_output(
succeed, and it failed in a previous step",
))?;

let (module, info) =
naga::back::pipeline_constants::process_overrides(module, info, &params.overrides)
.unwrap_pretty();
let naga::back::pipeline_constants::ProcessOverridesOutput {
module,
info,
unresolved: _,
} = naga::back::pipeline_constants::process_overrides(
module,
info,
None,
&params.overrides,
)
.unwrap_pretty();

let mut buffer = String::new();
let pipeline_options = Default::default();
Expand Down
15 changes: 14 additions & 1 deletion naga/src/back/msl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,12 @@ use alloc::{
};
use core::fmt::{Error as FmtError, Write};

use crate::{arena::Handle, ir, proc::index, valid::ModuleInfo};
use crate::{
arena::Handle,
ir,
proc::index,
valid::{ModuleInfo, UnresolvedOverrides},
};

mod keywords;
pub mod sampler;
Expand Down Expand Up @@ -431,6 +436,14 @@ pub struct PipelineOptions {
/// point is not found, an error will be thrown while writing.
pub entry_point: Option<(ir::ShaderStage, String)>,

/// Information about unresolved overrides.
///
/// This struct is returned by `process_overrides`. It tells the writer
/// which items to omit from the output because they are not used and refer
/// to overrides that were not resolved to a concrete value.
#[cfg_attr(feature = "serialize", serde(skip))]
pub unresolved_overrides: UnresolvedOverrides,

/// Allow `BuiltIn::PointSize` and inject it if doesn't exist.
///
/// Metal doesn't like this for non-point primitive topologies and requires it for
Expand Down
77 changes: 68 additions & 9 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ use crate::{
proc::{
self,
index::{self, BoundsCheck},
NameKey, TypeResolution,
NameKey, ResolveArraySizeError, TypeResolution,
},
valid, FastHashMap, FastHashSet,
valid::{self, UnresolvedOverrides},
FastHashMap, FastHashSet,
};

#[cfg(test)]
Expand Down Expand Up @@ -436,6 +437,7 @@ pub struct Writer<W> {
/// Set of (struct type, struct field index) denoting which fields require
/// padding inserted **before** them (i.e. between fields at index - 1 and index)
struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
unresolved_overrides: Option<UnresolvedOverrides>,
}

impl crate::Scalar {
Expand Down Expand Up @@ -775,6 +777,7 @@ impl<W: Write> Writer<W> {
#[cfg(test)]
put_block_stack_pointers: Default::default(),
struct_member_pads: FastHashSet::default(),
unresolved_overrides: None,
}
}

Expand Down Expand Up @@ -4032,6 +4035,7 @@ impl<W: Write> Writer<W> {
);
self.wrapped_functions.clear();
self.struct_member_pads.clear();
self.unresolved_overrides = Some(pipeline_options.unresolved_overrides.clone());

writeln!(
self.out,
Expand Down Expand Up @@ -4216,8 +4220,20 @@ impl<W: Write> Writer<W> {
first_time: false,
};

match size.resolve(module.to_ctx())? {
proc::IndexableLength::Known(size) => {
match size.resolve(module.to_ctx()) {
Err(ResolveArraySizeError::NonConstArrayLength) => {
// The array size was never resolved. This _should_
// be because it is an override expression and the
// type is not needed for the entry point being
// written.
// TODO: do we want to assemble `UnresolvedOverrides.types` to make this safer?
// (And if so, do we also want to validate that those types are truly unused?)
continue;
}
Err(err @ ResolveArraySizeError::ExpectedPositiveArrayLength) => {
return Err(err.into());
}
Ok(proc::IndexableLength::Known(size)) => {
writeln!(self.out, "struct {name} {{")?;
writeln!(
self.out,
Expand All @@ -4229,7 +4245,7 @@ impl<W: Write> Writer<W> {
)?;
writeln!(self.out, "}};")?;
}
proc::IndexableLength::Dynamic => {
Ok(proc::IndexableLength::Dynamic) => {
writeln!(self.out, "typedef {base_name} {name}[1];")?;
}
}
Expand Down Expand Up @@ -5757,6 +5773,17 @@ template <typename A>
fun_handle
);

if self
.unresolved_overrides
.as_ref()
.unwrap()
.functions
.contains_key(&fun_handle)
{
log::trace!("skipping due to unresolved overrides");
continue;
}

let ctx = back::FunctionCtx {
ty: back::FunctionType::Function(fun_handle),
info: &mod_info[fun_handle],
Expand Down Expand Up @@ -5880,6 +5907,19 @@ template <typename A>
};

for ep_index in ep_range {
if self
.unresolved_overrides
.as_ref()
.unwrap()
.entry_points
.contains_key(&ep_index)
{
log::error!(
"must write the same entry point that was passed to `process_overrides`"
);
return Err(Error::Override);
}

let ep = &module.entry_points[ep_index];
let fun = &ep.function;
let fun_info = mod_info.get_entry_point(ep_index);
Expand Down Expand Up @@ -6288,7 +6328,15 @@ template <typename A>
// within the entry point.
for (handle, var) in module.global_variables.iter() {
let usage = fun_info[handle];
if usage.is_empty() || var.space == crate::AddressSpace::Private {
if usage.is_empty()
|| var.space == crate::AddressSpace::Private
|| self
.unresolved_overrides
.as_ref()
.unwrap()
.global_variables
.contains_key(&handle)
{
continue;
}

Expand Down Expand Up @@ -6942,9 +6990,20 @@ mod workgroup_mem_init {

let mut access_stack = AccessStack::new();

let vars = module.global_variables.iter().filter(|&(handle, var)| {
!fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
});
let vars = module
.global_variables
.iter()
.filter(|&(handle, var)| {
!fun_info[handle].is_empty()
&& var.space == crate::AddressSpace::WorkGroup
&& !self
.unresolved_overrides
.as_ref()
.unwrap()
.global_variables
.contains_key(&handle)
})
.collect::<Vec<_>>();

for (handle, var) in vars {
access_stack.enter(Access::GlobalVariable(handle), |access_stack| {
Expand Down
Loading