
執筆者: кемо
最終更新: 2025/09/25
やあ
頑張って画像生成AI作るよ。ちなみに流行りの拡散モデルじゃないよ。
GANとは画像や音声の生成に使われる学習モデルの1つです。昔はかなり耳にしましたが、今はもはや忘れ去られてるんじゃないかというレベルですね。まあ、全然そんなことなく普通に今も使われてるし、通用する技術です。
で、本題のGANについて少しだけ詳しく話します。GANとは「敵対的生成ネットワーク」(カッコイイ)の略で、生成器・識別器と呼ばれる2つのモデルを競争させながら学習します。こんな説明でわかるわけない。まず生成器・識別器とは何なのかを簡単にまとめましょう。
この程度の認識でも問題ないんじゃないですかね。「なんで急に勝ち負けの話をし始めたのか」というところから次の要素を説明します。一回の学習の流れを見ましょう。
これを繰り返すことで生成器はより本物に近く生成するようになり、識別器は本物と生成物を見抜けるようになっていきます。このいたちごっここそがGANです。
# 生成器
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(latent_size, 1024, 4, 1, 0, bias=False),
nn.BatchNorm2d(1024),
nn.ReLU(True),
nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(True),
nn.ConvTranspose2d(32, 3, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, x):
return self.main(x)これが生成器
# 判别器
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.main = nn.Sequential(
nn.Conv2d(3, 32, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(32, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, 4, 2, 1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 1024, 4, 2, 1, bias=False),
nn.BatchNorm2d(1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(1024, 1, 4, 1, 0, bias=False)
)
def forward(self, x):
return self.main(x).view(-1)これが識別器です。
なにも面白くない一般NNです。こんなコードを載せる意味はあるのか……。
dataset =datasets.ImageFolder(
root="dataset",
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
G = Generator().to(device)
D = Discriminator().to(device)
loss_fn = nn.BCEWithLogitsLoss()
opt_G = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
opt_D = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
os.makedirs("gan_images", exist_ok=True)
fixed_noise = torch.randn(64, latent_size, 1, 1, device=device)
for epoch in range(epochs):
for i, (real_images, _) in enumerate(dataloader):
real_images = real_images.to(device)
batch_size = real_images.size(0)
noise = torch.randn(batch_size, latent_size, 1, 1, device=device)
fake_images = G(noise)
D_real = D(real_images)
D_fake = D(fake_images.detach())
loss_D = loss_fn(D_real, torch.ones_like(D_real)) + loss_fn(D_fake, torch.zeros_like(D_fake))
D.zero_grad()
loss_D.backward()
opt_D.step()
D_fake = D(fake_images)
loss_G = loss_fn(D_fake, torch.ones_like(D_fake))
G.zero_grad()
loss_G.backward()
opt_G.step()
if i % 50 == 0:
print(f"[{epoch}/{epochs}] [{i}/{len(dataloader)}] Loss D: {loss_D.item():.4f}, Loss G: {loss_G.item():.4f}")
with torch.no_grad():
samples = G(fixed_noise).detach().cpu()
vutils.save_image(samples, f"gan_images/epoch_{epoch+1}.png", normalize=True)
if (epoch + 1) % 100 == 0 or epoch == epochs - 1:
os.makedirs("gan_models", exist_ok=True)
torch.save({
'generator': G.state_dict(),
'discriminator': D.state_dict(),
'optimizer_G': opt_G.state_dict(),
'optimizer_D': opt_D.state_dict(),
'epoch': epoch
}, f"gan_models/gan_checkpoint_epoch_{epoch+1}.pth")これが学習コード。これを実行して放置するだけのお手軽学習です。時間はそんなかからないと思いますよ。正直、これやったの去年なので記憶にないです、今更記事書いてますが。
ここからですよ、ここから。学習するためのデータが必要です。そんなものいくらでも転がっているのですが、ここでデータも作りたいというカスの欲求によってデータ作成が開始されました。まず何を学習させるのか。個人的にはケモミミ美少女を学習させて、大量のケモミミ美少女を生み出したかったのですが、どう考えてもそんなもの用意できません。これについては現代の圧倒的な技術力と資金力によって生み出された最強AIたちにやってもらいましょう。この記事の前に誰かが書いていたはずです。
とりあえず、ある程度見栄えが良くて、データが作りやすそうなものがいいでしょう。熟考の末採用されたのはマインクラフトのスクリーンショットでした。地形が単純で多様性もあって、ゲームなので外に出なくていいという判断です。
データは最低でも1万枚は欲しい。これを馬鹿正直にスクショしていては体力が持ちません。なので一定の間隔ごとにスクショするスクリプトを用意しました。そして…
は?馬鹿なのか?
なぜここを自動化しなかったのか。どう考えてもコマンドで自動で移動しながらスクショしたほうがいいに決まってます。この愚かな判断によって私は十数時間のマイクラプレイを強いられました。しかもUIなしでひたすら歩き回るだけ。何が面白いんだこれ。
この虚無作業ののち無事に1万枚のスクショが生み出されました。





うーん。遠目から見たらマイクラですね。もうちょっとデータ作って、学習スクリプトを調整すればよくなりそうです。
まず計画をきちんと練るべきです。機械学習はデータが9割とかいいますが本当にそうです。これがケモミミ美少女メイドちゃんとかなモチベーションを保てたのですが、虚無マイクラは辛いよ…。
でも結構マイクラってきれいなんですよね。作りこみのすごさを実感しました。
では、ケモミミがあるように。