Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 3d2e36c

Browse files
committed
upstream rustc_codegen_llvm changes for enzyme/autodiff
1 parent 3fee0f1 commit 3d2e36c

File tree

13 files changed

+561
-27
lines changed

13 files changed

+561
-27
lines changed

‎compiler/rustc_ast/src/expand/autodiff_attrs.rs

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
use std::fmt::{self, Display, Formatter};
77
use std::str::FromStr;
88

9-
use crate::expand::typetree::TypeTree;
109
use crate::expand::{Decodable, Encodable, HashStable_Generic};
1110
use crate::ptr::P;
1211
use crate::{Ty, TyKind};
@@ -79,10 +78,6 @@ pub struct AutoDiffItem {
7978
/// The name of the function being generated
8079
pub target: String,
8180
pub attrs: AutoDiffAttrs,
82-
/// Describe the memory layout of input types
83-
pub inputs: Vec<TypeTree>,
84-
/// Describe the memory layout of the output type
85-
pub output: TypeTree,
8681
}
8782
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
8883
pub struct AutoDiffAttrs {
@@ -262,22 +257,14 @@ impl AutoDiffAttrs {
262257
!matches!(self.mode, DiffMode::Error | DiffMode::Source)
263258
}
264259

265-
pub fn into_item(
266-
self,
267-
source: String,
268-
target: String,
269-
inputs: Vec<TypeTree>,
270-
output: TypeTree,
271-
) -> AutoDiffItem {
272-
AutoDiffItem { source, target, inputs, output, attrs: self }
260+
pub fn into_item(self, source: String, target: String) -> AutoDiffItem {
261+
AutoDiffItem { source, target, attrs: self }
273262
}
274263
}
275264

276265
impl fmt::Display for AutoDiffItem {
277266
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
278267
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
279-
write!(f, " with attributes: {:?}", self.attrs)?;
280-
write!(f, " with inputs: {:?}", self.inputs)?;
281-
write!(f, " with output: {:?}", self.output)
268+
write!(f, " with attributes: {:?}", self.attrs)
282269
}
283270
}

‎compiler/rustc_codegen_gcc/src/lib.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ use gccjit::{CType, Context, OptimizationLevel};
9393
#[cfg(feature = "master")]
9494
use gccjit::{TargetInfo, Version};
9595
use rustc_ast::expand::allocator::AllocatorKind;
96+
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
9697
use rustc_codegen_ssa::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule};
9798
use rustc_codegen_ssa::back::write::{
9899
CodegenContext, FatLtoInput, ModuleConfig, TargetMachineFactoryFn,
@@ -439,6 +440,14 @@ impl WriteBackendMethods for GccCodegenBackend {
439440
) -> Result<ModuleCodegen<Self::Module>, FatalError> {
440441
back::write::link(cgcx, dcx, modules)
441442
}
443+
fn autodiff(
444+
_cgcx: &CodegenContext<Self>,
445+
_module: &ModuleCodegen<Self::Module>,
446+
_diff_fncs: Vec<AutoDiffItem>,
447+
_config: &ModuleConfig,
448+
) -> Result<(), FatalError> {
449+
unimplemented!()
450+
}
442451
}
443452

444453
/// This is the entrypoint for a hot plugged rustc_codegen_gccjit

‎compiler/rustc_codegen_llvm/messages.ftl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
codegen_llvm_autodiff_without_lto = using the autodiff feature requires using fat-lto
2+
13
codegen_llvm_copy_bitcode = failed to copy bitcode to object file: {$err}
24
35
codegen_llvm_dynamic_linking_with_lto =
@@ -47,6 +49,8 @@ codegen_llvm_parse_bitcode_with_llvm_err = failed to parse bitcode for LTO modul
4749
codegen_llvm_parse_target_machine_config =
4850
failed to parse target machine config to target machine: {$error}
4951
52+
codegen_llvm_prepare_autodiff = failed to prepare autodiff: src: {$src}, target: {$target}, {$error}
53+
codegen_llvm_prepare_autodiff_with_llvm_err = failed to prepare autodiff: {$llvm_err}, src: {$src}, target: {$target}, {$error}
5054
codegen_llvm_prepare_thin_lto_context = failed to prepare thin LTO context
5155
codegen_llvm_prepare_thin_lto_context_with_llvm_err = failed to prepare thin LTO context: {$llvm_err}
5256

