訂閱
糾錯(cuò)
加入自媒體

擴(kuò)散模型迎來“終極簡化”!何愷明團(tuán)隊(duì)新作:像素級一步生成,速度質(zhì)量雙巔峰

作者:Yiyang Lu等

解讀:AI生成未來

亮點(diǎn)直擊

Pixel MeanFlow (pMF) ,這是一種針對一步生成(one-step generation)的創(chuàng)新圖像生成模型。pMF 的核心突破在于成功地在無隱空間(latent-free)的像素級建模中實(shí)現(xiàn)了高效的一步生成。

pMF不僅擺脫了對預(yù)訓(xùn)練潛在編碼器(如 VQ-GAN 或 VAE)的依賴,直接在原始像素空間操作,而且在生成質(zhì)量上達(dá)到了與最先進(jìn)的多步隱空間擴(kuò)散模型相媲美的水平。

解決的問題

現(xiàn)代生成模型通常在兩個(gè)核心維度上進(jìn)行權(quán)衡:

采樣效率:多步采樣雖然質(zhì)量高但推理慢。

空間選擇:隱空間(Latent Space)通過壓縮降低了維度,但引入了復(fù)雜的編碼器/解碼器,且丟失了像素級的直接控制;像素空間(Pixel Space)雖然直觀("所見即所得"),但高維數(shù)據(jù)建模難度極大。

將“一步生成”與“像素空間建模”結(jié)合是一個(gè)極具挑戰(zhàn)性的任務(wù),因?yàn)閱我簧窠?jīng)網(wǎng)絡(luò)需要同時(shí)承擔(dān)極其復(fù)雜的軌跡建模(trajectory modeling)和圖像壓縮/抽象(manifold learning)任務(wù),F(xiàn)有的方法難以兼顧這兩者。

提出的方案

pMF 的核心思想是將網(wǎng)絡(luò)的預(yù)測目標(biāo)損失函數(shù)的計(jì)算空間解耦:

預(yù)測目標(biāo) (Prediction Target) :網(wǎng)絡(luò)直接預(yù)測去噪后的“干凈”圖像 (即 -prediction);诹餍渭僭O(shè),干凈圖像位于低維流形上,更易于神經(jīng)網(wǎng)絡(luò)擬合。

損失空間 (Loss Space) :損失函數(shù)定義在速度場(velocity space)中,遵循 MeanFlow 的公式,通過最小化瞬時(shí)速度誤差來學(xué)習(xí)平均速度場 。

轉(zhuǎn)換機(jī)制:引入了一個(gè)簡單的轉(zhuǎn)換公式,在圖像流形  和平均速度場  之間建立聯(lián)系:。這一轉(zhuǎn)換使得模型能夠利用像素空間的流形結(jié)構(gòu),同時(shí)在速度空間進(jìn)行有效的軌跡匹配。

應(yīng)用的技術(shù)

Pixel-space Prediction:直接在像素空間參數(shù)化去噪圖像 ,利用低維流形假設(shè)降低學(xué)習(xí)難度,避免直接預(yù)測高頻噪聲或速度場帶來的困難。

MeanFlow Formulation:利用 Improved MeanFlow (iMF) 框架,通過瞬時(shí)速度  的損失來學(xué)習(xí)平均速度場 。

Flow Matching:基于流匹配理論,建立從噪聲分布到數(shù)據(jù)分布的概率流。

Perceptual Loss:由于模型直接輸出像素,天然適合引入感知損失(LPIPS 和 ConvNeXt 特征),進(jìn)一步提升生成圖像的視覺質(zhì)量,彌補(bǔ)了像素級 MSE 損失的不足。

達(dá)到的效果

pMF 在 ImageNet 數(shù)據(jù)集上展現(xiàn)了強(qiáng)大的性能,證明了一步無潛在生成的可行性:

ImageNet 256×256: FID 分?jǐn)?shù)達(dá)到 2.22,超越了許多多步隱空間模型。

ImageNet 512×512: FID 分?jǐn)?shù)達(dá)到 2.48。

這表明一步像素級生成模型已經(jīng)具備了極強(qiáng)的競爭力,且不需要額外的解碼器開銷(解碼器本身在隱空間模型中占據(jù)顯著計(jì)算量)。

背景

本工作的 pMF 建立在 Flow Matching、MeanFlow 以及 JiT的基礎(chǔ)之上。

Flow Matching. Flow Matching (FM) 學(xué)習(xí)一個(gè)速度場 ,將先驗(yàn)分布  映射到數(shù)據(jù)分布 。本文考慮標(biāo)準(zhǔn)的線性插值調(diào)度:

其中數(shù)據(jù) ,噪聲 (例如高斯分布),時(shí)間 。在  時(shí),有:。該插值產(chǎn)生一個(gè)條件速度 :

FM 通過最小化 -空間中的損失函數(shù)(即“-loss”)來優(yōu)化由  參數(shù)化的網(wǎng)絡(luò) :

已有研究表明 (Lipman et al., 2023), 的潛在目標(biāo)是邊緣速度 。

在推理階段,通過求解常微分方程 (ODE): 從  到  生成樣本,其中 。這可以通過 Euler 或基于 Heun 的數(shù)值求解器來實(shí)現(xiàn)。

Flow Matching with x-prediction. 等式 (2) 中的量  是一個(gè)帶噪聲的圖像。為了便于使用在像素上操作的 Transformer,JiT 選擇通過神經(jīng)網(wǎng)絡(luò)參數(shù)化數(shù)據(jù) ,并通過以下方式將其轉(zhuǎn)換為速度 :

其中  是 Vision Transformer (ViT) 的直接輸出。這種公式被稱為 -prediction,而在訓(xùn)練中使用等式 (2) 中的 -loss。表 1 列出了這種關(guān)系。

Mean Flows. MeanFlow (MF) 框架學(xué)習(xí)一個(gè)平均速度場  用于少步/一步生成。將 FM 的  視為瞬時(shí)速度,MF 定義平均速度  為:

其中  和  是兩個(gè)時(shí)間步:。該定義引出了 MeanFlow 恒等式:

該恒等式提供了一種通過網(wǎng)絡(luò)  定義預(yù)測函數(shù)的方法:

這里,大寫  對應(yīng)于等式 (6) 的左側(cè),而在右側(cè),JVP 表示用于計(jì)算  的 Jacobian-vector product,“sg”表示停止梯度(stop-gradient)。本文遵循 iMF的 JVP 計(jì)算和實(shí)現(xiàn),這不是本文的重點(diǎn)。根據(jù)等式 (7) 的定義,iMF 像等式 (3) 一樣最小化 -loss,即 。這種公式可以被視為帶有 -loss 的 -prediction(參見表 1)。

Pixel MeanFlow

為了實(shí)現(xiàn)一步、無潛在生成,本文提出了 Pixel MeanFlow (pMF)。pMF 的核心設(shè)計(jì)是在 、 和  的不同場之間建立聯(lián)系。本文希望網(wǎng)絡(luò)像 JiT一樣直接輸出 ,而一步建模則像 MeanFlow一樣在  和  空間上進(jìn)行。

去噪圖像場

如前所述,iMF 和 JiT均可視為在最小化瞬時(shí)速度  的損失(-loss),區(qū)別在于 iMF 執(zhí)行的是平均速度預(yù)測(-prediction),而 JiT 執(zhí)行的是原始數(shù)據(jù)預(yù)測(-prediction);谶@一觀察,本工作在平均速度  與一種廣義形式的  之間建立了一種映射聯(lián)系。

考慮等式 (5) 中定義的平均速度場 :該場代表了一個(gè)由數(shù)據(jù)分布 、先驗(yàn)分布  以及時(shí)間調(diào)度決定的底層真實(shí)量,它與具體的網(wǎng)絡(luò)參數(shù)  無關(guān)。由此,本文推導(dǎo)出一個(gè)誘導(dǎo)場(induced field),定義如下:

如下文詳述,該場  扮演了類似于“去噪圖像”的角色。需要注意的是,本工作定義的  與以往文獻(xiàn)中提及的  不同,它是一個(gè)受兩個(gè)時(shí)間戳  索引的二元變量:對于給定的觀測值 ,本文的  是一個(gè)隨  變化的二維場,而非僅受  索引的一維軌跡。

廣義流形假設(shè)

