MLIR-AIE
AIECoreToStandard.cpp
Go to the documentation of this file.
1//===- AIECoreToStandard.cpp ------------------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2019 Xilinx Inc.
8//
9//===----------------------------------------------------------------------===//
10
14
15#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
16#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
19#include "mlir/Dialect/DLTI/DLTI.h"
20#include "mlir/Dialect/Index/IR/IndexDialect.h"
21#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22#include "mlir/Dialect/Math/IR/Math.h"
23#include "mlir/Dialect/Ptr/IR/PtrOps.h"
24#include "mlir/Dialect/UB/IR/UBOps.h"
25#include "mlir/Dialect/Vector/IR/VectorOps.h"
26#include "mlir/IR/Attributes.h"
27#include "mlir/IR/IRMapping.h"
28#include "mlir/IR/PatternMatch.h"
29#include "mlir/Pass/Pass.h"
30#include "mlir/Tools/mlir-translate/MlirTranslateMain.h"
31#include "mlir/Transforms/DialectConversion.h"
32
33namespace xilinx::AIE {
34#define GEN_PASS_DEF_AIECORETOSTANDARD
35#include "aie/Dialect/AIE/Transforms/AIEPasses.h.inc"
36} // namespace xilinx::AIE
37
38using namespace mlir;
39using namespace mlir::vector;
40using namespace xilinx;
41using namespace xilinx::AIE;
42
43static StringRef getArchIntrinsicString(AIEArch arch) {
44 switch (arch) {
45 case AIEArch::AIE1:
46 return "aie";
47 case AIEArch::AIE2:
48 return "aie2";
49 case AIEArch::AIE2p:
50 return "aie2p";
51 }
52 llvm::report_fatal_error("unsupported arch");
53}
54
55typedef std::tuple<const char *, std::vector<Type>, std::vector<Type>>
57typedef std::vector<IntrinsicDecl> IntrinsicDecls;
58
59static auto getAIE1Intrinsics(OpBuilder &builder) {
60 Type int32Type = IntegerType::get(builder.getContext(), 32);
61 Type int128Type = IntegerType::get(builder.getContext(), 128);
62 Type int384Type = IntegerType::get(builder.getContext(), 384);
63 Type floatType = Float32Type::get(builder.getContext());
64
65 // Note that not all of these are valid for a particular design, or needed.
66 // For right now, we will just accept the noise.
67 IntrinsicDecls functions = {
68 {"debug_i32", {int32Type}, {}},
69 {"llvm.aie.event0", {}, {}},
70 {"llvm.aie.event1", {}, {}},
71 {"llvm.aie.put.ms",
72 {int32Type, int32Type},
73 {}}, //(%channel, %value) -> ()
74 {"llvm.aie.put.wms",
75 {int32Type, int128Type},
76 {}}, //(%channel, %value) -> ()
77 {"llvm.aie.put.fms",
78 {int32Type, floatType},
79 {}}, //(%channel, %value) -> ()
80 {"llvm.aie.get.ss", {int32Type}, {int32Type}}, //(%channel, %value) -> ()
81 {"llvm.aie.get.wss",
82 {int32Type},
83 {int128Type}}, //(%channel, %value) -> ()
84 {"llvm.aie.get.fss", {int32Type}, {floatType}}, //(%channel, %value) -> ()
85 {"llvm.aie.put.mcd", {int384Type}, {}},
86 {"llvm.aie.get.scd", {}, {int384Type}},
87 {"llvm.aie.lock.acquire.reg",
88 {int32Type, int32Type},
89 {}}, //(%lock_id, %lock_val) -> ()
90 {"llvm.aie.lock.release.reg",
91 {int32Type, int32Type},
92 {}}, //(%lock_id, %lock_val) -> ()
93 };
94 return functions;
95}
96
97static auto getAIE2Intrinsics(OpBuilder &builder) {
98 Type int32Type = IntegerType::get(builder.getContext(), 32);
99 Type accType = VectorType::get({16}, int32Type);
100 IntrinsicDecls functions = {
101 {"debug_i32", {int32Type}, {}},
102 {"llvm.aie2.event", {int32Type}, {}},
103 {"llvm.aie2.put.ms", {int32Type, int32Type}, {}}, //(%value, %tlast) -> ()
104 {"llvm.aie2.get.ss", {}, {int32Type, int32Type}}, //() -> (%value, %tlast)
105 {"llvm.aie2.mcd.write.vec",
106 {accType, int32Type},
107 {}}, // (%value, %enable) -> ()
108 {"llvm.aie2.scd.read.vec",
109 {int32Type},
110 {accType}}, // (%enable) -> (%value)
111 {"llvm.aie2.acquire",
112 {int32Type, int32Type},
113 {}}, //(%lock_id, %lock_val) -> ()
114 {"llvm.aie2.release",
115 {int32Type, int32Type},
116 {}}, //(%lock_id, %lock_val) -> ()
117 {"llvm.aie2.set.ctrl.reg",
118 {int32Type, int32Type},
119 {}}, //(%reg_id, %value) -> ()
120 };
121 return functions;
122}
123
124static auto getAIE2pIntrinsics(OpBuilder &builder) {
125 Type int32Type = IntegerType::get(builder.getContext(), 32);
126 Type accType = VectorType::get({16}, int32Type);
127 IntrinsicDecls functions = {
128 {"debug_i32", {int32Type}, {}},
129 {"llvm.aie2p.event", {int32Type}, {}},
130 {"llvm.aie2p.put.ms",
131 {int32Type, int32Type},
132 {}}, //(%value, %tlast) -> ()
133 {"llvm.aie2p.get.ss",
134 {},
135 {int32Type, int32Type}}, //() -> (%value, %tlast)
136 {"llvm.aie2p.mcd.write.vec",
137 {accType, int32Type},
138 {}}, // (%value, %enable) -> ()
139 {"llvm.aie2p.scd.read.vec",
140 {int32Type},
141 {accType}}, // (%enable) -> (%value)
142 {"llvm.aie2p.acquire",
143 {int32Type, int32Type},
144 {}}, //(%lock_id, %lock_val) -> ()
145 {"llvm.aie2p.release",
146 {int32Type, int32Type},
147 {}}, //(%lock_id, %lock_val) -> ()
148 {"llvm.aie2p.set.ctrl.reg",
149 {int32Type, int32Type},
150 {}}, //(%reg_id, %value) -> ()
151 };
152 return functions;
153}
154
155static void declareAIEIntrinsics(AIEArch arch, OpBuilder &builder) {
156 auto registerIntrinsics = [&builder](IntrinsicDecls functions) {
157 for (auto &i : functions) {
158 auto [name, argTypes, retTypes] = i;
159 func::FuncOp::create(
160 builder, builder.getUnknownLoc(), name,
161 FunctionType::get(builder.getContext(), argTypes, retTypes))
162 .setPrivate();
163 }
164 };
165 switch (arch) {
166 case AIEArch::AIE1:
167 registerIntrinsics(getAIE1Intrinsics(builder));
168 return;
169 case AIEArch::AIE2:
170 registerIntrinsics(getAIE2Intrinsics(builder));
171 return;
172 case AIEArch::AIE2p:
173 registerIntrinsics(getAIE2pIntrinsics(builder));
174 return;
175 }
176 llvm::report_fatal_error("unsupported arch");
177}
178
179template <typename MyAIEOp>
182 using OpAdaptor = typename MyAIEOp::Adaptor;
183 ModuleOp &module;
184
185 AIEOpRemoval(MLIRContext *context, ModuleOp &m, PatternBenefit benefit = 1)
186 : OpConversionPattern<MyAIEOp>(context, benefit), module(m) {}
187
188 LogicalResult
189 matchAndRewrite(MyAIEOp op, OpAdaptor adaptor,
190 ConversionPatternRewriter &rewriter) const override {
191 rewriter.eraseOp(op);
192 return success();
193 }
194};
195
197 using OpConversionPattern::OpConversionPattern;
198 ModuleOp &module;
199
200 AIEDebugOpToStdLowering(MLIRContext *context, ModuleOp &m,
201 PatternBenefit benefit = 1)
202 : OpConversionPattern(context, benefit), module(m) {}
203
204 LogicalResult
205 matchAndRewrite(DebugOp op, OpAdaptor adaptor,
206 ConversionPatternRewriter &rewriter) const override {
207 std::string funcName = "debug_i32";
208 auto func = module.lookupSymbol<func::FuncOp>(funcName);
209 if (!func)
210 return op.emitOpError("Could not find the intrinsic function ")
211 << funcName;
212 SmallVector<Value, 1> args;
213 args.push_back(op.getArg());
214 func::CallOp::create(rewriter, rewriter.getUnknownLoc(), func, args);
215 rewriter.eraseOp(op);
216 return success();
217 }
218};
219
221 using OpConversionPattern::OpConversionPattern;
222 ModuleOp &module;
223
224 AIEPutStreamToStdLowering(MLIRContext *context, ModuleOp &m,
225 PatternBenefit benefit = 1)
226 : OpConversionPattern(context, benefit), module(m) {}
227
228 LogicalResult
229 matchAndRewrite(PutStreamOp op, OpAdaptor adaptor,
230 ConversionPatternRewriter &rewriter) const override {
231 auto device = op->getParentOfType<DeviceOp>();
232 const auto &targetModel = device.getTargetModel();
233 std::string funcName;
234 if (targetModel.getTargetArch() == AIEArch::AIE1)
235 funcName = "llvm.aie.put.";
236 else if (targetModel.getTargetArch() == AIEArch::AIE2)
237 funcName = "llvm.aie2.put.";
238 else
239 funcName = "llvm.aie2p.put.";
240
241 if (op.isWideStream())
242 funcName += "wms";
243 else if (op.isFloatStream())
244 funcName += "fms";
245 else
246 funcName += "ms";
247
248 auto putMSFunc = module.lookupSymbol<func::FuncOp>(funcName);
249 if (!putMSFunc)
250 return op.emitOpError("Could not find the intrinsic function ")
251 << funcName;
252 SmallVector<Value, 2> args;
253 if (targetModel.getTargetArch() == AIEArch::AIE1) {
254 args.push_back(op.getChannel());
255 args.push_back(op.getStreamValue());
256 } else {
257 args.push_back(op.getStreamValue());
258 args.push_back(arith::ConstantOp::create(
259 rewriter, op.getLoc(), IntegerType::get(rewriter.getContext(), 32),
260 rewriter.getI32IntegerAttr(0))); // tlast
261 }
262 func::CallOp::create(rewriter, rewriter.getUnknownLoc(), putMSFunc, args);
263 rewriter.eraseOp(op);
264 return success();
265 }
266};
267
269 using OpConversionPattern::OpConversionPattern;
270 ModuleOp &module;
271
272 AIEGetStreamToStdLowering(MLIRContext *context, ModuleOp &m,
273 PatternBenefit benefit = 1)
274 : OpConversionPattern(context, benefit), module(m) {}
275
276 LogicalResult
277 matchAndRewrite(GetStreamOp op, OpAdaptor adaptor,
278 ConversionPatternRewriter &rewriter) const override {
279 auto device = op->getParentOfType<DeviceOp>();
280 const auto &targetModel = device.getTargetModel();
281 std::string funcName;
282 if (targetModel.getTargetArch() == AIEArch::AIE1)
283 funcName = "llvm.aie.get.";
284 else if (targetModel.getTargetArch() == AIEArch::AIE2)
285 funcName = "llvm.aie2.get.";
286 else
287 funcName = "llvm.aie2p.get.";
288
289 if (op.isWideStream())
290 funcName += "wss";
291 else if (op.isFloatStream())
292 funcName += "fss";
293 else
294 funcName += "ss";
295
296 auto getSSFunc = module.lookupSymbol<func::FuncOp>(funcName);
297 if (!getSSFunc)
298 return op.emitOpError("Could not find the intrinsic function ")
299 << funcName;
300 SmallVector<Value, 2> args;
301 if (targetModel.getTargetArch() == AIEArch::AIE1)
302 args.push_back(op.getChannel());
303 auto getSSCall = func::CallOp::create(rewriter, rewriter.getUnknownLoc(),
304 getSSFunc, args);
305 rewriter.replaceOp(op, getSSCall.getResult(0));
306 // Capture TLAST in AIEv2?
307 return success();
308 }
309};
310
312 using OpConversionPattern::OpConversionPattern;
313 ModuleOp &module;
314
315 AIEPutCascadeToStdLowering(MLIRContext *context, ModuleOp &m,
316 PatternBenefit benefit = 1)
317 : OpConversionPattern(context, benefit), module(m) {}
318
319 LogicalResult
320 matchAndRewrite(PutCascadeOp op, OpAdaptor adaptor,
321 ConversionPatternRewriter &rewriter) const override {
322 auto device = op->getParentOfType<DeviceOp>();
323 const auto &targetModel = device.getTargetModel();
324 std::string funcName;
325 if (targetModel.getTargetArch() == AIEArch::AIE1)
326 funcName = "llvm.aie.put.mcd";
327 else if (targetModel.getTargetArch() == AIEArch::AIE2)
328 funcName = "llvm.aie2.mcd.write.vec";
329 else
330 funcName = "llvm.aie2p.mcd.write.vec";
331 auto putMCDFunc = module.lookupSymbol<func::FuncOp>(funcName);
332 if (!putMCDFunc)
333 return op.emitOpError("Could not find the intrinsic function ")
334 << funcName;
335 SmallVector<Value, 2> args;
336 Value cascadeValue = op.getCascadeValue();
337
338 // Check if we need a bitcast for the input value
339 Type expectedInputType = putMCDFunc.getFunctionType().getInput(0);
340 Type actualInputType = cascadeValue.getType();
341
342 if (expectedInputType != actualInputType) {
343 // Create a bitcast operation to convert from actual input type to
344 // expected type
345 cascadeValue = vector::BitCastOp::create(rewriter, op.getLoc(),
346 expectedInputType, cascadeValue);
347 }
348
349 args.push_back(cascadeValue);
350 if (isa<AIE2TargetModel>(targetModel))
351 args.push_back(arith::ConstantOp::create(
352 rewriter, op.getLoc(), IntegerType::get(rewriter.getContext(), 32),
353 rewriter.getI32IntegerAttr(1))); // enable
354
355 func::CallOp::create(rewriter, rewriter.getUnknownLoc(), putMCDFunc, args);
356 rewriter.eraseOp(op);
357 return success();
358 }
359};
360
362 using OpConversionPattern::OpConversionPattern;
363 ModuleOp &module;
364
365 AIEGetCascadeToStdLowering(MLIRContext *context, ModuleOp &m,
366 PatternBenefit benefit = 1)
367 : OpConversionPattern(context, benefit), module(m) {}
368
369 LogicalResult
370 matchAndRewrite(GetCascadeOp op, OpAdaptor adaptor,
371 ConversionPatternRewriter &rewriter) const override {
372 auto device = op->getParentOfType<DeviceOp>();
373 const auto &targetModel = device.getTargetModel();
374 std::string funcName;
375 if (targetModel.getTargetArch() == AIEArch::AIE1)
376 funcName = "llvm.aie.get.scd";
377 else if (targetModel.getTargetArch() == AIEArch::AIE2)
378 funcName = "llvm.aie2.scd.read.vec";
379 else
380 funcName = "llvm.aie2p.scd.read.vec";
381 auto getSCDFunc = module.lookupSymbol<func::FuncOp>(funcName);
382 if (!getSCDFunc)
383 return op.emitOpError("Could not find the intrinsic function ")
384 << funcName;
385 SmallVector<Value, 2> args;
386 if (isa<AIE2TargetModel>(targetModel))
387 args.push_back(arith::ConstantOp::create(
388 rewriter, op.getLoc(), IntegerType::get(rewriter.getContext(), 32),
389 rewriter.getI32IntegerAttr(1))); // enable
390
391 auto getSCDCall = func::CallOp::create(rewriter, rewriter.getUnknownLoc(),
392 getSCDFunc, args);
393 Value result = getSCDCall.getResult(0);
394
395 // Check if we need a bitcast
396 Type expectedType = op.getResult().getType();
397 Type intrinsicReturnType = result.getType();
398
399 if (expectedType != intrinsicReturnType) {
400 // Create a bitcast operation to convert from intrinsic return type to
401 // expected type
402 result = vector::BitCastOp::create(rewriter, op.getLoc(), expectedType,
403 result);
404 }
405
406 rewriter.replaceOp(op, result);
407 return success();
408 }
409};
410
412 using OpConversionPattern::OpConversionPattern;
413 ModuleOp &module;
414
415 AIEUseLockToStdLowering(MLIRContext *context, ModuleOp &m,
416 PatternBenefit benefit = 1)
417 : OpConversionPattern(context, benefit), module(m) {}
418 LogicalResult
419 matchAndRewrite(UseLockOp useLock, OpAdaptor adaptor,
420 ConversionPatternRewriter &rewriter) const override {
421 if (!isa<DeviceOp>(useLock->getParentOp())) {
422 auto device = useLock->getParentOfType<DeviceOp>();
423 if (!device) {
424 return module.emitOpError("Device Not found!");
425 }
426 const auto &targetModel = device.getTargetModel();
427
428 // Generate the intrinsic name
429 std::string funcName;
430 if (targetModel.getTargetArch() == AIEArch::AIE1)
431 funcName = "llvm.aie.lock.";
432 else if (targetModel.getTargetArch() == AIEArch::AIE2)
433 funcName = "llvm.aie2.";
434 else
435 funcName = "llvm.aie2p.";
436 if (useLock.acquire() || useLock.acquireGE())
437 funcName += "acquire";
438 else if (useLock.release())
439 funcName += "release";
440 if (targetModel.getTargetArch() == AIEArch::AIE1)
441 funcName += ".reg";
442
443 auto useLockFunc = module.lookupSymbol<func::FuncOp>(funcName);
444 if (!useLockFunc)
445 return useLock.emitOpError("Could not find the intrinsic function!");
446
447 SmallVector<Value, 2> args;
448 auto lockValue = useLock.getLockValue();
449
450 // AIE2 acquire greater equal is encoded as a negative value.
451 if (useLock.acquireGE()) {
452 lockValue = -lockValue;
453 }
454 args.push_back(arith::IndexCastOp::create(
455 rewriter, useLock.getLoc(),
456 IntegerType::get(rewriter.getContext(), 32), useLock.getLock()));
457 args.push_back(
458 arith::ConstantOp::create(rewriter, useLock.getLoc(),
459 IntegerType::get(rewriter.getContext(), 32),
460 rewriter.getI32IntegerAttr(lockValue)));
461
462 func::CallOp::create(rewriter, rewriter.getUnknownLoc(), useLockFunc,
463 args);
464 }
465 rewriter.eraseOp(useLock);
466 return success();
467 }
468};
469
471 using OpConversionPattern::OpConversionPattern;
472 ModuleOp &module;
473 int tileCol = 0;
474 int tileRow = 0;
475 AIEBufferToStandard(MLIRContext *context, ModuleOp &m,
476 PatternBenefit benefit = 1, int tileCol = -1,
477 int tileRow = -1)
478 : OpConversionPattern(context, benefit), module(m), tileCol(tileCol),
479 tileRow(tileRow) {}
480 LogicalResult
481 matchAndRewrite(BufferOp buffer, OpAdaptor adaptor,
482 ConversionPatternRewriter &rewriter) const override {
483 rewriter.setInsertionPointToStart(module.getBody());
484 auto t = llvm::cast<MemRefType>(buffer.getType());
485 int col = llvm::cast<TileOp>(buffer.getTile().getDefiningOp()).getCol();
486 int row = llvm::cast<TileOp>(buffer.getTile().getDefiningOp()).getRow();
487 auto symName = buffer.name().getValue();
488 mlir::ElementsAttr initValue = buffer.getInitialValueAttr();
489 // Don't emit initialization for cores that don't "own" the buffer (to
490 // prevent duplication in the data section of the elf/object file)
491 if ((tileRow != row && tileRow != -1) || (tileCol != col && tileCol != -1))
492 initValue = nullptr;
493 memref::GlobalOp::create(rewriter, rewriter.getUnknownLoc(), symName,
494 rewriter.getStringAttr("public"), buffer.getType(),
495 initValue, /*constant*/ false,
496 /*alignment*/ nullptr);
497
498 for (auto &use : make_early_inc_range(buffer.getResult().getUses())) {
499 Operation *user = use.getOwner();
500 rewriter.setInsertionPoint(user);
501 auto allocated = memref::GetGlobalOp::create(
502 rewriter, rewriter.getUnknownLoc(), t, symName);
503 // Assume that buffers are aligned so they can be vectorized.
504 memref::AssumeAlignmentOp::create(rewriter, rewriter.getUnknownLoc(),
505 allocated, 32);
506
507 use.set(allocated.getResult());
508 }
509
510 rewriter.eraseOp(buffer);
511 return success();
512 }
513};
514
516 using OpConversionPattern::OpConversionPattern;
517 ModuleOp &module;
518 IRMapping &mapper;
519 DenseMap<Operation *, SmallVector<BufferOp, 4>> &tileToBuffers;
520 int tileCol = 0;
521 int tileRow = 0;
522
524 MLIRContext *context, ModuleOp &m, IRMapping &mapper,
525 DenseMap<Operation *, SmallVector<BufferOp, 4>> &tileToBuffers,
526 PatternBenefit benefit = 1, int tileCol = 1, int tileRow = 1)
527 : OpConversionPattern(context, benefit), module(m), mapper(mapper),
529
530 LogicalResult
531 matchAndRewrite(CoreOp op, OpAdaptor adaptor,
532 ConversionPatternRewriter &rewriter) const override {
533
534 int col = op.colIndex();
535 int row = op.rowIndex();
536
537 // Only pull code for the indicated function
538 if ((tileRow != row && tileRow != -1) ||
539 (tileCol != col && tileCol != -1)) {
540 rewriter.eraseOp(op);
541 return success();
542 }
543
544 // The parent should be an AIE.device op.
545 rewriter.setInsertionPointAfter(op->getParentOp());
546
547 std::string coreName("core_" + std::to_string(col) + "_" +
548 std::to_string(row));
549 auto coreFunc =
550 func::FuncOp::create(rewriter, rewriter.getUnknownLoc(), coreName,
551 FunctionType::get(rewriter.getContext(), {}, {}));
552
553 rewriter.cloneRegionBefore(op.getBody(), coreFunc.getBody(),
554 coreFunc.getBody().begin(), mapper);
555
556 // Set saturation and rounding modes at core entry for AIE2/AIE2p, but
557 // only if the core body contains aievec.srs or bf16 aievec.matmul ops.
558 // Skip for cores with only lock/stream ops to avoid breaking existing
559 // test SSA naming.
560 bool hasSRS = false;
561 bool hasIntegerSRS = false;
562 bool hasBF16Matmul = false;
563 coreFunc.walk([&](Operation *childOp) {
564 StringRef opName = childOp->getName().getStringRef();
565 if (opName == "aievec.srs") {
566 hasSRS = true;
567 // Check if this is an integer SRS (e.g., i32→i8) vs float SRS
568 // (e.g., f32→bf16). Integer SRS needs positive_inf rounding to
569 // match C++ kernel behavior; float SRS works better with floor.
570 if (childOp->getNumResults() > 0) {
571 auto resultType = childOp->getResult(0).getType();
572 if (auto vecType = dyn_cast<VectorType>(resultType)) {
573 if (vecType.getElementType().isInteger())
574 hasIntegerSRS = true;
575 }
576 }
577 }
578 // Detect bf16 matmul ops — these need conv_even rounding to avoid
579 // systematic negative bias from floor rounding in BFP16 arithmetic.
580 if (opName == "aievec.matmul" || opName == "aievec.matmul_aie2p") {
581 if (childOp->getNumOperands() > 0) {
582 auto lhsType = childOp->getOperand(0).getType();
583 if (auto vecType = dyn_cast<VectorType>(lhsType))
584 if (vecType.getElementType().isBF16())
585 hasBF16Matmul = true;
586 }
587 }
588 });
589 if (hasSRS || hasBF16Matmul) {
590 auto device = op->getParentOfType<DeviceOp>();
591 if (device) {
592 AIEArch arch = device.getTargetModel().getTargetArch();
593 if (arch == AIEArch::AIE2 || arch == AIEArch::AIE2p) {
594 std::string ctrlRegFuncName = (arch == AIEArch::AIE2p)
595 ? "llvm.aie2p.set.ctrl.reg"
596 : "llvm.aie2.set.ctrl.reg";
597 auto ctrlRegFunc = module.lookupSymbol<func::FuncOp>(ctrlRegFuncName);
598 if (ctrlRegFunc) {
599 Block &entryBlock = coreFunc.getBody().front();
600 rewriter.setInsertionPointToStart(&entryBlock);
601 Location loc = op.getLoc();
602 // Rounding register index differs between AIE2 and AIE2P:
603 // AIE2: crRnd=6
604 // AIE2P: crRnd=1
605 // Saturation register uses AIE2 index (9) for both architectures.
606 // On AIE2P, index 9 maps to crPackSize (no-op for saturation),
607 // preserving the pre-existing behavior. The AIE2P crSat fix
608 // (index 0) requires updating downstream tests and is tracked
609 // separately.
610 int satRegIdx = 9;
611 int rndRegIdx = (arch == AIEArch::AIE2p) ? 1 : 6;
612 // saturation_mode::saturate = 1
613 auto cSatIdx = arith::ConstantOp::create(
614 rewriter, loc, rewriter.getI32IntegerAttr(satRegIdx));
615 auto c1 = arith::ConstantOp::create(rewriter, loc,
616 rewriter.getI32IntegerAttr(1));
617 func::CallOp::create(rewriter, loc, ctrlRegFunc,
618 ValueRange{cSatIdx, c1});
619 // Rounding mode:
620 // - conv_even (12) for bf16 matmul: eliminates systematic
621 // negative bias from floor rounding in BFP16 arithmetic,
622 // matching ::aie::set_rounding(aie::rounding_mode::conv_even)
623 // used in external C++ matmul kernels.
624 // - positive_inf (9) for integer SRS (shift-round-saturate on
625 // integer data, e.g., i32→i8).
626 // - floor (0) for float-only SRS (f32→bf16 truncation).
627 int roundingMode = hasBF16Matmul ? 12 : hasIntegerSRS ? 9 : 0;
628 auto cRndIdx = arith::ConstantOp::create(
629 rewriter, loc, rewriter.getI32IntegerAttr(rndRegIdx));
630 auto cRoundingMode = arith::ConstantOp::create(
631 rewriter, loc, rewriter.getI32IntegerAttr(roundingMode));
632 func::CallOp::create(rewriter, loc, ctrlRegFunc,
633 ValueRange{cRndIdx, cRoundingMode});
634 }
635 }
636 }
637 }
638
639 // Rewrite the AIE.end() op
640 coreFunc.getBody().walk([&](Operation *childOp) {
641 rewriter.setInsertionPointAfter(childOp);
642
643 if (isa<EndOp>(childOp)) {
644 func::ReturnOp::create(rewriter, rewriter.getUnknownLoc(),
645 ValueRange({}));
646 rewriter.eraseOp(childOp);
647 }
648 });
649
650 rewriter.eraseOp(op);
651 return success();
652 }
653};
654
655// Move all the ops with OpTy inside device, to just before the device.
656template <typename OpTy>
657void outlineOps(DeviceOp device) {
658 SmallVector<OpTy, 16> ops;
659 for (const auto &op : device.getOps<OpTy>())
660 ops.push_back(op);
661
662 for (const auto &op : ops)
663 op->moveBefore(device);
664}
665
666// Lower AIE.event to llvm.aie.event intrinsic
668 using OpConversionPattern::OpConversionPattern;
669 ModuleOp &module;
670
671 AIEEventOpToStdLowering(MLIRContext *context, ModuleOp &m,
672 PatternBenefit benefit = 1)
673 : OpConversionPattern(context, benefit), module(m) {}
674
675 LogicalResult
676 matchAndRewrite(EventOp op, OpAdaptor adaptor,
677 ConversionPatternRewriter &rewriter) const override {
678 auto device = op->getParentOfType<DeviceOp>();
679 std::string funcName;
680 SmallVector<Value, 1> args;
681 switch (device.getTargetModel().getTargetArch()) {
682 case AIEArch::AIE1:
683 funcName = "llvm.aie.event" + std::to_string(op.getVal());
684 break;
685 case AIEArch::AIE2:
686 funcName = "llvm.aie2.event";
687 args.push_back(arith::ConstantOp::create(
688 rewriter, op.getLoc(), rewriter.getI32Type(),
689 rewriter.getI32IntegerAttr(op.getVal())));
690 break;
691 case AIEArch::AIE2p:
692 funcName = "llvm.aie2p.event";
693 args.push_back(arith::ConstantOp::create(
694 rewriter, op.getLoc(), rewriter.getI32Type(),
695 rewriter.getI32IntegerAttr(op.getVal())));
696 break;
697 default:
698 return op->emitOpError("Unsupported AIEArch for EventOp lowering");
699 }
700 auto eventFunc = module.lookupSymbol<func::FuncOp>(funcName);
701 if (!eventFunc)
702 return op.emitOpError("Could not find the intrinsic function ")
703 << funcName;
704 func::CallOp::create(rewriter, rewriter.getUnknownLoc(), eventFunc, args);
705 rewriter.eraseOp(op);
706 return success();
707 }
708};
709
711 : xilinx::AIE::impl::AIECoreToStandardBase<AIECoreToStandardPass> {
713 AIECoreToStandardPass(const AIECoreToStandardOptions &options) {
714 deviceName = options.deviceName;
715 tileCol = options.tileCol;
716 tileRow = options.tileRow;
717 }
718
719 void runOnOperation() override {
720
721 ModuleOp m = getOperation();
722 OpBuilder builder = OpBuilder::atBlockEnd(m.getBody());
723
724 DeviceOp deviceOp = DeviceOp::getForSymbolInModuleOrError(m, deviceName);
725 if (!deviceOp) {
726 return signalPassFailure();
727 }
728 const auto &targetModel = deviceOp.getTargetModel();
729
730 // Copy data layout attribute from DeviceOp to ModuleOp if present
731 if (auto dlAttr = deviceOp->getAttr(DLTIDialect::kDataLayoutAttrName)) {
732 m->setAttr(DLTIDialect::kDataLayoutAttrName, dlAttr);
733 }
734
735 // Ensure that we don't have an incorrect target triple. This may override
736 // some bogus target triple in the original mlir.
737 m->setAttr(LLVM::LLVMDialect::getTargetTripleAttrName(),
738 builder.getStringAttr(
739 getArchIntrinsicString(targetModel.getTargetArch())));
740
741 DenseMap<Operation *, SmallVector<BufferOp, 4>> tileToBuffers;
742
743 // Populate intrinsic functions
744 // Intrinsic information:
745 // peano/llvm-project/llvm/lib/Target/AIE/AIEInstrInfo.td Also take a look
746 // at the tests: peano/llvm-project/llvm/test/CodeGen/AIE
747 builder.setInsertionPointToStart(m.getBody());
748 declareAIEIntrinsics(targetModel.getTargetArch(), builder);
749
750 IRMapping mapper;
751 ConversionTarget target(getContext());
752 target.addLegalDialect<func::FuncDialect>();
753 target.addLegalDialect<cf::ControlFlowDialect>();
754 target.addLegalDialect<memref::MemRefDialect>();
755 target.addLegalDialect<VectorDialect>();
756 target.addLegalDialect<aievec::AIEVecDialect>();
757 target.addLegalDialect<arith::ArithDialect>();
758 target.addLegalDialect<ub::UBDialect>();
759 target.addLegalDialect<math::MathDialect>();
760 target.addLegalDialect<index::IndexDialect>();
761 target.addLegalDialect<ptr::PtrDialect>();
762 target.addLegalOp<func::FuncOp, ModuleOp, UnrealizedConversionCastOp>();
763
764 RewritePatternSet patterns(&getContext());
768 AIEEventOpToStdLowering>(m.getContext(), m);
769
770 patterns.add<AIEBufferToStandard>(m.getContext(), m, /*benefit*/ 1, tileCol,
771 tileRow);
772 if (failed(applyPartialConversion(deviceOp, target, std::move(patterns))))
773 return signalPassFailure();
774
775 RewritePatternSet outlinePatterns(&getContext());
776 outlinePatterns.add<AIECoreToStandardFunc>(m.getContext(), m, mapper,
777 tileToBuffers, /*benefit*/ 1,
778 tileCol, tileRow);
779 if (failed(applyPartialConversion(deviceOp, target,
780 std::move(outlinePatterns))))
781 return signalPassFailure();
782
783 // Move all the func.func ops and memref.globals from the device to the
784 // module
785 outlineOps<memref::GlobalOp>(deviceOp);
786 outlineOps<func::FuncOp>(deviceOp);
787
788 RewritePatternSet removepatterns(&getContext());
789 removepatterns.add<
795 m.getContext(), m);
796
797 if (failed(applyPartialConversion(m, target, std::move(removepatterns))))
798 return signalPassFailure();
799 }
800};
801
802std::unique_ptr<OperationPass<ModuleOp>> AIE::createAIECoreToStandardPass() {
803 return std::make_unique<AIECoreToStandardPass>();
804}
805
806std::unique_ptr<OperationPass<ModuleOp>>
807AIE::createAIECoreToStandardPass(const AIECoreToStandardOptions &options) {
808 return std::make_unique<AIECoreToStandardPass>(options);
809}
std::vector< IntrinsicDecl > IntrinsicDecls
void outlineOps(DeviceOp device)
std::tuple< const char *, std::vector< Type >, std::vector< Type > > IntrinsicDecl
Include the generated interface declarations.
std::unique_ptr< mlir::OperationPass< mlir::ModuleOp > > createAIECoreToStandardPass()
AIEArch
Definition Passes.h:21
LogicalResult matchAndRewrite(BufferOp buffer, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
AIEBufferToStandard(MLIRContext *context, ModuleOp &m, PatternBenefit benefit=1, int tileCol=-1, int tileRow=-1)
ModuleOp &IRMapping & mapper
LogicalResult matchAndRewrite(CoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
AIECoreToStandardFunc(MLIRContext *context, ModuleOp &m, IRMapping &mapper, DenseMap< Operation *, SmallVector< BufferOp, 4 > > &tileToBuffers, PatternBenefit benefit=1, int tileCol=1, int tileRow=1)
DenseMap< Operation *, SmallVector< BufferOp, 4 > > & tileToBuffers
AIECoreToStandardPass()=default
AIECoreToStandardPass(const AIECoreToStandardOptions &options)
ModuleOp & AIEDebugOpToStdLowering(MLIRContext *context, ModuleOp &m, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(DebugOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ModuleOp & AIEEventOpToStdLowering(MLIRContext *context, ModuleOp &m, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(EventOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(GetCascadeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ModuleOp & AIEGetCascadeToStdLowering(MLIRContext *context, ModuleOp &m, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(GetStreamOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ModuleOp & AIEGetStreamToStdLowering(MLIRContext *context, ModuleOp &m, PatternBenefit benefit=1)
ModuleOp & AIEOpRemoval(MLIRContext *context, ModuleOp &m, PatternBenefit benefit=1)
typename MyAIEOp::Adaptor OpAdaptor
LogicalResult matchAndRewrite(MyAIEOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(PutCascadeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ModuleOp & AIEPutCascadeToStdLowering(MLIRContext *context, ModuleOp &m, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(PutStreamOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ModuleOp & AIEPutStreamToStdLowering(MLIRContext *context, ModuleOp &m, PatternBenefit benefit=1)
ModuleOp & AIEUseLockToStdLowering(MLIRContext *context, ModuleOp &m, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(UseLockOp useLock, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override