死磕sparkSQL源码之TreeNode

本文主要探讨Spark SQL中的InternalRow和TreeNode体系。InternalRow是表示数据行的抽象类,有BaseGenericInternalRow、UnsafeRow和JoinedRow等实现。TreeNode作为所有树结构的基类,提供了丰富的集合和树遍历操作接口,如map、flatmap、transform等。文章还介绍了TreeNode的泛型定义和相关方法,并预告将对Expression和QueryPlan子类进行讲解。

InternalRow体系

学习TreeNode之前,我们先了解下InternalRow。

对于我们一般接触到的数据库关系表来说,我们对于数据库中的数据操作都是按照“行”为单位的。在spark sql内部实现中,InternalRow是用来表示这一行行数据的类。看下源码中的解释,InternalRow作为一个抽象类,包numFields 和 update 方法,以及各列数据对应的 get 与 set 方法,但具体的实现逻辑体现在不同的子类中

/**
* An abstract class for row used internally in Spark SQL, which only contains the columns as
* internal types.
一个抽象类,用于表示spark SQL内部行,只包含内部类型的多个列(其实就是表示一行行数据的类)
*/

详细代码这里就不贴了,整理下一些重要接口的功能含义好了,注意InternalRow中都是根据下标来访问和操作列元素的 。

InternalRow实现类包括,BaseGenericinternalRow、UnsafeRow 和 JoinedRow 3 个直接子类

  • BaseGenericinternalRow:也是一个抽象类,实现了SpecializedGetters类中定义的所有GET方法,但是最终还是调用genericGet方法实现最终逻辑,genericGet方法在BaseGenericinternalRow内中只是定义了一个接口,最终实现在BaseGenericinternalRow的子类中。
  • JoinedRow:该类主要用于join操作,两个InternalRow放在一起形成新的InternalRow,在sparksql 聚合和join相关操作中,会用的比较多
  • UnsafeRow:不采用 Java 对象存储的方式,避免了 JVM 中垃圾回收( GC )的代价 。 此外,UnsafeRow 对行数据进行了特定的编码,使得存储更加高效 。

TreeNode体系

接下来正式开始进行TreeNode的学习

TreeNode是Spark SQL中所有树结构的基类,定义了一系列通用的集合操作和树遍历的操作接口。我们先看下TreeNode的代码

abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with TreePatternBits {


}

首先TreeNode是一个抽象类,一个泛型类;这里TreeNode[BaseType <: TreeNode[BaseType]]这种书写方式,不知道大家会不会很陌生,反正我一开始看的时候,觉得不知道咋回事,那么我们来一起理解写,这个具体是什么含义:

  • 首先,我们很明确这个TreeNode是个泛型,我们把[]中的看作一个T,其实就是TreeNode[T],这个没问题
  • 接下里,我们要理解下“<:”这个符号的含义,这属于scala泛型中的知识,上边界和下边界。上边界是“<:”,下边界是“>:”;上边界,拿代码中的定义的含义解释就是BaseType必须是TreeNode[BaseType]的子类。也就是说TreeNode的泛型类型用BaseType表示,泛型类型比如是TreeNode类的子类

另外,TreeNode还继承了Product接口,对于product接口相关使用介绍,请看这篇文章(scala之product特质理解_大家都叫我船长的博客-CSDN博客),看完应该就明白了。

接下来,开始详细看看TreeNode一些重要方法:

  • 返回子节点,只定义了接口,具体实现在之类中
 /**
   * Returns a Seq of the children of this node.
   * Children should not change. Immutability required for containsChild optimization
   */
  def children: Seq[BaseType]
  •  返回子节点的set集合
lazy val containsChild: Set[TreeNode[_]] = children.toSet
  • 比较两个TreeNode是否相等
def fastEquals(other: TreeNode[_]): Boolean = {
    this.eq(other) || this == other
  }
  • 查找第一个符合f条件的TreeNode
def find(f: BaseType => Boolean): Option[BaseType] = if (f(this)) {
    Some(this)
  } else {
    children.foldLeft(Option.empty[BaseType]) { (l, r) => l.orElse(r.find(f)) }
  }
  • 将函数f 递归 应用于TreeNode节点以及所有子节点(先应用于parent,后应用child)