圖 1 通過模擬從預(yù)訓(xùn)練 FM 模型獲得的一條 ODE 軌跡,可視化了  場和  場。如圖所示, 由含噪圖像組成,因?yàn)樽鳛樗俣葓觯?nbsp;包含噪聲和數(shù)據(jù)成分。相比之下, 場具有去噪圖像的外觀:它們是接近干凈的圖像,或者是因過度去噪而顯得模糊的圖像。接下來,本文討論流形假設(shè)如何推廣到這個(gè)量 。

注意 MF 中的時(shí)間步  滿足:。本文首先展示在  和  處的邊界情況可以近似滿足流形假設(shè);然后討論  的情況。

邊界情況 I: . 當(dāng)  時(shí),平均速度  退化為瞬時(shí)速度 ,即 。在這種情況下,等式 (8) 變?yōu)椋?/p>

這本質(zhì)上是 JiT 中使用的 -prediction 目標(biāo)。直觀地說,這個(gè)  是 JiT 要預(yù)測的去噪圖像。如果噪聲水平很高,這個(gè)去噪圖像可能是模糊的。正如經(jīng)典圖像去噪研究中廣泛觀察到的那樣,可以假設(shè)這些去噪圖像近似位于低維(或較低維)流形上。

邊界情況 II: . 等式 (5) 中  的定義給出:。將其代入等式 (8) 得到:

即,它是 ODE 軌跡的終點(diǎn)。對于真實(shí)的 ODE 軌跡,有 ,即它應(yīng)遵循圖像分布。因此,本文可以假設(shè)  近似位于圖像流形上。

一般情況: . 與邊界情況不同,量  不保證對應(yīng)于來自數(shù)據(jù)流形的(可能模糊的)圖像樣本。然而,根據(jù)經(jīng)驗(yàn),本文的模擬(圖 1 右)表明  看起來像去噪圖像。這與速度空間量(圖 1 中的 )形成鮮明對比,后者噪聲明顯更多。這種比較表明,通過神經(jīng)網(wǎng)絡(luò)對  進(jìn)行建?赡鼙葘Ω须s的  進(jìn)行建模更容易。實(shí)驗(yàn)表明,對于像素空間模型,-prediction 表現(xiàn)有效,而 -prediction 則嚴(yán)重退化。

算法

等式 (8) 中的誘導(dǎo)場  提供了 MeanFlow 網(wǎng)絡(luò)的一種重參數(shù)化。具體來說,本文讓網(wǎng)絡(luò)  直接輸出 ,并通過等式 (8) 計(jì)算相應(yīng)的速度場 :

這里, 是網(wǎng)絡(luò)的直接輸出,遵循 JiT。這個(gè)公式是等式 (4) 的自然擴(kuò)展。

本文將 (11) 中的  納入 iMF 公式,即使用帶有 -loss 的等式 (7)。具體來說,本文的優(yōu)化目標(biāo)是:

其中 。

從概念上講,這是帶有 -prediction 的 -loss,其中  通過  的關(guān)系轉(zhuǎn)換為  空間以回歸 。表 1 總結(jié)了這種關(guān)系。相應(yīng)的偽代碼在 Alg. 1 中。

帶有感知損失的像素平均流

網(wǎng)絡(luò)  直接將含噪輸入  映射到去噪圖像。這使得在訓(xùn)練時(shí)能夠?qū)崿F(xiàn)“所見即所得”的行為。因此,除了  損失外,本文還可以進(jìn)一步結(jié)合感知損失;跐撛诘姆椒ㄔ tokenizer 重建訓(xùn)練期間受益于感知損失,而基于像素的方法尚未能利用這一優(yōu)勢。

形式上,由于  是像素中的去噪圖像,本文直接對其應(yīng)用感知損失(例如 LPIPS)。本文的總體訓(xùn)練目標(biāo)是 ,其中  表示  和真實(shí)干凈圖像  之間的感知損失, 是權(quán)重超參數(shù)。在實(shí)踐中,僅當(dāng)添加的噪聲低于某個(gè)閾值(即 )時(shí)才應(yīng)用感知損失,以使去噪圖像不會(huì)太模糊。本文研究了基于 VGG 分類器的標(biāo)準(zhǔn) LPIPS 損失和基于 ConvNeXt-V2 的變體。

與前人工作的關(guān)系

本文的 pMF 與幾種先前的少步/一步方法密切相關(guān),討論如下。

