Ejemplo: casos adversarios con VGG16#
Utilizando el conocimiento de Física Diferenciable, queremos realizar ataques adversarios a una red VGG16 entrenada. Queremos pensar cómo debemos modificar la imagen de entrada para que la red se confunda de clase! Esto lo podemos hacer si pensamos a la red VGG16 como un simulador, que depende de la imagen de entrada y cuyo objetivo es una determinada clase objetivo.
Planteamos el problema de optimización como cuál es la modificación (perturbación) que debo hacer a la imagen de entrada, de manera tal de que la predicción sea la que yo quiera, es decir una clase objetivo seleccionada a priori (independientemente de que la imagen de entrada sea esa clase objetivo o no).
Eso produce un ejemplo adversario: no entrenamos el modelo, sino que optimizamos la imagen para que engañe a la red.
Preparación#
Usamos un modelo VGG16 preentrenado de torchvision. Sus pesos permanecerán congelados durante todo el notebook porque el objetivo no es entrenar la red, sino encontrar una perturbacion a la imagen seleccionada tal que engañe a la red.
import json
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import models, transforms
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("dispositivo:", device)
with open("imagenet.json") as f:
imagenet_labels = json.load(f)
try:
weights = models.VGG16_Weights.DEFAULT
model = models.vgg16(weights=weights)
except AttributeError:
model = models.vgg16(pretrained=True)
model = model.to(device).eval()
# acá congelamos los pesos del modelo, ya que no los vamos a actualizar
for param in model.parameters():
param.requires_grad = False
imagenet_mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
imagenet_std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)
to_tensor = transforms.ToTensor()
dispositivo: cpu
---------------------------------------------------------------------------
FileNotFoundError Traceback (most recent call last)
Cell In[1], line 14
10
11 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12 print("dispositivo:", device)
13
---> 14 with open("imagenet.json") as f:
15 imagenet_labels = json.load(f)
16
17 try:
FileNotFoundError: [Errno 2] No such file or directory: 'imagenet.json'
Cargar y mostrar la imagen#
Usamos PIL para evitar la confusión entre BGR y RGB que suele aparecer con cv2.
import requests
from io import BytesIO
def load_image(url):
response = requests.get(url)
img = Image.open(BytesIO(response.content)).convert("RGB")
return img
!unset proxy_https
im = load_image("https://t1.gstatic.com/licensed-image?q=tbn:ANd9GcSQgHnpqrWBnv3Y7q-Q9sXdNxFAxXCvyaDg38Qs04T7rcLw-Bk8UJQngdqO_T2NGp1jwnHY-AaeJYms9cZndhE")
im = load_image("http://t1.gstatic.com/licensed-image?q=tbn:ANd9GcQ4acvCNG_ou7PKgWfo4gyBlind6CGzskM_1MwF00Hsq0Zu9QxFs0Wo1MoFqDwb7-olbTqj2aOvpTCl4nOTMww")
plt.imshow(im)
pil_image= im.convert("RGB").resize((224, 224))
image=to_tensor(pil_image).unsqueeze(0).to(device)
plt.figure(figsize=(4, 4))
plt.imshow(pil_image)
plt.axis("off")
plt.title("Imagen original")
plt.show()
Funciones auxiliares#
Estas funciones condensan el proceso completo.
normalize_batch: adapta la imagen al formato que espera VGG16.predict_topk: devuelve las clases más probables.targeted_attack_step: realiza un paso de optimización sobre la imagen.
Ese último punto es donde aparece la analogía más clara con física diferenciable: la imagen actual actúa como estado de entrada, la red actúa como sistema diferenciable, la pérdida mide qué tan lejos estamos del objetivo y backward() propaga el gradiente hasta los píxeles.
def normalize_batch(batch):
return (batch - imagenet_mean) / imagenet_std
def predict_topk(batch, k=5):
with torch.no_grad():
logits = model(normalize_batch(batch))
probs = F.softmax(logits, dim=1)
top_probs, top_indices = probs.topk(k, dim=1)
rows = []
for prob, idx in zip(top_probs[0].cpu().tolist(), top_indices[0].cpu().tolist()):
rows.append((idx, imagenet_labels[str(idx)] if str(idx) in imagenet_labels else imagenet_labels[idx], prob))
return rows
def show_predictions(batch, title, k=5):
print(title)
for idx, label, prob in predict_topk(batch, k=k):
print(f" clase {idx:4d}: {label:25s} prob={100 * prob:6.2f}%")
def targeted_attack_step(current_image, target_index, step_size):
attack_image = current_image.clone().detach().requires_grad_(True)
logits = model(normalize_batch(attack_image))
target = torch.tensor([target_index], device=device)
loss = F.cross_entropy(logits, target)
loss.backward()
with torch.no_grad():
updated_image = attack_image - step_size * attack_image.grad
updated_image = updated_image.clamp(0.0, 1.0)
return updated_image.detach(), loss.item(), attack_image.grad.detach()
Inspeccionar la predicción original#
Antes de atacar la imagen, observamos qué cree ver el modelo.
show_predictions(image, "Predicciones antes del ataque")
Predicciones antes del ataque
clase 783: screw prob= 99.68%
clase 677: nail prob= 0.32%
clase 784: screwdriver prob= 0.00%
clase 506: coil prob= 0.00%
clase 543: dumbbell prob= 0.00%
Elegir una clase objetivo#
Fijamos una clase manualmente como objetivo.
En un ataque dirigido, la pregunta es: ¿cómo debo modificar la imagen para que el clasificador crea que pertenece a esta clase?
target_index = 2
target_label = imagenet_labels[str(target_index)] if str(target_index) in imagenet_labels else imagenet_labels[target_index]
print("Clase objetivo:", target_index, target_label)
Clase objetivo: 2 great white shark
Entender un solo paso de gradiente#
Este es el núcleo conceptual del notebook.
Calculamos el gradiente de la pérdida con respecto a cada píxel. Ese gradiente responde a la pregunta:
si cambio un poco este píxel, cómo cambia la pérdida de la clase objetivo?
Esa es exactamente la lógica de los sistemas diferenciables: obtener sensibilidad de la salida con respecto a las entradas o parámetros.
one_step_image, one_step_loss, one_step_grad = targeted_attack_step(image, target_index, step_size=0.01)
print("Pérdida tras calcular un paso:", one_step_loss)
print("Forma del gradiente:", tuple(one_step_grad.shape))
print("Mínimo/máximo del gradiente:", one_step_grad.min().item(), one_step_grad.max().item())
Pérdida tras calcular un paso: 31.317428588867188
Forma del gradiente: (1, 3, 224, 224)
Mínimo/máximo del gradiente: -0.8111703395843506 1.074661135673523
grad_vis = one_step_grad[0].detach().cpu().numpy()
grad_vis = np.transpose(grad_vis, (1, 2, 0))
grad_vis = grad_vis - grad_vis.min()
grad_vis = grad_vis / max(grad_vis.max(), 1e-8)
plt.figure(figsize=(4, 4))
plt.imshow(grad_vis)
plt.axis("off")
plt.title("Visualización del gradiente en los píxeles")
plt.show()
Ejecutar el ataque completo#
Ahora repetimos el paso de gradiente muchas veces.
En cada iteración:
pasamos la imagen actual por VGG16,
calculamos la pérdida para la clase objetivo,
hacemos backpropagation,
actualizamos los píxeles,
recortamos los valores al rango válido
[0, 1].
num_steps = 100
step_size = 0.01
attack_image = image.clone().detach()
loss_history = []
target_prob_history = []
for step in range(num_steps):
attack_image, loss_value, _ = targeted_attack_step(attack_image, target_index, step_size)
loss_history.append(loss_value)
with torch.no_grad():
probs = F.softmax(model(normalize_batch(attack_image)), dim=1)
target_prob = probs[0, target_index].item()
target_prob_history.append(target_prob)
print(f"paso {step + 1:02d}: pérdida={loss_value:.4f}, prob_objetivo={100 * target_prob:.2f}%")
paso 01: pérdida=31.3174, prob_objetivo=0.00%
paso 02: pérdida=27.9523, prob_objetivo=0.00%
paso 03: pérdida=22.9955, prob_objetivo=0.00%
paso 04: pérdida=17.4814, prob_objetivo=0.00%
paso 05: pérdida=13.6885, prob_objetivo=0.00%
paso 06: pérdida=11.0174, prob_objetivo=0.01%
paso 07: pérdida=9.1128, prob_objetivo=0.03%
paso 08: pérdida=7.9584, prob_objetivo=0.07%
paso 09: pérdida=7.3109, prob_objetivo=0.11%
paso 10: pérdida=6.7813, prob_objetivo=0.18%
paso 11: pérdida=6.3123, prob_objetivo=0.35%
paso 12: pérdida=5.6564, prob_objetivo=0.79%
paso 13: pérdida=4.8388, prob_objetivo=1.72%
paso 14: pérdida=4.0629, prob_objetivo=3.16%
paso 15: pérdida=3.4538, prob_objetivo=7.05%
paso 16: pérdida=2.6526, prob_objetivo=13.56%
paso 17: pérdida=1.9984, prob_objetivo=25.63%
paso 18: pérdida=1.3616, prob_objetivo=37.76%
paso 19: pérdida=0.9739, prob_objetivo=49.23%
paso 20: pérdida=0.7086, prob_objetivo=57.22%
paso 21: pérdida=0.5582, prob_objetivo=62.37%
paso 22: pérdida=0.4721, prob_objetivo=65.40%
paso 23: pérdida=0.4247, prob_objetivo=68.90%
paso 24: pérdida=0.3725, prob_objetivo=72.54%
paso 25: pérdida=0.3211, prob_objetivo=76.78%
paso 26: pérdida=0.2643, prob_objetivo=79.36%
paso 27: pérdida=0.2312, prob_objetivo=80.94%
paso 28: pérdida=0.2114, prob_objetivo=82.08%
paso 29: pérdida=0.1975, prob_objetivo=83.10%
paso 30: pérdida=0.1851, prob_objetivo=83.96%
paso 31: pérdida=0.1748, prob_objetivo=84.72%
paso 32: pérdida=0.1658, prob_objetivo=85.32%
paso 33: pérdida=0.1587, prob_objetivo=85.89%
paso 34: pérdida=0.1521, prob_objetivo=86.43%
paso 35: pérdida=0.1459, prob_objetivo=86.89%
paso 36: pérdida=0.1406, prob_objetivo=87.32%
paso 37: pérdida=0.1356, prob_objetivo=87.69%
paso 38: pérdida=0.1313, prob_objetivo=88.05%
paso 39: pérdida=0.1272, prob_objetivo=88.36%
paso 40: pérdida=0.1238, prob_objetivo=88.64%
paso 41: pérdida=0.1206, prob_objetivo=88.99%
paso 42: pérdida=0.1166, prob_objetivo=89.23%
paso 43: pérdida=0.1139, prob_objetivo=89.55%
paso 44: pérdida=0.1103, prob_objetivo=89.80%
paso 45: pérdida=0.1076, prob_objetivo=90.05%
paso 46: pérdida=0.1048, prob_objetivo=90.27%
paso 47: pérdida=0.1023, prob_objetivo=90.48%
paso 48: pérdida=0.1001, prob_objetivo=90.68%
paso 49: pérdida=0.0978, prob_objetivo=90.90%
paso 50: pérdida=0.0954, prob_objetivo=91.10%
paso 51: pérdida=0.0932, prob_objetivo=91.29%
paso 52: pérdida=0.0912, prob_objetivo=91.47%
paso 53: pérdida=0.0892, prob_objetivo=91.63%
paso 54: pérdida=0.0874, prob_objetivo=91.79%
paso 55: pérdida=0.0857, prob_objetivo=91.94%
paso 56: pérdida=0.0840, prob_objetivo=92.12%
paso 57: pérdida=0.0821, prob_objetivo=92.27%
paso 58: pérdida=0.0805, prob_objetivo=92.42%
paso 59: pérdida=0.0788, prob_objetivo=92.56%
paso 60: pérdida=0.0773, prob_objetivo=92.69%
paso 61: pérdida=0.0760, prob_objetivo=92.82%
paso 62: pérdida=0.0745, prob_objetivo=92.95%
paso 63: pérdida=0.0731, prob_objetivo=93.07%
paso 64: pérdida=0.0718, prob_objetivo=93.18%
paso 65: pérdida=0.0707, prob_objetivo=93.27%
paso 66: pérdida=0.0697, prob_objetivo=93.40%
paso 67: pérdida=0.0683, prob_objetivo=93.50%
paso 68: pérdida=0.0672, prob_objetivo=93.59%
paso 69: pérdida=0.0662, prob_objetivo=93.68%
paso 70: pérdida=0.0653, prob_objetivo=93.79%
paso 71: pérdida=0.0641, prob_objetivo=93.88%
paso 72: pérdida=0.0631, prob_objetivo=93.97%
paso 73: pérdida=0.0622, prob_objetivo=94.05%
paso 74: pérdida=0.0614, prob_objetivo=94.13%
paso 75: pérdida=0.0605, prob_objetivo=94.22%
paso 76: pérdida=0.0596, prob_objetivo=94.29%
paso 77: pérdida=0.0588, prob_objetivo=94.38%
paso 78: pérdida=0.0578, prob_objetivo=94.45%
paso 79: pérdida=0.0571, prob_objetivo=94.54%
paso 80: pérdida=0.0562, prob_objetivo=94.61%
paso 81: pérdida=0.0554, prob_objetivo=94.68%
paso 82: pérdida=0.0546, prob_objetivo=94.75%
paso 83: pérdida=0.0540, prob_objetivo=94.82%
paso 84: pérdida=0.0532, prob_objetivo=94.88%
paso 85: pérdida=0.0526, prob_objetivo=94.95%
paso 86: pérdida=0.0519, prob_objetivo=95.00%
paso 87: pérdida=0.0512, prob_objetivo=95.06%
paso 88: pérdida=0.0506, prob_objetivo=95.12%
paso 89: pérdida=0.0500, prob_objetivo=95.18%
paso 90: pérdida=0.0494, prob_objetivo=95.23%
paso 91: pérdida=0.0488, prob_objetivo=95.29%
paso 92: pérdida=0.0482, prob_objetivo=95.34%
paso 93: pérdida=0.0477, prob_objetivo=95.40%
paso 94: pérdida=0.0471, prob_objetivo=95.45%
paso 95: pérdida=0.0466, prob_objetivo=95.50%
paso 96: pérdida=0.0461, prob_objetivo=95.55%
paso 97: pérdida=0.0456, prob_objetivo=95.59%
paso 98: pérdida=0.0451, prob_objetivo=95.64%
paso 99: pérdida=0.0446, prob_objetivo=95.68%
paso 100: pérdida=0.0441, prob_objetivo=95.73%
9. Comparar predicciones antes y después#
show_predictions(image, "Predicciones antes del ataque")
print()
show_predictions(attack_image, "Predicciones después del ataque")
Predicciones antes del ataque
clase 783: screw prob= 99.68%
clase 677: nail prob= 0.32%
clase 784: screwdriver prob= 0.00%
clase 506: coil prob= 0.00%
clase 543: dumbbell prob= 0.00%
Predicciones después del ataque
clase 2: great white shark prob= 95.73%
clase 3: tiger shark prob= 1.25%
clase 4: hammerhead shark prob= 0.74%
clase 833: submarine prob= 0.61%
clase 394: sturgeon prob= 0.45%
La perturbación suele ser pequeña en magnitud, pero puede cambiar mucho la salida del modelo.
original_np = image[0].detach().cpu().permute(1, 2, 0).numpy()
adversarial_np = attack_image[0].detach().cpu().permute(1, 2, 0).numpy()
perturbation = adversarial_np - original_np
perturbation_vis = perturbation - perturbation.min()
perturbation_vis = perturbation_vis / max(perturbation_vis.max(), 1e-8)
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(original_np)
axes[0].set_title("Original")
axes[1].imshow(adversarial_np)
axes[1].set_title("Adversaria")
axes[2].imshow(perturbation_vis)
axes[2].set_title("Perturbación")
for ax in axes:
ax.axis("off")
plt.tight_layout()
plt.show()
Si el ataque funciona, la probabilidad de la clase objetivo debería subir mientras la pérdida baja.
fig, axes = plt.subplots(1, 2, figsize=(10, 3))
axes[0].plot(loss_history)
axes[0].set_title("Pérdida objetivo")
axes[0].set_xlabel("Paso")
axes[0].set_ylabel("Cross-entropy")
axes[1].plot(target_prob_history)
axes[1].set_title("Probabilidad de la clase objetivo")
axes[1].set_xlabel("Paso")
axes[1].set_ylabel("Probabilidad")
plt.tight_layout()
plt.show()