def foreach(f: BaseType => Unit): Unit = {
    f(this)
    children.foreach(_.foreach(f))
  }
  • 函数f 递归 应用于TreeNode节点以及所有子节点(先应用于child,后应用parent)
def foreachUp(f: BaseType => Unit): Unit = {
    children.foreach(_.foreachUp(f))
    f(this)
  }
  • 通过前序遍历的方式,将函数f递归应用于当前节点以及所有子节点,返回seq
def map[A](f: BaseType => A): Seq[A] = {
    val ret = new collection.mutable.ArrayBuffer[A]()
    foreach(ret += f(_))
    ret.toSeq
  }
  • flatmap和上面的map整体一致,但是这里的函数f的返回值必须是集合类型,这里需要注意
def flatMap[A](f: BaseType => TraversableOnce[A]): Seq[A] = {
    val ret = new collection.mutable.ArrayBuffer[A]()
    foreach(ret ++= f(_))  //f返回的结果必须是一个集合
    ret.toSeq
  }
def collect[B](pf: PartialFunction[BaseType, B]): Seq[B] = {
    val ret = new collection.mutable.ArrayBuffer[B]()
    val lifted = pf.lift
    foreach(node => lifted(node).foreach(ret.+=))
    ret.toSeq
  }
  • 返回当前节点的所有子节点
def collectLeaves(): Seq[BaseType] = {
    this.collect { case p if p.children.isEmpty => p }
  }
  • 先序的方式访问所有节点,且返回第一个pf作用后结果不为None的节点
def collectFirst[B](pf: PartialFunction[BaseType, B]): Option[B] = {
    val lifted = pf.lift
    lifted(this).orElse {
      children.foldLeft(Option.empty[B]) { (l, r) => l.orElse(r.collectFirst(pf)) }
    }
  }
  • mapProductIterator其实功能和productIterator.map(f).toArray一致
protected def mapProductIterator[B: ClassTag](f: Any => B): Array[B] = {
    val arr = Array.ofDim[B](productArity)
    var i = 0
    while (i < arr.length) {
      arr(i) = f(productElement(i))
      i += 1
    }
    arr
  }
  • 将当前节点的子节点替换为新的子节点
inal def withNewChildren(newChildren: Seq[BaseType]): BaseType = {
    val childrenIndexedSeq = asIndexedSeq(children)
    val newChildrenIndexedSeq = asIndexedSeq(newChildren)
    assert(newChildrenIndexedSeq.size == childrenIndexedSeq.size, "Incorrect number of children")
    if (childrenIndexedSeq.isEmpty ||
        childrenFastEquals(newChildrenIndexedSeq, childrenIndexedSeq)) {
      this
    } else {
      CurrentOrigin.withOrigin(origin) {
        val res = withNewChildrenInternal(newChildrenIndexedSeq)
        res.copyTagsFrom(this)
        res
      }
    }
  }
  • transfrom,调用transformDown,传入一个rule偏函数
def transform(rule: PartialFunction[BaseType, BaseType]): BaseType = {
    transformDown(rule)
  }
  •  transformDown,调用transformDownWithPruning,先序的方式使用rule作用于每个子节点,使用新的节点替换之前的,对节点不影响的,保留原来的节点
def transformDown(rule: PartialFunction[BaseType, BaseType]): BaseType = {
    transformDownWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule)
  }

def transformDownWithPruning(cond: TreePatternBits => Boolean,
    ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[BaseType, BaseType])
  : BaseType = {
    if (!cond.apply(this) || isRuleIneffective(ruleId)) {
      return this
    }
    val afterRule = CurrentOrigin.withOrigin(origin) {
      // 如果 this 是 BaseType 或其子类,则对 this 应用 rule 再返回应用 rule 后的结果,否则返回 this
      rule.applyOrElse(this, identity[BaseType])
    }

    // Check if unchanged and then possibly return old copy to avoid gc churn.
    if (this fastEquals afterRule) {
      // 如果应用了 rule 后节点无变化,则递归将 rule 应用于 children
      val rewritten_plan = mapChildren(_.transformDownWithPruning(cond, ruleId)(rule))
      if (this eq rewritten_plan) {
        markRuleAsIneffective(ruleId)
        this
      } else {
        rewritten_plan
      }
    } else {
      // If the transform function replaces this node with a new one, carry over the tags.
      // 如果应用了 rule 后节点有变化,则本节点换成变化后的节点(children 不变),再将 rule 递归应用于子节点。也就是从根节点往下来应用 rule 替换节点
      afterRule.copyTagsFrom(this)
      afterRule.mapChildren(_.transformDownWithPruning(cond, ruleId)(rule))
    }
  }
  • transformWithPruning,底层调用transformDownWithPruning(功能是返回此节点的副本,其中“规则”已递归应用于树。当“规则”不适用于给定节点时,它将保持不变)
