|
|
import torch |
|
|
from torch import nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
|
|
|
|
|
|
|
|
|
class PositionEncodingSine2D(nn.Module): |
|
|
""" |
|
|
This is a more standard version of the position embedding, very similar to the one |
|
|
used by the Attention is all you need paper, generalized to work on images. |
|
|
""" |
|
|
def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None): |
|
|
super(PositionEncodingSine2D, self).__init__() |
|
|
self.num_pos_feats = num_pos_feats |
|
|
self.temperature = temperature |
|
|
self.normalize = normalize |
|
|
if scale is not None and normalize is False: |
|
|
raise ValueError("normalize should be True if scale is passed") |
|
|
if scale is None: |
|
|
scale = 2 * math.pi |
|
|
self.scale = scale |
|
|
|
|
|
def forward(self, x, isTarget = False): |
|
|
''' |
|
|
input x: B, C, H, W |
|
|
return pos: B, C, H, W |
|
|
|
|
|
''' |
|
|
not_mask = torch.ones(x.size()[0], x.size()[2], x.size()[3]).to(x.device) |
|
|
y_embed = not_mask.cumsum(1, dtype=torch.float32) |
|
|
x_embed = not_mask.cumsum(2, dtype=torch.float32) |
|
|
|
|
|
if self.normalize: |
|
|
eps = 1e-6 |
|
|
|
|
|
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale |
|
|
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale |
|
|
|
|
|
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) |
|
|
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) |
|
|
|
|
|
pos_x = x_embed[:, :, :, None] / dim_t |
|
|
pos_y = y_embed[:, :, :, None] / dim_t |
|
|
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) |
|
|
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) |
|
|
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) |
|
|
return pos |
|
|
|
|
|
class EncoderLayerInnerAttention(nn.Module): |
|
|
""" |
|
|
Transformer encoder with all paramters |
|
|
""" |
|
|
def __init__(self, d_model, nhead, dim_feedforward, dropout, activation, pos_weight, feat_weight): |
|
|
super(EncoderLayerInnerAttention, self).__init__() |
|
|
|
|
|
|
|
|
self.pos_weight = pos_weight |
|
|
self.feat_weight = feat_weight |
|
|
self.inner_encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation = activation) |
|
|
self.posEncoder = PositionEncodingSine2D(d_model // 2) |
|
|
|
|
|
self.cross_encoder_layer = EncoderLayerCrossAttention(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, |
|
|
dropout=dropout, activation=activation) |
|
|
|
|
|
def forward(self, x, y, featmask, x_mask = None, y_mask = None): |
|
|
''' |
|
|
input x: B, C, H, W |
|
|
input y: B, C, H, W |
|
|
input x_mask: B, 1, H, W, mask == True will be ignored |
|
|
input y_mask: B, 1, H, W, mask == True will be ignored |
|
|
''' |
|
|
|
|
|
|
|
|
x = self.cross_encoder_layer(x, featmask, None)[0] |
|
|
|
|
|
bx, cx, hx, wx = x.size() |
|
|
|
|
|
by, cy, hy, wy = y.size() |
|
|
|
|
|
posx = self.posEncoder(x) |
|
|
posy = self.posEncoder(y) |
|
|
|
|
|
|
|
|
featx = self.feat_weight * x + self.pos_weight * posx |
|
|
featy = self.feat_weight * y + self.pos_weight * posy |
|
|
|
|
|
|
|
|
|
|
|
featx = featx.flatten(2).permute(2, 0, 1) |
|
|
featy = featy.flatten(2).permute(2, 0, 1) |
|
|
x_mask = x_mask.flatten(2).squeeze(1) if x_mask is not None else torch.cuda.BoolTensor(bx, hx * wx).fill_(False) |
|
|
y_mask = y_mask.flatten(2).squeeze(1) if y_mask is not None else torch.cuda.BoolTensor(by, hy * wy).fill_(False) |
|
|
|
|
|
|
|
|
len_seq_x, len_seq_y = featx.size()[0], featy.size()[0] |
|
|
|
|
|
output = torch.cat([featx, featy], dim=0) |
|
|
src_key_padding_mask = torch.cat((x_mask, y_mask), dim=1) |
|
|
with torch.no_grad() : |
|
|
src_mask = torch.cuda.BoolTensor(hx * wx + hy * wy, hx * wx + hy * wy).fill_(True) |
|
|
src_mask[:hx * wx, :hx * wx] = False |
|
|
src_mask[hx * wx :, hx * wx:] = False |
|
|
|
|
|
output = self.inner_encoder_layer(output, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask) |
|
|
|
|
|
outx, outy = output.narrow(0, 0, len_seq_x), output.narrow(0, len_seq_x, len_seq_y) |
|
|
outx, outy = outx.permute(1, 2, 0).view(bx, cx, hx, wx), outy.permute(1, 2, 0).view(by, cy, hy, wy) |
|
|
x_mask, y_mask = x_mask.view(bx, 1, hx, wx), y_mask.view(bx, 1, hy, wy) |
|
|
|
|
|
return outx, outy, x_mask, y_mask |
|
|
|
|
|
class EncoderLayerCrossAttention(nn.Module): |
|
|
""" |
|
|
Transformer encoder with all paramters |
|
|
""" |
|
|
def __init__(self, d_model, nhead, dim_feedforward, dropout, activation): |
|
|
super(EncoderLayerCrossAttention, self).__init__() |
|
|
|
|
|
self.cross_encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation = activation) |
|
|
|
|
|
def forward(self, featx, featy, featmask, x_mask = None, y_mask = None): |
|
|
''' |
|
|
input x: B, C, H, W |
|
|
input y: B, C, H, W |
|
|
input x_mask: B, 1, H, W, mask == True will be ignored |
|
|
input y_mask: B, 1, H, W, mask == True will be ignored |
|
|
''' |
|
|
|
|
|
bx, cx, hx, wx = featx.size() |
|
|
by, cy, hy, wy = featy.size() |
|
|
|
|
|
|
|
|
featx = featx.flatten(2).permute(2, 0, 1) |
|
|
featy = featy.flatten(2).permute(2, 0, 1) |
|
|
x_mask = x_mask.flatten(2).squeeze(1) if x_mask is not None else torch.cuda.BoolTensor(bx, hx * wx).fill_(False) |
|
|
y_mask = y_mask.flatten(2).squeeze(1) if y_mask is not None else torch.cuda.BoolTensor(by, hy * wy).fill_(False) |
|
|
|
|
|
|
|
|
len_seq_x, len_seq_y = featx.size()[0], featy.size()[0] |
|
|
|
|
|
output = torch.cat([featx, featy], dim=0) |
|
|
src_key_padding_mask = torch.cat((x_mask, y_mask), dim=1) |
|
|
with torch.no_grad() : |
|
|
src_mask = torch.cuda.BoolTensor(hx * wx + hy * wy, hx * wx + hy * wy).fill_(False) |
|
|
src_mask[:hx * wx, :hx * wx] = True |
|
|
src_mask[hx * wx :, hx * wx:] = True |
|
|
|
|
|
output = self.cross_encoder_layer(output, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask) |
|
|
|
|
|
outx, outy = output.narrow(0, 0, len_seq_x), output.narrow(0, len_seq_x, len_seq_y) |
|
|
outx, outy = outx.permute(1, 2, 0).view(bx, cx, hx, wx), outy.permute(1, 2, 0).view(by, cy, hy, wy) |
|
|
x_mask, y_mask = x_mask.view(bx, 1, hx, wx), y_mask.view(bx, 1, hy, wy) |
|
|
|
|
|
return outx, outy, x_mask, y_mask |
|
|
|
|
|
class EncoderLayerEmpty(nn.Module): |
|
|
""" |
|
|
Transformer encoder with all paramters |
|
|
""" |
|
|
def __init__(self): |
|
|
super(EncoderLayerEmpty, self).__init__() |
|
|
|
|
|
def forward(self, featx, featy, featmask, x_mask = None, y_mask = None): |
|
|
''' |
|
|
input x: B, C, H, W |
|
|
input y: B, C, H, W |
|
|
input x_mask: B, 1, H, W, mask == True will be ignored |
|
|
input y_mask: B, 1, H, W, mask == True will be ignored |
|
|
''' |
|
|
return featx, featy, x_mask, y_mask |
|
|
|
|
|
class EncoderLayerBlock(nn.Module): |
|
|
""" |
|
|
Transformer encoder with all paramters |
|
|
""" |
|
|
def __init__(self, d_model, nhead, dim_feedforward, dropout, activation, pos_weight, feat_weight, layer_type) : |
|
|
super(EncoderLayerBlock, self).__init__() |
|
|
|
|
|
cross_encoder_layer = EncoderLayerCrossAttention(d_model, nhead, dim_feedforward, dropout, activation) |
|
|
att_encoder_layer = EncoderLayerInnerAttention(d_model, nhead, dim_feedforward, dropout, activation, pos_weight, feat_weight) |
|
|
|
|
|
if layer_type[0] == 'C' : |
|
|
self.layer1 = cross_encoder_layer |
|
|
elif layer_type[0] == 'I' : |
|
|
self.layer1 = att_encoder_layer |
|
|
elif layer_type[0] == 'N' : |
|
|
self.layer1 = EncoderLayerEmpty() |
|
|
|
|
|
if layer_type[1] == 'C' : |
|
|
self.layer2 = cross_encoder_layer |
|
|
elif layer_type[1] == 'I' : |
|
|
self.layer2 = att_encoder_layer |
|
|
elif layer_type[1] == 'N' : |
|
|
self.layer2 = EncoderLayerEmpty() |
|
|
|
|
|
def forward(self, featx, featy, featmask, x_mask = None, y_mask = None): |
|
|
''' |
|
|
input x: B, C, H, W |
|
|
input y: B, C, H, W |
|
|
input x_mask: B, 1, H, W, mask == True will be ignored |
|
|
input y_mask: B, 1, H, W, mask == True will be ignored |
|
|
''' |
|
|
|
|
|
featx, featy, x_mask, y_mask = self.layer1(featx, featy, featmask, x_mask, y_mask) |
|
|
featx, featy, x_mask, y_mask = self.layer2(featx, featy, featmask, x_mask, y_mask) |
|
|
|
|
|
return featx, featy, x_mask, y_mask |
|
|
|
|
|
|
|
|
class Decoder(nn.Module): |
|
|
|
|
|
def __init__(self, in_channels=256, out_channels=3): |
|
|
super(Decoder, self).__init__() |
|
|
|
|
|
self.deconv1 = nn.ConvTranspose2d(in_channels, 128, 2, stride=2) |
|
|
self.relu1 = nn.ReLU() |
|
|
|
|
|
self.deconv2 = nn.ConvTranspose2d(128, out_channels, 2, stride=2) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
x = self.deconv1(x) |
|
|
x = self.relu1(x) |
|
|
|
|
|
x = self.deconv2(x) |
|
|
|
|
|
x = F.interpolate(x, scale_factor=4, mode='bilinear', align_corners=False) |
|
|
|
|
|
return x |
|
|
|
|
|
class ClsBranch(nn.Module): |
|
|
|
|
|
def __init__(self, in_dim): |
|
|
super(ClsBranch, self).__init__() |
|
|
|
|
|
self.conv = nn.Conv2d(in_dim, 1, 3) |
|
|
self.relu = nn.ReLU() |
|
|
|
|
|
self.mlp = nn.Sequential(*[nn.Linear(28*28, 32), |
|
|
nn.ReLU(), |
|
|
nn.Linear(32, 1), |
|
|
nn.Sigmoid()]) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
x = self.conv(x) |
|
|
x = self.relu(x) |
|
|
x = torch.flatten(x, start_dim=1) |
|
|
x = self.mlp(x) |
|
|
return x |
|
|
|
|
|
class Encoder(nn.Module): |
|
|
""" |
|
|
Transformer encoder with all paramters |
|
|
""" |
|
|
def __init__(self, feat_dim, pos_weight = 0.1, feat_weight=1, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048, dropout=0.1, activation='relu', layer_type = ['I', 'C', 'I', 'C', 'I', 'C'], drop_feat = 0.1): |
|
|
super(Encoder, self).__init__() |
|
|
|
|
|
self.num_layers = num_layers |
|
|
self.feat_proj = nn.Conv2d(feat_dim, d_model, kernel_size=1) |
|
|
self.drop_feat = nn.Dropout2d(p=drop_feat) |
|
|
self.encoder_blocks = nn.ModuleList([EncoderLayerBlock(d_model, nhead, dim_feedforward, dropout, activation, pos_weight, feat_weight, layer_type[i * 2 : i * 2 + 2]) for i in range(num_layers)]) |
|
|
|
|
|
self.decoder = Decoder(d_model, 3) |
|
|
self.cls_branch = ClsBranch(in_dim=256) |
|
|
self.sigmoid = nn.Sigmoid() |
|
|
self.eps = 1e-7 |
|
|
|
|
|
|
|
|
def forward(self, x, y, fmask, x_mask = None, y_mask = None): |
|
|
''' |
|
|
input x: B, C, H, W |
|
|
input y: B, C, H, W |
|
|
input x_mask: B, 1, H, W, mask == True will be ignored |
|
|
input y_mask: B, 1, H, W, mask == True will be ignored |
|
|
''' |
|
|
featx = self.feat_proj (x) |
|
|
featx = self.drop_feat(featx) |
|
|
|
|
|
bx, cx, hx, wx = featx.size() |
|
|
|
|
|
featy = self.feat_proj (y) |
|
|
featy = self.drop_feat(featy) |
|
|
|
|
|
by, cy, hy, wy = featy.size() |
|
|
|
|
|
featmask = self.feat_proj(fmask) |
|
|
|
|
|
for i in range(self.num_layers) : |
|
|
featx, featy, x_mask, y_mask = self.encoder_blocks[i](featx, featy, featmask, x_mask, y_mask) |
|
|
|
|
|
out_cls = self.cls_branch(featy) |
|
|
outx = self.sigmoid(self.decoder(featx)) |
|
|
outy = self.sigmoid(self.decoder(featy)) |
|
|
|
|
|
outx = torch.clamp(outx, min=self.eps, max=1-self.eps) |
|
|
outy = torch.clamp(outy, min=self.eps, max=1-self.eps) |
|
|
|
|
|
return outx, outy, out_cls |
|
|
|
|
|
|
|
|
class TransEncoder(nn.Module): |
|
|
""" |
|
|
Transformer encoder: small and large variants |
|
|
""" |
|
|
def __init__(self, feat_dim=1024, pos_weight = 0.1, feat_weight = 1, dropout=0.1, activation='relu', mode='small', layer_type=['I', 'C', 'I', 'C', 'I', 'N'], drop_feat=0.1): |
|
|
super(TransEncoder, self).__init__() |
|
|
|
|
|
if mode == 'tiny' : |
|
|
d_model=128 |
|
|
nhead=2 |
|
|
num_layers=3 |
|
|
dim_feedforward=256 |
|
|
|
|
|
elif mode == 'small' : |
|
|
d_model=256 |
|
|
nhead=2 |
|
|
num_layers=3 |
|
|
dim_feedforward=256 |
|
|
|
|
|
elif mode == 'base' : |
|
|
d_model=512 |
|
|
nhead=8 |
|
|
num_layers=3 |
|
|
dim_feedforward=2048 |
|
|
|
|
|
elif mode == 'large' : |
|
|
d_model=512 |
|
|
nhead=8 |
|
|
num_layers=6 |
|
|
dim_feedforward=2048 |
|
|
|
|
|
self.net = Encoder(feat_dim, pos_weight, feat_weight, d_model, nhead, num_layers, dim_feedforward, dropout, activation, layer_type, drop_feat) |
|
|
|
|
|
def forward(self, x, y, fmask, x_mask = None, y_mask = None): |
|
|
''' |
|
|
input x: B, C, H, W |
|
|
input y: B, C, H, W |
|
|
|
|
|
''' |
|
|
outx, outy, out_cls = self.net(x, y, fmask, x_mask, y_mask) |
|
|
|
|
|
return outx, outy, out_cls |
|
|
|
|
|
if __name__ == '__main__' : |
|
|
|
|
|
feat_dim = 256 |
|
|
mode = 'small' |
|
|
x = torch.cuda.FloatTensor(2, feat_dim, 10, 10) |
|
|
x_mask = torch.cuda.BoolTensor(2, 1, 10, 10) |
|
|
|
|
|
net = TransEncoder() |
|
|
|
|
|
print (net) |