Embedding

import torch
import torch.nn as nn

class CBOW(nn.Module):

    def __init__(self, vocab_size, embed_dim):
        super().__init__()

        self.output_dim = embed_dim
        # padding은 embedding을 lookup해서 사용하지 않게 하기 위해
        # vocab_size * embed_dim 의 크기를 가진 테이블 생성
        self.embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=0)

    def forward(self, x):

        # (batch_size, sequence) -> (batch_size, sequence, embed_dim)
        # 해당 시퀀스를 Look up table을 보고 가져옴
        x_embeded = self.embeddings(x)
        stnc_repr = torch.mean(x_embeded, dim=1) # batch_size x embed_dim
        return stnc_repr

torch.nn.Embedding이 어떻게 동작하는지 이해를 해보고 테스트를 해보았다.

 

input x는 위 그림과 같이 sequence당 embed 테이블의 값을 가져올 것이다.