Spark源码-sample计算流程分析

本文深入分析Spark的sample计算流程,从参数说明到源码细节,涵盖BernoulliSampler(伯努利分布)和PoissonSampler(泊松分布)的采样方法,揭示了随机抽样的实现策略。

1.参数说明

    Sample构造函数有三个参数,每个参数的含义如下:

    withReplacement:元素可以多次抽样(有放回的抽样)

    fraction:期望样本的大小作为RDD大小的一部分, 当withReplacement=false时:选择每个元素的概率;分数一定是[0,1] ; 当    withReplacement=true时:选择每个元素的期望次数; 分数必须大于等于0。

    seed:随机数生成器的种子

2.源码流程分析

    往下找sample的源码,可以看到调用了RDD的sample方法:

def sample(
    withReplacement: Boolean,
    fraction: Double,
    seed: Long = Utils.random.nextLong): RDD[T] = {
  require(fraction >= 0,
    s"Fraction must be nonnegative, but got ${fraction}")

  withScope {
    require(fraction >= 0.0, "Negative fraction value: " + fraction)
    if (withReplacement) {
      new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), true, seed)
    } else {
      new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](fraction), true, seed)
    }
  }
}

    方法前面都是对fraction合理性的判断,主要还是要看if和else中。都是返回了PartitionwiseSampleRDD对象,并且其构造方法传入了PoissonSampler和BernoulliSampler对象。当withReplacement=false时,是选择每个元素的概率,所以我们先看BernoulliSampler。

2.BernoulliSampler(伯努利分布)

 

class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] {

  /** epsilon slop to avoid failure from floating point jitter */
  require(
    fraction >= (0.0 - RandomSampler.roundingEpsilon)
      && fraction <= (1.0 + RandomSampler.roundingEpsilon),
    s"Sampling fraction ($fraction) must be on interval [0, 1]")

  private val rng: Random = RandomSampler.newDefaultRNG

  override def setSeed(seed: Long): Unit = rng.setSeed(seed)

  private lazy val gapSampling: GapSampling =
    new GapSampling(fraction, rng, RandomSampler.rngEpsilon)

  override def sample(): Int = {
    if (fraction <= 0.0) {
      0
    } else if (fraction >= 1.0) {
      1
    } else if (fraction <= RandomSampler.defaultMaxGapSamplingFraction) {
      gapSampling.sample()
    } else {
      if (rng.nextDouble() <= fraction) {
        1
      } else {
        0
      }
    }
  }

  override def clone: BernoulliSampler[T] = new BernoulliSampler[T](fraction)
}

    重点看BernoulliSampler类中的sample方法,到后面我们就可以看到在PartitionwiseSampleRDD对象中调用了该sample方法。如果fraction小于等于0.0,返回0,代表不被选择,大于等于1,就会被选择,当然这都是特殊情况,正常情况下都应该是(0,1)的。下面一个条件是fraction <= RandomSampler.defaultMaxGapSamplingFraction,后面的常量是0.4(不知道为啥要这样),后面调用了一个GapSampling对象的sample方法(裁员间隔采样的方法):

class GapSampling(
    f: Double,
    rng: Random = RandomSampler.newDefaultRNG,
    epsilon: Double = RandomSampler.rngEpsilon) extends Serializable {

  require(f > 0.0  &&  f < 1.0, s"Sampling fraction ($f) must reside on open interval (0, 1)")
  require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0")

  private val lnq = math.log1p(-f)

  /** Return 1 if the next item should be sampled. Otherwise, return 0. */

/**如果下一个元素需要被选为例子的话,返回1,否则的话就返回0**/
  def sample(): Int = {
    if (countForDropping > 0) {
      countForDropping -= 1
      0
    } else {
      advance()
      1
    }
  }

  private var countForDropping: Int = 0

  /**
   * Decide the number of elements that won't be sampled,
   * according to geometric dist P(k) = (f)(1-f)^k.
   */
  private def advance(): Unit = {
    val u = math.max(rng.nextDouble(), epsilon)
    countForDropping = (math.log(u) / lnq).toInt
  }

  /** advance to first sample as part of object construction. */
  advance()
  // Attempting to invoke this closer to the top with other object initialization
  // was causing it to break in strange ways, so I'm invoking it last, which seems to
  // work reliably.
}

    我们看其sample函数,及时要返回0或1。其中用到了countForDropping和advance()方法。变量lng是对-fraction加一再取自然对数,advance方法中,先生成一个0-1的随机数,然后向下面是取了对数函数,可以简化为log以1-fraction为底,u的对数。根据注释的说法,这是个间隔采样,兴趣的可以研究一下,反正我是没怎么看明白。

    接着看BernoulliSampler的sample函数,跳过GapSampling对象的sample方法之后,后面的就很简单了,判断随机生成的0-1的数和fraction比较,如果小于fraction就返回1,否则返回0。

