题意:
$(n + 1) \times (m + 1)$的网格图上,有$K$个点不能用,问剩下的点能组成多少正方形。斜着的也算
题解:
loj上有更棒的样例,还有解释
我们可以先不管不能用的点,然后减去不合法的正方形。正着的正方形比较好搞,对于斜着的,我们可以用一个大正方形框起来,然后可以发现如果确定了大正方形和它上面的一个点,那么我们就可以确定一个斜着的正方形,而且是一一对应的。正着的正方形也可以归到里面。所有正方形显然就是$\displaystyle \sum_{i = 1}^{\min(n, m)} (n - i + 1)(m - i + 1)i$。不合法的正方形的个数可以用容斥来搞,就是-一个点不合法的+两个点不合法的-三个点不合法的+四个点不合法的。一个点不合法的方案就是这个不合法的点在多少正方形上。我们可以分别计算它在上下左右每条边上的方案数,减去四个角上记重的就好了。两个点就$K^2$枚举吧,注意它们作为对角线时的情况。如果确定了两个点,那么第三个和第四个点都可以出来了,就可以顺便计算三个点不合法和四个点不合法的方案数了。复杂度$O(K^2logK)$,套了一个set
来看第三和第四个点是不是不能用的。
代码:
#include <bits/stdc++.h>
using namespace std;
const int mod = 100000007, inv3 = (mod + 1) / 3, inv6 = (mod + 1) / 6;
long long n, m, K, ans = 0;
struct pnt
{
long long x, y;
} a[2010];
set<long long> mp;
long long calc(long long n, long long m, long long pos)
{
long long mn = min(pos, m - pos), mx = max(pos, m - pos);
if (m < n)
return (mn + (1 + mn) * mn / 2 % mod + (mx - mn) * (mn + 1) % mod + (1 + m - mx) * (m - mx) / 2 % mod) % mod;
else
{
if (n < mn) return (2 + n + 1) * n / 2 % mod;
else
{
long long ans = (2 + mn + 1) * mn / 2 % mod;
if (n < mx) return (ans + (mn + 1) * (n - mn) % mod) % mod;
else return (ans + (mn + 1) * (mx - mn) % mod + (m - n + 1 + m - mx) * (n - mx) / 2 % mod) % mod;
}
}
}
int main()
{
scanf("%lld%lld%lld", &n, &m, &K);
for (int i = 1; i <= min(n, m); i++)
ans = (ans + (n - i + 1) * (m - i + 1) % mod * i % mod) % mod;
// printf("%lld\n", ans);
for (int i = 0; i < K; i++)
{
scanf("%lld%lld", &a[i].x, &a[i].y);
(ans -= calc(a[i].x, m, a[i].y)) %= mod;
// printf("%lld\n", ans);
(ans -= calc(m - a[i].y, n, a[i].x)) %= mod;
// printf("%lld\n", ans);
(ans -= calc(n - a[i].x, m, m - a[i].y)) %= mod;
// printf("%lld\n", ans);
(ans -= calc(a[i].y, n, n - a[i].x)) %= mod;
// printf("%lld\n", ans);
(ans += min(a[i].x, a[i].y) + min(a[i].x, m - a[i].y) + min(n - a[i].x, a[i].y) + min(n - a[i].x, m - a[i].y)) %= mod;
// printf("%lld\n", ans);
mp.insert(a[i].x * 1000010 + a[i].y);
}
long long hh3 = 0, hh4 = 0;
for (int i = 0; i < K; i++)
for (int j = i + 1; j < K; j++)
{
long long xx = a[j].x - a[i].x, yy = a[j].y - a[i].y;
long long x1 = a[i].x + yy, y1 = a[i].y - xx, x2 = a[j].x + yy, y2 = a[j].y - xx;
if (x1 >= 0 && x1 <= n && x2 >= 0 && x2 <= n && y1 >= 0 && y1 <= m && y2 >= 0 && y2 <= m)
{
// printf("%lld %lld\n%lld %lld\n%lld %lld\n%lld %lld\n\n", a[i].x, a[i].y, a[j].x, a[j].y, x1, y1, x2, y2);
ans++;
bool tf1 = mp.count(x1 * 1000010 + y1), tf2 = mp.count(x2 * 1000010 + y2);
if (tf1) hh3++;
if (tf2) hh3++;
if (tf1 && tf2) hh4++;
}
x1 = a[i].x - yy, y1 = a[i].y + xx, x2 = a[j].x - yy, y2 = a[j].y + xx;
if (x1 >= 0 && x1 <= n && x2 >= 0 && x2 <= n && y1 >= 0 && y1 <= m && y2 >= 0 && y2 <= m)
{
// printf("%lld %lld\n%lld %lld\n%lld %lld\n%lld %lld\n\n", a[i].x, a[i].y, a[j].x, a[j].y, x1, y1, x2, y2);
ans++;
bool tf1 = mp.count(x1 * 1000010 + y1), tf2 = mp.count(x2 * 1000010 + y2);
if (tf1) hh3++;
if (tf2) hh3++;
if (tf1 && tf2) hh4++;
}
if (abs(xx % 2) == abs(yy % 2))
{
x1 = a[i].x + (xx + yy) / 2, y1 = a[i].y + (yy - xx) / 2, x2 = x1 - yy, y2 = y1 + xx;
if (x1 >= 0 && x1 <= n && x2 >= 0 && x2 <= n && y1 >= 0 && y1 <= m && y2 >= 0 && y2 <= m)
{
// printf("%lld %lld\n%lld %lld\n%lld %lld\n%lld %lld\n\n", a[i].x, a[i].y, a[j].x, a[j].y, x1, y1, x2, y2);
ans++;
bool tf1 = mp.count(x1 * 1000010 + y1), tf2 = mp.count(x2 * 1000010 + y2);
if (tf1) hh3++;
if (tf2) hh3++;
if (tf1 && tf2) hh4++;
}
}
}
printf("%lld\n", ((ans - hh3 * inv3 % mod + hh4 * inv6 % mod) % mod + mod) % mod);
}