Tree and Graph Traversal with and without Spark Graphx

Tree data structure

Tree and linked list are basic data structure concept taught in computer science class. Tree traversal (also known as walking the tree) is a form of graph traversal and refers to the process of visiting (checking and/or updating) each node in a tree data structure, exactly once per node. Such traversals are classified by the order in which the nodes are visited.

There are 2 different strategies to traverse a tree:

· Depth first traversal: In simple term, walk the tree vertically

· Breadth first traversal: In simple term, walk the tree horizontally

Depth first traversal, there are following strategies:

· Pre-order

o Access the data part of the current node.

o Traverse the left subtree by recursively calling the pre-order function.

o Traverse the right subtree by recursively calling the pre-order function.

· In-order

o Traverse the left subtree by recursively calling the in-order function.

o Access the data part of the current node.

o Traverse the right subtree by recursively calling the in-order function.

· Post-order

o Traverse the left subtree by recursively calling the post-order function.

o Traverse the right subtree by recursively calling the post-order function.

o Access the data part of the current node.

Breadth first traversal:

· Traverse the tree in level-order, visit every node on a level before going to a lower level. This search is referred to as breadth-first search (BFS), as the search tree is broadened as much as possible on each depth before going to the next depth.

Tree traversal Implementation

Tree traversal can be implemented in any programming language of choices. Python appears to be easiest language when coming to implement, because its dynamic typing, or in simple term, you do not need to worry about the data type, because a node in a tree normally is an instance of a class. To start, following is the tree traversal implementation in Python

Python implementation

'''
Developed by George Jen, Jen Tek LLC
Implement the following:
1. Create a binary tree buildBinTree(arr) and add array element to the tree nodes.
if parent node is at index i in the array, 
    then 
        left child of that node is at index (2*i + 1) in the array 
        right child is at index (2*i + 2) in the array.
2. Tree Depth First Traversals:
In Order left to right-->leftRootRight(node)
In Order right to left-->rightRootLeft(node)
Pre Order left to right-->rootLeftRight(node)
Pre Order right to left-->def rootRightLeft(node)
Post Order left to right-->leftRightRoot(node)
Post Order right to left-->rightLeftRoot(node)
3. Tree Breadth first traversal:
Breadth first traversal from left to right--> breadthFirstTraversalLeftRight(node)
Breadth first traversal from right to left--> breadthFirstTraversalRightLeft(node)
4  Discover tree depth --> getBinTreeDepth(node,depth=[0])
It will fill in the depth value in one element depth array
5. Binary tree search
Binary tree search, traverse the tree branch return original index of arr if found, -1 if not found
return_index is an array that that one element.  It is default to [-1], once found, it will be set to [matched node.index]
binTreeSearch(node,element,return_index)
6. Plotting the tree-->plotTree(node)
It will basically perform breadth first traversal and print each node along the way
'''

#Implement DFS Tree search, traverse the tree from root to leave via left and right branches

#Create Tree Structure:

class binTree(object):
    def __init__(self,node,index):
        self.node=node
        self.index=index
        self.right=None
        self.left=None

#Build Tree from sorted arr (arr needs to be presorted, to avoid branch rearrange)        
        
def buildBinTree(arr): #arr is not sorted, Tree is not sorted
    #Create tree starting with root node

#if parent node is at index i in the array 
#then the left child of that node is at index (2*i + 1) and 
#right child is at index (2*i + 2) in the array.

    if len(arr)==0:
        return None
    
    nodes=[0]*len(arr)
    
    level=0
    count=-1
#create nodes for each element of arr, not link them yet to form data structure such as tree

    for i in range(len(arr)):
        nodes[i]=binTree(arr[i],i)
        
#Now link the nodes together to form a binary tree

    i=0
    while 2*i+2<len(nodes):
        nodes[i].left=nodes[2*i+1]
        nodes[i].right=nodes[2*i+2] 
        i+=1
    
    return nodes[0]

def plotTree(node):
    def printGivenLevel(node, level,line_counter,deep):
        if not node:
            return
        if level==1:
            for j in range(line_counter-2):
                print(" ",end="")
            print(node.node, end=" ")
        elif level>1:
            printGivenLevel(node.left, level-1,line_counter-1,deep)
            printGivenLevel(node.right, level-1,line_counter-1,deep)

            
    depth=[0]
    getBinTreeDepth(node,depth)
    for d in range(1,depth[0]):
        line_counter=depth[0]
        deep=[depth[0]]
        printGivenLevel(node, d, line_counter,deep)
        print(" ")

def rootLeftRight(node):
    if node:
        print(node.node)
        rootLeftRight(node.left)
        rootLeftRight(node.right)

    

def rootRightLeft(node):
    if node:
        print(node.node)
        rootRightLeft(node.right)
        rootRightLeft(node.left)

def leftRightRoot(node):
    if node:
        leftRightRoot(node.left)
        leftRightRoot(node.right)
        print(node.node)
        
def rightLeftRoot(node):
    if node:
        rightLeftRoot(node.right)
        rightLeftRoot(node.left)
        print(node.node)

def leftRootRight(node):
    if node:
        leftRootRight(node.left)
        print(node.node)
        leftRootRight(node.right)

#Print the values stored in tree in desc order        
        
def rightRootLeft(node):
    if node:
        #desc
        rightRootLeft(node.right)
        print(node.node)
        rightRootLeft(node.left)

def getBinTreeDepth(node,depth=[0]):
    #maximum depth of a balanced binary tree is usually on its left most path from root to left most leaf
    #it is stored in one element array depth[0]
    if node:
        depth[0]+=1
        getBinTreeDepth(node.left,depth)
    else:
        depth[0]+=1
    
        
        
def breadthFirstTraversalLeftRight(node):
    def printGivenLevel(node, level):
        if not node:
            return
        if level==1:
            print(node.node, end=" ")
        elif level>1:
            printGivenLevel(node.left, level-1)
            printGivenLevel(node.right, level-1)

            
    depth=[0]
    getBinTreeDepth(node,depth)
    for d in range(1,depth[0]):
        printGivenLevel(node, d)
        print(" ")

        
def breadthFirstTraversalRightLeft(node):
    def printGivenLevel(node, level):
        if not node:
            return
        if level==1:
            print(node.node,end=" ")
        elif level>1:
            printGivenLevel(node.right, level-1)
            printGivenLevel(node.left, level-1)
            
    depth=[0]
    getBinTreeDepth(node,depth)
    for d in range(1,depth[0]):
        printGivenLevel(node, d)
        print(" ")
    

#Binary tree search, traverse the tree branch return original index of arr if found, -1 if not found
#return_index is an array that that one element.  It is default to [-1], once found, it will be set to [matched node.index]
def binTreeSearch(node,element,return_index):

    if return_index[0]!=-1:
        print(" ")
        return
    if node:
        if node.node==element:
            return_index[0]=node.index
            return
        else:
            binTreeSearch(node.left,element,return_index)
            binTreeSearch(node.right,element,return_index)

#Driver code to run the traversal in different strategies:

if __name__=='__main__':
    #build the tree from arr
    print("Now build the tree from arr, return root node")
    root=buildBinTree([1,2,3,4,5,6,7])
#Traverse the tree in pre-order, in-order and post-order:
    print("Pre order tree traversal, left to right")
    rootLeftRight(root)
    print("pre order tree traversal, right to left")
    rootRightLeft(root)
    print("In order tree traversal, left to right")
    leftRootRight(root)
    print("In Order tree traversal, right to left")
    rightRootLeft(root)
    print("Post order tree traversal, left to right")
    leftRightRoot(root)
    print("Post order tree traversal, right to left")
    rightLeftRoot(root)
    print("Breadth first tree traversal, left to right")
    breadthFirstTraversalLeftRight(root)
    print("Breadth first tree traversal, right to left")
    breadthFirstTraversalRightLeft(root)
    print("Search value 5, if found, return index in original arr")
    search_index=[-1]
    binTreeSearch(root,5,search_index)
    if search_index[0]!=-1:
        print("Found, return index is {}".format(search_index[0]))
    else:
        print("Not found, return {}".format(search_index[0]))
    print("Search value 10, if not found, return -1")
    search_index=[-1]
    binTreeSearch(root,10,search_index)
    if search_index[0]!=-1:
        print("Found, return index is {}".format(search_index[0]))
    else:
        print("Not found, return {}".format(search_index[0]))
    
    print("Now plotting the tree")
plotTree(root)

Running above driver code produce below output:

Now build the tree from arr, return root node
Pre order tree traversal, left to right
1
2
4
5
3
6
7
pre order tree traversal, right to left
1
3
7
6
2
5
4
In order tree traversal, left to right
4
2
5
1
6
3
7
In Order tree traversal, right to left
7
3
6
1
5
2
4
Post order tree traversal, left to right
4
5
2
6
7
3
1
Post order tree traversal, right to left
7
6
3
5
4
2
1
Breadth first tree traversal, left to right
1  
2 3  
4 5 6 7  
Breadth first tree traversal, right to left
1  
3 2  
7 6 5 4  
Search value 5, if found, return index in original arr
 
Found, return index is 4
Search value 10, if not found, return -1
Not found, return -1
Now plotting the tree
  1  
 2  3  
4 5 6 7  

Dynamic typing vs static typing

Programming language with dynamic typing means the compiler or interpreter takes care the data type of a variable, programmer does not need to specify datatype. Python is such language. Partly because in Python, everything is object or subclass of object class.

Programming language with static typing means compiler or interpreter only enforce the data type and will give you an error if data type mismatches. Programmer is responsible for making sure proper data type is used to declare a variable. C++, Java and Scala are such language

It is easier to write tree traversal program in Python because of, in my view, the advantage of no data type needed to define tree node class. If you do it in other program language, you will need to take care the typing before you can build a tree.

Scala Implementation

Following is my attempt to write the same tree traversal in Scala. Why Scala? Because it is the language natively supported by Apache Spark Graphx, it is essentially Java code after compiling to java byte code.

Since I already have code written in Python, I just port it to Scala

/*
Define Tree case class first, which defines the data type of a tree. Consider this is like struct in C++
Define method to build a binary tree
*/
case class Tree[+T](attr : T, left :Option[Tree[T]], right : Option[Tree[T]])
/*
The above tree data structure is defined:
Tree(Stored value, left pointer, right pointer) 
*/

def buildBinTree[T](lines: IndexedSeq[IndexedSeq[T]]) = {
  def recurseFunc[T](lines: IndexedSeq[IndexedSeq[T]]): IndexedSeq[Tree[T]] = lines match {
    case line +: IndexedSeq() => line.map(Tree(_, None, None))
    case line +: rest => {
      val prevTrees = recurseFunc(rest)
      (line, prevTrees.grouped(2).toIndexedSeq).zipped
      .map{case (v, IndexedSeq(left, right)) => Tree(v, Some(left), Some(right))}
    }
    case _ => IndexedSeq.empty
  }
  recurseFunc(lines).headOption
}

/*
Get some data to be store in a tree
*/
val values = """1
2 3
4 5 6 7
""".stripMargin
//Build the tree
val tree = buildBinTree(values.lines.map(_.filterNot(_ == ' ').toIndexedSeq).toIndexedSeq)
/*
Output:
tree: Option[Tree[Char]] = Some(Tree(1,Some(Tree(2,Some(Tree(4,None,None)),Some(Tree(5,None,None)))),Some(Tree(3,Some(Tree(6,None,None)),Some(Tree(7,None,None))))))

*/
// Pre Order left to right-->rootLeftRight(node)
def rootLeftRight(node: Option[Tree[Char]]): Unit ={

     if (node.filter(_ != None)!=None)
      {
         node.foreach{ i => println(i.attr)}
         node.foreach{i=>rootLeftRight(i.left)}
         node.foreach(i=>rootLeftRight(i.right))
      }
}
//run it
rootLeftRight(tree)
/*
1
2
4
5
3
6
7
*/

