MLIR-AIE
AIEVecOptimizations.cpp
Go to the documentation of this file.
1//===- AIEVecOptimizations.cpp - Patterns to optimize AIEVec ----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2023, Advanced Micro Devices, Inc.
8//
9//===----------------------------------------------------------------------===//
10// This file contains conversions and rewrite that replace common AIEVec ops
11// with more complex, and performant AIEVec ops.
12//===----------------------------------------------------------------------===//
13
15
21
22#include "mlir/Dialect/Affine/IR/AffineOps.h"
23#include "mlir/Dialect/Func/IR/FuncOps.h"
24#include "mlir/Dialect/MemRef/IR/MemRef.h"
25#include "mlir/Dialect/SCF/IR/SCF.h"
26#include "mlir/IR/PatternMatch.h"
27#include "mlir/Pass/PassManager.h"
28#include "mlir/Transforms/DialectConversion.h"
29#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
30#include "mlir/Transforms/Passes.h"
31
32#define DEBUG_TYPE "aievec-optimize"
33
34using namespace llvm;
35using namespace mlir;
36using namespace arith;
37using namespace vector;
38using namespace xilinx;
39using namespace xilinx::aievec;
40
41//===----------------------------------------------------------------------===//
42// Utility functions
43//===----------------------------------------------------------------------===//
44namespace xilinx {
45namespace aievec {
46
47SmallVector<NamedAttribute>
48buildFMAOpSplatAttrForElemTy(aievec::aie1::FMAOp fmaOp, int64_t bcastPos,
49 int64_t step = 1);
50
51} // namespace aievec
52} // namespace xilinx
53
54static bool canFoldAIEShiftAndBroadcast(aievec::BroadcastOp op,
55 aievec::ShiftOp &shiftOp,
56 int32_t &idx) {
57 if (!op.getSource().getDefiningOp())
58 return false;
59
60 shiftOp = dyn_cast<aievec::ShiftOp>(op.getSource().getDefiningOp());
61
62 if (!shiftOp)
63 return false;
64
65 VectorType vType = cast<VectorType>(shiftOp->getResult(0).getType());
66 int32_t elemSize = getElementSizeInBits(vType);
67 auto constOp = cast<arith::ConstantOp>(shiftOp.getShift().getDefiningOp());
68 int32_t shiftBytes = cast<IntegerAttr>(constOp.getValue()).getInt();
69 idx = shiftBytes * 8 / elemSize + op.getIdx();
70
71 if (idx <= 0 || idx >= (int32_t)getVectorLaneSize(vType)) {
72 return false;
73 }
74
75 return true;
76}
77
78template <typename AIEv1MACLikeOp,
79 typename = std::enable_if_t<
80 std::is_same_v<AIEv1MACLikeOp, aievec::aie1::FMAOp> ||
81 std::is_same_v<AIEv1MACLikeOp, aievec::aie1::FMAOp::Adaptor>>>
82static bool isSingleColumnInt16VectorTimesScalarMac(AIEv1MACLikeOp fmaOp) {
83 // lhs is a 32xi16 vector
84 VectorType lhsVTy = cast<VectorType>(fmaOp.getLhs().getType());
85 auto intTy = dyn_cast<IntegerType>(lhsVTy.getElementType());
86 if (!intTy || intTy.getWidth() != 16)
87 return false;
88 if (lhsVTy.getShape()[0] != 32)
89 return false;
90 // Attributes match a Vector x Scalar mac
91 if (fmaOp.getXoffsets() != "0x73727170" ||
92 fmaOp.getXoffsetsHi() != "0x77767574" || fmaOp.getXstart() != "0" ||
93 fmaOp.getXsquare() != "0x3120" || fmaOp.getZoffsets() != "0" ||
94 fmaOp.getZoffsetsHi() != "0" || fmaOp.getZstep() != "1")
95 return false;
96 // lhs op is a concat of a vector and a dense<0> constant vector
97 if (!fmaOp.getLhs().getDefiningOp())
98 return false;
99 aievec::ConcatOp concatOp =
100 dyn_cast<aievec::ConcatOp>(fmaOp.getLhs().getDefiningOp());
101 if (!concatOp)
102 return false;
103 auto tailVec = concatOp.getSources()[1];
104 if (!tailVec.getDefiningOp())
105 return false;
106 auto constOp = dyn_cast<arith::ConstantOp>(tailVec.getDefiningOp());
107 if (!constOp)
108 return false;
109 auto cstDense = dyn_cast<DenseIntElementsAttr>(constOp.getValue());
110 if (!cstDense)
111 return false;
112 return llvm::all_of(cstDense, [](const APInt &val) { return val == 0; });
113}
114
115static bool singleColumnFMAOpCanFold(aievec::aie1::FMAOp fmaOp) {
116 auto accProdOp = fmaOp.getAcc().getDefiningOp();
117 if (!accProdOp)
118 return false;
119 auto accFmaOp = dyn_cast<aievec::aie1::FMAOp>(accProdOp);
120 if (!accFmaOp)
121 return false;
122 if (!isSingleColumnInt16VectorTimesScalarMac(accFmaOp))
123 return false;
124 return fmaOp.getRhs() == accFmaOp.getRhs() &&
125 !singleColumnFMAOpCanFold(accFmaOp);
126}
127
128//===----------------------------------------------------------------------===//
129// Lowering patterns
130//===----------------------------------------------------------------------===//
132 : public OpConversionPattern<aievec::aie1::FMAOp> {
133 using OpConversionPattern<aievec::aie1::FMAOp>::OpConversionPattern;
134
135 LogicalResult
136 matchAndRewrite(aievec::aie1::FMAOp fmaOp, OpAdaptor adaptor,
137 ConversionPatternRewriter &rewriter) const override {
138 if (!isSingleColumnInt16VectorTimesScalarMac(adaptor))
139 return failure();
140 auto accProdOp = adaptor.getAcc().getDefiningOp();
141 if (!accProdOp)
142 return failure();
143 auto accFmaOp = dyn_cast<aievec::aie1::FMAOp>(accProdOp);
144 if (!accFmaOp)
145 return failure();
146 if (!isSingleColumnInt16VectorTimesScalarMac(accFmaOp))
147 return failure();
148 if (adaptor.getRhs() != accFmaOp.getRhs())
149 return failure();
150 auto accConcatOp =
151 cast<aievec::ConcatOp>(accFmaOp.getLhs().getDefiningOp());
152 auto fmaConcatOp = cast<aievec::ConcatOp>(adaptor.getLhs().getDefiningOp());
153 unsigned fmaZstart, accFmaZstart;
154 if (adaptor.getZstart().getAsInteger(10, fmaZstart) ||
155 accFmaOp.getZstart().getAsInteger(10, accFmaZstart))
156 return failure();
157 auto start = std::min(fmaZstart, accFmaZstart);
158 auto step = std::max(fmaZstart, accFmaZstart) - start;
159 auto lowV = accConcatOp.getSources()[0];
160 auto hiV = fmaConcatOp.getSources()[0];
161 if (accFmaZstart > fmaZstart)
162 std::swap(lowV, hiV);
163 auto newConcatOp = rewriter.create<aievec::ConcatOp>(
164 fmaOp.getLoc(), adaptor.getLhs().getType(),
165 SmallVector<Value, 2>({lowV, hiV}));
166 auto newFmaOpAttr = buildFMAOpSplatAttrForElemTy(fmaOp, start, step);
167 rewriter.replaceOpWithNewOp<aievec::aie1::FMAOp>(
168 fmaOp, TypeRange({fmaOp.getResult().getType()}),
169 ValueRange({newConcatOp, adaptor.getRhs(), accFmaOp.getAcc()}),
170 newFmaOpAttr);
171 return success();
172 }
173};
174
176 : public OpConversionPattern<aievec::BroadcastOp> {
177 using OpConversionPattern<aievec::BroadcastOp>::OpConversionPattern;
178
179 LogicalResult
180 matchAndRewrite(aievec::BroadcastOp bcastOp, OpAdaptor adaptor,
181 ConversionPatternRewriter &rewriter) const override {
182 aievec::ShiftOp shiftOp = nullptr;
183 int32_t idx = 0;
184
185 if (!canFoldAIEShiftAndBroadcast(bcastOp, shiftOp, idx)) {
186 return failure();
187 }
188
189 VectorType resultType = cast<VectorType>(bcastOp.getResult().getType());
190
191 rewriter.replaceOpWithNewOp<aievec::BroadcastOp>(bcastOp, resultType,
192 shiftOp.getLhs(), idx);
193
194 return success();
195 }
196};
197
198//===----------------------------------------------------------------------===//
199// Pattern collection
200//===----------------------------------------------------------------------===//
201static void populateAIEVecV1TransformationPatterns(RewritePatternSet &patterns,
202 TargetBackend backend) {
203 patterns.add<MergeSingleColumnI16FMAOpPattern>(patterns.getContext());
204}
205
206static void populateAIEVecV2TransformationPatterns(RewritePatternSet &patterns,
207 TargetBackend backend) {
208 patterns.add<FoldAIEShiftAndBroadcast>(patterns.getContext());
209}
210
211//===----------------------------------------------------------------------===//
212// Legalizations
213//===----------------------------------------------------------------------===//
214
215static void
216configureAIEVecV1TransformationLegalizations(ConversionTarget &target,
217 TargetBackend backend) {
218 target.addLegalDialect<aievec::AIEVecDialect,
219 aievec::aie1::AIEVecAIE1Dialect>();
220 target.addDynamicallyLegalOp<aievec::aie1::FMAOp>(
221 [](aievec::aie1::FMAOp fmaOp) {
222 if (isSingleColumnInt16VectorTimesScalarMac(fmaOp))
223 return !singleColumnFMAOpCanFold(fmaOp);
224 return true;
225 });
226}
227
228static void
229configureAIEVecV2TransformationLegalizations(ConversionTarget &target,
230 TargetBackend backend) {
231 target.addDynamicallyLegalOp<xilinx::aievec::BroadcastOp>(
232 [](xilinx::aievec::BroadcastOp op) {
233 aievec::ShiftOp shiftOp = nullptr;
234 int32_t idx = 0;
235 return !canFoldAIEShiftAndBroadcast(op, shiftOp, idx);
236 });
237}
238
239//===----------------------------------------------------------------------===//
240// Lowering passes
241//===----------------------------------------------------------------------===//
243 : public PassWrapper<AIEVecTransformationPass, OperationPass<>> {
244 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AIEVecTransformationPass)
245
249
255
256 // In case we want to register this pass as a standalone pass for test
257 // purposes.
258 StringRef getArgument() const final { return "test-aievec-optimize"; }
259 StringRef getDescription() const final {
260 return "Optimize groups of simple aievec ops into complex aievec ops.";
261 }
262 void getDependentDialects(DialectRegistry &registry) const override {
263 // TODO: Review list of dependent dialects.
264 registry.insert<affine::AffineDialect, xilinx::aievec::AIEVecDialect,
265 aievec::aie1::AIEVecAIE1Dialect, arith::ArithDialect,
266 memref::MemRefDialect, scf::SCFDialect,
267 vector::VectorDialect>();
268 }
269
270 Option<std::string> aieTarget{
271 *this, "aie-target",
272 llvm::cl::desc("Select AIE version: \"aie\" or \"aie2\". This will "
273 "determine the vector size and available operations."),
274 llvm::cl::init("aie")};
275
276 Option<std::string> targetBackend{
277 *this, "target-backend",
278 llvm::cl::desc("Select translation backend: \"cpp\" or \"llvmir\". This "
279 "will determine the aievec operations used to convert "
280 "from vector dialect."),
281 llvm::cl::init("cpp")};
282
283 void runOnOperation() override {
284 auto op = getOperation();
285 MLIRContext *context = &getContext();
286 RewritePatternSet patterns(context);
287 ConversionTarget target(*context);
288 AIEArch aieVersion = AIEArch::AIE;
289 if (!aieTarget.empty()) {
290 std::string target = aieTarget;
291 if (target == "aieml" || target == "aie2") {
292 aieVersion = AIEArch::AIE2;
293 } else if (target != "aie") {
294 op->emitError() << "unknown AIE target '" << aieTarget << "'";
295 signalPassFailure();
296 return;
297 }
298 }
299
300 TargetBackend backend = TargetBackend::CPP;
301 if (!targetBackend.empty()) {
302 std::string backendStr = targetBackend;
303 if (backendStr == "llvmir") {
304 backend = TargetBackend::LLVMIR;
305 if (aieVersion == AIEArch::AIE) {
306 op->emitError() << "targetting LLVM IR is not supported for AIEv1";
307 signalPassFailure();
308 return;
309 }
310 } else if (backendStr != "cpp") {
311 op->emitError() << "unknown target backend'" << targetBackend << "'";
312 signalPassFailure();
313 return;
314 }
315 }
316
317 if (aieVersion == AIEArch::AIE) {
318 populateAIEVecV1TransformationPatterns(patterns, backend);
319 configureAIEVecV1TransformationLegalizations(target, backend);
320 } else {
321 populateAIEVecV2TransformationPatterns(patterns, backend);
322 configureAIEVecV2TransformationLegalizations(target, backend);
323 }
324
325 if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
326 signalPassFailure();
327 }
328 }
329};
330
331static std::unique_ptr<::mlir::Pass>
332createAIEVecTransformationPass(const OptimizeAIEVecOptions &options) {
333 return std::make_unique<AIEVecTransformationPass>(options);
334}
335
337 : public PassWrapper<AIEVecConvOpTransformationPass, OperationPass<>> {
338 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AIEVecConvOpTransformationPass)
339
343
350
351 // In case we want to register this pass as a standalone pass for test
352 // purposes.
353 StringRef getArgument() const final {
354 return "test-aievec-convolution-optimize";
355 }
356 StringRef getDescription() const final {
357 return "Optimize chains of macs into AIE2 conv ops.";
358 }
359 void getDependentDialects(DialectRegistry &registry) const override {
360 // TODO: Review list of dependent dialects.
361 registry.insert<affine::AffineDialect, xilinx::aievec::AIEVecDialect,
362 aievec::aie1::AIEVecAIE1Dialect, arith::ArithDialect,
363 memref::MemRefDialect, scf::SCFDialect,
364 vector::VectorDialect>();
365 }
366
367 Option<std::string> aieTarget{
368 *this, "aie-target",
369 llvm::cl::desc("Select AIE version: \"aie\" or \"aie2\". This will "
370 "determine the vector size and available operations."),
371 llvm::cl::init("aie")};
372
373 Option<std::string> targetBackend{
374 *this, "target-backend",
375 llvm::cl::desc("Select translation backend: \"cpp\" or \"llvmir\". This "
376 "will determine the aievec operations used to convert "
377 "from vector dialect."),
378 llvm::cl::init("cpp")};
379
380 Option<unsigned> shiftParam{
381 *this, "shift",
382 llvm::cl::desc("Shift parameter for rounding and saturation."),
383 llvm::cl::init(0)};
384
385 void runOnOperation() override {
386 auto op = getOperation();
387 MLIRContext *context = &getContext();
388 RewritePatternSet patterns(context);
389 ConversionTarget target(*context);
390 AIEArch aieVersion = AIEArch::AIE;
391 if (!aieTarget.empty()) {
392 std::string target = aieTarget;
393 if (target == "aieml" || target == "aie2") {
394 aieVersion = AIEArch::AIE2;
395 } else if (target != "aie") {
396 op->emitError() << "unknown AIE target '" << aieTarget << "'";
397 signalPassFailure();
398 return;
399 }
400 }
401
402 TargetBackend backend = TargetBackend::CPP;
403 if (!targetBackend.empty()) {
404 std::string backendStr = targetBackend;
405 if (backendStr == "llvmir") {
406 backend = TargetBackend::LLVMIR;
407 if (aieVersion == AIEArch::AIE) {
408 op->emitError() << "targetting LLVM IR is not supported for AIEv1";
409 signalPassFailure();
410 return;
411 }
412 } else if (backendStr != "cpp") {
413 op->emitError() << "unknown target backend'" << targetBackend << "'";
414 signalPassFailure();
415 return;
416 }
417 }
418
419 AnalysisManager am = getAnalysisManager();
420 if (aieVersion == AIEArch::AIE2) {
422 backend);
424 }
425
426 if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
427 signalPassFailure();
428 }
429 }
430};
431
432static std::unique_ptr<::mlir::Pass>
433createAIEVecConvOpTransformationPass(const OptimizeAIEVecOptions &options) {
434 return std::make_unique<AIEVecConvOpTransformationPass>(options);
435}
436
437//============================================================================//
438//=============== Main AIEVec2AIEVec Pipeline Configuration ==================//
439//============================================================================//
440
441void xilinx::aievec::buildOptimizeAIEVec(OpPassManager &pm,
442 const OptimizeAIEVecOptions &options) {
443 // Add AIEVec transformation pass.
444 pm.addPass(createAIEVecTransformationPass(options));
445
446 pm.addPass(createCSEPass());
447 pm.addPass(createCanonicalizerPass());
448
449 // Add generating aievec convolution ops pass
450 if (options.aieTarget == "aieml" || options.aieTarget == "aie2") {
452 pm.addPass(createAIEVecConvOpTransformationPass(options));
453 }
454
455 // Add post-lowering canonicalization passes.
456 pm.addPass(createCSEPass());
457 pm.addPass(createCanonicalizerPass());
458}
unsigned getVectorLaneSize(mlir::VectorType type)
Definition AIEVecUtils.h:55
SmallVector< NamedAttribute > buildFMAOpSplatAttrForElemTy(aievec::aie1::FMAOp fmaOp, int64_t bcastPos, int64_t step=1)
std::unique_ptr< mlir::Pass > createAIEVecConvolutionAnalysisPass()
void populateAIEVecConvOpTransformationPatterns(RewritePatternSet &patterns, AnalysisManager &am, unsigned shiftParam, TargetBackend backend)
int32_t getElementSizeInBits(mlir::VectorType type)
Definition AIEVecUtils.h:49
void configureAIEVecConvOpTransformationLegalizations(ConversionTarget &target, AnalysisManager &am, TargetBackend backend)
void buildOptimizeAIEVec(mlir::OpPassManager &pm, const OptimizeAIEVecOptions &options)
TargetBackend
Definition Passes.h:27
AIEArch
Definition Passes.h:21
AIEVecConvOpTransformationPass(const OptimizeAIEVecOptions &options)
void getDependentDialects(DialectRegistry &registry) const override
StringRef getDescription() const final
AIEVecTransformationPass(const OptimizeAIEVecOptions &options)
void getDependentDialects(DialectRegistry &registry) const override
StringRef getDescription() const final
StringRef getArgument() const final
Option< std::string > aieTarget
Option< std::string > targetBackend
LogicalResult matchAndRewrite(aievec::BroadcastOp bcastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::aie1::FMAOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Options for the "optimize-aievec" pipeline.
Definition Passes.h:73
PassOptions::Option< unsigned > shiftParam
Definition Passes.h:85
PassOptions::Option< std::string > targetBackend
Definition Passes.h:79
PassOptions::Option< std::string > aieTarget
Definition Passes.h:74