トップ画像
画像生成AI作るぞ

執筆者: кемо

最終更新: 2025/09/25

はじめに

やあ

頑張って画像生成AI作るよ。ちなみに流行りの拡散モデルじゃないよ。

GANとは

GANとは画像や音声の生成に使われる学習モデルの1つです。昔はかなり耳にしましたが、今はもはや忘れ去られてるんじゃないかというレベルですね。まあ、全然そんなことなく普通に今も使われてるし、通用する技術です。

で、本題のGANについて少しだけ詳しく話します。GANとは「敵対的生成ネットワーク」(カッコイイ)の略で、生成器・識別器と呼ばれる2つのモデルを競争させながら学習します。こんな説明でわかるわけない。まず生成器・識別器とは何なのかを簡単にまとめましょう。

  • 生成器:対象を生成する。識別機に本物と判断させたら勝ち
  • 識別器:対象を識別する。生成器の生成物を見抜けたら勝ち

この程度の認識でも問題ないんじゃないですかね。「なんで急に勝ち負けの話をし始めたのか」というところから次の要素を説明します。一回の学習の流れを見ましょう。

  1. 学習開始
  2. 生成器が対象を生成する(最初はランダムなノイズ)
  3. 識別器が生成物と我々が用意した本物を判別する
  4. 判別結果をもとに両モデルが学習(生成器はよりうまく生成し、識別器はより精確に判別できるように)

これを繰り返すことで生成器はより本物に近く生成するようになり、識別器は本物と生成物を見抜けるようになっていきます。このいたちごっここそがGANです

おまけ程度のコード

# 生成器
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(latent_size, 1024, 4, 1, 0, bias=False),
            nn.BatchNorm2d(1024),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(32, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    def forward(self, x):
        return self.main(x)

これが生成器

# 判别器
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 32, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(32, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(512, 1024, 4, 2, 1, bias=False),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(1024, 1, 4, 1, 0, bias=False)
        )
    def forward(self, x):
        return self.main(x).view(-1)

これが識別器です。

なにも面白くない一般NNです。こんなコードを載せる意味はあるのか……。

    dataset =datasets.ImageFolder(
        root="dataset",
        transform=transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
    )

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    G = Generator().to(device)
    D = Discriminator().to(device)
    loss_fn = nn.BCEWithLogitsLoss()
    opt_G = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
    opt_D = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))

    os.makedirs("gan_images", exist_ok=True)
    fixed_noise = torch.randn(64, latent_size, 1, 1, device=device)

    for epoch in range(epochs):
        for i, (real_images, _) in enumerate(dataloader):
            real_images = real_images.to(device)
            batch_size = real_images.size(0)

            noise = torch.randn(batch_size, latent_size, 1, 1, device=device)
            fake_images = G(noise)
            D_real = D(real_images)
            D_fake = D(fake_images.detach())
            loss_D = loss_fn(D_real, torch.ones_like(D_real)) + loss_fn(D_fake, torch.zeros_like(D_fake))
            D.zero_grad()
            loss_D.backward()
            opt_D.step()

            D_fake = D(fake_images)
            loss_G = loss_fn(D_fake, torch.ones_like(D_fake))
            G.zero_grad()
            loss_G.backward()
            opt_G.step()

            if i % 50 == 0:
                print(f"[{epoch}/{epochs}] [{i}/{len(dataloader)}] Loss D: {loss_D.item():.4f}, Loss G: {loss_G.item():.4f}")
        
        with torch.no_grad():
            samples = G(fixed_noise).detach().cpu()
            vutils.save_image(samples, f"gan_images/epoch_{epoch+1}.png", normalize=True)

        if (epoch + 1) % 100 == 0 or epoch == epochs - 1:
            os.makedirs("gan_models", exist_ok=True)
            torch.save({
                'generator': G.state_dict(),
                'discriminator': D.state_dict(),
                'optimizer_G': opt_G.state_dict(),
                'optimizer_D': opt_D.state_dict(),
                'epoch': epoch
            }, f"gan_models/gan_checkpoint_epoch_{epoch+1}.pth")

これが学習コード。これを実行して放置するだけのお手軽学習です。時間はそんなかからないと思いますよ。正直、これやったの去年なので記憶にないです、今更記事書いてますが。

地獄のデータセット作成

ここからですよ、ここから。学習するためのデータが必要です。そんなものいくらでも転がっているのですが、ここでデータも作りたいというカスの欲求によってデータ作成が開始されました。まず何を学習させるのか。個人的にはケモミミ美少女を学習させて、大量のケモミミ美少女を生み出したかったのですが、どう考えてもそんなもの用意できません。これについては現代の圧倒的な技術力と資金力によって生み出された最強AIたちにやってもらいましょう。この記事の前に誰かが書いていたはずです。

とりあえず、ある程度見栄えが良くて、データが作りやすそうなものがいいでしょう。熟考の末採用されたのはマインクラフトのスクリーンショットでした。地形が単純で多様性もあって、ゲームなので外に出なくていいという判断です。

データは最低でも1万枚は欲しい。これを馬鹿正直にスクショしていては体力が持ちません。なので一定の間隔ごとにスクショするスクリプトを用意しました。そして…

マイクラをプレイします

は?馬鹿なのか?

なぜここを自動化しなかったのか。どう考えてもコマンドで自動で移動しながらスクショしたほうがいいに決まってます。この愚かな判断によって私は十数時間のマイクラプレイを強いられました。しかもUIなしでひたすら歩き回るだけ。何が面白いんだこれ。

この虚無作業ののち無事に1万枚のスクショが生み出されました。

成果物

うーん。遠目から見たらマイクラですね。もうちょっとデータ作って、学習スクリプトを調整すればよくなりそうです。

まとめ

まず計画をきちんと練るべきです。機械学習はデータが9割とかいいますが本当にそうです。これがケモミミ美少女メイドちゃんとかなモチベーションを保てたのですが、虚無マイクラは辛いよ…。

でも結構マイクラってきれいなんですよね。作りこみのすごさを実感しました。

では、ケモミミがあるように。

取得に失敗しました

2023年度 入部

GitHub