用动态规划来解决问题

本篇文章讲述动态规划算法,前半部分通过两个简单实例介绍和引入动态规划解决问题的方式。后半部分进一步阐述动态规划在最短路径方面的应用。

引入

常常想为什么需要引入动态规划的概念。我们对于一些数据计算比较大或者有规律的问题,往往想要自动化来实现求其结果。一般情况下都会推导出数学的递推公式。往往都能转化为递归来实现,递归能够以简洁的语法来解决问题,但实际上其效率是非常低下的。而这种情况下我们可以很友好的给编译器提供一些帮助,将递归重新转化为非递归来实现。即将问题的子问题的答案记录在一个表内。

我在知乎上,看到这样一个入门实例,觉得非常好,转载过来:

以乘法计算为例,乘法的定义其实是做 n 次加法,请先忘掉九九乘法表,让你计算 9 乘以 9,如何得到 81 这个解?计算 9 乘以10 呢?9 乘以 999……以及 9 乘以 n呢?

  • 分析问题,构造状态转移方程

“状态转移方程”的学术定义亦可简单找到(比如置顶答案),略去不表。光看“方程”二字,可以明白它是一个式子。 针对以上问题,我们构造它的状态转移方程。 问题规模小的时候,我们可以容易得到以下式子:

9 乘 0=0; 9 乘 1=0+9; 9 乘 2=0+9+9; ……

可以得到:9 乘 n=0+9+…+9 (总共加了n 个 9 )。严谨的证明可以使用数学归纳法,略去不表。 现在,定义 dp(n) = 9 乘 n,改写以上式子:

dp(0) = 9 乘 0 = 0; dp(1) = 9 乘 1 = dp(0)+9; dp(2) = 9 乘 2 = dp(1)+9; ……

作差易得:dp(n) = dp(n-1) + 9;这就是状态转移方程了。 可以看到,有了状态转移方程,我们现在可以顺利求解9 乘 n(n为任意正整数)这一问题。

  • 以空间换时间

虽然能解,但当 n 很大时,计算耗时过大,看不出状态转移方程 dp(n) = dp( n - 1) + 9 与普通方程 9 乘 n = 0+9+…+9(总共加了 n 个 9 )相比有任何优势。 这时,如果 dp(n-1) 的结果已知,dp(n) = dp( n - 1)+9 只需计算一次加法,而 9 乘 n = 0+9+…+9 (总共加了 n 个9) 则需计算 n-1 次加法,效率差异一望即知。

存储计算结果,可令状态转移方程加速,而对普通方程没有意义。 以空间换时间,是令动态规划具有实用价值的必备举措。

上面这个例子简单清晰的介绍了动态规划的概念,我们要理解的就是动态规划过程中,始终有一个缓存的存在,来保存上一次计算出的数值,从而方便下一次调用。而这里,我们也稍微提醒一下自己注意递归和递推的区别。

实例 一

我们来看一个实例。这个例子是我在刷剑指 offer 中遇到的:由 0 1 2 3 4 构成的环。从数字 0 开始每次删除第 3个数字,则删除的前四个数字依次是 2,0,4,1。因此最后剩下的数字是 3。也就是讲删除一个节点之后,那么下一轮的删除的过程开始的节点就是本次删除节点的下一个节点。比如说第一次删除了 2,那么开始下一轮删除过程的开始节点就是 3 ,以此类推。

我们抽象化一下,每次找出第三个被删除的数字,首先要有一个环的概念。在不断删除数字的过程中环是不断的缩小的,通过列举数字,找出规律。在这 n 个数字中, 第一个被删除的数字是 (m-1)%n。为了简单起见,我们把(m- 1)%n 记为 k,那么删除 k 之后剩下的 n-1 个数字为 0,1,… ,k-1,k+1,… ,n-1,并且下一次删除从数字 k+1 开始计数。相当于在剩下的序列中, k+1 排在最前面,从而形成 k+1,… ,n- 1,0,I,… ,k-1 。该序列最后剩下的数字也应该是关于 n 和 m 的函数。由于这个序列的规律和前面最初的序列不一样(最初的序列是从 0 开始的连续序列),因此该函数不同于前面的函数,记为 f’(n-1,m)。最初序列最后剩下的数字 f(n, m)一定是删除一个数字之后的序列最后剩下的数字,即 f(n, m) = f’(n-1, m)。

接下来我们把剩下的这 n-1 个数字的序列 k-1, …,n-1,0,1,… ,k-1 做一个映射,映射的结果是形成一个从 0 到 n-2 的序列:  

k+1    ->    0 k+2    ->    1 … n-1    ->    n-k-2 0   ->    n-k-1 … k-1   ->   n-2

