Dies ist eine direkte Fortsetzung von a vorheriger Beitrag zum Thema Implementierung benutzerdefinierter TPU-Operationen mit Pallas. Von besonderem Interesse sind benutzerdefinierte Kernel, die die einzigartigen Eigenschaften der TPU-Architektur so nutzen, dass die Laufzeitleistung optimiert wird. In diesem Beitrag werden wir versuchen, diese Möglichkeit zu demonstrieren, indem wir die Leistungsfähigkeit von Pallas auf die Herausforderung anwenden, sequentielle Algorithmen auszuführen, die in eine überwiegend parallelisierbare Deep Studying (DL)-Arbeitslast eingestreut sind.
Wir werden uns darauf konzentrieren Nicht maximale Unterdrückung (NMS) von Bounding-Field-Vorschlägen als repräsentativem Algorithmus und erkunden Möglichkeiten zur Optimierung seiner Implementierung. Ein wichtiger Bestandteil von Pc Imaginative and prescient (CV) Objekterkennung Lösungen (z. B. Maske RCNN) wird NMS häufig verwendet, um überlappende Begrenzungsrahmen herauszufiltern und nur die „besten“ zu behalten. NMS erhält eine Liste von Bounding-Field-Vorschlägen, eine zugehörige Liste von Bewertungen und eine Schuldschein Schwelle und geht weiter zu gierig Und iterativ Wählen Sie das verbleibende Kästchen mit der höchsten Punktzahl aus und disqualifizieren Sie alle anderen Kästchen, bei denen es einen Schuldschein hat, der den angegebenen Schwellenwert überschreitet. Die Tatsache, dass die Field ausgewählt wurde n-ter Die Iteration hängt vom Vorhergehenden ab n-1 Die Anzahl der Schritte des Algorithmus bestimmt die Reihenfolge seiner Implementierung. Bitte sehen Hier und/oder Hier Weitere Informationen zum Grundprinzip von NMS und seiner Implementierung. Obwohl wir uns entschieden haben, uns auf einen bestimmten Algorithmus zu konzentrieren, sollte der Großteil unserer Diskussion auf andere sequentielle Algorithmen übertragen werden.
Sequentielle Algorithmen auf die CPU auslagern
Das Vorhandensein eines sequentiellen Algorithmus innerhalb eines überwiegend parallelisierbaren ML-Modells (z. B. Masks R-CNN) stellt eine interessante Herausforderung dar. Während GPUs, die üblicherweise für solche Arbeitslasten verwendet werden, sich hervorragend bei der Ausführung paralleler Operationen wie der Matrixmultiplikation eignen, können sie bei der Verarbeitung sequenzieller Algorithmen im Vergleich zu CPUs deutlich schlechter abschneiden. Dies führt oft zu Berechnungsdiagrammen, die Überschneidungen zwischen GPU und CPU beinhalten, wobei die GPU die parallelen Vorgänge und die CPU die sequentiellen Vorgänge übernimmt. NMS ist ein Paradebeispiel für einen sequentiellen Algorithmus, der üblicherweise auf die CPU ausgelagert wird. Tatsächlich ist eine genaue Analyse von Fackelvision’s „CUDA“-Implementierung von NMSzeigt, dass sogar ein erheblicher Teil des Algorithmus darauf ausgeführt wird CPU.
Obwohl die Auslagerung sequenzieller Vorgänge auf die CPU zu einer verbesserten Laufzeitleistung führen kann, sind mehrere potenzielle Nachteile zu berücksichtigen:
- Die geräteübergreifende Ausführung zwischen CPU und GPU erfordert normalerweise mehrere Synchronisierungspunkte zwischen den Geräten, was häufig zu Leerlaufzeiten auf der GPU führt, während diese darauf wartet, dass die CPU ihre Aufgaben erledigt. Da die GPU in der Regel die teuerste Komponente der Trainingsplattform ist, ist es unser Ziel, diese Leerlaufzeiten zu minimieren.
- In Customary-ML-Workflows ist die CPU für die Vorbereitung und Übermittlung von Daten an das Modell verantwortlich, das sich auf der GPU befindet. Wenn die Dateneingabepipeline eine rechenintensive Verarbeitung erfordert, kann dies die CPU belasten und zu einem „Eingabemangel“ auf der GPU führen. In solchen Szenarien könnte die Auslagerung von Teilen der Modellberechnung auf die CPU dieses Drawback noch verschlimmern.
Um diese Nachteile zu vermeiden, könnten Sie various Ansätze in Betracht ziehen, z. B. den Ersatz des sequentiellen Algorithmus durch eine vergleichbare Different (z. B. die vorgeschlagene). Hier), sich mit einer langsamen/suboptimalen GPU-Implementierung des sequentiellen Algorithmus zufrieden geben oder die Arbeitslast auf der CPU ausführen – jedes davon bringt seine eigenen potenziellen Kompromisse mit sich.
Sequentielle Algorithmen auf TPU
Hier könnte die einzigartige Architektur der TPU eine Likelihood bieten. Im Gegensatz zu GPUs sind TPUs sequentielle Prozessoren. Während ihre Fähigkeit, stark vektorisierte Operationen auszuführen, sie bei der Ausführung parallelisierbarer Operationen wie der Matrixmultiplikation mit GPUs konkurrenzfähig macht, könnten sie sich aufgrund ihrer sequentiellen Natur hervorragend für die Ausführung von ML-Workloads eignen, die eine Mischung aus sequentiellen und parallelen Komponenten enthalten. Bewaffnet mit dem Pallas-Erweiterung zu JAX, unserem Neu entdeckte TPU-Kernel-Erstellung Device werden wir diese Möglichkeit evaluieren, indem wir eine benutzerdefinierte Implementierung von NMS für TPU implementieren und evaluieren.
Haftungsausschluss
Die NMS-Implementierungen, die wir im Folgenden vorstellen, dienen nur zu Demonstrationszwecken. Wir haben keine nennenswerten Anstrengungen unternommen, um sie zu optimieren oder ihre Robustheit, Haltbarkeit oder Genauigkeit zu überprüfen. Bitte bedenken Sie, dass Pallas zum Zeitpunkt des Verfassens dieses Artikels ein ist Experimental- Function – noch in der aktiven Entwicklung. Der von uns geteilte Code (basierend auf JAX Model 0.4.32) kann zu dem Zeitpunkt, an dem Sie dies lesen, veraltet sein. Beachten Sie unbedingt die aktuellsten APIs und Ressourcen, die für Ihre Pallas-Entwicklung verfügbar sind. Bitte betrachten Sie unsere Erwähnung eines Algorithmus, einer Bibliothek oder einer API nicht als Empfehlung für deren Verwendung.
Wir beginnen mit einer einfachen Implementierung von NMS in Numpy Dies dient als Grundlage für den Leistungsvergleich:
import numpy as npdef nms_cpu(packing containers, scores, max_output_size, threshold=0.1):
epsilon = 1e-5
# Convert bounding packing containers and scores to numpy
packing containers = np.array(packing containers)
scores = np.array(scores)
# coordinates of bounding packing containers
start_x = packing containers(:, 0)
start_y = packing containers(:, 1)
end_x = packing containers(:, 2)
end_y = packing containers(:, 3)
# Compute areas of bounding packing containers
areas = (end_x - start_x) * (end_y - start_y)
# Type by confidence rating of bounding packing containers
order = np.argsort(scores)
# Picked bounding packing containers
picked_boxes = ()
# Iterate over bounding packing containers
whereas order.measurement > 0 and len(picked_boxes) < max_output_size:
# The index of the remaining field with the best rating
index = order(-1)
# Choose the bounding field with largest confidence rating
picked_boxes.append(index.merchandise())
# Compute coordinates of intersection
x1 = np.most(start_x(index), start_x(order(:-1)))
x2 = np.minimal(end_x(index), end_x(order(:-1)))
y1 = np.most(start_y(index), start_y(order(:-1)))
y2 = np.minimal(end_y(index), end_y(order(:-1)))
# Compute areas of intersection and union
w = np.most(x2 - x1, 0.0)
h = np.most(y2 - y1, 0.0)
intersection = w * h
union = areas(index) + areas(order(:-1)) - intersection
# Compute the ratio between intersection and union
ratio = intersection / np.clip(union, min=epsilon)
# discard packing containers above overlap threshold
maintain = np.the place(ratio < threshold)
order = order(maintain)
return picked_boxes
Um die Leistung unserer NMS-Funktion zu bewerten, generieren wir einen Stapel zufälliger Boxen und Scores (als JAX-Tensoren) und führen das Skript auf a aus Google Cloud TPU v5e System unter Verwendung der gleichen Umgebung und des gleichen Benchmarking-Dienstprogramms wie in unserem vorheriger Beitrag. Für dieses Experiment geben wir die CPU als an JAX-Standardgerät:
import jax
from jax import random
import jax.numpy as jnpdef generate_random_boxes(run_on_cpu = False):
if run_on_cpu:
jax.config.replace('jax_default_device', jax.units('cpu')(0))
else:
jax.config.replace('jax_default_device', jax.units('tpu')(0))
n_boxes = 1024
img_size = 1024
k1, k2, k3 = random.cut up(random.key(0), 3)
# Randomly generate field sizes and positions
box_sizes = random.randint(k1,
form=(n_boxes, 2),
minval=1,
maxval=img_size)
top_left = random.randint(k2,
form=(n_boxes, 2),
minval=0,
maxval=img_size - 1)
bottom_right = jnp.clip(top_left + box_sizes, 0, img_size - 1)
# Concatenate top-left and bottom-right coordinates
rand_boxes = jnp.concatenate((top_left, bottom_right),
axis=1).astype(jnp.bfloat16)
rand_scores = jax.random.uniform(k3,
form=(n_boxes,),
minval=0.0,
maxval=1.0)
return rand_boxes, rand_scores
rand_boxes, rand_scores = generate_random_boxes(run_on_cpu=True)
time = benchmark(nms_cpu)(rand_boxes, rand_scores, max_output_size=128)
print(f'nms_cpu: {time}')
Die daraus resultierende durchschnittliche Laufzeit beträgt 2,99 Millisekunden. Beachten Sie die Annahme, dass sich die Eingabe- und Ausgabetensoren auf der CPU befinden. Wenn sie sich auf der TPU befinden, sollte auch die Zeit zum Kopieren zwischen den Geräten berücksichtigt werden.
Wenn unsere NMS-Funktion eine Komponente innerhalb eines größeren Berechnungsdiagramms ist, das auf der TPU ausgeführt wird, bevorzugen wir möglicherweise eine TPU-kompatible Implementierung, um die Nachteile der geräteübergreifenden Ausführung zu vermeiden. Der folgende Codeblock enthält eine JAX-Implementierung von NMS, die speziell für die Beschleunigung durch JIT-Kompilierung entwickelt wurde. Bezeichnet die Anzahl der Boxen mit Nbeginnen wir mit der Berechnung des IOU zwischen jedem der N(N-1) Paar Kisten und Vorbereiten einer NXN Boolescher Tensor (mask_threshold) wo die (i,jDer )-te Eintrag gibt an, ob die IOU zwischen den Boxen liegt ich Und J den vordefinierten Schwellenwert überschreiten.
Um die iterative Auswahl von Boxen zu vereinfachen, erstellen wir eine Kopie des Maskentensors (mask_threshold2), bei dem die diagonalen Elemente auf Null gesetzt werden, um zu verhindern, dass sich eine Field selbst unterdrückt. Wir definieren außerdem zwei Rating-Monitoring-Tensoren: out_scoresdas die Punktzahlen der ausgewählten Kästchen behält (und die Punktzahlen der eliminierten auf Null setzt) und verbleibende_Scoreswodurch die Bewertungen der noch berücksichtigten Boxen beibehalten werden. Wir verwenden dann die jax.lax.while_loop Funktion zum iterativen Auswählen von Boxen während der Aktualisierung out_scores Und verbleibende_Scores Tensoren. Beachten Sie, dass sich das Format der Ausgabe dieser Funktion von der vorherigen Funktion unterscheidet und möglicherweise angepasst werden muss, um in nachfolgende Schritte des Berechnungsdiagramms zu passen.
import functools# Given N packing containers, calculates mask_threshold an NxN boolean masks
# the place the (i,j) entry signifies whether or not the IOU of packing containers i and j
# exceed the edge. Returns mask_threshold, mask_threshold2
# which is equal to mask_threshold with zero diagonal and
# the scores modified so that every one values are larger than 0
def init_tensors(packing containers, scores, threshold=0.1):
epsilon = 1e-5
# Extract left, high, proper, backside coordinates
left = packing containers(:, 0)
high = packing containers(:, 1)
proper = packing containers(:, 2)
backside = packing containers(:, 3)
# Compute areas of packing containers
areas = (proper - left) * (backside - high)
# Calculate intersection factors
inter_l = jnp.most(left(None, :), left(:, None))
inter_t = jnp.most(high(None, :), high(:, None))
inter_r = jnp.minimal(proper(None, :), proper(:, None))
inter_b = jnp.minimal(backside(None, :), backside(:, None))
# Width, peak, and space of the intersection
inter_w = jnp.clip(inter_r - inter_l, 0)
inter_h = jnp.clip(inter_b - inter_t, 0)
inter_area = inter_w * inter_h
# Union of the areas
union = areas(None, :) + areas(:, None) - inter_area
# IoU calculation
iou = inter_area / jnp.clip(union, epsilon)
# Shift scores to be larger than zero
out_scores = scores - jnp.min(scores) + epsilon
# Create masks based mostly on IoU threshold
mask_threshold = iou > threshold
# Create masks excluding diagonal (i.e., self IoU is ignored)
mask_threshold2 = mask_threshold * (1-jnp.eye(mask_threshold.form(0),
dtype=mask_threshold.dtype))
return mask_threshold, mask_threshold2, out_scores
@functools.partial(jax.jit, static_argnames=('max_output_size', 'threshold'))
def nms_jax(packing containers, scores, max_output_size, threshold=0.1):
# initialize masks and rating tensors
mask_threshold, mask_threshold2, out_scores = init_tensors(packing containers,
scores,
threshold)
# The out_scores tensor will retain the scores of the chosen packing containers
# and 0 the scores of the eradicated ones
# remaining_scores will keep non-zero scores for packing containers that
# haven't been chosen or eradicated
remaining_scores = out_scores.copy()
def choose_box(state):
i, remaining_scores, out_scores = state
# select index of field with highest rating from remaining scores
index = jnp.argmax(remaining_scores)
# examine validity of chosen field
legitimate = remaining_scores(index) > 0
# If legitimate, zero all scores with IOU larger than threshold
# (together with the chosen index)
remaining_scores = jnp.the place(mask_threshold(index) *legitimate,
0,
remaining_scores)
# zero the scores of the eradicated tensors (not together with
# the chosen index)
out_scores = jnp.the place(mask_threshold2(index)*legitimate,
0,
out_scores)
i = i + 1
return i, remaining_scores, out_scores
def cond_fun(state):
i, _, _ = state
return (i < max_output_size)
i = 0
state = (i, remaining_scores, out_scores)
_, _, out_scores = jax.lax.while_loop(cond_fun, choose_box, state)
# Output the resultant scores. To extract the chosen packing containers,
# Take the max_output_size highest scores:
# min = jnp.minimal(jnp.count_nonzero(scores), max_output_size)
# indexes = jnp.argsort(out_scores, descending=True)(:min)
return out_scores
# nms_jax may be run on both the CPU the TPU
rand_boxes, rand_scores = generate_random_boxes(run_on_cpu=True)
time = benchmark(nms_jax)(rand_boxes, rand_scores, max_output_size=128)
print(f'nms_jax on CPU: {time}')
rand_boxes, rand_scores = generate_random_boxes(run_on_cpu=False)
time = benchmark(nms_jax)(rand_boxes, rand_scores, max_output_size=128)
print(f'nms_jax on TPU: {time}')
Die Laufzeiten dieser Implementierung von NMS betragen 1,231 bzw. 0,416 Millisekunden auf CPU und TPU.
Wir präsentieren nun eine benutzerdefinierte Implementierung von NMS, bei der wir explizit die Tatsache nutzen, dass Pallas-Kernel auf TPUs in a ausgeführt werden sequentiell. Unsere Implementierung verwendet zwei boolesche Matrixmasken und zwei Rating-Protecting-Tensoren, ähnlich dem Ansatz in unserer vorherigen Funktion.
Wir definieren eine Kernelfunktion, wähle_boxverantwortlich für die Auswahl des nächsten Kästchens und die Aktualisierung der Rating-Tensoren, die im Arbeitsspeicher verwaltet werden. Wir rufen den Kernel über ein eindimensionales Gitter auf, wobei die Anzahl der Schritte (dh die Gittergröße) durch das bestimmt wird max_output_size Parameter.
Beachten Sie, dass aufgrund einiger Einschränkungen (zum Zeitpunkt des Verfassens dieses Artikels) sind bei den von Pallas unterstützten Operationen einige akrobatische Maßnahmen erforderlich, um sowohl die Funktion „argmax“ als auch die Gültigkeitsprüfung für die ausgewählten Felder zu implementieren. Der Kürze halber lassen wir die technischen Particulars weg und verweisen den interessierten Leser auf die Kommentare im Code unten.
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu# argmax helper operate
def pallas_argmax(scores, n_boxes):
# we assume that the index of every field is saved within the
# least important bits of the rating (see under)
idx = jnp.max(scores.astype(float)).astype(int) % n_boxes
return idx
# Pallas kernel definition
def choose_box(scores, thresh_mask1, thresh_mask2, ret_scores,
scores_scratch, remaining_scores_scratch, *, nsteps, n_boxes):
# initialize scratch reminiscence on first step
@pl.when(pl.program_id(0) == 0)
def _():
scores_scratch(...) = scores(...)
remaining_scores_scratch(...) = scores(...)
remaining_scores = remaining_scores_scratch(...)
# select field
idx = pallas_argmax(remaining_scores, n_boxes)
# we use any to verfiy validity of the chosen field due
# to limitations on indexing in pallas
legitimate = (remaining_scores>0).any()
# updating rating tensors
remaining_scores_scratch(...) = jnp.the place(thresh_mask1(idx,...)*legitimate,
0,
remaining_scores)
scores_scratch(...) = jnp.the place(thresh_mask2(idx,...)*legitimate,
0,
scores_scratch(...))
# set return worth on closing step
@pl.when(pl.program_id(0) == nsteps - 1)
def _():
ret_scores(...) = scores_scratch(...)
@functools.partial(jax.jit, static_argnames=('max_output_size', 'threshold'))
def nms_pallas(packing containers, scores, max_output_size, threshold=0.1):
n_boxes = scores.measurement
mask_threshold, mask_threshold2, scores = init_tensors(packing containers,
scores,
threshold)
# In an effort to work across the Pallas argsort limitation
# we create a brand new scores tensor with the identical ordering of
# the enter scores tensor through which the index of every rating
# within the ordering is encoded within the least important bits
sorted = jnp.argsort(scores, descending=True)
# descending integers: n_boxes-1, ..., 2, 1, 0
descending = jnp.flip(jnp.arange(n_boxes))
# new scores in descending with the least important
# bits carrying the argsort of the enter scores
ordered_scores = n_boxes * descending + sorted
# new scores with identical ordering as enter scores
scores = jnp.empty_like(ordered_scores
).at(sorted).set(ordered_scores)
grid = (max_output_size,)
return pl.pallas_call(
functools.partial(choose_box,
nsteps=max_output_size,
n_boxes=n_boxes),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=(
pl.BlockSpec(block_shape=(n_boxes,)),
pl.BlockSpec(block_shape=(n_boxes, n_boxes)),
pl.BlockSpec(block_shape=(n_boxes, n_boxes)),
),
out_specs=pl.BlockSpec(block_shape=(n_boxes,)),
scratch_shapes=(pltpu.VMEM((n_boxes,), scores.dtype),
pltpu.VMEM((n_boxes,), scores.dtype)),
grid=grid,
),
out_shape=jax.ShapeDtypeStruct((n_boxes,), scores.dtype),
compiler_params=dict(mosaic=dict(
dimension_semantics=("arbitrary",)))
)(scores, mask_threshold, mask_threshold2)
rand_boxes, rand_scores = generate_random_boxes(run_on_cpu=False)
time = benchmark(nms_pallas)(rand_boxes, rand_scores, max_output_size=128)
print(f'nms_pallas: {time}')
Die durchschnittliche Laufzeit unseres benutzerdefinierten NMS-Operators beträgt 0,139 Millisekunden und ist damit etwa dreimal schneller als unsere JAX-native Implementierung. Dieses Ergebnis unterstreicht das Potenzial, die Implementierung sequentieller Algorithmen an die einzigartigen Eigenschaften der TPU-Architektur anzupassen.
Beachten Sie, dass wir in unserer Pallas-Kernel-Implementierung die vollständigen Eingabetensoren laden TPU-VMEM-Speicher. Angesichts der begrenzten Kapazität von VMEM wird eine Vergrößerung der Eingabegröße (d. h. eine Erhöhung der Anzahl der Begrenzungsrahmen) wahrscheinlich zu Speicherproblemen führen. Typischerweise können solche Einschränkungen durch behoben werden Aufteilung der Eingaben mit BlockSpecs. Leider würde die Anwendung dieses Ansatzes die aktuelle NMS-Implementierung zerstören. Die Implementierung von NMS über Eingabeblöcke hinweg würde ein anderes Design erfordern, was den Rahmen dieses Beitrags sprengen würde.
Die Ergebnisse unserer Experimente sind in der folgenden Tabelle zusammengefasst:
Diese Ergebnisse zeigen das Potenzial für die Ausführung vollständiger ML-Berechnungsdiagramme auf TPU, selbst wenn sie sequentielle Komponenten enthalten. Insbesondere die von unserem Pallas NMS-Betreiber demonstrierte Leistungsverbesserung unterstreicht die Möglichkeit, Kernel so anzupassen, dass die Stärken der TPU genutzt werden.
In unserem vorheriger Beitrag Wir haben von der Möglichkeit erfahren, mithilfe der Pallas-Erweiterung für JAX benutzerdefinierte TPU-Operatoren zu erstellen. Um diese Likelihood zu maximieren, müssen die Kernel-Implementierungen an die spezifischen Eigenschaften der TPU-Architektur angepasst werden. In diesem Beitrag haben wir uns auf die sequentielle Natur des TPU-Prozessors und seine Verwendung bei der Optimierung eines benutzerdefinierten NMS-Kernels konzentriert. Während die Skalierung der Lösung zur Unterstützung einer unbegrenzten Anzahl von Begrenzungsrahmen weitere Arbeit erfordern würde, bleiben die von uns besprochenen Grundprinzipien weiterhin anwendbar.
Noch in der experimentellen Section seiner Entwicklung gibt es in Pallas noch einige Einschränkungen, die möglicherweise kreative Umgehungen erfordern. Aber die Stärke und das Potenzial sind klar erkennbar und wir gehen davon aus, dass sie mit zunehmender Reife des Rahmenwerks noch zunehmen werden.