tkj
文章143
标签102
分类0
bzoj 4558: [JLoi2016]方

bzoj 4558: [JLoi2016]方

题意:

$(n + 1) \times (m + 1)$的网格图上,有$K$个点不能用,问剩下的点能组成多少正方形。斜着的也算

题解:

loj上有更棒的样例,还有解释
我们可以先不管不能用的点,然后减去不合法的正方形。正着的正方形比较好搞,对于斜着的,我们可以用一个大正方形框起来,然后可以发现如果确定了大正方形和它上面的一个点,那么我们就可以确定一个斜着的正方形,而且是一一对应的。正着的正方形也可以归到里面。所有正方形显然就是$\displaystyle \sum_{i = 1}^{\min(n, m)} (n - i + 1)(m - i + 1)i$。不合法的正方形的个数可以用容斥来搞,就是-一个点不合法的+两个点不合法的-三个点不合法的+四个点不合法的。一个点不合法的方案就是这个不合法的点在多少正方形上。我们可以分别计算它在上下左右每条边上的方案数,减去四个角上记重的就好了。两个点就$K^2$枚举吧,注意它们作为对角线时的情况。如果确定了两个点,那么第三个和第四个点都可以出来了,就可以顺便计算三个点不合法和四个点不合法的方案数了。复杂度$O(K^2logK)$,套了一个set来看第三和第四个点是不是不能用的。

代码:

#include <bits/stdc++.h>
using namespace std;

const int mod = 100000007, inv3 = (mod + 1) / 3, inv6 = (mod + 1) / 6;
long long n, m, K, ans = 0;
struct pnt
{
    long long x, y;
} a[2010];
set<long long> mp;

long long calc(long long n, long long m, long long pos)
{
    long long mn = min(pos, m - pos), mx = max(pos, m - pos);
    if (m < n)
        return (mn + (1 + mn) * mn / 2 % mod + (mx - mn) * (mn + 1) % mod + (1 + m - mx) * (m - mx) / 2 % mod) % mod;
    else
    {
        if (n < mn) return (2 + n + 1) * n / 2 % mod;
        else
        {
            long long ans = (2 + mn + 1) * mn / 2 % mod;
            if (n < mx) return (ans + (mn + 1) * (n - mn) % mod) % mod;
            else return (ans + (mn + 1) * (mx - mn) % mod + (m - n + 1 + m - mx) * (n - mx) / 2 % mod) % mod;
        }
    }
}
int main()
{
    scanf("%lld%lld%lld", &n, &m, &K);
    for (int i = 1; i <= min(n, m); i++)
        ans = (ans + (n - i + 1) * (m - i + 1) % mod * i % mod) % mod;
    // printf("%lld\n", ans);
    for (int i = 0; i < K; i++)
    {
        scanf("%lld%lld", &a[i].x, &a[i].y);
        (ans -= calc(a[i].x, m, a[i].y)) %= mod;
        // printf("%lld\n", ans);
        (ans -= calc(m - a[i].y, n, a[i].x)) %= mod;
        // printf("%lld\n", ans);
        (ans -= calc(n - a[i].x, m, m - a[i].y)) %= mod;
        // printf("%lld\n", ans);
        (ans -= calc(a[i].y, n, n - a[i].x)) %= mod;
        // printf("%lld\n", ans);
        (ans += min(a[i].x, a[i].y) + min(a[i].x, m - a[i].y) + min(n - a[i].x, a[i].y) + min(n - a[i].x, m - a[i].y)) %= mod;
        // printf("%lld\n", ans);
        mp.insert(a[i].x * 1000010 + a[i].y);
    }
    long long hh3 = 0, hh4 = 0;
    for (int i = 0; i < K; i++)
        for (int j = i + 1; j < K; j++)
        {
            long long xx = a[j].x - a[i].x, yy = a[j].y - a[i].y;
            long long x1 = a[i].x + yy, y1 = a[i].y - xx, x2 = a[j].x + yy, y2 = a[j].y - xx;
            if (x1 >= 0 && x1 <= n && x2 >= 0 && x2 <= n && y1 >= 0 && y1 <= m && y2 >= 0 && y2 <= m)
            {
            // printf("%lld %lld\n%lld %lld\n%lld %lld\n%lld %lld\n\n", a[i].x, a[i].y, a[j].x, a[j].y, x1, y1, x2, y2);
                ans++;
                bool tf1 = mp.count(x1 * 1000010 + y1), tf2 = mp.count(x2 * 1000010 + y2);
                if (tf1) hh3++;
                if (tf2) hh3++;
                if (tf1 && tf2) hh4++;
            }
            x1 = a[i].x - yy, y1 = a[i].y + xx, x2 = a[j].x - yy, y2 = a[j].y + xx;
            if (x1 >= 0 && x1 <= n && x2 >= 0 && x2 <= n && y1 >= 0 && y1 <= m && y2 >= 0 && y2 <= m)
            {
            // printf("%lld %lld\n%lld %lld\n%lld %lld\n%lld %lld\n\n", a[i].x, a[i].y, a[j].x, a[j].y, x1, y1, x2, y2);
                ans++;
                bool tf1 = mp.count(x1 * 1000010 + y1), tf2 = mp.count(x2 * 1000010 + y2);
                if (tf1) hh3++;
                if (tf2) hh3++;
                if (tf1 && tf2) hh4++;
            }
            if (abs(xx % 2) == abs(yy % 2))
            {
                x1 = a[i].x + (xx + yy) / 2, y1 = a[i].y + (yy - xx) / 2, x2 = x1 - yy, y2 = y1 + xx;
                if (x1 >= 0 && x1 <= n && x2 >= 0 && x2 <= n && y1 >= 0 && y1 <= m && y2 >= 0 && y2 <= m)
                {
                // printf("%lld %lld\n%lld %lld\n%lld %lld\n%lld %lld\n\n", a[i].x, a[i].y, a[j].x, a[j].y, x1, y1, x2, y2);
                    ans++;
                    bool tf1 = mp.count(x1 * 1000010 + y1), tf2 = mp.count(x2 * 1000010 + y2);
                    if (tf1) hh3++;
                    if (tf2) hh3++;
                    if (tf1 && tf2) hh4++;
                }
            }
        }
    printf("%lld\n", ((ans - hh3 * inv3 % mod + hh4 * inv6 % mod) % mod + mod) % mod);
}
本文作者:tkj
本文链接:https://tkj666.github.io/137/
版权声明:本文采用 CC BY-NC-SA 3.0 CN 协议进行许可