GANs von Grund auf in Python erstellen

Foto von Michael & Diane Weidner An Unsplash

TDie Idee der Generative Adversarial Networks (GANs) wurde 2014 von Goodfellow und seinen Kollegen (1) vorgestellt und erfreute sich kurz darauf großer Beliebtheit im Bereich der Computervision und Bilderzeugung. Obwohl sich die KI in den letzten zehn Jahren rasant weiterentwickelt hat und immer mehr neue Algorithmen entwickelt wurden, sind die Einfachheit und Genialität dieses Konzepts immer noch äußerst beeindruckend. Heute möchte ich daher anhand des Versuchs, Wolken aus RGB-Satellitenbildern (Rot, Grün, Blau) zu entfernen, zeigen, wie leistungsfähig diese Netzwerke sein können.

Die Erstellung eines ausgewogenen, ausreichend großen und korrekt vorverarbeiteten CV-Datensatzes nimmt viel Zeit in Anspruch, daher habe ich mich entschlossen, die Möglichkeiten von Kaggle zu erkunden. Der Datensatz, der mir für diese Aufgabe am besten geeignet erschien, ist EuroSat (2), der eine offene Lizenz hat. Er umfasst 27000 beschriftete RGB-Bilder 64×64 Pixel von Sentinel-2 und ist für die Lösung des Issues der Mehrklassenklassifizierung konzipiert.

Bildbeispiel aus dem EuroSat-Datensatz. Lizenz.

An der Klassifizierung selbst sind wir nicht interessiert, aber eines der Hauptmerkmale des EuroSat-Datensatzes ist, dass alle seine Bilder einen klaren Himmel zeigen. Das ist genau das, was wir brauchen. Wir übernehmen diesen Ansatz aus (3), verwenden diese Sentinel-2-Aufnahmen als Ziele und erstellen Eingaben, indem wir ihnen Rauschen (Wolken) hinzufügen.

Bereiten wir additionally unsere Daten vor, bevor wir tatsächlich über GANs sprechen. Zuerst müssen wir die Daten herunterladen und alle Klassen in einem Verzeichnis zusammenführen.

🐍Der vollständige Python-Code: GitHub.

import numpy as np
import pandas as pd
import random

from os import listdir, mkdir, rename
from os.path import be part of, exists
import shutil
import datetime

import matplotlib.pyplot as plt
from highlight_text import ax_text, fig_text
from PIL import Picture

import warnings

warnings.filterwarnings('ignore')

lessons = listdir('./EuroSat')
path_target = './EuroSat/all_targets'
path_input = './EuroSat/all_inputs'

"""RUN IT ONLY ONCE TO RENAME THE FILES IN THE UNPACKED ARCHIVE"""
mkdir(path_input)
mkdir(path_target)
okay = 1
for sort in lessons:
path = be part of('./EuroSat', str(sort))
for i, f in enumerate(listdir(path)):
shutil.copyfile(be part of(path, f),
be part of(path_target, f))
rename(be part of(path_target, f), be part of(path_target, f'{okay}.jpg'))
okay += 1

Der zweite wichtige Schritt ist die Erzeugung von Rauschen. Während Sie verschiedene Ansätze verwenden können, z. B. das zufällige Ausblenden einiger Pixel oder das Hinzufügen von Gaußschem Rauschen, möchte ich in diesem Artikel etwas Neues für mich ausprobieren – Perlin-Rauschen. Es wurde in den 80er Jahren von Ken Perlin (4) bei der Entwicklung filmischer Raucheffekte erfunden. Diese Artwork von Rauschen wirkt im Vergleich zu normalem Zufallsrauschen organischer. Lassen Sie es mich einfach beweisen.

def generate_perlin_noise(width, top, scale, octaves, persistence, lacunarity):
noise = np.zeros((top, width))
for i in vary(top):
for j in vary(width):
noise(i)(j) = pnoise2(i / scale,
j / scale,
octaves=octaves,
persistence=persistence,
lacunarity=lacunarity,
repeatx=width,
repeaty=top,
base=0)
return noise

def normalize_noise(noise):
min_val = noise.min()
max_val = noise.max()
return (noise - min_val) / (max_val - min_val)

