絶滅

どうでもいい

Project Euler - Problem 251

Cardano Triplets

Problem 251 (日本語訳)

3 個の正整数の組 (a,b,c) が次の式を満たすときこれをカルダノトリプレット(Cardano Triplet) と呼ぶ:

$$\sqrt[3]{ a + b \sqrt{c} } + \sqrt[3]{ a - b \sqrt{c} } = 1$$

例えば, (2,1,5) はカルダノトリプレットである.

a+b+c ≤ 1000 を満たすカルダノトリプレットは 149 ある.

a+b+c ≤ 110,000,000 を満たすカルダノトリプレットはいくつあるか.

(日本語訳より)




簡単そうに見えて全然解けなかったけどメモリと時間をジャブジャブ使ったら解けた.




冒頭の式をこねると根号がはずれる.

$\displaystyle b^{2} c - a^{2} = \left( \frac{ 2a - 1 }{ 3 } \right)^{3} $

右辺は整数にならないといけないので

$ a = 3 k - 1 \ (a = 1, 2, ...) $

とおける.よって問題は

$ b^{2} c = k^{2} ( 8 k - 3 ) $
$ 3 k + b + c \leq 110000001 $

を満たす自然数の組 (k, b, c) を求めることと言い換えることができる.

プロットしてみる

f:id:my316g:20180205041414p:plain

緑とオレンジの交差部分における格子点の数が答えになる

どうやら k の upper bound は 1e7 と 2e7 の間くらいにあって,微分とか色々やると 16224760 くらいだとわかる.

k2 (8 k - 3) のそれぞれについて b2 で割れるか試していると日が暮れるので,素因数分解をして b の候補を絞っていく.

たとえば k = 6 であれば k2 (8 k - 3) = 62 * 45 = 22 * 34 * 5 = b2 c で,これを満たす b として考えられるのは b = 1, 2, 3, 6, 9, 18 である.

というわけで k とか 8 k - 3 とかの約数を求める必要があるが,エラトステネスっぽく上手くやる

map<int, int> fac[16224760];

みたいなヤバイ配列を用意して,約数をどんどん突っ込む

突っ込んだら各 fac[k] に対して簡単な再帰関数で全ての b を求めてやって,不等式を満たすなら解としてカウントしていく.

typedef long long ll;
typedef pair<int, int> pii;

#define AMAX 16224760
#define AAMAX (AMAX << 3)
#define MAX 110000000

bitset<AAMAX> p;
bitset<AAMAX> a8;
map<int, int> fac[AMAX];

ll ans = 0;

void dfs(vector<pii>& v, int pos, const int& a, ll b) {

    if (pos == v.size()) {
        ll apb = 3 * a - 1 + b;
        ll c = 1;
        if (apb + c > MAX) return;
        for (int i = 0; i < v.size(); i++) {
            pii& p = v[i];
            for (int j = 0; j < p.second; j++) {
                c *= p.first;
                if (apb + c > MAX) return;
            }
        }

        ans++;
        return;
    }

    for (int i = 0; i <= v[pos].second; i += 2) {
        v[pos].second -= i;
        dfs(v, pos + 1, a, b);
        v[pos].second += i;
        b *= v[pos].first;
        if (3 * a + b > MAX) return;
    }
}

void triplet(map<int, int>& f, const int& a) {
    vector<pii> v;
    for (const auto& e : f) {
        v.push_back(pii(e.first, e.second));
    }
    dfs(v, 0, a, 1);
}

int main()
{
    clock_t start, end;
    start = clock();

    //cin.tie(0);
    //ios::sync_with_stdio(false);

    p.flip();
    p[0] = false; p[1] = false;
    for (int i = 2; i * i < AAMAX; i++) {
        if (p[i]) {
            for (int j = i * i; j < AAMAX; j += i) {
                p[j] = false;
            }
        }
    }

    for (int a = 1; a < AMAX; a++) a8[(a << 3) - 3] = true;
    
    int i;
    // k^2
    for (i = 2; i * i < AMAX; i++) {
        if (p[i]) {
            for (ll j = i; j < AMAX; j *= i) {
                for (int k = j; k < AMAX; k += j) {
                    fac[k][i] += 2;
                }
            }
        }
    }
    for (; i < AMAX; i++) {
        if (p[i]) {
            for (int k = i; k < AMAX; k += i) {
                fac[k][i] += 2;
            }
        }
    }

    // 8k - 3
    for (i = 3; i * i < AAMAX; i++) {
        if (p[i]) {
            for (ll j = i; j < AAMAX; j *= i) {
                for (int k = j; k < AAMAX; k += j) {
                    if (a8[k]) fac[(k + 3) >> 3][i]++;
                }
            }
        }
    }
    for (; i < AAMAX; i++) {
        if (p[i]) {
            for (int k = i; k < AAMAX; k += i) {
                if (a8[k]) fac[(k + 3) >> 3][i]++;
            }
        }
    }

    for (int a = 1; a < AMAX; a++) {
        triplet(fac[a], a);
        if (!(a & 65535)) cout << a << " " << ans << endl;
    }

    cout << ans << endl;

    end = clock();
    printf("%d msec.\n", end - start);

    return 0;
}

