60.差分约束

差分约束

差分约束系统 是一种特殊的 n 元一次不等式组,它包含 n 个变量 x1,x2,,xn 以及 m 个约束条件,每个约束条件是由两个其中的变量做差构成的,形如 xixjck,其中 1i,jn,ij,1km 并且 ck 是常数(可以是非负数,也可以是负数)

我们要解决的问题是:求一组解 x1=a1,x2=a2,,xn=an,使得所有的约束条件得到满足,否则判断出无解

很多题目会给出(或隐性地给出)一系列的不等关系,我们可以尝试把它们转化为差分约束系统来解决。

我们设 x1x2c ,移项得 x1x2+c

观察这个不等式与最短路问题中的三角形不等式 dist[u]dist[v]+wu,v 的相似之处,利用这一点,我们可以把它转化为一个图论问题。也就是说,对于每一个 xixjc 我们从 xjxi 建一有向条边,边权为 c

这样建出的有向图,它的每个顶点都对应差分约束系统中的一个未知量,源点到每个顶点的最短路对应这些未知量的值,而每条边对应一个约束条件

image.png

那么问题来了,既然是最短路,源点在哪里呢?

实际上取哪个点为源点是无关紧要的,但是,有时候我们得到的图不是连通的,这样求出来的结果很容易出现INF。为了避免这种情形,我们习惯人为地增加一个超级源点

例如我们现在人为地新增一个 0 号点(或 n+1 号点),从它向所有顶点连一条边权为 0 的边

image.png

现在我们以 0 号点为源点求各点的最短路即可。注意,这相当于了添加了以下约束条件:

{x1x00x2x00x3x00

由于 x0 对应的是 dist[0] ,而 dist[0]=0 可知所有未知量均小于等于 0

因为这样求出来的只是一组解,显然,如果 {x1,x2,,xn} 是一组解,那么对于任意常数 d{x1+d,x2+d,,xn+d} 也是一组解

那么如果题目要求 x1,x2,,xnw 呢,那么可以把 dist[0] 设为 w 或者把从 0 号结点连向各点的边权设为 w,事实上,可以证明他们是满足 x1,x2,,xnw 的最大解(每个变量能取到的最大值)

那么如何求满足 x1,x2,,xnw 的最小解呢?只需要求最长路就行了。最长路满足三角形不等式 dist[u]dist[v]+wu,v,所以差分约束系统需要把小于等于换成大于等于。对于 SPFA 算法来说,需要初始化为 INF 而不是 INF,然后把比较符号颠倒一下即可

洛谷 P5960 【模板】差分约束

#include <bits/stdc++.h>
using namespace std;
typedef pair<int, int> pii;
const int INF = 0x3f3f3f3f;

int main() {
    int n, m; cin >> n >> m;
    vector<vector<pii>> g(n + 1);
    for (int i = 1; i <= m; i++) {
        int x, y, c; cin >> x >> y >> c;
        g[y].push_back({x, c});
    }
    for (int i = 1; i <= n; i++) g[0].push_back({i, 0});

    vector<int> dis(n + 1, INF), vis(n + 1, 0), cnt(n + 1, 0);
    
    auto spfa = [&]() -> bool {
        queue<int> q;
        q.push(0); vis[0] = 1; dis[0] = 0; cnt[0] = 1;
        while (!q.empty()) {
            int u = q.front(); q.pop(); vis[u] = 0;
            for (auto [v, w] : g[u]) {
                if (dis[v] > dis[u] + w) {
                    dis[v] = dis[u] + w;
                    if (!vis[v]) {
                        q.push(v); vis[v] = 1;
                        if (++cnt[v] > n) return false;
                    }
                }
            }
        }
        return true;
    };

    if (spfa()) {
        for (int i = 1; i <= n; i++) 
            cout << dis[i] << ' ';
    }
    else 
        printf ("NO\n");
    return 0;
}