Ejemplo: casos adversarios con VGG16#

Open in Colab

Utilizando el conocimiento de Física Diferenciable, queremos realizar ataques adversarios a una red VGG16 entrenada. Queremos pensar cómo debemos modificar la imagen de entrada para que la red se confunda de clase! Esto lo podemos hacer si pensamos a la red VGG16 como un simulador, que depende de la imagen de entrada y cuyo objetivo es una determinada clase objetivo.

Planteamos el problema de optimización como cuál es la modificación (perturbación) que debo hacer a la imagen de entrada, de manera tal de que la predicción sea la que yo quiera, es decir una clase objetivo seleccionada a priori (independientemente de que la imagen de entrada sea esa clase objetivo o no).

Eso produce un ejemplo adversario: no entrenamos el modelo, sino que optimizamos la imagen para que engañe a la red.

Preparación#

Usamos un modelo VGG16 preentrenado de torchvision. Sus pesos permanecerán congelados durante todo el notebook porque el objetivo no es entrenar la red, sino encontrar una perturbacion a la imagen seleccionada tal que engañe a la red.

import json
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import models, transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("dispositivo:", device)

with open("imagenet.json") as f:
    imagenet_labels = json.load(f)

try:
    weights = models.VGG16_Weights.DEFAULT
    model = models.vgg16(weights=weights)
except AttributeError:
    model = models.vgg16(pretrained=True)

model = model.to(device).eval()
# acá congelamos los pesos del modelo, ya que no los vamos a actualizar
for param in model.parameters():
    param.requires_grad = False

imagenet_mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
imagenet_std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)

to_tensor = transforms.ToTensor()
dispositivo: cpu
---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
Cell In[1], line 14
     10 
     11 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     12 print("dispositivo:", device)
     13 
---> 14 with open("imagenet.json") as f:
     15     imagenet_labels = json.load(f)
     16 
     17 try:

FileNotFoundError: [Errno 2] No such file or directory: 'imagenet.json'

Cargar y mostrar la imagen#

Usamos PIL para evitar la confusión entre BGR y RGB que suele aparecer con cv2.

import requests
from io import BytesIO
def load_image(url):
    response = requests.get(url)
    img = Image.open(BytesIO(response.content)).convert("RGB")
    return img
!unset proxy_https
im = load_image("https://t1.gstatic.com/licensed-image?q=tbn:ANd9GcSQgHnpqrWBnv3Y7q-Q9sXdNxFAxXCvyaDg38Qs04T7rcLw-Bk8UJQngdqO_T2NGp1jwnHY-AaeJYms9cZndhE")
im = load_image("http://t1.gstatic.com/licensed-image?q=tbn:ANd9GcQ4acvCNG_ou7PKgWfo4gyBlind6CGzskM_1MwF00Hsq0Zu9QxFs0Wo1MoFqDwb7-olbTqj2aOvpTCl4nOTMww")
plt.imshow(im)
pil_image= im.convert("RGB").resize((224, 224))
image=to_tensor(pil_image).unsqueeze(0).to(device)
plt.figure(figsize=(4, 4))
plt.imshow(pil_image)
plt.axis("off")
plt.title("Imagen original")
plt.show()
../_images/75da82a58094c76a711bb9b511a778656741f0b7b2a3ae0d4b8a351c2ddeeef6.png

Funciones auxiliares#

Estas funciones condensan el proceso completo.

  • normalize_batch: adapta la imagen al formato que espera VGG16.

  • predict_topk: devuelve las clases más probables.

  • targeted_attack_step: realiza un paso de optimización sobre la imagen.

Ese último punto es donde aparece la analogía más clara con física diferenciable: la imagen actual actúa como estado de entrada, la red actúa como sistema diferenciable, la pérdida mide qué tan lejos estamos del objetivo y backward() propaga el gradiente hasta los píxeles.

def normalize_batch(batch):
    return (batch - imagenet_mean) / imagenet_std


def predict_topk(batch, k=5):
    with torch.no_grad():
        logits = model(normalize_batch(batch))
        probs = F.softmax(logits, dim=1)
        top_probs, top_indices = probs.topk(k, dim=1)

    rows = []
    for prob, idx in zip(top_probs[0].cpu().tolist(), top_indices[0].cpu().tolist()):
        rows.append((idx, imagenet_labels[str(idx)] if str(idx) in imagenet_labels else imagenet_labels[idx], prob))
    return rows


