Tree and Graph Traversal with and without Spark Graphx

# Tree data structure

Tree and linked list are basic data structure concept taught in computer science class. Tree traversal (also known as walking the tree) is a form of graph traversal and refers to the process of visiting (checking and/or updating) each node in a tree data structure, exactly once per node. Such traversals are classified by the order in which the nodes are visited.
There are 2 different strategies to traverse a tree:
· Depth first traversal: In simple term, walk the tree vertically
· Breadth first traversal: In simple term, walk the tree horizontally
Depth first traversal, there are following strategies:
· Pre-order
o Access the data part of the current node.
o Traverse the left subtree by recursively calling the pre-order function.
o Traverse the right subtree by recursively calling the pre-order function.
· In-order
o Traverse the left subtree by recursively calling the in-order function.
o Access the data part of the current node.
o Traverse the right subtree by recursively calling the in-order function.
· Post-order
o Traverse the left subtree by recursively calling the post-order function.
o Traverse the right subtree by recursively calling the post-order function.
o Access the data part of the current node.
· Traverse the tree in level-order, visit every node on a level before going to a lower level. This search is referred to as breadth-first search (BFS), as the search tree is broadened as much as possible on each depth before going to the next depth.

# Tree traversal Implementation

Tree traversal can be implemented in any programming language of choices. Python appears to be easiest language when coming to implement, because its dynamic typing, or in simple term, you do not need to worry about the data type, because a node in a tree normally is an instance of a class. To start, following is the tree traversal implementation in Python

## Python implementation

