Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit c8905ea

Browse files
committed
Auto merge of #147128 - matthiaskrgr:rollup-mqey4c4, r=matthiaskrgr
Rollup of 6 pull requests Successful merges: - #140482 (std::net: update tcp deferaccept delay type to Duration.) - #141469 (Allow `&raw [mut | const]` for union field in safe code) - #144197 (TypeTree support in autodiff) - #146675 (Allow shared access to `Exclusive<T>` when `T: Sync`) - #147113 (Reland "Add LSX accelerated implementation for source file analysis") - #147120 (Fix --extra-checks=spellcheck to prevent cargo install every time) r? `@ghost` `@rustbot` modify labels: rollup
2 parents 8d72d3e + 4eb6b8f commit c8905ea

File tree

82 files changed

+1631
-62
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

82 files changed

+1631
-62
lines changed

‎compiler/rustc_ast/src/expand/autodiff_attrs.rs‎

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
use std::fmt::{self, Display, Formatter};
77
use std::str::FromStr;
88

9+
use crate::expand::typetree::TypeTree;
910
use crate::expand::{Decodable, Encodable, HashStable_Generic};
1011
use crate::{Ty, TyKind};
1112

@@ -84,6 +85,8 @@ pub struct AutoDiffItem {
8485
/// The name of the function being generated
8586
pub target: String,
8687
pub attrs: AutoDiffAttrs,
88+
pub inputs: Vec<TypeTree>,
89+
pub output: TypeTree,
8790
}
8891

8992
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
@@ -275,14 +278,22 @@ impl AutoDiffAttrs {
275278
!matches!(self.mode, DiffMode::Error | DiffMode::Source)
276279
}
277280

278-
pub fn into_item(self, source: String, target: String) -> AutoDiffItem {
279-
AutoDiffItem { source, target, attrs: self }
281+
pub fn into_item(
282+
self,
283+
source: String,
284+
target: String,
285+
inputs: Vec<TypeTree>,
286+
output: TypeTree,
287+
) -> AutoDiffItem {
288+
AutoDiffItem { source, target, inputs, output, attrs: self }
280289
}
281290
}
282291

283292
impl fmt::Display for AutoDiffItem {
284293
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
285294
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
286-
write!(f, " with attributes: {:?}", self.attrs)
295+
write!(f, " with attributes: {:?}", self.attrs)?;
296+
write!(f, " with inputs: {:?}", self.inputs)?;
297+
write!(f, " with output: {:?}", self.output)
287298
}
288299
}

‎compiler/rustc_ast/src/expand/typetree.rs‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ pub enum Kind {
3131
Half,
3232
Float,
3333
Double,
34+
F128,
3435
Unknown,
3536
}
3637

‎compiler/rustc_codegen_gcc/src/builder.rs‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,6 +1383,7 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> {
13831383
_src_align: Align,
13841384
size: RValue<'gcc>,
13851385
flags: MemFlags,
1386+
_tt: Option<rustc_ast::expand::typetree::FncTree>, // Autodiff TypeTrees are LLVM-only, ignored in GCC backend
13861387
) {
13871388
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
13881389
let size = self.intcast(size, self.type_size_t(), false);

‎compiler/rustc_codegen_gcc/src/intrinsic/mod.rs‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,7 @@ impl<'gcc, 'tcx> ArgAbiExt<'gcc, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
770770
scratch_align,
771771
bx.const_usize(self.layout.size.bytes()),
772772
MemFlags::empty(),
773+
None,
773774
);
774775

775776
bx.lifetime_end(scratch, scratch_size);

‎compiler/rustc_codegen_llvm/src/abi.rs‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
246246
scratch_align,
247247
bx.const_usize(copy_bytes),
248248
MemFlags::empty(),
249+
None,
249250
);
250251
bx.lifetime_end(llscratch, scratch_size);
251252
}

‎compiler/rustc_codegen_llvm/src/back/lto.rs‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,8 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
563563
config::AutoDiff::Enable => {}
564564
// We handle this below
565565
config::AutoDiff::NoPostopt => {}
566+
// Disables TypeTree generation
567+
config::AutoDiff::NoTT => {}
566568
}
567569
}
568570
// This helps with handling enums for now.

‎compiler/rustc_codegen_llvm/src/builder.rs‎

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::borrow::{Borrow, Cow};
22
use std::ops::Deref;
33
use std::{iter, ptr};
44

5+
use rustc_ast::expand::typetree::FncTree;
56
pub(crate) mod autodiff;
67
pub(crate) mod gpu_offload;
78

@@ -1107,11 +1108,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
11071108
src_align: Align,
11081109
size: &'ll Value,
11091110
flags: MemFlags,
1111+
tt: Option<FncTree>,
11101112
) {
11111113
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
11121114
let size = self.intcast(size, self.type_isize(), false);
11131115
let is_volatile = flags.contains(MemFlags::VOLATILE);
1114-
unsafe {
1116+
let memcpy = unsafe {
11151117
llvm::LLVMRustBuildMemCpy(
11161118
self.llbuilder,
11171119
dst,
@@ -1120,7 +1122,16 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
11201122
src_align.bytes() as c_uint,
11211123
size,
11221124
is_volatile,
1123-
);
1125+
)
1126+
};
1127+
1128+
// TypeTree metadata for memcpy is especially important: when Enzyme encounters
1129+
// a memcpy during autodiff, it needs to know the structure of the data being
1130+
// copied to properly track derivatives. For example, copying an array of floats
1131+
// vs. copying a struct with mixed types requires different derivative handling.
1132+
// The TypeTree tells Enzyme exactly what memory layout to expect.
1133+
if let Some(tt) = tt {
1134+
crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt);
11241135
}
11251136
}
11261137

‎compiler/rustc_codegen_llvm/src/builder/autodiff.rs‎

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::ptr;
22

33
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
4+
use rustc_ast::expand::typetree::FncTree;
45
use rustc_codegen_ssa::common::TypeKind;
56
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
67
use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv};
@@ -294,6 +295,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
294295
fn_args: &[&'ll Value],
295296
attrs: AutoDiffAttrs,
296297
dest: PlaceRef<'tcx, &'ll Value>,
298+
fnc_tree: FncTree,
297299
) {
298300
// We have to pick the name depending on whether we want forward or reverse mode autodiff.
299301
let mut ad_name: String = match attrs.mode {
@@ -370,6 +372,10 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
370372
fn_args,
371373
);
372374

375+
if !fnc_tree.args.is_empty() || !fnc_tree.ret.0.is_empty() {
376+
crate::typetree::add_tt(cx.llmod, cx.llcx, fn_to_diff, fnc_tree);
377+
}
378+
373379
let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
374380

375381
builder.store_to_place(call, dest.val);

‎compiler/rustc_codegen_llvm/src/intrinsic.rs‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,6 +1212,9 @@ fn codegen_autodiff<'ll, 'tcx>(
12121212
&mut diff_attrs.input_activity,
12131213
);
12141214

1215+
let fnc_tree =
1216+
rustc_middle::ty::fnc_typetrees(tcx, fn_source.ty(tcx, TypingEnv::fully_monomorphized()));
1217+
12151218
// Build body
12161219
generate_enzyme_call(
12171220
bx,
@@ -1222,6 +1225,7 @@ fn codegen_autodiff<'ll, 'tcx>(
12221225
&val_arr,
12231226
diff_attrs.clone(),
12241227
result,
1228+
fnc_tree,
12251229
);
12261230
}
12271231

‎compiler/rustc_codegen_llvm/src/lib.rs‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ mod llvm_util;
6868
mod mono_item;
6969
mod type_;
7070
mod type_of;
71+
mod typetree;
7172
mod va_arg;
7273
mod value;
7374

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /