Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock
Original file line number Diff line number Diff line change
Expand Up @@ -4257,6 +4257,7 @@ dependencies = [
"bitflags",
"either",
"gsgdt",
"parking_lot",
"polonius-engine",
"rustc_abi",
"rustc_apfloat",
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_ssa/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use rustc_ast::expand::allocator::{
};
use rustc_data_structures::fx::{FxHashMap, FxIndexSet};
use rustc_data_structures::profiling::{get_resident_set_size, print_time_passes_entry};
use rustc_data_structures::sync::{IntoDynSyncSend, par_map};
use rustc_data_structures::unord::UnordMap;
use rustc_hir::attrs::{AttributeKind, DebuggerVisualizerType, OptimizeAttr};
use rustc_hir::def_id::{CRATE_DEF_ID, DefId, LOCAL_CRATE};
Expand All @@ -25,6 +24,7 @@ use rustc_middle::mir::BinOp;
use rustc_middle::mir::interpret::ErrorHandled;
use rustc_middle::mir::mono::{CodegenUnit, CodegenUnitNameBuilder, MonoItem, MonoItemPartitions};
use rustc_middle::query::Providers;
use rustc_middle::sync::{IntoDynSyncSend, par_map};
use rustc_middle::ty::layout::{HasTyCtxt, HasTypingEnv, LayoutOf, TyAndLayout};
use rustc_middle::ty::{self, Instance, Ty, TyCtxt};
use rustc_middle::{bug, span_bug};
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_data_structures/src/marker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ impl_dyn_send!(
[std::sync::Mutex<T> where T: ?Sized+ DynSend]
[std::sync::mpsc::Sender<T> where T: DynSend]
[std::sync::Arc<T> where T: ?Sized + DynSync + DynSend]
[std::sync::Weak<T> where T: ?Sized + DynSync + DynSend]
[std::sync::LazyLock<T, F> where T: DynSend, F: DynSend]
[std::collections::HashSet<K, S> where K: DynSend, S: DynSend]
[std::collections::HashMap<K, V, S> where K: DynSend, V: DynSend, S: DynSend]
Expand Down
6 changes: 3 additions & 3 deletions compiler/rustc_data_structures/src/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,17 @@ pub use parking_lot::{
};

pub use self::atomic::AtomicU64;
pub use self::branch_key::BranchKey;
pub use self::freeze::{FreezeLock, FreezeReadGuard, FreezeWriteGuard};
#[doc(no_inline)]
pub use self::lock::{Lock, LockGuard, Mode};
pub use self::mode::{is_dyn_thread_safe, set_dyn_thread_safe_mode};
pub use self::parallel::{
broadcast, join, par_for_each_in, par_map, parallel_guard, scope, spawn, try_par_for_each_in,
};
pub use self::parallel::{ParallelGuard, broadcast, parallel_guard, spawn};
pub use self::vec::{AppendOnlyIndexVec, AppendOnlyVec};
pub use self::worker_local::{Registry, WorkerLocal};
pub use crate::marker::*;

mod branch_key;
mod freeze;
mod lock;
mod parallel;
Expand Down
50 changes: 50 additions & 0 deletions compiler/rustc_data_structures/src/sync/branch_key.rs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, first of all, thank you for your work on this. I can see a lot has been done to address the issue!

Unable to speak about entire implementation since don't know much about this part, but one note on this particular file: the logic seems a bit unclear to me due to the magic numbers and bit manipulations. Even though the file isn't large, adding a few explanatory comments would be a big help for clarity

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, we'll need a regression test to check it doesn't ICE with these changes

Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use std::cmp;

#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct BranchKey(u128);

impl BranchKey {
pub const fn root() -> Self {
Self(0x80000000_00000000_00000000_00000000)
}

fn bits_branch(self, branch_num: u128, bits: u32) -> Result<Self, BranchNestingError> {
let trailing_zeros = self.0.trailing_zeros();
let allocated_shift = trailing_zeros.checked_sub(bits).ok_or(BranchNestingError(()))?;
Ok(BranchKey(
self.0 & !(1 << trailing_zeros)
| (1 << allocated_shift)
| (branch_num << (allocated_shift + 1)),
))
}

pub fn branch(self, branch_num: u128, branch_space: u128) -> BranchKey {
debug_assert!(
branch_num < branch_space,
"branch_num = {branch_num} should be less than branch_space = {branch_space}"
);
// floor(log2(n - 1)) + 1 == ceil(log2(n))
self.bits_branch(branch_num, (branch_space - 1).checked_ilog2().map_or(0, |b| b + 1))
.expect("query branch space is exhausted")
}

pub fn disjoint_cmp(self, other: Self) -> cmp::Ordering {
self.0.cmp(&other.0)
}

pub fn nest(self, then: Self) -> Result<Self, BranchNestingError> {
let trailing_zeros = then.0.trailing_zeros();
let branch_num = then.0.wrapping_shr(trailing_zeros + 1);
let bits = u128::BITS - trailing_zeros;
self.bits_branch(branch_num, bits)
}
}

#[derive(Debug)]
pub struct BranchNestingError(());

impl Default for BranchKey {
fn default() -> Self {
BranchKey::root()
}
}
182 changes: 0 additions & 182 deletions compiler/rustc_data_structures/src/sync/parallel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,54 +43,6 @@ pub fn parallel_guard<R>(f: impl FnOnce(&ParallelGuard) -> R) -> R {
ret
}

fn serial_join<A, B, RA, RB>(oper_a: A, oper_b: B) -> (RA, RB)
where
A: FnOnce() -> RA,
B: FnOnce() -> RB,
{
let (a, b) = parallel_guard(|guard| {
let a = guard.run(oper_a);
let b = guard.run(oper_b);
(a, b)
});
(a.unwrap(), b.unwrap())
}

/// Runs a list of blocks in parallel. The first block is executed immediately on
/// the current thread. Use that for the longest running block.
#[macro_export]
macro_rules! parallel {
(impl $fblock:block [$($c:expr,)*] [$block:expr $(, $rest:expr)*]) => {
parallel!(impl $fblock [$block, $($c,)*] [$($rest),*])
};
(impl $fblock:block [$($blocks:expr,)*] []) => {
$crate::sync::parallel_guard(|guard| {
$crate::sync::scope(|s| {
$(
let block = $crate::sync::FromDyn::from(|| $blocks);
s.spawn(move |_| {
guard.run(move || block.into_inner()());
});
)*
guard.run(|| $fblock);
});
});
};
($fblock:block, $($blocks:block),*) => {
if $crate::sync::is_dyn_thread_safe() {
// Reverse the order of the later blocks since Rayon executes them in reverse order
// when using a single thread. This ensures the execution order matches that
// of a single threaded rustc.
parallel!(impl $fblock [] [$($blocks),*]);
} else {
$crate::sync::parallel_guard(|guard| {
guard.run(|| $fblock);
$(guard.run(|| $blocks);)*
});
}
};
}

pub fn spawn(func: impl FnOnce() + DynSend + 'static) {
if mode::is_dyn_thread_safe() {
let func = FromDyn::from(func);
Expand All @@ -102,140 +54,6 @@ pub fn spawn(func: impl FnOnce() + DynSend + 'static) {
}
}

// This function only works when `mode::is_dyn_thread_safe()`.
pub fn scope<'scope, OP, R>(op: OP) -> R
where
OP: FnOnce(&rustc_thread_pool::Scope<'scope>) -> R + DynSend,
R: DynSend,
{
let op = FromDyn::from(op);
rustc_thread_pool::scope(|s| FromDyn::from(op.into_inner()(s))).into_inner()
}

#[inline]
pub fn join<A, B, RA: DynSend, RB: DynSend>(oper_a: A, oper_b: B) -> (RA, RB)
where
A: FnOnce() -> RA + DynSend,
B: FnOnce() -> RB + DynSend,
{
if mode::is_dyn_thread_safe() {
let oper_a = FromDyn::from(oper_a);
let oper_b = FromDyn::from(oper_b);
let (a, b) = parallel_guard(|guard| {
rustc_thread_pool::join(
move || guard.run(move || FromDyn::from(oper_a.into_inner()())),
move || guard.run(move || FromDyn::from(oper_b.into_inner()())),
)
});
(a.unwrap().into_inner(), b.unwrap().into_inner())
} else {
serial_join(oper_a, oper_b)
}
}

fn par_slice<I: DynSend>(
items: &mut [I],
guard: &ParallelGuard,
for_each: impl Fn(&mut I) + DynSync + DynSend,
) {
struct State<'a, F> {
for_each: FromDyn<F>,
guard: &'a ParallelGuard,
group: usize,
}

fn par_rec<I: DynSend, F: Fn(&mut I) + DynSync + DynSend>(
items: &mut [I],
state: &State<'_, F>,
) {
if items.len() <= state.group {
for item in items {
state.guard.run(|| (state.for_each)(item));
}
} else {
let (left, right) = items.split_at_mut(items.len() / 2);
let mut left = state.for_each.derive(left);
let mut right = state.for_each.derive(right);
rustc_thread_pool::join(move || par_rec(*left, state), move || par_rec(*right, state));
}
}

let state = State {
for_each: FromDyn::from(for_each),
guard,
group: std::cmp::max(items.len() / 128, 1),
};
par_rec(items, &state)
}

pub fn par_for_each_in<I: DynSend, T: IntoIterator<Item = I>>(
t: T,
for_each: impl Fn(&I) + DynSync + DynSend,
) {
parallel_guard(|guard| {
if mode::is_dyn_thread_safe() {
let mut items: Vec<_> = t.into_iter().collect();
par_slice(&mut items, guard, |i| for_each(&*i))
} else {
t.into_iter().for_each(|i| {
guard.run(|| for_each(&i));
});
}
});
}

/// This runs `for_each` in parallel for each iterator item. If one or more of the
/// `for_each` calls returns `Err`, the function will also return `Err`. The error returned
/// will be non-deterministic, but this is expected to be used with `ErrorGuaranteed` which
/// are all equivalent.
pub fn try_par_for_each_in<T: IntoIterator, E: DynSend>(
t: T,
for_each: impl Fn(&<T as IntoIterator>::Item) -> Result<(), E> + DynSync + DynSend,
) -> Result<(), E>
where
<T as IntoIterator>::Item: DynSend,
{
parallel_guard(|guard| {
if mode::is_dyn_thread_safe() {
let mut items: Vec<_> = t.into_iter().collect();

let error = Mutex::new(None);

par_slice(&mut items, guard, |i| {
if let Err(err) = for_each(&*i) {
*error.lock() = Some(err);
}
});

if let Some(err) = error.into_inner() { Err(err) } else { Ok(()) }
} else {
t.into_iter().filter_map(|i| guard.run(|| for_each(&i))).fold(Ok(()), Result::and)
}
})
}

pub fn par_map<I: DynSend, T: IntoIterator<Item = I>, R: DynSend, C: FromIterator<R>>(
t: T,
map: impl Fn(I) -> R + DynSync + DynSend,
) -> C {
parallel_guard(|guard| {
if mode::is_dyn_thread_safe() {
let map = FromDyn::from(map);

let mut items: Vec<(Option<I>, Option<R>)> =
t.into_iter().map(|i| (Some(i), None)).collect();

par_slice(&mut items, guard, |i| {
i.1 = Some(map(i.0.take().unwrap()));
});

items.into_iter().filter_map(|i| i.1).collect()
} else {
t.into_iter().filter_map(|i| guard.run(|| map(i))).collect()
}
})
}

pub fn broadcast<R: DynSend>(op: impl Fn(usize) -> R + DynSync) -> Vec<R> {
if mode::is_dyn_thread_safe() {
let op = FromDyn::from(op);
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_incremental/src/persist/save.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ use std::fs;
use std::sync::Arc;

use rustc_data_structures::fx::FxIndexMap;
use rustc_data_structures::sync::join;
use rustc_middle::dep_graph::{
DepGraph, SerializedDepGraph, WorkProduct, WorkProductId, WorkProductMap,
};
use rustc_middle::sync::join;
use rustc_middle::ty::TyCtxt;
use rustc_serialize::Encodable as RustcEncodable;
use rustc_serialize::opaque::{FileEncodeResult, FileEncoder};
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_interface/src/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ pub fn try_print_query_stack(
if let Some(icx) = icx {
ty::print::with_no_queries!(print_query_stack(
QueryCtxt::new(icx.tcx),
icx.query,
icx.query.map(|i| i.id),
dcx,
limit_frames,
file,
Expand Down
3 changes: 2 additions & 1 deletion compiler/rustc_interface/src/passes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use rustc_codegen_ssa::{CodegenResults, CrateInfo};
use rustc_data_structures::jobserver::Proxy;
use rustc_data_structures::steal::Steal;
use rustc_data_structures::sync::{AppendOnlyIndexVec, FreezeLock, WorkerLocal};
use rustc_data_structures::{parallel, thousands};
use rustc_data_structures::thousands;
use rustc_errors::timings::TimingSection;
use rustc_expand::base::{ExtCtxt, LintStoreExpand};
use rustc_feature::Features;
Expand All @@ -27,6 +27,7 @@ use rustc_metadata::EncodedMetadata;
use rustc_metadata::creader::CStore;
use rustc_middle::arena::Arena;
use rustc_middle::dep_graph::DepsType;
use rustc_middle::parallel;
use rustc_middle::ty::{self, CurrentGcx, GlobalCtxt, RegisteredTools, TyCtxt};
use rustc_middle::util::Providers;
use rustc_parse::lexer::StripTokens;
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_lint/src/late.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ use std::any::Any;
use std::cell::Cell;

use rustc_data_structures::stack::ensure_sufficient_stack;
use rustc_data_structures::sync::join;
use rustc_hir::def_id::{LocalDefId, LocalModDefId};
use rustc_hir::{self as hir, AmbigArg, HirId, intravisit as hir_visit};
use rustc_middle::hir::nested_filter;
use rustc_middle::sync::join;
use rustc_middle::ty::{self, TyCtxt};
use rustc_session::Session;
use rustc_session::lint::LintPass;
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_metadata/src/rmeta/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use std::sync::Arc;

use rustc_data_structures::fx::{FxIndexMap, FxIndexSet};
use rustc_data_structures::memmap::{Mmap, MmapMut};
use rustc_data_structures::sync::{join, par_for_each_in};
use rustc_data_structures::temp_dir::MaybeTempDir;
use rustc_data_structures::thousands::usize_with_underscores;
use rustc_feature::Features;
Expand All @@ -21,6 +20,7 @@ use rustc_middle::dep_graph::WorkProductId;
use rustc_middle::middle::dependency_format::Linkage;
use rustc_middle::mir::interpret;
use rustc_middle::query::Providers;
use rustc_middle::sync::{join, par_for_each_in};
use rustc_middle::traits::specialization_graph;
use rustc_middle::ty::AssocContainer;
use rustc_middle::ty::codec::TyEncoder;
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_middle/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ edition = "2024"
bitflags = "2.4.1"
either = "1.5.0"
gsgdt = "0.1.2"
parking_lot = "0.12"
polonius-engine = "0.13.0"
rustc_abi = { path = "../rustc_abi" }
rustc_apfloat = "0.2.0"
Expand Down
3 changes: 2 additions & 1 deletion compiler/rustc_middle/src/hir/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@ use rustc_ast::visit::{VisitorResult, walk_list};
use rustc_data_structures::fingerprint::Fingerprint;
use rustc_data_structures::stable_hasher::{HashStable, StableHasher};
use rustc_data_structures::svh::Svh;
use rustc_data_structures::sync::{DynSend, DynSync, par_for_each_in, try_par_for_each_in};
use rustc_data_structures::sync::{DynSend, DynSync};
use rustc_hir::attrs::AttributeKind;
use rustc_hir::def::{DefKind, Res};
use rustc_hir::def_id::{DefId, LOCAL_CRATE, LocalDefId, LocalModDefId};
use rustc_hir::definitions::{DefKey, DefPath, DefPathHash};
use rustc_hir::intravisit::Visitor;
use rustc_hir::*;
use rustc_hir_pretty as pprust_hir;
use rustc_middle::sync::{par_for_each_in, try_par_for_each_in};
use rustc_span::def_id::StableCrateId;
use rustc_span::{ErrorGuaranteed, Ident, Span, Symbol, kw, with_metavar_spans};

Expand Down
Loading
Loading