1
'''
2
Developed by George Jen, Jen Tek LLC
3
Implement the following:
4
1. Create a binary tree buildBinTree(arr) and add array element to the tree nodes.
5
if parent node is at index i in the array,
6
then
7
left child of that node is at index (2*i + 1) in the array
8
right child is at index (2*i + 2) in the array.
9
2. Tree Depth First Traversals:
10
In Order left to right-->leftRootRight(node)
11
In Order right to left-->rightRootLeft(node)
12
Pre Order left to right-->rootLeftRight(node)
13
Pre Order right to left-->def rootRightLeft(node)
14
Post Order left to right-->leftRightRoot(node)
15
Post Order right to left-->rightLeftRoot(node)
16
17
18
19
4 Discover tree depth --> getBinTreeDepth(node,depth=)
20
It will fill in the depth value in one element depth array
21
5. Binary tree search
22
Binary tree search, traverse the tree branch return original index of arr if found, -1 if not found
23
return_index is an array that that one element. It is default to [-1], once found, it will be set to [matched node.index]
24
binTreeSearch(node,element,return_index)
25
6. Plotting the tree-->plotTree(node)
26
It will basically perform breadth first traversal and print each node along the way
27
'''
28
29
#Implement DFS Tree search, traverse the tree from root to leave via left and right branches
30
31
#Create Tree Structure:
32
33
class binTree(object):
34
def __init__(self,node,index):
35
self.node=node
36
self.index=index
37
self.right=None
38
self.left=None
39
40
#Build Tree from sorted arr (arr needs to be presorted, to avoid branch rearrange)
41
42
def buildBinTree(arr): #arr is not sorted, Tree is not sorted
43
#Create tree starting with root node
44
45
#if parent node is at index i in the array
46
#then the left child of that node is at index (2*i + 1) and
47
#right child is at index (2*i + 2) in the array.
48
49
if len(arr)==0:
50
return None
51
52
nodes=*len(arr)
53
54
level=0
55
count=-1
56
#create nodes for each element of arr, not link them yet to form data structure such as tree
57
58
for i in range(len(arr)):
59
nodes[i]=binTree(arr[i],i)
60
61
#Now link the nodes together to form a binary tree
62
63
i=0
64
while 2*i+2<len(nodes):
65
nodes[i].left=nodes[2*i+1]
66
nodes[i].right=nodes[2*i+2]
67
i+=1
68
69
return nodes
70
71
def plotTree(node):
72
def printGivenLevel(node, level,line_counter,deep):
73
if not node:
74
return
75
if level==1:
76
for j in range(line_counter-2):
77
print(" ",end="")
78
print(node.node, end=" ")
79
elif level>1:
80
printGivenLevel(node.left, level-1,line_counter-1,deep)
81
printGivenLevel(node.right, level-1,line_counter-1,deep)
82
83
84
depth=
85
getBinTreeDepth(node,depth)
86
for d in range(1,depth):
87
line_counter=depth
88
deep=[depth]
89
printGivenLevel(node, d, line_counter,deep)
90
print(" ")
91
92
def rootLeftRight(node):
93
if node:
94
print(node.node)
95
rootLeftRight(node.left)
96
rootLeftRight(node.right)
97
98
99
100
def rootRightLeft(node):
101
if node:
102
print(node.node)
103
rootRightLeft(node.right)
104
rootRightLeft(node.left)
105
106
def leftRightRoot(node):
107
if node:
108
leftRightRoot(node.left)
109
leftRightRoot(node.right)
110
print(node.node)
111
112
def rightLeftRoot(node):
113
if node:
114
rightLeftRoot(node.right)
115
rightLeftRoot(node.left)
116
print(node.node)
117
118
def leftRootRight(node):
119
if node:
120
leftRootRight(node.left)
121
print(node.node)
122
leftRootRight(node.right)
123
124
#Print the values stored in tree in desc order
125
126
def rightRootLeft(node):
127
if node:
128
#desc
129
rightRootLeft(node.right)
130
print(node.node)
131
rightRootLeft(node.left)
132
133
def getBinTreeDepth(node,depth=):
134
#maximum depth of a balanced binary tree is usually on its left most path from root to left most leaf
135
#it is stored in one element array depth
136
if node:
137
depth+=1
138
getBinTreeDepth(node.left,depth)
139
else:
140
depth+=1
141
142
143
144
145
def printGivenLevel(node, level):
146
if not node:
147
return
148
if level==1:
149
print(node.node, end=" ")
150
elif level>1:
151
printGivenLevel(node.left, level-1)
152
printGivenLevel(node.right, level-1)
153
154
155
depth=
156
getBinTreeDepth(node,depth)
157
for d in range(1,depth):
158
printGivenLevel(node, d)
159
print(" ")
160
161
162
163
def printGivenLevel(node, level):
164
if not node:
165
return
166
if level==1:
167
print(node.node,end=" ")
168
elif level>1:
169
printGivenLevel(node.right, level-1)
170
printGivenLevel(node.left, level-1)
171
172
depth=
173
getBinTreeDepth(node,depth)
174
for d in range(1,depth):
175
printGivenLevel(node, d)
176
print(" ")
177
178
179
#Binary tree search, traverse the tree branch return original index of arr if found, -1 if not found
180
#return_index is an array that that one element. It is default to [-1], once found, it will be set to [matched node.index]
181
def binTreeSearch(node,element,return_index):
182
183
if return_index!=-1:
184
print(" ")
185
return
186
if node:
187
if node.node==element:
188
return_index=node.index
189
return
190
else:
191
binTreeSearch(node.left,element,return_index)
192
binTreeSearch(node.right,element,return_index)
193
194
#Driver code to run the traversal in different strategies:
195
196
if __name__=='__main__':
197
#build the tree from arr
198
print("Now build the tree from arr, return root node")
199
root=buildBinTree([1,2,3,4,5,6,7])
200
#Traverse the tree in pre-order, in-order and post-order:
201
print("Pre order tree traversal, left to right")
202
rootLeftRight(root)
203
print("pre order tree traversal, right to left")
204
rootRightLeft(root)
205
print("In order tree traversal, left to right")
206
leftRootRight(root)
207
print("In Order tree traversal, right to left")
208
rightRootLeft(root)
209
print("Post order tree traversal, left to right")
210
leftRightRoot(root)
211
print("Post order tree traversal, right to left")
212
rightLeftRoot(root)
213
print("Breadth first tree traversal, left to right")
214
215
print("Breadth first tree traversal, right to left")
216
217
print("Search value 5, if found, return index in original arr")
218
search_index=[-1]
219
binTreeSearch(root,5,search_index)
220
if search_index!=-1:
221
print("Found, return index is {}".format(search_index))
222
else:
223
224
225
search_index=[-1]
226
binTreeSearch(root,10,search_index)
227
if search_index!=-1:
228
print("Found, return index is {}".format(search_index))
229
else:
230
231
232
print("Now plotting the tree")
233
plotTree(root)
234
235
Running above driver code produce below output:
236
237
Now build the tree from arr, return root node
238
Pre order tree traversal, left to right
239
1
240
2
241
4
242
5
243
3
244
6
245
7
246
pre order tree traversal, right to left
247
1
248
3
249
7
250
6
251
2
252
5
253
4
254
In order tree traversal, left to right
255
4
256
2
257
5
258
1
259
6
260
3
261
7
262
In Order tree traversal, right to left
263
7
264
3
265
6
266
1
267
5
268
2
269
4
270
Post order tree traversal, left to right
271
4
272
5
273
2
274
6
275
7
276
3
277
1
278
Post order tree traversal, right to left
279
7
280
6
281
3
282
5
283
4
284
2
285
1
286
Breadth first tree traversal, left to right
287
1
288
2 3
289
4 5 6 7
290
Breadth first tree traversal, right to left
291
1
292
3 2
293
7 6 5 4
294
Search value 5, if found, return index in original arr
295
296
Found, return index is 4
297
298
299
Now plotting the tree
300
1
301
2 3
302
4 5 6 7
303
Copied!