Consistency Models (CM): 學(xué)習(xí)從含噪樣本  直接到生成圖像的映射。在本文的符號中,這對應(yīng)于固定終點(diǎn) 。此外,CM 通常采用預(yù)處理器 (Pre-conditioner),其形式為 。除非  為零,否則網(wǎng)絡(luò)不執(zhí)行純粹的 -prediction。

Consistency Trajectory Models (CTM): 制定了一個(gè)雙時(shí)間量。與基于導(dǎo)數(shù)公式的 MeanFlow 不同,CTM 依賴于在訓(xùn)練期間對 ODE 進(jìn)行積分。

Flow Map Matching (FMM): 也是基于雙時(shí)間量。在本文符號中,F(xiàn)low Map 扮演位移的角色,即 。該量通常不位于低維流形上(例如  是含噪圖像)。

實(shí)驗(yàn)

本文通過 2D 玩具實(shí)驗(yàn)(圖 2)證明,當(dāng)?shù)讓訑?shù)據(jù)位于低維流形上時(shí),在 MeanFlow 中使用 -prediction 是更可取的。實(shí)驗(yàn)設(shè)置遵循。

形式上,本文考慮定義在 2D 空間上的底層數(shù)據(jù)分布(此處為 Swiss roll)。數(shù)據(jù)使用  列正交矩陣投影到  維觀測空間。本文在  維觀測空間上訓(xùn)練 MeanFlow 模型,其中 。本文比較了-prediction 與本文的 -prediction。

圖 2 顯示,-prediction 表現(xiàn)相當(dāng)不錯(cuò),而當(dāng)  增加時(shí),-prediction 迅速退化。本文觀察到這種性能差距反映在訓(xùn)練損失的差異上:-prediction 產(chǎn)生的訓(xùn)練損失低于 -prediction 對應(yīng)物。這表明對于容量有限的網(wǎng)絡(luò),預(yù)測  更容易。

ImageNet 實(shí)驗(yàn)

本文默認(rèn)在分辨率 256×256 的 ImageNet 上進(jìn)行消融實(shí)驗(yàn)。報(bào)告基于 50,000 個(gè)生成樣本的 FID。所有模型均通過單次函數(shù)評估 (1-NFE) 生成原始像素圖像。

網(wǎng)絡(luò)的預(yù)測目標(biāo)

本文的方法基于流形假設(shè),即  位于低維流形上且更容易預(yù)測。本文在表 2 中驗(yàn)證了這一假設(shè)。

64×64 分辨率: patch 維度為 48 ()。這個(gè)維度遠(yuǎn)低于網(wǎng)絡(luò)容量。結(jié)果顯示 pMF 在 -prediction 和 -prediction 下都表現(xiàn)良好。

256×256 分辨率: patch 維度為 768 ()。這導(dǎo)致高維觀測空間,神經(jīng)網(wǎng)絡(luò)更難建模。在這種情況下,只有 -prediction 表現(xiàn)良好(FID 9.56),這表明  位于較低維流形上,因此更適合學(xué)習(xí)。相比之下,-prediction 發(fā)生災(zāi)難性失。‵ID 164.89):作為一個(gè)含噪量, 在高維空間中具有全支撐,更難建模。

消融研究

優(yōu)化器 本文發(fā)現(xiàn)優(yōu)化器的選擇在 pMF 中起著重要作用。在圖 3a 中,本文比較了標(biāo)準(zhǔn) Adam 優(yōu)化器與最近提出的 Muon。Muon 表現(xiàn)出更快的收斂速度和大幅提升的 FID(從 11.86 提升至 8.71)。在一步生成設(shè)置中,更快的收斂優(yōu)勢被進(jìn)一步放大,因?yàn)楦玫木W(wǎng)絡(luò)能提供更準(zhǔn)確的停止梯度目標(biāo)。

感知損失 在圖 3b 中,本文進(jìn)一步結(jié)合感知損失。使用標(biāo)準(zhǔn) VGG-based LPIPS 將 FID 從 9.56 提升至 5.62;結(jié)合 ConvNeXt-V2 變體進(jìn)一步將 FID 提升至 3.53?傮w而言,結(jié)合感知損失帶來了約 6 個(gè) FID 點(diǎn)的提升。

