MLIR-AIE
SplitVectorLoadUpsChains.cpp
Go to the documentation of this file.
1//===- SplitVectorLoadUpsChains.cpp - Split Load+UPS Chains ----*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2025 Advanced Micro Devices Inc.
8//
9//===----------------------------------------------------------------------===//
10// This pass optimizes chains of vector.load followed by aievec.ups operations
11// for AIE2p targets. Instead of loading a 1024-bit vector and then shuffling
12// it into two halves for separate UPS operations (3 shuffles total), it splits
13// both the load and UPS into two 512-bit halves, requiring only 1 shuffle for
14// concatenation.
15//===----------------------------------------------------------------------===//
16
19#include "mlir/Dialect/Affine/IR/AffineOps.h"
20#include "mlir/Dialect/Arith/IR/Arith.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/Vector/IR/VectorOps.h"
23#include "mlir/IR/PatternMatch.h"
24#include "mlir/Pass/Pass.h"
25#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26
27#define DEBUG_TYPE "aievec-split-load-ups-chains"
28
29using namespace mlir;
30using namespace xilinx::aievec;
31
32namespace {
33
34/// Pattern to optimize vector.load + aievec.ups chains by splitting them.
35///
36/// This pattern detects cases where a 1024-bit vector is loaded and then
37/// passed to an aievec.ups operation that produces a 2048-bit result.
38/// Instead of the inefficient approach of:
39/// 1. Load 1024 bits
40/// 2. Shuffle to split into 2×512 bits
41/// 3. Apply 2× UPS operations
42/// 4. Shuffle to concatenate results
43///
44/// It transforms to:
45/// 1. Load 2×512 bits directly
46/// 2. Apply 2× UPS operations immediately
47/// 3. Shuffle once to concatenate results
48///
49/// This reduces shuffle operations from 3 to 1.
50struct SplitVectorLoadUpsChainPattern : public OpRewritePattern<UPSOp> {
52
53 LogicalResult matchAndRewrite(UPSOp upsOp,
54 PatternRewriter &rewriter) const override {
55 // Get source value and its type
56 Value source = upsOp.getSource();
57 auto srcVecTy = dyn_cast<VectorType>(source.getType());
58 if (!srcVecTy)
59 return failure();
60
61 // Get result type
62 auto resultVecTy = dyn_cast<VectorType>(upsOp.getResult().getType());
63 if (!resultVecTy)
64 return failure();
65
66 // Check if this is a 1024-bit -> 2048-bit integer UPS
67 Type srcElemTy = srcVecTy.getElementType();
68 Type resultElemTy = resultVecTy.getElementType();
69
70 if (!srcElemTy.isInteger() || !resultElemTy.isInteger())
71 return failure();
72
73 unsigned srcBitWidth = srcElemTy.getIntOrFloatBitWidth();
74 unsigned resultBitWidth = resultElemTy.getIntOrFloatBitWidth();
75 int64_t srcLanes = srcVecTy.getNumElements();
76 int64_t resultLanes = resultVecTy.getNumElements();
77
78 int64_t srcVectorSize = srcBitWidth * srcLanes;
79 int64_t resultVectorSize = resultBitWidth * resultLanes;
80
81 // Only optimize the 1024-bit -> 2048-bit case
82 // (e.g., v64int16 -> v64acc32)
83 if (srcVectorSize != 1024 || resultVectorSize != 2048)
84 return failure();
85
86 // Check that the UPS result width is 32 and source width is 16
87 if (resultBitWidth != 32 || srcBitWidth != 16)
88 return failure();
89
90 // Check if source is directly from a vector.load
91 auto loadOp = source.getDefiningOp<vector::LoadOp>();
92 if (!loadOp)
93 return failure();
94
95 // Ensure the load is only used by this UPS operation
96 if (!loadOp.getResult().hasOneUse())
97 return failure();
98
99 Location loc = upsOp.getLoc();
100
101 // Get load operation details
102 Value memRef = loadOp.getBase();
103 ValueRange indices = loadOp.getIndices();
104
105 // Create element type for half-sized vector (v32int16)
106 int64_t halfLanes = srcLanes / 2;
107 auto halfSrcVecTy = VectorType::get({halfLanes}, srcElemTy);
108 auto halfResultVecTy = VectorType::get({halfLanes}, resultElemTy);
109
110 // Calculate offset for second half load
111 // For v64int16, we need to offset by 32 elements (64 bytes for i16)
112 int64_t elementOffset = halfLanes;
113
114 // Create indices for first half load (same as original)
115 SmallVector<Value> firstHalfIndices(indices.begin(), indices.end());
116
117 // Create indices for second half load
118 SmallVector<Value> secondHalfIndices(indices.begin(), indices.end());
119
120 // Adjust the last index by the element offset
121 if (!indices.empty()) {
122 Value lastIdx = indices.back();
123 Value offsetVal =
124 arith::ConstantIndexOp::create(rewriter, loc, elementOffset);
125 Value newLastIdx =
126 arith::AddIOp::create(rewriter, loc, lastIdx, offsetVal);
127 secondHalfIndices.back() = newLastIdx;
128 }
129
130 // Create first half load
131 auto loadHalf0 = vector::LoadOp::create(rewriter, loc, halfSrcVecTy, memRef,
132 firstHalfIndices);
133
134 // Create second half load
135 auto loadHalf1 = vector::LoadOp::create(rewriter, loc, halfSrcVecTy, memRef,
136 secondHalfIndices);
137
138 // Create UPS for first half
139 auto upsHalf0 = UPSOp::create(rewriter, loc, halfResultVecTy,
140 loadHalf0.getResult(), upsOp.getShift());
141
142 // Create UPS for second half
143 auto upsHalf1 = UPSOp::create(rewriter, loc, halfResultVecTy,
144 loadHalf1.getResult(), upsOp.getShift());
145
146 // Concatenate the two halves using vector.shuffle
147 // The mask is sequential from 0 to 63 to concatenate [half0; half1]
148 SmallVector<int64_t> concatMask;
149 for (int64_t i = 0; i < resultLanes; ++i) {
150 concatMask.push_back(i);
151 }
152
153 auto concatOp = vector::ShuffleOp::create(
154 rewriter, loc, upsHalf0.getResult(), upsHalf1.getResult(), concatMask);
155
156 // Replace the original UPS operation with the concatenated result
157 rewriter.replaceOp(upsOp, concatOp.getResult());
158
159 // The original load will be removed by dead code elimination
160 // since it no longer has any uses
161
162 return success();
163 }
164};
165
166/// Pattern to optimize aievec.srs + vector.store chains by splitting them.
167///
168/// This pattern detects cases where a 2048-bit vector is passed to an
169/// aievec.srs operation that produces a 1024-bit result, which is then stored.
170/// Instead of the inefficient approach of:
171/// 1. Shuffle to split 2048-bit into 2×1024 bits
172/// 2. Apply 2× SRS operations
173/// 3. Shuffle to concatenate results
174/// 4. Store 1024 bits
175///
176/// It transforms to:
177/// 1. Split source via shuffle into 2×1024 bits (for SRS input)
178/// 2. Apply 2× SRS operations to get 2×512 bits
179/// 3. Store 2×512 bits directly
180///
181/// This reduces shuffle operations from 3 to 1.
182struct SplitVectorSrsStoreChainPattern
183 : public OpRewritePattern<vector::StoreOp> {
184 using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
185
186 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
187 PatternRewriter &rewriter) const override {
188 // Get the value being stored and its type
189 Value valueToStore = storeOp.getValueToStore();
190 auto storeVecTy = dyn_cast<VectorType>(valueToStore.getType());
191 if (!storeVecTy)
192 return failure();
193
194 // Check if the value comes from an aievec.srs operation
195 auto srsOp = valueToStore.getDefiningOp<SRSOp>();
196 if (!srsOp)
197 return failure();
198
199 // Ensure the SRS is only used by this store operation
200 if (!srsOp.getResult().hasOneUse())
201 return failure();
202
203 // Get source and result types of SRS
204 Value srsSource = srsOp.getSource();
205 auto srcVecTy = dyn_cast<VectorType>(srsSource.getType());
206 if (!srcVecTy)
207 return failure();
208
209 Type srcElemTy = srcVecTy.getElementType();
210 Type resultElemTy = storeVecTy.getElementType();
211
212 if (!srcElemTy.isInteger() || !resultElemTy.isInteger())
213 return failure();
214
215 unsigned srcBitWidth = srcElemTy.getIntOrFloatBitWidth();
216 unsigned resultBitWidth = resultElemTy.getIntOrFloatBitWidth();
217 int64_t srcLanes = srcVecTy.getNumElements();
218 int64_t resultLanes = storeVecTy.getNumElements();
219
220 int64_t srcVectorSize = srcBitWidth * srcLanes;
221 int64_t resultVectorSize = resultBitWidth * resultLanes;
222
223 // Only optimize the 2048-bit -> 1024-bit case
224 // (e.g., v64acc32 -> v64int16)
225 if (srcVectorSize != 2048 || resultVectorSize != 1024)
226 return failure();
227
228 // Check that the SRS source width is 32 and result width is 16
229 if (srcBitWidth != 32 || resultBitWidth != 16)
230 return failure();
231
232 Location loc = storeOp.getLoc();
233
234 // Get store operation details
235 Value memRef = storeOp.getBase();
236 ValueRange indices = storeOp.getIndices();
237
238 // Create element types for half-sized vectors
239 int64_t halfSrcLanes = srcLanes / 2;
240 int64_t halfResultLanes = resultLanes / 2;
241 auto halfResultVecTy = VectorType::get({halfResultLanes}, resultElemTy);
242
243 // Split the SRS source into two halves using shuffle
244 SmallVector<int64_t> firstHalfMask, secondHalfMask;
245 for (int64_t i = 0; i < halfSrcLanes; ++i) {
246 firstHalfMask.push_back(i);
247 secondHalfMask.push_back(halfSrcLanes + i);
248 }
249
250 auto srcHalf0 = vector::ShuffleOp::create(rewriter, loc, srsSource,
251 srsSource, firstHalfMask);
252 auto srcHalf1 = vector::ShuffleOp::create(rewriter, loc, srsSource,
253 srsSource, secondHalfMask);
254
255 // Create SRS for first half
256 auto srsHalf0 = SRSOp::create(rewriter, loc, halfResultVecTy,
257 srcHalf0.getResult(), srsOp.getShift());
258
259 // Create SRS for second half
260 auto srsHalf1 = SRSOp::create(rewriter, loc, halfResultVecTy,
261 srcHalf1.getResult(), srsOp.getShift());
262
263 // Calculate offset for second half store
264 int64_t elementOffset = halfResultLanes;
265
266 // Create indices for first half store (same as original)
267 SmallVector<Value> firstHalfIndices(indices.begin(), indices.end());
268
269 // Create indices for second half store
270 SmallVector<Value> secondHalfIndices(indices.begin(), indices.end());
271
272 // Adjust the last index by the element offset
273 if (!indices.empty()) {
274 Value lastIdx = indices.back();
275 Value offsetVal =
276 arith::ConstantIndexOp::create(rewriter, loc, elementOffset);
277 Value newLastIdx =
278 arith::AddIOp::create(rewriter, loc, lastIdx, offsetVal);
279 secondHalfIndices.back() = newLastIdx;
280 }
281
282 // Create first half store
283 vector::StoreOp::create(rewriter, loc, srsHalf0.getResult(), memRef,
284 firstHalfIndices);
285
286 // Create second half store
287 vector::StoreOp::create(rewriter, loc, srsHalf1.getResult(), memRef,
288 secondHalfIndices);
289
290 // Erase the original store operation
291 rewriter.eraseOp(storeOp);
292
293 // The original SRS will be removed by dead code elimination
294 // since it no longer has any uses
295
296 return success();
297 }
298};
299
300/// Pass to split vector.load + aievec.ups chains for better performance
301struct SplitVectorLoadUpsChainsPass
302 : public PassWrapper<SplitVectorLoadUpsChainsPass, OperationPass<>> {
303 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SplitVectorLoadUpsChainsPass)
304
305 StringRef getArgument() const final { return "aievec-split-load-ups-chains"; }
306
307 StringRef getDescription() const final {
308 return "Split vector.load + aievec.ups chains to reduce shuffle operations";
309 }
310
311 void getDependentDialects(DialectRegistry &registry) const override {
312 registry.insert<vector::VectorDialect, arith::ArithDialect,
313 memref::MemRefDialect, affine::AffineDialect,
314 xilinx::aievec::AIEVecDialect>();
315 }
316
317 void runOnOperation() override {
318 Operation *op = getOperation();
319 MLIRContext *context = &getContext();
320 RewritePatternSet patterns(context);
321
322 patterns
323 .add<SplitVectorLoadUpsChainPattern, SplitVectorSrsStoreChainPattern>(
324 context);
325
326 if (failed(applyPatternsGreedily(op, std::move(patterns)))) {
327 signalPassFailure();
328 }
329 }
330};
331
332} // namespace
333
334namespace xilinx {
335namespace aievec {
336
337std::unique_ptr<::mlir::Pass> createSplitVectorLoadUpsChainsPass() {
338 return std::make_unique<SplitVectorLoadUpsChainsPass>();
339}
340
341} // namespace aievec
342} // namespace xilinx
std::unique_ptr<::mlir::Pass > createSplitVectorLoadUpsChainsPass()
Create a pass that splits vector.load + aievec.ups chains to reduce shuffle operations for AIE2p targ...