多个大文件数据排序去重的解题思路

问题

假设有四个文件,每个文件 1GB,文件里的每一行存储一个随机的 int正整数 (0 <= v < 2^32),单文件最多 1亿行。要求对这四个文件进行排序去重,输出一个有序的文件。

假设服务器规格 8核16GB内存,希望执行时间尽可能短。

在上述的前提条件下,如果排序完之后,还希望输出每个数字出现的次数呢?

文件读写效率

在 Java 下,文件的 io 方式分为 字节流 和 字符流,字符流又分为 字符输入流 和 字符输出流。
字节流和字符流的区别在于,字节流是以字节为单位读写文件,字符流是以字符为单位读写文件。

字符流的好处是可以指定编码,比如 UTF-8,GBK 等,而字节流只能使用默认的编码。

对于该题目,相比之下,字符流更适合,因为我们只需要读取每一行,然后转换成 int,不需要考虑编码的问题。所以后文会有很多
BufferReader 和 BufferWriter 相关的代码,这里就不再赘述。

常规思路

假设要对所有的数字读取出来进行排序,全部读取出来其实内存里也能放得下。4个文件400000000行,int值占4个字节,理论上只需要 1.6G 的空间,但是做快排时,递归深度太深容易栈溢出。如果数据量更大,或者内存更小,比较理想的方案是外部排序法。

那么我们可以考虑分治的思想,将一个大文件分成多个小文件,每个小文件放到内存里排序,然后再将这些小文件合并成一个大文件。这样就可以解决内存放不下的问题。

1、分割文件、子文件排序完后写入临时文件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
/**
* 每个小文件的行数
*/
private static final int CHUNK_SIZE = 2500000;

/**
* 每个大文件拆成的小文件个数
*/
private static final int SPLIT_COUNT = 40;

/**
* 总的小文件个数
*/
private static final int TOTAL_CHUNK_COUNT = FILE_PATHS.length * SPLIT_COUNT;

private static final ExecutorService POOL = Executors.newFixedThreadPool(16);

private static void splitFiles() {
CountDownLatch latch = new CountDownLatch(TOTAL_CHUNK_COUNT);

AtomicInteger count = new AtomicInteger(0);
for (String filePath : FILE_PATHS) {
// 多线程读文件
POOL.submit(new Runnable() {
@Override
public void run() {
splitSingleFile(filePath, new SplitCallback() {
@Override
public void onSplit(List<Integer> lines) {
POOL.submit(new Runnable() {
@Override
public void run() {
try {
int index = count.incrementAndGet();
sortChunk(lines, index);
writeChunk(lines, index);
} catch (IOException ignored) {
} finally {
latch.countDown();
}
}
});
}
});
}
});
}

try {
latch.await();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}

service.shutdown();
}

/**
* 大文件切割成若干个小文件
*/
private static void splitSingleFile(String filePath, SplitCallback callback) {
List<Integer> lines = new ArrayList<>(CHUNK_SIZE);
try (BufferedReader reader = new BufferedReader(new FileReader(filePath))) {
String line;
while ((line = reader.readLine()) != null) {
lines.add(StringUtils.stringToInt(line));
if (lines.size() >= CHUNK_SIZE) {
List<Integer> originLines = lines;
callback.onSplit(originLines);
lines = new ArrayList<>(CHUNK_SIZE);
}
}
} catch (IOException e) {
e.printStackTrace();
}

if (!lines.isEmpty()) {
callback.onSplit(lines);
}

if (DEBUG) {
log("split file " + filePath + " end");
}
}

/**
* 对每个小文件进行排序
*/
private static void sortChunk(List<Integer> lines, int index) throws IOException {
// Collections.sort(lines); // 归并
CollectionUtils.quickSort(lines); // 快排

if (DEBUG) {
log("sort chunk " + index + " end");
}
}

/**
* 将排序后的内容写入临时文件
*/
private static void writeChunk(List<Integer> lines, int index) throws IOException {
String tempFileName = "temp_" + index + ".txt";
try (BufferedWriter writer = new BufferedWriter(new FileWriter(tempFileName))) {
for (Integer line : lines) {
writer.write(Integer.toString(line));
writer.newLine();
}
}

if (DEBUG) {
log("write chunk " + index + " end");
}
}

