1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import json
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer

# === 設定路徑 ===
vae_path = "/mnt/nfs/nina/nina/visa_task/visa_task/vae"
unet_path = "/mnt/nfs/nina/nina/visa_task/visa_task/unet"
text_encoder_path = "/mnt/nfs/nina/nina/visa_task/visa_task/text_encoder"
tokenizer_path = "/mnt/nfs/nina/nina/visa_task/visa_task/tokenizer"
json_path = "/mnt/nfs/nina/nina/visa_task/visa_task/src_VisA_filename_n_description_fixed.json"
save_output_dir = "./inference_results/"
os.makedirs(save_output_dir, exist_ok=True)

# === 裝置設定 ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === 載入模型 ===
print("🔄 載入模型與 tokenizer...")
vae = AutoencoderKL.from_pretrained(vae_path).to(device)
unet = UNet2DConditionModel.from_pretrained(unet_path).to(device)
text_encoder = CLIPTextModel.from_pretrained(text_encoder_path).to(device)
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
vae.eval()
unet.eval()
text_encoder.eval()

# === 載入 JSON 並篩出 anomaly 圖片及 prompt ===
with open(json_path, "r") as f:
json_data = json.load(f)

img_prompt_pairs = []
for img_info in json_data:
if "Anomaly" in img_info["file_name"]:
prompt = img_info.get("text", "a photo of anomaly") # 若無則設為預設
img_prompt_pairs.append((img_info["file_name"], prompt))

# === 圖像轉換 ===
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]) # [-1, 1] input for SD VAE
])

# === 推論數量上限(可調整)===
num_infer = 5
img_prompt_pairs = img_prompt_pairs[:num_infer]

print(f"🔍 即將推論 {len(img_prompt_pairs)} 張圖片...")

for img_name, prompt in img_prompt_pairs:
img_path = img_name # JSON 中已為絕對路徑
try:
img = Image.open(img_path).convert("RGB")
except Exception as e:
print(f"⚠️ 無法讀取圖片 {img_path}:{e}")
continue

input_tensor = transform(img).unsqueeze(0).to(device)

# === 生成條件向量 ===
text_inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(device)

with torch.no_grad():
text_embedding = text_encoder(**text_inputs).last_hidden_state
latents = vae.encode(input_tensor).latent_dist.sample() * 0.18215
output = unet(
latents,
timestep=torch.tensor(0).to(device), # 加入必要的 timestep
encoder_hidden_states=text_embedding
).sample

mask = torch.argmax(output[0], dim=0).cpu().numpy()

# === 儲存結果 ===
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title("Input Image")
plt.imshow(img)
plt.axis("off")

plt.subplot(1, 2, 2)
plt.title(f"Predicted Mask\nPrompt: {prompt}")
plt.imshow(mask, cmap="jet")
plt.axis("off")

img_base = os.path.basename(img_name)
plt.savefig(os.path.join(save_output_dir, f"{img_base}_mask.png"))
plt.close()

print("✅ 所有推論已完成,請查看 inference_results/")