由一道 OJ 引发的关于 double 类型的一些思考

昨天在做 1104 Sum of Number Segments 这道题的时候发现了一些奇怪的现象,经过多种方法试验得出的结论为:在大量或高精度计算中要避免使用double类型

很早就听说过,在任何关于金额或者积分等系统的设计中,使用double类型来存储数据是绝对禁止的。原因就在于计算机内部的加减乘除运算是通过加法器二进制运算来完成的,而二进制是无法准确表示一个浮点数的,只能在有限的精度内逼近这个值,比如:
$$
4.182
$$
这个数字在JavaBigDecimal中的值被表示为:
$$
4.1820000000000003836930773104541003704071044921875
$$
小小的误差在巨量的运算下累积,错误会被不断放大,导致难以预料的后果。

1
2
3
4
5
BigDecimal a = new BigDecimal("4.182"); // 以字符串形式精确表示的 4.182

BigDecimal b = new BigDecimal(4.182); // 实际上 b 的值为 4.18200000000000038369307...

double c = 4.182; // 同 b

不出意外,运算结果如下:

1
2
3
4
// 计算 4.182^32
System.out.printf("%f\n", a.pow(32)); // 76610608422533170888.500669 精确
System.out.printf("%f\n", b.pow(32)); // 76610608422533395814.068151 相差甚远
System.out.printf("%f\n", Math.pow(c, 32)); // 76610608422533400000.000000 精度严重损失

OK,说完背景,让我们再回到OJ的问题上来。

1104 Sum of Number Segments

题目

时间限制 200 ms 内存限制 64 MB

Given a sequence of positive numbers, a segment is defined to be a consecutive subsequence. For example, given the sequence { 0.1, 0.2, 0.3, 0.4 }, we have 10 segments: (0.1) (0.1, 0.2) (0.1, 0.2, 0.3) (0.1, 0.2, 0.3, 0.4) (0.2) (0.2, 0.3) (0.2, 0.3, 0.4) (0.3) (0.3, 0.4) and (0.4).

Now given a sequence, you are supposed to find the sum of all the numbers in all the segments. For the previous example, the sum of all the 10 segments is 0.1 + 0.3 + 0.6 + 1.0 + 0.2 + 0.5 + 0.9 + 0.3 + 0.7 + 0.4 = 5.0.

Input Specification:

Each input file contains one test case. For each case, the first line gives a positive integer N, the size of the sequence which is no more than 105. The next line contains N positive numbers in the sequence, each no more than 1.0, separated by a space.

Output Specification:

For each test case, print in one line the sum of all the numbers in all the segments, accurate up to 2 decimal places.

Sample Input:

1
2
4
0.1 0.2 0.3 0.4

Sample Output:

1
5.00

分析

这道题并不难,但是涉及到大量的double类型相乘相加运算,使得数据量大于一定阈值时,结果出现难以预料的差异——总而言之,这道题是有问题的。

当我看到这道题时,首先想到是万能的DFS回溯剪枝,但敲完代码后发现回溯太慢了会直接超时。思考之后发觉该题其实是有规律的,这是一道数学题。

注意:我的方法与网上绝大部分人的方法 / 标准答案不同

我不否认标准答案比我的方法效率高,但我认为算法题应该鼓励解题思路的多样性,而该题因为忽略了double类型精度损失的特性扼杀了这一点,后面会详细讨论。

动态规划

以Sample Input为例,从后向前遍历,箭头指示的为当前项,则每一项对应的Segment的即红框部分,每一项比后一项多出来的即黄框部分。

由此可以得出递推公式:
$$
A_{n}=A_{n+1}+(N-n+1)S_{n}
$$
其中$n$为输入数据当前项的序号,$N$为输入数据总项数,$A_{n}$为对应的Segment的和,$S_{n}$为输入数据当前项的值。

然后将$A$的所有项相加,即可得到答案。

Source Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#include<iostream>
#include<vector>
using namespace std;

int main() {
int N;
cin >> N;
vector<double> seq(N + 1);
for (int i = 1; i <= N; i++) {
scanf("%lf", &seq[i]);
}
double sum = seq[N];
for (int i = N - 1; i > 0; i--) {
seq[i] = seq[i + 1] + seq[i] * (N - i + 1);
sum += seq[i];
}
printf("%.2f", sum);
return 0;
}

结果

测试点 2,也就是最大数量测试($10^5$)答案错误,这一点让我比较疑惑。

对结果的讨论

确认代码无误后,我尝试在网上寻找答案,结果发现几乎所有人都使用了下面这种算法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
#include<iostream>
using namespace std;

int main() {
int N;
cin >> N;
double v, ans = 0;
for (int i = 1; i <= N; i++) {
scanf("%lf", &v);
ans += v * i * (N - i + 1);
}
printf("%.2f", ans);
return 0;
}

需要承认的是,在搜索之前我确实没想到这种方法。

这种方法——我暂且称之为标准答案——和我的方法耗时差不多,但是由于不用储存数据,占用内存很小。

我花了很久思考我自己写的算法问题在哪。

无果,那就写个测试函数分别用两种算法测试同一组数据试试吧。

比较测试(使用double类型)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include<iostream>
#include<random>
#include<vector>
#include<ctime>

using namespace std;

// 我的方法
double fun1(vector<double> &seq) {
int N = seq.size() - 1;
double sum = seq[N];
for (int i = N - 1; i > 0; i--) {
seq[i] = seq[i + 1] + seq[i] * (N - i + 1);
sum += seq[i];
}
return sum;
}

// 标准答案
double fun2(vector<double> &seq) {
double ans = 0;
int N = seq.size() - 1;
for (int i = 1; i <= N; i++) {
ans += seq[i] * i * (N - i + 1);
}
return ans;
}

int main() {
srand(unsigned(time(NULL)));
const int N = 100000; // 此处 N 设置为题目要求最大值
vector<double> seq(N + 1);
int cnt = 0; // 测试计数
while (true) {
cout << cnt++ << endl;
for (int i = 1; i <= N; i++) {
// 使用3位随机小数
seq[i] = 0.001 * (rand() % 1000);
}
// 如果两种方法答案误差在0.01以上(题目条件)
double a = fun2(seq), b = fun1(seq);
if (fabs(a - b) > 0.01) {
printf("%.2f\n", a);
printf("%.2f\n", b);
}
}
return 0;
}

上述程序运行后,很快便出现了输出:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
0
82821280742034.05
82821280742034.59
1
82505311440121.89
82505311440122.61
2
82563594625344.42
82563594625344.14
3
82862178812165.50
82862178812165.14
4
82948375961733.53
82948375961733.02
5
82932380057476.22
82932380057475.92

通过连续的几百个输出可以看到两种方法得出的结果必不相同,且差值还是比较明显的,至少不符合OJ系统对于精确度的要求。

但是这说明我的算法有问题吗?并不是。

将输入数据的量 N 从$10^5$改到$10^4$或更小后,两种算法输出便完全相同了。(注意$10^4$并不是上界)

多年的直觉告诉我这种情况与算法无关了,是语言细节出了问题。

比较测试(使用long long类型)

让我们来绕过double,试验一下两种算法在精确计算的情况下输出是否相同吧。

将输入的数据类型全部改为long long之后,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include<iostream>
#include<random>
#include<vector>
#include<ctime>

using namespace std;

// 我的方法
long long fun1(vector<long long> &seq) {
int N = seq.size() - 1;
long long sum = seq[N];
for (int i = N - 1; i > 0; i--) {
seq[i] = seq[i + 1] + seq[i] * (N - i + 1);
sum += seq[i];
}
return sum;
}

// 标准答案
long long fun2(vector<long long> &seq) {
long long ans = 0;
int N = seq.size() - 1;
for (int i = 1; i <= N; i++) {
ans += seq[i] * i * (N - i + 1);
}
return ans;
}

int main() {
srand(unsigned(time(NULL)));
const int N = 100000; // 此处 N 设置为题目要求最大值
vector<long long> seq(N + 1);
int cnt = 0; // 测试计数
while (true) {
cout << cnt++ << endl;
for (int i = 1; i <= N; i++) {
// 使用3位随机数
seq[i] = rand() % 1000;
}
// 如果两种方法答案不相等
long long a = fun2(seq), b = fun1(seq);
if (a != b) {
printf("%lld\n", a);
printf("%lld\n", b);
}
}
return 0;
}

