Engineering Challenges Interpretability
Captured source
source ↗The engineering challenges of scaling interpretability \ Anthropic Interpretability The engineering challenges of scaling interpretability Jun 13, 2024
In this post, and in the above roundtable video, our researchers reflect on the close relationship between scientific and engineering progress, and discuss the technical challenges they encountered in scaling our interpretability research to much larger AI models.
Last October, the Anthropic Interpretability team published Towards Monosemanticity , a paper applying the technique of dictionary learning to a small transformer model. In May this year, we published Scaling Monosemanticity , where we applied the same technique to a model several orders of magnitude larger. We found tens of millions of “features”—combinations of neurons that relate to semantic concepts—in Claude 3 Sonnet, representing an important step forward in understanding the inner workings of AI models. To continue making this progress, we need more engineers. This might seem surprising if you've only read our early papers (for example Frameworks and Toy Models of Superposition ), which required relatively little engineering. But reading the newer research should make clear the scale of the engineering challenge we face. Below, we share two examples of the technical engineering questions that were involved in our latest research. These illustrate the kinds of problems our engineers are tackling right now, and help explain why we think engineering will be one of the major bottlenecks to progress in AI interpretability—and ultimately, AI safety—research. If you're an engineer, this post is aimed at you. If you’re inspired by the examples of engineering problems discussed below, we strongly encourage you to apply for our Research Engineer role . Engineering Problem 1: Distributed Shuffle Our Sparse Autoencoders—the tools we use to investigate “features”—are trained on the activations of transformers, and those activations need to be shuffled to stop them from learning spurious, order-dependent patterns. When we first started training sparse autoencoders, we could fit our training data on a single GPU and trivially shuffle it. But eventually, we wanted to scale beyond what could fit in memory (imagine starting with the easy task of shuffling a deck of cards, but then scaling it up to shuffling entire warehouses full of cards — it’s a much more difficult problem). At this point, we could have implemented a distributed shuffle that scaled to petabytes. Instead, we decided on an approach we could implement quickly, but which didn't scale as well. We split our shuffle into K jobs where each job was responsible for 1/ K of the shuffled output data. We generated a permutation, had each job do a streaming read of all of the training data, and then had it write out its share of the output. This allowed us to scale further, but the downside was obvious: each job had to read all of the training data. This first took hours, and later took days. By the time we were working on Towards Monosemanticity , we had 100TB of training data (100 billion data points, each being 1KB) and shuffling had become a major headache. Performing a distributed shuffle that scales isn’t a novel or cutting-edge problem. But it was just one of many engineering problems we had to solve quickly to make scientific progress. In this case, we found a helpful blog post and extended the approach to many passes. For one pass, we have N jobs. Each job reads 1/ N of the dataset, shuffles it, and writes out the data in K files each with 1/ NK of the data. The contents of the first file written from each job represent the first 1/ K of the final shuffled data, but it still needs to be shuffled. It’s the same for the second file, and so on. In one pass, we have reduced one shuffle of all the data to N shuffles, each K times smaller. Now, if the shuffles fit in memory on a single machine, we can shuffle it and we’re done. If they don’t fit, we can just run another pass. Let’s say each job can keep 100GB of data in memory, and we write one hundred 1GB files. Each pass reduces the size of the shuffles needed by 100 times. One pass can shuffle 100GB of data, two passes can shuffle 10TB, three passes 1PB, four passes 100PB, and so on. Since we implemented this approach, we’ve stopped thinking about shuffling. Now it’s something that happens quickly, without issues. There are certainly better approaches and faster implementations than ours. But this approach solves our bottleneck, and frees us up to tackle the next problem. Engineering Problem 2: Feature Visualization Pipeline Another engineering challenge has been generating the underlying data for our feature visualizations, which allow users to see specific tokens that are most strongly activated as part of individual features, along with other information ( see the Feature Browser from the Towards Monosemanticity paper at this link ). For each feature, we want to find a variety of dataset examples that activate it to different levels, exploring its full distribution. Doing this efficiently for millions of features is an interesting distributed systems problem. Originally, all of this ran in a single job – but we quickly scaled beyond that. Below is a sketch of our current approach. Our dataset for visualization is 100M tokens, and we need to handle millions of features. First we “shard” over the dataset and features, splitting them into many different parts. Each job iterates over its slice of the dataset and, for its slice of features, keeps track of the K highest activating tokens for each feature and 10* K random tokens that activate the feature (we have already cached the transformer activations in s3, so we don’t need to recompute them). Next, we shard over the features and aggregate the results from the previous pass. This gives us the highest-activating tokens for each feature across the entire dataset, as well as a random set of tokens that activate the feature. These are the examples we’ll show in the feature visualization. For each of these examples, we need to calculate how the feature fires on surrounding tokens. Our first approach sharded over features. Each job loads the transformer activations for the examples of the features for which it’s responsible. The problem is that these examples are randomly distributed across the dataset: there’s no easy way to read only the data the job needs. To improve this, we added a pass sharded over…
Excerpt shown — open the source for the full document.