把映射定义为 p,则 p(x) = (x-k-1)%n,即如果映射前的数字是 x,则映射后的数字是 (x-k-1)%n。举个栗子,假设 x = k + 1,那么 p(x) = (x-k-1)%n = (k+1-k-1) % n = 0 % n = 0 ,即为映射之后的数字 0 。对应的逆映射是 p-1(x)=(x+k+1) % n。这里涉及到数学知识,开始还懵逼不知道什么是反映射,看成反函数就好了。

由于映射之后的序列和最初的序列有同样的形式,都是从 0 开始的连续序列,因此仍然可以用函数 f 来表示,记为 f(n-1,m)。根据我们的映射规则,映射之前的序列最后剩下的数字 f’(n-1,m)= p-1 [f(n-1,m)] = [f(n-1,m)+k+1]%n。把 k = (m-1)%n 代入得到 f(n,m) = f’(n-1,m) = [f(n-1,m) + m] % n。明显当 n = 1 的时候,也就是说环中最开始只有一个数字 0 ,所以最后剩下的数字也是 0 。得到数学递归公式如下:

我们根据公式写出算法代码如下:

public static int lastRemaining1(int n, int m) {
    if (n < 1 || m < 1) {
        return -1;
    }
  
    int last = 0;
    for (int i = 2; i <=n ; i++) {
        last = (last + m)%i;
    }
    return last;
}

实例二

我们再来看一个实例 :输入一个整型数组,数组里有正数也有负数。数组中一个或连续的多个整数组成一个子数组。求所有子数组的和的最大值。要求时间复杂度为 O(n)。例如输入的数组为{1, -2, 3, 10, -4, 7, 2, -5},和最大的子数组为{3, 10, -4, 7, 2}。因此输出为该子数组的和 18 。

我们还是用动态规划来解决这个问题,如下这个状态方程还是很容易得到的。每一次保留上一次计算结果,作为下一次计算的预备数。而当进行下一次计算的时候,我们还要继续对上一次保留的预备数进行判断,如果大于 0,我们继续将数组上的值累加上去,如果上一次预备数小于 0 ,我们抛弃上一次计算得到的预备数,重新用数组项的数值作为新的预备数。状态方程如下:

public static int lastRemaining2(int a[]){
        int last = 0;// 上一次结算结果
        int now = 0; // 目前最大值
        int sum = 0; // 最终最大值
  
        for (int i = 0; i < a.length; i++) {
            if (last <= 0)
                now = a[i];
            else
                now = last + a[i];
            if(sum < now) // 注释 1 处。留一个临时变量来保存最大值
                sum = now;
            last = now;
        }
        return sum;
    }

上面整个一段代码在注释一之前是没什么问题的。之所以加上注释一处的代码,因为,我们最开始手工计算的时候输出的最大值应该是 18。然后程序运行输出却是 13 。 发现原因主要是最后一次还是加上了 -5 。所以我们用一个变量保存目前为止最大值,而这个目前最大值的更新与否,就要不停的比较了。比如说我们倒数第二次计算出数值和为 18 。倒数第一次按照程序逻辑我们还是要去加上 -5 。这一步没有问题,是符合流程的。计算出的结果为 13 。却不是我们期待的,所以存在一个 13 与 18 的比较、最大值更新与否的问题。

目前为止,我们应该说很顺利的引入了动态规划的概念。

进一步应用

如下所示路线图,我们的目的是要找出 A 到 E 之间最短的距离。

我们仍然采用动态规划的思想来解决这个问题

STEP 1:描述最优解结构

用0到10分别表示这10个节点,要使得 0 到 10 之间的距离最短,令 f(i) 为到第 i 个节点的最短距离,则 f(10) = min( f(7)+3 ,f(8)+4,f(9)+3 ),同样的道理,我们得到 f(7),f(8),f(9).

STEP 2 :递归定义最优解的值

f(i) = min(f(i)+Dij)

其中 j 表示与 i 边有连接的点,并且 f(0) = 0;

STEP3:自底向上方式计算每个节点的最优值

最终我们利用递推公式分别求得 f(1) 到 f(10) 就好了。

按照动态规划的思想,我们在每一个阶段寻找新的节点的时候,一定要要求它是最优的子结构。

算法核心代码如下:

// 计算最短距离,并不包括计算其路径
public static int[] calMinDistance(int distance[][]) {  
        int dist[] = new int[distance.length];  
        dist[0] = 0;  
        for (int i = 1; i < distance.length; i++) {  
            int k = Integer.MAX_VALUE;  
            for (int j = 0; j < i; j++) {  
                if (distance[j][i] != 0) {  
                    if ((dist[j] + distance[j][i]) < k) {  
                        k = dist[j] + distance[j][i];  
                    }  
                }  
            }  
            dist[i] = k;  
        }  
        return dist;  
    }  

参考

Back

评论或交流:wuchangfeng2015@gmail.com