YuqianFu's picture
Upload folder using huggingface_hub
fe6c2e4 verified
import torch
from torch import nn
import torch.nn.functional as F
import math
### --- Positional encoding--- ###
### --- Borrowed from Detr--- ###
class PositionEncodingSine2D(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None):
super(PositionEncodingSine2D, self).__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, x, isTarget = False):
'''
input x: B, C, H, W
return pos: B, C, H, W
'''
not_mask = torch.ones(x.size()[0], x.size()[2], x.size()[3]).to(x.device)
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
## no diff between source and target
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
class EncoderLayerInnerAttention(nn.Module):
"""
Transformer encoder with all paramters
"""
def __init__(self, d_model, nhead, dim_feedforward, dropout, activation, pos_weight, feat_weight):
super(EncoderLayerInnerAttention, self).__init__()
self.pos_weight = pos_weight
self.feat_weight = feat_weight
self.inner_encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation = activation)
self.posEncoder = PositionEncodingSine2D(d_model // 2)
self.cross_encoder_layer = EncoderLayerCrossAttention(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward,
dropout=dropout, activation=activation)
def forward(self, x, y, featmask, x_mask = None, y_mask = None):
'''
input x: B, C, H, W
input y: B, C, H, W
input x_mask: B, 1, H, W, mask == True will be ignored
input y_mask: B, 1, H, W, mask == True will be ignored
'''
# do cross attention on x and featmask
x = self.cross_encoder_layer(x, featmask, None)[0]
bx, cx, hx, wx = x.size()
by, cy, hy, wy = y.size()
posx = self.posEncoder(x)
posy = self.posEncoder(y)
featx = self.feat_weight * x + self.pos_weight * posx
featy = self.feat_weight * y + self.pos_weight * posy
## input of transformer should be : seq_len * batch_size * feat_dim
featx = featx.flatten(2).permute(2, 0, 1)
featy = featy.flatten(2).permute(2, 0, 1)
x_mask = x_mask.flatten(2).squeeze(1) if x_mask is not None else torch.cuda.BoolTensor(bx, hx * wx).fill_(False)
y_mask = y_mask.flatten(2).squeeze(1) if y_mask is not None else torch.cuda.BoolTensor(by, hy * wy).fill_(False)
## input of transformer: (seq_len*2) * batch_size * feat_dim
len_seq_x, len_seq_y = featx.size()[0], featy.size()[0]
output = torch.cat([featx, featy], dim=0)
src_key_padding_mask = torch.cat((x_mask, y_mask), dim=1)
with torch.no_grad() :
src_mask = torch.cuda.BoolTensor(hx * wx + hy * wy, hx * wx + hy * wy).fill_(True)
src_mask[:hx * wx, :hx * wx] = False
src_mask[hx * wx :, hx * wx:] = False
output = self.inner_encoder_layer(output, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask)
outx, outy = output.narrow(0, 0, len_seq_x), output.narrow(0, len_seq_x, len_seq_y)
outx, outy = outx.permute(1, 2, 0).view(bx, cx, hx, wx), outy.permute(1, 2, 0).view(by, cy, hy, wy)
x_mask, y_mask = x_mask.view(bx, 1, hx, wx), y_mask.view(bx, 1, hy, wy)
return outx, outy, x_mask, y_mask
class EncoderLayerCrossAttention(nn.Module):
"""
Transformer encoder with all paramters
"""
def __init__(self, d_model, nhead, dim_feedforward, dropout, activation):
super(EncoderLayerCrossAttention, self).__init__()
self.cross_encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation = activation)
def forward(self, featx, featy, featmask, x_mask = None, y_mask = None):
'''
input x: B, C, H, W
input y: B, C, H, W
input x_mask: B, 1, H, W, mask == True will be ignored
input y_mask: B, 1, H, W, mask == True will be ignored
'''
bx, cx, hx, wx = featx.size()
by, cy, hy, wy = featy.size()
## input of transformer should be : seq_len * batch_size * feat_dim
featx = featx.flatten(2).permute(2, 0, 1)
featy = featy.flatten(2).permute(2, 0, 1)
x_mask = x_mask.flatten(2).squeeze(1) if x_mask is not None else torch.cuda.BoolTensor(bx, hx * wx).fill_(False)
y_mask = y_mask.flatten(2).squeeze(1) if y_mask is not None else torch.cuda.BoolTensor(by, hy * wy).fill_(False)
## input of transformer: (seq_len*2) * batch_size * feat_dim
len_seq_x, len_seq_y = featx.size()[0], featy.size()[0]
output = torch.cat([featx, featy], dim=0)
src_key_padding_mask = torch.cat((x_mask, y_mask), dim=1)
with torch.no_grad() :
src_mask = torch.cuda.BoolTensor(hx * wx + hy * wy, hx * wx + hy * wy).fill_(False)
src_mask[:hx * wx, :hx * wx] = True
src_mask[hx * wx :, hx * wx:] = True
output = self.cross_encoder_layer(output, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask)
outx, outy = output.narrow(0, 0, len_seq_x), output.narrow(0, len_seq_x, len_seq_y)
outx, outy = outx.permute(1, 2, 0).view(bx, cx, hx, wx), outy.permute(1, 2, 0).view(by, cy, hy, wy)
x_mask, y_mask = x_mask.view(bx, 1, hx, wx), y_mask.view(bx, 1, hy, wy)
return outx, outy, x_mask, y_mask
class EncoderLayerEmpty(nn.Module):
"""
Transformer encoder with all paramters
"""
def __init__(self):
super(EncoderLayerEmpty, self).__init__()
def forward(self, featx, featy, featmask, x_mask = None, y_mask = None):
'''
input x: B, C, H, W
input y: B, C, H, W
input x_mask: B, 1, H, W, mask == True will be ignored
input y_mask: B, 1, H, W, mask == True will be ignored
'''
return featx, featy, x_mask, y_mask
class EncoderLayerBlock(nn.Module):
"""
Transformer encoder with all paramters
"""
def __init__(self, d_model, nhead, dim_feedforward, dropout, activation, pos_weight, feat_weight, layer_type) :
super(EncoderLayerBlock, self).__init__()
cross_encoder_layer = EncoderLayerCrossAttention(d_model, nhead, dim_feedforward, dropout, activation)
att_encoder_layer = EncoderLayerInnerAttention(d_model, nhead, dim_feedforward, dropout, activation, pos_weight, feat_weight)
if layer_type[0] == 'C' :
self.layer1 = cross_encoder_layer
elif layer_type[0] == 'I' :
self.layer1 = att_encoder_layer
elif layer_type[0] == 'N' :
self.layer1 = EncoderLayerEmpty()
if layer_type[1] == 'C' :
self.layer2 = cross_encoder_layer
elif layer_type[1] == 'I' :
self.layer2 = att_encoder_layer
elif layer_type[1] == 'N' :
self.layer2 = EncoderLayerEmpty()
def forward(self, featx, featy, featmask, x_mask = None, y_mask = None):
'''
input x: B, C, H, W
input y: B, C, H, W
input x_mask: B, 1, H, W, mask == True will be ignored
input y_mask: B, 1, H, W, mask == True will be ignored
'''
featx, featy, x_mask, y_mask = self.layer1(featx, featy, featmask, x_mask, y_mask)
featx, featy, x_mask, y_mask = self.layer2(featx, featy, featmask, x_mask, y_mask)
return featx, featy, x_mask, y_mask
### --- Transformer Encoder --- ###
class Decoder(nn.Module):
def __init__(self, in_channels=256, out_channels=3):
super(Decoder, self).__init__()
self.deconv1 = nn.ConvTranspose2d(in_channels, 128, 2, stride=2) # 60
self.relu1 = nn.ReLU()
self.deconv2 = nn.ConvTranspose2d(128, out_channels, 2, stride=2) # 120
def forward(self, x):
x = self.deconv1(x)
x = self.relu1(x)
x = self.deconv2(x)
x = F.interpolate(x, scale_factor=4, mode='bilinear', align_corners=False)
return x
class ClsBranch(nn.Module):
# branch to predict if the object exists or not.
def __init__(self, in_dim):
super(ClsBranch, self).__init__()
self.conv = nn.Conv2d(in_dim, 1, 3)
self.relu = nn.ReLU()
self.mlp = nn.Sequential(*[nn.Linear(28*28, 32),
nn.ReLU(),
nn.Linear(32, 1),
nn.Sigmoid()])
def forward(self, x):
x = self.conv(x)
x = self.relu(x)
x = torch.flatten(x, start_dim=1)
x = self.mlp(x)
return x
class Encoder(nn.Module):
"""
Transformer encoder with all paramters
"""
def __init__(self, feat_dim, pos_weight = 0.1, feat_weight=1, d_model=512, nhead=8, num_layers=6, dim_feedforward=2048, dropout=0.1, activation='relu', layer_type = ['I', 'C', 'I', 'C', 'I', 'C'], drop_feat = 0.1):
super(Encoder, self).__init__()
self.num_layers = num_layers
self.feat_proj = nn.Conv2d(feat_dim, d_model, kernel_size=1)
self.drop_feat = nn.Dropout2d(p=drop_feat)
self.encoder_blocks = nn.ModuleList([EncoderLayerBlock(d_model, nhead, dim_feedforward, dropout, activation, pos_weight, feat_weight, layer_type[i * 2 : i * 2 + 2]) for i in range(num_layers)])
self.decoder = Decoder(d_model, 3)
self.cls_branch = ClsBranch(in_dim=256)
self.sigmoid = nn.Sigmoid()
self.eps = 1e-7
def forward(self, x, y, fmask, x_mask = None, y_mask = None):
'''
input x: B, C, H, W
input y: B, C, H, W
input x_mask: B, 1, H, W, mask == True will be ignored
input y_mask: B, 1, H, W, mask == True will be ignored
'''
featx = self.feat_proj (x)
featx = self.drop_feat(featx)
bx, cx, hx, wx = featx.size()
featy = self.feat_proj (y)
featy = self.drop_feat(featy)
by, cy, hy, wy = featy.size()
featmask = self.feat_proj(fmask)
for i in range(self.num_layers) :
featx, featy, x_mask, y_mask = self.encoder_blocks[i](featx, featy, featmask, x_mask, y_mask)
out_cls = self.cls_branch(featy)
outx = self.sigmoid(self.decoder(featx))
outy = self.sigmoid(self.decoder(featy))
outx = torch.clamp(outx, min=self.eps, max=1-self.eps)
outy = torch.clamp(outy, min=self.eps, max=1-self.eps)
return outx, outy, out_cls
### --- Transformer Encoder --- ###
class TransEncoder(nn.Module):
"""
Transformer encoder: small and large variants
"""
def __init__(self, feat_dim=1024, pos_weight = 0.1, feat_weight = 1, dropout=0.1, activation='relu', mode='small', layer_type=['I', 'C', 'I', 'C', 'I', 'N'], drop_feat=0.1):
super(TransEncoder, self).__init__()
if mode == 'tiny' :
d_model=128
nhead=2
num_layers=3
dim_feedforward=256
elif mode == 'small' :
d_model=256
nhead=2
num_layers=3
dim_feedforward=256
elif mode == 'base' :
d_model=512
nhead=8
num_layers=3
dim_feedforward=2048
elif mode == 'large' :
d_model=512
nhead=8
num_layers=6
dim_feedforward=2048
self.net = Encoder(feat_dim, pos_weight, feat_weight, d_model, nhead, num_layers, dim_feedforward, dropout, activation, layer_type, drop_feat)
def forward(self, x, y, fmask, x_mask = None, y_mask = None):
'''
input x: B, C, H, W
input y: B, C, H, W
'''
outx, outy, out_cls = self.net(x, y, fmask, x_mask, y_mask)
return outx, outy, out_cls
if __name__ == '__main__' :
feat_dim = 256
mode = 'small'
x = torch.cuda.FloatTensor(2, feat_dim, 10, 10)
x_mask = torch.cuda.BoolTensor(2, 1, 10, 10)
net = TransEncoder()
print (net)