In Teilen dieser Serie haben wir uns mit Graph Faltungsnetzwerken (GCNs) und Graph Achtungsnetzwerken (GATs) befasst. Beide Architekturen funktionieren intestine, haben aber auch einige Einschränkungen! Ein großes ist, dass für große Graphen die Berechnung der Knotendarstellungen mit GCNs und GATs sehr langsam werden. Eine weitere Einschränkung besteht darin, dass GCNs und GATs, wenn sich die Graphenstruktur ändert, nicht verallgemeinert werden können. Wenn additionally Knoten zum Diagramm hinzugefügt werden, können ein GCN oder GAT keine Vorhersagen dafür treffen. Zum Glück können diese Probleme gelöst werden!
In diesem Beitrag werde ich erklären Graphsage und wie es häufige Probleme von GCNs und Gats löst. Wir trainieren Graphsage und verwenden sie für Diagrammvorhersagen, um die Leistung mit GCNs und GATs zu vergleichen.
Neu in Gnns? Sie können mit beginnen mit Submit 1 über GCNs (Enthält auch das anfängliche Setup zum Ausführen der Code -Beispiele) und Submit 2 über Gats.
Zwei wichtige Probleme mit GCNs und Gats
Ich habe es in der Einführung in Kürze angesprochen, aber lassen Sie uns ein bisschen tiefer tauchen. Was sind die Probleme mit den vorherigen GNN -Modellen?
Downside 1. Sie verallgemeinern nicht
GCNs und GATs haben Probleme mit der Verallgemeinerung auf unsichtbare Grafiken. Die Grafikstruktur muss mit den Trainingsdaten übereinstimmen. Dies ist bekannt als als transduktives Lernenwo das Modell trainiert und Vorhersagen in derselben festen Grafik macht. Es ist tatsächlich zu bestimmten Graph -Topologien übernommen. In Wirklichkeit ändern sich die Grafiken: Knoten und Kanten können hinzugefügt oder entfernt werden, und dies geschieht oft in Szenarien in realer Welt. Wir möchten induktiv Lernen).
Downside 2. Sie haben Skalierbarkeitsprobleme
Das Coaching von GCNs und GATS in groß angelegten Grafiken ist rechnerisch teuer. GCNs erfordern eine wiederholte Nachbaraggregation, die exponentiell mit der Graphengröße wächst, während GATs (Multihead-) Aufmerksamkeitsmechanismen mit zunehmenden Knoten schlecht skalieren.
In großen Produktionsempfehlungssystemen mit großen Grafiken mit Millionen von Benutzern und Produkten sind GCNs und GATS unpraktisch und langsam.
Schauen wir uns die Graphsage an, um diese Probleme zu beheben.
Graphsage (Probe und Aggregat)
Graphsage macht das Coaching viel schneller und skalierbar. Dies geschieht durch Probenahme nur eine Untergruppe von Nachbarn. Für tremendous große Grafiken ist es rechnerisch unmöglich, alle Nachbarn eines Knotens zu verarbeiten (außer wenn Sie unbegrenzt Zeit haben, was wir alle nicht…), wie bei traditionellen GCNs. Ein weiterer wichtiger Schritt der Graphsage ist Kombinieren Sie die Merkmale der abgetasteten Nachbarn mit einer Aggregationsfunktion.
Wir werden alle Schritte der Graphsage unten durchlaufen.
1. Probenahme Nachbarn
Bei tabellarischen Daten ist die Stichprobe einfach. Es ist etwas, was Sie in jedem gemeinsamen Projekt für maschinelles Lernen bei der Erstellung von Zug-, Check- und Validierungssätzen machen. Mit Diagrammen können Sie keine zufälligen Knoten auswählen. Dies kann zu getrennten Graphen, Knoten ohne Nachbarn usw. führen:

Was du dürfen Mit Diagrammen wählt eine zufällige Untergruppe von Nachbarn mit fester Größe aus. In einem sozialen Netzwerk können Sie beispielsweise 3 Freunde für jeden Benutzer (anstelle aller Freunde) probieren:

2. Gesamtinformationen
Nach der Auswahl der Nachbarn aus dem vorherigen Teil kombiniert Graphsage ihre Funktionen zu einer einzigen Darstellung. Es gibt mehrere Möglichkeiten, dies zu tun (mehrere Aggregationsfunktionen). Die am häufigsten und die im Papier erklärten Typen sind mittlere AggregationAnwesend LstmUnd Pooling.
Mit der mittleren Aggregation wird der Durchschnitt über alle Merkmale der abgetasteten Nachbarn berechnet (sehr einfach und oft effektiv). In einer Formel:
LSTM -Aggregation verwendet eine Lstm (Artwork des neuronalen Netzwerks), um Nachbarfunktionen nacheinander zu verarbeiten. Es kann komplexere Beziehungen aufnehmen und ist mächtiger als die mittlere Aggregation.
Die dritte Artwork, die Poolaggregation, wendet eine nichtlineare Funktion an, um wichtige Merkmale zu extrahieren (denken Sie an Max-Pooling In einem neuronalen Netzwerk, in dem Sie auch den Maximalwert einiger Werte nehmen).
3.. Aktualisieren der Knotendarstellung
Nach der Probenahme und Aggregation der Knoten kombiniert seine früheren Funktionen mit den aggregierten Nachbarnfunktionen. Knoten werden von ihren Nachbarn lernen, aber auch ihre eigene Identität behalten, genau wie zuvor mit GCNs und Gats. Informationen können effektiv über den Diagramm fließen.
Dies ist die Formel für diesen Schritt:
Die Aggregation von Schritt 2 erfolgt über alle Nachbarn, und dann wird die Merkmalsdarstellung des Knotens verkettet. Dieser Vektor wird mit der Gewichtsmatrix multipliziert und durch die Nichtlinearität (z. B. Relu) geleitet. Als letzter Schritt kann die Normalisierung angewendet werden.
4. Wiederholen Sie für mehrere Schichten
Die ersten drei Schritte können mehrmals wiederholt werden. In diesem Fall können Informationen von entfernten Nachbarn fließen. Im Bild unten sehen Sie einen Knoten mit drei in der ersten Schicht ausgewählten Nachbarn (direkte Nachbarn) und zwei in der zweiten Schicht ausgewählte Nachbarn (Nachbarn der Nachbarn).

