diff --git a/pymc_bart/pgbart.py b/pymc_bart/pgbart.py index 013dda8..59aed36 100644 --- a/pymc_bart/pgbart.py +++ b/pymc_bart/pgbart.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging - from copy import deepcopy from numba import njit @@ -32,9 +30,6 @@ from pymc_bart.tree import LeafNode, SplitNode, Tree -_log = logging.getLogger("pymc") - - class PGBART(ArrayStepShared): """ Particle Gibss BART sampling step. @@ -74,6 +69,7 @@ def __init__( vars = inputvars(vars) value_bart = vars[0] self.bart = model.values_to_rvs[value_bart].owner.op + self.rng = np.random.default_rng() if isinstance(self.bart.X, Variable): self.X = self.bart.X.eval() @@ -114,10 +110,10 @@ def __init__( num_observations=self.num_observations, shape=self.shape, ) - self.normal = NormalSampler(mu_std, self.shape) - self.uniform = UniformSampler(0.33, 0.75, self.shape) + self.normal = NormalSampler(mu_std, self.shape, self.rng) + self.uniform = UniformSampler(0.33, 0.75, self.shape, self.rng) self.prior_prob_leaf_node = compute_prior_probability(self.bart.alpha) - self.ssv = SampleSplittingVariable(self.alpha_vec) + self.ssv = SampleSplittingVariable(self.alpha_vec, self.rng) self.tune = True @@ -140,7 +136,7 @@ def __init__( self.all_particles = [] for _ in range(self.m): self.a_tree.leaf_node_value = init_mean / self.m - p = ParticleTree(self.a_tree) + p = ParticleTree(self.a_tree, self.rng) self.all_particles.append(p) self.all_trees = np.array([p.tree for p in self.all_particles]) super().__init__(vars, shared) @@ -148,7 +144,7 @@ def __init__( def astep(self, _): variable_inclusion = np.zeros(self.num_variates, dtype="int") - tree_ids = np.random.choice(range(self.m), replace=False, size=self.batch[~self.tune]) + tree_ids = self.rng.choice(range(self.m), replace=False, size=self.batch[~self.tune]) for tree_id in tree_ids: # Compute the sum of trees without the old tree that we are attempting to replace self.sum_trees_noi = self.sum_trees - self.all_particles[tree_id].tree._predict() @@ -200,7 +196,7 @@ def astep(self, _): used_variates = new_tree.get_split_variables() if self.tune: - self.ssv = SampleSplittingVariable(self.alpha_vec) + self.ssv = SampleSplittingVariable(self.alpha_vec, self.rng) for index in used_variates: self.alpha_vec[index] += 1 else: @@ -256,7 +252,7 @@ def get_particle_tree(self, particles, normalized_weights): Sample a new particle, new tree and update log_weight """ new_index = self.systematic(normalized_weights)[ - discrete_uniform_sampler(self.num_particles) + discrete_uniform_sampler(self.num_particles, self.rng) ] new_particle = particles[new_index - 2] new_particle.log_weight = new_particle.old_likelihood_logp - self.log_num_particles @@ -290,7 +286,7 @@ def init_particles(self, tree_id: int) -> np.ndarray: particles = [p0, p1] for _ in self.indices: - pt = ParticleTree(self.a_tree) + pt = ParticleTree(self.a_tree, self.rng) if self.tune: pt.kfactor = self.uniform.random() else: @@ -328,14 +324,15 @@ def competence(var, has_grad): class ParticleTree: """Particle tree.""" - __slots__ = "tree", "expansion_nodes", "log_weight", "old_likelihood_logp", "kfactor" + __slots__ = "tree", "expansion_nodes", "log_weight", "old_likelihood_logp", "kfactor", "rng" - def __init__(self, tree): + def __init__(self, tree, rng): self.tree = tree.copy() # keeps the tree that we care at the moment self.expansion_nodes = [0] self.log_weight = 0 self.old_likelihood_logp = 0 self.kfactor = 0.75 + self.rng = rng def sample_tree( self, @@ -355,7 +352,7 @@ def sample_tree( # Probability that this node will remain a leaf node prob_leaf = prior_prob_leaf_node[self.tree[index_leaf_node].depth] - if prob_leaf < np.random.random(): + if prob_leaf < self.rng.random(): index_selected_predictor = grow_tree( self.tree, index_leaf_node, @@ -368,6 +365,7 @@ def sample_tree( normal, self.kfactor, shape, + self.rng, ) if index_selected_predictor is not None: new_indexes = self.tree.idx_leaf_nodes[-2:] @@ -393,7 +391,7 @@ def sample_leafs(self, sum_trees, m, normal, shape): class SampleSplittingVariable: - def __init__(self, alpha_vec): + def __init__(self, alpha_vec, rng): """ Sample splitting variables proportional to `alpha_vec`. @@ -401,9 +399,10 @@ def __init__(self, alpha_vec): This enforce sparsity. """ self.enu = list(enumerate(np.cumsum(alpha_vec / alpha_vec.sum()))) + self.rng = rng def rvs(self): - rnd = np.random.random() + rnd = self.rng.random() for i, val in self.enu: if rnd <= val: return i @@ -449,6 +448,7 @@ def grow_tree( normal, kfactor, shape, + rng, ): current_node = tree.get_node(index_leaf_node) idx_data_points = current_node.idx_data_points @@ -456,7 +456,7 @@ def grow_tree( index_selected_predictor = ssv.rvs() selected_predictor = available_predictors[index_selected_predictor] available_splitting_values = X[idx_data_points, selected_predictor] - split_value = get_split_value(available_splitting_values, idx_data_points, missing_data) + split_value = get_split_value(available_splitting_values, idx_data_points, missing_data, rng) if split_value is None: index_selected_predictor = None @@ -511,7 +511,7 @@ def get_new_idx_data_points(split_value, idx_data_points, selected_predictor, X) return left_node_idx_data_points, right_node_idx_data_points -def get_split_value(available_splitting_values, idx_data_points, missing_data): +def get_split_value(available_splitting_values, idx_data_points, missing_data, rng): if missing_data: idx_data_points = idx_data_points[~np.isnan(available_splitting_values)] @@ -521,7 +521,9 @@ def get_split_value(available_splitting_values, idx_data_points, missing_data): split_value = None if available_splitting_values.size > 0: - idx_selected_splitting_values = discrete_uniform_sampler(len(available_splitting_values)) + idx_selected_splitting_values = discrete_uniform_sampler( + len(available_splitting_values), rng + ) split_value = available_splitting_values[idx_selected_splitting_values] return split_value @@ -561,21 +563,22 @@ def fast_mean(ari): return res / count -def discrete_uniform_sampler(upper_value): +def discrete_uniform_sampler(upper_value, rng): """Draw from the uniform distribution with bounds [0, upper_value). This is the same and np.random.randit(upper_value) but faster. """ - return int(np.random.random() * upper_value) + return int(rng.random() * upper_value) class NormalSampler: """Cache samples from a standard normal distribution.""" - def __init__(self, scale, shape): + def __init__(self, scale, shape, rng): self.size = 1000 self.scale = scale self.shape = shape + self.rng = rng self.update() def random(self): @@ -587,17 +590,18 @@ def random(self): def update(self): self.idx = 0 - self.cache = np.random.normal(loc=0.0, scale=self.scale, size=(self.shape, self.size)) + self.cache = self.rng.normal(loc=0.0, scale=self.scale, size=(self.shape, self.size)) class UniformSampler: """Cache samples from a uniform distribution.""" - def __init__(self, lower_bound, upper_bound, shape): + def __init__(self, lower_bound, upper_bound, shape, rng): self.size = 1000 self.upper_bound = upper_bound self.lower_bound = lower_bound self.shape = shape + self.rng = rng self.update() def random(self): @@ -609,7 +613,7 @@ def random(self): def update(self): self.idx = 0 - self.cache = np.random.uniform( + self.cache = self.rng.uniform( self.lower_bound, self.upper_bound, size=(self.shape, self.size) )