优化后的Sieve of Eratosthenes筛法

之前写了Sieve of Eratosthenes筛法,但总觉得算法还不够好,因为我们为每个数都留了一个bit,而显然的,除了2以外的偶数都不可能是质数,那么如果我们只为奇数申请bit位来标记的话,效率应该有进一步的提升。

在原来的算法中,我们初始化时是这样的

初始化
Sieve of Eratosthenes筛法初始化

那么,我们真的需要把偶数放进来吗?

Why on earth we need even numbers while we pick primes
Why on earth do we need even numbers while we pick primes

只留下奇数,我们的计算量会少一半

只留下奇数的表
只留下奇数的表

根据Sieve of Eratosthenes筛法,找出100内的质数,只需要循环到sqrt(100)即可。于是我们从3开始循环(为什么不要2呢,因为已经没有其他的偶数了,而且,大家都知道2是唯一一个既是偶数又是质数的数,我们在程序里甚至不必给2保留标志位)

Sieve-3
找出3的倍数并标记

Sieve-5
找出5的倍数并标记

Sieve-7
找出7的倍数并标记

下一个数是11,不必再循环它和比它更大的数了,我们可以验证,此时的确没有还没被标记的11的倍数了

Sieve-11
验证

于是,我们就得到了100以内的所有质数(浅绿色标记)

Primes-in-100
100以内的所有质数

接下来是算法实现。为了方便计算(也是为了和上文一致),我们直接从3开始保存。3的下标是0,5的下标是1,以此类推,我们可以知道数字n与它对应的bit位t的关系是

t = (n - 2) >> 1

然后,我们从i=3开始循环,每次增加2。之后在我们每次从j = i*i开始标记时,因为i一定是奇数, 而j从i*i开始, 那么j也一定是奇数,又因为 奇数+奇数=偶数, 所以我们直接加上2*i来确保得到奇数。

于是修改后的sieve主要部分如下:

// 计算需要多少个字节, (n & 0xF)确保有足够的位
size_t length = n >> CHAR_BIT_LOG;
if (n & 0xF) length += 1;

// 只保存奇数
length = (size_t)ceil(length / 2.0);

result = (char *)malloc(length);
memset(result, 0, length);
for (int i = 3; i <= (int)floor(sqrt(n)); i+=2) {
    if (!isset(result, (i - 2) >> 1)) {
        int j = i*i;
        while (j <= n) {
            setbit(result, (j - 2) >> 1);
            // 因为i一定是奇数, 而j从i*i开始, 那么j也一定是奇数
            // 又因为 奇数+奇数=偶数, 所以我们直接加上2*i来确保得到奇数
           j += 2*i;
        }
    }
}

与先前版本的对比,在n=400000000时,优化后的sieve速度翻了一倍以上

Optimum Sieve of Eratosthenes
Optimum Sieve of Eratosthenes

以下是源代码,

//
//  main.cpp
//  optimum sieve
//
//  Created by Ryza 16/1/28.
//  Copyright © 2016[data deleted]. All rights reserved.
//

#include <iostream>
#include <functional>
#include <chrono>
#include <cmath>
 
using namespace std;
 
#define CHAR_BIT_LOG 3
#define MASK (~(~0 << CHAR_BIT_LOG))
#define setbit(a, x) ((a)[(x) >> CHAR_BIT_LOG] |= 1 << ((x) & MASK))
#define isset(a, x) ((a)[(x) >> CHAR_BIT_LOG] & (1 << ((x) & MASK)))
 
char * sieve(unsigned long long n) {
    char * result = NULL;
    if (n >= 2) {
        // 计算需要多少个字节, 检查(n & 0xF)以确保有足够的位
        size_t length = n >> CHAR_BIT_LOG;
        if (n & 0xF) length += 1;

        result = (char *)malloc(length);
        memset(result, 0, length);
        for (int i = 2; i <= (int)floor(sqrt(n)); ++i) {
            if (!isset(result, i)) {
                int j = i*i;
                while (j <= n) {
                    setbit(result, j);
                    j += i;
                }
            }
        }
    }
    return result;
}

char * sieve2(unsigned long long n) {
    char * result = NULL;
    if (n >= 2) {
        // 计算需要多少个字节, 检查(n & 0xF)以确保有足够的位
        size_t length = n >> CHAR_BIT_LOG;
        if (n & 0xF) length += 1;

        // 只保存奇数
        length = (size_t)ceil(length / 2.0);
        result = (char *)malloc(length);
        memset(result, 0, length);
        for (int i = 3; i <= (int)floor(sqrt(n)); i+=2) {
            if (!isset(result, (i - 2) >> 1)) {
                int j = i*i;
                while (j <= n) {
                    setbit(result, (j - 2) >> 1);
                    // 因为i一定是奇数, 而j从i*i开始, 那么j也一定是奇数
                    // 又因为 奇数+奇数=偶数, 所以我们直接加上2*i来确保得到奇数
                    j += 2*i;
                }
            }
        }
    }
    return result;
}

template <class Return, class Param>
auto runtime(function<Return(Param)> func, Param param) -> Return {
    Return result;
    auto start = chrono::high_resolution_clock::now();
    result = func(param);
    auto end = chrono::high_resolution_clock::now();
    printf("%lf\n", ((chrono::duration<double>)((end - start))).count());
    return result;
}

int main(int argc, const char * argv[]) {
    int upperbound = 42;
    char * result = runtime<char *, unsigned long long>(sieve, upperbound);
    for (int i = 3; i < upperbound; i++) {
        if (!isset(result, i)) {
            printf("%d ", i);
        }
    }
    printf("\n");
    free(result);

    result = runtime<char *, unsigned long long>(sieve2, upperbound);
    for (int i = 0; i < upperbound / 2; i++) {
        size_t prime = ((i<<1) + 3);
        if (!isset(result, i) && (prime < upperbound)) {
            printf("%zu ", prime);
        }
    }
    printf("\n");
    free(result);

    return 0;
}

Leave a Reply

Your email address will not be published. Required fields are marked *

2 × 3 =