19#include "mlir/Dialect/Affine/IR/AffineOps.h"
20#include "mlir/Dialect/Arith/IR/Arith.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/Vector/IR/VectorOps.h"
23#include "mlir/IR/PatternMatch.h"
24#include "mlir/Pass/Pass.h"
25#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27#define DEBUG_TYPE "aievec-split-load-ups-chains"
53 LogicalResult matchAndRewrite(UPSOp upsOp,
54 PatternRewriter &rewriter)
const override {
56 Value source = upsOp.getSource();
57 auto srcVecTy = dyn_cast<VectorType>(source.getType());
62 auto resultVecTy = dyn_cast<VectorType>(upsOp.getResult().getType());
67 Type srcElemTy = srcVecTy.getElementType();
68 Type resultElemTy = resultVecTy.getElementType();
70 if (!srcElemTy.isInteger() || !resultElemTy.isInteger())
73 unsigned srcBitWidth = srcElemTy.getIntOrFloatBitWidth();
74 unsigned resultBitWidth = resultElemTy.getIntOrFloatBitWidth();
75 int64_t srcLanes = srcVecTy.getNumElements();
76 int64_t resultLanes = resultVecTy.getNumElements();
78 int64_t srcVectorSize = srcBitWidth * srcLanes;
79 int64_t resultVectorSize = resultBitWidth * resultLanes;
83 if (srcVectorSize != 1024 || resultVectorSize != 2048)
87 if (resultBitWidth != 32 || srcBitWidth != 16)
91 auto loadOp = source.getDefiningOp<vector::LoadOp>();
96 if (!loadOp.getResult().hasOneUse())
99 Location loc = upsOp.getLoc();
102 Value memRef = loadOp.getBase();
103 ValueRange indices = loadOp.getIndices();
106 int64_t halfLanes = srcLanes / 2;
107 auto halfSrcVecTy = VectorType::get({halfLanes}, srcElemTy);
108 auto halfResultVecTy = VectorType::get({halfLanes}, resultElemTy);
112 int64_t elementOffset = halfLanes;
115 SmallVector<Value> firstHalfIndices(indices.begin(), indices.end());
118 SmallVector<Value> secondHalfIndices(indices.begin(), indices.end());
121 if (!indices.empty()) {
122 Value lastIdx = indices.back();
124 arith::ConstantIndexOp::create(rewriter, loc, elementOffset);
126 arith::AddIOp::create(rewriter, loc, lastIdx, offsetVal);
127 secondHalfIndices.back() = newLastIdx;
131 auto loadHalf0 = vector::LoadOp::create(rewriter, loc, halfSrcVecTy, memRef,
135 auto loadHalf1 = vector::LoadOp::create(rewriter, loc, halfSrcVecTy, memRef,
139 auto upsHalf0 = UPSOp::create(rewriter, loc, halfResultVecTy,
140 loadHalf0.getResult(), upsOp.getShift());
143 auto upsHalf1 = UPSOp::create(rewriter, loc, halfResultVecTy,
144 loadHalf1.getResult(), upsOp.getShift());
148 SmallVector<int64_t> concatMask;
149 for (int64_t i = 0; i < resultLanes; ++i) {
150 concatMask.push_back(i);
153 auto concatOp = vector::ShuffleOp::create(
154 rewriter, loc, upsHalf0.getResult(), upsHalf1.getResult(), concatMask);
157 rewriter.replaceOp(upsOp, concatOp.getResult());
182struct SplitVectorSrsStoreChainPattern
186 LogicalResult matchAndRewrite(vector::StoreOp storeOp,
187 PatternRewriter &rewriter)
const override {
189 Value valueToStore = storeOp.getValueToStore();
190 auto storeVecTy = dyn_cast<VectorType>(valueToStore.getType());
195 auto srsOp = valueToStore.getDefiningOp<SRSOp>();
200 if (!srsOp.getResult().hasOneUse())
204 Value srsSource = srsOp.getSource();
205 auto srcVecTy = dyn_cast<VectorType>(srsSource.getType());
209 Type srcElemTy = srcVecTy.getElementType();
210 Type resultElemTy = storeVecTy.getElementType();
212 if (!srcElemTy.isInteger() || !resultElemTy.isInteger())
215 unsigned srcBitWidth = srcElemTy.getIntOrFloatBitWidth();
216 unsigned resultBitWidth = resultElemTy.getIntOrFloatBitWidth();
217 int64_t srcLanes = srcVecTy.getNumElements();
218 int64_t resultLanes = storeVecTy.getNumElements();
220 int64_t srcVectorSize = srcBitWidth * srcLanes;
221 int64_t resultVectorSize = resultBitWidth * resultLanes;
225 if (srcVectorSize != 2048 || resultVectorSize != 1024)
229 if (srcBitWidth != 32 || resultBitWidth != 16)
232 Location loc = storeOp.getLoc();
235 Value memRef = storeOp.getBase();
236 ValueRange indices = storeOp.getIndices();
239 int64_t halfSrcLanes = srcLanes / 2;
240 int64_t halfResultLanes = resultLanes / 2;
241 auto halfResultVecTy = VectorType::get({halfResultLanes}, resultElemTy);
244 SmallVector<int64_t> firstHalfMask, secondHalfMask;
245 for (int64_t i = 0; i < halfSrcLanes; ++i) {
246 firstHalfMask.push_back(i);
247 secondHalfMask.push_back(halfSrcLanes + i);
250 auto srcHalf0 = vector::ShuffleOp::create(rewriter, loc, srsSource,
251 srsSource, firstHalfMask);
252 auto srcHalf1 = vector::ShuffleOp::create(rewriter, loc, srsSource,
253 srsSource, secondHalfMask);
256 auto srsHalf0 = SRSOp::create(rewriter, loc, halfResultVecTy,
257 srcHalf0.getResult(), srsOp.getShift());
260 auto srsHalf1 = SRSOp::create(rewriter, loc, halfResultVecTy,
261 srcHalf1.getResult(), srsOp.getShift());
264 int64_t elementOffset = halfResultLanes;
267 SmallVector<Value> firstHalfIndices(indices.begin(), indices.end());
270 SmallVector<Value> secondHalfIndices(indices.begin(), indices.end());
273 if (!indices.empty()) {
274 Value lastIdx = indices.back();
276 arith::ConstantIndexOp::create(rewriter, loc, elementOffset);
278 arith::AddIOp::create(rewriter, loc, lastIdx, offsetVal);
279 secondHalfIndices.back() = newLastIdx;
283 vector::StoreOp::create(rewriter, loc, srsHalf0.getResult(), memRef,
287 vector::StoreOp::create(rewriter, loc, srsHalf1.getResult(), memRef,
291 rewriter.eraseOp(storeOp);
301struct SplitVectorLoadUpsChainsPass
302 :
public PassWrapper<SplitVectorLoadUpsChainsPass, OperationPass<>> {
303 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SplitVectorLoadUpsChainsPass)
305 StringRef getArgument()
const final {
return "aievec-split-load-ups-chains"; }
307 StringRef getDescription()
const final {
308 return "Split vector.load + aievec.ups chains to reduce shuffle operations";
311 void getDependentDialects(DialectRegistry ®istry)
const override {
312 registry.insert<vector::VectorDialect, arith::ArithDialect,
313 memref::MemRefDialect, affine::AffineDialect,
314 xilinx::aievec::AIEVecDialect>();
317 void runOnOperation()
override {
318 Operation *op = getOperation();
319 MLIRContext *context = &getContext();
320 RewritePatternSet patterns(context);
323 .add<SplitVectorLoadUpsChainPattern, SplitVectorSrsStoreChainPattern>(
326 if (failed(applyPatternsGreedily(op, std::move(patterns)))) {
338 return std::make_unique<SplitVectorLoadUpsChainsPass>();
std::unique_ptr<::mlir::Pass > createSplitVectorLoadUpsChainsPass()
Create a pass that splits vector.load + aievec.ups chains to reduce shuffle operations for AIE2p targ...