‎compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,14 @@ pub(crate) fn run_pass_manager(
604604
debug!("running the pass manager");
605605
let opt_stage = if thin { llvm::OptStage::ThinLTO } else { llvm::OptStage::FatLTO };
606606
let opt_level = config.opt_level.unwrap_or(config::OptLevel::No);
607-
unsafe { write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage) }?;
607+
608+
// If this rustc version was build with enzyme/autodiff enabled, and if users applied the
609+
// `#[autodiff]` macro at least once, then we will later call llvm_optimize a second time.
610+
let first_run = true;
611+
debug!("running llvm pm opt pipeline");
612+
unsafe {
613+
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run)?;
614+
}
608615
debug!("lto done");
609616
Ok(())
610617
}

‎compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 128 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use libc::{c_char, c_int, c_void, size_t};
88
use llvm::{
99
LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols,
1010
};
11+
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
1112
use rustc_codegen_ssa::back::link::ensure_removed;
1213
use rustc_codegen_ssa::back::versioned_llvm_target;
1314
use rustc_codegen_ssa::back::write::{
@@ -28,7 +29,7 @@ use rustc_session::config::{
2829
use rustc_span::InnerSpan;
2930
use rustc_span::symbol::sym;
3031
use rustc_target::spec::{CodeModel, RelocModel, SanitizerSet, SplitDebuginfo, TlsModel};
31-
use tracing::debug;
32+
use tracing::{debug, trace};
3233

3334
use crate::back::lto::ThinBuffer;
3435
use crate::back::owned_target_machine::OwnedTargetMachine;
@@ -517,9 +518,38 @@ pub(crate) unsafe fn llvm_optimize(
517518
config: &ModuleConfig,
518519
opt_level: config::OptLevel,
519520
opt_stage: llvm::OptStage,
521+
skip_size_increasing_opts: bool,
520522
) -> Result<(), FatalError> {
521-
let unroll_loops =
522-
opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin;
523+
// Enzyme:
524+
// The whole point of compiler based AD is to differentiate optimized IR instead of unoptimized
525+
// source code. However, benchmarks show that optimizations increasing the code size
526+
// tend to reduce AD performance. Therefore deactivate them before AD, then differentiate the code
527+
// and finally re-optimize the module, now with all optimizations available.
528+
// FIXME(ZuseZ4): In a future update we could figure out how to only optimize individual functions getting
529+
// differentiated.
530+
531+
let unroll_loops;
532+
let vectorize_slp;
533+
let vectorize_loop;
534+
535+
// When we build rustc with enzyme/autodiff support, we want to postpone size-increasing
536+
// optimizations until after differentiation. FIXME(ZuseZ4): Before shipping on nightly,
537+
// we should make this more granular, or at least check that the user has at least one autodiff
538+
// call in their code, to justify altering the compilation pipeline.
539+
if skip_size_increasing_opts && cfg!(llvm_enzyme) {
540+
unroll_loops = false;
541+
vectorize_slp = false;
542+
vectorize_loop = false;
543+
} else {
544+
unroll_loops =
545+
opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin;
546+
vectorize_slp = config.vectorize_slp;
547+
vectorize_loop = config.vectorize_loop;
548+
}
549+
trace!(
550+
"Enzyme: Running with unroll_loops: {}, vectorize_slp: {}, vectorize_loop: {}",
551+
unroll_loops, vectorize_slp, vectorize_loop
552+
);
523553
let using_thin_buffers = opt_stage == llvm::OptStage::PreLinkThinLTO || config.bitcode_needed();
524554
let pgo_gen_path = get_pgo_gen_path(config);
525555
let pgo_use_path = get_pgo_use_path(config);
@@ -583,8 +613,8 @@ pub(crate) unsafe fn llvm_optimize(
583613
using_thin_buffers,
584614
config.merge_functions,
585615
unroll_loops,
586-
config.vectorize_slp,
587-
config.vectorize_loop,
616+
vectorize_slp,
617+
vectorize_loop,
588618
config.no_builtins,
589619
config.emit_lifetime_markers,
590620
sanitizer_options.as_ref(),
@@ -606,6 +636,83 @@ pub(crate) unsafe fn llvm_optimize(
606636
result.into_result().map_err(|()| llvm_err(dcx, LlvmError::RunLlvmPasses))
607637
}
608638

639+
pub(crate) fn differentiate(
640+
module: &ModuleCodegen<ModuleLlvm>,
641+
cgcx: &CodegenContext<LlvmCodegenBackend>,
642+
diff_items: Vec<AutoDiffItem>,
643+
config: &ModuleConfig,
644+
) -> Result<(), FatalError> {
645+
for item in &diff_items {
646+
trace!("{}", item);
647+
}
648+
649+
let llmod = module.module_llvm.llmod();
650+
let llcx = &module.module_llvm.llcx;
651+
let diag_handler = cgcx.create_dcx();
652+
653+
// Before dumping the module, we want all the tt to become part of the module.
654+
for item in diff_items.iter() {
655+
let name = CString::new(item.source.clone()).unwrap();
656+
let fn_def: Option<&llvm::Value> =
657+
unsafe { llvm::LLVMGetNamedFunction(llmod, name.as_ptr()) };
658+
let fn_def = match fn_def {
659+
Some(x) => x,
660+
None => {
661+
return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareAutoDiff {
662+
src: item.source.clone(),
663+
target: item.target.clone(),
664+
error: "could not find source function".to_owned(),
665+
}));
666+
}
667+
};
668+
let target_name = CString::new(item.target.clone()).unwrap();
669+
debug!("target name: {:?}", &target_name);
670+
let fn_target: Option<&llvm::Value> =
671+
unsafe { llvm::LLVMGetNamedFunction(llmod, target_name.as_ptr()) };
672+
let fn_target = match fn_target {
673+
Some(x) => x,
674+
None => {
675+
return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareAutoDiff {
676+
src: item.source.clone(),
677+
target: item.target.clone(),
678+
error: "could not find target function".to_owned(),
679+
}));
680+
}
681+
};
682+
683+
crate::builder::generate_enzyme_call(llmod, llcx, fn_def, fn_target, item.attrs.clone());
684+
}
685+
686+
// FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
687+
688+
if let Some(opt_level) = config.opt_level {
689+
let opt_stage = match cgcx.lto {
690+
Lto::Fat => llvm::OptStage::PreLinkFatLTO,
691+
Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO,
692+
_ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO,
693+
_ => llvm::OptStage::PreLinkNoLTO,
694+
};
695+
// This is our second opt call, so now we run all opts,
696+
// to make sure we get the best performance.
697+
let skip_size_increasing_opts = false;
698+
trace!("running Module Optimization after differentiation");
699+
unsafe {
700+
llvm_optimize(
701+
cgcx,
702+
diag_handler.handle(),
703+
module,
704+
config,
705+
opt_level,
706+
opt_stage,
707+
skip_size_increasing_opts,
708+
)?
709+
};
710+
}
711+
trace!("done with differentiate()");
712+
713+
Ok(())
714+
}
715+
609716
// Unsafe due to LLVM calls.
610717
pub(crate) unsafe fn optimize(
611718
cgcx: &CodegenContext<LlvmCodegenBackend>,
@@ -628,14 +735,29 @@ pub(crate) unsafe fn optimize(
628735
unsafe { llvm::LLVMWriteBitcodeToFile(llmod, out.as_ptr()) };
629736
}
630737

738+
// FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts
739+
631740
if let Some(opt_level) = config.opt_level {
632741
let opt_stage = match cgcx.lto {
633742
Lto::Fat => llvm::OptStage::PreLinkFatLTO,
634743
Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO,
635744
_ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO,
636745
_ => llvm::OptStage::PreLinkNoLTO,
637746
};
638-
return unsafe { llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage) };
747+
748+
// If we know that we will later run AD, then we disable vectorization and loop unrolling
749+
let skip_size_increasing_opts = cfg!(llvm_enzyme);
750+
return unsafe {
751+
llvm_optimize(
752+
cgcx,
753+
dcx,
754+
module,
755+
config,
756+
opt_level,
757+
opt_stage,
758+
skip_size_increasing_opts,
759+
)
760+
};
639761
}
640762
Ok(())
641763
}

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /