let us see the code first ( I have segmented it for a clearer view )
I recommend writing the code down and sketching out the tree for the ease of understanding
Tree : 1 2 3 x 5 x x 4 x x 6 x x
segment 1: def dfs(root):
segment 2: if not root:
return 0
return max( dfs(root.left) , dfs(root.right) ) + 1
   
 segment 3 segment 4 


segment 5
Now our initial root.val is 1
so code segment 1 (S1) executes
segment 2 (S2) is skipped
segment 3(S3) executes
now root.val = 2
so code S1 executes
S2 is skipped
S3 executes
now root.val = 3
so code S1 executes
S2 is skipped
S3 executes
and goes to S1 again, root.val = None (as 3 has no elements to it’s left)
then to S2 and returns 0
so now for root.val = 3 , S3 = 0
code moves to S4
now root.val = 5
so code S1 executes
S2 is skipped
S3 executes
and goes to S1 again, root.val = None (as 5 has no elements to it’s left)
then S2 executes and returns 0
so now for root.val = 5 , S3 = 0
the same happens to S4 (as 5 has no elements to it’s right)
so now for root.val = 5 , S4 = 0
the code now moves to S5
max (root.val = 5’s S3 , root.val = 5’s S4) + 1
Which is max( 0 , 0 ) + 1
this returns 1 to root.val = 3’s S4
if you look back, root.val = 3’s S3 returned 0 and now we have S4
then code now moves to S5
max (root.val = 3’s S3 , root.val = 3’s S4) + 1
Which is max( 0 , 1 ) + 1
this returns 2 to root.val = 2’s S3
now since root.val = 2’s S3 is executed
the code moves to S4 and returns 1 ( the same cycle that happened for values 3 and 5)
and for S5
max (root.val = 2’s S3 , root.val = 2’s S4) + 1
Which is max( 2 , 1 ) + 1
this returns 3 to root.val = 1’s S3
And for root.val = 1’s S4 it returns 1 ( the same cycle that happened for values 3 and 5)
so now we have max( 3 , 1 ) + 1
Which returns the height of the tree as 4