1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | import os import json import torch from torchvision import transforms from PIL import Image import matplotlib.pyplot as plt from diffusers import AutoencoderKL, UNet2DConditionModel from transformers import CLIPTextModel, CLIPTokenizer # === 設定路徑 === vae_path = "/mnt/nfs/nina/nina/visa_task/visa_task/vae" unet_path = "/mnt/nfs/nina/nina/visa_task/visa_task/unet" text_encoder_path = "/mnt/nfs/nina/nina/visa_task/visa_task/text_encoder" tokenizer_path = "/mnt/nfs/nina/nina/visa_task/visa_task/tokenizer" json_path = "/mnt/nfs/nina/nina/visa_task/visa_task/src_VisA_filename_n_description_fixed.json" save_output_dir = "./inference_results/" os.makedirs(save_output_dir, exist_ok=True) # === 裝置設定 === device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # === 載入模型 === print("🔄 載入模型與 tokenizer...") vae = AutoencoderKL.from_pretrained(vae_path).to(device) unet = UNet2DConditionModel.from_pretrained(unet_path).to(device) text_encoder = CLIPTextModel.from_pretrained(text_encoder_path).to(device) tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) vae.eval() unet.eval() text_encoder.eval() # === 載入 JSON 並篩出 anomaly 圖片及 prompt === with open(json_path, "r") as f: json_data = json.load(f) img_prompt_pairs = [] for img_info in json_data: if "Anomaly" in img_info["file_name"]: prompt = img_info.get("text", "a photo of anomaly") # 若無則設為預設 img_prompt_pairs.append((img_info["file_name"], prompt)) # === 圖像轉換 === transform = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) # [-1, 1] input for SD VAE ]) # === 推論數量上限(可調整)=== num_infer = 5 img_prompt_pairs = img_prompt_pairs[:num_infer] print(f"🔍 即將推論 {len(img_prompt_pairs)} 張圖片...") for img_name, prompt in img_prompt_pairs: img_path = img_name # JSON 中已為絕對路徑 try: img = Image.open(img_path).convert("RGB") except Exception as e: print(f"⚠️ 無法讀取圖片 {img_path}:{e}") continue input_tensor = transform(img).unsqueeze(0).to(device) # === 生成條件向量 === text_inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(device) with torch.no_grad(): text_embedding = text_encoder(**text_inputs).last_hidden_state latents = vae.encode(input_tensor).latent_dist.sample() * 0.18215 output = unet( latents, timestep=torch.tensor(0).to(device), # 加入必要的 timestep encoder_hidden_states=text_embedding ).sample mask = torch.argmax(output[0], dim=0).cpu().numpy() # === 儲存結果 === plt.figure(figsize=(10, 5)) plt.subplot(1, 2, 1) plt.title("Input Image") plt.imshow(img) plt.axis("off") plt.subplot(1, 2, 2) plt.title(f"Predicted Mask\nPrompt: {prompt}") plt.imshow(mask, cmap="jet") plt.axis("off") img_base = os.path.basename(img_name) plt.savefig(os.path.join(save_output_dir, f"{img_base}_mask.png")) plt.close() print("✅ 所有推論已完成,請查看 inference_results/") |
Direct link: https://paste.plurk.com/show/cEOWYAca7zuxCViIEGku