A practical guide to migrating your PyTorch training code to TPUs, with step-by-step debugging and optimization tips.
Recently, I’ve been working on training speculative decoding
I’m experimenting with the training pipeline and the model architecture to get better performance, therefore I have to re-train those models many times. I’ve chosen Llama 3.1 8B as the target model, and my drafter model is a single layer transformer with 450M parameters. Even though they are small, the compute required to train is still very high. On a single H100, it takes around 4 days to train on a 1.4M row dataset for 3 epochs. Even though I have access to up to 8 H100s using the university’s resources, scheduling of those GPUs might take days and running multiple experiments still takes a lot of time. They greatly help during the development and running small experiments without the full training, but I needed more compute to scale up my experiments especially for bigger models.
The training pipeline is open-source and available at SpecForge and the architecture follows EAGLE 3.
I’ve stumbled upon Google’s TRC program while I was looking for GPU grants. Google gives away their spare TPUs while they are not being used to the researchers. The application process was pretty easy, I just filled a short form and two days later I got access to hundreds of TPUs (in different regions). These limits might change user to user, however the one I was most interested in was 64 spot* Cloud TPU v6e chips.
| Feature | Google TPU v6e (Trillium) | NVIDIA H100 (SXM) |
|---|---|---|
| Compute (BF16) | 918 TFLOPs | 990 TFLOPs (1,979 Sparse) |
| HBM Capacity | 32 GB | 80 GB |
| HBM Bandwidth | 1.6 TB/s | 3.35 TB/s |
| Interconnect | 800 GB/s | 900 GB/s (NVLink) |
The amount of compute I will get by upgrading from 8 H100 to 64 v6e would help me a lot with scaling my experiments, so I’ve decided to migrate my codebase to support TPUs. To start with it, I choose a 4-chip v6e VM. This should give us performance roughly around 4 H100 GPUs and a total of 128 GB of HBM. Notice that TPUs have significantly less memory, thus we need to be more careful with how we distribute our models accross chips.
I used torch==2.9.0 and torch_xla[tpu]==2.9.0, which is the latest stable release as of Feb 2 2026. First, we need to remove all .cuda() calls and use a utility class to call to_device(...). Next, we should ensure our models run on bfloat16 as TPUs are optimized for bf16 operations. Since we don’t use torch.distributed anymore, we hide all tensor parallel logic behind a flag that is disabled when training is run on TPUs. However, SPMD on multi-node requires initializing torch.distributed for coordination, but it can’t use xla backend. Thus, you must use gloo to initialize. Note that if you forget to opt-out of dist.all_gather or other operations, you might see your training run very slowly, because those tensors are being broadcast over the gloo backend.
Tip: I’ve experienced some hard-to-debug issues due to a race condition when SPMD isn’t initialized early enough.
Since I mentioned SPMD (Single Program Multiple Data), let me explain what it is. The GPU code runs using Fully Sharded Data Parallelism (FSDP). The weights of each FSDP Unit are sharded amongst the devices, they are gathered on each device to be used during calculation of next layer. This reduces the memory footprint of models as those weights are lazily loaded. For more information check the official PyTorch tutorial on FSDP.
I wanted to keep using the same strategy for multi-TPU training. Even though pytorch_xla provides a FSDP class, I prefer to manually shard entities, as I found this wrapper to be ineffective for our training pipeline for a single layer transformer, yielding a single FSDP unit. Moreover, our TPUs (v6e) has only 32GB of HBM, whereas our H100 GPU had 80GB of VRAM. This discrepancy forces us to shard our model more aggressively using SPMD, which is an automatic parallelization system for common ML workloads.
Overview of the SPMD, source: pytorch.org/blog/pytorch-xla-spmd. The user defines a logical mesh that maps to the physical TPU topology. By applying sharding annotations to tensors (weights and inputs), the user guides the XLA compiler on how to distribute data across the mesh. The compiler then automatically partitions the computation graph, inserting the necessary collective operations such as all-reduces and all-gathers to ensure mathematical correctness across all devices.
SPMD allows us to shard tensors among devices without explicitly specifying communication & collective operations such as all-gather or all-reduce
Developers specify how tensors are sharded using shrading specs. I.e. if you use the sharding spec (None,) on RMSNorm, you will replicate the weights of RMSNorm on all devices. Assume you set your mesh shape to ('fsdp', 'model') and (4,1), If you use ('fsdp',), you will shard the weights among all 4 devices, meaning each holds 1/4 of the weights. For more details, check PyTorch XLA’s SPMD guide.
To migrate our models to SPMD, I’ve created some helpers to automatically shard the model using xm.mark_sharding(...). The initial version sharded weights with two dimensions on the first dimension, while leaving the other such as RMSNorm replicated to prevent excessive communication.
I’ve decided to fully utilize the TPUs by using (4,1) topology: Split all weights and input amongst the first dimension, meaning all devices will roughly hold 1/4 of the entire model weights thus the computation. Also make sure to use mp.MpDeviceLoader. and set input to shard as ('fsdp', None) to split the batch amongst devices.
In the first run, I got torch._inductor.exc.InductorError: LoweringException on the initial run coming from compiled kernels. Thus, I decided to disable dynamic compilations by replacing all @torch.compile with @maybe_compile and disabled compilations when running on TPU target and we got our first successful run.
TPU Runtime Utilization
┏━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Chip ┃ HBM Usage (GiB) ┃ Duty cycle ┃
┡━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ 0 │ 9.92 GiB / 31.25 GiB │ 0.88% │
│ 1 │ 9.92 GiB / 31.25 GiB │ 0.89% │
│ 2 │ 9.92 GiB / 31.25 GiB │ 0.89% │
│ 3 │ 9.92 GiB / 31.25 GiB │ 0.89% │
└──────┴──────────────────────┴────────────┘
Looking at tpu-info logs, it seems like we are able to split our model into 4 TPU cores successfully. But we have an issue,
TensorCore Utilization
┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Core ID ┃ TensorCore Utilization ┃
┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
│ 0 │ 0.00% │
│ 1 │ 0.00% │
│ 2 │ 0.00% │
│ 3 │ 0.00% │
└─────────┴────────────────────────┘
We see our TensorCore utilization is really low and our runs are very slow.
The first thing we notice when we run our TPU program is that it is really slow. TPUs are not run on eager mode like GPUs by default, so they need to do a fake forward pass to compute the computation graph, and later compile that computation graph to be optimized for MXU (matrix multiplication units). This compilation process is really expensive and causes a big overhead for starting our training pipeline. In our case, it took several minutes to run a single iteration. By setting PT_XLA_DEBUG_LEVEL=2, we can inspect those compilations. The first thing we notice is that it complains about graph breaks. As TPUs are really good at fusing operations (automatic fused kernels), breaking this graph with operations that require moving data to CPU trigger compilations and less efficient computation graphs. In addition to graph breaks, we see that a re-compilation is triggered after each forward pass,
Compilation Analysis: Compilation Cause
Compilation Analysis: torch_xla.sync in parallel loader at step end
Compilation Analysis: Graph Info:
Compilation Analysis: Graph Hash: ad9b7364e3d7a77aa6178b3269100fd6
Compilation Analysis: Number of Graph Inputs: 360
Compilation Analysis: Number of Graph Outputs: 41
The graph hash keeps changing, we are forced to recompile which is quite slow. Metrics also point out that there are 4 compilations and each compilation takes several minutes,
Metric: CompileTime
TotalSamples: 4
Accumulator: 04m32s050ms558.590us
Metric: ExecuteReplicatedTime
TotalSamples: 4
Accumulator: 01s390ms444.752us
...
Counter: MarkStep
Value: 1
Let’s try to understand why graph is forced to recompile for each iteration. We will set the following parameters to save our IR (Intermediate Representation) of TPU compilations and try to understand how each graph differs from each other.
export XLA_IR_DEBUG=1
export XLA_SAVE_TENSORS_FMT=text
export XLA_SAVE_TENSORS_FILE=/tmp/save.ir
Our first couple graphs are pretty small, it is not ideal to have such small graphs, however we will return back to this later on. There is one graph that is much bigger than the others and keeps recompiling,
## BEGIN_GRAPH
IR {
%0 = s64[] xla::device_data(), [email protected]:570, xla_shape=s64[]
%1 = s64[] xla::device_data(), [email protected]:570, xla_shape=s64[]
%2 = s64[4,473]{1,0} xla::device_data(), scope=cpu_data_to_xla_device.2, location=convert_fn@xla_model.py:1294, xla_shape=s64[4,473]{1,0}
%3 = s64[4,473,1]{2,1,0} aten::view(%2), location=generate_eagle3_data@eagle3_target_model.py:600, xla_shape=s64[4,473,1]{2,1,0}
%4 = (s64[4,473,1]{2,1,0}) aten::split(%3), location=get_dp_data_shard_from_tp@train_eagle3.py:799, xla_shape=(s64[4,473,1]{2,1,0})
%5 = s64[] aten::sum(%4), [email protected]:570, xla_shape=s64[]
%6 = s64[] aten::clamp(%5, %1, %0), [email protected]:570, xla_shape=s64[]
%7 = f32[] xla::cast(%6), [email protected]:568, xla_shape=f32[]
%8 = bf16[] prim::Constant(), [email protected]:51, xla_shape=bf16[]
%9 = bf16[] aten::expand(%8), [email protected]:51, xla_shape=bf16[]
...
As this graph is huge, it is hard to understand what is going on. It has all the computations for our target model LLM’s forward pass, newly trained model’s forward pass and backwards pass combined with optimizer steps. If we search for other compilations that are triggered for the same line [email protected]:570, we will see a prime suspect,
%6 = s64[4,1244,1]{2,1,0} aten::view(%5), location=generate_eagle3_data@eagle3_target_model.py:570, xla_shape=s64[4,1244,1]{2,1,0}
...
%6 = s64[4,491,1]{2,1,0} aten::view(%5), location=generate_eagle3_data@eagle3_target_model.py:570, xla_shape=s64[4,491,1]{2,1,0}
...
The shape of one of the inputs to the graph changes each time, causing a recompilation. The shape (4, 491, 1) is oddly familiar, it comes from generate_eagle3_data and it consists of our (batch_size, seq_len, 1), meaning this is our input ids. So the varying sequence length of our inputs is causing recompilation. The best solution is to group similar length inputs together (which we already do using DistributedLengthGroupedSampler), however it still yields different sequence lengths. Best I can think of is to pad those inputs into their closest multiple of 128 (as MXUs operate on 128x128 tiles). So we update the DataLoader with this call,
max_length = max(item["input_ids"].shape[1] for item in features)
max_length = min(max_length, self.max_seq_len)
if is_tpu():
padding_lookup = [128 * i for i in range(20)]
max_length = [x for x in self.padding_lookup if x >= max_length][0]
batch_input_ids = torch.cat(
[self.paddingtensor2D(item["input_ids"], max_length) for item in features]
)
This still causes frequent compilations during the initial stages of our training pipeline, however as time passess, our recompilations stop and we keep using cached computation graphs. For simplicity, predictability and rapid development, I’ve decided to pad everything to 2048 for the rest of this debugging session. I will get back to comparing results for fixed length and dynamic length later on. After running the code again with 2048 max sequence length, waiting a couple minutes and couple iterations for initial compilations to occur, we see our speed has increased from minutes to seconds for each iteration.
TPU Runtime Utilization TensorCore Utilization
┏━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Chip ┃ HBM Usage (GiB) ┃ Duty cycle ┃ ┃ Core ID ┃ TensorCore Utilization ┃
┡━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ ┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
│ 0 │ 8.93 GiB / 31.25 GiB │ 92.95% │ │ 0 │ 28.14% │
│ 1 │ 8.93 GiB / 31.25 GiB │ 92.95% │ │ 1 │ 28.32% │
│ 2 │ 8.93 GiB / 31.25 GiB │ 92.95% │ │ 2 │ 28.50% │
│ 3 │ 8.93 GiB / 31.25 GiB │ 92.95% │ │ 3 │ 28.55% │
└──────┴──────────────────────┴────────────┘ └─────────┴────────────────────────┘
Also this is what happens if you forget to comment out dist calls such as dist.all_reduce(accuracies, op=dist.ReduceOp.AVG) and replace with xm.all_reduce. As we utilized gloo backend, dist.all_reduce moves tensors to CPU and runs the operation over the standard network. In constrast, xm.all_reduce utilizes the high-speed TPU interconnect and keeps the operation within the HLO computation graph, allowing XLA to optimize the execution lazily without breaking the trace. So be careful!
Training Epoch 0: 0%| | 85/322442 [06:38<212:42:52, 2.38s/it]
TPU Runtime Utilization TensorCore Utilization
┏━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Chip ┃ HBM Usage (GiB) ┃ Duty cycle ┃ ┃ Core ID ┃ TensorCore Utilization ┃
┡━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ ┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
│ 0 │ 8.93 GiB / 31.25 GiB │ 34.47% │ │ 0 │ 11.65% │
│ 1 │ 8.93 GiB / 31.25 GiB │ 34.46% │ │ 1 │ 11.91% │
│ 2 │ 8.93 GiB / 31.25 GiB │ 34.46% │ │ 2 │ 13.17% │
│ 3 │ 8.93 GiB / 31.25 GiB │ 34.46% │ │ 3 │ 12.81% │
└──────┴──────────────────────┴────────────┘ └─────────┴────────────────────────┘
However, we are still severely under-utilizing our TPU cores. We can inspect the metrics by calling met.short_metrics_report() and met.clear_metrics() after each batch. Let’s take a look,
Training Epoch 0: 0%| | 74/322442 [04:25<84:17:51, 1.06it/s]
Counter: CachedCompile
Value: 438
Metric: ExecuteReplicatedTime
TotalSamples: 6
Accumulator: 01s039ms247.686us
ValueRate: 01s473ms641.554us / second
Rate: 8.50216 / second
Metric: TransferToDeviceTime
TotalSamples: 3
Accumulator: 206.900us
ValueRate: 026ms712.659us / second
Rate: 372.827 / second
Metric: TransferFromDeviceTime
TotalSamples: 6
Accumulator: 860ms592.740us
ValueRate: 01s160ms288.419us / second
Rate: 8.09887 / second
Counter: MarkStep
Value: 75
Counter: aten::_local_scalar_dense
Value: 223
Counter: aten::nonzero
Value: 76
As you can see, our counter CachedCompile keeps increasing, meaning we are not re-compiling the computation graph. But,
CachedCompile than our MarkStep, meaning we are splitting our single training loop into 4 computation graphs.TransferToDeviceTime and TransferFromDeviceTime per batch.aten::nonzero call and aten::_local_scalar_dense call that force data movement between CPU and device.Transfers between CPU and device happen with calls such as .cpu() or .item(), which we have to explicit calls to,
accuracies = accuracies.cpu().tolist()
plossess = plossess.cpu().tolist()
Our initial logging frequency was every step, surely we can optimize this quickly by logging less frequently such as every 100 steps. However there is a better solution PyTorch XLA provides, xm.add_step_closure, which executes the code after the given arguments are already computed by the graph at the end of the step, designed for things such as logging metrics.
def logging_closure(accuracies, losses, current_step):
acces_list = acces_stacked.tolist()
plosses_list = plosses_stacked.tolist()
avg_loss = sum(plosses_list) / len(plosses_list)
avg_acc = sum(acces_list) / len(acces_list)
print(f"Accuracy: {avg_acc:.2f}, loss: {avg_loss:.2f}")
# Or log to external providers such as WandB...
if global_step % (args.log_interval * args.draft_accumulation_steps) == 0:
xm.add_step_closure(
logging_closure,
args=(
acces_tensor,
plosses_tensor,
current_step_val,
)
)
This small change increases our throughput from 2.4 it/s to 4.3 it/s !
Training Epoch 0: 0%| | 97/322442 [02:34<20:57:43, 4.27it/s]
Counter: CachedCompile
Value: 192
Metric: ExecuteReplicatedTime
TotalSamples: 2
Accumulator: 178ms749.887us
ValueRate: 01m16s143ms388.265us / second
Rate: 856.748 / second
...
Metric: TransferToDeviceTime
TotalSamples: 2
Metric: TransferFromDeviceTime
TotalSamples: 4
Counter: aten::_local_scalar_dense
Value: 98
Counter: aten::nonzero
Value: 99
And our utilization also looks much better.
TPU Runtime Utilization TensorCore Utilization
┏━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Chip ┃ HBM Usage (GiB) ┃ Duty cycle ┃ ┃ Core ID ┃ TensorCore Utilization ┃
┡━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ ┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
│ 0 │ 8.37 GiB / 31.25 GiB │ 77.69% │ │ 0 │ 24.15% │
│ 1 │ 8.37 GiB / 31.25 GiB │ 77.69% │ │ 1 │ 23.83% │
│ 2 │ 8.37 GiB / 31.25 GiB │ 77.70% │ │ 2 │ 23.64% │
│ 3 │ 8.37 GiB / 31.25 GiB │ 77.70% │ │ 3 │ 23.51% │
└──────┴──────────────────────┴────────────┘ └─────────┴────────────────────────┘
We’re now down to 2 graphs from 4. There is one small graph that takes around 382us and one big graph that takes 177ms, which is our forward + backwards pass. Now, let’s try to figure out the remaining transfers, aten::_local_scalar_dense and aten::nonzero calls. Let’s look at unintended re-compilations to merge our loop into a single big computation graph.
Compilation Analysis: Compilation Cause
Compilation Analysis: most likely user code trying to access tensor value before torch_xla.sync
Compilation Analysis: Graph Info:
Compilation Analysis: Graph Hash: dd86da112d0f633a545849a2943f2f29
Compilation Analysis: Number of Graph Inputs: 1
Compilation Analysis: Number of Graph Outputs: 1
Compilation Analysis: Python Frame Triggered Execution:
Compilation Analysis: _ignore_causal_mask_sdpa (/home/dogac/miniconda/envs/specforge-tpu/lib/python3.11/site-packages/transformers/masking_utils.py:255)
Compilation Analysis: sdpa_mask_recent_torch (/home/dogac/miniconda/envs/specforge-tpu/lib/python3.11/site-packages/transformers/masking_utils.py:374)
Compilation Analysis: create_causal_mask (/home/dogac/miniconda/envs/specforge-tpu/lib/python3.11/site-packages/transformers/masking_utils.py:825)
Compilation Analysis: forward (/home/dogac/specforge/specforge/modeling/target/custom_backend/llama.py:323)
This is an optimization inside transformers library’s that allow skipping some heavy calculations for the causal mask.
padding_mask.all()
if not _is_torch_xpu_available or query_length == 1
else padding_mask[:, :query_length].all()
However the .all() call causes the graph to materialize and break out computation graph. As this is not working well in TPUs, I will override the allow_is_causal_skip flag by monkey-patching it, so that it always calculates the causal mask. After this change, I did not see a significant benefit on runtime, but now there is only 1 computation graph and the CachedCompile matches NumSteps while aten::_local_scalar_dense calls disappear.
Now let’s take a look at aten::nonzero calls,
target_head = target_head[..., t2d]
Here, t2d is a mapping from the verifier model’s token IDs to the drafter’s. The size of target_head is dynamic because it depends on the number of ones in the t2d tensor. Since these IDs are fixed, we can pre-compute the indices and avoid the expensive nonzero operation. This also lets the compiler know the exact dimensions during the forward pass.
self.t2d_indices = torch.nonzero(self.t2d).squeeze() # Pre-compute once at the start
target_head = torch.index_select(target_head, -1, t2d_indices)
After this change, the number of aten::nonzero calls also disappear.
Secondly, increasing our batch size from 4 to 16 increased our average batch processing time from 4.3 it/s to 5.2 it/s while increasing the Duty Cycle up to 99%. From now on, I will keep reporting the speed for processing a batch of 16 elements which is 1.30 it/s.
However, we are still not quite there yet, we run with around 25% efficiency. Metrics have only taken us this far. So let’s enable the tensorboard profiler. In the profiler, we can inspect memory, see the slowest HLO ops, do roofline analysis and inspect computation graphs.
import torch_xla.debug.profiler as xp
server = xp.start_server(9012)
# On another shell, run `tensorboard --logdir ~/tensorboard --port 6006`
Let’s take a look at the roofline analysis to understand how close we are to our hardware limits. Currently, we are using 50% of our HBM bandwidth and only 22.6% of our FLOPs.
| Step | Total Time per core (us) | Normalized FLOP Rate (GFLOP/s) | Bound by | HBM BW (GiB/s) | Roofline efficiency (%) | FLOP Rate / Peak (%) | Max memory BW utilization (%) |
|---|---|---|---|---|---|---|---|
| Total | 1,006,078 | 207,779.58 | HBM | 771.74 | 50.6% | 22.6% | 50.6% |
One optimization I wanted to test is, removing the custom causal mask we are passing to the model and let the underlying SDPA kernel handle it. My intuition was that disabling the allow_is_causal_skip caused additional computation. By setting causal_mask to None instead of creating it myself, I saw a significant improvement to 1.89 it/s. Once you take a look at the roofline analysis, you will see our HBM utilization dropped and FLOPs utilization has increased.
| Step | Total Time per core (us) | Normalized FLOP Rate (GFLOP/s) | Bound by | HBM BW (GiB/s) | Roofline efficiency (%) | FLOP Rate / Peak (%) | Max memory BW utilization (%) |
|---|---|---|---|---|---|---|---|
| 1 | 1,003,930 | 298,421.26 | HBM | 622.86 | 40.8% | 32.5% | 40.8% |
When we look at the HLO Op stats, we notice some operations create very expensive all-to-all or all-reduce operations. Those operations are triggered because weights are sharded among TPUs hence their outputs are too. When moving on to the next calculation, the computation might need another view of the tensor, requiring fetching all tensors and re-sharding or computation might need to aggregate results from other TPUs, requiring all-gather operations such as averaging gradients amongst different batches.
%all-to-all.2 = s32[2,2048,4,32064]{1,3,0,2:T(8,128)} all-to-all(s32[2,2048,4,32064]{1,3,0,2:T(8,128)} %multiply_add_fusion), channel_id=485, replica_groups=0, dimensions={2}
%all-reduce.2 = bf16[8,2048,32000]{1,2,0:T(8,128)(2,1)} all-reduce(bf16[8,2048,32000]{1,2,0:T(8,128)(2,1)} %fusion.151), channel_id=479, replica_groups=0, use_global_device_ids=true, to_apply=%add.2.clone
Those are vocabularies (32064 * 4 = 128,256) for the target model, 32000 for the draft model. We can sacrifice some memory as we are way below our per chip maxixmum memory, and duplicate those big weights in each chip. This would eliminate the need for those expensive communications. After replicating the LM head weights, training speed has increased to 2.13it/s.
| Step | Total Time per core (us) | Normalized FLOP Rate (GFLOP/s) | Bound by | HBM BW (GiB/s) | Roofline efficiency (%) | FLOP Rate / Peak (%) | Max memory BW utilization (%) |
|---|---|---|---|---|---|---|---|
| 1 | 970,951 | 343,915.37 | HBM | 652.71 | 42.8% | 37.5% | 42.8% |
Moreover, I’ve tried switching the down_proj and o_proj layers to row parallelism similar to Tensor Parallel (TP) workloads. However it made no change in performance, hinting that its communication is already optimized by the compiler.
Automatic Mixed Precision (AMP) is a technique used to train models more efficiently by using lower precision FP16 or BF16 for operations that are safer for lower precision, while automatically switching to FP32 for operations requiring higher precision. Note that our TPU specs have performance reported as BF16, because MXU can only operate on BF16 precision floats. In order to process FP32, you need to use Vector Processing Units (VPU) which has significantly less compute power, since they are not arranged as a 2D grid. We can enable automatic casting of operations to the right numerical type using the torch AMP autocast,
with torch.autocast(device_type="xla", dtype=torch.bfloat16):
After the change, our speed has increased to 2.17it/s.
Torch XLA automatically compiles your model code lazily into computation graphs. However, we can also utilize TorchDynamo, a Python-level JIT compiler, to optimize for the bytecode.
self.model = torch.compile(model, backend="openxla")
This has increased our speed up to 2.38it/s. However, note that not all compilations are optimal. For example trying to compile our loss function like we have done in CUDA case results in a slowdown and excessive memory usage. In this case, our speed has dropped to 2.15it/s after compiling.
@torch.compile(backend="openxla")
def _compute_loss(logits, target_p, position_mask):
logits = logits.float()
out_logp = nn.LogSoftmax(dim=2)(logits)
plogp = target_p * out_logp
loss = -torch.sum(position_mask * plogp, 2).mean()
return loss
Looking into the memory metrics using our profiler, forcing the compilation of a specific function creates memory fragmentation. Torch XLA compiler can successfully merge some of the computations in forwards and backwards passes, however forcing an eager compilation could prevent those fusions for a more efficient computation graph. Therefore @torch.compile is not always good for XLA workloads.
Let’s take a look at our performance and compare our v6e-4 node with a single H100 and a 4 H100 node. The first benchmark is a simple case of next token prediction, which I will denote with TTT=1*. We used two different attention kernels in CUDA, note that this attention kernel is only changing the drafter model, not the verifier. The verifier is hard-coded to use SDPA. The speedup reported on the TPU is relative to the best performing 4-GPU equivalent.
TTT=1, Sequence Length = 2048 (it/s)
| H100 (sdpa) | 4xH100 (sdpa) | H100 (flex) | 4xH100 (flex) | v6e-4 | |
|---|---|---|---|---|---|
| bs=32 | OOM | 40.96 | OOM | 42.88 | 37.76 (0.88x) |
| bs=16 | OOM | 40.16 | 10.95 | 41.76 (3.81x) | 38.40 (0.91x) |
| bs=4 | 10.4 | 31.2 (3.08x) | 10.84 | 32.8 | 32.36 (0.98x) |
| bs=1 | 8.85 | N/A | 9.20 | N/A | N/A |
TTT=1, Sequence Length = Dynamic (it/s)
| H100 (sdpa) | 4xH100 (sdpa) | H100 (flex) | 4xH100 (flex) | v6e-4 | |
|---|---|---|---|---|---|
| bs=32 | OOM | 115.52 | OOM | 119.04 | 112.0 (0.94x) |
| bs=16 | OOM | 104 | 38.08 | 106.5 (2.79x) | 101.7 (0.95x) |
| bs=4 | 30.4 | 63.2 (2.07x) | 29.4 | 63.72 (2.16x) | 92.00 (1.44x) |
| bs=1 | 19.20 | N/A | 22.0 | N/A | N/A |
We observe that:
Since we did not use a custom attention kernel for our TPUs (unlike our CUDA code), we observe that TPU model can’t catch up with the efficiency of flex attention baseline.
Next is more complicated, according to the EAGLE-3
TTT=8, Sequence Length = 2048 (it/s)
| H100 (sdpa) | 4xH100 (sdpa) | H100 (flex) | 4xH100 (flex) | v6e-4 | |
|---|---|---|---|---|---|
| bs=8 | OOM | 17.36 | OOM | 23.68 | 16.00 (0.67x) |
| bs=4 | 4.68 | 15.32 (3.27x) | 6.64 | 20.60 (3.10x) | 15.72 (0.76x) |
| bs=1 | 4.11 | N/A | 5.51 | N/A | N/A |
TTT=8, Sequence Length = Dynamic (it/s)
| H100 (sdpa) | 4xH100 (sdpa) | H100 (flex) | 4xH100 (flex) | v6e-4 | |
|---|---|---|---|---|---|
| bs=8 | OOM | 43.84 | OOM | 48.88 | 43.63 (0.89x) |
| bs=4 | 15.2 | 33.00 (2.17x) | 15.28 | 34.2 (2.23x) | 41.60 (1.21x) |
| bs=1 | 10.70 | N/A | 11.75 | N/A | N/A |
Here we see that TPUs are less efficient. It can be attributed to three factors,
Let’s take a look at our roofline analysis. Most of our operations are pretty close to Pareto optimality, however we still have operations that are inefficient. In this graph, the yellow dots (loop fusion), are memory-bound, whereas the blue dots (convolution fusion) are compute-bound. I suspect those yellow dots are results of the attention, whereas the blue dots are other operations MLP layers. Let’s try to validate that.
We can take a look at the HLO Op Profiler to see which operations have spent the most amount of time.
Here we see the how much time is spent on compute (FLOPs), memory (HBM) and wasted. Let’s inspect HLO Graph of an operation that has high ratio of wasted time.
The input shape 4, 32, 2048, 64 validates our claim that this loop fusion is a part of attention. 32 is our number of heads and 64 comes from splitting the head dimension 128 into two during RoPE. Also note that our last dimension 64 is padded to 128, this means our MXU could be under-utilized. Therefore an efficient attention implementation for those small models is cruicial to get better performance.
In this post, I walked through how speculative decoding models work and how to train one. I then migrated the training code to run on TPUs, optimizing it step by step until performance nearly matched multi-GPU setups. That said, there’s still room for improvement—particularly by using an attention kernel optimized for XLA. PyTorch XLA has experimental flash attention kernels, but integrating them into our architecture requires some additional work.
Looking at the remaining inefficiencies in the convolution fusion, I noticed that some computation graphs include many u32[] and s32[] dependencies. I’m not yet sure why these appear, but removing them would likely improve performance further.
In the future I will also train drafters for much larger models as well. Hopefully their performance would be more optimal, as their weights are bigger, they should have parallelization issues.
If you have any questions or comments, please feel free to reach out to me.
@misc{eldenk2026migratingtotpu,
title={Training language models on {TPUs} shouldn't be scary},
author={Doğaç Eldenk},
year={2026},
month=feb,
howpublished={\url{https://dogac.dev/blog/2026/migrating-to-tpu/}},
note={Blog post}
}
Here are some more articles you might like to read next: