Skip to content

Commit 5e8ac69

Browse files
committed
various updates
1 parent 638406d commit 5e8ac69

File tree

7 files changed

+100
-48
lines changed

7 files changed

+100
-48
lines changed

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use rustc_middle::bug;
1919
use rustc_middle::dep_graph::WorkProduct;
2020
use rustc_middle::middle::exported_symbols::{SymbolExportInfo, SymbolExportLevel};
2121
use rustc_session::config::{self, CrateType, Lto};
22-
use tracing::{debug, info};
22+
use tracing::{debug, info, trace};
2323

2424
use crate::back::write::{
2525
self, CodegenDiagnosticsStage, DiagnosticHandlers, bitcode_section_name, save_temp_bitcode,
@@ -606,12 +606,33 @@ pub(crate) fn run_pass_manager(
606606

607607
// If this rustc version was build with enzyme/autodiff enabled, and if users applied the
608608
// `#[autodiff]` macro at least once, then we will later call llvm_optimize a second time.
609-
let first_run = true;
610-
debug!("running llvm pm opt pipeline");
609+
trace!("running llvm pm opt pipeline");
611610
unsafe {
612-
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run)?;
611+
write::llvm_optimize(
612+
cgcx,
613+
dcx,
614+
module,
615+
config,
616+
opt_level,
617+
opt_stage,
618+
write::AutodiffStage::DuringAD,
619+
)?;
613620
}
614-
debug!("lto done");
621+
// FIXME(ZuseZ4): Make this more granular
622+
if cfg!(llvm_enzyme) && !thin {
623+
unsafe {
624+
write::llvm_optimize(
625+
cgcx,
626+
dcx,
627+
module,
628+
config,
629+
opt_level,
630+
llvm::OptStage::FatLTO,
631+
write::AutodiffStage::PostAD,
632+
)?;
633+
}
634+
}
635+
trace!("lto done");
615636
Ok(())
616637
}
617638

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -530,14 +530,24 @@ fn get_instr_profile_output_path(config: &ModuleConfig) -> Option<CString> {
530530
config.instrument_coverage.then(|| c"default_%m_%p.profraw".to_owned())
531531
}
532532

