새소식

연구 및 프로젝트

TorchCFM 사용법

  • -

1. Flow Matcher 클래스들

ConditionalFlowMatcher

from torchcfm import ConditionalFLowMatcher

# init
fm = ConditionalFlowMatcher(sigma=0.0)

t, xt, ut = fm.sample_location_and_conditional_flow(x0=None, x1=data)

 

ExactOptimalTransportConditionalFlowMatcher

from torchcfm import ExactOptimalTransportConditionalFlowMatcher

# 더 정교하고 계산 비용 높음
fm_ot = ExactOptimalTransportConditionalFlowMatcher(sigma=0.0
t, xt, ut = fm_ot.sample_location_and_conditional_flow(x0=x0, x1=x1)

 

2. Sampling Methods

from torchcfm import sample_ode

with torch.no_grad():
    x0 = torch.randn(100, 2)  # 초기 노이즈
    samples = sample_ode(
        model=trained_model,
        x0=x0,
        method='dopri5',     # ODE 솔버 방법
        step_size=0.1,       # 스텝 크기
        atol=1e-5,          # 절대 허용 오차
        rtol=1e-3           # 상대 허용 오차
    )

#available methods
methods = ['euler', 'midpoint', 'rk4', 'dopri5', 'dopri8']

# 더 빠르지만 덜 정확한 샘플링
samples = sample_euler(
    model=trained_model,
    x0=x0,
    step_size=0.01,
    num_steps=100
)

sample_location_and_conditional_flow

  • t : 시간 스텝
  • xt : 시간 t에서의 interpolated 데이터
  • ut : 시간 t에서의 조건부 벡터 필드
def sample_location_and_conditional_flow(self, x0=None, x1=None):
    """
    x0: 시작점 (source) - None이면 가우시안 노이즈에서 시작
    x1: 끝점 (target) - 실제 데이터 샘플
    """
import torch
from torchcfm import ConditionalFlowMatcher

# Example
fm = ConditionalFlowMatcher(sigma=0.0)
x1 = torch.randn(32, 2)  # 배치 크기 32, 2차원 데이터

t, xt, ut = fm.sample_location_and_conditional_flow(x0=None, x1=x1)

print(f"t shape: {t.shape}")    # [32, 1] - 각 샘플에 대한 시간
print(f"xt shape: {xt.shape}")  # [32, 2] - 시간 t에서의 interpolated 위치
print(f"ut shape: {ut.shape}")  # [32, 2] - 시간 t에서의 벡터 필드
# 특정 소스에서 타겟으로의 변환
x0 = torch.zeros(64, 5)      # 원점에서 시작
x1 = torch.randn(64, 5)      # 가우시안 분포로 이동

t, xt, ut = fm.sample_location_and_conditional_flow(x0=x0, x1=x1)
# sigma=0.0 (deterministic)
fm_det = ConditionalFlowMatcher(sigma=0.0)
t, xt, ut = fm_det.sample_location_and_conditional_flow(x0=None, x1=x1)

# sigma>0 (stochastic)
fm_stoch = ConditionalFlowMatcher(sigma=0.1)
t, xt, ut = fm_stoch.sample_location_and_conditional_flow(x0=None, x1=x1)
# 이 경우 xt에 추가적인 노이즈가 더해짐

 

 

3. Loss

 

 

 

4. 주요 Parameters

  • sigma : 노이즈 레벨 (0.0=deterministic, >0=stochastic)
  • method: ODE 솔버 ('euler', 'dopri5', 'rk4' 등)
  • step_size: 적분 스텝 크기
  • atol/rtol: 허용 오차 (정확도 vs 속도 trade-off)
  • num_steps: 이산화 스텝 수 (euler 방법용)
Contents

포스팅 주소를 복사했습니다

이 글이 도움이 되었다면 공감 부탁드립니다.