pytorch/动态图下,如下代码运行正常,但是换到静态图模式下会报错,静态图下该如何实现遍历网络层并提取中间特定某几层的输出?
layers_of_interest = [5, 10, 15, 20]
result = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i in layers_of_interest:
result.append(x)
return result
解答:
如果只是查看网络中间输出的话可以使用print,如果需要取出来中间结果进行计算等操作,可以定义parameter之后将中间层结果赋值到该parameter中去
不要用自己创建的Tuple,可以用ParameterTuple 把想拿到的权重先都放到init里初始化好,后面按照循环赋值的方式进行赋值
类似这样试下,最后拿到的cell_out就是中间层的输出:
可以在 中定义 Cell.__init__(self):
self.cell_out = ParameterTuple([Parameters]) #Parameters 为设置为 需要保存的中间层输出的dtype和shape。
def construct():
layers_of_interest = [5, 10, 15, 20]
result = []
for i, layer in enumerate(self.layers):
x = layer(x)
j = 0
if i in layers_of_interest:
cell_out[j] = x
j += 1
return result