MLIR-AIE
FoldMulAddChainToConvOp.cpp
Go to the documentation of this file.
1//===--FoldMulAddChainToConvOp.cpp - Fold Mul Add Chain To AIEVec Conv Op--===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2023 Xilinx Inc.
8//
9//===----------------------------------------------------------------------===//
10// This is the implementation of the folding pass from mul add chain
11// to AIEVec convolution operations, compatible with the AIE2 architecture.
12//===----------------------------------------------------------------------===//
13
15
20#include "mlir/Analysis/SliceAnalysis.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/Transforms/DialectConversion.h"
23#include "llvm/Support/Debug.h"
24
25#define DEBUG_TYPE "fold-mul-add-chain-to-conv"
26
27using namespace mlir;
28using namespace arith;
29using namespace vector;
30using namespace xilinx;
31using namespace xilinx::aievec;
32
33namespace xilinx::aievec {
34#define GEN_PASS_DEF_AIEVECCONVANALYSIS
35#include "aie/Dialect/AIEVec/Analysis/Passes.h.inc"
36} // namespace xilinx::aievec
37
38/// This analysis builds the longest possible chain of MAC operations whose
39/// operands are a vector that may or may not be shifted, and a broadcast.
40/// That is, these MACs represent `vector x scalar` ops, and are candidates to
41/// be grouped and replaced by mul_conv/fma_conv ops in AIE2.
42//
43// We build this chain recursively, climbing up the
45 static AnalysisManager *am;
46
47 struct ConvMac {
48 // If there's a non-accumulating convolution upchain,
49 // store it here temorarily.
50 std::unique_ptr<ConvMac> topOfChainMulConv;
51 // Accumulator value, if there is one.
52 Value acc;
53 // Left-hand side (non-broadcasting) source value
54 Value lhs;
55 // Left-hand side (broadcasting) source value
56 Value rhs;
57 // Amount that lhs is shifted
58 uint8_t shift;
59 // Element in rhs that is broadcasted
60 uint8_t bcastIdx;
61 ConvMac(Value lhs, Value rhs, uint8_t shift, uint8_t bcastIdx)
62 : topOfChainMulConv(nullptr), acc(nullptr), lhs(lhs), rhs(rhs),
64 };
65
67 // Group start index within the chain
68 uint64_t fromIdx;
69 // Index in chain after group last MAC
70 uint64_t toIdx;
71 // Initial position of the signal to be convolved
72 int64_t signalShift;
73 // Initial position of the convolution filter
74 int64_t bcastShift;
75 // Distance between elements in the filter
76 int64_t bcastDist; // Must be 1 or 2
77 };
78
79 typedef SmallVector<std::unique_ptr<ConvMac>, 8> ConvMacChain;
80 typedef SmallVector<ConvMacChainGroup, 8> ConvMacChainGroupList;
81
82 std::unique_ptr<ConvMacChain> convMacChain;
84
85 /// Sort the chain of MACs by sources. When two MACs share the same sources,
86 /// sort them by the broadcast index. If they don't, sort them by the order
87 /// of the ops in the code. This function should be called after the chain
88 /// is completed, and before operating on the groups of MACs. After sorting,
89 /// MACs that can be fused into single convolution ops will be contiguous in
90 /// the chain.
91 void sortChain() {
92 if ((*convMacChain)[0]->acc) {
93 std::sort(convMacChain->begin(), convMacChain->end(),
94 [](const auto &a, const auto &b) {
95 if (a->lhs == b->lhs) {
96 if (a->rhs == b->rhs)
97 return a->bcastIdx < b->bcastIdx;
98 return a->rhs.getDefiningOp()->isBeforeInBlock(
99 b->rhs.getDefiningOp());
100 }
101 // We should probably sort by lhs load address, if it exists
102 // XXX: We assume all MACs in the same block. If they're not,
103 // XXX: this will assert.
104 return a->lhs.getDefiningOp()->isBeforeInBlock(
105 b->lhs.getDefiningOp());
106 });
107 } else {
108 // If the top of the chain is not an accumulation, bring up all related
109 // convolution MACs and sort the rest by lhs.
110 auto firstLhs = (*convMacChain)[0]->lhs;
111 std::sort(convMacChain->begin(), convMacChain->end(),
112 [&firstLhs](const auto &a, const auto &b) {
113 if (a->lhs == b->lhs) {
114 if (a->rhs == b->rhs)
115 return a->bcastIdx < b->bcastIdx;
116 return a->rhs.getDefiningOp()->isBeforeInBlock(
117 b->rhs.getDefiningOp());
118 }
119 if (a->lhs == firstLhs)
120 return true;
121 if (b->lhs == firstLhs)
122 return false;
123 return a->lhs.getDefiningOp()->isBeforeInBlock(
124 b->lhs.getDefiningOp());
125 });
126 // Float the empty accumulator to the top.
127 if ((*convMacChain)[0]->acc)
128 for (auto &convMac : *convMacChain)
129 if (!convMac->acc) {
130 std::swap((*convMacChain)[0]->acc, convMac->acc);
131 break;
132 }
133 }
134 }
135
136 // Return the list of convolution mac ops in the chain as pairs of indices
137 // indicating the position within the chain where a group starts and the
138 // position where it ends: [start, end). If they have not been precomputed
139 // yet, this method will generate them.
141 // If there's no group or it's been computed already, return stored list.
142 if (groupsInChain.size() > 0 || !convMacChain || convMacChain->size() == 0)
143 return groupsInChain;
144
145 uint64_t grpStartIdx = 0;
146 uint64_t grpCurIdx = 0;
147 Value curLhs = (*convMacChain)[0]->lhs;
148 Value curRhs = (*convMacChain)[0]->rhs;
149 for (const auto &convMac : *convMacChain) {
150 if (grpCurIdx > grpStartIdx) {
151 if (curLhs != convMac->lhs || curRhs != convMac->rhs) {
152 groupsInChain.push_back({grpStartIdx, grpCurIdx,
153 getGroupSignalShift(grpStartIdx, grpCurIdx),
154 getGroupBcastShift(grpStartIdx, grpCurIdx),
155 getGroupBcastDist(grpStartIdx, grpCurIdx)});
156 grpStartIdx = grpCurIdx;
157 curLhs = convMac->lhs;
158 curRhs = convMac->rhs;
159 }
160 }
161 grpCurIdx++;
162 }
163 if (grpStartIdx < grpCurIdx)
164 groupsInChain.push_back({grpStartIdx, grpCurIdx,
165 getGroupSignalShift(grpStartIdx, grpCurIdx),
166 getGroupBcastShift(grpStartIdx, grpCurIdx),
167 getGroupBcastDist(grpStartIdx, grpCurIdx)});
168 return groupsInChain;
169 }
170
171 // Return the signal shift for the group in the MAC chain in [fromIdx, toIdx)
172 // the top. This method verifies that the elements of the signal are
173 // contiguously accessed. If they do not, or the specified group doesn't
174 // exist, this function returns -1.
175 int64_t getGroupSignalShift(uint64_t fromIdx, uint64_t toIdx) {
176 if (fromIdx >= toIdx || toIdx > convMacChain->size())
177 return -1;
178 if (toIdx == fromIdx + 1)
179 return static_cast<int64_t>((*convMacChain)[fromIdx]->shift);
180 for (uint64_t i = fromIdx; i < toIdx - 1; i++)
181 if ((static_cast<int64_t>((*convMacChain)[i + 1]->shift) -
182 static_cast<int64_t>((*convMacChain)[i]->shift)) != 1)
183 return -1;
184 return static_cast<int64_t>((*convMacChain)[fromIdx]->shift);
185 }
186
187 // Return the shift in value of the first broadcasted element in the i-th
188 // group. If there is no chain, or the i-th group does not exist,
189 // returns -1.
190 int64_t getGroupBcastShift(uint64_t fromIdx, uint64_t toIdx) {
191 if (fromIdx >= toIdx || toIdx > convMacChain->size())
192 return -1;
193 return static_cast<int64_t>((*convMacChain)[fromIdx]->bcastIdx);
194 }
195
196 // Returns the broadcast distance between elements within the group. If the
197 // distance is not constant and equal to 1 or 2, it returns -1.
198 int64_t getGroupBcastDist(uint64_t fromIdx, uint64_t toIdx) {
199 if (fromIdx >= toIdx || toIdx > convMacChain->size())
200 return -1;
201 if (toIdx == fromIdx + 1)
202 return 1;
203 int64_t bcastDist =
204 static_cast<int64_t>((*convMacChain)[fromIdx + 1]->bcastIdx) -
205 static_cast<int64_t>((*convMacChain)[fromIdx]->bcastIdx);
206 if (bcastDist != 1 && bcastDist != 2)
207 return -1;
208 for (uint64_t i = fromIdx + 1; i < toIdx - 1; i++)
209 if ((static_cast<int64_t>((*convMacChain)[i + 1]->bcastIdx) -
210 static_cast<int64_t>((*convMacChain)[i]->bcastIdx)) != bcastDist)
211 return -1;
212 return bcastDist;
213 }
214
216 const auto &groups = getGroupsInChain();
217 if (groups.size() == 0)
218 return false;
219 for (const auto &group : groups)
220 if (group.signalShift == -1 || group.bcastShift == -1 ||
221 group.bcastDist == -1)
222 return false;
223 return true;
224 }
225
226 std::unique_ptr<ConvMac> getConvMacFromMulOp(arith::MulIOp mulOp) {
227 auto mulOpLhsDefOp = mulOp.getLhs().getDefiningOp();
228 auto mulOpRhsDefOp = mulOp.getRhs().getDefiningOp();
229 if (!mulOpLhsDefOp || !mulOpRhsDefOp)
230 return nullptr;
231
232 Value convMacRhs = nullptr;
233 uint8_t convMacBcastIdx = 0;
234
235 auto getConvMacRhs = [&](Operation *mulOpOperand) -> bool {
236 SetVector<Operation *> opBwdSlices;
237 auto opFilter = [](Operation *op) {
238 return isa<aievec::BroadcastOp>(op) || isa<aievec::ExtOp>(op) ||
239 isa<aievec::ConcatOp>(op);
240 };
241 BackwardSliceOptions backwardSliceOptions;
242 backwardSliceOptions.filter = opFilter;
243
244 getBackwardSlice(mulOpOperand, &opBwdSlices, backwardSliceOptions);
245 opBwdSlices.insert(mulOpOperand);
246
247 LLVM_DEBUG(llvm::dbgs() << "opBwdSlices = [\n");
248 for ([[maybe_unused]] auto op : opBwdSlices) {
249 LLVM_DEBUG(llvm::dbgs() << *op << "\n");
250 }
251 LLVM_DEBUG(llvm::dbgs() << "]\n");
252
253 if (opBwdSlices.size() == 1) {
254 if (auto bcastOp = dyn_cast<aievec::BroadcastOp>(opBwdSlices[0])) {
255 convMacRhs = bcastOp.getSource();
256 convMacBcastIdx = bcastOp.getIdx();
257 return true;
258 }
259 } else if (opBwdSlices.size() >= 3) {
260 auto sliceSz = opBwdSlices.size();
261 if ((isa<aievec::ExtOp>(opBwdSlices[sliceSz - 3]) &&
262 isa<aievec::BroadcastOp>(opBwdSlices[sliceSz - 2]) &&
263 isa<aievec::ConcatOp>(opBwdSlices[sliceSz - 1])) ||
264 (isa<aievec::ConcatOp>(opBwdSlices[sliceSz - 3]) &&
265 isa<aievec::BroadcastOp>(opBwdSlices[sliceSz - 2]) &&
266 isa<aievec::ExtOp>(opBwdSlices[sliceSz - 1]))) {
267 convMacRhs = opBwdSlices[sliceSz - 3]->getOperand(0);
268 convMacBcastIdx =
269 dyn_cast<aievec::BroadcastOp>(opBwdSlices[sliceSz - 2]).getIdx();
270 return true;
271 }
272 }
273
274 return false;
275 };
276
277 // Obtain the broadcast operation feeding into the MulIOp
278 if (!getConvMacRhs(mulOpRhsDefOp)) {
279 if (getConvMacRhs(mulOpLhsDefOp)) {
280 std::swap(mulOpLhsDefOp, mulOpRhsDefOp);
281 }
282 }
283 if (!convMacRhs)
284 return nullptr;
285
286 // Obtain the ext or ext->shift op feeding into the MulIOp
287 aievec::ExtOp extOp;
288 aievec::ShiftOp shiftOp;
289 shiftOp = dyn_cast<aievec::ShiftOp>(mulOpLhsDefOp);
290 if (shiftOp)
291 extOp = shiftOp.getLhs().getDefiningOp<aievec::ExtOp>();
292 else
293 extOp = dyn_cast<aievec::ExtOp>(mulOpLhsDefOp);
294
295 // XXX: Actually, ExtOp might not exist but should work anyway.
296 // XXX: Should it, though?
297 if (!extOp)
298 return nullptr;
299
300 Value convMacLhs = extOp.getSource();
301 uint8_t shift = 0;
302 if (shiftOp) {
303 auto shiftConstDefOp =
304 shiftOp.getShift().getDefiningOp<arith::ConstantOp>();
305 if (shiftConstDefOp) {
306 auto shiftAttr = cast<IntegerAttr>(shiftConstDefOp.getValue());
307 auto vType = cast<VectorType>(mulOp.getResult().getType());
308 shift = 8 * shiftAttr.getInt() / getElementSizeInBits(vType);
309 }
310 }
311
312 return std::make_unique<ConvMac>(convMacLhs, convMacRhs, shift,
313 convMacBcastIdx);
314 }
315
316 std::unique_ptr<ConvMac> getConvMacFromAddOp(arith::AddIOp addOp) {
317 // Make sure at least one of them is a multiplication, and the other one
318 // is the accumulator coming form upchain.
319 auto mulOp = addOp.getLhs().getDefiningOp<arith::MulIOp>();
320 Value acc = addOp.getRhs();
321 if (!mulOp) {
322 mulOp = addOp.getRhs().getDefiningOp<arith::MulIOp>();
323 acc = addOp.getLhs();
324 }
325 if (!mulOp)
326 return nullptr;
327
328 // Get the parameters of the convolution from the operands of the MulIOp
329 auto convMac = getConvMacFromMulOp(mulOp);
330 if (!convMac)
331 return nullptr;
332
333 // If both sides are MulIOp, we might be at the top of the chain
334 auto upChainAccMulOp = acc.getDefiningOp<arith::MulIOp>();
335 if (upChainAccMulOp) {
336 auto convMac2 = getConvMacFromMulOp(upChainAccMulOp);
337 // XXX: We pre-sort the top two MACs to make sure that an undefined
338 // XXX: accumulator ends up on top of the chain.
339 // XXX: But it might not be necessary? CHECK!
340 if (convMac2 && convMac->lhs == convMac2->lhs &&
341 convMac->rhs == convMac->rhs) {
342 if (convMac->bcastIdx < convMac2->bcastIdx &&
343 convMac->shift < convMac2->shift) {
344 convMac2->topOfChainMulConv = std::move(convMac);
345 convMac2->acc = acc;
346 return convMac2;
347 } else if (convMac->bcastIdx > convMac2->bcastIdx &&
348 convMac->shift > convMac2->shift) {
349 convMac->topOfChainMulConv = std::move(convMac2);
350 convMac->acc = acc;
351 return convMac;
352 } else {
353 // WARNING: In this situation, the chain is ambiguous and picking one
354 // WARNING: option over the other may result in a successful
355 // WARNING: and/or better replacement. Here, we are assuming that
356 // WARNING: is going to be either one or the other, or it won't
357 // WARNING: matter.
358 }
359 } else {
360 convMac->topOfChainMulConv = std::move(convMac2);
361 }
362 }
363 convMac->acc = acc;
364 return convMac;
365 }
366
367 LongestConvMACChainAnalysis(arith::AddIOp addOp) {
368 std::unique_ptr<ConvMac> macConvChainElem = getConvMacFromAddOp(addOp);
369 if (!macConvChainElem)
370 return;
371
372 if (macConvChainElem->acc) {
373 auto upChainAddOp = macConvChainElem->acc.getDefiningOp<arith::AddIOp>();
374 if (upChainAddOp) {
375 auto &upChainChainAnalysis =
376 am->getChildAnalysis<LongestConvMACChainAnalysis>(upChainAddOp);
377 if (upChainChainAnalysis.convMacChain) {
378 convMacChain = std::move(upChainChainAnalysis.convMacChain);
379 convMacChain->push_back(std::move(macConvChainElem));
380 return;
381 }
382 }
383 }
384 assert(!convMacChain && "Convolution MAC chain unexpectedly not empty");
385 convMacChain = std::make_unique<ConvMacChain>();
386 if (macConvChainElem->topOfChainMulConv)
387 convMacChain->push_back(std::move(macConvChainElem->topOfChainMulConv));
388 convMacChain->push_back(std::move(macConvChainElem));
389 }
390};
391// HACK: For some reason, it's not possible to access the analysis manager from
392// HACK: within an analysis, but we need it to build the analysis recursively.
393// HACK: If there is a good reason not to do this, we should find an
394// HACK: alternative way to build the MAC chain.
395AnalysisManager *LongestConvMACChainAnalysis::am = nullptr;
396
397// This conversion pattern folds a MAC chain into mul_conv and mac_conv
398// ops. We can handle the mul MAC with a random order.
400 : public OpConversionPattern<arith::AddIOp> {
401 using OpConversionPattern<arith::AddIOp>::OpConversionPattern;
402
403 FoldMulAddChainToConvOpPattern(MLIRContext *context, AnalysisManager &am,
404 unsigned shiftParam = 0)
405 : OpConversionPattern<arith::AddIOp>(context), am(am),
407
408 LogicalResult
409 matchAndRewrite(arith::AddIOp srcOp, OpAdaptor adaptor,
410 ConversionPatternRewriter &rewriter) const override {
411 auto &convMacChainAnalysis =
412 am.getChildAnalysis<LongestConvMACChainAnalysis>(srcOp);
413 auto &convMacChain = convMacChainAnalysis.convMacChain;
414 if (!convMacChain)
415 return failure();
416
417 auto loc = srcOp.getLoc();
418 VectorType vecTy = cast<VectorType>(srcOp.getResult().getType());
419 unsigned elemWidth = cast<IntegerType>(vecTy.getElementType()).getWidth();
420 unsigned accWidth = elemWidth <= 8 ? 32 : 64;
421 int32_t M = elemWidth == 8 ? 32 : 16;
422 int32_t N = elemWidth == 8 ? 8 : 4;
423
424 Type wideElemTy = IntegerType::get(getContext(), accWidth);
425 Type accVecTy = VectorType::get(vecTy.getShape(), wideElemTy);
426
427 const auto &groups = convMacChainAnalysis.getGroupsInChain();
428 Value grpAcc = (*convMacChain)[groups[0].fromIdx]->acc;
429 if (grpAcc)
430 grpAcc = rewriter
431 .create<aievec::UPSOp>(srcOp.getLoc(), accVecTy, grpAcc,
432 /*shift=*/0)
433 .getResult();
434 for (const auto &group : groups) {
435 Value grpLhs = (*convMacChain)[group.fromIdx]->lhs;
436 Value grpRhs = (*convMacChain)[group.fromIdx]->rhs;
437 auto filterVecTy = cast<VectorType>(grpRhs.getType());
438 auto signalVecTy = cast<VectorType>(grpLhs.getType());
439 // Sort out the vector used as filter
440 // If the length of the filter is half that of the signal, concatenate
441 // the filter with itself.
442 if (2 * filterVecTy.getShape()[0] == signalVecTy.getShape()[0])
443 grpRhs =
444 rewriter
445 .create<aievec::ConcatOp>(
446 loc, signalVecTy, SmallVector<Value, 2>({grpRhs, grpRhs}))
447 .getResult();
448 // If the filter has duplicate elements, pack them.
449 if (group.bcastDist == 2)
450 // NOTE: This shuffle mode works for `vector<64xi8>`
451 grpRhs = rewriter
452 .create<aievec::ShuffleOp>(loc, signalVecTy, grpRhs,
453 grpRhs, ShuffleMode::T8_64X2_LO)
454 .getResult();
455 // If the first element of the filter to be used is not 0, shift the
456 // filter to align the first element to the beginning.
457 if (group.bcastShift) {
458 int32_t shiftBytes =
459 group.bcastShift * getElementSizeInBits(filterVecTy) >>
460 (3 + group.bcastDist - 1);
461 auto shiftBytesCst =
462 rewriter
463 .create<arith::ConstantOp>(
464 loc, rewriter.getI32IntegerAttr(shiftBytes))
465 .getResult();
466 grpRhs = rewriter
467 .create<aievec::ShiftOp>(grpRhs.getDefiningOp()->getLoc(),
468 signalVecTy, grpRhs, grpRhs,
469 shiftBytesCst)
470 .getResult();
471 }
472 // Sort out the vector used as signal
473 // If the signal to be convolved doesn't start at element 0, shift the
474 // signal to align the first element to the beginning.
475 if (group.signalShift) {
476 int32_t shiftBytes =
477 group.signalShift * getElementSizeInBits(signalVecTy) >> 3;
478 auto shiftBytesCst =
479 rewriter
480 .create<arith::ConstantOp>(
481 loc, rewriter.getI32IntegerAttr(shiftBytes))
482 .getResult();
483 grpLhs = rewriter
484 .create<aievec::ShiftOp>(loc, signalVecTy, grpLhs, grpLhs,
485 shiftBytesCst)
486 .getResult();
487 }
488 // Generate a convolution operation for the group
489 // If there is no upchain accumulator, use a mul_conv; use a mac_conv
490 // otherwise.
491 if (!grpAcc)
492 grpAcc = rewriter
493 .create<aievec::MulConvOp>(srcOp.getLoc(), accVecTy,
494 grpLhs, grpRhs, M, N)
495 .getResult();
496 else
497 grpAcc =
498 rewriter
499 .create<aievec::FMAConvOp>(srcOp.getLoc(), accVecTy, grpLhs,
500 grpRhs, grpAcc, M, N, false)
501 .getResult();
502 }
503
504 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
505 srcOp.getLoc(), rewriter.getI32IntegerAttr(shiftParam));
506 rewriter.replaceOpWithNewOp<aievec::SRSOp>(srcOp, vecTy, grpAcc,
507 shiftParamOp.getResult());
508 return success();
509 }
510
511 AnalysisManager &am;
512 unsigned shiftParam;
513};
514
515namespace xilinx::aievec {
516
518 AnalysisManager &am,
519 TargetBackend backend) {
521 target.addLegalDialect<AIEVecDialect>();
522 target.addLegalDialect<arith::ArithDialect>();
523 target.addDynamicallyLegalOp<arith::AddIOp>([&am](arith::AddIOp op) {
524 auto &convAnalysis = am.getChildAnalysis<LongestConvMACChainAnalysis>(op);
525 return !convAnalysis.canChainBeReplacedWithConvOps();
526 });
527}
528
529void populateAIEVecConvOpTransformationPatterns(RewritePatternSet &patterns,
530 AnalysisManager &am,
531 unsigned shiftParam,
532 TargetBackend backend) {
533 patterns.add<FoldMulAddChainToConvOpPattern>(patterns.getContext(), am,
534 shiftParam);
535}
536
537struct AIEVecConvAnalysis : public AIEVecConvAnalysisBase<AIEVecConvAnalysis> {
542
543 void runOnOperation() override {
544 markAllAnalysesPreserved();
545 AnalysisManager am = getAnalysisManager();
547 Operation *op = getOperation();
548
549 // Compute all the chains
550 op->walk([&](arith::AddIOp addOp) {
551 if (isa<VectorType>(addOp.getResult().getType()))
552 am.getChildAnalysis<LongestConvMACChainAnalysis>(addOp);
553 });
554
555 // Sort the chains, ready to split by group
556 op->walk([&](arith::AddIOp addOp) {
557 if (isa<VectorType>(addOp.getResult().getType())) {
558 auto &analysis =
559 am.getChildAnalysis<LongestConvMACChainAnalysis>(addOp);
560 if (analysis.convMacChain)
561 analysis.sortChain();
562 }
563 });
564
565 if (printResult) {
566 op->walk([&](arith::AddIOp addOp) {
567 if (isa<VectorType>(addOp.getResult().getType())) {
568 auto &macChainAnalysis =
569 am.getChildAnalysis<LongestConvMACChainAnalysis>(addOp);
570 if (macChainAnalysis.canChainBeReplacedWithConvOps()) {
571 addOp.print(llvm::outs());
572 llvm::outs() << " is at the end of a convolution MAC Chain:\n";
573 listChain(macChainAnalysis.convMacChain,
574 macChainAnalysis.getGroupsInChain());
575 }
576 }
577 });
578 }
579 }
580
581 void listChain(const std::unique_ptr<ConvMacChain> &chain,
582 const ConvMacChainGroupList &groups) const {
583 uint64_t gIdx = 0;
584 for (const auto &group : groups) {
585 llvm::outs() << "-------------- GROUP " << std::to_string(gIdx)
586 << " --------------\n";
587 llvm::outs() << " Signal Shift: " << std::to_string(group.signalShift)
588 << " Kernel Shift: " << std::to_string(group.bcastShift)
589 << " Kernel Duplication: "
590 << std::to_string(group.bcastDist) << "\n";
591 for (uint64_t i = group.fromIdx; i < group.toIdx; i++) {
592 auto shift = (*chain)[i]->shift;
593 auto bcastIdx = (*chain)[i]->bcastIdx;
594 auto lhsOp = (*chain)[i]->lhs.getDefiningOp();
595 auto rhsOp = (*chain)[i]->rhs.getDefiningOp();
596 if (!(*chain)[i]->acc)
597 llvm::outs() << " [mul_conv]\n";
598 llvm::outs() << " [Shift: " << std::to_string(shift) << "]: ";
599 lhsOp->print(llvm::outs());
600 llvm::outs() << "\n [Bcast: " << std::to_string(bcastIdx) << "]: ";
601 rhsOp->print(llvm::outs());
602 llvm::outs() << "\n";
603 }
604 gIdx++;
605 }
606 llvm::outs() << "-------------------------------------\n";
607 }
608};
609
610std::unique_ptr<Pass> createAIEVecConvolutionAnalysisPass() {
611 return std::make_unique<AIEVecConvAnalysis>();
612}
613
614} // namespace xilinx::aievec
std::unique_ptr< mlir::Pass > createAIEVecConvolutionAnalysisPass()
void populateAIEVecConvOpTransformationPatterns(RewritePatternSet &patterns, AnalysisManager &am, unsigned shiftParam, TargetBackend backend)
int32_t getElementSizeInBits(mlir::VectorType type)
Definition AIEVecUtils.h:49
void configureAIEVecConvOpTransformationLegalizations(ConversionTarget &target, AnalysisManager &am, TargetBackend backend)
TargetBackend
Definition Passes.h:27
LogicalResult matchAndRewrite(arith::AddIOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
FoldMulAddChainToConvOpPattern(MLIRContext *context, AnalysisManager &am, unsigned shiftParam=0)
ConvMac(Value lhs, Value rhs, uint8_t shift, uint8_t bcastIdx)
This analysis builds the longest possible chain of MAC operations whose operands are a vector that ma...
SmallVector< ConvMacChainGroup, 8 > ConvMacChainGroupList
const ConvMacChainGroupList & getGroupsInChain()
void sortChain()
Sort the chain of MACs by sources.
int64_t getGroupBcastShift(uint64_t fromIdx, uint64_t toIdx)
std::unique_ptr< ConvMacChain > convMacChain
LongestConvMACChainAnalysis(arith::AddIOp addOp)
std::unique_ptr< ConvMac > getConvMacFromAddOp(arith::AddIOp addOp)
SmallVector< std::unique_ptr< ConvMac >, 8 > ConvMacChain
int64_t getGroupSignalShift(uint64_t fromIdx, uint64_t toIdx)
std::unique_ptr< ConvMac > getConvMacFromMulOp(arith::MulIOp mulOp)
int64_t getGroupBcastDist(uint64_t fromIdx, uint64_t toIdx)
LongestConvMACChainAnalysis::ConvMacChain ConvMacChain
void listChain(const std::unique_ptr< ConvMacChain > &chain, const ConvMacChainGroupList &groups) const
LongestConvMACChainAnalysis::ConvMacChainGroupList ConvMacChainGroupList