import torch
import torch.nn as nn
class CBOW(nn.Module):
def __init__(self, vocab_size, embed_dim):
super().__init__()
self.output_dim = embed_dim
# padding은 embedding을 lookup해서 사용하지 않게 하기 위해
# vocab_size * embed_dim 의 크기를 가진 테이블 생성
self.embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
def forward(self, x):
# (batch_size, sequence) -> (batch_size, sequence, embed_dim)
# 해당 시퀀스를 Look up table을 보고 가져옴
x_embeded = self.embeddings(x)
stnc_repr = torch.mean(x_embeded, dim=1) # batch_size x embed_dim
return stnc_repr
torch.nn.Embedding이 어떻게 동작하는지 이해를 해보고 테스트를 해보았다.
input x는 위 그림과 같이 sequence당 embed 테이블의 값을 가져올 것이다.
'인공지능 > Deep Learning' 카테고리의 다른 글
RN(Relation Network) (0) | 2023.03.23 |
---|---|
Conv2d Layer의 weight와 bias의 shape은? (0) | 2023.03.04 |
Dropout 실습 (0) | 2023.03.04 |
Cross Entropy 구현하기 (0) | 2023.03.04 |
분류경계선이 선형으로 나오는 모델은 선형분류만 가능 (0) | 2023.03.02 |