From bd98c4b1e07ae21904d1802b65469f10e795978e Mon Sep 17 00:00:00 2001 From: uchiiii Date: Thu, 27 Mar 2025 13:40:17 +0900 Subject: [PATCH] memo --- common/modules/embedding/config.py | 1 + common/modules/embedding/embedding.py | 1 + projects/twhin/models/models.py | 7 +++++++ 3 files changed, 9 insertions(+) diff --git a/common/modules/embedding/config.py b/common/modules/embedding/config.py index 2f5df15..6fe5a9d 100644 --- a/common/modules/embedding/config.py +++ b/common/modules/embedding/config.py @@ -29,6 +29,7 @@ class EmbeddingBagConfig(base_config.BaseConfig): name: str = pydantic.Field(..., description="name of embedding bag") num_embeddings: int = pydantic.Field(..., description="size of embedding dictionary") embedding_dim: int = pydantic.Field(..., description="size of each embedding vector") + # here pretrained: EmbeddingSnapshot = pydantic.Field(None, description="Snapshot properties") vocab: str = pydantic.Field( None, description="Directory to parquet files of mapping from entity ID to table index." diff --git a/common/modules/embedding/embedding.py b/common/modules/embedding/embedding.py index b0a085e..69695da 100644 --- a/common/modules/embedding/embedding.py +++ b/common/modules/embedding/embedding.py @@ -36,6 +36,7 @@ def __init__( ) ) + # https://pytorch.org/torchrec/modules-api-reference.html#torchrec.modules.embedding_modules.EmbeddingBagCollection self.ebc = EmbeddingBagCollection( device="meta", tables=tables, diff --git a/projects/twhin/models/models.py b/projects/twhin/models/models.py index b1dfaa7..42fbc49 100644 --- a/projects/twhin/models/models.py +++ b/projects/twhin/models/models.py @@ -18,6 +18,7 @@ def __init__(self, model_config: TwhinModelConfig, data_config: TwhinDataConfig) super().__init__() self.batch_size = data_config.per_replica_batch_size self.table_names = [table.name for table in model_config.embeddings.tables] + # 単なる EmbeddingBag の list self.large_embeddings = LargeEmbeddings(model_config.embeddings) self.embedding_dim = model_config.embeddings.tables[0].embedding_dim self.num_tables = len(model_config.embeddings.tables) @@ -31,6 +32,8 @@ def __init__(self, model_config: TwhinModelConfig, data_config: TwhinDataConfig) ) def forward(self, batch: EdgeBatch): + # B: Batch size + # D: embedding dimention # B x D trans_embs = self.all_trans_embs.data[batch.rels] @@ -48,12 +51,15 @@ def forward(self, batch: EdgeBatch): x = torch.sum(x, 1) # B x 2 x D + # [:, 0, :]: source + # [:, 1, :]: target x = x.reshape(self.batch_size, 2, self.embedding_dim) # translated translated = x[:, 1, :] + trans_embs negs = [] + # 同じ batch 内で negative sampling if self.in_batch_negatives: # construct dot products for negatives via matmul for relation in range(self.num_relations): @@ -80,6 +86,7 @@ def forward(self, batch: EdgeBatch): sampled_rhs = rhs_matrix[rhs_indices] # RS + # あれ negative sampling に対しては traslate してない? negs_rhs = torch.flatten(torch.matmul(lhs_matrix, sampled_rhs.t())) negs_lhs = torch.flatten(torch.matmul(rhs_matrix, sampled_lhs.t()))