メモリ 3.5GB,実行時間 110 sec の最悪なプログラムができた.どう解くのが正解なんだろう




(2/05 18:48 追記)

Window 幅を 220 程度にして区間ふるい的な方法で約数列挙したら使用メモリが 256MB 程度まで減った.時間は 120 sec ほどと少し遅くなったが……

#include "bits/stdc++.h"

using namespace std;
typedef long long ll;

#define AMAX (1 << 24)
#define AAMAX (AMAX << 4)
#define MAX 110000000
#define WINDOW (1 << 20)

bitset<AAMAX> p;
bitset<AAMAX> a8;
vector<map<int, int>> fac;

ll ans = 0;

//区間ふるい : [L, R) における a^2 (8a - 3) の約数を列挙する
void seg_sieve(vector<map<int, int>>& f, ll L, ll R) {

    f.clear();
    f.resize(R - L);

    // a^2
    ll i;
    for (i = 0; i * i < R; i++) {
        if (p[i]) {
            for (ll j = i; j < R; j *= i) {
                ll q = (L + j - 1) / j * j;
                for (ll k = q; k < R; k += j) {
                    f[k - L][i] += 2;
                }
            }
        }
    }
    for (; i < R; i++) {
        if (p[i]) {
            ll q = (L + i - 1) / i * i;
            for (ll k = q; k < R; k += i) {
                f[k - L][i] += 2;
            }
        }
    }

    // 8k - 3
    ll L8 = 8 * L - 3;
    ll R8 = 8 * R - 3;
    for (i = 3; i * i < R8; i++) {
        if (p[i]) {
            for (ll j = i; j < R8; j *= i) {
                ll q = (L8 + j - 1) / j * j;
                for (int k = q; k < R8; k += j) {
                    if (a8[k]) {
                        fac[((k + 3) >> 3) - L][i]++;
                    }
                }
            }
        }
    }
    for (; i < R8; i++) {
        if (p[i]) {
            ll q = (L8 + i - 1) / i * i;
            for (int k = q; k < R8; k += i) {
                if (a8[k]) fac[((k + 3) >> 3) - L][i]++;
            }
        }
    }
}

void dfs(vector<pii>& v, int pos, const int& a, ll b) {

    if (pos == v.size()) {
        ll apb = 3 * a - 1 + b;
        ll c = 1;
        if (apb + c > MAX) return;
        for (int i = 0; i < v.size(); i++) {
            pii& p = v[i];
            for (int j = 0; j < p.second; j++) {
                c *= p.first;
                if (apb + c > MAX) return;

            }
        }

        ans++;
        return;
    }

    for (int i = 0; i <= v[pos].second; i += 2) {
        v[pos].second -= i;
        dfs(v, pos + 1, a, b);
        v[pos].second += i;
        b *= v[pos].first;
        if (3 * a + b > MAX) return;
    }
}

void triplet(map<int, int>& f, const int& a) {
    vector<pii> v;
    for (const auto& e : f) {
        v.push_back(pii(e.first, e.second));
    }
    dfs(v, 0, a, 1);
}

int main()
{
    clock_t start, end;
    start = clock();

    //cin.tie(0);
    //ios::sync_with_stdio(false);

    // sieve
    p.flip();
    p[0] = false; p[1] = false;
    for (int i = 2; i * i < AAMAX; i++)  if (p[i]) for (int j = i * i; j < AAMAX; j += i) p[j] = false;

    // 8a - 3 flag
    for (int a = 1; a < AMAX; a++) a8[(a << 3) - 3] = true;

    // enum divisor (seg_sieve)
    for (ll L = 1; L < AMAX; L += WINDOW) {
        ll R = L + WINDOW;
        seg_sieve(fac, L, R);
        // get b and c
        for (int a = L; a < min(L + WINDOW, (ll)AMAX); a++) {
            triplet(fac[a - L], a);
        }
        cout << L + WINDOW << " " << ans << endl;
    }

    cout << ans << endl;

    end = clock();
    printf("%d msec.\n", end - start);

    return 0;
}