//Pre Order right to left-->rootRightLeft(node)

def rootRightLeft(node: Option[Tree[Char]]): Unit ={

     if (node.filter(_ != None)!=None)
      {
         node.foreach{ i => println(i.attr)}
         node.foreach{i=>rootRightLeft(i.right)}
         node.foreach(i=>rootRightLeft(i.left))
      }
}
//run it
rootRightLeft(tree)
/*
Output:
1
3
7
6
2
5
4

*/

def leftRightRoot(node: Option[Tree[Char]]): Unit ={

     if (node.filter(_.attr != None)!=None)
      {
         node.foreach{i=>leftRightRoot(i.left)}
         node.foreach(i=>leftRightRoot(i.right))
         node.foreach{ i => println(i.attr)}
      }
}
//Post Order left to right-->leftRightRoot(node)
def leftRightRoot(node: Option[Tree[Char]]): Unit ={

     if (node.filter(_.attr != None)!=None)
      {
         node.foreach{i=>leftRightRoot(i.left)}
         node.foreach(i=>leftRightRoot(i.right))
         node.foreach{ i => println(i.attr)}
      }
}
//run it
leftRightRoot(tree)
/*
Output:
4
5
2
6
7
3
1

*/
//Post Order right to left-->rightLeftRoot(node)
def rightLeftRoot(node: Option[Tree[Char]]): Unit ={

     if (node.filter(_.attr != None)!=None)
      {
         node.foreach{i=>rightLeftRoot(i.right)}
         node.foreach(i=>rightLeftRoot(i.left))
         node.foreach{ i => println(i.attr)}
      }       
}
//run it
rightLeftRoot(tree)
/*
Output:
7
6
3
5
4
2
1

*/
//In Order left to right-->leftRootRight(node)
def leftRootRight(node: Option[Tree[Char]]): Unit ={

     if (node.filter(_.attr != None)!=None)
      {
         node.foreach{i=>leftRootRight(i.left)}
         node.foreach{ i => println(i.attr)}
         node.foreach(i=>leftRootRight(i.right))
      }       
}
//run it
leftRootRight(tree)
/*
Output:
4
2
5
1
6
3
7

*/
//In Order right to left-->rightRootLeft(node)
def rightRootLeft(node: Option[Tree[Char]]): Unit ={

     if (node.filter(_.attr != None)!=None)
      {
         node.foreach{i=>rightRootLeft(i.right)}
         node.foreach{ i => println(i.attr)}
         node.foreach(i=>rightRootLeft(i.left))
      }       
}
//run it
rightRootLeft(tree)
/*
Output:
7
3
6
1
5
2
4

*/
//Breadth first traversal from left to right--> breadthFirstTraversalLeftRight(node)