def generate_clouds(width, top, base_scale, octaves, persistence, lacunarity):
clouds = np.zeros((top, width))
for octave in vary(1, octaves + 1):
scale = base_scale / octave
layer = generate_perlin_noise(width, top, scale, 1, persistence, lacunarity)
clouds += layer * (persistence ** octave)

clouds = normalize_noise(clouds)
return clouds

def overlay_clouds(picture, clouds, alpha=0.5):

clouds_rgb = np.stack((clouds) * 3, axis=-1)

picture = picture.astype(float) / 255.0
clouds_rgb = clouds_rgb.astype(float)

blended = picture * (1 - alpha) + clouds_rgb * alpha

blended = (blended * 255).astype(np.uint8)
return blended

width, top = 64, 64
octaves = 12 #variety of noise layers mixed
persistence = 0.5 #decrease persistence reduces the amplitude of higher-frequency octaves
lacunarity = 2 #larger lacunarity will increase the frequency of higher-frequency octaves
for i in vary(len(listdir(path_target))):
base_scale = random.uniform(5,120) #noise frequency
alpha = random.uniform(0,1) #transparency

clouds = generate_clouds(width, top, base_scale, octaves, persistence, lacunarity)

img = np.asarray(Picture.open(be part of(path_target, f'{i+1}.jpg')))
picture = Picture.fromarray(overlay_clouds(img,clouds, alpha))
picture.save(be part of(path_input,f'{i+1}.jpg'))
print(f'Processed {i+1}/{len(listdir(path_target))}')

idx = np.random.randint(27000)
fig,ax = plt.subplots(1,2)
ax(0).imshow(np.asarray(Picture.open(be part of(path_target, f'{idx}.jpg'))))
ax(1).imshow(np.asarray(Picture.open(be part of(path_input, f'{idx}.jpg'))))
ax(0).set_title("Goal")
ax(0).axis('off')
ax(1).set_title("Enter")
ax(1).axis('off')
plt.present()
Bild von Autor.

Wie Sie oben sehen können, sind die Wolken auf den Bildern sehr realistisch, sie haben unterschiedliche „Dichte“ und Texturen und ähneln den echten.

Wenn Sie Perlin-Rauschen genauso faszinierend finden wie ich, finden Sie hier ein wirklich cooles Video dazu, wie dieses Rauschen in der GameDev-Branche eingesetzt werden kann:

Da wir jetzt über einen gebrauchsfertigen Datensatz verfügen, sprechen wir über GANs.

Um diesen Gedanken besser zu veranschaulichen, stellen wir uns vor, Sie reisen durch Südostasien und brauchen dringend einen Kapuzenpullover, weil es draußen zu kalt ist. Auf dem nächsten Straßenmarkt finden Sie ein kleines Geschäft mit Markenkleidung. Der Verkäufer bringt Ihnen einen schönen Kapuzenpullover zum Anprobieren und sagt, es sei von der bekannten Marke ExpensiveButNotWorthIt. Sie sehen ihn sich genauer an und kommen zu dem Schluss, dass es sich offensichtlich um eine Fälschung handelt. Der Verkäufer sagt: „Second mal, ich habe den ECHTEN.“ Er kommt mit einem anderen Kapuzenpullover zurück, der noch mehr wie der Markenpullover aussieht, aber dennoch eine Fälschung ist. Nach mehreren Versuchen dieser Artwork bringt Ihnen der Verkäufer eine nicht zu unterscheidende Kopie des legendären ExpensiveButNotWorthIt und Sie kaufen ihn bereitwillig. So funktionieren die GANs im Grunde!

Im Fall von GANs nennt man Sie Diskriminator (D). Das Ziel eines Diskriminators ist es, zwischen einem echten und einem gefälschten Objekt zu unterscheiden oder die Aufgabe der binären Klassifizierung zu lösen. Der Verkäufer wird Generator (G) genannt, da er versucht, eine hochwertige Fälschung zu erzeugen. Diskriminator und Generator werden unabhängig voneinander trainiert, um sich gegenseitig zu übertreffen. Daher erhalten wir am Ende eine hochwertige Fälschung.

GAN-Architektur. Lizenz.