替代方案:預(yù)處理器 本文比較了三種預(yù)處理器變體:(i) 線性;(ii) EDM 風(fēng)格;(iii) sCM 風(fēng)格。表 3a 顯示,盡管 EDM 和 sCM 風(fēng)格優(yōu)于樸素線性變體,但在本文考慮的極高維輸入機(jī)制中,簡單的 -prediction 更可取且性能更好。這是因?yàn)槌?nbsp;,否則網(wǎng)絡(luò)預(yù)測會(huì)偏離  空間,可能位于更高維流形上。

替代方案:時(shí)間采樣器 本文研究了限制時(shí)間采樣的替代設(shè)計(jì):僅 (即 Flow Matching),僅 (類似 CM),或兩者的組合。表 3b 顯示這些受限的時(shí)間采樣器都不足以解決本文考慮的挑戰(zhàn)性場景。這表明 MeanFlow 方法利用  點(diǎn)之間的關(guān)系來學(xué)習(xí)場,限制時(shí)間采樣可能會(huì)破壞這種公式。

高分辨率生成 在表4中,本文研究了分辨率 256、512 和 1024 下的 pMF。通過增加 patch size(例如 )來保持序列長度不變 (),導(dǎo)致極大的 patch 維度(例如 12288)。結(jié)果顯示 pMF 可以有效處理這種極具挑戰(zhàn)性的情況。即使觀測空間是高維的,模型始終預(yù)測 ,其潛在維度不會(huì)成比例增長。

可擴(kuò)展性 表 5 顯示,增加模型大小和訓(xùn)練周期均能提升結(jié)果。

系統(tǒng)級比較

ImageNet 256×256. 表 6 顯示本文的方法達(dá)到了 2.22 FID。據(jù)本文所知,該類別中(一步、無潛在擴(kuò)散/流)唯一的方法是最近提出的 EPG,其 FID 為 8.82。與領(lǐng)先的 GAN 相比,pMF 實(shí)現(xiàn)了相當(dāng)?shù)?FID,但計(jì)算量大幅降低(例如 StyleGAN-XL 的計(jì)算量是 pMF-H/16 的 5.8 倍)。

ImageNet 512×512. 表 7 顯示 pMF 在 512×512 下達(dá)到 2.48 FID。值得注意的是,其計(jì)算成本(參數(shù)量和 Gflops)與 256×256 對應(yīng)物相當(dāng)。唯一的開銷來自 patch embedding 和預(yù)測層。

結(jié)論

本質(zhì)上,圖像生成模型是從噪聲到圖像像素的映射。由于生成建模的固有挑戰(zhàn),該問題通常被分解為更易處理的子問題,涉及多個(gè)步驟和階段。雖然有效,但這些設(shè)計(jì)偏離了深度學(xué)習(xí)的端到端精神。

本文關(guān)于 pMF 的研究表明,神經(jīng)網(wǎng)絡(luò)是具有高度表現(xiàn)力的映射,當(dāng)設(shè)計(jì)得當(dāng)時(shí),能夠?qū)W習(xí)復(fù)雜的端到端映射,例如直接從噪聲到像素。除了其實(shí)際潛力外,本文希望本工作將鼓勵(lì)未來對直接、端到端生成建模的探索。

參考文獻(xiàn)

[1] One-step Latent-free Image Generation with Pixel Mean Flows

       原文標(biāo)題 : 擴(kuò)散模型迎來“終極簡化”!何愷明團(tuán)隊(duì)新作:像素級一步生成,速度質(zhì)量雙巔峰

聲明: 本文由入駐維科號的作者撰寫,觀點(diǎn)僅代表作者本人,不代表OFweek立場。如有侵權(quán)或其他問題,請聯(lián)系舉報(bào)。

發(fā)表評論

0條評論,0人參與

請輸入評論內(nèi)容...

請輸入評論/評論長度6~500個(gè)字

您提交的評論過于頻繁,請輸入驗(yàn)證碼繼續(xù)

  • 看不清,點(diǎn)擊換一張  刷新

暫無評論

暫無評論

    人工智能 獵頭職位 更多
    掃碼關(guān)注公眾號
    OFweek人工智能網(wǎng)
    獲取更多精彩內(nèi)容
    文章糾錯(cuò)
    x
    *文字標(biāo)題:
    *糾錯(cuò)內(nèi)容:
    聯(lián)系郵箱:
    *驗(yàn) 證 碼:

    粵公網(wǎng)安備 44030502002758號