-
-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use explicitly added ApplyDeferred
stages when computing automatically inserted sync points.
#16782
Changes from all commits
31fc75f
242ce70
776d062
306cc07
d17c236
53a4d32
54206f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -80,41 +80,109 @@ impl ScheduleBuildPass for AutoInsertApplyDeferredPass { | |
let mut sync_point_graph = dependency_flattened.clone(); | ||
let topo = graph.topsort_graph(dependency_flattened, ReportCycles::Dependency)?; | ||
|
||
fn set_has_conditions(graph: &ScheduleGraph, node: NodeId) -> bool { | ||
!graph.set_conditions_at(node).is_empty() | ||
|| graph | ||
.hierarchy() | ||
.graph() | ||
.edges_directed(node, Direction::Incoming) | ||
.any(|(parent, _)| set_has_conditions(graph, parent)) | ||
} | ||
|
||
fn system_has_conditions(graph: &ScheduleGraph, node: NodeId) -> bool { | ||
assert!(node.is_system()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the concern here for this assert? Since the other function has no inverse assert? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The other one does this implicitly. For the systems, there's no corresponding function, but we can implement it ourselves pretty simply. But we need to make sure this only happens for systems. In practice, we will never call this with a set, but better to be safe than sorry. |
||
!graph.system_conditions[node.index()].is_empty() | ||
|| graph | ||
.hierarchy() | ||
.graph() | ||
.edges_directed(node, Direction::Incoming) | ||
.any(|(parent, _)| set_has_conditions(graph, parent)) | ||
} | ||
|
||
let mut system_has_conditions_cache = HashMap::default(); | ||
|
||
fn is_valid_explicit_sync_point( | ||
graph: &ScheduleGraph, | ||
system: NodeId, | ||
system_has_conditions_cache: &mut HashMap<usize, bool>, | ||
) -> bool { | ||
let index = system.index(); | ||
is_apply_deferred(graph.systems[index].get().unwrap()) | ||
&& !*system_has_conditions_cache | ||
.entry(index) | ||
.or_insert_with(|| system_has_conditions(graph, system)) | ||
} | ||
|
||
// calculate the number of sync points each sync point is from the beginning of the graph | ||
// use the same sync point if the distance is the same | ||
let mut distances: HashMap<usize, Option<u32>> = | ||
let mut distances: HashMap<usize, u32> = | ||
HashMap::with_capacity_and_hasher(topo.len(), Default::default()); | ||
// Keep track of any explicit sync nodes for a specific distance. | ||
let mut distance_to_explicit_sync_node: HashMap<u32, NodeId> = HashMap::default(); | ||
for node in &topo { | ||
let add_sync_after = graph.systems[node.index()].get().unwrap().has_deferred(); | ||
let node_system = graph.systems[node.index()].get().unwrap(); | ||
|
||
let node_needs_sync = | ||
if is_valid_explicit_sync_point(graph, *node, &mut system_has_conditions_cache) { | ||
distance_to_explicit_sync_node.insert( | ||
distances.get(&node.index()).copied().unwrap_or_default(), | ||
*node, | ||
); | ||
|
||
// This node just did a sync, so the only reason to do another sync is if one was | ||
// explicitly scheduled afterwards. | ||
false | ||
} else { | ||
node_system.has_deferred() | ||
}; | ||
|
||
for target in dependency_flattened.neighbors_directed(*node, Direction::Outgoing) { | ||
let add_sync_on_edge = add_sync_after | ||
&& !is_apply_deferred(graph.systems[target.index()].get().unwrap()) | ||
&& !self.no_sync_edges.contains(&(*node, target)); | ||
|
||
let weight = if add_sync_on_edge { 1 } else { 0 }; | ||
|
||
let edge_needs_sync = node_needs_sync | ||
&& !self.no_sync_edges.contains(&(*node, target)) | ||
|| is_valid_explicit_sync_point( | ||
graph, | ||
target, | ||
&mut system_has_conditions_cache, | ||
); | ||
|
||
let weight = if edge_needs_sync { 1 } else { 0 }; | ||
|
||
// Use whichever distance is larger, either the current distance, or the distance to | ||
// the parent plus the weight. | ||
let distance = distances | ||
.get(&target.index()) | ||
.unwrap_or(&None) | ||
.or(Some(0)) | ||
.map(|distance| { | ||
distance.max( | ||
distances.get(&node.index()).unwrap_or(&None).unwrap_or(0) + weight, | ||
) | ||
}); | ||
.copied() | ||
.unwrap_or_default() | ||
.max(distances.get(&node.index()).copied().unwrap_or_default() + weight); | ||
|
||
distances.insert(target.index(), distance); | ||
} | ||
} | ||
|
||
if add_sync_on_edge { | ||
let sync_point = | ||
self.get_sync_point(graph, distances[&target.index()].unwrap()); | ||
sync_point_graph.add_edge(*node, sync_point); | ||
sync_point_graph.add_edge(sync_point, target); | ||
// Find any edges which have a different number of sync points between them and make sure | ||
// there is a sync point between them. | ||
for node in &topo { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You need to do this in a second iteration because until then you don't necessarily have all explicit sync points collected before adding new ones, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not quite. The real problem is that we don't know the "distance" for the explicit sync node until it is "expanded" (the outer loop has it as |
||
let node_distance = distances.get(&node.index()).copied().unwrap_or_default(); | ||
for target in dependency_flattened.neighbors_directed(*node, Direction::Outgoing) { | ||
let target_distance = distances.get(&target.index()).copied().unwrap_or_default(); | ||
if node_distance == target_distance { | ||
// These nodes are the same distance, so they don't need an edge between them. | ||
continue; | ||
} | ||
|
||
// edge is now redundant | ||
sync_point_graph.remove_edge(*node, target); | ||
if is_apply_deferred(graph.systems[target.index()].get().unwrap()) { | ||
// We don't need to insert a sync point since ApplyDeferred is a sync point | ||
// already! | ||
continue; | ||
} | ||
let sync_point = distance_to_explicit_sync_node | ||
.get(&target_distance) | ||
.copied() | ||
.unwrap_or_else(|| self.get_sync_point(graph, target_distance)); | ||
|
||
sync_point_graph.add_edge(*node, sync_point); | ||
sync_point_graph.add_edge(sync_point, target); | ||
|
||
sync_point_graph.remove_edge(*node, target); | ||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This and the other two functions kinda bloat this up. It might be worth it making them methods of
ScheduleGraph
since that is the first argument for all of them. Though the third and it's HashMap parameter make it probably even less useful in any other (future) context.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kinda don't want to do this. If in the future this is a desirable thing to know for other cases, then we should probably use some dynamic programming to figure out which sets/systems overall have conditions. This is only "efficient" because we expect few
ApplyDeferred
stages in the first place, so querying for those in particular is probably cheaper.