## Dynamic typing vs static typing

Programming language with dynamic typing means the compiler or interpreter takes care the data type of a variable, programmer does not need to specify datatype. Python is such language. Partly because in Python, everything is object or subclass of object class.
Programming language with static typing means compiler or interpreter only enforce the data type and will give you an error if data type mismatches. Programmer is responsible for making sure proper data type is used to declare a variable. C++, Java and Scala are such language
It is easier to write tree traversal program in Python because of, in my view, the advantage of no data type needed to define tree node class. If you do it in other program language, you will need to take care the typing before you can build a tree.

## Scala Implementation

Following is my attempt to write the same tree traversal in Scala. Why Scala? Because it is the language natively supported by Apache Spark Graphx, it is essentially Java code after compiling to java byte code.
Since I already have code written in Python, I just port it to Scala
1
/*
2
Define Tree case class first, which defines the data type of a tree. Consider this is like struct in C++
3
Define method to build a binary tree
4
*/
5
case class Tree[+T](attr : T, left :Option[Tree[T]], right : Option[Tree[T]])
6
/*
7
The above tree data structure is defined:
8
Tree(Stored value, left pointer, right pointer)
9
*/
10
11
def buildBinTree[T](lines: IndexedSeq[IndexedSeq[T]]) = {
12
def recurseFunc[T](lines: IndexedSeq[IndexedSeq[T]]): IndexedSeq[Tree[T]] = lines match {
13
case line +: IndexedSeq() => line.map(Tree(_, None, None))
14
case line +: rest => {
15
val prevTrees = recurseFunc(rest)
16
(line, prevTrees.grouped(2).toIndexedSeq).zipped
17
.map{case (v, IndexedSeq(left, right)) => Tree(v, Some(left), Some(right))}
18
}
19
case _ => IndexedSeq.empty
20
}
21
22
}
23
24
/*
25
Get some data to be store in a tree
26
*/
27
val values = """1
28
2 3
29
4 5 6 7
30
""".stripMargin
31
//Build the tree
32
val tree = buildBinTree(values.lines.map(_.filterNot(_ == ' ').toIndexedSeq).toIndexedSeq)
33
/*
34
Output:
35
tree: Option[Tree[Char]] = Some(Tree(1,Some(Tree(2,Some(Tree(4,None,None)),Some(Tree(5,None,None)))),Some(Tree(3,Some(Tree(6,None,None)),Some(Tree(7,None,None))))))
36
37
*/
38
// Pre Order left to right-->rootLeftRight(node)
39
def rootLeftRight(node: Option[Tree[Char]]): Unit ={
40
41
if (node.filter(_ != None)!=None)
42
{
43
node.foreach{ i => println(i.attr)}
44
node.foreach{i=>rootLeftRight(i.left)}
45
node.foreach(i=>rootLeftRight(i.right))
46
}
47
}
48
//run it
49
rootLeftRight(tree)
50
/*
51
1
52
2
53
4
54
5
55
3
56
6
57
7
58
*/
59
60
//Pre Order right to left-->rootRightLeft(node)
61
62
def rootRightLeft(node: Option[Tree[Char]]): Unit ={
63
64
if (node.filter(_ != None)!=None)
65
{
66
node.foreach{ i => println(i.attr)}
67
node.foreach{i=>rootRightLeft(i.right)}
68
node.foreach(i=>rootRightLeft(i.left))
69
}
70
}
71
//run it
72
rootRightLeft(tree)
73
/*
74
Output:
75
1
76
3
77
7
78
6
79
2
80
5
81
4
82
83
*/
84
85
def leftRightRoot(node: Option[Tree[Char]]): Unit ={
86
87
if (node.filter(_.attr != None)!=None)
88
{
89
node.foreach{i=>leftRightRoot(i.left)}
90
node.foreach(i=>leftRightRoot(i.right))
91
node.foreach{ i => println(i.attr)}
92
}
93
}
94
//Post Order left to right-->leftRightRoot(node)
95
def leftRightRoot(node: Option[Tree[Char]]): Unit ={
96
97
if (node.filter(_.attr != None)!=None)
98
{
99
node.foreach{i=>leftRightRoot(i.left)}
100
node.foreach(i=>leftRightRoot(i.right))
101
node.foreach{ i => println(i.attr)}
102
}
103
}
104
//run it
105
leftRightRoot(tree)
106
/*
107
Output:
108
4
109
5
110
2
111
6
112
7
113
3
114
1
115
116
*/
117
//Post Order right to left-->rightLeftRoot(node)
118
def rightLeftRoot(node: Option[Tree[Char]]): Unit ={
119
120
if (node.filter(_.attr != None)!=None)
121
{
122
node.foreach{i=>rightLeftRoot(i.right)}
123
node.foreach(i=>rightLeftRoot(i.left))
124
node.foreach{ i => println(i.attr)}
125
}
126
}
127
//run it
128
rightLeftRoot(tree)
129
/*
130
Output:
131
7
132
6
133
3
134
5
135
4
136
2
137
1
138
139
*/
140
//In Order left to right-->leftRootRight(node)
141
def leftRootRight(node: Option[Tree[Char]]): Unit ={
142
143
if (node.filter(_.attr != None)!=None)
144
{
145
node.foreach{i=>leftRootRight(i.left)}
146
node.foreach{ i => println(i.attr)}
147
node.foreach(i=>leftRootRight(i.right))
148
}
149
}
150
//run it
151
leftRootRight(tree)
152
/*
153
Output:
154
4
155
2
156
5
157
1
158
6
159
3
160
7
161
162
*/
163
//In Order right to left-->rightRootLeft(node)
164
def rightRootLeft(node: Option[Tree[Char]]): Unit ={
165
166
if (node.filter(_.attr != None)!=None)
167
{
168
node.foreach{i=>rightRootLeft(i.right)}
169
node.foreach{ i => println(i.attr)}
170
node.foreach(i=>rightRootLeft(i.left))
171
}
172
}
173
//run it
174
rightRootLeft(tree)
175
/*
176
Output:
177
7
178
3
179
6
180
1
181
5
182
2
183
4
184
185
*/
186
187
188
//Helper function to calculate depth of tree needed by BFS
189
def getBinTreeDepth(node: Option[Tree[Char]], depth: Array[Int]=Array(0)): Unit = {
190
if (node.filter(_.attr != None)!=None)
191
{
192
node.foreach(i=>getBinTreeDepth(i.left, depth))
193
depth(0)=depth(0)+1
194
195
}
196
}
197
198
def printGivenLevel(node: Option[Tree[Char]], level: Int): Unit={
199
if (node.filter(_.attr != None)!=None)
200
{
201
if (level==1)
202
{
203
node.foreach{ i => print(i.attr)}
204
print(" ")
205
}
206
else if (level>1)
207
{
208
node.foreach{i=>printGivenLevel(i.left, level-1)}
209
node.foreach{i=>printGivenLevel(i.right, level-1)}
210
}
211
}
212
}
213
var depth=Array(0)
214
getBinTreeDepth(node,depth)
215
for (d<-1 to depth(0))
216
{
217
printGivenLevel(node, d)
218
println(" ")
219
}
220
}
221
//Run it
222
223
/*
224
1
225
2 3
226
4 5 6 7
227
228
*/
229
230
231
def printGivenLevel(node: Option[Tree[Char]], level: Int): Unit={
232
if (node.filter(_.attr != None)!=None)
233
{
234
if (level==1)
235
{
236
node.foreach{ i => print(i.attr)}
237
print(" ")
238
}
239
else if (level>1)
240
{
241
node.foreach{i=>printGivenLevel(i.right, level-1)}
242
node.foreach{i=>printGivenLevel(i.left, level-1)}
243
}
244
}
245
}
246
247
248
var depth=Array(0)
249
getBinTreeDepth(node,depth)
250
for (d<-1 to depth(0))
251
{
252
printGivenLevel(node, d)
253
println(" ")
254
}
255
}
256
//run it
257
258
/*
259
Output:
260
1
261
3 2
262
7 6 5 4
263
264
*/
265
Copied!

