洛谷P2408 不同子串个数

原题链接:洛谷P2408

题目描述:给定一个字符串 $s$,求 $s$ 的不同子串的个数。

后缀自动机

$\mathrm{Ans} = \sum_{i=1}^{\rm size}(maxlen[i]-minlen[i]+1)$ (每种状态中的子串皆不相同,且长度连续,因此只要累加每种状态集合中的子串个数即可)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#include <iostream>
#include <cstdio>
#include <cstring>

using namespace std;

typedef long long LL;

const int N = 100010, M = N * 2;

int n;
char str[N];
struct SAM {
    int maxlen[M], trans[M][26], link[M], sz, last;
    SAM() {
        memset(maxlen, 0, sizeof maxlen);
        memset(trans, 0, sizeof trans);
        memset(link, 0, sizeof link);
        sz = last = 1;
    }
    void extend(int ch) {
        int cur = ++sz, p = last;
        maxlen[cur] = maxlen[p] + 1;
        for(; p && !trans[p][ch]; p = link[p]) trans[p][ch] = cur;
        if(!p) link[cur] = 1;
        else {
            int q = trans[p][ch];
            if(maxlen[p] + 1 == maxlen[q]) link[cur] = q;
            else {
                int clone = ++sz;
                maxlen[clone] = maxlen[p] + 1;
                memcpy(trans[clone], trans[q], sizeof trans[clone]);
                link[clone] = link[q];
                for(; p && trans[p][ch] == q; p = link[p]) trans[p][ch] = clone;
                link[cur] = link[q] = clone;
            }
        }
        last = cur;
    }
} sam;

int main() {
    scanf("%d%s", &n, str);
    for (int i = 0; i < n; i++) sam.extend(str[i] - 'a');
    LL ans = 0;
    for (int i = 1; i <= sam.sz; i++) {
        ans += sam.maxlen[i] - sam.maxlen[sam.link[i]]; // maxlen[i] - minlen[i] + 1
    }
    cout << ans << endl;
    return 0;
}

后缀数组

$\mathrm{Ans}=\frac{n(n+1)}{2}-\sum_{i=2}^{n}height[i]$ (所有子串个数-重复出现多余的个数)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#include <iostream>
#include <cstdio>
#include <cstring>

using namespace std;

typedef long long LL;

const int N = 100010;
int n;
int sa[N], rk[N], cnt[N], oldrk[N << 1], id[N], px[N];
char s[N];
int ht[N];

inline bool cmp(int x, int y, int k) {
    return oldrk[x] == oldrk[y] && oldrk[x + k] == oldrk[y + k];
}

void SA() {
	int i, p = 0, k, m = 300;
	memset(cnt, 0, sizeof cnt);
	for (i = 1; i <= n; i++) cnt[rk[i] = s[i]]++;
	for (i = 1; i <= m; i++) cnt[i] += cnt[i - 1];
	for (i = n; i >= 1; i--) sa[cnt[rk[i]]--] = i;
	for (k = 1; k < n; k <<= 1, m = p) { 
		for (p = 0, i = n ; i > n - k; i--) id[++p] = i;
		for (i = 1;i <= n; i++){
			if (sa[i] > k) id[++p] = sa[i] - k;
		}
		memset(cnt, 0, sizeof cnt);
		for (i = 1; i <= n; i++) cnt[px[i] = rk[id[i]]]++;
		for (i = 1; i <= m; i++) cnt[i] += cnt[i - 1];
		for (i = n; i >= 1; i--) sa[cnt[px[i]]--] = id[i];
		memcpy(oldrk, rk, sizeof rk);
		for (p = 0, i = 1; i <= n; i++){
			rk[sa[i]] = cmp(sa[i], sa[i - 1], k) ? p : ++p;
		}
	}
	
	for (int i = 1; i <= n; i++) {
		if (rk[i] == 1) continue;
		int k = max(ht[rk[i - 1]] - 1, 0);
		while (i + k <= n && s[i + k]== s[sa[rk[i] - 1] + k]) k++;
		ht[rk[i]] = k;
	}
}

int main() {
    scanf("%d%s", &n, s + 1);
    SA();
    LL ans = 1ll * (n + 1) * n / 2;
    for (int i = 2; i <= n; i++) ans -= ht[i];
    cout << ans << endl;
    return 0;
}
updatedupdated2020-05-112020-05-11
加载评论
点击刷新