1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
| #include <bits/stdc++.h> #define LL long long using namespace std; const int N = 70; LL n, m, k; int d[2][N][2][2][2], p, mxl; bool vis[N][2][2][2]; void dp(int len, int fn, int fm, int fk, int &as0, int &as1) { int &f = d[0][len][fn][fm][fk], &g = d[1][len][fn][fm][fk]; if (len > mxl) return as0 = 1, as1 = 0, void(); if (vis[len][fn][fm][fk]) return as0 = f, as1 = g, void(); vis[len][fn][fm][fk] = true; int pn = (n >> mxl - len) & 1, pm = (m >> mxl - len) & 1, pk = (k >> mxl - len) & 1; for (int i = 0; i <= (fn ? pn : 1); ++i) for (int j = 0, pf, pg; j <= (fm ? pm : 1); ++j) if (!fk || (i ^ j) >= pk) { dp(len + 1, fn && (i == pn), fm && (j == pm), fk && ((i ^ j) == pk), pf, pg); f = (f + pf) % p, g = ((LL)g + (1ll << mxl - len) * (i ^ j) % p * pf + pg) % p; } as0 = f, as1 = g; } int main() { int T, ct, ans[2]; LL t; for (scanf("%d", &T); T--; memset(vis, false, sizeof vis), memset(d, 0, sizeof d), mxl = 0) { scanf("%lld %lld %lld %d", &n, &m, &k, &p), --n, --m; for (ct = 0, t = n; t; ++ct, t >>= 1);mxl = max(mxl, ct); for (ct = 0, t = m; t; ++ct, t >>= 1);mxl = max(mxl, ct); for (ct = 0, t = k; t; ++ct, t >>= 1);mxl = max(mxl, ct); dp(1, 1, 1, 1, ans[0], ans[1]); printf("%d\n", ((LL)ans[1] - k % p * ans[0] % p + p) % p); } return 0; }
|