Der Trainingsprozess sieht ursprünglich folgendermaßen aus:

  1. Beispiel für Eingangsrauschen (in unserem Fall Bilder mit Wolken).
  2. Geben Sie G das Rauschen weiter und sammeln Sie die Vorhersage.
  3. Berechnen Sie den D-Verlust, indem Sie zwei Vorhersagen einholen, eine für die Ausgabe von G und eine für die realen Daten.
  4. Aktualisieren Sie die Gewichte von D.
  5. Probieren Sie das Eingangsrauschen erneut aus.
  6. Geben Sie G das Rauschen weiter und sammeln Sie die Vorhersage.
  7. Berechnen Sie den G-Verlust, indem Sie seine Vorhersage in D einspeisen.
  8. Aktualisieren Sie die Gewichte von G.
Trainingsschleife für GANs. Quelle: (1).

Mit anderen Worten können wir eine Wertfunktion V(G,D) definieren:

Quelle: (1).

wo wir den Begriff minimieren wollen log(1-D(G(z))) um G zu trainieren und zu maximieren log D(x) um D zu trainieren (in dieser Notation x – reale Datenprobe und z – Rauschen).

Versuchen wir jetzt, es in PyTorch zu implementieren!

Im Originalartikel sprechen die Autoren von der Verwendung von Multilayer Perceptron (MLP); es wird auch oft einfach als ANN bezeichnet, aber ich möchte einen etwas komplizierteren Ansatz ausprobieren – ich möchte die UNet-Architektur (5) als Generator und ResNet (6) als Diskriminator verwenden. Dies sind beides bekannte CNN-Architekturen, daher werde ich sie hier nicht erklären (lassen Sie mich in den Kommentaren wissen, ob ich einen separaten Artikel schreiben soll).

Lasst sie uns bauen. Diskriminator:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.useful as F
from torch.utils.knowledge import Dataset, DataLoader
from torchvision import transforms
from torch.utils.knowledge import Subset
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride = 1, downsample = None):
tremendous(ResidualBlock, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = stride, padding = 1),
nn.BatchNorm2d(out_channels),
nn.ReLU())
self.conv2 = nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
nn.BatchNorm2d(out_channels))
self.downsample = downsample
self.relu = nn.ReLU()
self.out_channels = out_channels

def ahead(self, x):
residual = x
out = self.conv1(x)
out = self.conv2(out)
if self.downsample:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out

class ResNet(nn.Module):
def __init__(self, block=ResidualBlock, all_connections=(3,4,6,3)):
tremendous(ResNet, self).__init__()
self.inputs = 16
self.conv1 = nn.Sequential(
nn.Conv2d(3, 16, kernel_size = 3, stride = 1, padding = 1),
nn.BatchNorm2d(16),
nn.ReLU()) #16x64x64
self.maxpool = nn.MaxPool2d(kernel_size = 2, stride = 2) #16x32x32

self.layer0 = self.makeLayer(block, 16, all_connections(0), stride = 1) #connections = 3, form: 16x32x32
self.layer1 = self.makeLayer(block, 32, all_connections(1), stride = 2)#connections = 4, form: 32x16x16
self.layer2 = self.makeLayer(block, 128, all_connections(2), stride = 2)#connections = 6, form: 1281x8x8
self.layer3 = self.makeLayer(block, 256, all_connections(3), stride = 2)#connections = 3, form: 256x4x4
self.avgpool = nn.AvgPool2d(4, stride=1)
self.fc = nn.Linear(256, 1)

def makeLayer(self, block, outputs, connections, stride=1):
downsample = None
if stride != 1 or self.inputs != outputs:
downsample = nn.Sequential(
nn.Conv2d(self.inputs, outputs, kernel_size=1, stride=stride),
nn.BatchNorm2d(outputs),
)
layers = ()
layers.append(block(self.inputs, outputs, stride, downsample))
self.inputs = outputs
for i in vary(1, connections):
layers.append(block(self.inputs, outputs))

return nn.Sequential(*layers)

def ahead(self, x):
x = self.conv1(x)
x = self.maxpool(x)
x = self.layer0(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.avgpool(x)
x = x.view(-1, 256)
x = self.fc(x).flatten()
return F.sigmoid(x)

Generator:


class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
tremendous(DoubleConv, self).__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)

def ahead(self, x):
return self.double_conv(x)

class UNet(nn.Module):
def __init__(self):
tremendous().__init__()
self.conv_1 = DoubleConv(3, 32) # 32x64x64
self.pool_1 = nn.MaxPool2d(kernel_size=2, stride=2) # 32x32x32

