Numpy中内置的函数diag是一个变化莫测的函数。
这是np.diag函数的源代码:
- def diag(v, k=0):
- v = asanyarray(v)
- s = v.shape
- if len(s) == 1:
- n = s[0]+abs(k)
- res = zeros((n, n), v.dtype)
- if k >= 0:
- i = k
- else:
- i = (-k) * n
- res[:n-k].flat[i::n+1] = v
- return res
- elif len(s) == 2:
- return diagonal(v, k)
- else:
- raise ValueError("Input must be 1- or 2-d.")
我们可以看出np.diag函数可以传入的参数有 v 和 k。
对于v:
v是一个数组。(一维或者二维)
当v是一个一维数组时,结果形成一个以一维数组为对角线元素的矩阵;
当v是一个二维矩阵时,结果输出矩阵的对角线元素。
对于k:
k默认等于零,意味着取对角线,位置不偏移。
如果k > 0,那么取或者放对角线上面第k斜行。
如果k < 0,那么取或者放对角线下面第k斜行。
使用案例帮助理解:
假设现在有这样一个数组array:
- >>> array
- array([[1, 2, 3],
- [4, 5, 6],
- [7, 8, 9]])
v :二维数组,k:0
- >>> np.diag(a)
- array([1, 5, 9])
v:一维数组,k:0
- # 把上面的array([1, 5, 9])作为输入, 即np.diag(array) = [1, 5, 9]
- >>> np.diag(np.diag(a))
- array([[1, 0, 0],
- [0, 5, 0],
- [0, 0, 9]])
v:二维数组,k:1
- >>> np.diag(array, 1)
- array([2, 6])
v:一维数组,k:1
- # 把上面的array([1, 5, 9])作为输入, 即np.diag(array) = [1, 5, 9]
- >>> np.diag(np.diag(array), 1)
- array([0 1 0 0]
- [0 0 5 0]
- [0 0 0 9]
- [0 0 0 0]])