From 65a2ae453e6a7d583cbab515d7339a2b044c87c1 Mon Sep 17 00:00:00 2001 From: codertimo Date: Tue, 23 Oct 2018 22:28:48 +0900 Subject: [PATCH 01/10] Removing WordVocab on dataset.vocab --- README.md | 14 +-- bert_pytorch/__main__.py | 4 +- bert_pytorch/dataset/__init__.py | 1 - bert_pytorch/dataset/vocab.py | 185 ------------------------------- bert_pytorch/trainer/pretrain.py | 19 ++-- requirements.txt | 3 +- 6 files changed, 19 insertions(+), 207 deletions(-) delete mode 100644 bert_pytorch/dataset/vocab.py diff --git a/README.md b/README.md index 16d3df9..a242b56 100644 --- a/README.md +++ b/README.md @@ -42,20 +42,20 @@ pip install bert-pytorch ### 0. Prepare your corpus ``` -Welcome to the \t the jungle\n -I can stay \t here all night\n +Welcome to the the jungle +I can stay here all night ``` or tokenized corpus (tokenization is not in package) ``` -Wel_ _come _to _the \t _the _jungle\n -_I _can _stay \t _here _all _night\n +Wel_ _come _to _the _the _jungle +_I _can _stay _here _all _night ``` ### 1. Building vocab based on your corpus ```shell -bert-vocab -c data/corpus.small -o data/vocab.small +vocab-builder -c data/corpus.small -o data/vocab.small ``` ### 2. Train your own BERT model @@ -109,7 +109,7 @@ not directly captured by language modeling ## Author -Junseong Kim, Scatter Lab (codertimo@gmail.com / junseong.kim@scatter.co.kr) +Junseong Kim, Scatter Lab (codertimo@gmail.com / junseong.kim@scatterlab.co.kr) ## License @@ -117,4 +117,4 @@ This project following Apache 2.0 License as written in LICENSE file Copyright 2018 Junseong Kim, Scatter Lab, respective BERT contributors -Copyright (c) 2018 Alexander Rush : [The Annotated Trasnformer](https://github.com/harvardnlp/annotated-transformer) +Copyright (c) 2018 Alexander Rush : [The Annotated Transformer](https://github.com/harvardnlp/annotated-transformer) diff --git a/bert_pytorch/__main__.py b/bert_pytorch/__main__.py index d4193f2..a390ab8 100644 --- a/bert_pytorch/__main__.py +++ b/bert_pytorch/__main__.py @@ -4,7 +4,9 @@ from .model import BERT from .trainer import BERTTrainer -from .dataset import BERTDataset, WordVocab +from .dataset import BERTDataset + +from vocab_builder import WordVocab def train(): diff --git a/bert_pytorch/dataset/__init__.py b/bert_pytorch/dataset/__init__.py index 90e9036..2341716 100644 --- a/bert_pytorch/dataset/__init__.py +++ b/bert_pytorch/dataset/__init__.py @@ -1,2 +1 @@ from .dataset import BERTDataset -from .vocab import WordVocab diff --git a/bert_pytorch/dataset/vocab.py b/bert_pytorch/dataset/vocab.py deleted file mode 100644 index f7346a7..0000000 --- a/bert_pytorch/dataset/vocab.py +++ /dev/null @@ -1,185 +0,0 @@ -import pickle -import tqdm -from collections import Counter - - -class TorchVocab(object): - """Defines a vocabulary object that will be used to numericalize a field. - Attributes: - freqs: A collections.Counter object holding the frequencies of tokens - in the data used to build the Vocab. - stoi: A collections.defaultdict instance mapping token strings to - numerical identifiers. - itos: A list of token strings indexed by their numerical identifiers. - """ - - def __init__(self, counter, max_size=None, min_freq=1, specials=['', ''], - vectors=None, unk_init=None, vectors_cache=None): - """Create a Vocab object from a collections.Counter. - Arguments: - counter: collections.Counter object holding the frequencies of - each value found in the data. - max_size: The maximum size of the vocabulary, or None for no - maximum. Default: None. - min_freq: The minimum frequency needed to include a token in the - vocabulary. Values less than 1 will be set to 1. Default: 1. - specials: The list of special tokens (e.g., padding or eos) that - will be prepended to the vocabulary in addition to an - token. Default: [''] - vectors: One of either the available pretrained vectors - or custom pretrained vectors (see Vocab.load_vectors); - or a list of aforementioned vectors - unk_init (callback): by default, initialize out-of-vocabulary word vectors - to zero vectors; can be any function that takes in a Tensor and - returns a Tensor of the same size. Default: torch.Tensor.zero_ - vectors_cache: directory for cached vectors. Default: '.vector_cache' - """ - self.freqs = counter - counter = counter.copy() - min_freq = max(min_freq, 1) - - self.itos = list(specials) - # frequencies of special tokens are not counted when building vocabulary - # in frequency order - for tok in specials: - del counter[tok] - - max_size = None if max_size is None else max_size + len(self.itos) - - # sort by frequency, then alphabetically - words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0]) - words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True) - - for word, freq in words_and_frequencies: - if freq < min_freq or len(self.itos) == max_size: - break - self.itos.append(word) - - # stoi is simply a reverse dict for itos - self.stoi = {tok: i for i, tok in enumerate(self.itos)} - - self.vectors = None - if vectors is not None: - self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache) - else: - assert unk_init is None and vectors_cache is None - - def __eq__(self, other): - if self.freqs != other.freqs: - return False - if self.stoi != other.stoi: - return False - if self.itos != other.itos: - return False - if self.vectors != other.vectors: - return False - return True - - def __len__(self): - return len(self.itos) - - def vocab_rerank(self): - self.stoi = {word: i for i, word in enumerate(self.itos)} - - def extend(self, v, sort=False): - words = sorted(v.itos) if sort else v.itos - for w in words: - if w not in self.stoi: - self.itos.append(w) - self.stoi[w] = len(self.itos) - 1 - - -class Vocab(TorchVocab): - def __init__(self, counter, max_size=None, min_freq=1): - self.pad_index = 0 - self.unk_index = 1 - self.eos_index = 2 - self.sos_index = 3 - self.mask_index = 4 - super().__init__(counter, specials=["", "", "", "", ""], - max_size=max_size, min_freq=min_freq) - - def to_seq(self, sentece, seq_len, with_eos=False, with_sos=False) -> list: - pass - - def from_seq(self, seq, join=False, with_pad=False): - pass - - @staticmethod - def load_vocab(vocab_path: str) -> 'Vocab': - with open(vocab_path, "rb") as f: - return pickle.load(f) - - def save_vocab(self, vocab_path): - with open(vocab_path, "wb") as f: - pickle.dump(self, f) - - -# Building Vocab with text files -class WordVocab(Vocab): - def __init__(self, texts, max_size=None, min_freq=1): - print("Building Vocab") - counter = Counter() - for line in tqdm.tqdm(texts): - if isinstance(line, list): - words = line - else: - words = line.replace("\n", "").replace("\t", "").split() - - for word in words: - counter[word] += 1 - super().__init__(counter, max_size=max_size, min_freq=min_freq) - - def to_seq(self, sentence, seq_len=None, with_eos=False, with_sos=False, with_len=False): - if isinstance(sentence, str): - sentence = sentence.split() - - seq = [self.stoi.get(word, self.unk_index) for word in sentence] - - if with_eos: - seq += [self.eos_index] # this would be index 1 - if with_sos: - seq = [self.sos_index] + seq - - origin_seq_len = len(seq) - - if seq_len is None: - pass - elif len(seq) <= seq_len: - seq += [self.pad_index for _ in range(seq_len - len(seq))] - else: - seq = seq[:seq_len] - - return (seq, origin_seq_len) if with_len else seq - - def from_seq(self, seq, join=False, with_pad=False): - words = [self.itos[idx] - if idx < len(self.itos) - else "<%d>" % idx - for idx in seq - if not with_pad or idx != self.pad_index] - - return " ".join(words) if join else words - - @staticmethod - def load_vocab(vocab_path: str) -> 'WordVocab': - with open(vocab_path, "rb") as f: - return pickle.load(f) - - -def build(): - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("-c", "--corpus_path", required=True, type=str) - parser.add_argument("-o", "--output_path", required=True, type=str) - parser.add_argument("-s", "--vocab_size", type=int, default=None) - parser.add_argument("-e", "--encoding", type=str, default="utf-8") - parser.add_argument("-m", "--min_freq", type=int, default=1) - args = parser.parse_args() - - with open(args.corpus_path, "r", encoding=args.encoding) as f: - vocab = WordVocab(f, max_size=args.vocab_size, min_freq=args.min_freq) - - print("VOCAB SIZE:", len(vocab)) - vocab.save_vocab(args.output_path) diff --git a/bert_pytorch/trainer/pretrain.py b/bert_pytorch/trainer/pretrain.py index 0b882dd..d1c0894 100644 --- a/bert_pytorch/trainer/pretrain.py +++ b/bert_pytorch/trainer/pretrain.py @@ -1,13 +1,12 @@ import torch import torch.nn as nn + from torch.optim import Adam from torch.utils.data import DataLoader from ..model import BERTLM, BERT from .optim_schedule import ScheduledOptim -import tqdm - class BERTTrainer: """ @@ -84,17 +83,11 @@ def iteration(self, epoch, data_loader, train=True): """ str_code = "train" if train else "test" - # Setting the tqdm progress bar - data_iter = tqdm.tqdm(enumerate(data_loader), - desc="EP_%s:%d" % (str_code, epoch), - total=len(data_loader), - bar_format="{l_bar}{r_bar}") - avg_loss = 0.0 total_correct = 0 total_element = 0 - for i, data in data_iter: + for i, data in enumerate(data_loader): # 0. batch_data will be sent into the device(GPU or cpu) data = {key: value.to(self.device) for key, value in data.items()} @@ -124,14 +117,16 @@ def iteration(self, epoch, data_loader, train=True): post_fix = { "epoch": epoch, - "iter": i, + "iter": "[%d/%d]" % (i, len(data_loader)), "avg_loss": avg_loss / (i + 1), - "avg_acc": total_correct / total_element * 100, + "mask_loss": mask_loss.item(), + "next_loss": next_loss.item(), + "avg_next_acc": total_correct / total_element * 100, "loss": loss.item() } if i % self.log_freq == 0: - data_iter.write(str(post_fix)) + print(str(post_fix)) print("EP%d_%s, avg_loss=" % (epoch, str_code), avg_loss / len(data_iter), "total_acc=", total_correct * 100.0 / total_element) diff --git a/requirements.txt b/requirements.txt index 3689708..d53ab8c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ tqdm numpy -torch>=0.4.0 \ No newline at end of file +torch >= 0.4.0 +vocab-builder \ No newline at end of file From 2a0b28218f4fde216cbb7750eb584c2ada0d487b Mon Sep 17 00:00:00 2001 From: codertimo Date: Thu, 25 Oct 2018 09:54:01 +0900 Subject: [PATCH 02/10] Fix issue #32 miss padding issue --- bert_pytorch/trainer/pretrain.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/bert_pytorch/trainer/pretrain.py b/bert_pytorch/trainer/pretrain.py index d1c0894..aac13c6 100644 --- a/bert_pytorch/trainer/pretrain.py +++ b/bert_pytorch/trainer/pretrain.py @@ -58,7 +58,8 @@ def __init__(self, bert: BERT, vocab_size: int, self.optim_schedule = ScheduledOptim(self.optim, self.bert.hidden, n_warmup_steps=warmup_steps) # Using Negative Log Likelihood Loss function for predicting the masked_token - self.criterion = nn.NLLLoss(ignore_index=0) + self.masked_criterion = nn.NLLLoss(ignore_index=0) + self.next_criterion = nn.NLLLoss() self.log_freq = log_freq @@ -95,10 +96,10 @@ def iteration(self, epoch, data_loader, train=True): next_sent_output, mask_lm_output = self.model.forward(data["bert_input"], data["segment_label"]) # 2-1. NLL(negative log likelihood) loss of is_next classification result - next_loss = self.criterion(next_sent_output, data["is_next"]) + next_loss = self.next_criterion(next_sent_output, data["is_next"]) # 2-2. NLLLoss of predicting masked token word - mask_loss = self.criterion(mask_lm_output.transpose(1, 2), data["bert_label"]) + mask_loss = self.masked_criterion(mask_lm_output.transpose(1, 2), data["bert_label"]) # 2-3. Adding next_loss and mask_loss : 3.4 Pre-training Procedure loss = next_loss + mask_loss @@ -128,7 +129,7 @@ def iteration(self, epoch, data_loader, train=True): if i % self.log_freq == 0: print(str(post_fix)) - print("EP%d_%s, avg_loss=" % (epoch, str_code), avg_loss / len(data_iter), "total_acc=", + print("EP%d_%s, avg_loss=" % (epoch, str_code), avg_loss / len(data_loader), "total_acc=", total_correct * 100.0 / total_element) def save(self, epoch, file_path="output/bert_trained.model"): From 9ace5fbbb1db3c7bde23ddc2f0b7cba959e28160 Mon Sep 17 00:00:00 2001 From: codertimo Date: Thu, 25 Oct 2018 10:03:39 +0900 Subject: [PATCH 03/10] Bumping version 0.0.1a5 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 4e721cf..2d1699b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ import os import sys -__version__ = "0.0.1a4" +__version__ = "0.0.1a5" with open("requirements.txt") as f: require_packages = [line[:-1] if line[-1] == "\n" else line for line in f] From 8f98c858332ac0708a4b8d42cd9c755d26154183 Mon Sep 17 00:00:00 2001 From: codertimo Date: Thu, 25 Oct 2018 15:01:40 +0900 Subject: [PATCH 04/10] Weighted Loss --- bert_pytorch/trainer/pretrain.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bert_pytorch/trainer/pretrain.py b/bert_pytorch/trainer/pretrain.py index aac13c6..ef2dd8c 100644 --- a/bert_pytorch/trainer/pretrain.py +++ b/bert_pytorch/trainer/pretrain.py @@ -102,7 +102,7 @@ def iteration(self, epoch, data_loader, train=True): mask_loss = self.masked_criterion(mask_lm_output.transpose(1, 2), data["bert_label"]) # 2-3. Adding next_loss and mask_loss : 3.4 Pre-training Procedure - loss = next_loss + mask_loss + loss = next_loss + 10 * mask_loss # 3. backward and optimization only in train if train: From f5c128a34b2d860d13233942dee02bf88c9a1ac6 Mon Sep 17 00:00:00 2001 From: codertimo Date: Thu, 25 Oct 2018 15:04:13 +0900 Subject: [PATCH 05/10] Change Next loss weighted sum --- bert_pytorch/trainer/pretrain.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bert_pytorch/trainer/pretrain.py b/bert_pytorch/trainer/pretrain.py index ef2dd8c..5f24481 100644 --- a/bert_pytorch/trainer/pretrain.py +++ b/bert_pytorch/trainer/pretrain.py @@ -96,13 +96,13 @@ def iteration(self, epoch, data_loader, train=True): next_sent_output, mask_lm_output = self.model.forward(data["bert_input"], data["segment_label"]) # 2-1. NLL(negative log likelihood) loss of is_next classification result - next_loss = self.next_criterion(next_sent_output, data["is_next"]) + next_loss = self.next_criterion(next_sent_output, data["is_next"]) * 10 # 2-2. NLLLoss of predicting masked token word mask_loss = self.masked_criterion(mask_lm_output.transpose(1, 2), data["bert_label"]) # 2-3. Adding next_loss and mask_loss : 3.4 Pre-training Procedure - loss = next_loss + 10 * mask_loss + loss = next_loss + mask_loss # 3. backward and optimization only in train if train: From 0867d0f026e68445247dfd3f21001a320dd5cb7c Mon Sep 17 00:00:00 2001 From: codertimo Date: Fri, 2 Nov 2018 17:31:30 +0900 Subject: [PATCH 06/10] Adding feature based on official bert code --- bert_pytorch/__main__.py | 8 +- bert_pytorch/dataset/dataset.py | 49 ++------ bert_pytorch/model/attention/multi_head.py | 4 +- bert_pytorch/model/bert.py | 4 +- bert_pytorch/model/embedding/bert.py | 2 +- bert_pytorch/model/embedding/position.py | 21 +--- bert_pytorch/model/language_model.py | 2 +- bert_pytorch/model/transformer.py | 8 +- bert_pytorch/model/utils/__init__.py | 2 +- bert_pytorch/model/utils/feed_forward.py | 12 +- bert_pytorch/model/utils/sublayer.py | 5 +- bert_pytorch/trainer/optimizer/__init__.py | 1 + bert_pytorch/trainer/optimizer/adamw.py | 110 ++++++++++++++++++ .../trainer/{ => optimizer}/optim_schedule.py | 2 +- bert_pytorch/trainer/pretrain.py | 14 ++- 15 files changed, 161 insertions(+), 83 deletions(-) create mode 100644 bert_pytorch/trainer/optimizer/__init__.py create mode 100644 bert_pytorch/trainer/optimizer/adamw.py rename bert_pytorch/trainer/{ => optimizer}/optim_schedule.py (97%) diff --git a/bert_pytorch/__main__.py b/bert_pytorch/__main__.py index a390ab8..926610f 100644 --- a/bert_pytorch/__main__.py +++ b/bert_pytorch/__main__.py @@ -21,6 +21,7 @@ def train(): parser.add_argument("-l", "--layers", type=int, default=8, help="number of layers") parser.add_argument("-a", "--attn_heads", type=int, default=8, help="number of attention heads") parser.add_argument("-s", "--seq_len", type=int, default=20, help="maximum sequence len") + parser.add_argument("-d", "--dropout", type=float, default=0.1, help="dropout rate") parser.add_argument("-b", "--batch_size", type=int, default=64, help="number of batch_size") parser.add_argument("-e", "--epochs", type=int, default=10, help="number of epochs") @@ -30,7 +31,7 @@ def train(): parser.add_argument("--log_freq", type=int, default=10, help="printing loss every n iter: setting n") parser.add_argument("--corpus_lines", type=int, default=None, help="total number of lines in corpus") parser.add_argument("--cuda_devices", type=int, nargs='+', default=None, help="CUDA device ids") - parser.add_argument("--on_memory", type=bool, default=True, help="Loading on memory: true or false") + parser.add_argument("--on_memory", type=bool, default=False, help="Loading on memory: true or false") parser.add_argument("--lr", type=float, default=1e-3, help="learning rate of adam") parser.add_argument("--adam_weight_decay", type=float, default=0.01, help="weight_decay of adam") @@ -39,6 +40,9 @@ def train(): args = parser.parse_args() + # Logging Parameter + print(args) + print("Loading Vocab", args.vocab_path) vocab = WordVocab.load_vocab(args.vocab_path) print("Vocab Size: ", len(vocab)) @@ -57,7 +61,7 @@ def train(): if test_dataset is not None else None print("Building BERT model") - bert = BERT(len(vocab), hidden=args.hidden, n_layers=args.layers, attn_heads=args.attn_heads) + bert = BERT(len(vocab), hidden=args.hidden, n_layers=args.layers, attn_heads=args.attn_heads, dropout=args.dropout) print("Creating BERT Trainer") trainer = BERTTrainer(bert, len(vocab), train_dataloader=train_data_loader, test_dataloader=test_data_loader, diff --git a/bert_pytorch/dataset/dataset.py b/bert_pytorch/dataset/dataset.py index 7d787f3..6ed03ac 100644 --- a/bert_pytorch/dataset/dataset.py +++ b/bert_pytorch/dataset/dataset.py @@ -8,31 +8,18 @@ class BERTDataset(Dataset): def __init__(self, corpus_path, vocab, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True): self.vocab = vocab self.seq_len = seq_len - - self.on_memory = on_memory - self.corpus_lines = corpus_lines self.corpus_path = corpus_path self.encoding = encoding + self.datas = [] with open(corpus_path, "r", encoding=encoding) as f: - if self.corpus_lines is None and not on_memory: - for _ in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines): - self.corpus_lines += 1 - - if on_memory: - self.lines = [line[:-1].split("\t") - for line in tqdm.tqdm(f, desc="Loading Dataset", total=corpus_lines)] - self.corpus_lines = len(self.lines) - - if not on_memory: - self.file = open(corpus_path, "r", encoding=encoding) - self.random_file = open(corpus_path, "r", encoding=encoding) - - for _ in range(random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)): - self.random_file.__next__() + for i, line in enumerate(f): + next_line = f.readline() + if line != "\n" and next_line != "\n": + self.datas.append((line[:-1], next_line[:-1])) def __len__(self): - return self.corpus_lines + return self.datas def __getitem__(self, item): t1, t2, is_next_label = self.random_sent(item) @@ -99,27 +86,7 @@ def random_sent(self, index): return t1, self.get_random_line(), 0 def get_corpus_line(self, item): - if self.on_memory: - return self.lines[item][0], self.lines[item][1] - else: - line = self.file.__next__() - if line is None: - self.file.close() - self.file = open(self.corpus_path, "r", encoding=self.encoding) - line = self.file.__next__() - - t1, t2 = line[:-1].split("\t") - return t1, t2 + return self.datas[item][0], self.datas[item][1] def get_random_line(self): - if self.on_memory: - return self.lines[random.randrange(len(self.lines))][1] - - line = self.file.__next__() - if line is None: - self.file.close() - self.file = open(self.corpus_path, "r", encoding=self.encoding) - for _ in range(random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)): - self.random_file.__next__() - line = self.random_file.__next__() - return line[:-1].split("\t")[1] + return self.datas[random.randrange(len(self.datas))][1] diff --git a/bert_pytorch/model/attention/multi_head.py b/bert_pytorch/model/attention/multi_head.py index c8a47f9..de72a3a 100644 --- a/bert_pytorch/model/attention/multi_head.py +++ b/bert_pytorch/model/attention/multi_head.py @@ -34,4 +34,6 @@ def forward(self, query, key, value, mask=None): # 3) "Concat" using a view and apply a final linear. x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k) - return self.output_linear(x) + # 4) Applying Output Linear Model + x = self.output_linear(x) + return x diff --git a/bert_pytorch/model/bert.py b/bert_pytorch/model/bert.py index c4cec4a..eb81d03 100644 --- a/bert_pytorch/model/bert.py +++ b/bert_pytorch/model/bert.py @@ -31,7 +31,9 @@ def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0 # multi-layers transformer blocks, deep network self.transformer_blocks = nn.ModuleList( - [TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)]) + [TransformerBlock(hidden=hidden, attn_heads=attn_heads, + feed_forward_hidden=hidden * 4, dropout=dropout) + for _ in range(n_layers)]) def forward(self, x, segment_info): # attention masking for padded token diff --git a/bert_pytorch/model/embedding/bert.py b/bert_pytorch/model/embedding/bert.py index bcd5115..63cceca 100644 --- a/bert_pytorch/model/embedding/bert.py +++ b/bert_pytorch/model/embedding/bert.py @@ -8,7 +8,7 @@ class BERTEmbedding(nn.Module): """ BERT Embedding which is consisted with under features 1. TokenEmbedding : normal embedding matrix - 2. PositionalEmbedding : adding positional information using sin, cos + 2. PositionalEmbedding : adding positional information 2. SegmentEmbedding : adding sentence segment info, (sent_A:1, sent_B:2) sum of all these features are output of BERTEmbedding diff --git a/bert_pytorch/model/embedding/position.py b/bert_pytorch/model/embedding/position.py index d55c224..63ab27e 100644 --- a/bert_pytorch/model/embedding/position.py +++ b/bert_pytorch/model/embedding/position.py @@ -1,25 +1,10 @@ import torch.nn as nn -import torch -import math -class PositionalEmbedding(nn.Module): +class PositionalEmbedding(nn.Embedding): def __init__(self, d_model, max_len=512): - super().__init__() - - # Compute the positional encodings once in log space. - pe = torch.zeros(max_len, d_model).float() - pe.require_grad = False - - position = torch.arange(0, max_len).float().unsqueeze(1) - div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() - - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - - pe = pe.unsqueeze(0) - self.register_buffer('pe', pe) + super().__init__(max_len, d_model) def forward(self, x): - return self.pe[:, :x.size(1)] + return self.weight.data[:x.size(1)] diff --git a/bert_pytorch/model/language_model.py b/bert_pytorch/model/language_model.py index 608f42a..a97c601 100644 --- a/bert_pytorch/model/language_model.py +++ b/bert_pytorch/model/language_model.py @@ -39,7 +39,7 @@ def __init__(self, hidden): self.softmax = nn.LogSoftmax(dim=-1) def forward(self, x): - return self.softmax(self.linear(x[:, 0])) + return self.softmax(self.linear(x[:, 0]).tanh()) class MaskedLanguageModel(nn.Module): diff --git a/bert_pytorch/model/transformer.py b/bert_pytorch/model/transformer.py index 288de26..f538bfb 100644 --- a/bert_pytorch/model/transformer.py +++ b/bert_pytorch/model/transformer.py @@ -1,7 +1,7 @@ import torch.nn as nn from .attention import MultiHeadedAttention -from .utils import SublayerConnection, PositionwiseFeedForward +from .utils import SublayerConnection, FeedForward class TransformerBlock(nn.Module): @@ -19,8 +19,8 @@ def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout): """ super().__init__() - self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden) - self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout) + self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden, dropout=dropout) + self.feed_forward = FeedForward(d_model=hidden, d_ff=feed_forward_hidden) self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout) self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout) self.dropout = nn.Dropout(p=dropout) @@ -28,4 +28,4 @@ def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout): def forward(self, x, mask): x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask)) x = self.output_sublayer(x, self.feed_forward) - return self.dropout(x) + return x diff --git a/bert_pytorch/model/utils/__init__.py b/bert_pytorch/model/utils/__init__.py index e7bddc6..68ebee6 100644 --- a/bert_pytorch/model/utils/__init__.py +++ b/bert_pytorch/model/utils/__init__.py @@ -1,4 +1,4 @@ -from .feed_forward import PositionwiseFeedForward +from .feed_forward import FeedForward from .layer_norm import LayerNorm from .sublayer import SublayerConnection from .gelu import GELU diff --git a/bert_pytorch/model/utils/feed_forward.py b/bert_pytorch/model/utils/feed_forward.py index a225c5e..f09ef4d 100644 --- a/bert_pytorch/model/utils/feed_forward.py +++ b/bert_pytorch/model/utils/feed_forward.py @@ -2,15 +2,17 @@ from .gelu import GELU -class PositionwiseFeedForward(nn.Module): +class FeedForward(nn.Module): "Implements FFN equation." - def __init__(self, d_model, d_ff, dropout=0.1): - super(PositionwiseFeedForward, self).__init__() + def __init__(self, d_model, d_ff): + super(FeedForward, self).__init__() self.w_1 = nn.Linear(d_model, d_ff) self.w_2 = nn.Linear(d_ff, d_model) - self.dropout = nn.Dropout(dropout) self.activation = GELU() def forward(self, x): - return self.w_2(self.dropout(self.activation(self.w_1(x)))) + x = self.w_1(x) + x = self.activation(x) + x = self.w_2(x) + return x diff --git a/bert_pytorch/model/utils/sublayer.py b/bert_pytorch/model/utils/sublayer.py index 6e36793..b46b807 100644 --- a/bert_pytorch/model/utils/sublayer.py +++ b/bert_pytorch/model/utils/sublayer.py @@ -5,7 +5,6 @@ class SublayerConnection(nn.Module): """ A residual connection followed by a layer norm. - Note for code simplicity the norm is first as opposed to last. """ def __init__(self, size, dropout): @@ -13,6 +12,6 @@ def __init__(self, size, dropout): self.norm = LayerNorm(size) self.dropout = nn.Dropout(dropout) - def forward(self, x, sublayer): + def forward(self, x, sublayer, dropout=True): "Apply residual connection to any sublayer with the same size." - return x + self.dropout(sublayer(self.norm(x))) + return self.norm(x + self.dropout(sublayer(x)) if dropout else sublayer(x)) diff --git a/bert_pytorch/trainer/optimizer/__init__.py b/bert_pytorch/trainer/optimizer/__init__.py new file mode 100644 index 0000000..a18f15f --- /dev/null +++ b/bert_pytorch/trainer/optimizer/__init__.py @@ -0,0 +1 @@ +from .adamw import AdamW diff --git a/bert_pytorch/trainer/optimizer/adamw.py b/bert_pytorch/trainer/optimizer/adamw.py new file mode 100644 index 0000000..f496778 --- /dev/null +++ b/bert_pytorch/trainer/optimizer/adamw.py @@ -0,0 +1,110 @@ +import math +import torch +from torch.optim.optimizer import Optimizer + +""" +egg-west/AdamW-pytorch : Implementation and experiment for AdamW on pytorch +from https://github.com/egg-west/AdamW-pytorch +""" + + +class AdamW(Optimizer): + """Implements Adam algorithm. + It has been proposed in `Adam: A Method for Stochastic Optimization`_. + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0, amsgrad=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad) + super(AdamW, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdamW, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + amsgrad = group['amsgrad'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + # if group['weight_decay'] != 0: + # grad = grad.add(group['weight_decay'], p.data) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(1 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = max_exp_avg_sq.sqrt().add_(group['eps']) + else: + denom = exp_avg_sq.sqrt().add_(group['eps']) + + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 + + # p.data.addcdiv_(-step_size, exp_avg, denom) + p.data.add_(-step_size, torch.mul(p.data, group['weight_decay']).addcdiv_(1, exp_avg, denom)) + + return loss diff --git a/bert_pytorch/trainer/optim_schedule.py b/bert_pytorch/trainer/optimizer/optim_schedule.py similarity index 97% rename from bert_pytorch/trainer/optim_schedule.py rename to bert_pytorch/trainer/optimizer/optim_schedule.py index 5ccd222..a1a8fc1 100644 --- a/bert_pytorch/trainer/optim_schedule.py +++ b/bert_pytorch/trainer/optimizer/optim_schedule.py @@ -2,7 +2,7 @@ import numpy as np -class ScheduledOptim(): +class ScheduledOptim: '''A simple wrapper class for learning rate scheduling''' def __init__(self, optimizer, d_model, n_warmup_steps): diff --git a/bert_pytorch/trainer/pretrain.py b/bert_pytorch/trainer/pretrain.py index 5f24481..aedbc10 100644 --- a/bert_pytorch/trainer/pretrain.py +++ b/bert_pytorch/trainer/pretrain.py @@ -5,7 +5,8 @@ from torch.utils.data import DataLoader from ..model import BERTLM, BERT -from .optim_schedule import ScheduledOptim +from .optimizer.optim_schedule import ScheduledOptim +from .optimizer.adamw import AdamW class BERTTrainer: @@ -54,7 +55,7 @@ def __init__(self, bert: BERT, vocab_size: int, self.test_data = test_dataloader # Setting the Adam optimizer with hyper-param - self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) + self.optim = AdamW(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) self.optim_schedule = ScheduledOptim(self.optim, self.bert.hidden, n_warmup_steps=warmup_steps) # Using Negative Log Likelihood Loss function for predicting the masked_token @@ -96,7 +97,7 @@ def iteration(self, epoch, data_loader, train=True): next_sent_output, mask_lm_output = self.model.forward(data["bert_input"], data["segment_label"]) # 2-1. NLL(negative log likelihood) loss of is_next classification result - next_loss = self.next_criterion(next_sent_output, data["is_next"]) * 10 + next_loss = self.next_criterion(next_sent_output, data["is_next"]) # 2-2. NLLLoss of predicting masked token word mask_loss = self.masked_criterion(mask_lm_output.transpose(1, 2), data["bert_label"]) @@ -127,7 +128,12 @@ def iteration(self, epoch, data_loader, train=True): } if i % self.log_freq == 0: - print(str(post_fix)) + print(post_fix) + + # Logging for PaperSpace matrix monitor + # index = epoch * len(data_loader) + i + # for code in ["avg_loss", "mask_loss", "next_loss", "avg_next_acc"]: + # print(json.dumps({"chart": code, "y": post_fix[code], "x": index})) print("EP%d_%s, avg_loss=" % (epoch, str_code), avg_loss / len(data_loader), "total_acc=", total_correct * 100.0 / total_element) From c981348c9e22795ebb42be4542e8b52a372dc708 Mon Sep 17 00:00:00 2001 From: codertimo Date: Fri, 2 Nov 2018 17:45:56 +0900 Subject: [PATCH 07/10] Adding run-able python script --- run.py | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 run.py diff --git a/run.py b/run.py new file mode 100644 index 0000000..1bbee8c --- /dev/null +++ b/run.py @@ -0,0 +1,2 @@ +from bert_pytorch.__main__ import train +train() From 5d33758aa107901568cafb29a9a48bc8b207b78e Mon Sep 17 00:00:00 2001 From: codertimo Date: Fri, 2 Nov 2018 17:52:11 +0900 Subject: [PATCH 08/10] Fix Typo error --- bert_pytorch/dataset/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bert_pytorch/dataset/dataset.py b/bert_pytorch/dataset/dataset.py index 6ed03ac..61e78f7 100644 --- a/bert_pytorch/dataset/dataset.py +++ b/bert_pytorch/dataset/dataset.py @@ -19,7 +19,7 @@ def __init__(self, corpus_path, vocab, seq_len, encoding="utf-8", corpus_lines=N self.datas.append((line[:-1], next_line[:-1])) def __len__(self): - return self.datas + return len(self.datas) def __getitem__(self, item): t1, t2, is_next_label = self.random_sent(item) From 91e635bb40d0d646fdeae613df64c5e3da53deb3 Mon Sep 17 00:00:00 2001 From: codertimo Date: Fri, 2 Nov 2018 17:56:44 +0900 Subject: [PATCH 09/10] Change pretrain --- bert_pytorch/trainer/pretrain.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bert_pytorch/trainer/pretrain.py b/bert_pytorch/trainer/pretrain.py index aedbc10..d50e003 100644 --- a/bert_pytorch/trainer/pretrain.py +++ b/bert_pytorch/trainer/pretrain.py @@ -56,7 +56,7 @@ def __init__(self, bert: BERT, vocab_size: int, # Setting the Adam optimizer with hyper-param self.optim = AdamW(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) - self.optim_schedule = ScheduledOptim(self.optim, self.bert.hidden, n_warmup_steps=warmup_steps) + # self.optim_schedule = ScheduledOptim(self.optim, self.bert.hidden, n_warmup_steps=warmup_steps) # Using Negative Log Likelihood Loss function for predicting the masked_token self.masked_criterion = nn.NLLLoss(ignore_index=0) @@ -107,9 +107,9 @@ def iteration(self, epoch, data_loader, train=True): # 3. backward and optimization only in train if train: - self.optim_schedule.zero_grad() + self.optim.zero_grad() loss.backward() - self.optim_schedule.step_and_update_lr() + self.optim.step() # next sentence prediction accuracy correct = next_sent_output.argmax(dim=-1).eq(data["is_next"]).sum().item() From 768736f0209cffcba5c6430f18de26a2dab565af Mon Sep 17 00:00:00 2001 From: codertimo Date: Fri, 2 Nov 2018 18:35:15 +0900 Subject: [PATCH 10/10] output projection weight and embedding weight sync --- bert_pytorch/model/language_model.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/bert_pytorch/model/language_model.py b/bert_pytorch/model/language_model.py index a97c601..8c53b2b 100644 --- a/bert_pytorch/model/language_model.py +++ b/bert_pytorch/model/language_model.py @@ -18,7 +18,8 @@ def __init__(self, bert: BERT, vocab_size): super().__init__() self.bert = bert self.next_sentence = NextSentencePrediction(self.bert.hidden) - self.mask_lm = MaskedLanguageModel(self.bert.hidden, vocab_size) + self.mask_lm = MaskedLanguageModel(self.bert.hidden, vocab_size, + embedding=self.bert.embedding.token) def forward(self, x, segment_label): x = self.bert(x, segment_label) @@ -48,13 +49,15 @@ class MaskedLanguageModel(nn.Module): n-class classification problem, n-class = vocab_size """ - def __init__(self, hidden, vocab_size): + def __init__(self, hidden, vocab_size, embedding=None): """ :param hidden: output size of BERT model :param vocab_size: total vocab size """ super().__init__() self.linear = nn.Linear(hidden, vocab_size) + if embedding is not None: + self.linear.weight.data = embedding.weight.data self.softmax = nn.LogSoftmax(dim=-1) def forward(self, x):