@@ -4,13 +4,13 @@ use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivit
4
4
use rustc_codegen_ssa:: ModuleCodegen ;
5
5
use rustc_codegen_ssa:: back:: write:: ModuleConfig ;
6
6
use rustc_codegen_ssa:: common:: TypeKind ;
7
- use rustc_codegen_ssa:: traits:: BaseTypeCodegenMethods ;
7
+ use rustc_codegen_ssa:: traits:: { BaseTypeCodegenMethods , BuilderMethods } ;
8
8
use rustc_errors:: FatalError ;
9
9
use rustc_middle:: bug;
10
10
use tracing:: { debug, trace} ;
11
11
12
12
use crate :: back:: write:: llvm_err;
13
- use crate :: builder:: { SBuilder , UNNAMED } ;
13
+ use crate :: builder:: { Builder , OperandRef , PlaceRef , UNNAMED } ;
14
14
use crate :: context:: SimpleCx ;
15
15
use crate :: declare:: declare_simple_fn;
16
16
use crate :: errors:: { AutoDiffWithoutEnable , LlvmError } ;
@@ -19,7 +19,7 @@ use crate::llvm::{Metadata, True};
19
19
use crate :: value:: Value ;
20
20
use crate :: { CodegenContext , LlvmCodegenBackend , ModuleLlvm , attributes, llvm} ;
21
21
22
- fn get_params ( fnc : & Value ) -> Vec < & Value > {
22
+ pub ( crate ) fn _get_params ( fnc : & Value ) -> Vec < & Value > {
23
23
let param_num = llvm:: LLVMCountParams ( fnc) as usize ;
24
24
let mut fnc_args: Vec < & Value > = vec ! [ ] ;
25
25
fnc_args. reserve ( param_num) ;
@@ -30,7 +30,7 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
30
30
fnc_args
31
31
}
32
32
33
- fn has_sret ( fnc : & Value ) -> bool {
33
+ pub ( crate ) fn has_sret ( fnc : & Value ) -> bool {
34
34
let num_args = llvm:: LLVMCountParams ( fnc) as usize ;
35
35
if num_args == 0 {
36
36
false
@@ -49,9 +49,9 @@ fn has_sret(fnc: &Value) -> bool {
49
49
// need to match those.
50
50
// FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
51
51
// using iterators and peek()?
52
- fn match_args_from_caller_to_enzyme < ' ll > (
52
+ fn match_args_from_caller_to_enzyme < ' ll , ' tcx > (
53
53
cx : & SimpleCx < ' ll > ,
54
- builder : & SBuilder < ' ll , ' ll > ,
54
+ builder : & mut Builder < ' _ , ' ll , ' tcx > ,
55
55
width : u32 ,
56
56
args : & mut Vec < & ' ll llvm:: Value > ,
57
57
inputs : & [ DiffActivity ] ,
@@ -201,7 +201,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
201
201
// Beyond sret, this article describes our challenges nicely:
202
202
// <https://yorickpeterse.com/articles/the-mess-that-is-handling-structure-arguments-and-returns-in-llvm/>
203
203
// I.e. (i32, f32) will get merged into i64, but we don't handle that yet.
204
- fn compute_enzyme_fn_ty < ' ll > (
204
+ pub ( crate ) fn compute_enzyme_fn_ty < ' ll > (
205
205
cx : & SimpleCx < ' ll > ,
206
206
attrs : & AutoDiffAttrs ,
207
207
fn_to_diff : & ' ll Value ,
@@ -289,11 +289,14 @@ fn compute_enzyme_fn_ty<'ll>(
289
289
/// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/>
290
290
// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
291
291
// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
292
- fn generate_enzyme_call < ' ll > (
292
+ pub ( crate ) fn generate_enzyme_call < ' ll , ' tcx > (
293
+ builder : & mut Builder < ' _ , ' ll , ' tcx > ,
293
294
cx : & SimpleCx < ' ll > ,
294
295
fn_to_diff : & ' ll Value ,
295
296
outer_fn : & ' ll Value ,
297
+ fn_args : & [ OperandRef < ' tcx , & ' ll Value > ] ,
296
298
attrs : AutoDiffAttrs ,
299
+ dest : PlaceRef < ' tcx , & ' ll Value > ,
297
300
) {
298
301
// We have to pick the name depending on whether we want forward or reverse mode autodiff.
299
302
let mut ad_name: String = match attrs. mode {
@@ -366,14 +369,6 @@ fn generate_enzyme_call<'ll>(
366
369
let enzyme_marker_attr = llvm:: CreateAttrString ( cx. llcx , "enzyme_marker" ) ;
367
370
attributes:: apply_to_llfn ( outer_fn, Function , & [ enzyme_marker_attr] ) ;
368
371
369
- // first, remove all calls from fnc
370
- let entry = llvm:: LLVMGetFirstBasicBlock ( outer_fn) ;
371
- let br = llvm:: LLVMRustGetTerminator ( entry) ;
372
- llvm:: LLVMRustEraseInstFromParent ( br) ;
373
-
374
- let last_inst = llvm:: LLVMRustGetLastInstruction ( entry) . unwrap ( ) ;
375
- let mut builder = SBuilder :: build ( cx, entry) ;
376
-
377
372
let num_args = llvm:: LLVMCountParams ( & fn_to_diff) ;
378
373
let mut args = Vec :: with_capacity ( num_args as usize + 1 ) ;
379
374
args. push ( fn_to_diff) ;
@@ -389,40 +384,20 @@ fn generate_enzyme_call<'ll>(
389
384
}
390
385
391
386
let has_sret = has_sret ( outer_fn) ;
392
- let outer_args: Vec < & llvm:: Value > = get_params ( outer_fn ) ;
387
+ let outer_args: Vec < & llvm:: Value > = fn_args . iter ( ) . map ( |op| op . immediate ( ) ) . collect ( ) ;
393
388
match_args_from_caller_to_enzyme (
394
389
& cx,
395
- & builder,
390
+ builder,
396
391
attrs. width ,
397
392
& mut args,
398
393
& attrs. input_activity ,
399
394
& outer_args,
400
395
has_sret,
401
396
) ;
402
397
403
- let call = builder. call ( enzyme_ty, ad_fn, & args, None ) ;
404
-
405
- // This part is a bit iffy. LLVM requires that a call to an inlineable function has some
406
- // metadata attached to it, but we just created this code oota. Given that the
407
- // differentiated function already has partly confusing metadata, and given that this
408
- // affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the
409
- // dummy code which we inserted at a higher level.
410
- // FIXME(ZuseZ4): Work with Enzyme core devs to clarify what debug metadata issues we have,
411
- // and how to best improve it for enzyme core and rust-enzyme.
412
- let md_ty = cx. get_md_kind_id ( "dbg" ) ;
413
- if llvm:: LLVMRustHasMetadata ( last_inst, md_ty) {
414
- let md = llvm:: LLVMRustDIGetInstMetadata ( last_inst)
415
- . expect ( "failed to get instruction metadata" ) ;
416
- let md_todiff = cx. get_metadata_value ( md) ;
417
- llvm:: LLVMSetMetadata ( call, md_ty, md_todiff) ;
418
- } else {
419
- // We don't panic, since depending on whether we are in debug or release mode, we might
420
- // have no debug info to copy, which would then be ok.
421
- trace ! ( "no dbg info" ) ;
422
- }
398
+ let call = builder. call ( enzyme_ty, None , None , ad_fn, & args, None , None ) ;
423
399
424
- // Now that we copied the metadata, get rid of dummy code.
425
- llvm:: LLVMRustEraseInstUntilInclusive ( entry, last_inst) ;
400
+ builder. store_to_place ( call, dest. val ) ;
426
401
427
402
if cx. val_ty ( call) == cx. type_void ( ) || has_sret {
428
403
if has_sret {
@@ -445,10 +420,10 @@ fn generate_enzyme_call<'ll>(
445
420
llvm:: LLVMBuildStore ( & builder. llbuilder , call, sret_ptr) ;
446
421
}
447
422
builder. ret_void ( ) ;
448
- } else {
449
- builder. ret ( call) ;
450
423
}
451
424
425
+ builder. store_to_place ( call, dest. val ) ;
426
+
452
427
// Let's crash in case that we messed something up above and generated invalid IR.
453
428
llvm:: LLVMRustVerifyFunction (
454
429
outer_fn,
@@ -463,6 +438,7 @@ pub(crate) fn differentiate<'ll>(
463
438
diff_items : Vec < AutoDiffItem > ,
464
439
_config : & ModuleConfig ,
465
440
) -> Result < ( ) , FatalError > {
441
+ // TODO(Sa4dUs): delete all this logic
466
442
for item in & diff_items {
467
443
trace ! ( "{}" , item) ;
468
444
}
@@ -482,7 +458,7 @@ pub(crate) fn differentiate<'ll>(
482
458
for item in diff_items. iter ( ) {
483
459
let name = item. source . clone ( ) ;
484
460
let fn_def: Option < & llvm:: Value > = cx. get_function ( & name) ;
485
- let Some ( fn_def ) = fn_def else {
461
+ let Some ( _fn_def ) = fn_def else {
486
462
return Err ( llvm_err (
487
463
diag_handler. handle ( ) ,
488
464
LlvmError :: PrepareAutoDiff {
@@ -494,7 +470,7 @@ pub(crate) fn differentiate<'ll>(
494
470
} ;
495
471
debug ! ( ?item. target) ;
496
472
let fn_target: Option < & llvm:: Value > = cx. get_function ( & item. target ) ;
497
- let Some ( fn_target ) = fn_target else {
473
+ let Some ( _fn_target ) = fn_target else {
498
474
return Err ( llvm_err (
499
475
diag_handler. handle ( ) ,
500
476
LlvmError :: PrepareAutoDiff {
@@ -505,7 +481,7 @@ pub(crate) fn differentiate<'ll>(
505
481
) ) ;
506
482
} ;
507
483
508
- generate_enzyme_call ( & cx, fn_def, fn_target, item. attrs . clone ( ) ) ;
484
+ // generate_enzyme_call(&cx, fn_def, fn_target, item.attrs.clone());
509
485
}
510
486
511
487
// FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
0 commit comments