533+
// PreAD will run llvm opts but disable size increasing opts (vectorization, loop unrolling)
534+
// DuringAD is the same as above, but also runs the enzyme opt and autodiff passes.
535+
// PostAD will run all opts, including size increasing opts.
536+
#[derive(Debug, Eq, PartialEq)]
537+
pub(crate) enum AutodiffStage {
538+
PreAD,
539+
DuringAD,
540+
PostAD,
541+
}
542+
533543
pub(crate) unsafe fn llvm_optimize(
534544
cgcx: &CodegenContext<LlvmCodegenBackend>,
535545
dcx: DiagCtxtHandle<'_>,
536546
module: &ModuleCodegen<ModuleLlvm>,
537547
config: &ModuleConfig,
538548
opt_level: config::OptLevel,
539549
opt_stage: llvm::OptStage,
540-
skip_size_increasing_opts: bool,
550+
autodiff_stage: AutodiffStage,
541551
) -> Result<(), FatalError> {
542552
// Enzyme:
543553
// The whole point of compiler based AD is to differentiate optimized IR instead of unoptimized
@@ -550,7 +560,7 @@ pub(crate) unsafe fn llvm_optimize(
550560
let unroll_loops;
551561
let vectorize_slp;
552562
let vectorize_loop;
553-
let run_enzyme;
563+
let run_enzyme = cfg!(llvm_enzyme) && autodiff_stage == AutodiffStage::DuringAD;
554564

555565
// When we build rustc with enzyme/autodiff support, we want to postpone size-increasing
556566
// optimizations until after differentiation. Our pipeline is thus: (opt + enzyme), (full opt).
@@ -559,17 +569,15 @@ pub(crate) unsafe fn llvm_optimize(
559569
// FIXME(ZuseZ4): Before shipping on nightly,
560570
// we should make this more granular, or at least check that the user has at least one autodiff
561571
// call in their code, to justify altering the compilation pipeline.
562-
if skip_size_increasing_opts && cfg!(llvm_enzyme) {
572+
if cfg!(llvm_enzyme) && autodiff_stage != AutodiffStage::PostAD {
563573
unroll_loops = false;
564574
vectorize_slp = false;
565575
vectorize_loop = false;
566-
run_enzyme = true;
567576
} else {
568577
unroll_loops =
569578
opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin;
570579
vectorize_slp = config.vectorize_slp;
571580
vectorize_loop = config.vectorize_loop;
572-
run_enzyme = false;
573581
}
574582
trace!(?unroll_loops, ?vectorize_slp, ?vectorize_loop, ?run_enzyme);
575583
let using_thin_buffers = opt_stage == llvm::OptStage::PreLinkThinLTO || config.bitcode_needed();
@@ -691,18 +699,14 @@ pub(crate) unsafe fn optimize(
691699
_ => llvm::OptStage::PreLinkNoLTO,
692700
};
693701

694-
// If we know that we will later run AD, then we disable vectorization and loop unrolling
695-
let skip_size_increasing_opts = cfg!(llvm_enzyme);
702+
// If we know that we will later run AD, then we disable vectorization and loop unrolling.
703+
// Otherwise we pretend AD is already done and run the normal opt pipeline (=PostAD).
704+
// FIXME(ZuseZ4): Make this more granular, only set PreAD if we actually have autodiff
705+
// usages, not just if we build rustc with autodiff support.
706+
let autodiff_stage =
707+
if cfg!(llvm_enzyme) { AutodiffStage::PreAD } else { AutodiffStage::PostAD };
696708
return unsafe {
697-
llvm_optimize(
698-
cgcx,
699-
dcx,
700-
module,
701-
config,
702-
opt_level,
703-
opt_stage,
704-
skip_size_increasing_opts,
705-
)
709+
llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, autodiff_stage)
706710
};
707711
}
708712
Ok(())

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ pub(crate) fn differentiate<'ll>(
285285
module: &'ll ModuleCodegen<ModuleLlvm>,
286286
cgcx: &CodegenContext<LlvmCodegenBackend>,
287287
diff_items: Vec<AutoDiffItem>,
288-
config: &ModuleConfig,
288+
_config: &ModuleConfig,
289289
) -> Result<(), FatalError> {
290290
for item in &diff_items {
291291
trace!("{}", item);
@@ -320,29 +320,29 @@ pub(crate) fn differentiate<'ll>(
320320

321321
// FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
322322

323-
if let Some(opt_level) = config.opt_level {
324-
let opt_stage = match cgcx.lto {
325-
Lto::Fat => llvm::OptStage::PreLinkFatLTO,
326-
Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO,
327-
_ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO,
328-
_ => llvm::OptStage::PreLinkNoLTO,
329-
};
330-
// This is our second opt call, so now we run all opts,
331-
// to make sure we get the best performance.
332-
let skip_size_increasing_opts = false;
333-
trace!("running Module Optimization after differentiation");
334-
unsafe {
335-
llvm_optimize(
336-
cgcx,
337-
diag_handler.handle(),
338-
module,
339-
config,
340-
opt_level,
341-
opt_stage,
342-
skip_size_increasing_opts,
343-
)?
344-
};
345-
}
323+
//if let Some(opt_level) = config.opt_level {
324+
// let opt_stage = match cgcx.lto {
325+
// Lto::Fat => llvm::OptStage::PreLinkFatLTO,
326+
// Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO,
327+
// _ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO,
328+
// _ => llvm::OptStage::PreLinkNoLTO,
329+
// };
330+
// // This is our second opt call, so now we run all opts,
331+
// // to make sure we get the best performance.
332+
// let skip_size_increasing_opts = false;
333+
// trace!("running Module Optimization after differentiation");
334+
// unsafe {
335+
// llvm_optimize(
336+
// cgcx,
337+
// diag_handler.handle(),
338+
// module,
339+
// config,
340+
// opt_level,
341+
// opt_stage,
342+
// skip_size_increasing_opts,
343+
// )?
344+
// };
345+
//}
346346
trace!("done with differentiate()");
347347

348348
Ok(())

compiler/rustc_llvm/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,6 @@ libc = "0.2.73"
1414
# pinned `cc` in `rustc_codegen_ssa` if you update `cc` here.
1515
cc = "=1.2.7"
1616
# tidy-alphabetical-end
17+
18+
[lints.rust]
19+
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(llvm_enzyme)'] }

compiler/rustc_llvm/build.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,25 @@ fn main() {
193193
cfg.define(&flag, None);
194194
}
195195

196+
// This doesn't work
197+
if tracked_env_var_os("llvm_enzyme").is_some() {
198+
// If we're just running `check`, there's no need for LLVM to be built.
199+
loop {
200+
println!("ENZYME!");
201+
}
202+
}
203+
204+
// This doesn't work either
205+
if cfg!(llvm_enzyme) {
206+
// Enzyme is a fork of LLVM, so we need to use a different build script.
207+
loop {
208+
println!("ENZYME!");
209+
}
210+
}
211+
212+
// FIXME: Only enable it, if the user build rustc with autodiff support
213+
cfg.define("ENZYME", None);
214+
196215
if tracked_env_var_os("LLVM_RUSTLLVM").is_some() {
197216
cfg.define("LLVM_RUSTLLVM", None);
198217
}

compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -688,8 +688,11 @@ struct LLVMRustSanitizerOptions {
688688
bool SanitizeKernelAddressRecover;
689689
};
690690

691+
691692
// This symbol won't be available or used when Enzyme is not enabled
692-
extern "C" void registerEnzyme(llvm::PassBuilder &PB) __attribute__((weak));
693+
#ifdef ENZYME
694+
extern "C" void registerEnzyme(llvm::PassBuilder &PB);
695+
#endif
693696

694697
extern "C" LLVMRustResult LLVMRustOptimize(
695698
LLVMModuleRef ModuleRef, LLVMTargetMachineRef TMRef,
@@ -1015,6 +1018,7 @@ extern "C" LLVMRustResult LLVMRustOptimize(
10151018
}
10161019

10171020
// now load "-enzyme" pass:
1021+
#ifdef ENZYME
10181022
if (RunEnzyme) {
10191023
registerEnzyme(PB);
10201024
if (auto Err = PB.parsePassPipeline(MPM, "enzyme")) {
@@ -1023,6 +1027,7 @@ extern "C" LLVMRustResult LLVMRustOptimize(
10231027
return LLVMRustResult::Failure;
10241028
}
10251029
}
1030+
#endif
10261031

10271032
// Upgrade all calls to old intrinsics first.
10281033
for (Module::iterator I = TheModule->begin(), E = TheModule->end(); I != E;)

tests/codegen/autodiff.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ fn square(x: &f64) -> f64 {
1515
// CHECK-NEXT:invertstart:
1616
// CHECK-NEXT: %_0 = fmul double %x.0.val, %x.0.val
1717
// CHECK-NEXT: %0 = fadd fast double %x.0.val, %x.0.val
18-
// CHECK-NEXT: %1 = load double, ptr %"x'", align 8, !alias.scope !17816, !noalias !17819
18+
// CHECK-NEXT: %1 = load double, ptr %"x'", align 8
1919
// CHECK-NEXT: %2 = fadd fast double %1, %0
20-
// CHECK-NEXT: store double %2, ptr %"x'", align 8, !alias.scope !17816, !noalias !17819
20+
// CHECK-NEXT: store double %2, ptr %"x'", align 8
2121
// CHECK-NEXT: ret double %_0
2222
// CHECK-NEXT:}
2323

0 commit comments

Comments
 (0)