J. N. Tsitsiklis and Z.-Q. Luo, “Communication complexity of convex optimization,” Journal of Complexity, vol. 3, no. 3, pp. 231–243, Sep. 1987, doi: 10.1016/0885-064x(87)90013-6.
问题描述
两个用户各自有一个凸函数fi,相互交互最少的二进制消息,从而找到fi+f2的最优点
基本定义
-
F:定义域[0,1]n上凸函数的一个集合
-
I(f;ϵ)∈[0,1]n:定义域上,给定误差ϵ后f最小值对应的自变量集合(f(x)≤f(y)+ε,∀y∈[0,1]n)
-
C(f1,f2;ϵ,π):在协议π和精度ϵ下,两个函数通过交换信息找到集合I(f1+f2;ε)中元素所需的消息数目
-
C(F;ε,π):该协议在最坏情况下找到目标所需交换的消息数量
C(F;ε,π)=supf1,f2∈FC(f1,f2;ε,π) -
C(F;ε):最优协议下所需的交换消息的数量,又称为ϵ-communication complexity
C(F;ε)=infπ∈I(ε)C(F;ε,π) -
消息传输的模式,通信T次
-
每次传播信息的计算
mi(t)=Mi,t(fi,mj(0),…,mj(t−1)) -
最终最优点的确定
x=Q(f1,m2(0),….,m2(T−1))
-
Straightforward Lower Bound
Lemma 1: If F⊂G then C(F;ε)≤C(G;ε)
简单函数所需传输的消息数量更少
Proposition:C(FQ;ε)≥O(n(logn+log(1/ε)))
其中FQ表示带有f(x)=‖x−x⋆‖2形式的二次函数的集合,其中x⋆∈[0,1]n。根据Lemma知道,选择最简单的函数能找到下界。考虑f1=0,所以f2的最小值需要控制在ϵ1/2的精度内,因此至少需要(An/ε1/2)Bn个半径为ϵ1/2Euclidean ball来覆盖中[0,1]n。因此最终Q的解集的势至少就是(An/ε1/2)Bn。由于函数的值域的势不会超过定义域的势,所以Q的解集的势不超过2T,也就有T≥O(n(logn+log(1/ε))。
Naive Upper Bounds
The method of the centers of gravity (MCG) 在求解凸函数势需要最小次数的梯度计算。将MCG方法扩展到了分布式的场景,得到上界。
一维下的最优算法
算法核心在于用消息指示不同的计算步骤,而不是传递数据。
算法首先定义两个区间,分别表示
- [a,b]:f1+f2最优点所在的区间,x⋆∈[a,b]
- [c,d]:f′(x⋆),f′1(a+b2),f′2(a+b2)所在的区间
以区间[c,d]为基准,分别计算消息m1,m2
- f′1(a+b2)∈[c,c+d2]则m1=0,否则m1=1
- −f′2(a+b2)∈[c,c+d2]则m2=0,否则m2=1
根据消息m1,m2的不同组合,分别缩减区间[a,b]或者[c,d]。缩减的设计总从两个原则
- (f1+f2)′=f′1+f′2,导值的正负性来找最小值
- 通过压缩(f1+f2)′(a+b2)趋于零,从而确定a+b2就是最小值
代码:
import numpy as np
import matplotlib.pyplot as plt
def f1(x):
return (x - 2) ** 2
def df1(x):
return 2 * (x - 2)
def f2(x):
return (x + 1) ** 2
def df2(x):
return 2 * (x + 1)
a, b, c, d = -1, 1, -3, 3
eps = 0.1
while b - a > eps and d - c > eps:
if df1((a + b) / 2) <= (c + d) / 2:
m1 = 0
else:
m1 = 1
if -df2((a + b) / 2) <= (c + d) / 2:
m2 = 0
else:
m2 = 1
if m1 == 0 and m2 == 1:
a = (a + b) / 2
elif m1 == 1 and m2 == 0:
b = (a + b) / 2
elif m1 == 1 and m2 == 1:
c = (c + d) / 2
elif m1 == 0 and m2 == 0:
d = (c + d) / 2
print('传输消息+2')
print(a, b, c, d)
if b - a <= eps:
optimum = a + eps
else:
optimum = f1((a + b) / 2) + f2((a + b) / 2)
print(optimum)
print(f1(0.5) + f2(0.5))
# 直观画图结果
x = np.linspace(-1, 2, 100)
y = f1(x) + f2(x)
plt.plot(x, y)
plt.show()