∑ i = 1 n [ a + ⌊ i c ⌋ = b + ⌊ i d ⌋ ] \sum_{i=1}^{n}[a+\lfloor\frac{i}{c}\rfloor=b+\lfloor\frac{i}{d}\rfloor] i=1∑n[a+⌊ci⌋=b+⌊di⌋]
不妨设
c
<
d
c
设式子左右边分别为 f ( i ) , g ( i ) f(i),g(i) f(i),g(i),则答案即为 f ( i ) − g ( i ) = 0 f(i)-g(i)=0 f(i)−g(i)=0 的个数。
不难发现 f ( i ) − g ( i ) f(i)-g(i) f(i)−g(i) 这个值相邻的差的绝对值不超过 1 1 1,所以不会出现 − 1 -1 −1 和 1 1 1 相邻的情况。
不难想到, ⌊ i c ⌋ \lfloor\frac{i}{c}\rfloor ⌊ci⌋ 的增长速度更快,故 f ( i ) − g ( i ) f(i)-g(i) f(i)−g(i) 不严格递增,且当 f ( i ) − g ( i ) = 2 f(i)-g(i)=2 f(i)−g(i)=2 时,后面只可能出现 1 1 1 或更大的数,不可能出现 0 0 0。
故二分出第一个出现的 0 0 0、 1 1 1、 2 2 2 的位置 l , p , r l,p,r l,p,r;
则有 ∀ i ∈ [ l , p − 1 ] \forall i\in[l,p-1] ∀i∈[l,p−1], f ( i ) − g ( i ) ∈ [ − 1 , 0 ] f(i)-g(i)\in[-1,0] f(i)−g(i)∈[−1,0]; ∀ i ∈ [ p , r − 1 ] \forall i\in [p,r-1] ∀i∈[p,r−1], f ( i ) − g ( i ) ∈ [ 0 , 1 ] f(i)-g(i)\in[0,1] f(i)−g(i)∈[0,1]。
O ( 1 ) \mathcal O(1) O(1) 求 ∑ i = l r ⌊ i c ⌋ \sum\limits_{i=l}^{r}\lfloor\frac{i}{c}\rfloor i=l∑r⌊ci⌋ 后用长度减去后加上和,讨论一下即可。
时间复杂度 O ( T log m ) \mathcal O(T\log m) O(Tlogm)。
#include
using namespace std;
#define int long long
typedef long long ll;
#define ha putchar(' ')
#define he putchar('\n')
inline int read() {
int x = 0, f = 1;
char c = getchar();
while (c < '0' || c > '9') {
if (c == '-')
f = -1;
c = getchar();
}
while (c >= '0' && c <= '9')
x = x * 10 + c - '0', c = getchar();
return x * f;
}
inline void write(int x) {
if (x < 0) {
putchar('-');
x = -x;
}
if (x > 9)
write(x / 10);
putchar(x % 10 + '0');
}
int m, ax, bx, ay, by;
int calc(int x, int y, int p) {
int l = 0, r = m / x, res = r + 1;
while (l <= r) {
int mid = (l + r) >> 1;
if (mid - mid * x / y >= p) res = mid, r = mid - 1;
else l = mid + 1;
}
return res * x;
}
int S(int n, int x, int c) {
int k = n / c, res = c * k * (k + 1) / 2;
if (n != (k + 1) * c - 1) res = res - k * ((k + 1) * c - 1 - n);
return n * x + res;
}
int sum(int l, int r) {
if (r < l) return 0;
return S(r, ax, bx) - S(l - 1, ax, bx) - S(r, ay, by) + S(l - 1, ay, by);
}
signed main() {
int T = read();
while (T--) {
m = read(), ax = read(), bx = read(), ay = read(), by = read();
if (bx == by) {
if (ax == ay) write(m), he;
else write(0), he;
continue;
}
if (bx > by)
swap(ax, ay), swap(bx, by);
if (ax > ay) {
write(0), he;
continue;
}
int l = calc(bx, by, ay - ax), p = calc(bx, by, ay - ax + 1), r = calc(bx, by, ay - ax + 2);
l = max(l, 1ll);
if (l > m) {
write(0), he;
continue;
}
if (p > m) {
write(m + 1 - l + sum(l + 1, m)), he;
continue;
}
if (r > m) r = m + 1;
write(p - l + sum(l + 1, p - 1) + r - p - sum(p, r - 1)), he;
}
return 0;
}