| import os |
| import subprocess |
| import re |
| import random |
| import matplotlib.pyplot as plt |
| import json |
| def get_gpu_memory_usage(): |
| """Returns a list of GPU memory usage in MB.""" |
| try: |
| |
| result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used', '--format=csv,nounits,noheader'], |
| stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) |
|
|
| |
| if result.returncode != 0: |
| raise RuntimeError(f"nvidia-smi command failed with error: {result.stderr}") |
|
|
| |
| memory_usages = [int(x) for x in result.stdout.strip().split('\n')] |
| return memory_usages |
| except Exception as e: |
| print(f"Error querying GPU memory usage: {e}") |
| return [] |
|
|
| def set_cuda_visible_device(): |
| """Sets the CUDA_VISIBLE_DEVICES environment variable to the GPU with the smallest memory usage.""" |
| memory_usages = get_gpu_memory_usage() |
|
|
| if not memory_usages: |
| print("No GPU memory usage data available.") |
| return |
|
|
| |
| min_memory_index = memory_usages.index(min(memory_usages)) |
|
|
| |
| os.environ["CUDA_VISIBLE_DEVICES"] = str(min_memory_index) |
| print(f"Set CUDA_VISIBLE_DEVICES to GPU {min_memory_index} with {memory_usages[min_memory_index]} MB used.") |
|
|
| return str(min_memory_index) |
|
|
| os.environ["ASN_ROOT_DIR"] = "/home/nickj/asn/second_order_lens" |
| os.chdir(os.environ["ASN_ROOT_DIR"]) |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
| import os.path |
| import argparse |
| from pathlib import Path |
|
|
| from tqdm import tqdm |
| from utils.factory import create_model_and_transforms, get_tokenizer |
| from PIL import Image, ImageDraw |
|
|
| def get_model(model_name = "ViT-B/16", pretrained = "openai", device = "cuda:0"): |
| torch.multiprocessing.set_sharing_strategy("file_system") |
| model, _, preprocess = create_model_and_transforms( |
| model_name, pretrained=pretrained, force_quick_gelu=True, |
| ) |
| model.to(device) |
| model.eval() |
| context_length = model.context_length |
| vocab_size = model.vocab_size |
|
|
| return { |
| "model": model, |
| "model_name": model_name, |
| "pretrained": pretrained, |
| "preprocess": preprocess, |
| "context_length": context_length, |
| "vocab_size": vocab_size |
| } |
|
|
| img_path = "/datasets/ilsvrc_2024-01-04_1913/val/n04398044/ILSVRC2012_val_00042447.JPEG" |
| |
| def load_images(preprocess, image_folder = "/datasets/ilsvrc/current/val", count = 100, images_only = True): |
| file_list = [] |
|
|
| for root, dirs, files in os.walk(image_folder): |
| for file in files: |
| file_list.append(os.path.join(root, file)) |
|
|
| if count > len(file_list): |
| sampled_files = file_list |
| else: |
| sampled_files = random.sample(file_list, count) |
|
|
| image_files = [] |
|
|
| for filename in sampled_files: |
| image_files.append(preprocess(Image.open(filename))) |
| if images_only: |
| return image_files |
| else: |
| return image_files, sampled_files |
|
|
| def calc_neuron_potentials(model, attn_layers = (1, 2), include_layernorm = True): |
| |
|
|
| embed_dim = model.visual.transformer.resblocks[0].attn.out_proj.in_features |
| num_heads = model.visual.transformer.resblocks[0].attn.num_heads |
| head_dim = embed_dim // num_heads |
| layers = len(model.visual.transformer.resblocks) |
|
|
| results = dict() |
|
|
| for neuron_layer in tqdm(range(layers), desc = "Calculating attention shifting potentials"): |
| neuron_projection = model.visual.transformer.resblocks[neuron_layer].state_dict()["mlp.c_proj.weight"] |
| for l_attn in range(min(layers, neuron_layer + attn_layers[0]), min(layers, neuron_layer + attn_layers[1])): |
| ln_vector = model.visual.transformer.resblocks[l_attn].ln_1.state_dict()["weight"] |
| attn_matrix = model.visual.transformer.resblocks[l_attn].state_dict()["attn.in_proj_weight"] |
| W_Q, W_K, W_V = (attn_matrix[:embed_dim].reshape(num_heads, head_dim, -1), |
| attn_matrix[embed_dim:2*embed_dim].reshape(num_heads, head_dim, -1), |
| attn_matrix[2*embed_dim:].reshape(num_heads, head_dim, -1)) |
|
|
| for head_idx in range(num_heads): |
| W_Q_h, W_K_h = W_Q[head_idx], W_K[head_idx] |
| effects = [] |
| for i in range(neuron_projection.shape[1]): |
| if include_layernorm: |
| neuron_attn_effect = torch.norm(W_Q_h.T @ W_K_h @ (neuron_projection[:, i] * ln_vector)) |
| else: |
| neuron_attn_effect = torch.norm(W_Q_h.T @ W_K_h @ neuron_projection[:, i]) |
| effects.append(neuron_attn_effect) |
|
|
| results[(neuron_layer, l_attn, head_idx)] = torch.tensor(effects) |
| return results |
|
|
| def calc_top_asns(shift_potentials, top_k = 10, per = "layer", layers_away = 1): |
| num_layers = max([key[1] for key in shift_potentials.keys()]) + 1 |
| num_heads = max([key[2] for key in shift_potentials.keys()]) |
|
|
| top_asns = [] |
| for layer in range(num_layers - layers_away): |
| if per == "layer": |
| potentials = [] |
| for head_idx in range(num_heads): |
| potentials.append(shift_potentials[(layer, layer + layers_away, head_idx)]) |
| potentials = torch.max(torch.stack(potentials, dim = 0), dim = 0).values |
| _, sorted_indices = torch.sort(potentials, descending = True) |
| top_asns.append(sorted_indices[:top_k].tolist()) |
| elif per == "head": |
| top_layer_asns = [] |
| for head_idx in range(num_heads): |
| _, sorted_indices = torch.sort(shift_potentials[(layer, layer + layers_away, head_idx)], descending = True) |
| top_layer_asns.append(sorted_indices[:top_k].tolist()) |
| top_asns.append(top_layer_asns) |
| else: |
| raise ValueError(f"Invalid per value: {per}") |
| return top_asns |
|
|
| def aggregate_attn_map(attn_map, layer, head): |
| num_tokens = attn_map.shape[-1] |
| assert (num_tokens - 1) ** 0.5 % 1 == 0, "num_tokens - 1 is not a perfect square" |
|
|
| num_patches = int((num_tokens - 1) ** 0.5) |
| aggregate_scores = torch.sum(attn_map[:, layer, head, 1:, 1:], dim = 1).reshape((1, num_patches, num_patches)) |
| return aggregate_scores |
|
|
| def attn_map_cls_token(attn_map, layer, head): |
| |
| num_tokens = attn_map.shape[-1] |
| assert (num_tokens - 1) ** 0.5 % 1 == 0, "num_tokens - 1 is not a perfect square" |
|
|
| num_patches = int((num_tokens - 1) ** 0.5) |
| attn_map_reshaped = attn_map[:, layer, head, 0, 1:].reshape((1, num_patches, num_patches)) |
| return attn_map_reshaped |
|
|
| def visualize_attn_shift(attn_map1, attn_map2, image, display=True, out=None, min_diff=None, max_diff=None): |
| import matplotlib.pyplot as plt |
| import numpy as np |
|
|
| |
| diff_map = attn_map2 - attn_map1 |
|
|
| |
| image = image.convert("RGBA") |
| overlay = Image.new("RGBA", image.size, (0, 0, 0, 0)) |
| draw = ImageDraw.Draw(overlay) |
|
|
| |
| block_size_x = image.size[0] / diff_map.shape[0] |
| block_size_y = image.size[1] / diff_map.shape[1] |
|
|
| |
| cmap = plt.get_cmap('coolwarm_r') |
|
|
| |
| if max_diff is None: |
| max_diff = diff_map.max() |
| if min_diff is None: |
| min_diff = diff_map.min() |
|
|
| for i in range(diff_map.shape[0]): |
| for j in range(diff_map.shape[1]): |
| |
| intensity = diff_map[i, j] |
| normalized_intensity = (intensity - min_diff) / (max_diff - min_diff) |
| rgba_color = cmap(1 - normalized_intensity) |
| color = tuple(int(c * 255) for c in rgba_color[:3]) + (int(rgba_color[3] * 128),) |
|
|
| |
| draw.rectangle( |
| [j * block_size_x, i * block_size_y, (j + 1) * block_size_x, (i + 1) * block_size_y], |
| fill=color |
| ) |
|
|
| |
| combined = Image.alpha_composite(image, overlay) |
|
|
| if display: |
| |
| combined.show() |
|
|
| |
| plt.figure(figsize=(6, 1)) |
| plt.imshow([np.linspace(min_diff, max_diff, 256)], cmap='coolwarm_r', aspect='auto') |
| plt.gca().set_visible(False) |
| plt.colorbar(orientation="horizontal") |
| plt.show() |
|
|
| if out is not None: |
| combined.save(out) |
|
|
| return combined |
|
|
| def visualize_attn_shift_binary(attn_map1, attn_map2, image, display=True, out=None): |
| |
| |
| |
| diff_map = attn_map2 - attn_map1 |
|
|
| |
| diff_map_normalized = (diff_map - diff_map.min()) / (diff_map.max() - diff_map.min()) |
| |
| image = image.convert("RGBA") |
| overlay = Image.new("RGBA", image.size, (0, 0, 0, 0)) |
| draw = ImageDraw.Draw(overlay) |
|
|
| |
| block_size_x = image.size[0] / diff_map.shape[0] |
| block_size_y = image.size[1] / diff_map.shape[1] |
|
|
| for i in range(diff_map.shape[0]): |
| for j in range(diff_map.shape[1]): |
| |
| intensity = diff_map_normalized[i, j] |
| alpha = int(255 * 0.5) |
| if diff_map[i, j] > 0: |
| color = (0, int(255 * intensity), 0, alpha) |
| else: |
| color = (int(255 * (1 - intensity)), 0, 0, alpha) |
|
|
| |
| draw.rectangle( |
| [j * block_size_x, i * block_size_y, (j + 1) * block_size_x, (i + 1) * block_size_y], |
| fill=color |
| ) |
|
|
| |
| combined = Image.alpha_composite(image, overlay) |
|
|
| if display: |
| |
| combined.show() |
|
|
| if out is not None: |
| combined.save(out) |
|
|
| return combined |
|
|
| def is_outlier(mean, std, value): |
| return value < mean - 2 * std or value > mean + 2 * std |
|
|
|
|
| def get_neuron_activations(images, prs_group, model, device = "cuda:0"): |
| |
| random_neuron_acts = [] |
| for image in tqdm(images, desc="Processing images"): |
| prs_group.reinit() |
| image_input = image.unsqueeze(0).to(device) |
| representation = model.encode_image( |
| image_input, attn_method="head", normalize=False |
| ) |
| prs_group.finalize() |
| gelu_outs = prs_group.post_gelu_outputs() |
| random_neuron_acts.append(gelu_outs) |
| random_neuron_acts = torch.stack(random_neuron_acts, dim = 0) |
| return random_neuron_acts |
|
|
| def normalize_array(arr): |
| min_val = np.min(arr) |
| max_val = np.max(arr) |
| |
| if max_val - min_val == 0: |
| return np.zeros_like(arr) |
| normalized_arr = (arr - min_val) / (max_val - min_val) |
| return normalized_arr |
|
|
| def np_l2(arr1, arr2): |
| return np.linalg.norm(arr1 - arr2) |
|
|
| def best_class(classifier, representation): |
| cs = torch.cosine_similarity(classifier, representation.permute(1, 0), dim = 0) |
| return torch.argmax(cs).item(), cs[torch.argmax(cs).item()].item() |
|
|
| def load_group_attn_shifts(timestamp): |
| |
| results_dir = "./results/supp1B" |
| |
| |
| |
|
|
| latest_dir = os.path.join(results_dir, timestamp) |
| print(f"Using latest results directory: {latest_dir}") |
|
|
| |
| with open(os.path.join(latest_dir, "metadata.json"), "r") as f: |
| metadata = json.load(f) |
|
|
| |
| attn_maps = np.memmap(os.path.join(latest_dir, "attention_maps.mmap"), |
| dtype=np.float32, |
| mode='r', |
| shape=tuple(metadata["attention_maps_shape"])) |
|
|
| resblocks = np.memmap(os.path.join(latest_dir, "resblocks.mmap"), |
| dtype=np.float32, |
| mode='r', |
| shape=tuple(metadata["resblocks_shape"])) |
|
|
| |
| file_list = metadata.get("file_list", []) |
|
|
| |
| top_k_values = metadata.get("top_k_values", [0]) |
|
|
| return { |
| "attn_maps": attn_maps, |
| "resblocks": resblocks, |
| "metadata": metadata, |
| "file_list": file_list, |
| "top_k_values": top_k_values, |
| "num_layers": metadata.get("num_layers", 0), |
| "num_images": metadata.get("num_images", 0), |
| "num_heads": metadata.get("num_heads", 0) |
| } |
|
|
| def load_individual_attn_shifts(timestamp, supp = "supp1D"): |
| results_dir = f"./results/{supp}" |
| |
| |
| |
|
|
| latest_dir = os.path.join(results_dir, timestamp) |
| print(f"Using latest results directory: {latest_dir}") |
|
|
| |
| with open(os.path.join(latest_dir, "metadata.json"), "r") as f: |
| metadata = json.load(f) |
|
|
| |
| attn_maps = np.memmap(os.path.join(latest_dir, "attention_maps.mmap"), |
| dtype=np.float32, |
| mode='r', |
| shape=tuple(metadata["attention_maps_shape"])) |
|
|
| baseline_attn_maps = np.memmap(os.path.join(latest_dir, "baseline_attention_maps.mmap"), |
| dtype=np.float32, |
| mode='r', |
| shape=tuple(metadata["baseline_attention_maps_shape"])) |
|
|
| neuron_activations = np.memmap(os.path.join(latest_dir, "neuron_activations.mmap"), |
| dtype=np.float32, |
| mode='r', |
| shape=tuple(metadata["neuron_activations_shape"])) |
|
|
| baseline_neuron_activations = np.memmap(os.path.join(latest_dir, "baseline_neuron_activations.mmap"), |
| dtype=np.float32, |
| mode='r', |
| shape=tuple(metadata["baseline_neuron_activations_shape"])) |
|
|
| ablated_neurons = np.memmap(os.path.join(latest_dir, "ablated_neurons.mmap"), |
| dtype=np.float32, |
| mode='r', |
| shape=tuple(metadata["ablated_neurons_shape"])) |
|
|
| |
| file_list = metadata.get("file_list", []) |
|
|
| |
| k = metadata.get("k", 25) |
|
|
| return { |
| "attn_maps": attn_maps, |
| "baseline_attn_maps": baseline_attn_maps, |
| "neuron_activations": neuron_activations, |
| "baseline_neuron_activations": baseline_neuron_activations, |
| "ablated_neurons": ablated_neurons, |
| "metadata": metadata, |
| "file_list": file_list, |
| "k": k, |
| "num_layers": metadata.get("num_layers", 12), |
| "num_images": metadata.get("num_images", 100), |
| "model_name": metadata.get("model_name", "ViT-B-16"), |
| "pretrained": metadata.get("pretrained", "openai") |
| } |
|
|
| def find_register_neurons_cuda(model, preprocess, prs_group, register_norm_threshold = 30, highest_layer = -1, device = "cuda:0", processed_image_cnt = 500): |
| num_layers = len(model.visual.transformer.resblocks) |
| highest_layer = num_layers - 1 if highest_layer == -1 else highest_layer |
| num_neurons = model.visual.transformer.resblocks[0].mlp.state_dict()["c_proj.weight"].shape[1] |
| random_images = load_images(preprocess, count=processed_image_cnt) |
| neuron_scores = torch.zeros((len(random_images), num_layers, num_neurons), device=device) |
| alignment_scores = torch.zeros((len(random_images), num_layers, num_neurons), device=device) |
| image_count = 0 |
|
|
| for i in tqdm(range(len(random_images)), desc="Processing random images"): |
| image = random_images[i].unsqueeze(0).to(device) |
| prs_group.reinit() |
|
|
| with torch.inference_mode(): |
| representation = model.encode_image( |
| image, attn_method="head", normalize=False |
| ) |
| prs_group.finalize() |
|
|
| baseline_neuron_acts = prs_group.post_gelu_outputs().to(device) |
| baseline_resblock_outputs = prs_group.resblock_outputs().to(device) |
|
|
| |
| norm_map = torch.norm(baseline_resblock_outputs[-1], dim=1) |
| filtered_norms = norm_map.clone() |
| filtered_norms[filtered_norms < register_norm_threshold] = 0 |
|
|
| |
| register_locations = torch.where(filtered_norms > register_norm_threshold)[0] |
|
|
| if len(register_locations) == 0: |
| continue |
|
|
| image_count += 1 |
|
|
| |
| for layer in range(num_layers): |
| |
| act_layer = torch.abs(baseline_neuron_acts[layer]) |
|
|
| |
| sparse_neurons = torch.sum(act_layer < 0.5, dim=0) >= 0.5 * act_layer.shape[0] |
|
|
| |
| if not torch.any(sparse_neurons): |
| continue |
|
|
| |
| |
| register_values = act_layer[register_locations] |
|
|
| |
| |
| neuron_means = register_values.mean(dim=0) |
|
|
| |
| neuron_means = neuron_means * sparse_neurons.float() |
|
|
| |
| neuron_scores[i, layer] = neuron_means |
|
|
| |
| mean_neuron_scores = neuron_scores[:image_count].mean(dim=0) |
| mean_alignment_scores = alignment_scores[:image_count].mean(dim=0) |
|
|
| |
| flattened_scores = mean_neuron_scores.flatten() |
| sorted_values, sorted_indices = torch.sort(flattened_scores, descending=True) |
|
|
| flattened_alignment = mean_alignment_scores.flatten() |
| sorted_alignment_values, sorted_alignment_indices = torch.sort(flattened_alignment, descending=True) |
|
|
| |
| top_indices = [(idx.item() // num_neurons, idx.item() % num_neurons) for idx in sorted_indices] |
| top_alignment_indices = [(idx.item() // num_neurons, idx.item() % num_neurons) for idx in sorted_alignment_indices] |
|
|
| register_norms = [ |
| (layer, neuron, sorted_values[i].item()) |
| for i, (layer, neuron) in enumerate(top_indices) |
| if layer <= highest_layer |
| ] |
|
|
| best_alignment_scores = [ |
| (layer, neuron, sorted_alignment_values[i].item()) |
| for i, (layer, neuron) in enumerate(top_alignment_indices) |
| if layer <= highest_layer |
| ] |
|
|
| return register_norms, best_alignment_scores |
|
|
| def find_register_neurons(model, preprocess, prs_group, register_norm_threshold = 30, highest_layer = -1, device = "cuda:0", processed_image_cnt = 500): |
| num_layers = len(model.visual.transformer.resblocks) |
| highest_layer = num_layers - 1 if highest_layer == -1 else highest_layer |
| num_neurons = model.visual.transformer.resblocks[0].mlp.state_dict()["c_proj.weight"].shape[1] |
|
|
| random_images = load_images(preprocess, count = processed_image_cnt) |
| neuron_scores = torch.zeros((len(random_images), num_layers, num_neurons)) |
| for i in tqdm(range(len(random_images)), desc="Processing random images"): |
| image = random_images[i].unsqueeze(0).to(device) |
|
|
| prs_group.reinit() |
| with torch.no_grad(): |
| representation = model.encode_image( |
| image, attn_method="head", normalize=False |
| ) |
| prs_group.finalize() |
|
|
| |
| baseline_neuron_acts = prs_group.post_gelu_outputs().cpu().numpy() |
| baseline_resblock_outputs = prs_group.resblock_outputs().cpu().numpy() |
|
|
| |
| norms = np.linalg.norm(baseline_resblock_outputs[-1], axis=1) |
| norms[norms < register_norm_threshold] = 0 |
| register_locations = np.where(norms > register_norm_threshold)[0] |
|
|
| |
| for layer in range(num_layers): |
| for neuron in range(num_neurons): |
| neuron_map = baseline_neuron_acts[layer, :, neuron] |
| mask = np.zeros_like(neuron_map, dtype=bool) |
| mask[register_locations] = True |
| neuron_map[~mask] = 0 |
| if np.any(neuron_map < 0): |
| continue |
| |
| |
|
|
| neuron_scores[i, layer, neuron] = torch.tensor(neuron_map[register_locations].mean()) |
| mean_neuron_scores = neuron_scores.mean(dim=0) |
| |
| flattened_scores = mean_neuron_scores.flatten() |
| sorted_values, sorted_indices = torch.sort(flattened_scores, descending=True) |
|
|
| |
| top_indices = [(idx.item() // num_neurons, idx.item() % num_neurons) for idx in sorted_indices] |
|
|
| return [(layer, neuron, sorted_values[i].item()) for i, (layer, neuron) in enumerate(top_indices) if layer <= highest_layer] |
|
|
|
|
| def plot_attn_maps(attn_maps, image_idx): |
|
|
| num_layers, num_heads, patch_height, patch_width = attn_maps.shape |
| print(f"Shape of image_shifts: {attn_maps.shape}") |
|
|
| |
| fig, axes = plt.subplots(num_layers, num_heads, figsize=(2*num_heads, 2*num_layers)) |
| fig.suptitle(f'Attention Shift Maps for Image #{image_idx}', fontsize=16) |
|
|
| |
| from mpl_toolkits.axes_grid1 import make_axes_locatable |
|
|
| |
| for layer in range(num_layers): |
| |
| layer_vmin = attn_maps[layer].min().item() |
| layer_vmax = attn_maps[layer].max().item() |
|
|
| for head in range(num_heads): |
| |
| if num_layers == 1 and num_heads == 1: |
| ax = axes |
| elif num_layers == 1: |
| ax = axes[head] |
| elif num_heads == 1: |
| ax = axes[layer] |
| else: |
| ax = axes[layer, head] |
|
|
| |
| im = ax.imshow(attn_maps[layer, head], cmap='viridis', vmin=layer_vmin, vmax=layer_vmax) |
|
|
| |
| ax.set_xticks([]) |
| ax.set_yticks([]) |
|
|
| |
| if head == 0: |
| ax.set_ylabel(f'Layer {layer}') |
| if layer == num_layers-1: |
| ax.set_xlabel(f'Head {head}') |
|
|
| |
| if head == num_heads-1: |
| |
| divider = make_axes_locatable(ax) |
| cax = divider.append_axes("right", size="5%", pad=0.05) |
| plt.colorbar(im, cax=cax) |
|
|
| |
| plt.tight_layout() |
| return plt |
|
|
| def calculate_iou(output, target): |
| intersection = output * (output == target) |
| area_inter = intersection.sum().item() |
| area_pred = output.sum().item() |
| area_target = target.sum().item() |
| union = area_pred + area_target - area_inter |
| iou = area_inter / union |
| return area_inter, union, iou |
|
|
| def calculate_pixel_accuracy(output, target): |
| correct = output * (output == target) |
| correct = correct.sum().item() |
| total = target.sum().item() |
| return correct, total, correct / total |
|
|