15#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
16#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
19#include "mlir/Dialect/Index/IR/IndexDialect.h"
20#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21#include "mlir/Dialect/Math/IR/Math.h"
22#include "mlir/Dialect/UB/IR/UBOps.h"
23#include "mlir/Dialect/Vector/IR/VectorOps.h"
24#include "mlir/IR/Attributes.h"
25#include "mlir/IR/IRMapping.h"
26#include "mlir/IR/PatternMatch.h"
27#include "mlir/Pass/Pass.h"
28#include "mlir/Tools/mlir-translate/MlirTranslateMain.h"
29#include "mlir/Transforms/DialectConversion.h"
36static StringRef getArchIntrinsicString(
AIEArch arch) {
45 llvm::report_fatal_error(
"unsupported arch");
48typedef std::tuple<const char *, std::vector<Type>, std::vector<Type>>
52static auto getAIE1Intrinsics(OpBuilder &builder) {
53 Type int32Type = IntegerType::get(builder.getContext(), 32);
54 Type int128Type = IntegerType::get(builder.getContext(), 128);
55 Type int384Type = IntegerType::get(builder.getContext(), 384);
56 Type floatType = Float32Type::get(builder.getContext());
61 {
"debug_i32", {int32Type}, {}},
62 {
"llvm.aie.event0", {}, {}},
63 {
"llvm.aie.event1", {}, {}},
65 {int32Type, int32Type},
68 {int32Type, int128Type},
71 {int32Type, floatType},
73 {
"llvm.aie.get.ss", {int32Type}, {int32Type}},
77 {
"llvm.aie.get.fss", {int32Type}, {floatType}},
78 {
"llvm.aie.put.mcd", {int384Type}, {}},
79 {
"llvm.aie.get.scd", {}, {int384Type}},
80 {
"llvm.aie.lock.acquire.reg",
81 {int32Type, int32Type},
83 {
"llvm.aie.lock.release.reg",
84 {int32Type, int32Type},
90static auto getAIE2Intrinsics(OpBuilder &builder) {
91 Type int32Type = IntegerType::get(builder.getContext(), 32);
92 Type accType = VectorType::get({16}, int32Type);
94 {
"debug_i32", {int32Type}, {}},
95 {
"llvm.aie2.put.ms", {int32Type, int32Type}, {}},
96 {
"llvm.aie2.get.ss", {}, {int32Type, int32Type}},
97 {
"llvm.aie2.mcd.write.vec",
100 {
"llvm.aie2.scd.read.vec",
103 {
"llvm.aie2.acquire",
104 {int32Type, int32Type},
106 {
"llvm.aie2.release",
107 {int32Type, int32Type},
113static auto getAIE2pIntrinsics(OpBuilder &builder) {
114 Type int32Type = IntegerType::get(builder.getContext(), 32);
115 Type accType = VectorType::get({16}, int32Type);
117 {
"debug_i32", {int32Type}, {}},
118 {
"llvm.aie2p.put.ms",
119 {int32Type, int32Type},
121 {
"llvm.aie2p.get.ss",
123 {int32Type, int32Type}},
124 {
"llvm.aie2p.mcd.write.vec",
125 {accType, int32Type},
127 {
"llvm.aie2p.scd.read.vec",
130 {
"llvm.aie2p.acquire",
131 {int32Type, int32Type},
133 {
"llvm.aie2p.release",
134 {int32Type, int32Type},
140static void declareAIEIntrinsics(
AIEArch arch, OpBuilder &builder) {
142 for (
auto &i : functions) {
143 auto [name, argTypes, retTypes] = i;
145 .create<func::FuncOp>(
146 builder.getUnknownLoc(), name,
147 FunctionType::get(builder.getContext(), argTypes, retTypes))
153 registerIntrinsics(getAIE1Intrinsics(builder));
156 registerIntrinsics(getAIE2Intrinsics(builder));
159 registerIntrinsics(getAIE2pIntrinsics(builder));
162 llvm::report_fatal_error(
"unsupported arch");
165template <
typename MyAIEOp>
171 AIEOpRemoval(MLIRContext *context, ModuleOp &m, PatternBenefit benefit = 1)
176 ConversionPatternRewriter &rewriter)
const override {
177 rewriter.eraseOp(op);
183 using OpConversionPattern::OpConversionPattern;
187 PatternBenefit benefit = 1)
192 ConversionPatternRewriter &rewriter)
const override {
193 std::string funcName =
"debug_i32";
194 auto func =
module.lookupSymbol<func::FuncOp>(funcName);
196 return op.emitOpError(
"Could not find the intrinsic function ")
198 SmallVector<Value, 1>
args;
199 args.push_back(op.getArg());
200 rewriter.create<func::CallOp>(rewriter.getUnknownLoc(), func,
args);
201 rewriter.eraseOp(op);
207 using OpConversionPattern::OpConversionPattern;
211 PatternBenefit benefit = 1)
216 ConversionPatternRewriter &rewriter)
const override {
217 auto device = op->getParentOfType<DeviceOp>();
218 const auto &targetModel = device.getTargetModel();
219 std::string funcName;
220 if (targetModel.getTargetArch() == AIEArch::AIE1)
221 funcName =
"llvm.aie.put.";
222 else if (targetModel.getTargetArch() == AIEArch::AIE2)
223 funcName =
"llvm.aie2.put.";
225 funcName =
"llvm.aie2p.put.";
227 if (op.isWideStream())
229 else if (op.isFloatStream())
234 auto putMSFunc =
module.lookupSymbol<func::FuncOp>(funcName);
236 return op.emitOpError(
"Could not find the intrinsic function ")
238 SmallVector<Value, 2>
args;
239 if (targetModel.getTargetArch() == AIEArch::AIE1) {
240 args.push_back(op.getChannel());
241 args.push_back(op.getStreamValue());
243 args.push_back(op.getStreamValue());
244 args.push_back(rewriter.create<arith::ConstantOp>(
245 op.getLoc(), IntegerType::get(rewriter.getContext(), 32),
246 rewriter.getI32IntegerAttr(0)));
248 rewriter.create<func::CallOp>(rewriter.getUnknownLoc(), putMSFunc,
args);
249 rewriter.eraseOp(op);
255 using OpConversionPattern::OpConversionPattern;
259 PatternBenefit benefit = 1)
264 ConversionPatternRewriter &rewriter)
const override {
265 auto device = op->getParentOfType<DeviceOp>();
266 const auto &targetModel = device.getTargetModel();
267 std::string funcName;
268 if (targetModel.getTargetArch() == AIEArch::AIE1)
269 funcName =
"llvm.aie.get.";
270 else if (targetModel.getTargetArch() == AIEArch::AIE2)
271 funcName =
"llvm.aie2.get.";
273 funcName =
"llvm.aie2p.get.";
275 if (op.isWideStream())
277 else if (op.isFloatStream())
282 auto getSSFunc =
module.lookupSymbol<func::FuncOp>(funcName);
284 return op.emitOpError(
"Could not find the intrinsic function ")
286 SmallVector<Value, 2>
args;
287 if (targetModel.getTargetArch() == AIEArch::AIE1)
288 args.push_back(op.getChannel());
289 auto getSSCall = rewriter.create<func::CallOp>(rewriter.getUnknownLoc(),
291 rewriter.replaceOp(op, getSSCall.getResult(0));
298 using OpConversionPattern::OpConversionPattern;
302 PatternBenefit benefit = 1)
307 ConversionPatternRewriter &rewriter)
const override {
308 auto device = op->getParentOfType<DeviceOp>();
309 const auto &targetModel = device.getTargetModel();
310 std::string funcName;
311 if (targetModel.getTargetArch() == AIEArch::AIE1)
312 funcName =
"llvm.aie.put.mcd";
313 else if (targetModel.getTargetArch() == AIEArch::AIE2)
314 funcName =
"llvm.aie2.mcd.write.vec";
316 funcName =
"llvm.aie2p.mcd.write.vec";
317 auto putMCDFunc =
module.lookupSymbol<func::FuncOp>(funcName);
319 return op.emitOpError(
"Could not find the intrinsic function ")
321 SmallVector<Value, 2>
args;
322 args.push_back(op.getCascadeValue());
323 if (isa<AIE2TargetModel>(targetModel))
324 args.push_back(rewriter.create<arith::ConstantOp>(
325 op.getLoc(), IntegerType::get(rewriter.getContext(), 32),
326 rewriter.getI32IntegerAttr(1)));
328 rewriter.create<func::CallOp>(rewriter.getUnknownLoc(), putMCDFunc,
args);
329 rewriter.eraseOp(op);
335 using OpConversionPattern::OpConversionPattern;
339 PatternBenefit benefit = 1)
344 ConversionPatternRewriter &rewriter)
const override {
345 auto device = op->getParentOfType<DeviceOp>();
346 const auto &targetModel = device.getTargetModel();
347 std::string funcName;
348 if (targetModel.getTargetArch() == AIEArch::AIE1)
349 funcName =
"llvm.aie.get.scd";
350 else if (targetModel.getTargetArch() == AIEArch::AIE2)
351 funcName =
"llvm.aie2.scd.read.vec";
353 funcName =
"llvm.aie2p.scd.read.vec";
354 auto getSCDFunc =
module.lookupSymbol<func::FuncOp>(funcName);
356 return op.emitOpError(
"Could not find the intrinsic function ")
358 SmallVector<Value, 2>
args;
359 if (isa<AIE2TargetModel>(targetModel))
360 args.push_back(rewriter.create<arith::ConstantOp>(
361 op.getLoc(), IntegerType::get(rewriter.getContext(), 32),
362 rewriter.getI32IntegerAttr(1)));
364 auto getSCDCall = rewriter.create<func::CallOp>(rewriter.getUnknownLoc(),
366 rewriter.replaceOp(op, getSCDCall.getResult(0));
372 using OpConversionPattern::OpConversionPattern;
376 PatternBenefit benefit = 1)
380 ConversionPatternRewriter &rewriter)
const override {
381 if (!isa<DeviceOp>(useLock->getParentOp())) {
382 auto device = useLock->getParentOfType<DeviceOp>();
384 return module.emitOpError("Device Not found!");
386 const auto &targetModel = device.getTargetModel();
389 std::string funcName;
390 if (targetModel.getTargetArch() == AIEArch::AIE1)
391 funcName =
"llvm.aie.lock.";
392 else if (targetModel.getTargetArch() == AIEArch::AIE2)
393 funcName =
"llvm.aie2.";
395 funcName =
"llvm.aie2p.";
396 if (useLock.acquire() || useLock.acquireGE())
397 funcName +=
"acquire";
398 else if (useLock.release())
399 funcName +=
"release";
400 if (targetModel.getTargetArch() == AIEArch::AIE1)
403 auto useLockFunc =
module.lookupSymbol<func::FuncOp>(funcName);
405 return useLock.emitOpError(
"Could not find the intrinsic function!");
407 SmallVector<Value, 2>
args;
408 auto lockValue = useLock.getLockValue();
411 if (useLock.acquireGE()) {
412 lockValue = -lockValue;
414 args.push_back(rewriter.create<arith::IndexCastOp>(
415 useLock.getLoc(), IntegerType::get(rewriter.getContext(), 32),
417 args.push_back(rewriter.create<arith::ConstantOp>(
418 useLock.getLoc(), IntegerType::get(rewriter.getContext(), 32),
419 rewriter.getI32IntegerAttr(lockValue)));
421 rewriter.create<func::CallOp>(rewriter.getUnknownLoc(), useLockFunc,
424 rewriter.eraseOp(useLock);
430 using OpConversionPattern::OpConversionPattern;
435 PatternBenefit benefit = 1,
int tileCol = -1,
441 ConversionPatternRewriter &rewriter)
const override {
442 rewriter.setInsertionPointToStart(module.getBody());
443 auto t = llvm::cast<MemRefType>(buffer.getType());
444 int col = llvm::cast<TileOp>(buffer.getTile().getDefiningOp()).getCol();
445 int row = llvm::cast<TileOp>(buffer.getTile().getDefiningOp()).getRow();
446 auto symName = buffer.name().getValue();
447 mlir::ElementsAttr initValue = buffer.getInitialValueAttr();
452 rewriter.create<memref::GlobalOp>(
453 rewriter.getUnknownLoc(), symName, rewriter.getStringAttr(
"public"),
454 buffer.getType(), initValue,
false,
457 for (
auto &use : make_early_inc_range(buffer.getResult().getUses())) {
458 Operation *user = use.getOwner();
459 rewriter.setInsertionPoint(user);
460 auto allocated = rewriter.create<memref::GetGlobalOp>(
461 rewriter.getUnknownLoc(), t, symName);
463 rewriter.create<memref::AssumeAlignmentOp>(rewriter.getUnknownLoc(),
466 use.set(allocated.getResult());
469 rewriter.eraseOp(buffer);
475 using OpConversionPattern::OpConversionPattern;
483 MLIRContext *context, ModuleOp &m, IRMapping &
mapper,
484 DenseMap<Operation *, SmallVector<BufferOp, 4>> &
tileToBuffers,
491 ConversionPatternRewriter &rewriter)
const override {
493 int col = op.colIndex();
494 int row = op.rowIndex();
499 rewriter.eraseOp(op);
504 rewriter.setInsertionPointAfter(op->getParentOp());
506 std::string coreName(
"core_" + std::to_string(
col) +
"_" +
507 std::to_string(
row));
508 auto coreFunc = rewriter.create<func::FuncOp>(
509 rewriter.getUnknownLoc(), coreName,
510 FunctionType::get(rewriter.getContext(), {}, {}));
512 rewriter.cloneRegionBefore(op.getBody(), coreFunc.getBody(),
513 coreFunc.getBody().begin(),
mapper);
516 coreFunc.getBody().walk([&](Operation *childOp) {
517 rewriter.setInsertionPointAfter(childOp);
519 if (isa<EndOp>(childOp)) {
520 rewriter.create<func::ReturnOp>(rewriter.getUnknownLoc(),
522 rewriter.eraseOp(childOp);
526 rewriter.eraseOp(op);
532template <
typename OpTy>
534 SmallVector<OpTy, 16> ops;
535 for (
const auto &op : device.getOps<OpTy>())
538 for (
const auto &op : ops)
539 op->moveBefore(device);
544 using OpConversionPattern::OpConversionPattern;
548 PatternBenefit benefit = 1)
553 ConversionPatternRewriter &rewriter)
const override {
554 std::string funcName =
"llvm.aie.event" + std::to_string(op.getVal());
555 auto eventFunc =
module.lookupSymbol<func::FuncOp>(funcName);
557 return op.emitOpError(
"Could not find the intrinsic function ")
559 rewriter.create<func::CallOp>(rewriter.getUnknownLoc(), eventFunc,
561 rewriter.eraseOp(op);
569 ModuleOp m = getOperation();
570 OpBuilder builder = OpBuilder::atBlockEnd(m.getBody());
572 if (m.getOps<DeviceOp>().empty()) {
573 m.emitOpError(
"expected AIE.device operation at toplevel");
574 return signalPassFailure();
576 DeviceOp device = *m.getOps<DeviceOp>().begin();
577 const auto &targetModel = device.getTargetModel();
581 m->setAttr(LLVM::LLVMDialect::getTargetTripleAttrName(),
582 builder.getStringAttr(
583 getArchIntrinsicString(targetModel.getTargetArch())));
585 DenseMap<Operation *, SmallVector<BufferOp, 4>> tileToBuffers;
591 builder.setInsertionPointToStart(m.getBody());
592 declareAIEIntrinsics(targetModel.getTargetArch(), builder);
595 ConversionTarget target(getContext());
596 target.addLegalDialect<func::FuncDialect>();
597 target.addLegalDialect<cf::ControlFlowDialect>();
598 target.addLegalDialect<memref::MemRefDialect>();
599 target.addLegalDialect<VectorDialect>();
600 target.addLegalDialect<aievec::AIEVecDialect>();
601 target.addLegalDialect<arith::ArithDialect>();
602 target.addLegalDialect<ub::UBDialect>();
603 target.addLegalDialect<math::MathDialect>();
604 target.addLegalDialect<index::IndexDialect>();
605 target.addLegalOp<func::FuncOp, ModuleOp>();
607 RewritePatternSet patterns(&getContext());
615 if (failed(applyPartialConversion(m, target, std::move(patterns))))
616 return signalPassFailure();
618 RewritePatternSet outlinePatterns(&getContext());
622 if (failed(applyPartialConversion(m, target, std::move(outlinePatterns))))
623 return signalPassFailure();
627 outlineOps<memref::GlobalOp>(device);
628 outlineOps<func::FuncOp>(device);
630 RewritePatternSet removepatterns(&getContext());
639 if (failed(applyPartialConversion(m, target, std::move(removepatterns))))
640 return signalPassFailure();
645 return std::make_unique<AIECoreToStandardPass>();
std::vector< IntrinsicDecl > IntrinsicDecls
void outlineOps(DeviceOp device)
std::tuple< const char *, std::vector< Type >, std::vector< Type > > IntrinsicDecl
Include the generated interface declarations.
std::unique_ptr< mlir::OperationPass< mlir::ModuleOp > > createAIECoreToStandardPass()
LogicalResult matchAndRewrite(BufferOp buffer, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
AIEBufferToStandard(MLIRContext *context, ModuleOp &m, PatternBenefit benefit=1, int tileCol=-1, int tileRow=-1)
ModuleOp &IRMapping & mapper
LogicalResult matchAndRewrite(CoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
AIECoreToStandardFunc(MLIRContext *context, ModuleOp &m, IRMapping &mapper, DenseMap< Operation *, SmallVector< BufferOp, 4 > > &tileToBuffers, PatternBenefit benefit=1, int tileCol=1, int tileRow=1)
DenseMap< Operation *, SmallVector< BufferOp, 4 > > & tileToBuffers
void runOnOperation() override
ModuleOp & AIEDebugOpToStdLowering(MLIRContext *context, ModuleOp &m, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(DebugOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ModuleOp & AIEEventOpToStdLowering(MLIRContext *context, ModuleOp &m, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(EventOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(GetCascadeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ModuleOp & AIEGetCascadeToStdLowering(MLIRContext *context, ModuleOp &m, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(GetStreamOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ModuleOp & AIEGetStreamToStdLowering(MLIRContext *context, ModuleOp &m, PatternBenefit benefit=1)
ModuleOp & AIEOpRemoval(MLIRContext *context, ModuleOp &m, PatternBenefit benefit=1)
typename MyAIEOp::Adaptor OpAdaptor
LogicalResult matchAndRewrite(MyAIEOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(PutCascadeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ModuleOp & AIEPutCascadeToStdLowering(MLIRContext *context, ModuleOp &m, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(PutStreamOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ModuleOp & AIEPutStreamToStdLowering(MLIRContext *context, ModuleOp &m, PatternBenefit benefit=1)
ModuleOp & AIEUseLockToStdLowering(MLIRContext *context, ModuleOp &m, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(UseLockOp useLock, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override