import torch
from torchvision import transforms
fromPILimport Image
import requests
from io import BytesIO
import matplotlib.pyplot as plt
# Load an image from a URL
image_url ="https://upload.wikimedia.org/wikipedia/commons/2/26/YellowLabradorLooking_new.jpg"
# Use a user agent to avoid being blocked by the website
headers ={"User-Agent":"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36"}
response = requests.get(image_url, headers=headers)
# Check if the request was successful
if response.status_code ==200:
image = Image.open(BytesIO(response.content))
# Display the image
plt.imshow(image)
plt.axis('off')
plt.title('Original Image')
plt.show()
from torchvision.models import vit_b_16
# Step 3: Load a pre-trained Vision Transformer model
model =vit_b_16(pretrained=True)
model.eval() # Set the model to evaluation mode(no training happening here)
# Forward pass through the model
with torch.no_grad(): # No gradients are needed,as we are only doing inference
output =model(input_batch)
# Output: This will be a classification result(e.g., ImageNet classes)
1.
2.
3.
4.
5.
6.
7.
8.
9.
10.
11.
步骤4:解释输出
让我们从ImageNet数据集中获取预测的标签。
# Step 4: Interpret the output
from torchvision import models
# Load ImageNet labels for interpretation
imagenet_labels = requests.get("https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json").json()
# Get the index of the highest score
_, predicted_class = torch.max(output,1)
# Display the predicted classpredicted_label= imagenet_labels[predicted_class.item()]print(f"Predicted Label: {predicted_label}")
# Visualize the result
plt.imshow(image)
plt.axis('off')
plt.title(f"Predicted: {predicted_label}")
plt.show()