diff --git a/README.md b/README.md index 16d3df9..a242b56 100644 --- a/README.md +++ b/README.md @@ -42,20 +42,20 @@ pip install bert-pytorch ### 0. Prepare your corpus ``` -Welcome to the \t the jungle\n -I can stay \t here all night\n +Welcome to the the jungle +I can stay here all night ``` or tokenized corpus (tokenization is not in package) ``` -Wel_ _come _to _the \t _the _jungle\n -_I _can _stay \t _here _all _night\n +Wel_ _come _to _the _the _jungle +_I _can _stay _here _all _night ``` ### 1. Building vocab based on your corpus ```shell -bert-vocab -c data/corpus.small -o data/vocab.small +vocab-builder -c data/corpus.small -o data/vocab.small ``` ### 2. Train your own BERT model @@ -109,7 +109,7 @@ not directly captured by language modeling ## Author -Junseong Kim, Scatter Lab (codertimo@gmail.com / junseong.kim@scatter.co.kr) +Junseong Kim, Scatter Lab (codertimo@gmail.com / junseong.kim@scatterlab.co.kr) ## License @@ -117,4 +117,4 @@ This project following Apache 2.0 License as written in LICENSE file Copyright 2018 Junseong Kim, Scatter Lab, respective BERT contributors -Copyright (c) 2018 Alexander Rush : [The Annotated Trasnformer](https://github.com/harvardnlp/annotated-transformer) +Copyright (c) 2018 Alexander Rush : [The Annotated Transformer](https://github.com/harvardnlp/annotated-transformer) diff --git a/bert_pytorch/__main__.py b/bert_pytorch/__main__.py index d4193f2..a390ab8 100644 --- a/bert_pytorch/__main__.py +++ b/bert_pytorch/__main__.py @@ -4,7 +4,9 @@ from .model import BERT from .trainer import BERTTrainer -from .dataset import BERTDataset, WordVocab +from .dataset import BERTDataset + +from vocab_builder import WordVocab def train(): diff --git a/bert_pytorch/dataset/__init__.py b/bert_pytorch/dataset/__init__.py index 90e9036..2341716 100644 --- a/bert_pytorch/dataset/__init__.py +++ b/bert_pytorch/dataset/__init__.py @@ -1,2 +1 @@ from .dataset import BERTDataset -from .vocab import WordVocab diff --git a/bert_pytorch/dataset/vocab.py b/bert_pytorch/dataset/vocab.py deleted file mode 100644 index f7346a7..0000000 --- a/bert_pytorch/dataset/vocab.py +++ /dev/null @@ -1,185 +0,0 @@ -import pickle -import tqdm -from collections import Counter - - -class TorchVocab(object): - """Defines a vocabulary object that will be used to numericalize a field. - Attributes: - freqs: A collections.Counter object holding the frequencies of tokens - in the data used to build the Vocab. - stoi: A collections.defaultdict instance mapping token strings to - numerical identifiers. - itos: A list of token strings indexed by their numerical identifiers. - """ - - def __init__(self, counter, max_size=None, min_freq=1, specials=['', ''], - vectors=None, unk_init=None, vectors_cache=None): - """Create a Vocab object from a collections.Counter. - Arguments: - counter: collections.Counter object holding the frequencies of - each value found in the data. - max_size: The maximum size of the vocabulary, or None for no - maximum. Default: None. - min_freq: The minimum frequency needed to include a token in the - vocabulary. Values less than 1 will be set to 1. Default: 1. - specials: The list of special tokens (e.g., padding or eos) that - will be prepended to the vocabulary in addition to an - token. Default: [''] - vectors: One of either the available pretrained vectors - or custom pretrained vectors (see Vocab.load_vectors); - or a list of aforementioned vectors - unk_init (callback): by default, initialize out-of-vocabulary word vectors - to zero vectors; can be any function that takes in a Tensor and - returns a Tensor of the same size. Default: torch.Tensor.zero_ - vectors_cache: directory for cached vectors. Default: '.vector_cache' - """ - self.freqs = counter - counter = counter.copy() - min_freq = max(min_freq, 1) - - self.itos = list(specials) - # frequencies of special tokens are not counted when building vocabulary - # in frequency order - for tok in specials: - del counter[tok] - - max_size = None if max_size is None else max_size + len(self.itos) - - # sort by frequency, then alphabetically - words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0]) - words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True) - - for word, freq in words_and_frequencies: - if freq < min_freq or len(self.itos) == max_size: - break - self.itos.append(word) - - # stoi is simply a reverse dict for itos - self.stoi = {tok: i for i, tok in enumerate(self.itos)} - - self.vectors = None - if vectors is not None: - self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache) - else: - assert unk_init is None and vectors_cache is None - - def __eq__(self, other): - if self.freqs != other.freqs: - return False - if self.stoi != other.stoi: - return False - if self.itos != other.itos: - return False - if self.vectors != other.vectors: - return False - return True - - def __len__(self): - return len(self.itos) - - def vocab_rerank(self): - self.stoi = {word: i for i, word in enumerate(self.itos)} - - def extend(self, v, sort=False): - words = sorted(v.itos) if sort else v.itos - for w in words: - if w not in self.stoi: - self.itos.append(w) - self.stoi[w] = len(self.itos) - 1 - - -class Vocab(TorchVocab): - def __init__(self, counter, max_size=None, min_freq=1): - self.pad_index = 0 - self.unk_index = 1 - self.eos_index = 2 - self.sos_index = 3 - self.mask_index = 4 - super().__init__(counter, specials=["", "", "", "", ""], - max_size=max_size, min_freq=min_freq) - - def to_seq(self, sentece, seq_len, with_eos=False, with_sos=False) -> list: - pass - - def from_seq(self, seq, join=False, with_pad=False): - pass - - @staticmethod - def load_vocab(vocab_path: str) -> 'Vocab': - with open(vocab_path, "rb") as f: - return pickle.load(f) - - def save_vocab(self, vocab_path): - with open(vocab_path, "wb") as f: - pickle.dump(self, f) - - -# Building Vocab with text files -class WordVocab(Vocab): - def __init__(self, texts, max_size=None, min_freq=1): - print("Building Vocab") - counter = Counter() - for line in tqdm.tqdm(texts): - if isinstance(line, list): - words = line - else: - words = line.replace("\n", "").replace("\t", "").split() - - for word in words: - counter[word] += 1 - super().__init__(counter, max_size=max_size, min_freq=min_freq) - - def to_seq(self, sentence, seq_len=None, with_eos=False, with_sos=False, with_len=False): - if isinstance(sentence, str): - sentence = sentence.split() - - seq = [self.stoi.get(word, self.unk_index) for word in sentence] - - if with_eos: - seq += [self.eos_index] # this would be index 1 - if with_sos: - seq = [self.sos_index] + seq - - origin_seq_len = len(seq) - - if seq_len is None: - pass - elif len(seq) <= seq_len: - seq += [self.pad_index for _ in range(seq_len - len(seq))] - else: - seq = seq[:seq_len] - - return (seq, origin_seq_len) if with_len else seq - - def from_seq(self, seq, join=False, with_pad=False): - words = [self.itos[idx] - if idx < len(self.itos) - else "<%d>" % idx - for idx in seq - if not with_pad or idx != self.pad_index] - - return " ".join(words) if join else words - - @staticmethod - def load_vocab(vocab_path: str) -> 'WordVocab': - with open(vocab_path, "rb") as f: - return pickle.load(f) - - -def build(): - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("-c", "--corpus_path", required=True, type=str) - parser.add_argument("-o", "--output_path", required=True, type=str) - parser.add_argument("-s", "--vocab_size", type=int, default=None) - parser.add_argument("-e", "--encoding", type=str, default="utf-8") - parser.add_argument("-m", "--min_freq", type=int, default=1) - args = parser.parse_args() - - with open(args.corpus_path, "r", encoding=args.encoding) as f: - vocab = WordVocab(f, max_size=args.vocab_size, min_freq=args.min_freq) - - print("VOCAB SIZE:", len(vocab)) - vocab.save_vocab(args.output_path) diff --git a/bert_pytorch/trainer/pretrain.py b/bert_pytorch/trainer/pretrain.py index 0b882dd..5f24481 100644 --- a/bert_pytorch/trainer/pretrain.py +++ b/bert_pytorch/trainer/pretrain.py @@ -1,13 +1,12 @@ import torch import torch.nn as nn + from torch.optim import Adam from torch.utils.data import DataLoader from ..model import BERTLM, BERT from .optim_schedule import ScheduledOptim -import tqdm - class BERTTrainer: """ @@ -59,7 +58,8 @@ def __init__(self, bert: BERT, vocab_size: int, self.optim_schedule = ScheduledOptim(self.optim, self.bert.hidden, n_warmup_steps=warmup_steps) # Using Negative Log Likelihood Loss function for predicting the masked_token - self.criterion = nn.NLLLoss(ignore_index=0) + self.masked_criterion = nn.NLLLoss(ignore_index=0) + self.next_criterion = nn.NLLLoss() self.log_freq = log_freq @@ -84,17 +84,11 @@ def iteration(self, epoch, data_loader, train=True): """ str_code = "train" if train else "test" - # Setting the tqdm progress bar - data_iter = tqdm.tqdm(enumerate(data_loader), - desc="EP_%s:%d" % (str_code, epoch), - total=len(data_loader), - bar_format="{l_bar}{r_bar}") - avg_loss = 0.0 total_correct = 0 total_element = 0 - for i, data in data_iter: + for i, data in enumerate(data_loader): # 0. batch_data will be sent into the device(GPU or cpu) data = {key: value.to(self.device) for key, value in data.items()} @@ -102,10 +96,10 @@ def iteration(self, epoch, data_loader, train=True): next_sent_output, mask_lm_output = self.model.forward(data["bert_input"], data["segment_label"]) # 2-1. NLL(negative log likelihood) loss of is_next classification result - next_loss = self.criterion(next_sent_output, data["is_next"]) + next_loss = self.next_criterion(next_sent_output, data["is_next"]) * 10 # 2-2. NLLLoss of predicting masked token word - mask_loss = self.criterion(mask_lm_output.transpose(1, 2), data["bert_label"]) + mask_loss = self.masked_criterion(mask_lm_output.transpose(1, 2), data["bert_label"]) # 2-3. Adding next_loss and mask_loss : 3.4 Pre-training Procedure loss = next_loss + mask_loss @@ -124,16 +118,18 @@ def iteration(self, epoch, data_loader, train=True): post_fix = { "epoch": epoch, - "iter": i, + "iter": "[%d/%d]" % (i, len(data_loader)), "avg_loss": avg_loss / (i + 1), - "avg_acc": total_correct / total_element * 100, + "mask_loss": mask_loss.item(), + "next_loss": next_loss.item(), + "avg_next_acc": total_correct / total_element * 100, "loss": loss.item() } if i % self.log_freq == 0: - data_iter.write(str(post_fix)) + print(str(post_fix)) - print("EP%d_%s, avg_loss=" % (epoch, str_code), avg_loss / len(data_iter), "total_acc=", + print("EP%d_%s, avg_loss=" % (epoch, str_code), avg_loss / len(data_loader), "total_acc=", total_correct * 100.0 / total_element) def save(self, epoch, file_path="output/bert_trained.model"): diff --git a/requirements.txt b/requirements.txt index 3689708..d53ab8c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ tqdm numpy -torch>=0.4.0 \ No newline at end of file +torch >= 0.4.0 +vocab-builder \ No newline at end of file diff --git a/setup.py b/setup.py index 4e721cf..2d1699b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ import os import sys -__version__ = "0.0.1a4" +__version__ = "0.0.1a5" with open("requirements.txt") as f: require_packages = [line[:-1] if line[-1] == "\n" else line for line in f]