unet_3plus.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddlex.paddleseg.cvlibs import manager
  18. from paddlex.paddleseg.models.layers.layer_libs import SyncBatchNorm
  19. from paddlex.paddleseg.cvlibs.param_init import kaiming_normal_init
  20. @manager.MODELS.add_component
  21. class UNet3Plus(nn.Layer):
  22. """
  23. The UNet3+ implementation based on PaddlePaddle.
  24. The original article refers to
  25. Huang H , Lin L , Tong R , et al. "UNet 3+: A Full-Scale Connected UNet for Medical Image Segmentation"
  26. (https://arxiv.org/abs/2004.08790).
  27. Args:
  28. in_channels (int, optional): The channel number of input image. Default: 3.
  29. num_classes (int, optional): The unique number of target classes. Default: 2.
  30. is_batchnorm (bool, optional): Use batchnorm after conv or not. Default: True.
  31. is_deepsup (bool, optional): Use deep supervision or not. Default: False.
  32. is_CGM (bool, optional): Use classification-guided module or not.
  33. If True, is_deepsup must be True. Default: False.
  34. """
  35. def __init__(self,
  36. in_channels=3,
  37. num_classes=2,
  38. is_batchnorm=True,
  39. is_deepsup=False,
  40. is_CGM=False):
  41. super(UNet3Plus, self).__init__()
  42. # parameters
  43. self.is_deepsup = True if is_CGM else is_deepsup
  44. self.is_CGM = is_CGM
  45. # internal definition
  46. self.filters = [64, 128, 256, 512, 1024]
  47. self.cat_channels = self.filters[0]
  48. self.cat_blocks = 5
  49. self.up_channels = self.cat_channels * self.cat_blocks
  50. # layers
  51. self.encoder = Encoder(in_channels, self.filters, is_batchnorm)
  52. self.decoder = Decoder(self.filters, self.cat_channels,
  53. self.up_channels)
  54. if self.is_deepsup:
  55. self.deepsup = DeepSup(self.up_channels, self.filters, num_classes)
  56. if self.is_CGM:
  57. self.cls = nn.Sequential(
  58. nn.Dropout(p=0.5),
  59. nn.Conv2D(self.filters[4], 2, 1),
  60. nn.AdaptiveMaxPool2D(1),
  61. nn.Sigmoid())
  62. else:
  63. self.outconv1 = nn.Conv2D(
  64. self.up_channels, num_classes, 3, padding=1)
  65. # initialise weights
  66. for sublayer in self.sublayers():
  67. if isinstance(sublayer, nn.Conv2D):
  68. kaiming_normal_init(sublayer.weight)
  69. elif isinstance(sublayer, (nn.BatchNorm, nn.SyncBatchNorm)):
  70. kaiming_normal_init(sublayer.weight)
  71. def dotProduct(self, seg, cls):
  72. B, N, H, W = seg.shape
  73. seg = seg.reshape((B, N, H * W))
  74. clssp = paddle.ones([1, N])
  75. ecls = (cls * clssp).reshape([B, N, 1])
  76. final = seg * ecls
  77. final = final.reshape((B, N, H, W))
  78. return final
  79. def forward(self, inputs):
  80. hs = self.encoder(inputs)
  81. hds = self.decoder(hs)
  82. if self.is_deepsup:
  83. out = self.deepsup(hds)
  84. if self.is_CGM:
  85. # classification-guided module
  86. cls_branch = self.cls(hds[-1]).squeeze(3).squeeze(
  87. 2) # (B,N,1,1)->(B,N)
  88. cls_branch_max = cls_branch.argmax(axis=1)
  89. cls_branch_max = cls_branch_max.reshape(
  90. (-1, 1)).astype('float')
  91. out = [self.dotProduct(d, cls_branch_max) for d in out]
  92. else:
  93. out = [self.outconv1(hds[0])] # d1->320*320*num_classes
  94. return out
  95. class Encoder(nn.Layer):
  96. def __init__(self, in_channels, filters, is_batchnorm):
  97. super(Encoder, self).__init__()
  98. self.conv1 = UnetConv2D(in_channels, filters[0], is_batchnorm)
  99. self.poolconv2 = MaxPoolConv2D(filters[0], filters[1], is_batchnorm)
  100. self.poolconv3 = MaxPoolConv2D(filters[1], filters[2], is_batchnorm)
  101. self.poolconv4 = MaxPoolConv2D(filters[2], filters[3], is_batchnorm)
  102. self.poolconv5 = MaxPoolConv2D(filters[3], filters[4], is_batchnorm)
  103. def forward(self, inputs):
  104. h1 = self.conv1(inputs) # h1->320*320*64
  105. h2 = self.poolconv2(h1) # h2->160*160*128
  106. h3 = self.poolconv3(h2) # h3->80*80*256
  107. h4 = self.poolconv4(h3) # h4->40*40*512
  108. hd5 = self.poolconv5(h4) # h5->20*20*1024
  109. return [h1, h2, h3, h4, hd5]
  110. class Decoder(nn.Layer):
  111. def __init__(self, filters, cat_channels, up_channels):
  112. super(Decoder, self).__init__()
  113. '''stage 4d'''
  114. # h1->320*320, hd4->40*40, Pooling 8 times
  115. self.h1_PT_hd4 = nn.MaxPool2D(8, 8, ceil_mode=True)
  116. self.h1_PT_hd4_cbr = ConvBnReLU2D(filters[0], cat_channels)
  117. # h2->160*160, hd4->40*40, Pooling 4 times
  118. self.h2_PT_hd4 = nn.MaxPool2D(4, 4, ceil_mode=True)
  119. self.h2_PT_hd4_cbr = ConvBnReLU2D(filters[1], cat_channels)
  120. # h3->80*80, hd4->40*40, Pooling 2 times
  121. self.h3_PT_hd4 = nn.MaxPool2D(2, 2, ceil_mode=True)
  122. self.h3_PT_hd4_cbr = ConvBnReLU2D(filters[2], cat_channels)
  123. # h4->40*40, hd4->40*40, Concatenation
  124. self.h4_Cat_hd4_cbr = ConvBnReLU2D(filters[3], cat_channels)
  125. # hd5->20*20, hd4->40*40, Upsample 2 times
  126. self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14
  127. self.hd5_UT_hd4_cbr = ConvBnReLU2D(filters[4], cat_channels)
  128. # fusion(h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4)
  129. self.cbr4d_1 = ConvBnReLU2D(up_channels, up_channels) # 16
  130. '''stage 3d'''
  131. # h1->320*320, hd3->80*80, Pooling 4 times
  132. self.h1_PT_hd3 = nn.MaxPool2D(4, 4, ceil_mode=True)
  133. self.h1_PT_hd3_cbr = ConvBnReLU2D(filters[0], cat_channels)
  134. # h2->160*160, hd3->80*80, Pooling 2 times
  135. self.h2_PT_hd3 = nn.MaxPool2D(2, 2, ceil_mode=True)
  136. self.h2_PT_hd3_cbr = ConvBnReLU2D(filters[1], cat_channels)
  137. # h3->80*80, hd3->80*80, Concatenation
  138. self.h3_Cat_hd3_cbr = ConvBnReLU2D(filters[2], cat_channels)
  139. # hd4->40*40, hd4->80*80, Upsample 2 times
  140. self.hd4_UT_hd3 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14
  141. self.hd4_UT_hd3_cbr = ConvBnReLU2D(up_channels, cat_channels)
  142. # hd5->20*20, hd4->80*80, Upsample 4 times
  143. self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14
  144. self.hd5_UT_hd3_cbr = ConvBnReLU2D(filters[4], cat_channels)
  145. # fusion(h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3)
  146. self.cbr3d_1 = ConvBnReLU2D(up_channels, up_channels) # 16
  147. '''stage 2d '''
  148. # h1->320*320, hd2->160*160, Pooling 2 times
  149. self.h1_PT_hd2 = nn.MaxPool2D(2, 2, ceil_mode=True)
  150. self.h1_PT_hd2_cbr = ConvBnReLU2D(filters[0], cat_channels)
  151. # h2->160*160, hd2->160*160, Concatenation
  152. self.h2_Cat_hd2_cbr = ConvBnReLU2D(filters[1], cat_channels)
  153. # hd3->80*80, hd2->160*160, Upsample 2 times
  154. self.hd3_UT_hd2 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14
  155. self.hd3_UT_hd2_cbr = ConvBnReLU2D(up_channels, cat_channels)
  156. # hd4->40*40, hd2->160*160, Upsample 4 times
  157. self.hd4_UT_hd2 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14
  158. self.hd4_UT_hd2_cbr = ConvBnReLU2D(up_channels, cat_channels)
  159. # hd5->20*20, hd2->160*160, Upsample 8 times
  160. self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode='bilinear') # 14*14
  161. self.hd5_UT_hd2_cbr = ConvBnReLU2D(filters[4], cat_channels)
  162. # fusion(h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2)
  163. self.cbr2d_1 = ConvBnReLU2D(up_channels, up_channels) # 16
  164. '''stage 1d'''
  165. # h1->320*320, hd1->320*320, Concatenation
  166. self.h1_Cat_hd1_cbr = ConvBnReLU2D(filters[0], cat_channels)
  167. # hd2->160*160, hd1->320*320, Upsample 2 times
  168. self.hd2_UT_hd1 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14
  169. self.hd2_UT_hd1_cbr = ConvBnReLU2D(up_channels, cat_channels)
  170. # hd3->80*80, hd1->320*320, Upsample 4 times
  171. self.hd3_UT_hd1 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14
  172. self.hd3_UT_hd1_cbr = ConvBnReLU2D(up_channels, cat_channels)
  173. # hd4->40*40, hd1->320*320, Upsample 8 times
  174. self.hd4_UT_hd1 = nn.Upsample(scale_factor=8, mode='bilinear') # 14*14
  175. self.hd4_UT_hd1_cbr = ConvBnReLU2D(up_channels, cat_channels)
  176. # hd5->20*20, hd1->320*320, Upsample 16 times
  177. self.hd5_UT_hd1 = nn.Upsample(
  178. scale_factor=16, mode='bilinear') # 14*14
  179. self.hd5_UT_hd1_cbr = ConvBnReLU2D(filters[4], cat_channels)
  180. # fusion(h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1)
  181. self.cbr1d_1 = ConvBnReLU2D(up_channels, up_channels) # 16
  182. def forward(self, inputs):
  183. h1, h2, h3, h4, hd5 = inputs
  184. h1_PT_hd4 = self.h1_PT_hd4_cbr(self.h1_PT_hd4(h1))
  185. h2_PT_hd4 = self.h2_PT_hd4_cbr(self.h2_PT_hd4(h2))
  186. h3_PT_hd4 = self.h3_PT_hd4_cbr(self.h3_PT_hd4(h3))
  187. h4_Cat_hd4 = self.h4_Cat_hd4_cbr(h4)
  188. hd5_UT_hd4 = self.hd5_UT_hd4_cbr(self.hd5_UT_hd4(hd5))
  189. # hd4->40*40*up_channels
  190. hd4 = self.cbr4d_1(
  191. paddle.concat(
  192. [h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4], 1))
  193. h1_PT_hd3 = self.h1_PT_hd3_cbr(self.h1_PT_hd3(h1))
  194. h2_PT_hd3 = self.h2_PT_hd3_cbr(self.h2_PT_hd3(h2))
  195. h3_Cat_hd3 = self.h3_Cat_hd3_cbr(h3)
  196. hd4_UT_hd3 = self.hd4_UT_hd3_cbr(self.hd4_UT_hd3(hd4))
  197. hd5_UT_hd3 = self.hd5_UT_hd3_cbr(self.hd5_UT_hd3(hd5))
  198. # hd3->80*80*up_channels
  199. hd3 = self.cbr3d_1(
  200. paddle.concat(
  201. [h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3], 1))
  202. h1_PT_hd2 = self.h1_PT_hd2_cbr(self.h1_PT_hd2(h1))
  203. h2_Cat_hd2 = self.h2_Cat_hd2_cbr(h2)
  204. hd3_UT_hd2 = self.hd3_UT_hd2_cbr(self.hd3_UT_hd2(hd3))
  205. hd4_UT_hd2 = self.hd4_UT_hd2_cbr(self.hd4_UT_hd2(hd4))
  206. hd5_UT_hd2 = self.hd5_UT_hd2_cbr(self.hd5_UT_hd2(hd5))
  207. # hd2->160*160*up_channels
  208. hd2 = self.cbr2d_1(
  209. paddle.concat([
  210. h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2
  211. ], 1))
  212. h1_Cat_hd1 = self.h1_Cat_hd1_cbr(h1)
  213. hd2_UT_hd1 = self.hd2_UT_hd1_cbr(self.hd2_UT_hd1(hd2))
  214. hd3_UT_hd1 = self.hd3_UT_hd1_cbr(self.hd3_UT_hd1(hd3))
  215. hd4_UT_hd1 = self.hd4_UT_hd1_cbr(self.hd4_UT_hd1(hd4))
  216. hd5_UT_hd1 = self.hd5_UT_hd1_cbr(self.hd5_UT_hd1(hd5))
  217. # hd1->320*320*up_channels
  218. hd1 = self.cbr1d_1(
  219. paddle.concat([
  220. h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1
  221. ], 1))
  222. return [hd1, hd2, hd3, hd4, hd5]
  223. class DeepSup(nn.Layer):
  224. def __init__(self, up_channels, filters, num_classes):
  225. super(DeepSup, self).__init__()
  226. self.convup5 = ConvUp2D(filters[4], num_classes, 16)
  227. self.convup4 = ConvUp2D(up_channels, num_classes, 8)
  228. self.convup3 = ConvUp2D(up_channels, num_classes, 4)
  229. self.convup2 = ConvUp2D(up_channels, num_classes, 2)
  230. self.outconv1 = nn.Conv2D(up_channels, num_classes, 3, padding=1)
  231. def forward(self, inputs):
  232. hd1, hd2, hd3, hd4, hd5 = inputs
  233. d5 = self.convup5(hd5) # 16->256
  234. d4 = self.convup4(hd4) # 32->256
  235. d3 = self.convup3(hd3) # 64->256
  236. d2 = self.convup2(hd2) # 128->256
  237. d1 = self.outconv1(hd1) # 256
  238. return [d1, d2, d3, d4, d5]
  239. class ConvBnReLU2D(nn.Sequential):
  240. def __init__(self, in_channels, out_channels):
  241. super(ConvBnReLU2D, self).__init__(
  242. nn.Conv2D(
  243. in_channels, out_channels, 3, padding=1),
  244. nn.BatchNorm(out_channels),
  245. nn.ReLU())
  246. class ConvUp2D(nn.Sequential):
  247. def __init__(self, in_channels, out_channels, scale_factor):
  248. super(ConvUp2D, self).__init__(
  249. nn.Conv2D(
  250. in_channels, out_channels, 3, padding=1),
  251. nn.Upsample(
  252. scale_factor=scale_factor, mode='bilinear'))
  253. class MaxPoolConv2D(nn.Sequential):
  254. def __init__(self, in_channels, out_channels, is_batchnorm):
  255. super(MaxPoolConv2D, self).__init__(
  256. nn.MaxPool2D(kernel_size=2),
  257. UnetConv2D(in_channels, out_channels, is_batchnorm))
  258. class UnetConv2D(nn.Layer):
  259. def __init__(self,
  260. in_channels,
  261. out_channels,
  262. is_batchnorm,
  263. num_conv=2,
  264. kernel_size=3,
  265. stride=1,
  266. padding=1):
  267. super(UnetConv2D, self).__init__()
  268. self.num_conv = num_conv
  269. for i in range(num_conv):
  270. conv = (nn.Sequential(nn.Conv2D(in_channels, out_channels, kernel_size, stride, padding),
  271. nn.BatchNorm(out_channels),
  272. nn.ReLU()) \
  273. if is_batchnorm else \
  274. nn.Sequential(nn.Conv2D(in_channels, out_channels, kernel_size, stride, padding),
  275. nn.ReLU()))
  276. setattr(self, 'conv%d' % (i + 1), conv)
  277. in_channels = out_channels
  278. # initialise the blocks
  279. for children in self.children():
  280. children.weight_attr = paddle.framework.ParamAttr(
  281. initializer=paddle.nn.initializer.KaimingNormal)
  282. children.bias_attr = paddle.framework.ParamAttr(
  283. initializer=paddle.nn.initializer.KaimingNormal)
  284. def forward(self, inputs):
  285. x = inputs
  286. for i in range(self.num_conv):
  287. conv = getattr(self, 'conv%d' % (i + 1))
  288. x = conv(x)
  289. return x