shufflenet_slim.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddlex.paddleseg.cvlibs import manager, param_init
  18. from paddlex.paddleseg.models import layers
  19. from paddlex.paddleseg.utils import utils
  20. __all__ = ['ShuffleNetV2']
  21. @manager.MODELS.add_component
  22. class ShuffleNetV2(nn.Layer):
  23. def __init__(self, num_classes, pretrained=None, align_corners=False):
  24. super().__init__()
  25. self.pretrained = pretrained
  26. self.num_classes = num_classes
  27. self.align_corners = align_corners
  28. self.conv_bn0 = _ConvBNReLU(3, 36, 3, 2, 1)
  29. self.conv_bn1 = _ConvBNReLU(36, 18, 1, 1, 0)
  30. self.block1 = nn.Sequential(
  31. SFNetV2Module(
  32. 36, stride=2, out_channels=72),
  33. SFNetV2Module(
  34. 72, stride=1),
  35. SFNetV2Module(
  36. 72, stride=1),
  37. SFNetV2Module(
  38. 72, stride=1))
  39. self.block2 = nn.Sequential(
  40. SFNetV2Module(
  41. 72, stride=2),
  42. SFNetV2Module(
  43. 144, stride=1),
  44. SFNetV2Module(
  45. 144, stride=1),
  46. SFNetV2Module(
  47. 144, stride=1),
  48. SFNetV2Module(
  49. 144, stride=1),
  50. SFNetV2Module(
  51. 144, stride=1),
  52. SFNetV2Module(
  53. 144, stride=1),
  54. SFNetV2Module(
  55. 144, stride=1))
  56. self.depthwise_separable0 = _SeparableConvBNReLU(144, 64, 3, stride=1)
  57. self.depthwise_separable1 = _SeparableConvBNReLU(82, 64, 3, stride=1)
  58. weight_attr = paddle.ParamAttr(
  59. learning_rate=1.,
  60. regularizer=paddle.regularizer.L2Decay(coeff=0.),
  61. initializer=nn.initializer.XavierUniform())
  62. self.deconv = nn.Conv2DTranspose(
  63. 64,
  64. self.num_classes,
  65. 2,
  66. stride=2,
  67. padding=0,
  68. weight_attr=weight_attr,
  69. bias_attr=True)
  70. self.init_weight()
  71. def forward(self, x):
  72. ## Encoder
  73. conv1 = self.conv_bn0(x) # encoder 1
  74. shortcut = self.conv_bn1(conv1) # shortcut 1
  75. pool = F.max_pool2d(
  76. conv1, kernel_size=3, stride=2, padding=1) # encoder 2
  77. # Block 1
  78. conv = self.block1(pool) # encoder 3
  79. # Block 2
  80. conv = self.block2(conv) # encoder 4
  81. ### decoder
  82. conv = self.depthwise_separable0(conv)
  83. shortcut_shape = paddle.shape(shortcut)[2:]
  84. conv_b = F.interpolate(
  85. conv,
  86. shortcut_shape,
  87. mode='bilinear',
  88. align_corners=self.align_corners)
  89. concat = paddle.concat(x=[shortcut, conv_b], axis=1)
  90. decode_conv = self.depthwise_separable1(concat)
  91. logit = self.deconv(decode_conv)
  92. return [logit]
  93. def init_weight(self):
  94. for layer in self.sublayers():
  95. if isinstance(layer, nn.Conv2D):
  96. param_init.normal_init(layer.weight, std=0.001)
  97. elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)):
  98. param_init.constant_init(layer.weight, value=1.0)
  99. param_init.constant_init(layer.bias, value=0.0)
  100. if self.pretrained is not None:
  101. utils.load_pretrained_model(self, self.pretrained)
  102. class _ConvBNReLU(nn.Layer):
  103. def __init__(self,
  104. in_channels,
  105. out_channels,
  106. kernel_size,
  107. stride,
  108. padding,
  109. groups=1,
  110. **kwargs):
  111. super().__init__()
  112. weight_attr = paddle.ParamAttr(
  113. learning_rate=1, initializer=nn.initializer.KaimingUniform())
  114. self._conv = nn.Conv2D(
  115. in_channels,
  116. out_channels,
  117. kernel_size,
  118. padding=padding,
  119. stride=stride,
  120. groups=groups,
  121. weight_attr=weight_attr,
  122. bias_attr=False,
  123. **kwargs)
  124. self._batch_norm = layers.SyncBatchNorm(out_channels)
  125. def forward(self, x):
  126. x = self._conv(x)
  127. x = self._batch_norm(x)
  128. x = F.relu(x)
  129. return x
  130. class _ConvBN(nn.Layer):
  131. def __init__(self,
  132. in_channels,
  133. out_channels,
  134. kernel_size,
  135. stride,
  136. padding,
  137. groups=1,
  138. **kwargs):
  139. super().__init__()
  140. weight_attr = paddle.ParamAttr(
  141. learning_rate=1, initializer=nn.initializer.KaimingUniform())
  142. self._conv = nn.Conv2D(
  143. in_channels,
  144. out_channels,
  145. kernel_size,
  146. padding=padding,
  147. stride=stride,
  148. groups=groups,
  149. weight_attr=weight_attr,
  150. bias_attr=False,
  151. **kwargs)
  152. self._batch_norm = layers.SyncBatchNorm(out_channels)
  153. def forward(self, x):
  154. x = self._conv(x)
  155. x = self._batch_norm(x)
  156. return x
  157. class _SeparableConvBNReLU(nn.Layer):
  158. def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
  159. super().__init__()
  160. self.depthwise_conv = _ConvBN(
  161. in_channels,
  162. out_channels=in_channels,
  163. kernel_size=kernel_size,
  164. padding=int(kernel_size / 2),
  165. groups=in_channels,
  166. **kwargs)
  167. self.piontwise_conv = _ConvBNReLU(
  168. in_channels,
  169. out_channels,
  170. kernel_size=1,
  171. groups=1,
  172. stride=1,
  173. padding=0)
  174. def forward(self, x):
  175. x = self.depthwise_conv(x)
  176. x = self.piontwise_conv(x)
  177. return x
  178. class SFNetV2Module(nn.Layer):
  179. def __init__(self, input_channels, stride, out_channels=None):
  180. super().__init__()
  181. if stride == 1:
  182. branch_channel = int(input_channels / 2)
  183. else:
  184. branch_channel = input_channels
  185. if out_channels is None:
  186. self.in_channels = int(branch_channel)
  187. else:
  188. self.in_channels = int(out_channels / 2)
  189. self._depthwise_separable_0 = _SeparableConvBNReLU(
  190. input_channels, self.in_channels, 3, stride=stride)
  191. self._conv = _ConvBNReLU(
  192. branch_channel, self.in_channels, 1, stride=1, padding=0)
  193. self._depthwise_separable_1 = _SeparableConvBNReLU(
  194. self.in_channels, self.in_channels, 3, stride=stride)
  195. self.stride = stride
  196. def forward(self, input):
  197. if self.stride == 1:
  198. shortcut, branch = paddle.split(x=input, num_or_sections=2, axis=1)
  199. else:
  200. branch = input
  201. shortcut = self._depthwise_separable_0(input)
  202. branch_1x1 = self._conv(branch)
  203. branch_dw1x1 = self._depthwise_separable_1(branch_1x1)
  204. output = paddle.concat(x=[shortcut, branch_dw1x1], axis=1)
  205. # channel shuffle
  206. out_shape = paddle.shape(output)
  207. h, w = out_shape[2], out_shape[3]
  208. output = paddle.reshape(x=output, shape=[0, 2, self.in_channels, h, w])
  209. output = paddle.transpose(x=output, perm=[0, 2, 1, 3, 4])
  210. output = paddle.reshape(
  211. x=output, shape=[0, 2 * self.in_channels, h, w])
  212. return output
  213. if __name__ == '__main__':
  214. import numpy as np
  215. import os
  216. np.random.seed(100)
  217. paddle.seed(100)
  218. net = ShuffleNetV2(10)
  219. img = np.random.random(size=(4, 3, 100, 100)).astype('float32')
  220. img = paddle.to_tensor(img)
  221. out = net(img)
  222. print(out)
  223. net.forward = paddle.jit.to_static(net.forward)
  224. save_path = os.path.join('.', 'model')
  225. in_var = paddle.ones([4, 3, 100, 100])
  226. paddle.jit.save(net, save_path, input_spec=[in_var])