## Scalability

In theory, the above tree traversal codes in Python or Scala do what they were supposed to for a tree that has 7 nodes including root and leaves. Are they scalable? For a tree that has millions or billions of nodes? Of cause not. Unless the codes are written to cut the tree into pieces and let multiple computer (worker nodes) to process each part of the tree and produce the final traversal result.
How to write tree traversal code that is scalable, it is actually easy when writing tree traversal code to run on Apache Spark Graphx

## Graph traversal with Apache Spark Graphx Pregel

What is the difference between a tree and a graph?
Tree is a graph, except tree does not have loops while regular graph does.
To construct the same tree in Apache Spark Graphx:
1
import org.apache.spark._
2
import org.apache.spark.graphx._
3
import org.apache.spark.rdd.RDD
4
val vertices=sc.parallelize(Seq((1L,1.0),(2L,2.0),(3L,3.0),(4L,4.0),(5L,5.0),(6L,6.0),(7L,7.0)))
5
val edges=sc.parallelize(Seq(Edge(1L,2L,1),Edge(1L,3L,2), Edge(2L,4L,3), Edge(2L,5L,4), Edge(3L,6L, 5), Edge(3L,7L, 6)))
6
val graph=Graph(vertices,edges)
7
8
//this graph is a tree, because there is no loop amongst any nodes.
9
//How to tell, run triangle count in the graph, it should tell 0 triangle on any node
10
//Following code show triangle count
11
12
graph.triangleCount.vertices.collect.foreach(println)
13
14
/*
15
Output tuple pair, 1st element is node id, 2nd element is triangle count
16
All zeros
17
(4,0)
18
(6,0)
19
(2,0)
20
(1,0)
21
(3,0)
22
(7,0)
23
(5,0)
24
25
*/
26
//Following codes travers the tree graph:
27
// travers starts from root of the tree, which is id 1
28
val start: VertexId = 1
29
//set start vertex attributes to 0.0, others to Double.inifinity
30
val markedGraph = graph.mapVertices((id, _) => if (id == start) 0.0 else
31
Double.PositiveInfinity)
32
33
val vprog = { (id: VertexId, attr: Double, msg: Double) => math.min(attr,msg) }
34
35
val sendMessage = { (triplet: EdgeTriplet[Double, Int]) =>
36
var run:Iterator[(VertexId, Double)] = Iterator.empty
37
//Only the vertex that has Double.PositiveInfinity has not been visited, to avoid endless loop
38
if(!(triplet.srcAttr != Double.PositiveInfinity && triplet.dstAttr != Double.PositiveInfinity)){
39
if(triplet.srcAttr != Double.PositiveInfinity){
40
run = Iterator((triplet.dstId,triplet.srcAttr+1))
41
}else{
42
run = Iterator((triplet.srcId,triplet.dstAttr+1))
43
}
44
}
45
run
46
}
47
48
val mergeMessage = { (a: Double, b: Double) => math.min(a,b) }
49
50
val graphTraverse = markedGraph.pregel(Double.PositiveInfinity, 20)(vprog, sendMessage, reduceMessage)
51
52
println(graphTraverse.vertices.collect.map(x=>x._1).mkString("\n"))
53
/*
54
Output node id:
55
4
56
6
57
2
58
1
59
3
60
7
61
5
62
*/
63
Copied!
The above graph is a tree, because there is no loop. The earlier standalone Python and Scala programs can only traverse tree, i.e., a graph that does not have loop. If a graph has loop, for example, in social media, “Jack is a friend of Join, Join is a friend of Robert, Robert is a friend of Jack”, then this would be a loop. If a graph has a loop, the earlier traversal strategy will be end up recurse endlessly and will eventually fail due to out of memory in the stack.
The code that runs on Apache Spark Graphx has taken care of the case if the Graph has loop and will not loop indefinitely, by using marking strategy, meaning it will skip the node that has been visited earlier.
Let’s modify the tree to make it a graph with loop, by adding an edge Edge(7L,1L,8)
1
val vertices=sc.parallelize(Seq((1L,1.0),(2L,2.0),(3L,3.0),(4L,4.0),(5L,5.0),(6L,6.0),(7L,7.0)))
2
val edges=sc.parallelize(Seq(Edge(1L,2L,1),Edge(1L,3L,2), Edge(2L,4L,3), Edge(2L,5L,4), Edge(3L,6L, 5), Edge(3L,7L, 6), Edge(7L,1L,8)))
3
val graph=Graph(vertices,edges)
4
5
//This will not be a tree, because there will be triangle count > 0
6
7
graph.triangleCount.vertices.collect.foreach(println)
8
9
/*
10
Output tuple pair, 1st element is node id, 2nd element is triangle count
11
Some are > 0
12
(4,0)
13
(6,0)
14
(2,0)
15
(1,1)
16
(3,1)
17
(7,1)
18
(5,0)
19
20
*/
21
22
//run traversal on this Graph with loop
23
val start: VertexId = 1
24
//set start vertex attributes to 0.0, others to Double.inifinity
25
val markedGraph = graph.mapVertices((id, _) => if (id == start) 0.0 else
26
Double.PositiveInfinity)
27
28
val vprog = { (id: VertexId, attr: Double, msg: Double) => math.min(attr,msg) }
29
30
val sendMessage = { (triplet: EdgeTriplet[Double, Int]) =>
31
var run:Iterator[(VertexId, Double)] = Iterator.empty
32
//Only the vertex that has Double.PositiveInfinity has not been visited, to avoid endless loop
33
if(!(triplet.srcAttr != Double.PositiveInfinity && triplet.dstAttr != Double.PositiveInfinity)){
34
if(triplet.srcAttr != Double.PositiveInfinity){
35
run = Iterator((triplet.dstId,triplet.srcAttr+1))
36
}else{
37
run = Iterator((triplet.srcId,triplet.dstAttr+1))
38
}
39
}
40
run
41
}
42
43
val mergeMessage = { (a: Double, b: Double) => math.min(a,b) }
44
45
val graphTraverse = markedGraph.pregel(Double.PositiveInfinity, 20)(vprog, sendMessage, mergeMessage)
46
47
println(graphTraverse.vertices.collect.map(x=>x._1).mkString("\n"))
48
49
/*
50
Output the same traversal
51
4
52
6
53
2
54
1
55
3
56
7
57
5
58
59
60
*/
61
62
Copied!

# Summary:

Tree and Graph traversal on Spark Graphx is scalable. Because the traversal is running on the Spark cluster with multiple worker nodes. In fact, the traversal code is running under Google’s Pregel framework, which is designed to run large graph computing, for Pregel, you can see my prior writing:
As always, codes used in this writing are in my GitHub repo.
Thank you for your time viewing this writing.