//Helper function to calculate depth of tree needed by BFS
def getBinTreeDepth(node: Option[Tree[Char]], depth: Array[Int]=Array(0)): Unit = {
    if (node.filter(_.attr != None)!=None)
      {
       node.foreach(i=>getBinTreeDepth(i.left, depth))
       depth(0)=depth(0)+1

      }   
}
def breadthFirstTraversalLeftRight(node: Option[Tree[Char]]): Unit={
    def printGivenLevel(node: Option[Tree[Char]], level: Int):  Unit={
        if (node.filter(_.attr != None)!=None)
          {
            if (level==1)
                {
                node.foreach{ i => print(i.attr)}
                print(" ")
                }   
            else if (level>1)
              {
                node.foreach{i=>printGivenLevel(i.left, level-1)}
                node.foreach{i=>printGivenLevel(i.right, level-1)}
              }
    }
    }
    var depth=Array(0)
    getBinTreeDepth(node,depth)
    for (d<-1 to depth(0))
        {
        printGivenLevel(node, d)
        println(" ")
        }
}
//Run it
breadthFirstTraversalLeftRight(tree)
/*
1  
2 3  
4 5 6 7 

*/
//Breadth first traversal from right to left--> breadthFirstTraversalRightLeft(node)
def breadthFirstTraversalRightLeft(node: Option[Tree[Char]]): Unit={
    def printGivenLevel(node: Option[Tree[Char]], level: Int):  Unit={
        if (node.filter(_.attr != None)!=None)
          {
            if (level==1)
                {
                node.foreach{ i => print(i.attr)}
                print(" ")
                }   
            else if (level>1)
              {
                node.foreach{i=>printGivenLevel(i.right, level-1)}
                node.foreach{i=>printGivenLevel(i.left, level-1)}
              }
    }
    }
    
    
    var depth=Array(0)
    getBinTreeDepth(node,depth)
    for (d<-1 to depth(0))
        {
        printGivenLevel(node, d)
        println(" ")
        }  
}
//run it
breadthFirstTraversalRightLeft(tree)
/*
Output:
1  
3 2  
7 6 5 4

*/

Scalability

In theory, the above tree traversal codes in Python or Scala do what they were supposed to for a tree that has 7 nodes including root and leaves. Are they scalable? For a tree that has millions or billions of nodes? Of cause not. Unless the codes are written to cut the tree into pieces and let multiple computer (worker nodes) to process each part of the tree and produce the final traversal result.

How to write tree traversal code that is scalable, it is actually easy when writing tree traversal code to run on Apache Spark Graphx

Graph traversal with Apache Spark Graphx Pregel

What is the difference between a tree and a graph?

Tree is a graph, except tree does not have loops while regular graph does.

To construct the same tree in Apache Spark Graphx:

import org.apache.spark._
import org.apache.spark.graphx._
import org.apache.spark.rdd.RDD
val vertices=sc.parallelize(Seq((1L,1.0),(2L,2.0),(3L,3.0),(4L,4.0),(5L,5.0),(6L,6.0),(7L,7.0)))
val edges=sc.parallelize(Seq(Edge(1L,2L,1),Edge(1L,3L,2), Edge(2L,4L,3), Edge(2L,5L,4), Edge(3L,6L, 5), Edge(3L,7L, 6)))
val graph=Graph(vertices,edges)

//this graph is a tree, because there is no loop amongst any nodes.
//How to tell, run triangle count in the graph, it should tell 0 triangle on any node
//Following code show triangle count

graph.triangleCount.vertices.collect.foreach(println)

/*
Output tuple pair, 1st element is node id, 2nd element is triangle count
All zeros
(4,0)
(6,0)
(2,0)
(1,0)
(3,0)
(7,0)
(5,0)

*/
//Following codes travers the tree graph:
// travers starts from root of the tree, which is id 1
val start: VertexId = 1
//set start vertex attributes to 0.0, others to Double.inifinity
val markedGraph = graph.mapVertices((id, _) => if (id == start) 0.0 else
  Double.PositiveInfinity)

val vprog = { (id: VertexId, attr: Double, msg: Double) => math.min(attr,msg) }

val sendMessage = { (triplet: EdgeTriplet[Double, Int]) =>
  var run:Iterator[(VertexId, Double)] = Iterator.empty
//Only the vertex that has Double.PositiveInfinity has not been visited, to avoid endless loop
  if(!(triplet.srcAttr != Double.PositiveInfinity && triplet.dstAttr != Double.PositiveInfinity)){
    if(triplet.srcAttr != Double.PositiveInfinity){
      run = Iterator((triplet.dstId,triplet.srcAttr+1))
    }else{
      run = Iterator((triplet.srcId,triplet.dstAttr+1))
    }
  }
  run
}

