The easiest way to do it is by keeping an array of our lengths:
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def maxDepth(self, root: TreeNode) -> int:
lengths = []
def dfs(root, length = 0):
if not root:
lengths.append(length)
return
length += 1
dfs(root.left, length)
dfs(root.right, length)
dfs(root)
return max(lengths)
The intuitive way is to recurse:
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def maxDepth(self, root: TreeNode) -> int:
def dfs(root, depth = 0):
return max(dfs(root.left, depth + 1), dfs(root.right, depth + 1)) if root else depth
return dfs(root)
We return the maximum of either the left or right, and we error check with the one line if statement.