Skip to content

Commit b5c901b

Browse files
committed
feat: bert model implementation
1 parent 2da522d commit b5c901b

File tree

2 files changed

+162
-0
lines changed

2 files changed

+162
-0
lines changed

bert/config.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from typing import NamedTuple
2+
3+
4+
class BertConfig(NamedTuple):
5+
vocab_size: int
6+
type_vocab_size: int
7+
max_position_embeddings: int
8+
9+
hidden_size: int
10+
hidden_act: str
11+
initializer_range: float
12+
intermediate_size: int
13+
num_attention_heads: int
14+
num_hidden_layers: int
15+
16+
layer_norm_eps: float
17+
hidden_dropout_prob: float
18+
attention_probs_dropout_prob: float

bert/model.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import math
2+
from typing import Tuple
3+
4+
import torch
5+
from torch import Tensor, nn
6+
from torch.nn import functional as fnn
7+
8+
from .config import BertConfig
9+
10+
11+
class BertEmbedding(nn.Module):
12+
def __init__(self, config: BertConfig):
13+
super().__init__()
14+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
15+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
16+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
17+
18+
self.layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps)
19+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
20+
21+
def forward(self, input_ids: Tensor, token_type_ids: Tensor, turn_type_ids: Tensor) -> Tensor:
22+
word_embeds = self.word_embeddings(input_ids)
23+
token_type_embeds = self.token_type_embeddings(token_type_ids)
24+
position_embed = self.position_embeddings(turn_type_ids)
25+
26+
embed_output = word_embeds + token_type_embeds + position_embed
27+
embed_output = self.layer_norm(embed_output)
28+
embed_output = self.dropout(embed_output)
29+
return embed_output
30+
31+
32+
class BertMultiHeadAttention(nn.Module):
33+
def __init__(self, config: BertConfig):
34+
super().__init__()
35+
self.num_attention_heads = config.num_attention_heads
36+
self.head_hidden_size = config.hidden_size // config.num_attention_heads
37+
self.hidden_size = config.hidden_size
38+
39+
self.query = nn.Linear(config.hidden_size, config.hidden_size)
40+
self.key = nn.Linear(config.hidden_size, config.hidden_size)
41+
self.value = nn.Linear(config.hidden_size, config.hidden_size)
42+
self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob)
43+
44+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
45+
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
46+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
47+
48+
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
49+
# query, key, value linear projection
50+
query_output = self.query(hidden_states)
51+
key_output = self.key(hidden_states)
52+
value_output = self.value(hidden_states)
53+
54+
seq_len = hidden_states.size(1)
55+
56+
# split hidden_state into num_heads pieces (hidden_size = num_attention_heads * head_hidden_size)
57+
# ops #1: (batch, seq_len, hidden_size) -> (batch, seq_len, num_attention_heads, head_hidden_size)
58+
# ops #2: (batch, seq_len, num_attention_heads, head_hidden_size) -> (batch, num_attention_heads, seq_len, head_hidden_size)
59+
# output: (batch, num_attention_heads, seq_len, head_hidden_size)
60+
query_output = query_output.view(-1, seq_len, self.num_attention_heads, self.head_hidden_size)
61+
query_output = query_output.transpose(1, 2)
62+
key_output = key_output.view(-1, seq_len, self.num_attention_heads, self.head_hidden_size)
63+
key_output = key_output.transpose(1, 2)
64+
value_output = value_output.view(-1, seq_len, self.num_attention_heads, self.head_hidden_size)
65+
value_output = value_output.transpose(1, 2)
66+
67+
# attention_ops: (batch, num_attention_heads, seq_len, head_hidden_size) x (batch, num_attention_heads, head_hidden_size, seq_len)
68+
# output: (batch, num_attention_heads, seq_len, seq_len)
69+
attention_scores = torch.matmul(query_output, key_output.transpose(2, 3))
70+
attention_scores = attention_scores / math.sqrt(self.head_hidden_size)
71+
72+
# TODO: attention mask
73+
# TODO: head mask
74+
75+
# normalize attention scores to probs
76+
attention_probs = fnn.softmax(attention_scores, dim=-1)
77+
attention_probs = self.dropout(attention_probs)
78+
79+
# context_ops: (batch, num_attention_heads, seq_len, seq_len) x (batch, num_attention_heads, seq_len, head_hidden_size)
80+
# output: (batch, num_attention_heads, seq_len, hidden_size)
81+
context_encoded_output = torch.matmul(attention_probs, value_output)
82+
83+
# merge multi-head output to single head output
84+
# ops1: (batch, num_attention_heads, seq_len, head_hidden_size) -> (batch, seq_len, num_attention_heads, head_hidden_size)
85+
# ops2: (batch, seq_len, num_attention_heads, head_hidden_size) -> (batch, seq_len, hidden_size)
86+
# output: (batch, seq_len, num_attention_heads, head_hidden_size)
87+
context_encoded_output = context_encoded_output.transpose(1, 2).contiguous()
88+
context_encoded_output = context_encoded_output.view(-1, seq_len, self.hidden_size)
89+
90+
# output linear projection + layer norm + dropout
91+
context_encoded_output = self.dense(context_encoded_output)
92+
context_encoded_output = self.layer_norm(context_encoded_output)
93+
context_encoded_output = self.dropout(context_encoded_output)
94+
95+
return context_encoded_output
96+
97+
98+
class BertLayer(nn.Module):
99+
def __init__(self, config: BertConfig):
100+
super().__init__()
101+
self.attention = BertMultiHeadAttention(config)
102+
103+
self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
104+
self.intermediate_activation_fn = nn.GELU()
105+
106+
self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
107+
self.output_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
108+
self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
109+
110+
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
111+
context_encoded_output = self.attention(hidden_states, attention_mask)
112+
113+
intermediate_output = self.intermediate_dense(context_encoded_output)
114+
intermediate_output = self.intermediate_activation_fn(intermediate_output)
115+
116+
layer_output = self.output_dense(intermediate_output)
117+
layer_output = self.output_layer_norm(layer_output)
118+
layer_output = self.output_dropout(layer_output)
119+
return layer_output
120+
121+
122+
class BertModel(nn.Module):
123+
def __init__(self, config: BertConfig):
124+
super().__init__()
125+
self.config = config
126+
127+
self.embedding = BertEmbedding(config)
128+
self.layers = nn.ModuleList([BertLayer(config) for layer in range(config.num_hidden_layers)])
129+
130+
self.pooler_dense = nn.Linear(config.hidden_size, config.hidden_size)
131+
self.pooler_activation_fn = nn.Tanh()
132+
133+
def forward(
134+
self, input_ids: Tensor, attention_mask: Tensor, token_type_ids: Tensor, position_ids: Tensor
135+
) -> Tuple[Tensor, Tensor]:
136+
hidden_states = self.embedding(input_ids, token_type_ids, position_ids)
137+
138+
for layer in self.layers:
139+
hidden_states = layer(hidden_states, attention_mask)
140+
141+
pooled_output = self.pooler_dense(hidden_states[:, 0])
142+
pooled_output = self.pooler_activation_fn(pooled_output)
143+
144+
return pooled_output, hidden_states

0 commit comments

Comments
 (0)