def show_predictions(batch, title, k=5):
    print(title)
    for idx, label, prob in predict_topk(batch, k=k):
        print(f"  clase {idx:4d}: {label:25s}  prob={100 * prob:6.2f}%")


def targeted_attack_step(current_image, target_index, step_size):
    attack_image = current_image.clone().detach().requires_grad_(True)
    logits = model(normalize_batch(attack_image))
    target = torch.tensor([target_index], device=device)
    loss = F.cross_entropy(logits, target)
    loss.backward()

    with torch.no_grad():
        updated_image = attack_image - step_size * attack_image.grad
        updated_image = updated_image.clamp(0.0, 1.0)

    return updated_image.detach(), loss.item(), attack_image.grad.detach()

Inspeccionar la predicción original#

Antes de atacar la imagen, observamos qué cree ver el modelo.

show_predictions(image, "Predicciones antes del ataque")
Predicciones antes del ataque
  clase  783: screw                      prob= 99.68%
  clase  677: nail                       prob=  0.32%
  clase  784: screwdriver                prob=  0.00%
  clase  506: coil                       prob=  0.00%
  clase  543: dumbbell                   prob=  0.00%

Elegir una clase objetivo#

Fijamos una clase manualmente como objetivo.

En un ataque dirigido, la pregunta es: ¿cómo debo modificar la imagen para que el clasificador crea que pertenece a esta clase?

target_index = 2
target_label = imagenet_labels[str(target_index)] if str(target_index) in imagenet_labels else imagenet_labels[target_index]
print("Clase objetivo:", target_index, target_label)
Clase objetivo: 2 great white shark

Entender un solo paso de gradiente#

Este es el núcleo conceptual del notebook.

Calculamos el gradiente de la pérdida con respecto a cada píxel. Ese gradiente responde a la pregunta:

si cambio un poco este píxel, cómo cambia la pérdida de la clase objetivo?

Esa es exactamente la lógica de los sistemas diferenciables: obtener sensibilidad de la salida con respecto a las entradas o parámetros.

one_step_image, one_step_loss, one_step_grad = targeted_attack_step(image, target_index, step_size=0.01)
print("Pérdida tras calcular un paso:", one_step_loss)
print("Forma del gradiente:", tuple(one_step_grad.shape))
print("Mínimo/máximo del gradiente:", one_step_grad.min().item(), one_step_grad.max().item())
Pérdida tras calcular un paso: 31.317428588867188
Forma del gradiente: (1, 3, 224, 224)
Mínimo/máximo del gradiente: -0.8111703395843506 1.074661135673523
grad_vis = one_step_grad[0].detach().cpu().numpy()
grad_vis = np.transpose(grad_vis, (1, 2, 0))
grad_vis = grad_vis - grad_vis.min()
grad_vis = grad_vis / max(grad_vis.max(), 1e-8)

plt.figure(figsize=(4, 4))
plt.imshow(grad_vis)
plt.axis("off")
plt.title("Visualización del gradiente en los píxeles")
plt.show()
../_images/230c4182903a17de0a613eda0b14c8c5c0d713388f0c94ea374034757ad510dc.png

Ejecutar el ataque completo#

Ahora repetimos el paso de gradiente muchas veces.

En cada iteración:

  1. pasamos la imagen actual por VGG16,

  2. calculamos la pérdida para la clase objetivo,

  3. hacemos backpropagation,

  4. actualizamos los píxeles,

  5. recortamos los valores al rango válido [0, 1].

num_steps = 100
step_size = 0.01

attack_image = image.clone().detach()
loss_history = []
target_prob_history = []

for step in range(num_steps):
    attack_image, loss_value, _ = targeted_attack_step(attack_image, target_index, step_size)
    loss_history.append(loss_value)

    with torch.no_grad():
        probs = F.softmax(model(normalize_batch(attack_image)), dim=1)
        target_prob = probs[0, target_index].item()
        target_prob_history.append(target_prob)

    print(f"paso {step + 1:02d}: pérdida={loss_value:.4f}, prob_objetivo={100 * target_prob:.2f}%")
paso 01: pérdida=31.3174, prob_objetivo=0.00%
paso 02: pérdida=27.9523, prob_objetivo=0.00%
paso 03: pérdida=22.9955, prob_objetivo=0.00%
paso 04: pérdida=17.4814, prob_objetivo=0.00%
paso 05: pérdida=13.6885, prob_objetivo=0.00%
paso 06: pérdida=11.0174, prob_objetivo=0.01%
paso 07: pérdida=9.1128, prob_objetivo=0.03%
paso 08: pérdida=7.9584, prob_objetivo=0.07%
paso 09: pérdida=7.3109, prob_objetivo=0.11%
paso 10: pérdida=6.7813, prob_objetivo=0.18%
paso 11: pérdida=6.3123, prob_objetivo=0.35%
paso 12: pérdida=5.6564, prob_objetivo=0.79%
paso 13: pérdida=4.8388, prob_objetivo=1.72%
paso 14: pérdida=4.0629, prob_objetivo=3.16%
paso 15: pérdida=3.4538, prob_objetivo=7.05%
paso 16: pérdida=2.6526, prob_objetivo=13.56%
paso 17: pérdida=1.9984, prob_objetivo=25.63%
paso 18: pérdida=1.3616, prob_objetivo=37.76%
paso 19: pérdida=0.9739, prob_objetivo=49.23%
paso 20: pérdida=0.7086, prob_objetivo=57.22%
paso 21: pérdida=0.5582, prob_objetivo=62.37%
paso 22: pérdida=0.4721, prob_objetivo=65.40%
paso 23: pérdida=0.4247, prob_objetivo=68.90%
paso 24: pérdida=0.3725, prob_objetivo=72.54%
paso 25: pérdida=0.3211, prob_objetivo=76.78%
paso 26: pérdida=0.2643, prob_objetivo=79.36%
paso 27: pérdida=0.2312, prob_objetivo=80.94%
paso 28: pérdida=0.2114, prob_objetivo=82.08%
paso 29: pérdida=0.1975, prob_objetivo=83.10%
paso 30: pérdida=0.1851, prob_objetivo=83.96%
paso 31: pérdida=0.1748, prob_objetivo=84.72%
paso 32: pérdida=0.1658, prob_objetivo=85.32%
paso 33: pérdida=0.1587, prob_objetivo=85.89%
paso 34: pérdida=0.1521, prob_objetivo=86.43%
paso 35: pérdida=0.1459, prob_objetivo=86.89%
paso 36: pérdida=0.1406, prob_objetivo=87.32%
paso 37: pérdida=0.1356, prob_objetivo=87.69%
paso 38: pérdida=0.1313, prob_objetivo=88.05%
paso 39: pérdida=0.1272, prob_objetivo=88.36%
paso 40: pérdida=0.1238, prob_objetivo=88.64%
paso 41: pérdida=0.1206, prob_objetivo=88.99%
paso 42: pérdida=0.1166, prob_objetivo=89.23%
paso 43: pérdida=0.1139, prob_objetivo=89.55%
paso 44: pérdida=0.1103, prob_objetivo=89.80%
paso 45: pérdida=0.1076, prob_objetivo=90.05%
paso 46: pérdida=0.1048, prob_objetivo=90.27%
paso 47: pérdida=0.1023, prob_objetivo=90.48%
paso 48: pérdida=0.1001, prob_objetivo=90.68%
paso 49: pérdida=0.0978, prob_objetivo=90.90%
paso 50: pérdida=0.0954, prob_objetivo=91.10%
paso 51: pérdida=0.0932, prob_objetivo=91.29%
paso 52: pérdida=0.0912, prob_objetivo=91.47%
paso 53: pérdida=0.0892, prob_objetivo=91.63%
paso 54: pérdida=0.0874, prob_objetivo=91.79%
paso 55: pérdida=0.0857, prob_objetivo=91.94%
paso 56: pérdida=0.0840, prob_objetivo=92.12%
paso 57: pérdida=0.0821, prob_objetivo=92.27%
paso 58: pérdida=0.0805, prob_objetivo=92.42%
paso 59: pérdida=0.0788, prob_objetivo=92.56%
paso 60: pérdida=0.0773, prob_objetivo=92.69%
paso 61: pérdida=0.0760, prob_objetivo=92.82%
paso 62: pérdida=0.0745, prob_objetivo=92.95%
paso 63: pérdida=0.0731, prob_objetivo=93.07%
paso 64: pérdida=0.0718, prob_objetivo=93.18%
paso 65: pérdida=0.0707, prob_objetivo=93.27%
paso 66: pérdida=0.0697, prob_objetivo=93.40%
paso 67: pérdida=0.0683, prob_objetivo=93.50%
paso 68: pérdida=0.0672, prob_objetivo=93.59%
paso 69: pérdida=0.0662, prob_objetivo=93.68%
paso 70: pérdida=0.0653, prob_objetivo=93.79%
paso 71: pérdida=0.0641, prob_objetivo=93.88%
paso 72: pérdida=0.0631, prob_objetivo=93.97%
paso 73: pérdida=0.0622, prob_objetivo=94.05%
paso 74: pérdida=0.0614, prob_objetivo=94.13%
paso 75: pérdida=0.0605, prob_objetivo=94.22%
paso 76: pérdida=0.0596, prob_objetivo=94.29%
paso 77: pérdida=0.0588, prob_objetivo=94.38%
paso 78: pérdida=0.0578, prob_objetivo=94.45%
paso 79: pérdida=0.0571, prob_objetivo=94.54%
paso 80: pérdida=0.0562, prob_objetivo=94.61%
paso 81: pérdida=0.0554, prob_objetivo=94.68%
paso 82: pérdida=0.0546, prob_objetivo=94.75%
paso 83: pérdida=0.0540, prob_objetivo=94.82%
paso 84: pérdida=0.0532, prob_objetivo=94.88%
paso 85: pérdida=0.0526, prob_objetivo=94.95%
paso 86: pérdida=0.0519, prob_objetivo=95.00%
paso 87: pérdida=0.0512, prob_objetivo=95.06%
paso 88: pérdida=0.0506, prob_objetivo=95.12%
paso 89: pérdida=0.0500, prob_objetivo=95.18%
paso 90: pérdida=0.0494, prob_objetivo=95.23%
paso 91: pérdida=0.0488, prob_objetivo=95.29%
paso 92: pérdida=0.0482, prob_objetivo=95.34%
paso 93: pérdida=0.0477, prob_objetivo=95.40%
paso 94: pérdida=0.0471, prob_objetivo=95.45%
paso 95: pérdida=0.0466, prob_objetivo=95.50%
paso 96: pérdida=0.0461, prob_objetivo=95.55%
paso 97: pérdida=0.0456, prob_objetivo=95.59%
paso 98: pérdida=0.0451, prob_objetivo=95.64%
paso 99: pérdida=0.0446, prob_objetivo=95.68%
paso 100: pérdida=0.0441, prob_objetivo=95.73%