2、合并临时文件,通过优先队列不断地从临时文件里的每一行取出最小的,然后写入最终的输出文件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
private static void mergeSortFiles() throws IOException {
File[] tempFiles = new File[TOTAL_CHUNK_COUNT];
for (int i = 0; i < TOTAL_CHUNK_COUNT; i++) {
tempFiles[i] = new File("temp_" + (i + 1) + ".txt");
}

// 使用优先队列进行合并排序
PriorityQueue<Pair> pq = new PriorityQueue<>(tempFiles.length, Comparator.comparing(Function.identity(), Main::compareLines));

// 将每个临时文件的第一行添加到优先队列中
for (File tempFile : tempFiles) {
BufferedReader br = new BufferedReader(new FileReader(tempFile));
pq.offer(new Pair(br, StringUtils.stringToInt(br.readLine())));
}

// 合并排序后的结果输出到文件
int current = pollAndOffer(pq);
int count = 1;
try (BufferedWriter writer = new BufferedWriter(new FileWriter(OUTPUT_PATH))) {
while (!pq.isEmpty()) {
int line = pollAndOffer(pq);
if (current != line) {
writer.write(current + ":" + count + "\n");
current = line;
count = 1;
} else {
count++;
}
}
writer.write(current + ":" + count + "\n");
writer.flush();
}

if (DEBUG) {
log("merge temp file end");
}

// 删除临时文件
for (File tempFile : tempFiles) {
tempFile.delete();
}
}

private static int pollAndOffer(PriorityQueue<Pair> pq) throws IOException {
Pair pair = pq.poll();
int value = pair.value;
String nextLine = pair.br.readLine();
if (nextLine != null && nextLine.length() > 0) {
pair.value = StringUtils.stringToInt(nextLine);
pq.offer(pair);
} else {
pair.br.close();
}
return value;
}

private static int compareLines(Pair br1, Pair br2) {
return (br1.value < br2.value) ? -1 : ((br1.value == br2.value) ? 0 : 1);
}

位图思路

仔细理解一下题目要求,我们只需要对这四个文件进行排序去重,输出一个有序的文件。我们并不需要知道每个数字出现的次数,只需要知道这个数字是否出现过。那么我们可以使用位图的思想,将每个数字映射到一个bit位上,如果这个数字出现过,那么这个bit位就是1,否则就是0。这样我们就可以用一个bit位来表示一个数字是否出现过,而不需要用一个int来表示一个数字出现的次数。

最后按顺序遍历每一位,如果 bit位是1,就输出对应的数字。

那么在读文件这步就可以优化为以下代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
int[] list = new int[2000000000 / 32 + 1];
CountDownLatch latch = new CountDownLatch(filePaths.length * multiple);
for (String filePath : filePaths) {
try (BufferedReader reader = new BufferedReader(new FileReader(filePath))) {
String line;
while ((line = reader.readLine()) != null) {
int value = stringToInt(line);
int index = value >> 5;
int mod = value & 31; // 取模,相当于 value % 32

list[index] |= 1 << mod;
}
} catch (IOException e) {
e.printStackTrace();
}
}

输出文件:

1
2
3
4
5
6
7
8
9
10
try (BufferedWriter writer = new BufferedWriter(new FileWriter(outputPath))) {
for (int i = 0; i < list.length; i++) {
for (int j = 0; j < 32; j++) {
if ((list[i] & (1 << j)) != 0) {
writer.write(Integer.toString((i << 5) + j));
writer.newLine();
}
}
}
}

位图的效率,比常规排序的思想,速度快了非常多。

位图并发冲突问题

使用位图时,如果多个线程同时对同一个bit位进行写操作,就会出现并发冲突问题。比如线程A和线程B同时对不同bit位进行写1操作,这个是非原子操作,可能会出现线程A写完后,线程B写完后,线程A写的1被覆盖的情况。

借鉴 CAS 的思想,我们可以使用 AtomicIntegerArray 来解决这个问题。

