import os import subprocess import re import random import matplotlib.pyplot as plt import json def get_gpu_memory_usage(): """Returns a list of GPU memory usage in MB.""" try: # Run nvidia-smi command and capture the output result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used', '--format=csv,nounits,noheader'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) # Check if the command was successful if result.returncode != 0: raise RuntimeError(f"nvidia-smi command failed with error: {result.stderr}") # Parse the output to get memory usage memory_usages = [int(x) for x in result.stdout.strip().split('\n')] return memory_usages except Exception as e: print(f"Error querying GPU memory usage: {e}") return [] def set_cuda_visible_device(): """Sets the CUDA_VISIBLE_DEVICES environment variable to the GPU with the smallest memory usage.""" memory_usages = get_gpu_memory_usage() if not memory_usages: print("No GPU memory usage data available.") return # Find the index of the GPU with the smallest memory usage min_memory_index = memory_usages.index(min(memory_usages)) # Set the CUDA_VISIBLE_DEVICES environment variable os.environ["CUDA_VISIBLE_DEVICES"] = str(min_memory_index) print(f"Set CUDA_VISIBLE_DEVICES to GPU {min_memory_index} with {memory_usages[min_memory_index]} MB used.") return str(min_memory_index) os.environ["ASN_ROOT_DIR"] = "/home/nickj/asn/second_order_lens" os.chdir(os.environ["ASN_ROOT_DIR"]) import numpy as np import torch from PIL import Image import os.path import argparse from pathlib import Path from tqdm import tqdm from utils.factory import create_model_and_transforms, get_tokenizer from PIL import Image, ImageDraw def get_model(model_name = "ViT-B/16", pretrained = "openai", device = "cuda:0"): torch.multiprocessing.set_sharing_strategy("file_system") model, _, preprocess = create_model_and_transforms( model_name, pretrained=pretrained, force_quick_gelu=True, ) model.to(device) model.eval() context_length = model.context_length vocab_size = model.vocab_size return { "model": model, "model_name": model_name, "pretrained": pretrained, "preprocess": preprocess, "context_length": context_length, "vocab_size": vocab_size } img_path = "/datasets/ilsvrc_2024-01-04_1913/val/n04398044/ILSVRC2012_val_00042447.JPEG" # img_path = "./sample.jpeg" def load_images(preprocess, image_folder = "/datasets/ilsvrc/current/val", count = 100, images_only = True): file_list = [] for root, dirs, files in os.walk(image_folder): for file in files: file_list.append(os.path.join(root, file)) if count > len(file_list): sampled_files = file_list else: sampled_files = random.sample(file_list, count) image_files = [] for filename in sampled_files: image_files.append(preprocess(Image.open(filename))) if images_only: return image_files else: return image_files, sampled_files def calc_neuron_potentials(model, attn_layers = (1, 2), include_layernorm = True): # Calculates the attention-shifting potential scores for every neuron to the attention heads defined by the given layers (relative to the MLP layer) embed_dim = model.visual.transformer.resblocks[0].attn.out_proj.in_features num_heads = model.visual.transformer.resblocks[0].attn.num_heads head_dim = embed_dim // num_heads layers = len(model.visual.transformer.resblocks) results = dict() for neuron_layer in tqdm(range(layers), desc = "Calculating attention shifting potentials"): neuron_projection = model.visual.transformer.resblocks[neuron_layer].state_dict()["mlp.c_proj.weight"] for l_attn in range(min(layers, neuron_layer + attn_layers[0]), min(layers, neuron_layer + attn_layers[1])): ln_vector = model.visual.transformer.resblocks[l_attn].ln_1.state_dict()["weight"] attn_matrix = model.visual.transformer.resblocks[l_attn].state_dict()["attn.in_proj_weight"] W_Q, W_K, W_V = (attn_matrix[:embed_dim].reshape(num_heads, head_dim, -1), attn_matrix[embed_dim:2*embed_dim].reshape(num_heads, head_dim, -1), attn_matrix[2*embed_dim:].reshape(num_heads, head_dim, -1)) for head_idx in range(num_heads): W_Q_h, W_K_h = W_Q[head_idx], W_K[head_idx] effects = [] for i in range(neuron_projection.shape[1]): if include_layernorm: neuron_attn_effect = torch.norm(W_Q_h.T @ W_K_h @ (neuron_projection[:, i] * ln_vector)) else: neuron_attn_effect = torch.norm(W_Q_h.T @ W_K_h @ neuron_projection[:, i]) effects.append(neuron_attn_effect) results[(neuron_layer, l_attn, head_idx)] = torch.tensor(effects) return results def calc_top_asns(shift_potentials, top_k = 10, per = "layer", layers_away = 1): num_layers = max([key[1] for key in shift_potentials.keys()]) + 1 # the last layer has no ASNs by definition num_heads = max([key[2] for key in shift_potentials.keys()]) top_asns = [] for layer in range(num_layers - layers_away): if per == "layer": potentials = [] for head_idx in range(num_heads): potentials.append(shift_potentials[(layer, layer + layers_away, head_idx)]) potentials = torch.max(torch.stack(potentials, dim = 0), dim = 0).values _, sorted_indices = torch.sort(potentials, descending = True) top_asns.append(sorted_indices[:top_k].tolist()) elif per == "head": top_layer_asns = [] for head_idx in range(num_heads): _, sorted_indices = torch.sort(shift_potentials[(layer, layer + layers_away, head_idx)], descending = True) top_layer_asns.append(sorted_indices[:top_k].tolist()) top_asns.append(top_layer_asns) else: raise ValueError(f"Invalid per value: {per}") return top_asns def aggregate_attn_map(attn_map, layer, head): num_tokens = attn_map.shape[-1] assert (num_tokens - 1) ** 0.5 % 1 == 0, "num_tokens - 1 is not a perfect square" num_patches = int((num_tokens - 1) ** 0.5) aggregate_scores = torch.sum(attn_map[:, layer, head, 1:, 1:], dim = 1).reshape((1, num_patches, num_patches)) return aggregate_scores def attn_map_cls_token(attn_map, layer, head): # Gets the attention map for the CLS token num_tokens = attn_map.shape[-1] assert (num_tokens - 1) ** 0.5 % 1 == 0, "num_tokens - 1 is not a perfect square" num_patches = int((num_tokens - 1) ** 0.5) attn_map_reshaped = attn_map[:, layer, head, 0, 1:].reshape((1, num_patches, num_patches)) return attn_map_reshaped def visualize_attn_shift(attn_map1, attn_map2, image, display=True, out=None, min_diff=None, max_diff=None): import matplotlib.pyplot as plt import numpy as np # Subtract attn_map1 from attn_map2 diff_map = attn_map2 - attn_map1 # Convert the image to RGBA image = image.convert("RGBA") overlay = Image.new("RGBA", image.size, (0, 0, 0, 0)) draw = ImageDraw.Draw(overlay) # Calculate the size of each attention block block_size_x = image.size[0] / diff_map.shape[0] block_size_y = image.size[1] / diff_map.shape[1] # Create a colormap cmap = plt.get_cmap('coolwarm_r') # 'cool' colormap for lighter to darker # Get the min and max values for scaling the colormap if max_diff is None: max_diff = diff_map.max() if min_diff is None: min_diff = diff_map.min() for i in range(diff_map.shape[0]): for j in range(diff_map.shape[1]): # Get the color from the colormap intensity = diff_map[i, j] normalized_intensity = (intensity - min_diff) / (max_diff - min_diff) # Scale to [0, 1] rgba_color = cmap(1 - normalized_intensity) # Invert the normalized intensity color = tuple(int(c * 255) for c in rgba_color[:3]) + (int(rgba_color[3] * 128),) # Draw the rectangle on the overlay with transparency draw.rectangle( [j * block_size_x, i * block_size_y, (j + 1) * block_size_x, (i + 1) * block_size_y], fill=color # Add transparency to the color ) # Composite the overlay with the original image combined = Image.alpha_composite(image, overlay) if display: # Display the result combined.show() # Show the color scale plt.figure(figsize=(6, 1)) plt.imshow([np.linspace(min_diff, max_diff, 256)], cmap='coolwarm_r', aspect='auto') plt.gca().set_visible(False) plt.colorbar(orientation="horizontal") plt.show() if out is not None: combined.save(out) return combined def visualize_attn_shift_binary(attn_map1, attn_map2, image, display=True, out=None): # Creates a visualization of the attention shift where green = positive, red = negative values. # This is useful when there are outliers in the difference map causing the middle values around 0 to be messed into one color # Subtract attn_map1 from attn_map2 diff_map = attn_map2 - attn_map1 # Normalize the difference map to range [0, 1] for visualization diff_map_normalized = (diff_map - diff_map.min()) / (diff_map.max() - diff_map.min()) # Convert the image to RGBA image = image.convert("RGBA") overlay = Image.new("RGBA", image.size, (0, 0, 0, 0)) draw = ImageDraw.Draw(overlay) # Calculate the size of each attention block block_size_x = image.size[0] / diff_map.shape[0] block_size_y = image.size[1] / diff_map.shape[1] for i in range(diff_map.shape[0]): for j in range(diff_map.shape[1]): # Calculate the color intensity based on the difference intensity = diff_map_normalized[i, j] alpha = int(255 * 0.5) # Tone down the alpha to 50% if diff_map[i, j] > 0: color = (0, int(255 * intensity), 0, alpha) # Green for positive else: color = (int(255 * (1 - intensity)), 0, 0, alpha) # Red for negative # Draw the rectangle on the overlay draw.rectangle( [j * block_size_x, i * block_size_y, (j + 1) * block_size_x, (i + 1) * block_size_y], fill=color ) # Composite the overlay with the original image combined = Image.alpha_composite(image, overlay) if display: # Display the result combined.show() if out is not None: combined.save(out) return combined def is_outlier(mean, std, value): return value < mean - 2 * std or value > mean + 2 * std def get_neuron_activations(images, prs_group, model, device = "cuda:0"): # Returns neuron activations in shape (num_images, num_layers, num_patches, num_neurons) random_neuron_acts = [] for image in tqdm(images, desc="Processing images"): prs_group.reinit() image_input = image.unsqueeze(0).to(device) representation = model.encode_image( image_input, attn_method="head", normalize=False ) prs_group.finalize() gelu_outs = prs_group.post_gelu_outputs() random_neuron_acts.append(gelu_outs) random_neuron_acts = torch.stack(random_neuron_acts, dim = 0) return random_neuron_acts def normalize_array(arr): min_val = np.min(arr) max_val = np.max(arr) # Avoid division by zero if all values are the same if max_val - min_val == 0: return np.zeros_like(arr) normalized_arr = (arr - min_val) / (max_val - min_val) return normalized_arr def np_l2(arr1, arr2): return np.linalg.norm(arr1 - arr2) def best_class(classifier, representation): cs = torch.cosine_similarity(classifier, representation.permute(1, 0), dim = 0) return torch.argmax(cs).item(), cs[torch.argmax(cs).item()].item() def load_group_attn_shifts(timestamp): # Load from Supp1B results_dir = "./results/supp1B" # dirs = [os.path.join(results_dir, d) for d in os.listdir(results_dir) # if os.path.isdir(os.path.join(results_dir, d))] # latest_dir = max(dirs, key=os.path.getmtime) latest_dir = os.path.join(results_dir, timestamp) print(f"Using latest results directory: {latest_dir}") # Load metadata with open(os.path.join(latest_dir, "metadata.json"), "r") as f: metadata = json.load(f) # Load memory-mapped files attn_maps = np.memmap(os.path.join(latest_dir, "attention_maps.mmap"), dtype=np.float32, mode='r', shape=tuple(metadata["attention_maps_shape"])) resblocks = np.memmap(os.path.join(latest_dir, "resblocks.mmap"), dtype=np.float32, mode='r', shape=tuple(metadata["resblocks_shape"])) # Get file list from metadata file_list = metadata.get("file_list", []) # Get top_k values top_k_values = metadata.get("top_k_values", [0]) return { "attn_maps": attn_maps, "resblocks": resblocks, "metadata": metadata, "file_list": file_list, "top_k_values": top_k_values, "num_layers": metadata.get("num_layers", 0), "num_images": metadata.get("num_images", 0), "num_heads": metadata.get("num_heads", 0) } def load_individual_attn_shifts(timestamp, supp = "supp1D"): results_dir = f"./results/{supp}" # dirs = [os.path.join(results_dir, d) for d in os.listdir(results_dir) # if os.path.isdir(os.path.join(results_dir, d))] # latest_dir = max(dirs, key=os.path.getmtime) latest_dir = os.path.join(results_dir, timestamp) print(f"Using latest results directory: {latest_dir}") # Load metadata with open(os.path.join(latest_dir, "metadata.json"), "r") as f: metadata = json.load(f) # Load memory-mapped files attn_maps = np.memmap(os.path.join(latest_dir, "attention_maps.mmap"), dtype=np.float32, mode='r', shape=tuple(metadata["attention_maps_shape"])) baseline_attn_maps = np.memmap(os.path.join(latest_dir, "baseline_attention_maps.mmap"), dtype=np.float32, mode='r', shape=tuple(metadata["baseline_attention_maps_shape"])) neuron_activations = np.memmap(os.path.join(latest_dir, "neuron_activations.mmap"), dtype=np.float32, mode='r', shape=tuple(metadata["neuron_activations_shape"])) baseline_neuron_activations = np.memmap(os.path.join(latest_dir, "baseline_neuron_activations.mmap"), dtype=np.float32, mode='r', shape=tuple(metadata["baseline_neuron_activations_shape"])) ablated_neurons = np.memmap(os.path.join(latest_dir, "ablated_neurons.mmap"), dtype=np.float32, mode='r', shape=tuple(metadata["ablated_neurons_shape"])) # Get file list from metadata file_list = metadata.get("file_list", []) # Get k value k = metadata.get("k", 25) return { "attn_maps": attn_maps, "baseline_attn_maps": baseline_attn_maps, "neuron_activations": neuron_activations, "baseline_neuron_activations": baseline_neuron_activations, "ablated_neurons": ablated_neurons, "metadata": metadata, "file_list": file_list, "k": k, "num_layers": metadata.get("num_layers", 12), "num_images": metadata.get("num_images", 100), "model_name": metadata.get("model_name", "ViT-B-16"), "pretrained": metadata.get("pretrained", "openai") } def find_register_neurons_cuda(model, preprocess, prs_group, register_norm_threshold = 30, highest_layer = -1, device = "cuda:0", processed_image_cnt = 500): num_layers = len(model.visual.transformer.resblocks) highest_layer = num_layers - 1 if highest_layer == -1 else highest_layer num_neurons = model.visual.transformer.resblocks[0].mlp.state_dict()["c_proj.weight"].shape[1] random_images = load_images(preprocess, count=processed_image_cnt) neuron_scores = torch.zeros((len(random_images), num_layers, num_neurons), device=device) alignment_scores = torch.zeros((len(random_images), num_layers, num_neurons), device=device) image_count = 0 for i in tqdm(range(len(random_images)), desc="Processing random images"): image = random_images[i].unsqueeze(0).to(device) prs_group.reinit() with torch.inference_mode(): representation = model.encode_image( image, attn_method="head", normalize=False ) prs_group.finalize() baseline_neuron_acts = prs_group.post_gelu_outputs().to(device) baseline_resblock_outputs = prs_group.resblock_outputs().to(device) # Calculate norm map using torch norm_map = torch.norm(baseline_resblock_outputs[-1], dim=1) filtered_norms = norm_map.clone() filtered_norms[filtered_norms < register_norm_threshold] = 0 # Get register locations as a tensor register_locations = torch.where(filtered_norms > register_norm_threshold)[0] if len(register_locations) == 0: continue image_count += 1 # Process all layers vectorized for layer in range(num_layers): # Get absolute activations for all neurons in this layer act_layer = torch.abs(baseline_neuron_acts[layer]) # Shape: [seq_len, num_neurons] # Check sparsity condition for all neurons at once sparse_neurons = torch.sum(act_layer < 0.5, dim=0) >= 0.5 * act_layer.shape[0] # Shape: [num_neurons] # Skip computation if no neurons meet the condition if not torch.any(sparse_neurons): continue # Get values at register locations for all neurons simultaneously # This creates a tensor of shape [num_register_locations, num_neurons] register_values = act_layer[register_locations] # For neurons that pass sparsity condition, compute mean at register locations # First, compute mean for all neurons (this is fast) neuron_means = register_values.mean(dim=0) # Shape: [num_neurons] # Then zero out means for neurons that don't pass sparsity condition neuron_means = neuron_means * sparse_neurons.float() # Store in score tensor neuron_scores[i, layer] = neuron_means # Rest of the code remains the same mean_neuron_scores = neuron_scores[:image_count].mean(dim=0) mean_alignment_scores = alignment_scores[:image_count].mean(dim=0) # Flatten and find top values flattened_scores = mean_neuron_scores.flatten() sorted_values, sorted_indices = torch.sort(flattened_scores, descending=True) flattened_alignment = mean_alignment_scores.flatten() sorted_alignment_values, sorted_alignment_indices = torch.sort(flattened_alignment, descending=True) # Convert indices to layer/neuron pairs top_indices = [(idx.item() // num_neurons, idx.item() % num_neurons) for idx in sorted_indices] top_alignment_indices = [(idx.item() // num_neurons, idx.item() % num_neurons) for idx in sorted_alignment_indices] register_norms = [ (layer, neuron, sorted_values[i].item()) for i, (layer, neuron) in enumerate(top_indices) if layer <= highest_layer ] best_alignment_scores = [ (layer, neuron, sorted_alignment_values[i].item()) for i, (layer, neuron) in enumerate(top_alignment_indices) if layer <= highest_layer ] return register_norms, best_alignment_scores def find_register_neurons(model, preprocess, prs_group, register_norm_threshold = 30, highest_layer = -1, device = "cuda:0", processed_image_cnt = 500): num_layers = len(model.visual.transformer.resblocks) highest_layer = num_layers - 1 if highest_layer == -1 else highest_layer num_neurons = model.visual.transformer.resblocks[0].mlp.state_dict()["c_proj.weight"].shape[1] random_images = load_images(preprocess, count = processed_image_cnt) neuron_scores = torch.zeros((len(random_images), num_layers, num_neurons)) for i in tqdm(range(len(random_images)), desc="Processing random images"): image = random_images[i].unsqueeze(0).to(device) prs_group.reinit() with torch.no_grad(): representation = model.encode_image( image, attn_method="head", normalize=False ) prs_group.finalize() # Gather neuron activations and resblock outputs baseline_neuron_acts = prs_group.post_gelu_outputs().cpu().numpy() baseline_resblock_outputs = prs_group.resblock_outputs().cpu().numpy() # Calculate norms of the last resblock outputs. Only consider patches of the activation maps that correspond with registers norms = np.linalg.norm(baseline_resblock_outputs[-1], axis=1) norms[norms < register_norm_threshold] = 0 register_locations = np.where(norms > register_norm_threshold)[0] # register_neurons = [] for layer in range(num_layers): for neuron in range(num_neurons): neuron_map = baseline_neuron_acts[layer, :, neuron] mask = np.zeros_like(neuron_map, dtype=bool) mask[register_locations] = True neuron_map[~mask] = 0 if np.any(neuron_map < 0): continue # dist = np.linalg.norm(normalize_array(norms) - normalize_array(neuron_map)) # register_neurons.append((layer, neuron, dist.item(), neuron_map[register_locations].mean())) neuron_scores[i, layer, neuron] = torch.tensor(neuron_map[register_locations].mean()) mean_neuron_scores = neuron_scores.mean(dim=0) # Flatten the 2D tensor to find global top values flattened_scores = mean_neuron_scores.flatten() sorted_values, sorted_indices = torch.sort(flattened_scores, descending=True) # Convert flat indices back to 2D coordinates (layer, neuron) top_indices = [(idx.item() // num_neurons, idx.item() % num_neurons) for idx in sorted_indices] return [(layer, neuron, sorted_values[i].item()) for i, (layer, neuron) in enumerate(top_indices) if layer <= highest_layer] def plot_attn_maps(attn_maps, image_idx): num_layers, num_heads, patch_height, patch_width = attn_maps.shape print(f"Shape of image_shifts: {attn_maps.shape}") # Create a grid of plots for all layers and heads fig, axes = plt.subplots(num_layers, num_heads, figsize=(2*num_heads, 2*num_layers)) fig.suptitle(f'Attention Shift Maps for Image #{image_idx}', fontsize=16) # Import the correct module for make_axes_locatable from mpl_toolkits.axes_grid1 import make_axes_locatable # Plot each layer-head combination for layer in range(num_layers): # Determine min and max for this layer for consistent colorbar scaling within the layer layer_vmin = attn_maps[layer].min().item() layer_vmax = attn_maps[layer].max().item() for head in range(num_heads): # Get the current axis (handle both 2D and 1D cases) if num_layers == 1 and num_heads == 1: ax = axes elif num_layers == 1: ax = axes[head] elif num_heads == 1: ax = axes[layer] else: ax = axes[layer, head] # Plot the attention shift map with layer-specific normalization im = ax.imshow(attn_maps[layer, head], cmap='viridis', vmin=layer_vmin, vmax=layer_vmax) # Remove ticks for cleaner appearance ax.set_xticks([]) ax.set_yticks([]) # Add layer and head labels only on the edges if head == 0: ax.set_ylabel(f'Layer {layer}') if layer == num_layers-1: ax.set_xlabel(f'Head {head}') # Add a colorbar for each layer (only once per row) if head == num_heads-1: # Create a colorbar that's properly sized relative to the plot divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="5%", pad=0.05) plt.colorbar(im, cax=cax) # Adjust layout to make room for the colorbars plt.tight_layout() return plt def calculate_iou(output, target): intersection = output * (output == target) area_inter = intersection.sum().item() area_pred = output.sum().item() area_target = target.sum().item() union = area_pred + area_target - area_inter iou = area_inter / union return area_inter, union, iou def calculate_pixel_accuracy(output, target): correct = output * (output == target) correct = correct.sum().item() total = target.sum().item() return correct, total, correct / total