运行!

不出所料,Console的输出只有连续的计数——两种算法结果完全相同

谁更精确?

其实故事到这里已经可以告一段落了,但是这个问题困扰我研究到半夜3点,我非要抬一下杠

Java提供了可以比较精确地进行数学计算的大数类BigDeciaml,于是,测试程序可以稍微改一下,使用Java跑出“精确答案”。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import java.io.*;
import java.math.BigDecimal;

public class Test {
private static int N;
private static BufferedReader br;

static {
try {
br = new BufferedReader(
new InputStreamReader(
new FileInputStream("src\\Text.txt")));
} catch (FileNotFoundException e) {
e.printStackTrace();
}
}

public static void main(String[] args) throws IOException {
N = Integer.parseInt(br.readLine());
String[] data = br.readLine().split(" ");
BigDecimal[] seq = new BigDecimal[data.length + 1];

for (int i = 1; i < seq.length; i++) {
seq[i] = new BigDecimal(data[i - 1]);
}

System.out.printf("%f\n", fun2(seq));
System.out.printf("%f\n", fun1(seq));
}

// 我的算法
public static BigDecimal fun1(BigDecimal[] seq) {
BigDecimal sum = seq[N];
for (int i = N - 1; i > 0; i--) {
seq[i] = seq[i + 1].add(seq[i].multiply(new BigDecimal(N - i + 1)));
sum = sum.add(seq[i]);
}
return sum;
}

// 标准答案
public static BigDecimal fun2(BigDecimal[] seq) {
BigDecimal ans = new BigDecimal(0);
for (int i = 1; i <= N; i++) {
ans = ans.add(seq[i].multiply(
new BigDecimal(i).multiply(
new BigDecimal(N - i + 1))));
}
return ans;
}
}

所以我用C++0.001 * rand() % 1000这句代码,抓了一个样例(10万个测试数据),在两个平台分别用两种算法进行测试,结果如下:

测试方法 测试结果 相对误差
Java BigDecimal (我的方法) 82725496576006.158000 0
Java BigDecimal (标准答案) 82725496576006.158000 0
C++ (我的方法) 82725496576006.140625 0.017375
C++ (标准答案) 82725496576005.750000 0.408000

由上表,两种算法如果不考虑精度损失的情况下都是完全正确的,但是OJ系统判定的答案却只考虑了一种解法,且与精确答案不同——不如说C++在没有大数类的情况下,想要计算出精确结果是非常不容易的,应该避免在这种大量计算中使用double类型进行判定。

后记

一道小小的OJ,做题20分钟,查错1小时,查资料1小时,怀疑人生1小时,做测试3小时,写博客2小时。

手动再见。


2020.6.15 更新:经过反馈,PAT已经修改了该题数据

但是!

本文提到的方法全部木大,现在我也不知道怎么才能AC了…


2020.6.16 再次更新:已AC,但我反馈的不是这个意思…

总体思路就是让小数点后移,避免double类型的不连续性及不精确性影响到答案。

先上代码:

Solution 1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#include<iostream>
using namespace std;

int main() {
int N;
cin >> N;
double v;
long long ans = 0;
for (int i = 1; i <= N; i++) {
scanf("%lf", &v);
ans += (long long) (v * 1000) * i * (N - i + 1);
}
printf("%.2f", ans / 1000.0);
return 0;
}

Solution 2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#include<iostream>
#include<vector>
using namespace std;

int main() {
int N;
cin >> N;
vector<long long> seq(N + 1);
double temp;
for (int i = 1; i <= N; i++) {
scanf("%lf", &temp);
seq[i] = temp * 1000;
}
long long sum = seq[N];
for (int i = N - 1; i > 0; i--) {
seq[i] = seq[i + 1] + seq[i] * (N - i + 1);
sum += seq[i];
}
printf("%.2f", sum / 1000.0);
return 0;
}

然而,我的本意是让数据直接给整型,避免使用容易产生误差的浮点型,因为这样题目难度可以保持不变。而现在的方案加大了题目难度…

看了一下今天这道题的提交情况,挺惨烈的。

ε=ε=ε=(~ ̄▽ ̄)~