self.conv_2 = DoubleConv(32, 64) #64x32x32
self.pool_2 = nn.MaxPool2d(kernel_size=2, stride=2) #64x16x16

self.conv_3 = DoubleConv(64, 128) #128x16x16
self.pool_3 = nn.MaxPool2d(kernel_size=2, stride=2) #128x8x8

self.conv_4 = DoubleConv(128, 256) #256x8x8
self.pool_4 = nn.MaxPool2d(kernel_size=2, stride=2) #256x4x4

self.conv_5 = DoubleConv(256, 512) #512x2x2

#DECODER
self.upconv_1 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) #256x4x4
self.conv_6 = DoubleConv(512, 256) #256x4x4

self.upconv_2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2) #128x8x8
self.conv_7 = DoubleConv(256, 128) #128x8x8

self.upconv_3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) #64x16x16
self.conv_8 = DoubleConv(128, 64) #64x16x16

self.upconv_4 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2) #32x32x32
self.conv_9 = DoubleConv(64, 32) #32x32x32

self.output = nn.Conv2d(32, 3, kernel_size = 3, stride = 1, padding = 1) #3x64x64

def ahead(self, batch):

conv_1_out = self.conv_1(batch)
conv_2_out = self.conv_2(self.pool_1(conv_1_out))
conv_3_out = self.conv_3(self.pool_2(conv_2_out))
conv_4_out = self.conv_4(self.pool_3(conv_3_out))
conv_5_out = self.conv_5(self.pool_4(conv_4_out))

conv_6_out = self.conv_6(torch.cat((self.upconv_1(conv_5_out), conv_4_out), dim=1))
conv_7_out = self.conv_7(torch.cat((self.upconv_2(conv_6_out), conv_3_out), dim=1))
conv_8_out = self.conv_8(torch.cat((self.upconv_3(conv_7_out), conv_2_out), dim=1))
conv_9_out = self.conv_9(torch.cat((self.upconv_4(conv_8_out), conv_1_out), dim=1))

output = self.output(conv_9_out)

return F.sigmoid(output)

Jetzt müssen wir unsere Daten in Practice/Take a look at aufteilen und sie in einen Torch-Datensatz packen:

class dataset(Dataset):
def __init__(self, batch_size, images_paths, targets, img_size = 64):
self.batch_size = batch_size
self.img_size = img_size
self.images_paths = images_paths
self.targets = targets
self.len = len(self.images_paths) // batch_size

self.remodel = transforms.Compose((
transforms.ToTensor(),
))

self.batch_im = (self.images_paths(idx * self.batch_size:(idx + 1) * self.batch_size) for idx in vary(self.len))
self.batch_t = (self.targets(idx * self.batch_size:(idx + 1) * self.batch_size) for idx in vary(self.len))

def __getitem__(self, idx):
pred = torch.stack((
self.remodel(Picture.open(be part of(path_input,file_name)))
for file_name in self.batch_im(idx)
))
goal = torch.stack((
self.remodel(Picture.open(be part of(path_target,file_name)))
for file_name in self.batch_im(idx)
))
return pred, goal

def __len__(self):
return self.len

Perfekt. Es ist Zeit, die Trainingsschleife zu schreiben. Bevor wir das tun, definieren wir unsere Verlustfunktionen und den Optimierer:

gadget = torch.gadget("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 64
num_epochs = 15
learning_rate_D = 1e-5
learning_rate_G = 1e-4

discriminator = ResNet()
generator = UNet()

bce = nn.BCEWithLogitsLoss()
l1loss = nn.L1Loss()

optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate_D)
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate_G)

scheduler_D = optim.lr_scheduler.StepLR(optimizer_D, step_size=10, gamma=0.1)
scheduler_G = optim.lr_scheduler.StepLR(optimizer_G, step_size=10, gamma=0.1)

Wie Sie sehen, unterscheiden sich diese Verluste von dem Bild mit dem GAN-Algorithmus. Insbesondere habe ich L1Loss hinzugefügt. Die Idee ist, dass wir nicht einfach ein zufälliges Bild aus Rauschen erzeugen, sondern die meisten Informationen aus der Eingabe behalten und nur das Rauschen entfernen möchten. Der G-Verlust beträgt additionally:

G_Verlust = log(1 − D(G(z))) + 𝝀 |G(z)-y|

statt nur

G_Verlust = log(1 − D(G(z)))

𝝀 ist ein beliebiger Koeffizient, der zwei Komponenten der Verluste ausgleicht.

Lassen Sie uns abschließend die Daten aufteilen, um den Trainingsprozess zu starten:

test_ratio, train_ratio = 0.3, 0.7
num_test = int(len(listdir(path_target))*test_ratio)
num_train = int((int(len(listdir(path_target)))-num_test))

img_size = (64, 64)

print("Variety of prepare samples:", num_train)
print("Variety of check samples:", num_test)

random.seed(231)
train_idxs = np.array(random.pattern(vary(num_test+num_train), num_train))
masks = np.ones(num_train+num_test, dtype=bool)
masks(train_idxs) = False

photos = {}
options = random.pattern(listdir(path_input),num_test+num_train)
targets = random.pattern(listdir(path_target),num_test+num_train)

random.Random(231).shuffle(options)
random.Random(231).shuffle(targets)

train_input_img_paths = np.array(options)(train_idxs)
train_target_img_path = np.array(targets)(train_idxs)
test_input_img_paths = np.array(options)(masks)
test_target_img_path = np.array(targets)(masks)

train_loader = dataset(batch_size=batch_size, img_size=img_size, images_paths=train_input_img_paths, targets=train_target_img_path)
test_loader = dataset(batch_size=batch_size, img_size=img_size, images_paths=test_input_img_paths, targets=test_target_img_path)

Jetzt können wir unsere Trainingsschleife ausführen:

train_loss_G, train_loss_D, val_loss_G, val_loss_D = (), (), (), ()
all_loss_G, all_loss_D = (), ()
best_generator_epoch_val_loss, best_discriminator_epoch_val_loss = -np.inf, -np.inf
for epoch in vary(num_epochs):

discriminator.prepare()
generator.prepare()

discriminator_epoch_loss, generator_epoch_loss = 0, 0

for inputs, targets in train_loader:
inputs, true = inputs, targets

'''1. Coaching the Discriminator (ResNet)'''
optimizer_D.zero_grad()

pretend = generator(inputs).detach()

pred_fake = discriminator(pretend).to(gadget)
loss_fake = bce(pred_fake, torch.zeros(batch_size, gadget=gadget))

pred_real = discriminator(true).to(gadget)
loss_real = bce(pred_real, torch.ones(batch_size, gadget=gadget))

loss_D = (loss_fake+loss_real)/2

loss_D.backward()
optimizer_D.step()

discriminator_epoch_loss += loss_D.merchandise()
all_loss_D.append(loss_D.merchandise())

'''2. Coaching the Generator (UNet)'''
optimizer_G.zero_grad()

pretend = generator(inputs)
pred_fake = discriminator(pretend).to(gadget)

loss_G_bce = bce(pred_fake, torch.ones_like(pred_fake, gadget=gadget))
loss_G_l1 = l1loss(pretend, targets)*100
loss_G = loss_G_bce + loss_G_l1
loss_G.backward()
optimizer_G.step()

generator_epoch_loss += loss_G.merchandise()
all_loss_G.append(loss_G.merchandise())

discriminator_epoch_loss /= len(train_loader)
generator_epoch_loss /= len(train_loader)
train_loss_D.append(discriminator_epoch_loss)
train_loss_G.append(generator_epoch_loss)

discriminator.eval()
generator.eval()

discriminator_epoch_val_loss, generator_epoch_val_loss = 0, 0

with torch.no_grad():
for inputs, targets in test_loader:
inputs, targets = inputs, targets

pretend = generator(inputs)
pred = discriminator(pretend).to(gadget)

loss_G_bce = bce(pretend, torch.ones_like(pretend, gadget=gadget))
loss_G_l1 = l1loss(pretend, targets)*100
loss_G = loss_G_bce + loss_G_l1
loss_D = bce(pred.to(gadget), torch.zeros(batch_size, gadget=gadget))

discriminator_epoch_val_loss += loss_D.merchandise()
generator_epoch_val_loss += loss_G.merchandise()

discriminator_epoch_val_loss /= len(test_loader)
generator_epoch_val_loss /= len(test_loader)

