MLIR-AIE
FoldMulAddChainToConvOp.cpp
Go to the documentation of this file.
1//===--FoldMulAddChainToConvOp.cpp - Fold Mul Add Chain To AIEVec Conv Op--===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2023 Xilinx Inc.
8//
9//===----------------------------------------------------------------------===//
10// This is the implementation of the folding pass from mul add chain
11// to AIEVec convolution operations, compatible with the AIE2 architecture.
12//===----------------------------------------------------------------------===//
13
15
20#include "mlir/Analysis/SliceAnalysis.h"
21#include "mlir/Dialect/UB/IR/UBOps.h"
22#include "mlir/IR/PatternMatch.h"
23#include "mlir/Transforms/DialectConversion.h"
24#include "llvm/Support/Debug.h"
25
26#define DEBUG_TYPE "fold-mul-add-chain-to-conv"
27
28using namespace mlir;
29using namespace arith;
30using namespace vector;
31using namespace xilinx;
32using namespace xilinx::aievec;
33
34namespace xilinx::aievec {
35#define GEN_PASS_DEF_AIEVECCONVANALYSIS
36#include "aie/Dialect/AIEVec/Analysis/Passes.h.inc"
37} // namespace xilinx::aievec
38
39/// This analysis builds the longest possible chain of MAC operations whose
40/// operands are a vector that may or may not be shifted, and a broadcast.
41/// That is, these MACs represent `vector x scalar` ops, and are candidates to
42/// be grouped and replaced by mul_conv/fma_conv ops in AIE2.
43//
44// We build this chain recursively, climbing up the
46 static AnalysisManager *am;
47
48 struct ConvMac {
49 // If there's a non-accumulating convolution upchain,
50 // store it here temorarily.
51 std::unique_ptr<ConvMac> topOfChainMulConv;
52 // Accumulator value, if there is one.
53 Value acc;
54 // Left-hand side (non-broadcasting) source value
55 Value lhs;
56 // Left-hand side (broadcasting) source value
57 Value rhs;
58 // Amount that lhs is shifted
59 uint8_t shift;
60 // Element in rhs that is broadcasted
61 uint8_t bcastIdx;
62 ConvMac(Value lhs, Value rhs, uint8_t shift, uint8_t bcastIdx)
63 : topOfChainMulConv(nullptr), acc(nullptr), lhs(lhs), rhs(rhs),
65 };
66
68 // Group start index within the chain
69 uint64_t fromIdx;
70 // Index in chain after group last MAC
71 uint64_t toIdx;
72 // Initial position of the signal to be convolved
73 int64_t signalShift;
74 // Initial position of the convolution filter
75 int64_t bcastShift;
76 // Distance between elements in the filter
77 int64_t bcastDist; // Must be 1 or 2
78 };
79
80 typedef SmallVector<std::unique_ptr<ConvMac>, 8> ConvMacChain;
81 typedef SmallVector<ConvMacChainGroup, 8> ConvMacChainGroupList;
82
83 std::unique_ptr<ConvMacChain> convMacChain;
85
86 /// Sort the chain of MACs by sources. When two MACs share the same sources,
87 /// sort them by the broadcast index. If they don't, sort them by the order
88 /// of the ops in the code. This function should be called after the chain
89 /// is completed, and before operating on the groups of MACs. After sorting,
90 /// MACs that can be fused into single convolution ops will be contiguous in
91 /// the chain.
92 void sortChain() {
93 if ((*convMacChain)[0]->acc) {
94 std::sort(convMacChain->begin(), convMacChain->end(),
95 [](const auto &a, const auto &b) {
96 if (a->lhs == b->lhs) {
97 if (a->rhs == b->rhs)
98 return a->bcastIdx < b->bcastIdx;
99 return a->rhs.getDefiningOp()->isBeforeInBlock(
100 b->rhs.getDefiningOp());
101 }
102 // We should probably sort by lhs load address, if it exists
103 // XXX: We assume all MACs in the same block. If they're not,
104 // XXX: this will assert.
105 return a->lhs.getDefiningOp()->isBeforeInBlock(
106 b->lhs.getDefiningOp());
107 });
108 } else {
109 // If the top of the chain is not an accumulation, bring up all related
110 // convolution MACs and sort the rest by lhs.
111 auto firstLhs = (*convMacChain)[0]->lhs;
112 std::sort(convMacChain->begin(), convMacChain->end(),
113 [&firstLhs](const auto &a, const auto &b) {
114 if (a->lhs == b->lhs) {
115 if (a->rhs == b->rhs)
116 return a->bcastIdx < b->bcastIdx;
117 return a->rhs.getDefiningOp()->isBeforeInBlock(
118 b->rhs.getDefiningOp());
119 }
120 if (a->lhs == firstLhs)
121 return true;
122 if (b->lhs == firstLhs)
123 return false;
124 return a->lhs.getDefiningOp()->isBeforeInBlock(
125 b->lhs.getDefiningOp());
126 });
127 // Float the empty accumulator to the top.
128 if ((*convMacChain)[0]->acc)
129 for (auto &convMac : *convMacChain)
130 if (!convMac->acc) {
131 std::swap((*convMacChain)[0]->acc, convMac->acc);
132 break;
133 }
134 }
135 }
136
137 // Return the list of convolution mac ops in the chain as pairs of indices
138 // indicating the position within the chain where a group starts and the
139 // position where it ends: [start, end). If they have not been precomputed
140 // yet, this method will generate them.
142 // If there's no group or it's been computed already, return stored list.
143 if (groupsInChain.size() > 0 || !convMacChain || convMacChain->size() == 0)
144 return groupsInChain;
145
146 uint64_t grpStartIdx = 0;
147 uint64_t grpCurIdx = 0;
148 Value curLhs = (*convMacChain)[0]->lhs;
149 Value curRhs = (*convMacChain)[0]->rhs;
150 for (const auto &convMac : *convMacChain) {
151 if (grpCurIdx > grpStartIdx) {
152 if (curLhs != convMac->lhs || curRhs != convMac->rhs) {
153 groupsInChain.push_back({grpStartIdx, grpCurIdx,
154 getGroupSignalShift(grpStartIdx, grpCurIdx),
155 getGroupBcastShift(grpStartIdx, grpCurIdx),
156 getGroupBcastDist(grpStartIdx, grpCurIdx)});
157 grpStartIdx = grpCurIdx;
158 curLhs = convMac->lhs;
159 curRhs = convMac->rhs;
160 }
161 }
162 grpCurIdx++;
163 }
164 if (grpStartIdx < grpCurIdx)
165 groupsInChain.push_back({grpStartIdx, grpCurIdx,
166 getGroupSignalShift(grpStartIdx, grpCurIdx),
167 getGroupBcastShift(grpStartIdx, grpCurIdx),
168 getGroupBcastDist(grpStartIdx, grpCurIdx)});
169 return groupsInChain;
170 }
171
172 // Return the signal shift for the group in the MAC chain in [fromIdx, toIdx)
173 // the top. This method verifies that the elements of the signal are
174 // contiguously accessed. If they do not, or the specified group doesn't
175 // exist, this function returns -1.
176 int64_t getGroupSignalShift(uint64_t fromIdx, uint64_t toIdx) {
177 if (fromIdx >= toIdx || toIdx > convMacChain->size())
178 return -1;
179 if (toIdx == fromIdx + 1)
180 return static_cast<int64_t>((*convMacChain)[fromIdx]->shift);
181 for (uint64_t i = fromIdx; i < toIdx - 1; i++)
182 if ((static_cast<int64_t>((*convMacChain)[i + 1]->shift) -
183 static_cast<int64_t>((*convMacChain)[i]->shift)) != 1)
184 return -1;
185 return static_cast<int64_t>((*convMacChain)[fromIdx]->shift);
186 }
187
188 // Return the shift in value of the first broadcasted element in the i-th
189 // group. If there is no chain, or the i-th group does not exist,
190 // returns -1.
191 int64_t getGroupBcastShift(uint64_t fromIdx, uint64_t toIdx) {
192 if (fromIdx >= toIdx || toIdx > convMacChain->size())
193 return -1;
194 return static_cast<int64_t>((*convMacChain)[fromIdx]->bcastIdx);
195 }
196
197 // Returns the broadcast distance between elements within the group. If the
198 // distance is not constant and equal to 1 or 2, it returns -1.
199 int64_t getGroupBcastDist(uint64_t fromIdx, uint64_t toIdx) {
200 if (fromIdx >= toIdx || toIdx > convMacChain->size())
201 return -1;
202 if (toIdx == fromIdx + 1)
203 return 1;
204 int64_t bcastDist =
205 static_cast<int64_t>((*convMacChain)[fromIdx + 1]->bcastIdx) -
206 static_cast<int64_t>((*convMacChain)[fromIdx]->bcastIdx);
207 if (bcastDist != 1 && bcastDist != 2)
208 return -1;
209 for (uint64_t i = fromIdx + 1; i < toIdx - 1; i++)
210 if ((static_cast<int64_t>((*convMacChain)[i + 1]->bcastIdx) -
211 static_cast<int64_t>((*convMacChain)[i]->bcastIdx)) != bcastDist)
212 return -1;
213 return bcastDist;
214 }
215
217 const auto &groups = getGroupsInChain();
218 if (groups.size() == 0)
219 return false;
220 for (const auto &group : groups)
221 if (group.signalShift == -1 || group.bcastShift == -1 ||
222 group.bcastDist == -1)
223 return false;
224 return true;
225 }
226
227 std::unique_ptr<ConvMac> getConvMacFromMulOp(arith::MulIOp mulOp) {
228 auto mulOpLhsDefOp = mulOp.getLhs().getDefiningOp();
229 auto mulOpRhsDefOp = mulOp.getRhs().getDefiningOp();
230 if (!mulOpLhsDefOp || !mulOpRhsDefOp)
231 return nullptr;
232
233 Value convMacRhs = nullptr;
234 uint8_t convMacBcastIdx = 0;
235
236 auto getConvMacRhs = [&](Operation *mulOpOperand) -> bool {
237 SetVector<Operation *> opBwdSlices;
238 auto opFilter = [](Operation *op) {
239 return isa<aievec::BroadcastOp>(op) || isa<aievec::ExtOp>(op) ||
240 isa<aievec::ConcatOp>(op);
241 };
242 BackwardSliceOptions backwardSliceOptions;
243 backwardSliceOptions.filter = opFilter;
244
245 (void)getBackwardSlice(mulOpOperand, &opBwdSlices, backwardSliceOptions);
246 opBwdSlices.insert(mulOpOperand);
247
248 LLVM_DEBUG(llvm::dbgs() << "opBwdSlices = [\n");
249 for ([[maybe_unused]] auto op : opBwdSlices) {
250 LLVM_DEBUG(llvm::dbgs() << *op << "\n");
251 }
252 LLVM_DEBUG(llvm::dbgs() << "]\n");
253
254 if (opBwdSlices.size() == 1) {
255 if (auto bcastOp = dyn_cast<aievec::BroadcastOp>(opBwdSlices[0])) {
256 convMacRhs = bcastOp.getSource();
257 convMacBcastIdx = bcastOp.getIdx();
258 return true;
259 }
260 } else if (opBwdSlices.size() >= 3) {
261 auto sliceSz = opBwdSlices.size();
262 if ((isa<aievec::ExtOp>(opBwdSlices[sliceSz - 3]) &&
263 isa<aievec::BroadcastOp>(opBwdSlices[sliceSz - 2]) &&
264 isa<aievec::ConcatOp>(opBwdSlices[sliceSz - 1])) ||
265 (isa<aievec::ConcatOp>(opBwdSlices[sliceSz - 3]) &&
266 isa<aievec::BroadcastOp>(opBwdSlices[sliceSz - 2]) &&
267 isa<aievec::ExtOp>(opBwdSlices[sliceSz - 1]))) {
268 convMacRhs = opBwdSlices[sliceSz - 3]->getOperand(0);
269 convMacBcastIdx =
270 dyn_cast<aievec::BroadcastOp>(opBwdSlices[sliceSz - 2]).getIdx();
271 return true;
272 }
273 }
274
275 return false;
276 };
277
278 // Obtain the broadcast operation feeding into the MulIOp
279 if (!getConvMacRhs(mulOpRhsDefOp)) {
280 if (getConvMacRhs(mulOpLhsDefOp)) {
281 std::swap(mulOpLhsDefOp, mulOpRhsDefOp);
282 }
283 }
284 if (!convMacRhs)
285 return nullptr;
286
287 // Obtain the ext or ext->shift op feeding into the MulIOp
288 aievec::ExtOp extOp;
289 aievec::ShiftOp shiftOp;
290 shiftOp = dyn_cast<aievec::ShiftOp>(mulOpLhsDefOp);
291 if (shiftOp)
292 extOp = shiftOp.getLhs().getDefiningOp<aievec::ExtOp>();
293 else
294 extOp = dyn_cast<aievec::ExtOp>(mulOpLhsDefOp);
295
296 // XXX: Actually, ExtOp might not exist but should work anyway.
297 // XXX: Should it, though?
298 if (!extOp)
299 return nullptr;
300
301 Value convMacLhs = extOp.getSource();
302 uint8_t shift = 0;
303 if (shiftOp) {
304 auto shiftConstDefOp =
305 shiftOp.getShift().getDefiningOp<arith::ConstantOp>();
306 if (shiftConstDefOp) {
307 auto shiftAttr = cast<IntegerAttr>(shiftConstDefOp.getValue());
308 auto vType = cast<VectorType>(mulOp.getResult().getType());
309 shift = 8 * shiftAttr.getInt() / getElementSizeInBits(vType);
310 }
311 }
312
313 return std::make_unique<ConvMac>(convMacLhs, convMacRhs, shift,
314 convMacBcastIdx);
315 }
316
317 std::unique_ptr<ConvMac> getConvMacFromAddOp(arith::AddIOp addOp) {
318 // Make sure at least one of them is a multiplication, and the other one
319 // is the accumulator coming form upchain.
320 auto mulOp = addOp.getLhs().getDefiningOp<arith::MulIOp>();
321 Value acc = addOp.getRhs();
322 if (!mulOp) {
323 mulOp = addOp.getRhs().getDefiningOp<arith::MulIOp>();
324 acc = addOp.getLhs();
325 }
326 if (!mulOp)
327 return nullptr;
328
329 // Get the parameters of the convolution from the operands of the MulIOp
330 auto convMac = getConvMacFromMulOp(mulOp);
331 if (!convMac)
332 return nullptr;
333
334 // If both sides are MulIOp, we might be at the top of the chain
335 auto upChainAccMulOp = acc.getDefiningOp<arith::MulIOp>();
336 if (upChainAccMulOp) {
337 auto convMac2 = getConvMacFromMulOp(upChainAccMulOp);
338 // XXX: We pre-sort the top two MACs to make sure that an undefined
339 // XXX: accumulator ends up on top of the chain.
340 // XXX: But it might not be necessary? CHECK!
341 if (convMac2 && convMac->lhs == convMac2->lhs &&
342 convMac->rhs == convMac->rhs) {
343 if (convMac->bcastIdx < convMac2->bcastIdx &&
344 convMac->shift < convMac2->shift) {
345 convMac2->topOfChainMulConv = std::move(convMac);
346 convMac2->acc = acc;
347 return convMac2;
348 } else if (convMac->bcastIdx > convMac2->bcastIdx &&
349 convMac->shift > convMac2->shift) {
350 convMac->topOfChainMulConv = std::move(convMac2);
351 convMac->acc = acc;
352 return convMac;
353 } else {
354 // WARNING: In this situation, the chain is ambiguous and picking one
355 // WARNING: option over the other may result in a successful
356 // WARNING: and/or better replacement. Here, we are assuming that
357 // WARNING: is going to be either one or the other, or it won't
358 // WARNING: matter.
359 }
360 } else {
361 convMac->topOfChainMulConv = std::move(convMac2);
362 }
363 }
364 convMac->acc = acc;
365 return convMac;
366 }
367
368 LongestConvMACChainAnalysis(arith::AddIOp addOp) {
369 std::unique_ptr<ConvMac> macConvChainElem = getConvMacFromAddOp(addOp);
370 if (!macConvChainElem)
371 return;
372
373 if (macConvChainElem->acc) {
374 auto upChainAddOp = macConvChainElem->acc.getDefiningOp<arith::AddIOp>();
375 if (upChainAddOp) {
376 auto &upChainChainAnalysis =
377 am->getChildAnalysis<LongestConvMACChainAnalysis>(upChainAddOp);
378 if (upChainChainAnalysis.convMacChain) {
379 convMacChain = std::move(upChainChainAnalysis.convMacChain);
380 convMacChain->push_back(std::move(macConvChainElem));
381 return;
382 }
383 }
384 }
385 assert(!convMacChain && "Convolution MAC chain unexpectedly not empty");
386 convMacChain = std::make_unique<ConvMacChain>();
387 if (macConvChainElem->topOfChainMulConv)
388 convMacChain->push_back(std::move(macConvChainElem->topOfChainMulConv));
389 convMacChain->push_back(std::move(macConvChainElem));
390 }
391};
392// HACK: For some reason, it's not possible to access the analysis manager from
393// HACK: within an analysis, but we need it to build the analysis recursively.
394// HACK: If there is a good reason not to do this, we should find an
395// HACK: alternative way to build the MAC chain.
396AnalysisManager *LongestConvMACChainAnalysis::am = nullptr;
397
398// This conversion pattern folds a MAC chain into mul_conv and mac_conv
399// ops. We can handle the mul MAC with a random order.
401 : public OpConversionPattern<arith::AddIOp> {
402 using OpConversionPattern<arith::AddIOp>::OpConversionPattern;
403
404 FoldMulAddChainToConvOpPattern(MLIRContext *context, AnalysisManager &am,
405 unsigned shiftParam = 0)
406 : OpConversionPattern<arith::AddIOp>(context), am(am),
408
409 LogicalResult
410 matchAndRewrite(arith::AddIOp srcOp, OpAdaptor adaptor,
411 ConversionPatternRewriter &rewriter) const override {
412 auto &convMacChainAnalysis =
413 am.getChildAnalysis<LongestConvMACChainAnalysis>(srcOp);
414 auto &convMacChain = convMacChainAnalysis.convMacChain;
415 if (!convMacChain)
416 return failure();
417
418 auto loc = srcOp.getLoc();
419 VectorType vecTy = cast<VectorType>(srcOp.getResult().getType());
420 unsigned elemWidth = cast<IntegerType>(vecTy.getElementType()).getWidth();
421 unsigned accWidth = elemWidth <= 8 ? 32 : 64;
422 int32_t M = elemWidth == 8 ? 32 : 16;
423 int32_t N = elemWidth == 8 ? 8 : 4;
424
425 Type wideElemTy = IntegerType::get(getContext(), accWidth);
426 Type accVecTy = VectorType::get(vecTy.getShape(), wideElemTy);
427
428 const auto &groups = convMacChainAnalysis.getGroupsInChain();
429 Value grpAcc = (*convMacChain)[groups[0].fromIdx]->acc;
430 if (grpAcc)
431 grpAcc = rewriter
432 .create<aievec::UPSOp>(srcOp.getLoc(), accVecTy, grpAcc,
433 /*shift=*/0)
434 .getResult();
435 for (const auto &group : groups) {
436 Value grpLhs = (*convMacChain)[group.fromIdx]->lhs;
437 Value grpRhs = (*convMacChain)[group.fromIdx]->rhs;
438 auto filterVecTy = cast<VectorType>(grpRhs.getType());
439 auto signalVecTy = cast<VectorType>(grpLhs.getType());
440 // Sort out the vector used as filter
441 // If the length of the filter is half that of the signal, concatenate
442 // the filter with itself.
443 if (2 * filterVecTy.getShape()[0] == signalVecTy.getShape()[0])
444 grpRhs =
445 rewriter
446 .create<aievec::ConcatOp>(
447 loc, signalVecTy, SmallVector<Value, 2>({grpRhs, grpRhs}))
448 .getResult();
449 // If the filter has duplicate elements, pack them.
450 if (group.bcastDist == 2)
451 // NOTE: This shuffle mode works for `vector<64xi8>`
452 grpRhs = rewriter
453 .create<aievec::ShuffleOp>(loc, signalVecTy, grpRhs,
454 grpRhs, ShuffleMode::T8_64X2_LO)
455 .getResult();
456 // If the first element of the filter to be used is not 0, shift the
457 // filter to align the first element to the beginning.
458 if (group.bcastShift) {
459 int32_t shiftBytes =
460 group.bcastShift * getElementSizeInBits(filterVecTy) >>
461 (3 + group.bcastDist - 1);
462 auto shiftBytesCst =
463 rewriter
464 .create<arith::ConstantOp>(
465 loc, rewriter.getI32IntegerAttr(shiftBytes))
466 .getResult();
467 grpRhs = rewriter
468 .create<aievec::ShiftOp>(grpRhs.getDefiningOp()->getLoc(),
469 signalVecTy, grpRhs, grpRhs,
470 shiftBytesCst)
471 .getResult();
472 }
473 // Sort out the vector used as signal
474 // If the signal to be convolved doesn't start at element 0, shift the
475 // signal to align the first element to the beginning.
476 if (group.signalShift) {
477 int32_t shiftBytes =
478 group.signalShift * getElementSizeInBits(signalVecTy) >> 3;
479 auto shiftBytesCst =
480 rewriter
481 .create<arith::ConstantOp>(
482 loc, rewriter.getI32IntegerAttr(shiftBytes))
483 .getResult();
484 grpLhs = rewriter
485 .create<aievec::ShiftOp>(loc, signalVecTy, grpLhs, grpLhs,
486 shiftBytesCst)
487 .getResult();
488 }
489 // Generate a convolution operation for the group
490 // If there is no upchain accumulator, use a mul_conv; use a mac_conv
491 // otherwise.
492 if (!grpAcc)
493 grpAcc = rewriter
494 .create<aievec::MulConvOp>(srcOp.getLoc(), accVecTy,
495 grpLhs, grpRhs, M, N)
496 .getResult();
497 else
498 grpAcc =
499 rewriter
500 .create<aievec::FMAConvOp>(srcOp.getLoc(), accVecTy, grpLhs,
501 grpRhs, grpAcc, M, N, false)
502 .getResult();
503 }
504
505 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
506 srcOp.getLoc(), rewriter.getI32IntegerAttr(shiftParam));
507 rewriter.replaceOpWithNewOp<aievec::SRSOp>(srcOp, vecTy, grpAcc,
508 shiftParamOp.getResult());
509 return success();
510 }
511
512 AnalysisManager &am;
513 unsigned shiftParam;
514};
515
516namespace xilinx::aievec {
517
519 AnalysisManager &am,
520 TargetBackend backend) {
522 target.addLegalDialect<AIEVecDialect>();
523 target.addLegalDialect<arith::ArithDialect>();
524 target.addLegalDialect<ub::UBDialect>();
525 target.addDynamicallyLegalOp<arith::AddIOp>([&am](arith::AddIOp op) {
526 auto &convAnalysis = am.getChildAnalysis<LongestConvMACChainAnalysis>(op);
527 return !convAnalysis.canChainBeReplacedWithConvOps();
528 });
529}
530
531void populateAIEVecConvOpTransformationPatterns(RewritePatternSet &patterns,
532 AnalysisManager &am,
533 unsigned shiftParam,
534 TargetBackend backend) {
535 patterns.add<FoldMulAddChainToConvOpPattern>(patterns.getContext(), am,
536 shiftParam);
537}
538
539struct AIEVecConvAnalysis : public AIEVecConvAnalysisBase<AIEVecConvAnalysis> {
544
545 void runOnOperation() override {
546 markAllAnalysesPreserved();
547 AnalysisManager am = getAnalysisManager();
549 Operation *op = getOperation();
550
551 // Compute all the chains
552 op->walk([&](arith::AddIOp addOp) {
553 if (isa<VectorType>(addOp.getResult().getType()))
554 am.getChildAnalysis<LongestConvMACChainAnalysis>(addOp);
555 });
556
557 // Sort the chains, ready to split by group
558 op->walk([&](arith::AddIOp addOp) {
559 if (isa<VectorType>(addOp.getResult().getType())) {
560 auto &analysis =
561 am.getChildAnalysis<LongestConvMACChainAnalysis>(addOp);
562 if (analysis.convMacChain)
563 analysis.sortChain();
564 }
565 });
566
567 if (printResult) {
568 op->walk([&](arith::AddIOp addOp) {
569 if (isa<VectorType>(addOp.getResult().getType())) {
570 auto &macChainAnalysis =
571 am.getChildAnalysis<LongestConvMACChainAnalysis>(addOp);
572 if (macChainAnalysis.canChainBeReplacedWithConvOps()) {
573 addOp.print(llvm::outs());
574 llvm::outs() << " is at the end of a convolution MAC Chain:\n";
575 listChain(macChainAnalysis.convMacChain,
576 macChainAnalysis.getGroupsInChain());
577 }
578 }
579 });
580 }
581 }
582
583 void listChain(const std::unique_ptr<ConvMacChain> &chain,
584 const ConvMacChainGroupList &groups) const {
585 uint64_t gIdx = 0;
586 for (const auto &group : groups) {
587 llvm::outs() << "-------------- GROUP " << std::to_string(gIdx)
588 << " --------------\n";
589 llvm::outs() << " Signal Shift: " << std::to_string(group.signalShift)
590 << " Kernel Shift: " << std::to_string(group.bcastShift)
591 << " Kernel Duplication: "
592 << std::to_string(group.bcastDist) << "\n";
593 for (uint64_t i = group.fromIdx; i < group.toIdx; i++) {
594 auto shift = (*chain)[i]->shift;
595 auto bcastIdx = (*chain)[i]->bcastIdx;
596 auto lhsOp = (*chain)[i]->lhs.getDefiningOp();
597 auto rhsOp = (*chain)[i]->rhs.getDefiningOp();
598 if (!(*chain)[i]->acc)
599 llvm::outs() << " [mul_conv]\n";
600 llvm::outs() << " [Shift: " << std::to_string(shift) << "]: ";
601 lhsOp->print(llvm::outs());
602 llvm::outs() << "\n [Bcast: " << std::to_string(bcastIdx) << "]: ";
603 rhsOp->print(llvm::outs());
604 llvm::outs() << "\n";
605 }
606 gIdx++;
607 }
608 llvm::outs() << "-------------------------------------\n";
609 }
610};
611
612std::unique_ptr<Pass> createAIEVecConvolutionAnalysisPass() {
613 return std::make_unique<AIEVecConvAnalysis>();
614}
615
616} // namespace xilinx::aievec
std::unique_ptr< mlir::Pass > createAIEVecConvolutionAnalysisPass()
void populateAIEVecConvOpTransformationPatterns(RewritePatternSet &patterns, AnalysisManager &am, unsigned shiftParam, TargetBackend backend)
int32_t getElementSizeInBits(mlir::VectorType type)
Definition AIEVecUtils.h:49
void configureAIEVecConvOpTransformationLegalizations(ConversionTarget &target, AnalysisManager &am, TargetBackend backend)
TargetBackend
Definition Passes.h:27
LogicalResult matchAndRewrite(arith::AddIOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
FoldMulAddChainToConvOpPattern(MLIRContext *context, AnalysisManager &am, unsigned shiftParam=0)
ConvMac(Value lhs, Value rhs, uint8_t shift, uint8_t bcastIdx)
This analysis builds the longest possible chain of MAC operations whose operands are a vector that ma...
SmallVector< ConvMacChainGroup, 8 > ConvMacChainGroupList
const ConvMacChainGroupList & getGroupsInChain()
void sortChain()
Sort the chain of MACs by sources.
int64_t getGroupBcastShift(uint64_t fromIdx, uint64_t toIdx)
std::unique_ptr< ConvMacChain > convMacChain
LongestConvMACChainAnalysis(arith::AddIOp addOp)
std::unique_ptr< ConvMac > getConvMacFromAddOp(arith::AddIOp addOp)
SmallVector< std::unique_ptr< ConvMac >, 8 > ConvMacChain
int64_t getGroupSignalShift(uint64_t fromIdx, uint64_t toIdx)
std::unique_ptr< ConvMac > getConvMacFromMulOp(arith::MulIOp mulOp)
int64_t getGroupBcastDist(uint64_t fromIdx, uint64_t toIdx)
LongestConvMACChainAnalysis::ConvMacChain ConvMacChain
void listChain(const std::unique_ptr< ConvMacChain > &chain, const ConvMacChainGroupList &groups) const
LongestConvMACChainAnalysis::ConvMacChainGroupList ConvMacChainGroupList