1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
|
class PConvBNActiv(nn.Module): def __init__(self, in_channels, out_channels, bn=True, sample='none-3', activ='relu', bias=False): super(PConvBNActiv, self).__init__() if sample == 'down-7': self.conv = PartialConv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3, bias=bias, multi_channel = True) elif sample == 'down-5': self.conv = PartialConv2d(in_channels, out_channels, kernel_size=5, stride=2, padding=2, bias=bias, multi_channel = True) elif sample == 'down-3': self.conv = PartialConv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=bias, multi_channel = True) else: self.conv = PartialConv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=bias, multi_channel = True) if bn: self.bn = nn.BatchNorm2d(out_channels) if activ == 'relu': self.activation = nn.ReLU() elif activ == 'leaky': self.activation = nn.LeakyReLU(negative_slope=0.2)
def forward(self, images, masks): images, masks = self.conv(images, masks) if hasattr(self, 'bn'): images = self.bn(images) if hasattr(self, 'activation'): images = self.activation(images)
return images, masks
class PUNet(nn.Module): def __init__(self, in_channels=3, out_channels=3, up_sampling_node='nearest', init_weights=True): super(PUNet, self).__init__() self.freeze_ec_bn = False self.up_sampling_node = up_sampling_node self.ec_images_1 = PConvBNActiv(in_channels, 64, bn=False, sample='down-7') self.ec_images_2 = PConvBNActiv(64, 128, sample='down-5') self.ec_images_3 = PConvBNActiv(128, 256, sample='down-5') self.ec_images_4 = PConvBNActiv(256, 512, sample='down-3') self.ec_images_5 = PConvBNActiv(512, 512, sample='down-3') self.ec_images_6 = PConvBNActiv(512, 512, sample='down-3') self.ec_images_7 = PConvBNActiv(512, 512, sample='down-3') self.dc_images_7 = PConvBNActiv(512 + 512, 512, activ='leaky') self.dc_images_6 = PConvBNActiv(512 + 512, 512, activ='leaky') self.dc_images_5 = PConvBNActiv(512 + 512, 512, activ='leaky') self.dc_images_4 = PConvBNActiv(512 + 256, 256, activ='leaky') self.dc_images_3 = PConvBNActiv(256 + 128, 128, activ='leaky') self.dc_images_2 = PConvBNActiv(128 + 64, 64, activ='leaky') self.dc_images_1 = PConvBNActiv(64 + out_channels, out_channels, bn=False, sample='none-3', activ=None, bias=True) self.tanh = nn.Tanh() if init_weights: self.apply(weights_init())
def forward(self, input_images, input_masks): ec_images = {} ec_images['ec_images_0'], ec_images['ec_images_masks_0'] = input_images, input_masks ec_images['ec_images_1'], ec_images['ec_images_masks_1'] = self.ec_images_1(input_images, input_masks) ec_images['ec_images_2'], ec_images['ec_images_masks_2'] = self.ec_images_2(ec_images['ec_images_1'], ec_images['ec_images_masks_1']) ec_images['ec_images_3'], ec_images['ec_images_masks_3'] = self.ec_images_3(ec_images['ec_images_2'], ec_images['ec_images_masks_2']) ec_images['ec_images_4'], ec_images['ec_images_masks_4'] = self.ec_images_4(ec_images['ec_images_3'], ec_images['ec_images_masks_3']) ec_images['ec_images_5'], ec_images['ec_images_masks_5'] = self.ec_images_5(ec_images['ec_images_4'], ec_images['ec_images_masks_4']) ec_images['ec_images_6'], ec_images['ec_images_masks_6'] = self.ec_images_6(ec_images['ec_images_5'], ec_images['ec_images_masks_5']) ec_images['ec_images_7'], ec_images['ec_images_masks_7'] = self.ec_images_7(ec_images['ec_images_6'], ec_images['ec_images_masks_6']) dc_images, dc_images_masks = ec_images['ec_images_7'], ec_images['ec_images_masks_7'] for _ in range(7, 0, -1): ec_images_skip = 'ec_images_{:d}'.format(_ - 1) ec_images_masks = 'ec_images_masks_{:d}'.format(_ - 1) dc_conv = 'dc_images_{:d}'.format(_) dc_images = F.interpolate(dc_images, scale_factor=2, mode=self.up_sampling_node) dc_images_masks = F.interpolate(dc_images_masks, scale_factor=2, mode=self.up_sampling_node) dc_images = torch.cat((dc_images, ec_images[ec_images_skip]), dim=1) dc_images_masks = torch.cat((dc_images_masks, ec_images[ec_images_masks]), dim=1) dc_images, dc_images_masks = getattr(self, dc_conv)(dc_images, dc_images_masks) outputs = self.tanh(dc_images)
return outputs
|