val_loss_D.append(discriminator_epoch_val_loss)
val_loss_G.append(generator_epoch_val_loss)

print(f"------Epoch ({epoch+1}/{num_epochs})------nTrain Loss D: {discriminator_epoch_loss:.4f}, Val Loss D: {discriminator_epoch_val_loss:.4f}")
print(f'Practice Loss G: {generator_epoch_loss:.4f}, Val Loss G: {generator_epoch_val_loss:.4f}')

if discriminator_epoch_val_loss > best_discriminator_epoch_val_loss:
discriminator_epoch_val_loss = best_discriminator_epoch_val_loss
torch.save(discriminator.state_dict(), "discriminator.pth")
if generator_epoch_val_loss > best_generator_epoch_val_loss:
generator_epoch_val_loss = best_generator_epoch_val_loss
torch.save(generator.state_dict(), "generator.pth")
#scheduler_D.step()
#scheduler_G.step()

fig, ax = plt.subplots(1,3)
ax(0).imshow(np.transpose(inputs.numpy()(7), (1,2,0)))
ax(1).imshow(np.transpose(targets.numpy()(7), (1,2,0)))
ax(2).imshow(np.transpose(pretend.detach().numpy()(7), (1,2,0)))
plt.present()

Nachdem der Code fertig ist, können wir die Verluste plotten. Dieser Code wurde teilweise übernommen von diese coole Webseite:

from matplotlib.font_manager import FontProperties

background_color = '#001219'
font = FontProperties(fname='LexendDeca-VariableFont_wght.ttf')
fig, ax = plt.subplots(1, 2, figsize=(16, 9))
fig.set_facecolor(background_color)
ax(0).set_facecolor(background_color)
ax(1).set_facecolor(background_color)

ax(0).plot(vary(len(all_loss_G)), all_loss_G, colour='#bc6c25', lw=0.5)
ax(1).plot(vary(len(all_loss_D)), all_loss_D, colour='#00b4d8', lw=0.5)

ax(0).scatter(
(np.array(all_loss_G).argmax(), np.array(all_loss_G).argmin()),
(np.array(all_loss_G).max(), np.array(all_loss_G).min()),
s=30, colour='#bc6c25',
)
ax(1).scatter(
(np.array(all_loss_D).argmax(), np.array(all_loss_D).argmin()),
(np.array(all_loss_D).max(), np.array(all_loss_D).min()),
s=30, colour='#00b4d8',
)

ax_text(
np.array(all_loss_G).argmax()+60, np.array(all_loss_G).max()+0.1,
f'{spherical(np.array(all_loss_G).max(),1)}',
fontsize=13, colour='#bc6c25',
font=font,
ax=ax(0)
)
ax_text(
np.array(all_loss_G).argmin()+60, np.array(all_loss_G).min()-0.1,
f'{spherical(np.array(all_loss_G).min(),1)}',
fontsize=13, colour='#bc6c25',
font=font,
ax=ax(0)
)

ax_text(
np.array(all_loss_D).argmax()+60, np.array(all_loss_D).max()+0.01,
f'{spherical(np.array(all_loss_D).max(),1)}',
fontsize=13, colour='#00b4d8',
font=font,
ax=ax(1)
)
ax_text(
np.array(all_loss_D).argmin()+60, np.array(all_loss_D).min()-0.005,
f'{spherical(np.array(all_loss_D).min(),1)}',
fontsize=13, colour='#00b4d8',
font=font,
ax=ax(1)
)
for i in vary(2):
ax(i).tick_params(axis='x', colours='white')
ax(i).tick_params(axis='y', colours='white')
ax(i).spines('left').set_color('white')
ax(i).spines('backside').set_color('white')
ax(i).set_xlabel('Epoch', colour='white', fontproperties=font, fontsize=13)
ax(i).set_ylabel('Loss', colour='white', fontproperties=font, fontsize=13)

ax(0).set_title('Generator', colour='white', fontproperties=font, fontsize=18)
ax(1).set_title('Discriminator', colour='white', fontproperties=font, fontsize=18)
plt.savefig('Loss.jpg')
plt.present()
# ax(0).set_axis_off()
# ax(1).set_axis_off()

Von admin

Schreibe einen Kommentar

Deine E-Mail-Adresse wird nicht veröffentlicht. Erforderliche Felder sind mit * markiert