题意:
给你一个长度为$n$的序列,问有多少字串满足除了中间长度为$b - 1$的子串,前后走势完全一样。
题解:
感觉我的做法好复杂。。。
首先把原序列差分,然后就变成了有多少个字串满足除了中间长度为$b$的子串,前后完全一样。考虑以前后两个一样的串开始的两个后缀$i$和$j$,它们可以贡献答案,就要满足$|i - j| - lcp \le b$,但两个后缀也不能靠太近,否则怎么都不能腾出$b$的空位来,所以也要满足$|i - j| > b$。这样我们就可以对后缀数组分治,每次找到当前区间最小的$height$的位置,确定它两边后缀的lcp,然后用主席树求出由两边的后缀贡献的答案,递归下去搞就好了。注意最小的$height$不一定在中点,所以我们要找两边中较短的一段处理,这样复杂度才有保证。应该是$O(nlog^2n)$的。
代码:
#include <bits/stdc++.h>
using namespace std;
int n, b, a[50010], sa[50010], h[50010], rk[50010], height[50010], lg[50010], ans = 0, mn[16][50010];
map<int, int> lsh;
struct tree
{
int c;
tree *lc, *rc;
tree() : lc(NULL), rc(NULL), c(0) {}
} *root[50010];
void radix(int *a, int *b, int *s, int n, int K)
{
int *hh = new int[K + 1];
for (int i = 0; i <= K; i++)
hh[i] = 0;
for (int i = 0; i < n; i++)
hh[s[a[i]]]++;
for (int i = 1; i <= K; i++)
hh[i] += hh[i - 1];
for (int i = n - 1; i >= 0; i--)
b[--hh[s[a[i]]]] = a[i];
delete[] hh;
}
void get_sa(const int *ss, int *sa, int n, int K)
{
int *s = new int[n << 1], *hh = new int[n];
memset(s, 0, sizeof(int) * (n << 1));
memcpy(s, ss, sizeof(int) * n);
for (int i = 0; i < n; i++)
hh[i] = i;
radix(hh, sa, s, n, K);
for (int len = 1; len < n; len <<= 1)
{
int name = 0, c1 = -1, c2 = -1;
for (int i = 0; i < n; i++)
{
if (s[sa[i]] != c1 || s[sa[i] + (len >> 1)] != c2)
{
name++;
c1 = s[sa[i]], c2 = s[sa[i] + (len >> 1)];
}
hh[sa[i]] = name;
}
swap(s, hh);
for (int i = n - len; i < n; i++)
hh[i - n + len] = i;
for (int i = 0, num = len; i < n; i++)
if (sa[i] >= len)
hh[num++] = sa[i] - len;
radix(hh, sa, s, n, name);
}
}
int get_mnpos(int l, int r)
{
return
height[mn[lg[r - l]][l]] < height[mn[lg[r - l]][r - (1 << lg[r - l])]] ?
mn[lg[r - l]][l]
:
mn[lg[r - l]][r - (1 << lg[r - l])];
}
tree *ins(tree *ori, int l, int r, int p)
{
tree *i = new tree;
if (ori) *i = *ori;
i->c++;
if (l == r) return i;
int md = l + r >> 1;
if (p <= md) i->lc = ins(ori ? ori->lc : NULL, l, md, p);
else i->rc = ins(ori ? ori->rc : NULL, md + 1, r, p);
return i;
}
int get(tree *ori, tree *i, int l, int r, int ll, int rr)
{
if (ll < 0) ll = 0;
if (rr >= n) rr = n - 1;
if (!i) return 0;
if (l == ll && r == rr) return i->c - (ori ? ori->c : 0);
int md = l + r >> 1;
if (rr <= md) return get(ori ? ori->lc : NULL, i->lc, l, md, ll, rr);
else if (ll > md) return get(ori ? ori->rc : NULL, i->rc, md + 1, r, ll, rr);
else return get(ori ? ori->lc : NULL, i->lc, l, md, ll, md) + get(ori ? ori->rc : NULL, i->rc, md + 1, r, md + 1, rr);
}
void wk(int l, int r)
{
if (l >= r) return;
int pos = get_mnpos(l, r);
if (height[pos])
{
if (pos - l + 1 < r - pos)
{
for (int i = l; i <= pos; i++)
{
ans += get(root[pos], root[r], 0, n - 1, sa[i] - height[pos] - b, sa[i] - b - 1);
ans += get(root[pos], root[r], 0, n - 1, sa[i] + b + 1, sa[i] + height[pos] + b);
}
}
else
{
for (int i = pos + 1; i <= r; i++)
{
ans += get(l == 0 ? NULL : root[l - 1], root[pos], 0, n - 1, sa[i] - height[pos] - b, sa[i] - b - 1);
ans += get(l == 0 ? NULL : root[l - 1], root[pos], 0, n - 1, sa[i] + b + 1, sa[i] + height[pos] + b);
}
}
}
wk(l, pos); wk(pos + 1, r);
}
int main()
{
scanf("%d%d", &n, &b);
if (n == 1)
{
puts("0");
return 0;
}
for (int i = 0; i < n; i++)
scanf("%d", &a[i]);
for (int i = 0; i < n - 1; i++)
lsh[a[i] = a[i + 1] - a[i]] = 1;
int nn = 0;
for (map<int, int>::iterator it = lsh.begin(); it != lsh.end(); it++)
it->second = ++nn;
for (int i = 0; i < n - 1; i++)
a[i] = lsh[a[i]];
a[--n] = 0;
get_sa(a, sa, n, nn);
for (int i = 0; i < n; i++)
rk[sa[i]] = i;
sa[n] = n;
for (int i = 0, k = 0; i < n; i++)
{
if (k) k--;
int j = sa[rk[i] + 1];
while (a[i + k] == a[j + k]) k++;
height[rk[i]] = k;
}
for (int i = 0; i < n; i++)
mn[0][i] = i;
lg[1] = 0;
for (int i = 2; i <= n; i++)
lg[i] = lg[i >> 1] + 1;
for (int i = 1; i <= 15; i++)
{
for (int j = 0; j < n; j++)
{
mn[i][j] = mn[i - 1][j];
if (j + (1 << i - 1) < n && height[mn[i - 1][j + (1 << i - 1)]] < height[mn[i][j]]) mn[i][j] = mn[i - 1][j + (1 << i - 1)];
}
}
root[0] = ins(NULL, 0, n - 1, sa[0]);
for (int i = 1; i < n; i++)
root[i] = ins(root[i - 1], 0, n - 1, sa[i]);
wk(0, n - 1);
printf("%d\n", ans);
}