如题,已知一个数列,你需要进行下面三种操作:
将某区间每一个数乘上 x x x
将某区间每一个数加上 x x x
求出某区间每一个数的和
第一行包含三个整数 n , m , p n,m,p n,m,p,分别表示该数列数字的个数、操作的总个数和模数。
第二行包含 n n n 个用空格分隔的整数,其中第 i i i 个数字表示数列第 i i i 项的初始值。
接下来 m m m 行每行包含若干个整数,表示一个操作,具体如下:
操作
1
1
1: 格式:1 x y k
含义:将区间
[
x
,
y
]
[x,y]
[x,y] 内每个数乘上
k
k
k
操作
2
2
2: 格式:2 x y k
含义:将区间
[
x
,
y
]
[x,y]
[x,y] 内每个数加上
k
k
k
操作
3
3
3: 格式:3 x y
含义:输出区间
[
x
,
y
]
[x,y]
[x,y] 内每个数的和对
p
p
p 取模所得的结果
输出包含若干行整数,即为所有操作 3 3 3 的结果。
5 5 38
1 5 4 2 3
2 1 4 1
3 2 5
1 2 4 2
2 3 5 5
3 1 4
17
2
【数据范围】
对于
30
%
30\%
30% 的数据:
n
≤
8
n \le 8
n≤8,
m
≤
10
m \le 10
m≤10
对于
70
%
70\%
70% 的数据:$n \le 10^3
,
,
, m \le 10^4$
对于
100
%
100\%
100% 的数据:$ n \le 10^5
,
,
, m \le 10^5$
除样例外, p = 571373 p = 571373 p=571373
样例说明:
故输出应为 17 17 17、 2 2 2( 40 m o d 38 = 2 40 \bmod 38 = 2 40mod38=2 )
s[pos].add = (s[pos].add + k) % mod;
s[pos].sum = (s[pos].sum + k * (s[pos].r - s[pos].l + 1)) % mod;
这里就有点不一样了。
先把 mul
和 sum
乘上 k
。
对于之前已经有的 add
,把它乘上 k
即可。在这里,我们把乘之后的值直接更新add的值。
你想, add
其实应该加到 sum
里面,所有乘上 k
后,运用乘法分配律, (sum + add) * k == sum * k + add * k
。
这样来实现 add
和 sum
有序进行。
s[pos].add = (s[pos].add * k) % mod;
s[pos].mul = (s[pos].mul * k) % mod;
s[pos].sum = (s[pos].sum * k) % mod;
现在要下传两个标记: add
和 mul
。
sum
:因为 add
之前已经乘过,所以在子孩子乘过 mul
后直接加就行。
mul
:直接乘。
add
:因为 add
的值是要包括乘之后的值,所以子孩子要先乘上 mul
。
s[pos << 1].sum = (s[pos << 1].sum * s[pos].mul + s[pos].add * (s[pos << 1].r - s[pos << 1].l + 1)) % mod;
s[pos << 1].mul = (s[pos << 1].mul * s[pos].mul) % mod;
s[pos << 1].add = (s[pos << 1].add * s[pos].mul + s[pos].add) % mod;
在此注释: <<
和 |
是位运算,n << 1 == n * 2
,n << 1 | 1 == n * 2 + 1
(再具体的自己百度)。
import java.io.*;
class Node {
int l;
int r;
long sum;
long add;
long mul;
public Node(int l, int r, long sum, long add, long mul) {
this.l = l;
this.r = r;
this.sum = sum;
this.add = add;
this.mul = mul;
}
}
public class Main {
static BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(System.out));
static int MAXN = 100010;
static int[] a = new int[MAXN];
static Node[] s = new Node[MAXN * 4];
static int n;
static int m;
static int mod;
public static void main(String[] args) throws IOException {
Read read = new Read();
String[] s0 = read.getStringLine().split(" ");
n = Integer.parseInt(s0[0]);
m = Integer.parseInt(s0[1]);
mod = Integer.parseInt(s0[2]);
String[] s2 = read.getStringLine().split(" ");
for (int i = 1; i <= n; i++) {
a[i] = Integer.parseInt(s2[i - 1]);
}
for (int i = 1; i < s.length; i++) {
s[i] = new Node(0,0,0,0,1);
}
buildTree(1, 1, n);
for (int i = 1; i <= m; i++) {
int opt;
int x;
int y;
String[] si = read.getStringLine().split(" ");
opt = Integer.parseInt(si[0]);
x = Integer.parseInt(si[1]);
y = Integer.parseInt(si[2]);
if (opt == 1) {
int k = Integer.parseInt(si[3]);
ChangeMul(1, x, y, k);
} else if (opt == 2) {
int k = Integer.parseInt(si[3]);
ChangeAdd(1, x, y, k);
} else if (opt == 3) {
writer.write(AskRange(1, x, y) + "\n");
}
}
writer.flush();
writer.close();
}
static void update(int pos) {
s[pos].sum = (s[pos << 1].sum + s[pos << 1 | 1].sum) % mod;
}
static void pushdown(int pos) {
s[pos << 1].sum = (s[pos << 1].sum * s[pos].mul + s[pos].add * (s[pos << 1].r - s[pos << 1].l + 1)) % mod;
s[pos << 1 | 1].sum = (s[pos << 1 | 1].sum * s[pos].mul + s[pos].add * (s[pos << 1 | 1].r - s[pos << 1 | 1].l + 1)) % mod;
s[pos << 1].mul = (s[pos << 1].mul * s[pos].mul) % mod;
s[pos << 1 | 1].mul = (s[pos << 1 | 1].mul * s[pos].mul) % mod;
s[pos << 1].add = (s[pos << 1].add * s[pos].mul + s[pos].add) % mod;
s[pos << 1 | 1].add = (s[pos << 1 | 1].add * s[pos].mul + s[pos].add) % mod;
s[pos].add = 0;
s[pos].mul = 1;
}
static void buildTree(int pos, int l, int r) { //建树
s[pos].l = l;
s[pos].r = r;
s[pos].mul = 1;
if (l == r) {
s[pos].sum = a[l] % mod;
return;
}
int mid = (l + r) >> 1;
buildTree(pos << 1, l, mid);
buildTree(pos << 1 | 1, mid + 1, r);
update(pos);
}
static void ChangeMul(int pos, int x, int y, int k) { //区间乘法
if (x <= s[pos].l && s[pos].r <= y) {
s[pos].add = (s[pos].add * k) % mod;
s[pos].mul = (s[pos].mul * k) % mod;
s[pos].sum = (s[pos].sum * k) % mod;
return;
}
pushdown(pos);
int mid = (s[pos].l + s[pos].r) >> 1;
if (x <= mid) {
ChangeMul(pos << 1, x, y, k);
}
if (y > mid) {
ChangeMul(pos << 1 | 1, x, y, k);
}
update(pos);
return;
}
static void ChangeAdd(int pos, int x, int y, int k) { //区间加法
if (x <= s[pos].l && s[pos].r <= y) {
s[pos].add = (s[pos].add + k) % mod;
s[pos].sum = (s[pos].sum + (long) k * (s[pos].r - s[pos].l + 1)) % mod;
return;
}
pushdown(pos);
int mid = (s[pos].l + s[pos].r) >> 1;
if (x <= mid) {
ChangeAdd(pos << 1, x, y, k);
}
if (y > mid) {
ChangeAdd(pos << 1 | 1, x, y, k);
}
update(pos);
return;
}
static long AskRange(int pos, int x, int y) { //区间询问
if (x <= s[pos].l && s[pos].r <= y) {
return s[pos].sum;
}
pushdown(pos);
long val = 0;
int mid = (s[pos].l + s[pos].r) >> 1;
if (x <= mid) {
val = (val + AskRange(pos << 1, x, y)) % mod;
}
if (y > mid) {
val = (val + AskRange(pos << 1 | 1, x, y)) % mod;
}
return val;
}
}
class Read {
BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
StreamTokenizer st = new StreamTokenizer(new InputStreamReader(System.in));
public int nextInt() throws IOException {
st.nextToken();
return (int) st.nval;
}
public double nextDouble() throws IOException {
st.nextToken();
return st.nval;
}
public String nextString() throws IOException {
st.nextToken();
return st.sval;
}
public String getStringLine() throws IOException {
return reader.readLine();
}
}
参考这位大佬提供的C++语言版本的模板