9. Comparar predicciones antes y después#

show_predictions(image, "Predicciones antes del ataque")
print()
show_predictions(attack_image, "Predicciones después del ataque")
Predicciones antes del ataque
  clase  783: screw                      prob= 99.68%
  clase  677: nail                       prob=  0.32%
  clase  784: screwdriver                prob=  0.00%
  clase  506: coil                       prob=  0.00%
  clase  543: dumbbell                   prob=  0.00%

Predicciones después del ataque
  clase    2: great white shark          prob= 95.73%
  clase    3: tiger shark                prob=  1.25%
  clase    4: hammerhead shark           prob=  0.74%
  clase  833: submarine                  prob=  0.61%
  clase  394: sturgeon                   prob=  0.45%

La perturbación suele ser pequeña en magnitud, pero puede cambiar mucho la salida del modelo.

original_np = image[0].detach().cpu().permute(1, 2, 0).numpy()
adversarial_np = attack_image[0].detach().cpu().permute(1, 2, 0).numpy()
perturbation = adversarial_np - original_np

perturbation_vis = perturbation - perturbation.min()
perturbation_vis = perturbation_vis / max(perturbation_vis.max(), 1e-8)

fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(original_np)
axes[0].set_title("Original")
axes[1].imshow(adversarial_np)
axes[1].set_title("Adversaria")
axes[2].imshow(perturbation_vis)
axes[2].set_title("Perturbación")

for ax in axes:
    ax.axis("off")

plt.tight_layout()
plt.show()
../_images/249ebad4647ae406d95a52c0362beb0c5f5fbfa3b426da7d0964dafce1f96bf1.png

Si el ataque funciona, la probabilidad de la clase objetivo debería subir mientras la pérdida baja.

fig, axes = plt.subplots(1, 2, figsize=(10, 3))
axes[0].plot(loss_history)
axes[0].set_title("Pérdida objetivo")
axes[0].set_xlabel("Paso")
axes[0].set_ylabel("Cross-entropy")

axes[1].plot(target_prob_history)
axes[1].set_title("Probabilidad de la clase objetivo")
axes[1].set_xlabel("Paso")
axes[1].set_ylabel("Probabilidad")

plt.tight_layout()
plt.show()
../_images/9637a99c1a8b04af2161d15b1f35c9ee4a09f5613fb1dfe32b2125a955689e65.png