val mergeMessage = { (a: Double, b: Double) => math.min(a,b) }

val graphTraverse = markedGraph.pregel(Double.PositiveInfinity, 20)(vprog, sendMessage, reduceMessage)

println(graphTraverse.vertices.collect.map(x=>x._1).mkString("\n"))
/*
Output node id:
4
6
2
1
3
7
5
*/

The above graph is a tree, because there is no loop. The earlier standalone Python and Scala programs can only traverse tree, i.e., a graph that does not have loop. If a graph has loop, for example, in social media, “Jack is a friend of Join, Join is a friend of Robert, Robert is a friend of Jack”, then this would be a loop. If a graph has a loop, the earlier traversal strategy will be end up recurse endlessly and will eventually fail due to out of memory in the stack.

The code that runs on Apache Spark Graphx has taken care of the case if the Graph has loop and will not loop indefinitely, by using marking strategy, meaning it will skip the node that has been visited earlier.

Let’s modify the tree to make it a graph with loop, by adding an edge Edge(7L,1L,8)

val vertices=sc.parallelize(Seq((1L,1.0),(2L,2.0),(3L,3.0),(4L,4.0),(5L,5.0),(6L,6.0),(7L,7.0)))
val edges=sc.parallelize(Seq(Edge(1L,2L,1),Edge(1L,3L,2), Edge(2L,4L,3), Edge(2L,5L,4), Edge(3L,6L, 5), Edge(3L,7L, 6), Edge(7L,1L,8)))
val graph=Graph(vertices,edges)

//This will not be a tree, because there will be triangle count > 0

graph.triangleCount.vertices.collect.foreach(println)

/*
Output tuple pair, 1st element is node id, 2nd element is triangle count
Some are > 0
(4,0)
(6,0)
(2,0)
(1,1)
(3,1)
(7,1)
(5,0)

*/

//run traversal on this Graph with loop
val start: VertexId = 1
//set start vertex attributes to 0.0, others to Double.inifinity
val markedGraph = graph.mapVertices((id, _) => if (id == start) 0.0 else
  Double.PositiveInfinity)

val vprog = { (id: VertexId, attr: Double, msg: Double) => math.min(attr,msg) }

val sendMessage = { (triplet: EdgeTriplet[Double, Int]) =>
  var run:Iterator[(VertexId, Double)] = Iterator.empty
//Only the vertex that has Double.PositiveInfinity has not been visited, to avoid endless loop
  if(!(triplet.srcAttr != Double.PositiveInfinity && triplet.dstAttr != Double.PositiveInfinity)){
    if(triplet.srcAttr != Double.PositiveInfinity){
      run = Iterator((triplet.dstId,triplet.srcAttr+1))
    }else{
      run = Iterator((triplet.srcId,triplet.dstAttr+1))
    }
  }
  run
}

val mergeMessage = { (a: Double, b: Double) => math.min(a,b) }

val graphTraverse = markedGraph.pregel(Double.PositiveInfinity, 20)(vprog, sendMessage, mergeMessage)

println(graphTraverse.vertices.collect.map(x=>x._1).mkString("\n"))

/*
Output the same traversal
4
6
2
1
3
7
5


*/

Summary:

Tree and Graph traversal on Spark Graphx is scalable. Because the traversal is running on the Spark cluster with multiple worker nodes. In fact, the traversal code is running under Google’s Pregel framework, which is designed to run large graph computing, for Pregel, you can see my prior writing:

https://www.linkedin.com/posts/dr-george-jen-257a7626_bulk-synchronous-parallel-google-pregel-spark-activity-6654510643561013248-G3Zi

As always, codes used in this writing are in my GitHub repo.

https://github.com/geyungjen/jentekllc

Thank you for your time viewing this writing.

Last updated