Partition trees

Lei Yan

2019/03/08

在学习CS61A的时候遇到了如下问题:

给定正整数n,用最大不超过m的部分来划分n,一共有多少种方法?

举个例子令n = 2, m = 2,则一共有两种方法,分别是:
2 = 2
2 = 1 + 1
这个例子比较简单,一下子就可以算出来。接下来考虑一个复杂点的,比如n = 6, m = 4,此时该如何计算呢?下面先看一张图:
kzcvB4.md.jpg

计算可分为两部分,一部分是用到m = 4(即上图的前两行),另一部分是没有用到m = 4(即上图的后七行)。在用到m = 4的那部分,可以看出划分的个数就等于count_partitions(n-m, m);在不用m = 4的那部分,最大可以用到m - 1 = 3,也就是count_partitions(n, m-1)。这样递归下去,总会出现n - m <= 0或者m = 0的情况,如果n = 0则只有一种划分方式,就是0,如果n < 0 or m = 0,则一种划分方式都不存在。把上面的分析过程翻译为代码就是下面的函数:

def count_partitions(n, m):
    if n == 0:
        return 1
    elif n < 0 or m == 0:
        return 0
    else:
        return count_partitions(n-m, m) + count_partitions(n, m-1)

count_partitions(6, 4)
## 9

上面的函数只是计算出来划分的个数,那么能不能把所有的划分结果都输出呢?为了解决这个问题,我们需要树(tree)这种数据结构。

# 首先是tree的constructor
def tree(label, branches=[]):
    """
    Input : 
           label   : the root of tree
           branches: branches of tree
    Output:a tree(a list of lists) 
    """
    for branch in branches:
        assert is_tree(branch), 'branches must be tree'
    return [label] + list(branches)

# 接下来的两个函数是selector
def label(tree):
    """Get root of a tree"""
    return tree[0]

def branches(tree):
    """Get the branches of a tree"""
    return tree[1:]

# 下面是两个判断函数

# 判断是否是 tree 的思路很简单
# 首先 tree 必须是一个 list 而且要有 root
# 其次 tree 的每一个 branch 都必须是树,这里就用了递归来做了
def is_tree(tree):
    if type(tree) != list or len(tree) < 1:
        return False
    for branch in branches(tree):
        if not is_tree(branch):
            return False
    return True

# 判断是否是叶子,如果 tree 有 branch 那就不是叶子
def is_leaf(tree):
    return not branches(tree)

# 下面是把前面的划分给组成一颗树
# 两个递归基的分析和第一个程序一样
# 不同的是下面的函数会生成一个划分树(二叉树)
# 树的根是 m,两个分支是 m, m-1
# 然后递归的往下分
# 递归基的情况,如果 n = 0 and m != 0,那说明刚好
# 划分完,于是留下一个 True 的叶子
# 反之,留下 False 的叶子
# 从一个 True 叶子出发到达根部(不包括),就是一个 n 的划分
# False叶子不是一个划分
def partition_tree(n, m):
    """Return a partition of n using parts of up to m"""
    if n == 0:
        return tree(True) 
    elif n < 0 or m == 0:
        return tree(False)
    else: # left 对应至少用一个 m, right 至多只用 m-1
        left, right = partition_tree(n-m, m), partition_tree(n, m-1)
        return tree(m, [left, right])

下面的函数可以打印出所有的划分结果

# 由上面的分析可知,我们只打印叶子是 True 的
# 这样就完成了递归基
# 接下来是递归主体
# 只有一点需要注意,那就是因为 left 部分至少用了一个 m
# 所以它的 partition 需要把 m 加进去
# S.join(iterable) 
# Return a string which is the concatenation of the strings in the
# iterable. The separator between elements is S.
def print_parts(tree, partition=[]):
    if is_leaf(tree):
        if label(tree):
            print(' + '.join(partition)) 
    else:                                
        left, right = branches(tree) # 'tree' is binary
        m = str(label(tree))
        print_parts(left, partition + [m])
        print_parts(right, partition)

接下来我们来用上面的print_parts()函数来打印上面图片的例子

print_parts(partition_tree(6, 4))
## 4 + 2
## 4 + 1 + 1
## 3 + 3
## 3 + 2 + 1
## 3 + 1 + 1 + 1
## 2 + 2 + 2
## 2 + 2 + 1 + 1
## 2 + 1 + 1 + 1 + 1
## 1 + 1 + 1 + 1 + 1 + 1