Skip to content

Commit 23cb884

Browse files
committed
Naive impl of generate enzyme call
Now I use functions squash
1 parent fbd9b77 commit 23cb884

File tree

3 files changed

+71
-56
lines changed

3 files changed

+71
-56
lines changed

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 21 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivit
44
use rustc_codegen_ssa::ModuleCodegen;
55
use rustc_codegen_ssa::back::write::ModuleConfig;
66
use rustc_codegen_ssa::common::TypeKind;
7-
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
7+
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
88
use rustc_errors::FatalError;
99
use rustc_middle::bug;
1010
use tracing::{debug, trace};
1111

1212
use crate::back::write::llvm_err;
13-
use crate::builder::{SBuilder, UNNAMED};
13+
use crate::builder::{Builder, OperandRef, PlaceRef, UNNAMED};
1414
use crate::context::SimpleCx;
1515
use crate::declare::declare_simple_fn;
1616
use crate::errors::{AutoDiffWithoutEnable, LlvmError};
@@ -19,7 +19,7 @@ use crate::llvm::{Metadata, True};
1919
use crate::value::Value;
2020
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm};
2121

22-
fn get_params(fnc: &Value) -> Vec<&Value> {
22+
pub(crate) fn _get_params(fnc: &Value) -> Vec<&Value> {
2323
let param_num = llvm::LLVMCountParams(fnc) as usize;
2424
let mut fnc_args: Vec<&Value> = vec![];
2525
fnc_args.reserve(param_num);
@@ -30,7 +30,7 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
3030
fnc_args
3131
}
3232

33-
fn has_sret(fnc: &Value) -> bool {
33+
pub(crate) fn has_sret(fnc: &Value) -> bool {
3434
let num_args = llvm::LLVMCountParams(fnc) as usize;
3535
if num_args == 0 {
3636
false
@@ -49,9 +49,9 @@ fn has_sret(fnc: &Value) -> bool {
4949
// need to match those.
5050
// FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
5151
// using iterators and peek()?
52-
fn match_args_from_caller_to_enzyme<'ll>(
52+
fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
5353
cx: &SimpleCx<'ll>,
54-
builder: &SBuilder<'ll, 'll>,
54+
builder: &mut Builder<'_, 'll, 'tcx>,
5555
width: u32,
5656
args: &mut Vec<&'ll llvm::Value>,
5757
inputs: &[DiffActivity],
@@ -201,7 +201,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
201201
// Beyond sret, this article describes our challenges nicely:
202202
// <https://yorickpeterse.com/articles/the-mess-that-is-handling-structure-arguments-and-returns-in-llvm/>
203203
// I.e. (i32, f32) will get merged into i64, but we don't handle that yet.
204-
fn compute_enzyme_fn_ty<'ll>(
204+
pub(crate) fn compute_enzyme_fn_ty<'ll>(
205205
cx: &SimpleCx<'ll>,
206206
attrs: &AutoDiffAttrs,
207207
fn_to_diff: &'ll Value,
@@ -289,11 +289,14 @@ fn compute_enzyme_fn_ty<'ll>(
289289
/// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/>
290290
// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
291291
// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
292-
fn generate_enzyme_call<'ll>(
292+
pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
293+
builder: &mut Builder<'_, 'll, 'tcx>,
293294
cx: &SimpleCx<'ll>,
294295
fn_to_diff: &'ll Value,
295296
outer_fn: &'ll Value,
297+
fn_args: &[OperandRef<'tcx, &'ll Value>],
296298
attrs: AutoDiffAttrs,
299+
dest: PlaceRef<'tcx, &'ll Value>,
297300
) {
298301
// We have to pick the name depending on whether we want forward or reverse mode autodiff.
299302
let mut ad_name: String = match attrs.mode {
@@ -366,14 +369,6 @@ fn generate_enzyme_call<'ll>(
366369
let enzyme_marker_attr = llvm::CreateAttrString(cx.llcx, "enzyme_marker");
367370
attributes::apply_to_llfn(outer_fn, Function, &[enzyme_marker_attr]);
368371

369-
// first, remove all calls from fnc
370-
let entry = llvm::LLVMGetFirstBasicBlock(outer_fn);
371-
let br = llvm::LLVMRustGetTerminator(entry);
372-
llvm::LLVMRustEraseInstFromParent(br);
373-
374-
let last_inst = llvm::LLVMRustGetLastInstruction(entry).unwrap();
375-
let mut builder = SBuilder::build(cx, entry);
376-
377372
let num_args = llvm::LLVMCountParams(&fn_to_diff);
378373
let mut args = Vec::with_capacity(num_args as usize + 1);
379374
args.push(fn_to_diff);
@@ -389,40 +384,20 @@ fn generate_enzyme_call<'ll>(
389384
}
390385

391386
let has_sret = has_sret(outer_fn);
392-
let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
387+
let outer_args: Vec<&llvm::Value> = fn_args.iter().map(|op| op.immediate()).collect();
393388
match_args_from_caller_to_enzyme(
394389
&cx,
395-
&builder,
390+
builder,
396391
attrs.width,
397392
&mut args,
398393
&attrs.input_activity,
399394
&outer_args,
400395
has_sret,
401396
);
402397

403-
let call = builder.call(enzyme_ty, ad_fn, &args, None);
404-
405-
// This part is a bit iffy. LLVM requires that a call to an inlineable function has some
406-
// metadata attached to it, but we just created this code oota. Given that the
407-
// differentiated function already has partly confusing metadata, and given that this
408-
// affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the
409-
// dummy code which we inserted at a higher level.
410-
// FIXME(ZuseZ4): Work with Enzyme core devs to clarify what debug metadata issues we have,
411-
// and how to best improve it for enzyme core and rust-enzyme.
412-
let md_ty = cx.get_md_kind_id("dbg");
413-
if llvm::LLVMRustHasMetadata(last_inst, md_ty) {
414-
let md = llvm::LLVMRustDIGetInstMetadata(last_inst)
415-
.expect("failed to get instruction metadata");
416-
let md_todiff = cx.get_metadata_value(md);
417-
llvm::LLVMSetMetadata(call, md_ty, md_todiff);
418-
} else {
419-
// We don't panic, since depending on whether we are in debug or release mode, we might
420-
// have no debug info to copy, which would then be ok.
421-
trace!("no dbg info");
422-
}
398+
let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
423399

424-
// Now that we copied the metadata, get rid of dummy code.
425-
llvm::LLVMRustEraseInstUntilInclusive(entry, last_inst);
400+
builder.store_to_place(call, dest.val);
426401

427402
if cx.val_ty(call) == cx.type_void() || has_sret {
428403
if has_sret {
@@ -445,10 +420,10 @@ fn generate_enzyme_call<'ll>(
445420
llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr);
446421
}
447422
builder.ret_void();
448-
} else {
449-
builder.ret(call);
450423
}
451424

425+
builder.store_to_place(call, dest.val);
426+
452427
// Let's crash in case that we messed something up above and generated invalid IR.
453428
llvm::LLVMRustVerifyFunction(
454429
outer_fn,
@@ -463,6 +438,7 @@ pub(crate) fn differentiate<'ll>(
463438
diff_items: Vec<AutoDiffItem>,
464439
_config: &ModuleConfig,
465440
) -> Result<(), FatalError> {
441+
// TODO(Sa4dUs): delete all this logic
466442
for item in &diff_items {
467443
trace!("{}", item);
468444
}
@@ -482,7 +458,7 @@ pub(crate) fn differentiate<'ll>(
482458
for item in diff_items.iter() {
483459
let name = item.source.clone();
484460
let fn_def: Option<&llvm::Value> = cx.get_function(&name);
485-
let Some(fn_def) = fn_def else {
461+
let Some(_fn_def) = fn_def else {
486462
return Err(llvm_err(
487463
diag_handler.handle(),
488464
LlvmError::PrepareAutoDiff {
@@ -494,7 +470,7 @@ pub(crate) fn differentiate<'ll>(
494470
};
495471
debug!(?item.target);
496472
let fn_target: Option<&llvm::Value> = cx.get_function(&item.target);
497-
let Some(fn_target) = fn_target else {
473+
let Some(_fn_target) = fn_target else {
498474
return Err(llvm_err(
499475
diag_handler.handle(),
500476
LlvmError::PrepareAutoDiff {
@@ -505,7 +481,7 @@ pub(crate) fn differentiate<'ll>(
505481
));
506482
};
507483

508-
generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone());
484+
// generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone());
509485
}
510486

511487
// FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,19 @@ use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
99
use rustc_codegen_ssa::mir::place::{PlaceRef, PlaceValue};
1010
use rustc_codegen_ssa::traits::*;
1111
use rustc_hir as hir;
12+
use rustc_hir::def_id::LOCAL_CRATE;
1213
use rustc_middle::mir::BinOp;
1314
use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf};
14-
use rustc_middle::ty::{self, GenericArgsRef, Ty};
15+
use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty};
1516
use rustc_middle::{bug, span_bug};
1617
use rustc_span::{Span, Symbol, sym};
17-
use rustc_symbol_mangling::mangle_internal_symbol;
18+
use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_crate};
1819
use rustc_target::spec::{HasTargetSpec, PanicStrategy};
1920
use tracing::debug;
2021

2122
use crate::abi::FnAbiLlvmExt;
2223
use crate::builder::Builder;
24+
use crate::builder::autodiff::generate_enzyme_call;
2325
use crate::context::CodegenCx;
2426
use crate::llvm::{self, Metadata};
2527
use crate::type_::Type;
@@ -200,20 +202,56 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
200202
_ if tcx.has_attr(def_id, sym::rustc_autodiff) => {
201203
// NOTE(Sa4dUs): This is a hacky way to get the autodiff items
202204
// so we can focus on the lowering of the intrinsic call
205+
let mut source_id = None;
206+
let mut diff_attrs = None;
207+
let items: Vec<_> = tcx.hir_body_owners().map(|i| i.to_def_id()).collect();
208+
209+
// Hacky way of getting primal-diff pair, only works for code with 1 autodiff call
210+
for target_id in &items {
211+
let Some(target_attrs) = &tcx.codegen_fn_attrs(target_id).autodiff_item else {
212+
continue;
213+
};
203214

204-
// `diff_items` is empty even when autodiff is enabled, and if we're here,
205-
// it's because some function was marked as intrinsic and had the `rustc_autodiff` attr
206-
let diff_items = tcx.collect_and_partition_mono_items(()).autodiff_items;
215+
if target_attrs.is_source() {
216+
source_id = Some(*target_id);
217+
} else {
218+
diff_attrs = Some(target_attrs);
219+
}
220+
}
207221

208-
// this shouldn't happen?
209-
if diff_items.is_empty() {
210-
bug!("no autodiff items found for {def_id:?}");
222+
if source_id.is_none() || diff_attrs.is_none() {
223+
bug!("could not find source_id={source_id:?} or diff_attrs={diff_attrs:?}");
211224
}
212225

213-
// TODO(Sa4dUs): generate the enzyme call itself, based on the logic in `builder.rs`
226+
let diff_attrs = diff_attrs.unwrap().clone();
227+
228+
// Get source fn
229+
let source_id = source_id.unwrap();
230+
let fn_source = Instance::mono(tcx, source_id);
231+
let source_symbol =
232+
symbol_name_for_instance_in_crate(tcx, fn_source.clone(), LOCAL_CRATE);
233+
let fn_to_diff: Option<&'ll llvm::Value> = self.cx.get_function(&source_symbol);
234+
let Some(fn_to_diff) = fn_to_diff else { bug!("could not find source function") };
235+
236+
// Declare target fn
237+
let target_symbol =
238+
symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE);
239+
let fn_abi = self.cx.fn_abi_of_instance(instance, ty::List::empty());
240+
let outer_fn: &'ll Value =
241+
self.cx.declare_fn(&target_symbol, fn_abi, Some(instance));
242+
243+
// Build body
244+
generate_enzyme_call(
245+
self,
246+
self.cx,
247+
fn_to_diff,
248+
outer_fn,
249+
args, // This argument was not in the original `generate_enzyme_call`, now it's included because `get_params` is not working anymore
250+
diff_attrs.clone(),
251+
result,
252+
);
214253

215-
// Just gen the fallback body for now
216-
return Err(ty::Instance::new_raw(def_id, instance.args));
254+
return Ok(());
217255
}
218256
sym::is_val_statically_known => {
219257
let intrinsic_type = args[0].layout.immediate_llvm_type(self.cx);

tests/codegen/autodiff/scalar.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
//@ no-prefer-dynamic
33
//@ needs-enzyme
44
#![feature(autodiff)]
5+
#![feature(intrinsics)]
56

67
use std::autodiff::autodiff_reverse;
78

0 commit comments

Comments
 (0)