我们先来看下np.split的实现方法:
- @array_function_dispatch(_split_dispatcher)
- def split(ary, indices_or_sections, axis=0):
- try:
- len(indices_or_sections)
- except TypeError:
- sections = indices_or_sections
- N = ary.shape[axis]
- if N % sections:
- raise ValueError(
- 'array split does not result in an equal division') from None
- return array_split(ary, indices_or_sections, axis)
当然有兴趣的可以继续看array_split的具体操作方法。
从split的定义可以看到,参数是(数组,数或数组,维度)返回值是列表,里面的每组元素是数组。看示例:
参数是整数
- import numpy as np
-
- a=np.arange(10)
- np.split(a,5)
- #[array([0, 1]), array([2, 3]), array([4, 5]), array([6, 7]), array([8, 9])]
将0-9,均分成5份。如果不能被整除,比如是4,将出现错误:
ValueError: array split does not result in an equal division
参数是列表,按照里面每个值来分段
- a=np.arange(10)
- np.split(a,[4,8])
- #[array([0, 1, 2, 3]), array([4, 5, 6, 7]), array([8, 9])]
变形成二维数组来切分
- a=np.arange(40).reshape(8,5)
- np.split(a,4)
- '''
- [array([[0, 1, 2, 3, 4],
- [5, 6, 7, 8, 9]]),
- array([[10, 11, 12, 13, 14],
- [15, 16, 17, 18, 19]]),
- array([[20, 21, 22, 23, 24],
- [25, 26, 27, 28, 29]]),
- array([[30, 31, 32, 33, 34],
- [35, 36, 37, 38, 39]])]
- '''
axis=1的结果
- a=np.arange(40).reshape(5,8)
- np.split(a,4,axis=1)
- '''
- [array([[ 0, 1],
- [ 8, 9],
- [16, 17],
- [24, 25],
- [32, 33]]),
- array([[ 2, 3],
- [10, 11],
- [18, 19],
- [26, 27],
- [34, 35]]),
- array([[ 4, 5],
- [12, 13],
- [20, 21],
- [28, 29],
- [36, 37]]),
- array([[ 6, 7],
- [14, 15],
- [22, 23],
- [30, 31],
- [38, 39]])]
- '''
接下来的nd.split的用法跟np.split虽然用法很像,还是存在一些区别需要注意。
还是贴下nd.split的方法:
- def split(data=None, num_outputs=_Null, axis=_Null, squeeze_axis=_Null, out=None, name=None, **kwargs):
- return (0,)
跟np.split的区别就是必须指定axis,不然会报错。
- from mxnet import nd
-
- a=nd.arange(40).reshape(8,5)
- nd.split(a,4,axis=0)
-
- '''
- [
- [[0. 1. 2. 3. 4.]
- [5. 6. 7. 8. 9.]]
-
, -
- [[10. 11. 12. 13. 14.]
- [15. 16. 17. 18. 19.]]
-
, -
- [[20. 21. 22. 23. 24.]
- [25. 26. 27. 28. 29.]]
-
, -
- [[30. 31. 32. 33. 34.]
- [35. 36. 37. 38. 39.]]
-
] - '''
三维的例子亦如是,如:切分第二维切4份
- a=nd.arange(40).reshape(2,4,5)
- nd.split(a,4,axis=1)
- '''
- [
- [[[ 0. 1. 2. 3. 4.]]
-
- [[20. 21. 22. 23. 24.]]]
-
, -
- [[[ 5. 6. 7. 8. 9.]]
-
- [[25. 26. 27. 28. 29.]]]
-
, -
- [[[10. 11. 12. 13. 14.]]
-
- [[30. 31. 32. 33. 34.]]]
-
, -
- [[[15. 16. 17. 18. 19.]]
-
- [[35. 36. 37. 38. 39.]]]
-
] - '''
每个元素的形状是nd.split(a,4,axis=1)[1].shape #(2,1,5)
除了参数名称不一样,个数也不一样,比如squeeze_axis这个新增的参数,可以减掉一维。
- a=nd.arange(40).reshape(2,4,5)
- nd.split(a,2,axis=0,squeeze_axis=1)
- '''
- [
- [[ 0. 1. 2. 3. 4.]
- [ 5. 6. 7. 8. 9.]
- [10. 11. 12. 13. 14.]
- [15. 16. 17. 18. 19.]]
-
, -
- [[20. 21. 22. 23. 24.]
- [25. 26. 27. 28. 29.]
- [30. 31. 32. 33. 34.]
- [35. 36. 37. 38. 39.]]
-
] - '''
看出有什么不同了吗?少了一维,本来里面每个元素是
再看一例:
- a=nd.arange(40).reshape(2,4,5)
- nd.split(a,4,axis=1,squeeze_axis=1)
- '''
- [
- [[ 0. 1. 2. 3. 4.]
- [20. 21. 22. 23. 24.]]
-
, -
- [[ 5. 6. 7. 8. 9.]
- [25. 26. 27. 28. 29.]]
-
, -
- [[10. 11. 12. 13. 14.]
- [30. 31. 32. 33. 34.]]
-
, -
- [[15. 16. 17. 18. 19.]
- [35. 36. 37. 38. 39.]]
-
] - '''
如果没有squeeze_axis=1这个参数,里面的元素形状是
所以这个其实就是将所在切分的维,有且仅有1,那么就减掉这个维度。这个其实是有意义的,毕竟属于没数据的占着空的维度,可以去掉。