2.2 PoissonSampler(泊松分布)

class PoissonSampler[T](
    fraction: Double,
    useGapSamplingIfPossible: Boolean) extends RandomSampler[T, T] {

  def this(fraction: Double) = this(fraction, useGapSamplingIfPossible = true)

  /** Epsilon slop to avoid failure from floating point jitter. */
  require(
    fraction >= (0.0 - RandomSampler.roundingEpsilon),
    s"Sampling fraction ($fraction) must be >= 0")

  // PoissonDistribution throws an exception when fraction <= 0
  // If fraction is <= 0, Iterator.empty is used below, so we can use any placeholder value.
  private val rng = new PoissonDistribution(if (fraction > 0.0) fraction else 1.0)
  private val rngGap = RandomSampler.newDefaultRNG

  override def setSeed(seed: Long) {
    rng.reseedRandomGenerator(seed)
    rngGap.setSeed(seed)
  }

  private lazy val gapSamplingReplacement =
    new GapSamplingReplacement(fraction, rngGap, RandomSampler.rngEpsilon)

  override def sample(): Int = {
    if (fraction <= 0.0) {
      0
    } else if (useGapSamplingIfPossible &&
               fraction <= RandomSampler.defaultMaxGapSamplingFraction) {
      gapSamplingReplacement.sample()
    } else {
      rng.sample()
    }
  }

  override def sample(items: Iterator[T]): Iterator[T] = {
    if (fraction <= 0.0) {
      Iterator.empty
    } else {
      val useGapSampling = useGapSamplingIfPossible &&
        fraction <= RandomSampler.defaultMaxGapSamplingFraction

      items.flatMap { item =>
        val count = if (useGapSampling) gapSamplingReplacement.sample() else rng.sample()
        if (count == 0) Iterator.empty else Iterator.fill(count)(item)
      }
    }
  }

  override def clone: PoissonSampler[T] = new PoissonSampler[T](fraction, useGapSamplingIfPossible)
}

主要是用到了GapSamplingReplacement的sample和PoissonDistribution的sample,其方法如下:

GapSamplingReplacement.sample

class GapSamplingReplacement(
    val f: Double,
    val rng: Random = RandomSampler.newDefaultRNG,
    epsilon: Double = RandomSampler.rngEpsilon) extends Serializable {

  require(f > 0.0, s"Sampling fraction ($f) must be > 0")
  require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0")

  protected val q = math.exp(-f)

  /**
   * Sample from Poisson distribution, conditioned such that the sampled value is >= 1.
   * This is an adaptation from the algorithm for Generating Poisson distributed random variables:
   * http://en.wikipedia.org/wiki/Poisson_distribution
   */
  protected def poissonGE1: Int = {
    // simulate that the standard poisson sampling
    // gave us at least one iteration, for a sample of >= 1
    var pp = q + ((1.0 - q) * rng.nextDouble())
    var r = 1

    // now continue with standard poisson sampling algorithm
    pp *= rng.nextDouble()
    while (pp > q) {
      r += 1
      pp *= rng.nextDouble()
    }
    r
  }
  private var countForDropping: Int = 0

  def sample(): Int = {
    if (countForDropping > 0) {
      countForDropping -= 1
      0
    } else {
      val r = poissonGE1
      advance()
      r
    }
  }

  /**
   * Skip elements with replication factor zero (i.e. elements that won't be sampled).
   * Samples 'k' from geometric distribution  P(k) = (1-q)(q)^k, where q = e^(-f), that is
   * q is the probability of Poisson(0; f)
   */
  private def advance(): Unit = {
    val u = math.max(rng.nextDouble(), epsilon)
    countForDropping = (math.log(u) / (-f)).toInt
  }

  /** advance to first sample as part of object construction. */
  advance()
  // Attempting to invoke this closer to the top with other object initialization
  // was causing it to break in strange ways, so I'm invoking it last, which seems to
  // work reliably.
}

这个是间隔采样。

PoissonDistribution.sample

public int sample() {
    return (int) FastMath.min(nextPoisson(mean), Integer.MAX_VALUE);
}

这个是伯努利采样。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值