def transformWithPruning(cond: TreePatternBits => Boolean,
    ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[BaseType, BaseType])
  : BaseType = {
    transformDownWithPruning(cond, ruleId)(rule)
  }

def transformDownWithPruning(cond: TreePatternBits => Boolean,
    ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[BaseType, BaseType])
  : BaseType = {
    if (!cond.apply(this) || isRuleIneffective(ruleId)) {
      return this
    }
    val afterRule = CurrentOrigin.withOrigin(origin) {
      // 如果 this 是 BaseType 或其子类,则对 this 应用 rule 再返回应用 rule 后的结果,否则返回 this
      rule.applyOrElse(this, identity[BaseType])
    }

    // Check if unchanged and then possibly return old copy to avoid gc churn.
    if (this fastEquals afterRule) {
      // 如果应用了 rule 后节点无变化,则递归将 rule 应用于 children
      val rewritten_plan = mapChildren(_.transformDownWithPruning(cond, ruleId)(rule))
      if (this eq rewritten_plan) {
        markRuleAsIneffective(ruleId)
        this
      } else {
        rewritten_plan
      }
    } else {
      // If the transform function replaces this node with a new one, carry over the tags.
      // 如果应用了 rule 后节点有变化,则本节点换成变化后的节点(children 不变),再将 rule 递归应用于子节点。也就是从根节点往下来应用 rule 替换节点
      afterRule.copyTagsFrom(this)
      afterRule.mapChildren(_.transformDownWithPruning(cond, ruleId)(rule))
    }
  }
  • transformUp 用后序遍历方式将规则作用于所有节点,调用transformUpWithPruning
def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = {
    transformUpWithPruning(AlwaysProcess.fn, UnknownRuleId)(rule)
  }

def transformUpWithPruning(cond: TreePatternBits => Boolean,
    ruleId: RuleId = UnknownRuleId)(rule: PartialFunction[BaseType, BaseType])
  : BaseType = {
    if (!cond.apply(this) || isRuleIneffective(ruleId)) {
      return this
    }
    val afterRuleOnChildren = mapChildren(_.transformUpWithPruning(cond, ruleId)(rule))
    val newNode = if (this fastEquals afterRuleOnChildren) {
      CurrentOrigin.withOrigin(origin) {
        rule.applyOrElse(this, identity[BaseType])
      }
    } else {
      CurrentOrigin.withOrigin(origin) {
        rule.applyOrElse(afterRuleOnChildren, identity[BaseType])
      }
    }
    if (this eq newNode) {
      markRuleAsIneffective(ruleId)
      this
    } else {
      // If the transform function replaces this node with a new one, carry over the tags.
      newNode.copyTagsFrom(this)
      newNode
    }
  }
  • mapChildren 返回 f 应用于所有子节点后该节点的 copy。
def mapChildren(f: BaseType => BaseType): BaseType = {
    if (containsChild.nonEmpty) {
      withNewChildren(children.map(f))
    } else {
      this
    }
  }

上面罗列的方法,基本就是TreeNode常用的,还有一些不常用的非核心的,这里就不一一介绍了,大家有兴趣的可以自己去看看源码。

另外TreeNode有两个子类,分别是Expression和QueryPlan,这篇文章我们就先讲到这里,后面会对这两个子类也会进行一一介绍的。

有兴趣的可以关注我,后面一起学习sparkSql源码,另外文章中有错误的地方,感谢指出哈。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值