Link Prediction With PyTorch BigGraph
In previous post we talked about Graph Representation and Network Embeddings. In this post, we’ll talk about an paper implementation: PyTorch-BigGraph from Facebook (github link), particularly about how they train and use the network embedding to perform link predictions.
Link Predictions
PyTorch BigGraph (PBG) can do link prediction by 1) learn an embedding for each entity 2) a function for each relation type that takes two entity embeddings and assigns them a score, 3) with the goal of having positive relations achieve higher scores than negative ones.
Train Embeddings
The way PBG trains network embeddings is pretty simple: calculate scores of positive and negative samples then aggregate them in a single loss value. While training, PBG minimize that lost to derive the embeddings. Note that PBG considers each entity to have its own embedding, which is entirely independent from any other entity’s embedding.
Scaling up
What is powerful about PBG is that it could scale. In order for PBG to operate on large-scale graphs, the graph is broken up into small pieces, on which training can happen in a distributed manner. This is first achieved by further splitting the entities of each type into a certain number of subsets, called partitions. Then, for each relation type, its edges are divided into buckets: for each pair of partitions (one from the left- and one from the right-hand side entity types for that relation type) a bucket is created, which contains the edges of that type whose left- and right-hand side entities are in those partitions.
Training Yourself
Install the torchbiggraph package and run the fb15k example based on Freebase entities and relationships.
pip install torchbiggraph
torchbiggraph_example_fb15k
When you got the trained model, run the following script to get out the edge score for link prediction:
import json
import h5py
import torch
from torchbiggraph.model import ComplexDiagonalDynamicOperator, DotComparator
# Load count of dynamic relations
with open("data/FB15k/dynamic_rel_count.txt", "rt") as tf:
dynamic_rel_count = int(tf.read().strip())
# Load the operator's state dict
with h5py.File("model/fb15k/model.v50.h5", "r") as hf:
operator_state_dict = {
"real": torch.from_numpy(hf["model/relations/0/operator/rhs/real"][...]),
"imag": torch.from_numpy(hf["model/relations/0/operator/rhs/imag"][...]),
}
operator = ComplexDiagonalDynamicOperator(400, dynamic_rel_count)
operator.load_state_dict(operator_state_dict)
comparator = DotComparator()
# Load the names of the entities, ordered by offset.
with open("data/FB15k/entity_names_all_0.json", "rt") as tf:
entity_names = json.load(tf)
src_entity_offset = entity_names.index("/m/0f8l9c") # France
dest_entity_offset = entity_names.index("/m/05qtj") # Paris
# Load the names of the relation types, ordered by index.
with open("data/FB15k/dynamic_rel_names.json", "rt") as tf:
rel_type_names = json.load(tf)
rel_type_index = rel_type_names.index("/location/country/capital")
# Load the trained embeddings
with h5py.File("model/fb15k/embeddings_all_0.v50.h5", "r") as hf:
src_embedding = torch.from_numpy(hf["embeddings"][src_entity_offset, :])
dest_embedding = torch.from_numpy(hf["embeddings"][dest_entity_offset, :])
# Calculate the scores
scores, _, _ = comparator(
comparator.prepare(src_embedding.view(1, 1, 400)),
comparator.prepare(
operator(
dest_embedding.view(1, 400),
torch.tensor([rel_type_index]),
).view(1, 1, 400),
),
torch.empty(1, 0, 400), # Left-hand side negatives, not needed
torch.empty(1, 0, 400), # Right-hand side negatives, not needed
)
print(scores)
Source: https://torchbiggraph.readthedocs.io/en/latest/, https://github.com/facebookresearch/PyTorch-BigGraph