Zusammenfassend lässt sich sagen, dass die wichtigsten Stärken von Graphsage ihre Skalierbarkeit sind (die Abtastung macht es für large Grafiken effizient). Flexibilität können Sie es verwenden für Induktives Lernen (Funktioniert intestine, wenn es zur Vorhersage von unsichtbaren Knoten und Grafiken verwendet wird); Die Aggregation hilft bei der Verallgemeinerung, da sie laute Merkmale glättet. und die Mehrschichten erlauben das Modell, aus fernen Knoten zu lernen.
Cool! Und das Beste, Graphsage wird in implementiert PygSo können wir es leicht in Pytorch verwenden.
Vorhersage mit Graphsage
In den vorherigen Beiträgen haben wir einen MLP, GCN und GAT auf dem implementiert Cora Datensatz (CC BY-SA). Um Ihren Geist ein wenig zu aktualisieren, ist Cora ein Datensatz mit wissenschaftlichen Veröffentlichungen, in denen Sie das Thema jedes Papiers mit insgesamt sieben Klassen vorhersagen müssen. Dieser Datensatz ist relativ gering und ist daher möglicherweise nicht der beste Satz für das Testen von Graphsage. Wir werden dies trotzdem tun, nur um vergleichen zu können. Mal sehen, wie intestine Graphsage funktioniert.
Interessante Teile des Codes, die ich gerne mit Graphsage hervorhebt:
- Der
NeighborLoader
Dadurch wird die Auswahl der Nachbarn für jede Schicht ausgeführt:
from torch_geometric.loader import NeighborLoader
# 10 neighbors sampled within the first layer, 10 within the second layer
num_neighbors = (10, 10)
# pattern information from the prepare set
train_loader = NeighborLoader(
information,
num_neighbors=num_neighbors,
batch_size=batch_size,
input_nodes=information.train_mask,
)
- Der Aggregationstyp ist in der implementiert
SAGEConv
Schicht. Der Normal istimply
Sie können dies an ändernmax
oderlstm
:
from torch_geometric.nn import SAGEConv
SAGEConv(in_c, out_c, aggr='imply')
- Ein weiterer wichtiger Unterschied besteht darin, dass Graphsage in Mini -Chargen und GCN und GAT auf dem vollständigen Datensatz trainiert wird. Dies berührt die Essenz von Graphsage, da die Nachbarabtastung von Graphsage es ermöglicht, in Mini -Chargen zu trainieren, benötigen wir nicht mehr den vollständigen Diagramm. GCNs und GATs benötigen das komplette Diagramm für die korrekte Merkmalsausbreitung und Berechnung der Aufmerksamkeitswerte. Deshalb trainieren wir GCNs und GATS im vollständigen Diagramm.
- Der Relaxation des Codes ist ähnlich wie zuvor, außer dass wir eine Klasse haben, in der alle verschiedenen Modelle basierend auf dem instanziiert werden
model_type
(GCN, Gat oder Salbei). Dies erleichtert es einfach, kleine Änderungen zu vergleichen oder vorzunehmen.
Dies ist das vollständige Skript. Wir trainieren 100 Epochen und wiederholen das Experiment 10 Mal, um die durchschnittliche Genauigkeit und Standardabweichung für jedes Modell zu berechnen:
import torch
import torch.nn.purposeful as F
from torch_geometric.nn import SAGEConv, GCNConv, GATConv
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import NeighborLoader
# dataset_name could be 'Cora', 'CiteSeer', 'PubMed'
dataset_name = 'Cora'
hidden_dim = 64
num_layers = 2
num_neighbors = (10, 10)
batch_size = 128
num_epochs = 100
model_types = ('GCN', 'GAT', 'SAGE')
dataset = Planetoid(root='information', identify=dataset_name)
information = dataset(0)
machine = torch.machine('cuda' if torch.cuda.is_available() else 'cpu')
information = information.to(machine)
class GNN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers, model_type='SAGE', gat_heads=8):
tremendous().__init__()
self.convs = torch.nn.ModuleList()
self.model_type = model_type
self.gat_heads = gat_heads
def get_conv(in_c, out_c, is_final=False):
if model_type == 'GCN':
return GCNConv(in_c, out_c)
elif model_type == 'GAT':
heads = 1 if is_final else gat_heads
concat = False if is_final else True
return GATConv(in_c, out_c, heads=heads, concat=concat)
else:
return SAGEConv(in_c, out_c, aggr='imply')
if model_type == 'GAT':
self.convs.append(get_conv(in_channels, hidden_channels))
in_dim = hidden_channels * gat_heads
for _ in vary(num_layers - 2):
self.convs.append(get_conv(in_dim, hidden_channels))
in_dim = hidden_channels * gat_heads
self.convs.append(get_conv(in_dim, out_channels, is_final=True))
else:
self.convs.append(get_conv(in_channels, hidden_channels))
for _ in vary(num_layers - 2):
self.convs.append(get_conv(hidden_channels, hidden_channels))
self.convs.append(get_conv(hidden_channels, out_channels))
def ahead(self, x, edge_index):
for conv in self.convs(:-1):
x = F.relu(conv(x, edge_index))
x = self.convs(-1)(x, edge_index)
return x
@torch.no_grad()
def check(mannequin):
mannequin.eval()
out = mannequin(information.x, information.edge_index)
pred = out.argmax(dim=1)
accs = ()
for masks in (information.train_mask, information.val_mask, information.test_mask):
accs.append(int((pred(masks) == information.y(masks)).sum()) / int(masks.sum()))
return accs
outcomes = {}
for model_type in model_types:
print(f'Coaching {model_type}')
outcomes(model_type) = ()
for i in vary(10):
mannequin = GNN(dataset.num_features, hidden_dim, dataset.num_classes, num_layers, model_type, gat_heads=8).to(machine)
optimizer = torch.optim.Adam(mannequin.parameters(), lr=0.01, weight_decay=5e-4)
if model_type == 'SAGE':
train_loader = NeighborLoader(
information,
num_neighbors=num_neighbors,
batch_size=batch_size,
input_nodes=information.train_mask,
)
def prepare():
mannequin.prepare()
total_loss = 0
for batch in train_loader:
batch = batch.to(machine)
optimizer.zero_grad()
out = mannequin(batch.x, batch.edge_index)
loss = F.cross_entropy(out, batch.y(:out.dimension(0)))
loss.backward()
optimizer.step()
total_loss += loss.merchandise()
return total_loss / len(train_loader)
else:
def prepare():
mannequin.prepare()
optimizer.zero_grad()
out = mannequin(information.x, information.edge_index)
loss = F.cross_entropy(out(information.train_mask), information.y(information.train_mask))
loss.backward()
optimizer.step()
return loss.merchandise()
best_val_acc = 0
best_test_acc = 0
for epoch in vary(1, num_epochs + 1):
loss = prepare()
train_acc, val_acc, test_acc = check(mannequin)
if val_acc > best_val_acc:
best_val_acc = val_acc
best_test_acc = test_acc
if epoch % 10 == 0:
print(f'Epoch {epoch:02d} | Loss: {loss:.4f} | Practice: {train_acc:.4f} | Val: {val_acc:.4f} | Check: {test_acc:.4f}')
outcomes(model_type).append((best_val_acc, best_test_acc))
for model_name, model_results in outcomes.gadgets():
model_results = torch.tensor(model_results)
print(f'{model_name} Val Accuracy: {model_results(:, 0).imply():.3f} ± {model_results(:, 0).std():.3f}')
print(f'{model_name} Check Accuracy: {model_results(:, 1).imply():.3f} ± {model_results(:, 1).std():.3f}')
Und hier sind die Ergebnisse:
GCN Val Accuracy: 0.791 ± 0.007
GCN Check Accuracy: 0.806 ± 0.006
GAT Val Accuracy: 0.790 ± 0.007
GAT Check Accuracy: 0.800 ± 0.004
SAGE Val Accuracy: 0.899 ± 0.005
SAGE Check Accuracy: 0.907 ± 0.004
Beeindruckende Verbesserung! Auch in diesem kleinen Datensatz übertrifft Graphsage Gat und GCN leicht! Ich wiederholte diesen Check für CiteSeer- und PubMed -Datensätze, und Graphsage kam immer am besten heraus.
Was ich hier gerne bemerkte, ist, dass GCN immer noch sehr nützlich ist, es ist eine der effektivsten Basislinien (wenn die Grafikstruktur dies zulässt). Außerdem habe ich nicht viel Hyperparameter-Tuning durchgeführt, sondern nur einige Standardwerte (wie 8 Köpfe für die Multi-Head-Aufmerksamkeit). In größeren, komplexeren und lauteren Graphen werden die Vorteile von Graphsage klarer als in diesem Beispiel. Wir haben keine Leistungstests durchgeführt, da für diese kleinen Graphen nicht schneller als GCN.
Abschluss
Graphsage bringt uns sehr nette Verbesserungen und Vorteile im Vergleich zu GATS und GCNs. Induktives Lernen ist möglich, Graphsage kann sich recht intestine ändern. Und wir haben es in diesem Beitrag nicht getestet, aber die Nachbar -Probenahme ermöglicht es, Characteristic -Darstellungen für größere Grafiken mit guter Leistung zu erstellen.
Verwandt
Optimierung von Verbindungen: Mathematische Optimierung innerhalb der Grafiken
Graph Neural Networks Teil 1. Graph Faltungsnetzwerke erläutert
Graph Neural Networks Teil 2. Diagramm -Aufmerksamkeitsnetzwerke vs. GCNs