• 2022牛客多校训练第二场 J题 Link with Arithmetic Progression


    题目链接

    Link with Arithmetic Progression

    题目大意

    给定一个数组,让我们找到一条拟合它的直线,要求均方误差最小。相信大家在数学课或者机器学习课上都学过相关的知识。

    题解

    T a r g e t = m i n ( ∑ i = 1 n [ a 1 + ( i − 1 ) × d − a i ] 2 ) Target = min(\sum^{n}_{i = 1} [a_1 + (i - 1) \times d - a_i]^2) Target=min(i=1n[a1+(i1)×dai]2)
    这就是我们要求的值,我们需要让其最小,我们可以先考虑把 a 1 a_1 a1求出来,很明显,随着 a 1 a_1 a1从负无穷到正无穷的变化过程中,Target值是先下降后上升的,也就是说 a 1 a_1 a1是一个凹函数,因此我们可以考虑用三分求 a 1 a_1 a1,然后在考虑在 a 1 a_1 a1确定的情况下,如何确定 d d d
    我们对公式进行变形:
    = [ a 1 + ( i − 1 ) × d ] 2 − 2 × [ a 1 + ( i − 1 ) × d ] × a i + a i 2 = [a_1 + (i - 1)\times d]^2 - 2 \times [a_1 + (i - 1) \times d] \times a_i + {a_i}^2 =[a1+(i1)×d]22×[a1+(i1)×d]×ai+ai2
    然后,将带有 d d d的合并同类项。
    = ( i − 1 ) 2 × d 2 + 2 × ( i − 1 ) × ( a 1 − a i ) × d + a 1 2 + a i 2 = (i - 1) ^2 \times d^2 + 2 \times (i - 1) \times (a_1 - a_i) \times d + {a_1}^2 + {a_i}^2 =(i1)2×d2+2×(i1)×(a1ai)×d+a12+ai2
    根据二次函数的知识,我们可以很轻松的确定 d = − b 2 ∗ ( i − 1 ) 2 d = \frac {-b} {2 * (i - 1) ^ 2} d=2(i1)2b
    然后,就可以进行计算了。
    话不多说,上代码:

    代码

    #include
    #include
    #include 
    #include 
    using namespace std;
    #define int long long
    #define double long double
    const int maxn = 100010;
    int w[maxn];
    int n;
    
    namespace GTI
    {
        char gc(void)
           {
            const int S = 1 << 16;
            static char buf[S], *s = buf, *t = buf;
            if (s == t) t = buf + fread(s = buf, 1, S, stdin);
            if (s == t) return EOF;
            return *s++;
        }
        int gti(void)
           {
            int a = 0, b = 1, c = gc();
            for (; !isdigit(c); c = gc()) b ^= (c == '-');
            for (; isdigit(c); c = gc()) a = a * 10 + c - '0';
            return b ? a : -a;
        }
    }
    using GTI::gti;
    double check(double a1)
    {
    	double a = 0, b = 0;
    	for(int i = 1; i <= n; i ++){
    		a += (i - 1) * (i - 1);
    		b += 2 * (i - 1) * (a1 - w[i]);
    	}
    	double d = - b / (2 * a);
    	double res = 0;
    	for(int i = 1; i <= n; i ++){
    		res += (a1 + (i - 1) * d - w[i]) * (a1 + (i - 1) * d - w[i]);
    	}
    	return res;
    }
    signed main()
    {
    	int t; t = gti();
    	while(t --)
    	{
    		n = gti();
    		for(int i = 1; i <= n; i ++) w[i] = gti();
    		double l = -1e10, r = 1e10;
    		while(r - l > 1e-5) {
    			double len = r - l;
    			double mid_l = l + len / 3, mid_r = r - len / 3;
    			if(check(mid_l) >= check(mid_r)) l = mid_l;
    			else r = mid_r;
    		}
    		printf("%.10Lf\n", check(r));
    	}
    	return 0;
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
  • 相关阅读:
    TensorFlow之分类模型-1
    Redis的简介
    洛谷千题详解 | P1014 [NOIP1999 普及组] Cantor 表【C++、Java语言】
    YOLOv5利用Labelimg标注自己数据集
    腾讯云TI平台持续升级,TI-ACC训练加速性能较原生框架提升超30%
    Java实现计算两个日期之间的工作日天数
    网络安全红队详细接收
    工作小记系列2:Kubevirt简介
    1089 不能被3整除的数
    Matlab设置figure中标题/图例英文不同字体
  • 原文地址:https://blog.csdn.net/m0_51171995/article/details/126087968