14#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
15#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
18#include "mlir/Dialect/Index/IR/IndexDialect.h"
19#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20#include "mlir/Dialect/Math/IR/Math.h"
21#include "mlir/Dialect/Vector/IR/VectorOps.h"
22#include "mlir/IR/Attributes.h"
23#include "mlir/IR/IRMapping.h"
24#include "mlir/IR/PatternMatch.h"
25#include "mlir/Pass/Pass.h"
26#include "mlir/Tools/mlir-translate/MlirTranslateMain.h"
27#include "mlir/Transforms/DialectConversion.h"
34static StringRef getArchIntrinsicString(
AIEArch arch) {
43 llvm::report_fatal_error(
"unsupported arch");
46typedef std::tuple<const char *, std::vector<Type>, std::vector<Type>>
50static auto getAIE1Intrinsics(OpBuilder &builder) {
51 Type int32Type = IntegerType::get(builder.getContext(), 32);
52 Type int128Type = IntegerType::get(builder.getContext(), 128);
53 Type int384Type = IntegerType::get(builder.getContext(), 384);
54 Type floatType = Float32Type::get(builder.getContext());
59 {
"debug_i32", {int32Type}, {}},
60 {
"llvm.aie.event0", {}, {}},
61 {
"llvm.aie.event1", {}, {}},
63 {int32Type, int32Type},
66 {int32Type, int128Type},
69 {int32Type, floatType},
71 {
"llvm.aie.get.ss", {int32Type}, {int32Type}},
75 {
"llvm.aie.get.fss", {int32Type}, {floatType}},
76 {
"llvm.aie.put.mcd", {int384Type}, {}},
77 {
"llvm.aie.get.scd", {}, {int384Type}},
78 {
"llvm.aie.lock.acquire.reg",
79 {int32Type, int32Type},
81 {
"llvm.aie.lock.release.reg",
82 {int32Type, int32Type},
88static auto getAIE2Intrinsics(OpBuilder &builder) {
89 Type int32Type = IntegerType::get(builder.getContext(), 32);
90 Type accType = VectorType::get({16}, int32Type);
92 {
"debug_i32", {int32Type}, {}},
93 {
"llvm.aie2.put.ms", {int32Type, int32Type}, {}},
94 {
"llvm.aie2.get.ss", {}, {int32Type, int32Type}},
95 {
"llvm.aie2.mcd.write.vec",
98 {
"llvm.aie2.scd.read.vec",
101 {
"llvm.aie2.acquire",
102 {int32Type, int32Type},
104 {
"llvm.aie2.release",
105 {int32Type, int32Type},
111static auto getAIE2pIntrinsics(OpBuilder &builder) {
112 Type int32Type = IntegerType::get(builder.getContext(), 32);
113 Type accType = VectorType::get({16}, int32Type);
115 {
"debug_i32", {int32Type}, {}},
116 {
"llvm.aie2p.put.ms",
117 {int32Type, int32Type},
119 {
"llvm.aie2p.get.ss",
121 {int32Type, int32Type}},
122 {
"llvm.aie2p.mcd.write.vec",
123 {accType, int32Type},
125 {
"llvm.aie2p.scd.read.vec",
128 {
"llvm.aie2p.acquire",
129 {int32Type, int32Type},
131 {
"llvm.aie2p.release",
132 {int32Type, int32Type},
138static void declareAIEIntrinsics(
AIEArch arch, OpBuilder &builder) {
140 for (
auto &i : functions) {
141 auto [name, argTypes, retTypes] = i;
143 .create<func::FuncOp>(
144 builder.getUnknownLoc(), name,
145 FunctionType::get(builder.getContext(), argTypes, retTypes))
151 registerIntrinsics(getAIE1Intrinsics(builder));
154 registerIntrinsics(getAIE2Intrinsics(builder));
157 registerIntrinsics(getAIE2pIntrinsics(builder));
160 llvm::report_fatal_error(
"unsupported arch");
163template <
typename MyAIEOp>
169 AIEOpRemoval(MLIRContext *context, ModuleOp &m, PatternBenefit benefit = 1)
169 AIEOpRemoval(MLIRContext *context, ModuleOp &m, PatternBenefit benefit = 1) {
…}
174 ConversionPatternRewriter &rewriter)
const override {
175 rewriter.eraseOp(op);
181 using OpConversionPattern::OpConversionPattern;
185 PatternBenefit benefit = 1)
190 ConversionPatternRewriter &rewriter)
const override {
191 std::string funcName =
"debug_i32";
192 auto func =
module.lookupSymbol<func::FuncOp>(funcName);
194 return op.emitOpError(
"Could not find the intrinsic function ")
196 SmallVector<Value, 1>
args;
197 args.push_back(op.getArg());
198 rewriter.create<func::CallOp>(rewriter.getUnknownLoc(), func,
args);
199 rewriter.eraseOp(op);
205 using OpConversionPattern::OpConversionPattern;
209 PatternBenefit benefit = 1)
214 ConversionPatternRewriter &rewriter)
const override {
215 auto device = op->getParentOfType<DeviceOp>();
216 const auto &targetModel = device.getTargetModel();
217 std::string funcName;
218 if (targetModel.getTargetArch() == AIEArch::AIE1)
219 funcName =
"llvm.aie.put.";
220 else if (targetModel.getTargetArch() == AIEArch::AIE2)
221 funcName =
"llvm.aie2.put.";
223 funcName =
"llvm.aie2p.put.";
225 if (op.isWideStream())
227 else if (op.isFloatStream())
232 auto putMSFunc =
module.lookupSymbol<func::FuncOp>(funcName);
234 return op.emitOpError(
"Could not find the intrinsic function ")
236 SmallVector<Value, 2>
args;
237 if (targetModel.getTargetArch() == AIEArch::AIE1) {
238 args.push_back(op.getChannel());
239 args.push_back(op.getStreamValue());
241 args.push_back(op.getStreamValue());
242 args.push_back(rewriter.create<arith::ConstantOp>(
243 op.getLoc(), IntegerType::get(rewriter.getContext(), 32),
244 rewriter.getI32IntegerAttr(0)));
246 rewriter.create<func::CallOp>(rewriter.getUnknownLoc(), putMSFunc,
args);
247 rewriter.eraseOp(op);
253 using OpConversionPattern::OpConversionPattern;
257 PatternBenefit benefit = 1)
262 ConversionPatternRewriter &rewriter)
const override {
263 auto device = op->getParentOfType<DeviceOp>();
264 const auto &targetModel = device.getTargetModel();
265 std::string funcName;
266 if (targetModel.getTargetArch() == AIEArch::AIE1)
267 funcName =
"llvm.aie.get.";
268 else if (targetModel.getTargetArch() == AIEArch::AIE2)
269 funcName =
"llvm.aie2.get.";
271 funcName =
"llvm.aie2p.get.";
273 if (op.isWideStream())
275 else if (op.isFloatStream())
280 auto getSSFunc =
module.lookupSymbol<func::FuncOp>(funcName);
282 return op.emitOpError(
"Could not find the intrinsic function ")
284 SmallVector<Value, 2>
args;
285 if (targetModel.getTargetArch() == AIEArch::AIE1)
286 args.push_back(op.getChannel());
287 auto getSSCall = rewriter.create<func::CallOp>(rewriter.getUnknownLoc(),
289 rewriter.replaceOp(op, getSSCall.getResult(0));
296 using OpConversionPattern::OpConversionPattern;
300 PatternBenefit benefit = 1)
305 ConversionPatternRewriter &rewriter)
const override {
306 auto device = op->getParentOfType<DeviceOp>();
307 const auto &targetModel = device.getTargetModel();
308 std::string funcName;
309 if (targetModel.getTargetArch() == AIEArch::AIE1)
310 funcName =
"llvm.aie.put.mcd";
311 else if (targetModel.getTargetArch() == AIEArch::AIE2)
312 funcName =
"llvm.aie2.mcd.write.vec";
314 funcName =
"llvm.aie2p.mcd.write.vec";
315 auto putMCDFunc =
module.lookupSymbol<func::FuncOp>(funcName);
317 return op.emitOpError(
"Could not find the intrinsic function ")
319 SmallVector<Value, 2>
args;
320 args.push_back(op.getCascadeValue());
321 if (isa<AIE2TargetModel>(targetModel))
322 args.push_back(rewriter.create<arith::ConstantOp>(
323 op.getLoc(), IntegerType::get(rewriter.getContext(), 32),
324 rewriter.getI32IntegerAttr(1)));
326 rewriter.create<func::CallOp>(rewriter.getUnknownLoc(), putMCDFunc,
args);
327 rewriter.eraseOp(op);
333 using OpConversionPattern::OpConversionPattern;
337 PatternBenefit benefit = 1)
342 ConversionPatternRewriter &rewriter)
const override {
343 auto device = op->getParentOfType<DeviceOp>();
344 const auto &targetModel = device.getTargetModel();
345 std::string funcName;
346 if (targetModel.getTargetArch() == AIEArch::AIE1)
347 funcName =
"llvm.aie.get.scd";
348 else if (targetModel.getTargetArch() == AIEArch::AIE2)
349 funcName =
"llvm.aie2.scd.read.vec";
351 funcName =
"llvm.aie2p.scd.read.vec";
352 auto getSCDFunc =
module.lookupSymbol<func::FuncOp>(funcName);
354 return op.emitOpError(
"Could not find the intrinsic function ")
356 SmallVector<Value, 2>
args;
357 if (isa<AIE2TargetModel>(targetModel))
358 args.push_back(rewriter.create<arith::ConstantOp>(
359 op.getLoc(), IntegerType::get(rewriter.getContext(), 32),
360 rewriter.getI32IntegerAttr(1)));
362 auto getSCDCall = rewriter.create<func::CallOp>(rewriter.getUnknownLoc(),
364 rewriter.replaceOp(op, getSCDCall.getResult(0));
370 using OpConversionPattern::OpConversionPattern;
374 PatternBenefit benefit = 1)
378 ConversionPatternRewriter &rewriter)
const override {
379 if (!isa<DeviceOp>(useLock->getParentOp())) {
380 auto device = useLock->getParentOfType<DeviceOp>();
382 return module.emitOpError("Device Not found!");
384 const auto &targetModel = device.getTargetModel();
387 std::string funcName;
388 if (targetModel.getTargetArch() == AIEArch::AIE1)
389 funcName =
"llvm.aie.lock.";
390 else if (targetModel.getTargetArch() == AIEArch::AIE2)
391 funcName =
"llvm.aie2.";
393 funcName =
"llvm.aie2p.";
394 if (useLock.acquire() || useLock.acquireGE())
395 funcName +=
"acquire";
396 else if (useLock.release())
397 funcName +=
"release";
398 if (targetModel.getTargetArch() == AIEArch::AIE1)
401 auto useLockFunc =
module.lookupSymbol<func::FuncOp>(funcName);
403 return useLock.emitOpError(
"Could not find the intrinsic function!");
405 SmallVector<Value, 2>
args;
406 auto lockValue = useLock.getLockValue();
409 if (useLock.acquireGE()) {
410 lockValue = -lockValue;
412 args.push_back(rewriter.create<arith::IndexCastOp>(
413 useLock.getLoc(), IntegerType::get(rewriter.getContext(), 32),
415 args.push_back(rewriter.create<arith::ConstantOp>(
416 useLock.getLoc(), IntegerType::get(rewriter.getContext(), 32),
417 rewriter.getI32IntegerAttr(lockValue)));
419 rewriter.create<func::CallOp>(rewriter.getUnknownLoc(), useLockFunc,
422 rewriter.eraseOp(useLock);
428 using OpConversionPattern::OpConversionPattern;
433 PatternBenefit benefit = 1,
int tileCol = -1,
439 ConversionPatternRewriter &rewriter)
const override {
440 rewriter.setInsertionPointToStart(module.getBody());
441 auto t = llvm::cast<MemRefType>(buffer.getType());
442 int col = llvm::cast<TileOp>(buffer.getTile().getDefiningOp()).getCol();
443 int row = llvm::cast<TileOp>(buffer.getTile().getDefiningOp()).getRow();
444 auto symName = buffer.name().getValue();
445 mlir::ElementsAttr initValue = buffer.getInitialValueAttr();
450 rewriter.create<memref::GlobalOp>(
451 rewriter.getUnknownLoc(), symName, rewriter.getStringAttr(
"public"),
452 buffer.getType(), initValue,
false,
455 for (
auto &use : make_early_inc_range(buffer.getResult().getUses())) {
456 Operation *user = use.getOwner();
457 rewriter.setInsertionPoint(user);
458 auto allocated = rewriter.create<memref::GetGlobalOp>(
459 rewriter.getUnknownLoc(), t, symName);
461 rewriter.create<memref::AssumeAlignmentOp>(rewriter.getUnknownLoc(),
464 use.set(allocated.getResult());
467 rewriter.eraseOp(buffer);
473 using OpConversionPattern::OpConversionPattern;
481 MLIRContext *context, ModuleOp &m, IRMapping &
mapper,
482 DenseMap<Operation *, SmallVector<BufferOp, 4>> &
tileToBuffers,
489 ConversionPatternRewriter &rewriter)
const override {
491 int col = op.colIndex();
492 int row = op.rowIndex();
497 rewriter.eraseOp(op);
502 rewriter.setInsertionPointAfter(op->getParentOp());
504 std::string coreName(
"core_" + std::to_string(
col) +
"_" +
505 std::to_string(
row));
506 auto coreFunc = rewriter.create<func::FuncOp>(
507 rewriter.getUnknownLoc(), coreName,
508 FunctionType::get(rewriter.getContext(), {}, {}));
510 rewriter.cloneRegionBefore(op.getBody(), coreFunc.getBody(),
511 coreFunc.getBody().begin(),
mapper);
514 coreFunc.getBody().walk([&](Operation *childOp) {
515 rewriter.setInsertionPointAfter(childOp);
517 if (isa<EndOp>(childOp)) {
518 rewriter.create<func::ReturnOp>(rewriter.getUnknownLoc(),
520 rewriter.eraseOp(childOp);
524 rewriter.eraseOp(op);
530template <
typename OpTy>
532 SmallVector<OpTy, 16> ops;
533 for (
const auto &op : device.getOps<OpTy>())
536 for (
const auto &op : ops)
537 op->moveBefore(device);
542 using OpConversionPattern::OpConversionPattern;
546 PatternBenefit benefit = 1)
551 ConversionPatternRewriter &rewriter)
const override {
552 std::string funcName =
"llvm.aie.event" + std::to_string(op.getVal());
553 auto eventFunc =
module.lookupSymbol<func::FuncOp>(funcName);
555 return op.emitOpError(
"Could not find the intrinsic function ")
557 rewriter.create<func::CallOp>(rewriter.getUnknownLoc(), eventFunc,
559 rewriter.eraseOp(op);
567 ModuleOp m = getOperation();
568 OpBuilder builder = OpBuilder::atBlockEnd(m.getBody());
570 if (m.getOps<DeviceOp>().empty()) {
571 m.emitOpError(
"expected AIE.device operation at toplevel");
572 return signalPassFailure();
574 DeviceOp device = *m.getOps<DeviceOp>().begin();
575 const auto &targetModel = device.getTargetModel();
579 m->setAttr(LLVM::LLVMDialect::getTargetTripleAttrName(),
580 builder.getStringAttr(
581 getArchIntrinsicString(targetModel.getTargetArch())));
583 DenseMap<Operation *, SmallVector<BufferOp, 4>> tileToBuffers;
589 builder.setInsertionPointToStart(m.getBody());
590 declareAIEIntrinsics(targetModel.getTargetArch(), builder);
593 ConversionTarget target(getContext());
594 target.addLegalDialect<func::FuncDialect>();
595 target.addLegalDialect<cf::ControlFlowDialect>();
596 target.addLegalDialect<memref::MemRefDialect>();
597 target.addLegalDialect<VectorDialect>();
598 target.addLegalDialect<arith::ArithDialect>();
599 target.addLegalDialect<math::MathDialect>();
600 target.addLegalDialect<index::IndexDialect>();
601 target.addLegalOp<func::FuncOp, ModuleOp>();
603 RewritePatternSet patterns(&getContext());
611 if (failed(applyPartialConversion(m, target, std::move(patterns))))
612 return signalPassFailure();
614 RewritePatternSet outlinePatterns(&getContext());
618 if (failed(applyPartialConversion(m, target, std::move(outlinePatterns))))
619 return signalPassFailure();
623 outlineOps<memref::GlobalOp>(device);
624 outlineOps<func::FuncOp>(device);
626 RewritePatternSet removepatterns(&getContext());
635 if (failed(applyPartialConversion(m, target, std::move(removepatterns))))
636 return signalPassFailure();
641 return std::make_unique<AIECoreToStandardPass>();
std::vector< IntrinsicDecl > IntrinsicDecls
void outlineOps(DeviceOp device)
std::tuple< const char *, std::vector< Type >, std::vector< Type > > IntrinsicDecl
Include the generated interface declarations.
std::unique_ptr< mlir::OperationPass< mlir::ModuleOp > > createAIECoreToStandardPass()
LogicalResult matchAndRewrite(BufferOp buffer, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
AIEBufferToStandard(MLIRContext *context, ModuleOp &m, PatternBenefit benefit=1, int tileCol=-1, int tileRow=-1)
ModuleOp &IRMapping & mapper
LogicalResult matchAndRewrite(CoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
AIECoreToStandardFunc(MLIRContext *context, ModuleOp &m, IRMapping &mapper, DenseMap< Operation *, SmallVector< BufferOp, 4 > > &tileToBuffers, PatternBenefit benefit=1, int tileCol=1, int tileRow=1)
DenseMap< Operation *, SmallVector< BufferOp, 4 > > & tileToBuffers
void runOnOperation() override
ModuleOp & AIEDebugOpToStdLowering(MLIRContext *context, ModuleOp &m, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(DebugOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ModuleOp & AIEEventOpToStdLowering(MLIRContext *context, ModuleOp &m, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(EventOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(GetCascadeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ModuleOp & AIEGetCascadeToStdLowering(MLIRContext *context, ModuleOp &m, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(GetStreamOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ModuleOp & AIEGetStreamToStdLowering(MLIRContext *context, ModuleOp &m, PatternBenefit benefit=1)
ModuleOp & AIEOpRemoval(MLIRContext *context, ModuleOp &m, PatternBenefit benefit=1)
typename MyAIEOp::Adaptor OpAdaptor
LogicalResult matchAndRewrite(MyAIEOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(PutCascadeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ModuleOp & AIEPutCascadeToStdLowering(MLIRContext *context, ModuleOp &m, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(PutStreamOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ModuleOp & AIEPutStreamToStdLowering(MLIRContext *context, ModuleOp &m, PatternBenefit benefit=1)
ModuleOp & AIEUseLockToStdLowering(MLIRContext *context, ModuleOp &m, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(UseLockOp useLock, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override