或者借鉴 ConcurrentHashMap 的思想,使用分段锁来解决这个问题。例如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
/**
* 计数分段锁,锁的个数越大,理论上在并发时锁被占用的几率越小
*/
private static final Object[] LOCKS = new Object[0x1000];

for (int i = 0; i < LOCKS.length; i++) {
LOCKS[i] = new Object();
}

private static void splitFiles() throws Exception {
int[] list = new int[2000000000 / 32 + 1];
CountDownLatch latch = new CountDownLatch(FILE_PATHS.length * SEGMENT_COUNT);
for (String filePath : FILE_PATHS) {
POOL.submit(new Runnable() {
@Override
public void run() {
try (BufferedReader reader = new BufferedReader(new FileReader(filePath))) {
String line;
while ((line = reader.readLine()) != null) {
int value = StringUtils.stringToInt(line);
int index = value >> 5;
int mod = value & 31; // 取模,相当于 value % 32

// 分段锁
synchronized (LOCKS[index & 0xFFF]) {
list[index] |= 1 << mod;
}
}
} catch (IOException e) {
e.printStackTrace();
} finally {
latch.countDown();
}
}
});
}

try {
latch.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
}

并发读单个文件

首先要解决的问题是,如何并发读单个文件,首先想到的是可以使用 RandomAccessFile 来解决这个问题,但实际上 RandomAccessFile
的效率极其低下,如果我们还在 BufferReader 的基础上,可以通过 skip() 方法来跳到指定的位置。从而实现分段并发读单个文件。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
long startPos = file.length() / 2; // 假设分为两段,第二段从文件中间开始读
long endPos = file.length();
long readLength = 0; // 已读取的长度
long segmentLength = endPos - startPos; // 第二段的长度

try (BufferedReader reader = new BufferedReader(new FileReader(file))) {
if (startPos > 0) {
// 跳到指定位置
reader.skip(startPos - 1);
char[] chars = new char[1];
reader.read(chars);
if (chars[0] != '\n') {
// 如果指针可能在行中间,跳过这个不完整的行。跳过的这行其实会被上一段读取
String line = reader.readLine();
if (line != null) {
segmentLength -= line.length() + 1;
}
}

String line;
while ((line = reader.readLine()) != null) {
readLength += line.length() + 1;

if (readLength >= segmentLength) {
break;
}
}
}
...
} catch (IOException e) {
e.printStackTrace();
}

实际上 reader.skip() 是有比较大的性能损耗的,因为它是通过不断地跳过字符来实现的。探索了源码发现可以从 BufferReader 构造方法传入的 FileReader 入手。BufferReader 本身并不感知当前读取到的具体位置,而是通过 FileReader 的 FileInputSteam 来实现的。所以我们可以通过反射来获取到 FileInputSteam 的实例,然后调用它的 skip() 方法来实现跳过指定的位置。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
FileReader fileReader = new FileReader(file);
FileInputStream inputStream = null;
try {
Class<?> readerClass = Reader.class;
Field field = readerClass.getDeclaredField("lock");
field.setAccessible(true);
is = (FileInputStream)field.get(fileReader);

BufferedReader reader = new BufferedReader(fileReader)

// 跳到指定位置
is.skip(startPos - 1);
} catch (Exception e) {
e.printStackTrace();
}

并发写文件

由于要求写入的是有序的,那么多线程写入要如何保证有序呢?我们可以多线程,每个线程负责把一段区间内的数字有序写到文件,然后按顺序合并临时文件。这样就可以保证有序了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
/**
* 输出到文件,先并发写入多个临时文件,再合并
*/
private static void output() {
int multiple = 4; // 分段数

List<File> tempFiles = new ArrayList<>();
for (int i = 0; i < multiple; i++) {
File file = new File("temp" + i + ".txt");
tempFiles.add(file);
}

CountDownLatch latch = new CountDownLatch(multiple);

// 批量写入字节长度
int maxLength = 2520;
int maxIndex = maxLength - 20;
char lineBreak = '\n';
char colon = ':';

// 多线程分批写文件
for (int i = 0; i < multiple; i++) {
int fileIndex = i;
POOL.submit(new Runnable() {
@Override
public void run() {
WriteRunnable runnable = null;
// 这里有个细节,第一个文件可以直接写到输出文件,后面的文件先写到临时文件,合并的时候就可以少合并一个文件
File file = fileIndex == 0 ? new File(OUTPUT_PATH) : tempFiles.get(fileIndex);
try (BufferedWriter writer = new BufferedWriter(new FileWriter(file), 128 * 1024)) {
// 生产者-消费者 模式
runnable = new WriteRunnable(writer);
POOL.submit(runnable);

int index = 0;
char[] charArray = new char[maxLength];
int startPos = (int)((long)fileIndex * BIT_ARRAY.length / multiple);
int endPos = (int)(((long)fileIndex + 1) * BIT_ARRAY.length / multiple);
for (int i = startPos; i < endPos; i++) {
if (BIT_ARRAY[i] != 0) {
if (index > maxIndex) {
char[] res = copyOfRange(charArray, 0, index);
// 添加到队列
runnable.queue.offer(res);
index = 0;
}

// 填充数字
index += StringUtils.stringSize(i);
StringUtils.getChars(i, index, charArray);

// 换行
charArray[index++] = lineBreak;
}
}

if (index > 0) {
char[] res = copyOfRange(charArray, 0, index);
runnable.queue.offer(res);
}

while (runnable.queue.size() > 0) {
Thread.yield();
}
} catch (Exception e) {
e.printStackTrace();
} finally {
latch.countDown();

if (runnable != null) {
runnable.run = false;
}
}
}
});
}

try {
latch.await();
} catch (InterruptedException e) {
e.printStackTrace();
}

combineFiles(tempFiles, 1, multiple);

for (File file : tempFiles) {
file.delete();
}
}

/**
* 合并多个文件
*/
private static void combineFiles(List<File> tempFiles, int start, int end) {
if (start >= end) {
return;
}
StringBuilder sb = new StringBuilder();
for (int i = start; i < end; i++) {
sb.append(tempFiles.get(i).getAbsolutePath()).append(" ");
}

try {
List<String> command = new ArrayList<>();
command.add("/bin/sh");
command.add("-c");
command.add("cat " + sb + " >> " + OUTPUT_PATH);

// 合并文件
ProcessBuilder processBuilder = new ProcessBuilder(command);
Process process = processBuilder.start();
process.waitFor();
} catch (Exception e) {
e.printStackTrace();
}
}

static class MyRunnable implements Runnable {

BufferedWriter writer;

Queue<char[]> queue = new ConcurrentLinkedQueue<>();

boolean run = true;

public MyRunnable(BufferedWriter writer) {
this.writer = writer;
}

@Override
public void run() {
try {
while (run) {
char[] str = queue.poll();
if (str != null) {
writer.write(str);
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}

上面取巧通过 cat 指令来合并文件,效率比 nio 高一些。也可以通过 nio 的方式来合并文件。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
public static void mergeFiles(List<String> fileNames, String outputFileName) {
try {
FileOutputStream outputStream = new FileOutputStream(outputFileName);
FileChannel outputChannel = outputStream.getChannel();

for (String fileName : fileNames) {
Path path = Paths.get(fileName);
FileChannel inputChannel = FileChannel.open(path);

ByteBuffer buffer = ByteBuffer.allocate(1024);

while (inputChannel.read(buffer) > 0) {
buffer.flip();
outputChannel.write(buffer);
buffer.clear();
}

inputChannel.close();
}

outputChannel.close();
} catch (Exception e) {
e.printStackTrace();
}
}

细致优化思路

parseInt() 方法的优化

String.parseInt() 方法里做了很多边界检查,在海量数据的循环里面,这一步的耗时也不容忽视。由于我们知道输入的数字都是正整数,所以可以做一些优化。

1
2
3
4
5
6
7
8
9
10
11
12
public static int stringToInt(String str) {
int num = 0;
int i = 0;
int len = str.length();

while (i < len) {
num = num * 10 + str.charAt(i) - '0';
i++;
}

return num;
}

避免重复的字符串和数字之间的转换

例如写入的时候,直接放在 char[] 数组里,而不是先转换成字符串,BufferWriter write 的时候又会再转换成 char[] 数组。虽然字符串里面也是 char[] 数组,但是会有一些额外的开销。

多线程并发均匀分布

由于文字是有序的,如果我们简单的按写入行数分割,会发现前面的线程因为写入的数字都比较小,写入明显要比后面的线程快,所以我们可以调整参数,让前面的线程写入的行数多一些,尽可能让每个线程的写入时间均匀分布,减少等待时间,例如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
for (int i = 0; i < SEGMENT_COUNT; i++) {
int startPos = (int)(getReadSegmentPosition(i - 1, SEGMENT_COUNT) * totalLength);
int endPos = (int)(getReadSegmentPosition(i, SEGMENT_COUNT) * totalLength);

...
}

// 调整参数,让每个线程写入的数据量分布更均匀
private static int[] writeRatios = new int[] {50, 99, 148, 195, 241, 286, 330, 373};

// 获取第 index 个线程读取的结束位置
private static float getWriteSegmentPosition(int index, int length) {
if (index < 0) {
return 0;
}
return (float)writeRatios[index % length] / writeRatios[writeRatios.length - 1];
}

虽然多线程的调度并不可能每次都均匀,但是相比之前的效率提升还是非常明显的。

并发量控制

虽然8核理论上可以同时执行8个线程,但实际上每个线程的执行过程中,可能会有一些 io 等待时间,所以并发数要大于8才能达到充分利用,例如并发数16,但也不能过度提升并发数,过度反而会降低效率,因为线程的切换也是需要时间的。所以在并发量上需要根据服务器具体性能做一些调整。

JVM 参数调优

JVM 参数也会影响效率,例如 -Xms -Xmx -Xmn -XX:SurvivorRatio -XX:NewRatio 等参数,可以根据服务器的具体情况做一些调整。

例如我们可以预估程序跑的时候大概需要多大的堆空间,把 -Xms 设为该数值,降低频繁GC的概率,或者 -Xms 和 -Xmx 设为相同的值,避免程序运行时扩容、缩容堆空间。

在该题目的背景下,如果能降低 GC 带来的影响是比较理想的,因为内存空间足够,我们不需要太快回收垃圾,可以把垃圾回收器设置为串行回收器,降低一些影响。

进一步优化

在 io 读写方式不变的情况下,上述思路比较难有突破的空间了。Java 中 io 方式大概可以被分为三种:普通 IO(字节流、字符流),FileChannel(文件通道),mmap(内存映射)。

FileWriter,FileReader 存在于 java.io 包中,FileChannel 存在于 java.nio 包中,FileChannel 是 NIO 里面的一种,它的底层是通过操作系统的文件通道来实现的,它的效率比普通 IO 高很多,但是它的效率还是比不上 mmap。

由于系统内存足够,4个1G的文件可以直接全部映射到内存里,这样就可以直接操作内存,而不需要通过系统调用来操作文件。

1
2
3
4
5
6
7
8
9
10
11
12
MappedByteBuffer buffer = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size());
int totalLength = buffer.limit();
int currentNumber = 0;
for (int totalIndex=0; totalIndex < totalLength; totalIndex++) {
byte b = buffer.get(totalIndex);
if (b == '\n') {
// 记录数字
currentNumber = 0;
} else {
currentNumber = curNum * 10 + (b - '0');
}
}

打印重复数量

开头提到,在排好序之后,如果还希望输出每个数字出现的次数,那么位图就不行了,因为它只能记录 0和1。系统内存足够的情况下,最多 20亿个数字,我们可以直接创建一个 20亿长度的 byte数组来存储,约暂用2个G的内存,byte 最大可以存储到 127,因为数字是完全随机的,理论上重复数超过 127 的可能性很小。

假设真的重复数超过 127 怎么办?

1、可以使用 short 数组,short 最大可以存储到 32767,但是这样会占用 4G 的内存。

2、我们可以另起一个 HashMap,key 是数字,value 是出现的次数,当超过 127 时,剩下的数量记录在 HashMap 里面,这样就可以解决重复数超过 127 的问题。