코드

Channel-wise dropout 코드 구현

gmlee729 2024. 12. 29. 18:59

CustomDropout이 원래 dropout - nn.Dropout과 학습곡선 동일하게 나옴

ChannelDropout이 바꾼거 

 

성능은... 단순 CNN에 넣으니까 둘다 잘 안나오는데 음....

 

class CustomDropout(nn.Module):
    def __init__(self, p: float = 0.5):
        super(CustomDropout, self).__init__()
        self.p = p

    def forward(self, x):
        if self.training: 
            mask = (torch.rand_like(x) > self.p).float()
            return x * mask / (1 - self.p) 
        else:
            return x  

class ChannelDropout(nn.Module):
    def __init__(self, p: float = 0.5):
        super(ChannelDropout, self).__init__()
        self.p = p

    def forward(self, x):
        if self.training:
         
            drop = (torch.rand(1, x.shape[1], 1, 1, device=x.device) > self.p).float()
            x = x * drop / (1 - self.p)            
            return x
        else:
            return x