MLIR-AIE
VectorToAIEVecConversions.cpp
Go to the documentation of this file.
1//===-VectorToAIEVecConversions.cpp - Vector to AIEVec convs. ---*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2023, Advanced Micro Devices, Inc.
8//
9//===----------------------------------------------------------------------===//
10// This file contains conversions from the Vector dialect into the AIEVec
11// dialect. Conversions assume that the Vector dialect has been rectricted
12// to ops that can be translated to a sequence of valid AIEVec ops.
13//===----------------------------------------------------------------------===//
14
19
21#include "mlir/Dialect/Affine/IR/AffineOps.h"
22#include "mlir/Dialect/EmitC/IR/EmitC.h"
23#include "mlir/Dialect/Func/IR/FuncOps.h"
24#include "mlir/Dialect/Math/IR/Math.h"
25#include "mlir/Dialect/MemRef/IR/MemRef.h"
26#include "mlir/Dialect/SCF/IR/SCF.h"
27#include "mlir/Dialect/UB/IR/UBOps.h"
28#include "mlir/IR/PatternMatch.h"
29#include "mlir/IR/SymbolTable.h"
30#include "mlir/IR/TypeUtilities.h"
31#include "mlir/Pass/PassManager.h"
32#include "mlir/Transforms/DialectConversion.h"
33#include "mlir/Transforms/Passes.h"
34#include "llvm/ADT/SmallSet.h"
35#include <bitset>
36#include <optional>
37#include <tuple>
38
39#define DEBUG_TYPE "lower-vector-to-aievec"
40
41using namespace llvm;
42using namespace mlir;
43using namespace arith;
44using namespace vector;
45using namespace xilinx;
46using namespace xilinx::aievec;
47
48//===----------------------------------------------------------------------===//
49// Utility functions
50//===----------------------------------------------------------------------===//
51
52static bool isNarrowingOp(Operation *op) {
53 if (isa<arith::TruncFOp>(op) || isa<arith::TruncIOp>(op))
54 return true;
55
56 if (auto srsOp = dyn_cast<aievec::SRSOp>(op)) {
57 auto *srsOpSrcOp = srsOp.getSource().getDefiningOp();
58 if (isa<aievec::UPSOp>(srsOpSrcOp) || isa<aievec::CastOp>(srsOpSrcOp))
59 return true;
60 }
61 return false;
62}
63
64// Check if a TruncIOp is part of a shrsi+[clamp]+trunc chain that can be
65// lowered to a compound SRS pattern. Only returns true for chains that the
66// ShiftClampTruncToSRSPattern rewrite will actually match:
67// - shrsi → trunci (no clamp)
68// - shrsi → maxsi → minsi → trunci (full clamp pair)
69// - shrsi → minsi → maxsi → trunci (reversed clamp pair)
70// Does NOT return true for a single min or max without a matching pair,
71// to avoid marking trunci illegal when no rewrite can handle it.
72static bool isSRSCompoundCandidate(arith::TruncIOp trunciOp) {
73 Value source = trunciOp.getIn();
74
75 // Case 1: direct shrsi → trunci
76 if (source.getDefiningOp<arith::ShRSIOp>())
77 return true;
78
79 // Case 2: minsi(maxsi(shrsi(...))) → trunci
80 if (auto minsiOp = source.getDefiningOp<arith::MinSIOp>()) {
81 if (auto maxsiOp = minsiOp.getLhs().getDefiningOp<arith::MaxSIOp>()) {
82 if (maxsiOp.getLhs().getDefiningOp<arith::ShRSIOp>())
83 return true;
84 }
85 }
86
87 // Case 3: maxsi(minsi(shrsi(...))) → trunci (reversed order)
88 if (auto maxsiOp = source.getDefiningOp<arith::MaxSIOp>()) {
89 if (auto minsiOp = maxsiOp.getLhs().getDefiningOp<arith::MinSIOp>()) {
90 if (minsiOp.getLhs().getDefiningOp<arith::ShRSIOp>())
91 return true;
92 }
93 }
94
95 // Case 4: minsi(maxsi(x)) → trunci (clamp-only, no shrsi)
96 // Used for saturating clamp after skip-add (e.g., skip_scale=0).
97 if (auto minsiOp = source.getDefiningOp<arith::MinSIOp>()) {
98 if (minsiOp.getLhs().getDefiningOp<arith::MaxSIOp>())
99 return true;
100 }
101
102 // Case 5: maxsi(minsi(x)) → trunci (reversed clamp-only, no shrsi)
103 if (auto maxsiOp = source.getDefiningOp<arith::MaxSIOp>()) {
104 if (maxsiOp.getLhs().getDefiningOp<arith::MinSIOp>())
105 return true;
106 }
107
108 return false;
109}
110
111// Check if a ShRSIOp's result feeds into a trunci (possibly through
112// clamp ops), meaning the compound pattern will consume it.
113static bool shrsiUsedByCompoundSRS(arith::ShRSIOp rsOp) {
114 for (Operation *user : rsOp->getUsers()) {
115 // Direct: shrsi → trunci (validate full chain via isSRSCompoundCandidate)
116 if (auto truncOp = dyn_cast<arith::TruncIOp>(user))
117 if (isSRSCompoundCandidate(truncOp))
118 return true;
119 // Through clamp: shrsi → maxsi → minsi → trunci
120 // or: shrsi → minsi → maxsi → trunci
121 if (isa<arith::MaxSIOp, arith::MinSIOp>(user)) {
122 for (Operation *user2 : user->getUsers()) {
123 if (auto truncOp2 = dyn_cast<arith::TruncIOp>(user2))
124 if (isSRSCompoundCandidate(truncOp2))
125 return true;
126 if (isa<arith::MaxSIOp, arith::MinSIOp>(user2)) {
127 for (Operation *user3 : user2->getUsers()) {
128 if (auto truncOp3 = dyn_cast<arith::TruncIOp>(user3))
129 if (isSRSCompoundCandidate(truncOp3))
130 return true;
131 }
132 }
133 }
134 }
135 }
136 return false;
137}
138
139// Check if a scalar maxsi/minsi is sandwiched in a compound SRS chain
140// (between shrsi and trunci). Used to keep the op legal so the scalar
141// compound SRS pattern can consume the entire chain.
142static bool scalarClampInCompoundSRS(Operation *op) {
143 if (!isa<arith::MaxSIOp, arith::MinSIOp>(op))
144 return false;
145 // Only apply to scalar types
146 if (isa<VectorType>(op->getResult(0).getType()))
147 return false;
148 for (Operation *user : op->getUsers()) {
149 if (auto truncOp = dyn_cast<arith::TruncIOp>(user)) {
150 if (isSRSCompoundCandidate(truncOp))
151 return true;
152 }
153 if (isa<arith::MaxSIOp, arith::MinSIOp>(user)) {
154 for (Operation *user2 : user->getUsers()) {
155 if (auto truncOp2 = dyn_cast<arith::TruncIOp>(user2)) {
156 if (isSRSCompoundCandidate(truncOp2))
157 return true;
158 }
159 }
160 }
161 }
162 return false;
163}
164
165// Given a Value, if it is defined by a widening op (arith:ExtSIOp,
166// arith::ExtUIOp, arith::ExtFOp, aievec::UPSOp + aievec::SRSOp,
167// aievec::UPSOp + aievec::CastOp), return the source of the widening op.
168static std::optional<Value> getSourceOfWideningOp(Value src) {
169 if (auto extSIOp = src.getDefiningOp<arith::ExtSIOp>())
170 return extSIOp.getIn();
171 if (auto extUIOp = src.getDefiningOp<arith::ExtUIOp>())
172 return extUIOp.getIn();
173 if (auto extFOp = src.getDefiningOp<arith::ExtFOp>())
174 return extFOp.getIn();
175 if (auto srsOp = src.getDefiningOp<aievec::SRSOp>()) {
176 // Conversion through AIE intrinsics takes two steps:
177 // 1) Load to accumulator: aievec.ups
178 // 2) Move from accumulator: aievec.srs
179 auto srsSource = srsOp.getSource();
180 if (srsSource)
181 if (auto upsOp = srsSource.getDefiningOp<aievec::UPSOp>())
182 return upsOp.getSource();
183 }
184 if (auto castOp = src.getDefiningOp<aievec::CastOp>()) {
185 // Conversion through AIE intrinsics can also take the following two steps:
186 // 1) Load to accumulator: aievec.ups
187 // 2) Move from accumulator: aievec.cast
188 auto castSource = castOp.getSource();
189 if (castSource)
190 if (auto upsOp = castSource.getDefiningOp<aievec::UPSOp>())
191 return upsOp.getSource();
192 }
193 return std::optional<Value>();
194}
195
196// Given a Value, if it is defined by a narrowing op (arith::TruncFOp,
197// arith::TruncIOp), return the source of the narrowing op.
198static std::optional<Value> getSourceOfNarrowingOp(Value src) {
199 if (auto truncFOp = src.getDefiningOp<arith::TruncFOp>())
200 return truncFOp.getIn();
201 if (auto truncIOp = src.getDefiningOp<arith::TruncIOp>())
202 return truncIOp.getIn();
203 return std::optional<Value>();
204}
205
206//===----------------------------------------------------------------------===//
207// Type conversion utilities with narrowing/widening optimization awareness
208//===----------------------------------------------------------------------===//
209
210// Smart widen a value to target type. If the value was previously narrowed
211// from the target type, reuse the original source to avoid truncf->extf chains.
212static Value widenValueWithNarrowingCheck(Value val, Type targetType,
213 Location loc,
214 ConversionPatternRewriter &rewriter) {
215 // Check if this value was narrowed from the target type
216 if (auto narrowedSrc = getSourceOfNarrowingOp(val)) {
217 if (narrowedSrc->getType() == targetType)
218 return *narrowedSrc; // Reuse the original value (skip truncf->extf)
219 }
220
221 // Otherwise, create the widening op if needed
222 if (val.getType() == targetType)
223 return val;
224
225 return arith::ExtFOp::create(rewriter, loc, targetType, val);
226}
227
228// Result structure for smart narrowing operation
230 Value narrowedValue; // The narrowed value (or original if optimized)
231 bool skipNarrowing; // True if we should skip creating truncf
232 Operation *wideningUser; // The widening op to replace (if skipNarrowing)
233};
234
235// Smart narrow a value to target type. If the result will be immediately
236// widened back, skip both truncf and extf operations.
237static NarrowingResult
238narrowValueWithWideningCheck(Operation *srcOp, Value val, Type targetType,
239 Location loc,
240 ConversionPatternRewriter &rewriter) {
241 NarrowingResult result;
242 result.narrowedValue = val;
243 result.skipNarrowing = false;
244 result.wideningUser = nullptr;
245
246 // Check if srcOp will be immediately widened back
247 if (srcOp->hasOneUse()) {
248 Operation *user = *srcOp->getUsers().begin();
249 if (auto extfOp = dyn_cast<arith::ExtFOp>(user)) {
250 // The result will be widened - skip both truncf and extf
251 result.skipNarrowing = true;
252 result.wideningUser = extfOp;
253 return result;
254 }
255 }
256
257 // Normal case: create the narrowing op
258 result.narrowedValue =
259 arith::TruncFOp::create(rewriter, loc, targetType, val);
260 return result;
261}
262
263// High-level helper to perform a binary operation on bf16 values in f32
264// precision. This function handles:
265// 1. Smart widening of operands (reuses f32 source if narrowed from f32)
266// 2. Executing the operation in f32
267// 3. Smart narrowing back to bf16 (skips truncf->extf if result is widened)
268static void
269performBF16BinaryOpInF32(Value lhs, Value rhs, Operation *srcOp, Location loc,
270 ConversionPatternRewriter &rewriter,
271 std::function<Value(Value, Value)> opBuilder) {
272 Type f32Type = rewriter.getF32Type();
273
274 // Smart widen both operands (reuse f32 source if narrowed from f32)
275 Value lhsF32 = widenValueWithNarrowingCheck(lhs, f32Type, loc, rewriter);
276 Value rhsF32 = widenValueWithNarrowingCheck(rhs, f32Type, loc, rewriter);
277
278 // Perform operation in f32
279 Value resultF32 = opBuilder(lhsF32, rhsF32);
280
281 // Smart narrow back to bf16 (skip if result will be widened)
282 auto narrowResult = narrowValueWithWideningCheck(
283 srcOp, resultF32, lhs.getType(), loc, rewriter);
284
285 if (narrowResult.skipNarrowing) {
286 // Replace the widening user directly with f32 result
287 rewriter.replaceOp(narrowResult.wideningUser, resultF32);
288 rewriter.eraseOp(srcOp);
289 } else {
290 rewriter.replaceOp(srcOp, narrowResult.narrowedValue);
291 }
292}
293
294// Given the LHS and RHS of an `arith::AddIOp`, if one of them is defined by an
295// `arith::MulIOp`, return a tuple with the `lhs`, `rhs`, and `acc` of the MAC
296// operation that can replace them.
297static std::optional<std::tuple<Value, Value, Value>>
298extractMACOperandsFromAddOperands(Value addLhs, Value addRhs) {
299 auto *lhsDefOp = addLhs.getDefiningOp();
300 auto *rhsDefOp = addRhs.getDefiningOp();
301 arith::MulIOp mulOp = nullptr;
302 Value acc;
303 if (lhsDefOp) {
304 mulOp = dyn_cast<arith::MulIOp>(lhsDefOp);
305 acc = addRhs;
306 }
307 if (!mulOp && rhsDefOp) {
308 mulOp = dyn_cast<arith::MulIOp>(rhsDefOp);
309 acc = addLhs;
310 }
311 if (mulOp)
312 return std::make_tuple(mulOp.getLhs(), mulOp.getRhs(), acc);
313
314 // If the MulIOp has been already translated to aievec::aie1::MulOp:
315 auto lhsSrsOp = addLhs.getDefiningOp<aievec::SRSOp>();
316 auto rhsSrsOp = addRhs.getDefiningOp<aievec::SRSOp>();
317 aievec::aie1::MulOp aieMulOp = nullptr;
318 if (lhsSrsOp) {
319 aieMulOp = lhsSrsOp.getSource().getDefiningOp<aievec::aie1::MulOp>();
320 acc = addRhs;
321 }
322 if (!aieMulOp && rhsSrsOp) {
323 aieMulOp = rhsSrsOp.getSource().getDefiningOp<aievec::aie1::MulOp>();
324 acc = addLhs;
325 }
326 if (aieMulOp)
327 return std::make_tuple(aieMulOp.getLhs(), aieMulOp.getRhs(), acc);
328 return {};
329}
330
331// Given the LHS and RHS of an `arith::AddFOp`, if one of them is defined by an
332// `arith::MulFOp`, return a tuple with the `lhs`, `rhs`, and `acc` of the FMA
333// operation that can replace them.
334static std::optional<std::tuple<Value, Value, Value>>
335extractFMACOperandsFromAddOperands(Value addLhs, Value addRhs) {
336 auto *lhsDefOp = addLhs.getDefiningOp();
337 auto *rhsDefOp = addRhs.getDefiningOp();
338 arith::MulFOp mulOp = nullptr;
339 Value acc;
340 if (lhsDefOp) {
341 mulOp = dyn_cast<arith::MulFOp>(lhsDefOp);
342 acc = addRhs;
343 }
344 if (!mulOp && rhsDefOp) {
345 mulOp = dyn_cast<arith::MulFOp>(rhsDefOp);
346 acc = addLhs;
347 }
348 if (mulOp)
349 return std::make_tuple(mulOp.getLhs(), mulOp.getRhs(), acc);
350 return {};
351}
352
353// Convert a input value to a target vector type. This function can insert
354// multiple aievec ops depending on the combination of input and output vector
355// types.
356static std::optional<Value>
357convertValueToTargetTypeAIE2(ConversionPatternRewriter &rewriter, Location loc,
358 Value inputVal, VectorType tgtType) {
359 auto srcType = cast<VectorType>(inputVal.getType());
360 auto srcElemType = srcType.getElementType();
361 unsigned srcBitWidth = srcElemType.getIntOrFloatBitWidth();
362 unsigned srcLaneSize = getVectorLaneSize(srcType);
363
364 auto tgtElemType = tgtType.getElementType();
365 unsigned tgtBitWidth = tgtElemType.getIntOrFloatBitWidth();
366 unsigned tgtLaneSize = getVectorLaneSize(tgtType);
367
368 if (srcType == tgtType)
369 return inputVal;
370
371 if ((srcElemType == tgtElemType) && (srcLaneSize != tgtLaneSize)) {
372 // TODO: relax the condition below?
373 if ((srcLaneSize == 16 && tgtLaneSize == 32 &&
374 isa<FloatType>(srcElemType)) ||
375 (srcLaneSize == 32 && tgtLaneSize == 64 &&
376 isa<IntegerType>(srcElemType))) {
377 auto zeroConstOp = arith::ConstantOp::create(
378 rewriter, loc, srcType.getElementType(),
379 rewriter.getZeroAttr(srcType.getElementType()));
380 auto broadcastZeroOp = aievec::BroadcastScalarOp::create(
381 rewriter, loc, tgtType, zeroConstOp->getResult(0));
382 auto extOp = aievec::ExtOp::create(rewriter, loc, srcType,
383 broadcastZeroOp->getResult(0), 0);
384
385 SmallVector<Value> inputSources = {inputVal, extOp->getResult(0)};
386 auto concatOp =
387 aievec::ConcatOp::create(rewriter, loc, tgtType, inputSources);
388
389 return concatOp.getResult();
390 }
391 } else if ((srcElemType != tgtElemType) && (srcLaneSize == tgtLaneSize) &&
392 isa<IntegerType>(srcElemType) && isa<IntegerType>(tgtElemType)) {
393 if (srcBitWidth == 16 && tgtBitWidth == 32 && srcLaneSize == 16) {
394 // Case 1: vector<16xi16> to vector<16xi32> conversion by aievec.ups +
395 // aievec.cast
396 auto accType = getVectorOpDestType(srcType, /*AIE2 =*/true);
397 auto upsOp = aievec::UPSOp::create(rewriter, loc, accType, inputVal);
398 auto castOp = aievec::CastOp::create(
399 rewriter, loc, tgtType, upsOp.getResult(), /*isResAcc*/ false);
400 return castOp.getResult();
401 }
402
403 if (srcBitWidth == 8 && tgtBitWidth == 32 && srcLaneSize == 16) {
404 // Case 2: vector<16xi8> to vector<16xi32> conversion by aievec.concat +
405 // aievec.ups + aievec.cast + aievec.ext
406 auto concatOutType = createVectorType(32, srcElemType);
407 auto concatOp =
408 aievec::ConcatOp::create(rewriter, loc, concatOutType,
409 SmallVector<Value>({inputVal, inputVal}));
410 auto accType = getVectorOpDestType(concatOutType, /*AIE2 =*/true);
411 auto upsOp =
412 aievec::UPSOp::create(rewriter, loc, accType, concatOp.getResult());
413 auto castType = createVectorType(32, tgtElemType);
414 auto castOp = aievec::CastOp::create(
415 rewriter, loc, castType, upsOp.getResult(), /*isResAcc*/ false);
416 auto extOp =
417 aievec::ExtOp::create(rewriter, loc, tgtType, castOp.getResult(), 0);
418 return extOp.getResult();
419 }
420
421 if (srcBitWidth == 8 && tgtBitWidth == 16 && srcLaneSize == 32) {
422 // Case 3: vector<32xi8> to vector<32xi16> conversion by aievec.unpack
423 auto unpackOp =
424 aievec::UnpackOp::create(rewriter, loc, tgtType, inputVal);
425 return unpackOp.getResult();
426 }
427 }
428
429 return std::nullopt;
430}
431
432// Return the list of attributes that configure an `aievec.select` op to
433// perform a rotation of the input vector by `rotation` number of elements.
434// The attribute values depend on the vector type of the select operation.
435static SmallVector<NamedAttribute>
436buildAttributeListForRotationSelectOp(PatternRewriter &rewriter, VectorType vTy,
437 int64_t rotation) {
438 unsigned width = 0;
439 auto elemTy = vTy.getElementType();
440 if (auto intTy = dyn_cast<IntegerType>(elemTy))
441 width = intTy.getWidth();
442 StringAttr attr0 = rewriter.getStringAttr("0");
443 StringAttr attr0x06040200 = rewriter.getStringAttr("0x06040200");
444 StringAttr attr0x0e0c0a08 = rewriter.getStringAttr("0x0e0c0a08");
445 StringAttr attr0x2103 = rewriter.getStringAttr("0x2103");
446 StringAttr attr0x3210 = rewriter.getStringAttr("0x3210");
447 StringAttr selectAttrName = rewriter.getStringAttr("select");
448 StringAttr xoffsetsAttrName = rewriter.getStringAttr("xoffsets");
449 StringAttr xoffsetsHiAttrName = rewriter.getStringAttr("xoffsets_hi");
450 StringAttr xsquareAttrName = rewriter.getStringAttr("xsquare");
451 StringAttr xstartAttrName = rewriter.getStringAttr("xstart");
452 StringAttr yoffsetsAttrName = rewriter.getStringAttr("yoffsets");
453 StringAttr yoffsetsHiAttrName = rewriter.getStringAttr("yoffsets_hi");
454 StringAttr ysquareAttrName = rewriter.getStringAttr("ysquare");
455 StringAttr ystartAttrName = rewriter.getStringAttr("ystart");
456
457 switch (width) {
458 case 16: {
459 if (rotation % 2) {
460 int64_t xstart = rotation + 1;
461 int64_t ystart = rotation - 1;
462 return SmallVector<NamedAttribute, 9>(
463 {{selectAttrName, rewriter.getStringAttr("0x11111111")},
464 {xoffsetsAttrName, attr0x06040200},
465 {xoffsetsHiAttrName, attr0x0e0c0a08},
466 {xsquareAttrName, attr0x2103},
467 {xstartAttrName, rewriter.getStringAttr(std::to_string(xstart))},
468 {yoffsetsAttrName, rewriter.getStringAttr("0x0503010f")},
469 {yoffsetsHiAttrName, rewriter.getStringAttr("0x0d0b0907")},
470 {ysquareAttrName, attr0x2103},
471 {ystartAttrName, rewriter.getStringAttr(std::to_string(ystart))}});
472 }
473 return SmallVector<NamedAttribute, 9>(
474 {{selectAttrName, attr0},
475 {xoffsetsAttrName, attr0x06040200},
476 {xoffsetsHiAttrName, attr0x0e0c0a08},
477 {xsquareAttrName, attr0x3210},
478 {xstartAttrName, rewriter.getStringAttr(std::to_string(rotation))},
479 {yoffsetsAttrName, attr0},
480 {yoffsetsHiAttrName, attr0},
481 {ysquareAttrName, attr0},
482 {ystartAttrName, attr0}});
483 }
484 case 32:
485 return SmallVector<NamedAttribute, 7>(
486 {{selectAttrName, attr0},
487 {xoffsetsAttrName, rewriter.getStringAttr("0x76543210")},
488 {xsquareAttrName, attr0x3210},
489 {xstartAttrName, rewriter.getStringAttr(std::to_string(rotation))},
490 {yoffsetsAttrName, attr0},
491 {ysquareAttrName, attr0},
492 {ystartAttrName, attr0}});
493 default:
494 llvm::report_fatal_error("Unexpected width!");
495 }
496
497 return {};
498}
499
500namespace xilinx::aievec {
501
502SmallVector<NamedAttribute>
503buildFMAOpSplatAttrForElemTy(aievec::aie1::FMAOp fmaOp, int64_t bcastPos,
504 int64_t step = 1) {
505 unsigned width = 0;
506 auto elemTy = fmaOp.getLhs().getType().getElementType();
507 if (auto intTy = dyn_cast<IntegerType>(elemTy))
508 width = intTy.getWidth();
509 auto *ctx = fmaOp.getContext();
510 switch (width) {
511 case 16:
512 // NOTE: The pattern is:
513 // acc[0] = x[0] * z[bcastPos] + x[16] * z[bcastPos+step]
514 // acc[1] = x[1] * z[bcastPos] + x[17] * z[bcastPos+step]
515 // acc[2] = x[2] * z[bcastPos] + x[18] * z[bcastPos+step]
516 // acc[3] = x[3] * z[bcastPos] + x[19] * z[bcastPos+step]
517 // acc[4] = x[4] * z[bcastPos] + x[20] * z[bcastPos+step]
518 // acc[5] = x[5] * z[bcastPos] + x[21] * z[bcastPos+step]
519 // acc[6] = x[6] * z[bcastPos] + x[22] * z[bcastPos+step]
520 // acc[7] = x[7] * z[bcastPos] + x[23] * z[bcastPos+step]
521 // acc[8] = x[8] * z[bcastPos] + x[24] * z[bcastPos+step]
522 // acc[9] = x[9] * z[bcastPos] + x[25] * z[bcastPos+step]
523 // acc[10] = x[10] * z[bcastPos] + x[26] * z[bcastPos+step]
524 // acc[11] = x[11] * z[bcastPos] + x[27] * z[bcastPos+step]
525 // acc[12] = x[12] * z[bcastPos] + x[28] * z[bcastPos+step]
526 // acc[13] = x[13] * z[bcastPos] + x[29] * z[bcastPos+step]
527 // acc[14] = x[14] * z[bcastPos] + x[30] * z[bcastPos+step]
528 // acc[15] = x[15] * z[bcastPos] + x[31] * z[bcastPos+step]
529 return SmallVector<NamedAttribute, 11>(
530 {{fmaOp.getXstartAttrName(), StringAttr::get(ctx, "0")},
531 {fmaOp.getXoffsetsAttrName(), StringAttr::get(ctx, "0x73727170")},
532 {fmaOp.getXoffsetsHiAttrName(), StringAttr::get(ctx, "0x77767574")},
533 {fmaOp.getXstepAttrName(), fmaOp.getXstepAttr()},
534 {fmaOp.getXsquareAttrName(), StringAttr::get(ctx, "0x3120")},
535 {fmaOp.getZstartAttrName(),
536 StringAttr::get(ctx, std::to_string(bcastPos))},
537 {fmaOp.getZoffsetsAttrName(), StringAttr::get(ctx, "0")},
538 {fmaOp.getZoffsetsHiAttrName(), StringAttr::get(ctx, "0")},
539 {fmaOp.getZstepAttrName(), StringAttr::get(ctx, std::to_string(step))},
540 {fmaOp.getZsquareAttrName(), fmaOp.getZsquareAttr()},
541 {fmaOp.getFmsubAttrName(), fmaOp.getFmsubAttr()}});
542 case 32:
543 return SmallVector<NamedAttribute, 11>(
544 {{fmaOp.getXstartAttrName(), StringAttr::get(ctx, "0")},
545 {fmaOp.getXoffsetsAttrName(), StringAttr::get(ctx, "0x76543210")},
546 {fmaOp.getXoffsetsHiAttrName(), fmaOp.getXoffsetsHiAttr()},
547 {fmaOp.getXstepAttrName(), fmaOp.getXstepAttr()},
548 {fmaOp.getXsquareAttrName(), fmaOp.getXsquareAttr()},
549 {fmaOp.getZstartAttrName(),
550 StringAttr::get(ctx, std::to_string(bcastPos))},
551 {fmaOp.getZoffsetsAttrName(), StringAttr::get(ctx, "0x00000000")},
552 {fmaOp.getZoffsetsHiAttrName(), fmaOp.getZoffsetsHiAttr()},
553 {fmaOp.getZstepAttrName(), fmaOp.getZstepAttr()},
554 {fmaOp.getZsquareAttrName(), fmaOp.getZsquareAttr()},
555 {fmaOp.getFmsubAttrName(), fmaOp.getFmsubAttr()}});
556 default:
557 llvm::report_fatal_error("Unexpected width!");
558 }
559
560 return {};
561}
562
563} // namespace xilinx::aievec
564
565template <typename SrcOpTy, typename AIEv2ElemOp>
566static LogicalResult genAddElemAIE2(ConversionPatternRewriter &rewriter,
567 Value lval, Value rval, VectorType srcType,
568 SrcOpTy srcOp) {
569 auto lCastOp = aievec::CastOp::create(rewriter, srcOp.getLoc(), srcType, lval,
570 /*isResAcc*/ true);
571 auto rCastOp = aievec::CastOp::create(rewriter, srcOp.getLoc(), srcType, rval,
572 /*isResAcc*/ true);
573 auto elemOp = AIEv2ElemOp::create(
574 rewriter, srcOp.getLoc(), lCastOp->getResult(0).getType(),
575 lCastOp->getResult(0), rCastOp->getResult(0));
576 rewriter.replaceOpWithNewOp<aievec::CastOp>(
577 srcOp, srcOp.getType(), elemOp.getResult(), /*isResAcc*/ false);
578 return success();
579}
580
581static arith::CmpIPredicate
582convertToIntegerPredicate(arith::CmpFPredicate pred) {
583 switch (pred) {
584 case CmpFPredicate::UEQ:
585 case CmpFPredicate::OEQ:
586 return CmpIPredicate::eq;
587 case CmpFPredicate::UGT:
588 return CmpIPredicate::ugt;
589 case CmpFPredicate::OGT:
590 return CmpIPredicate::sgt;
591 case CmpFPredicate::UGE:
592 return CmpIPredicate::uge;
593 case CmpFPredicate::OGE:
594 return CmpIPredicate::sge;
595 case CmpFPredicate::ULT:
596 return CmpIPredicate::ult;
597 case CmpFPredicate::OLT:
598 return CmpIPredicate::slt;
599 case CmpFPredicate::ULE:
600 return CmpIPredicate::ule;
601 case CmpFPredicate::OLE:
602 return CmpIPredicate::sle;
603 case CmpFPredicate::UNE:
604 case CmpFPredicate::ONE:
605 return CmpIPredicate::ne;
606 default:
607 llvm::report_fatal_error("Unexpected predicate!");
608 }
609}
610
611static arith::CmpIPredicate
612convertToIntegerPredicate(arith::CmpIPredicate pred) {
613 return pred;
614}
615
616static aievec::CmpOp createCmpOpAIE2(ConversionPatternRewriter &rewriter,
617 CmpIPredicate pred, Location loc,
618 Type type, Value lhs, Value rhs) {
619 switch (pred) {
620 case CmpIPredicate::eq:
621 return aievec::CmpOp::create(rewriter, loc, type, lhs, rhs, "eq");
622 case CmpIPredicate::ne:
623 return aievec::CmpOp::create(rewriter, loc, type, lhs, rhs, "ne");
624 case CmpIPredicate::slt:
625 return aievec::CmpOp::create(rewriter, loc, type, lhs, rhs, "slt");
626 case CmpIPredicate::ult:
627 return aievec::CmpOp::create(rewriter, loc, type, lhs, rhs, "ult");
628 case CmpIPredicate::sle:
629 return aievec::CmpOp::create(rewriter, loc, type, lhs, rhs, "sle");
630 case CmpIPredicate::ule:
631 return aievec::CmpOp::create(rewriter, loc, type, lhs, rhs, "ule");
632 case CmpIPredicate::sgt:
633 return aievec::CmpOp::create(rewriter, loc, type, lhs, rhs, "sgt");
634 case CmpIPredicate::ugt:
635 return aievec::CmpOp::create(rewriter, loc, type, lhs, rhs, "ugt");
636 case CmpIPredicate::sge:
637 return aievec::CmpOp::create(rewriter, loc, type, lhs, rhs, "sge");
638 case CmpIPredicate::uge:
639 return aievec::CmpOp::create(rewriter, loc, type, lhs, rhs, "uge");
640 }
641 return nullptr;
642}
643
644template <typename DstOpTy>
645static aievec::ExtElemOp
646generateAIEVecOpsForReductionOp(ConversionPatternRewriter &rewriter,
647 vector::ReductionOp srcOp, int shiftIndex,
648 Value curValue) {
649 assert(shiftIndex > 0 && (shiftIndex & (shiftIndex - 1)) == 0 &&
650 "shiftIndex must be power of 2");
651
652 Location loc = srcOp.getLoc();
653 auto vType = dyn_cast<VectorType>(curValue.getType());
654 Type scalarType = vType.getElementType();
655 Type vecType = curValue.getType();
656 DstOpTy curOp = nullptr;
657 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
658
659 for (int id = shiftIndex; id > 0; id /= 2) {
660 auto constOp = arith::ConstantOp::create(
661 rewriter, loc, rewriter.getI32IntegerAttr(id * elWidth / 8));
662
663 auto shiftBytesOp = aievec::ShiftOp::create(
664 rewriter, loc, vecType, curValue, curValue, constOp.getResult());
665
666 curOp = DstOpTy::create(rewriter, loc, vecType, curValue,
667 shiftBytesOp.getResult());
668
669 curValue = curOp.getResult();
670 }
671
672 auto zeroConstOp =
673 arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0));
674 return aievec::ExtElemOp::create(rewriter, loc, scalarType, curOp,
675 zeroConstOp.getResult());
676}
677
678// Helper to pad a v16bf16 vector to v32bf16 by concatenating with a splat of
679// the given infinity value. Used by min/max reduction patterns.
680// Returns {paddedVector, newLaneSize}.
681static std::pair<Value, unsigned>
682padV16ToV32WithInfinity(ConversionPatternRewriter &rewriter, Location loc,
683 Value inputVec, Type scalarType, bool negativeInf) {
684 VectorType v32bf16Type = createVectorType(32, scalarType);
685 VectorType v16bf16Type = createVectorType(16, scalarType);
686
687 // Create a scalar infinity constant
688 auto infAttr = rewriter.getFloatAttr(
689 scalarType,
690 APFloat::getInf(cast<FloatType>(scalarType).getFloatSemantics(),
691 negativeInf));
692 auto splatInf = arith::ConstantOp::create(rewriter, loc, infAttr).getResult();
693
694 // Broadcast to v32bf16, then extract upper half (which is also infinity)
695 auto infVec =
696 aievec::BroadcastScalarOp::create(rewriter, loc, v32bf16Type, splatInf);
697 auto infUpperHalf =
698 aievec::ExtOp::create(rewriter, loc, v16bf16Type, infVec, 1);
699
700 // Concatenate input with infinity padding
701 Value paddedVec =
702 aievec::ConcatOp::create(rewriter, loc, v32bf16Type,
703 ValueRange{inputVec, infUpperHalf.getResult()});
704
705 return {paddedVec, 32};
706}
707
708// Helper to pad a v16bf16 vector to v32bf16 by concatenating with zeros.
709// Used by elementwise min/max/cmp/sel patterns for 256-bit bf16 support.
710static Value padV16ToV32WithZeros(ConversionPatternRewriter &rewriter,
711 Location loc, Value inputVec,
712 Type scalarType) {
713 VectorType v32Type = createVectorType(32, scalarType);
714 VectorType v16Type = createVectorType(16, scalarType);
715 auto zeroAttr = rewriter.getZeroAttr(v16Type);
716 auto zeroVec = arith::ConstantOp::create(rewriter, loc, zeroAttr);
717 return aievec::ConcatOp::create(rewriter, loc, v32Type,
718 ValueRange{inputVec, zeroVec.getResult()});
719}
720
721static func::FuncOp getOrInsertFuncDecl(ConversionPatternRewriter &rewriter,
722 Operation *parentSymbolTableOp,
723 StringRef funcName, TypeRange inTypes,
724 TypeRange outTypes) {
725
726 mlir::OpBuilder::InsertionGuard insertGuard(rewriter);
727 rewriter.setInsertionPointToStart(
728 &parentSymbolTableOp->getRegions().front().getBlocks().front());
729 SymbolTable st = SymbolTable(parentSymbolTableOp);
730 func::FuncOp fnOpLookup = st.lookup<func::FuncOp>(funcName);
731 func::FuncOp fnOp;
732 // if the function is already declared, use the existing function, don't
733 // declare multiple times
734 if (fnOpLookup != NULL) {
735 fnOp = fnOpLookup;
736 } else {
737 StringAttr t1 = rewriter.getStringAttr("sym_visibility");
738 StringAttr t2 = rewriter.getStringAttr("private");
739 NamedAttribute funcAccess = NamedAttribute(t1, t2);
740 FunctionType fnType =
741 mlir::FunctionType::get(rewriter.getContext(), inTypes, outTypes);
742 fnOp = func::FuncOp::create(rewriter, parentSymbolTableOp->getLoc(),
743 funcName, fnType, funcAccess);
744 }
745 return fnOp;
746}
747
748//===----------------------------------------------------------------------===//
749// Wide vector splitting utility
750//===----------------------------------------------------------------------===//
751
752// Utility function to split a wide vector operation (e.g., v32bf16) into two
753// half-width operations (e.g., v16bf16) and concatenate the results.
754// This pattern is common for AIE2 when hardware only supports 16-lane
755// operations but we need to process 32-lane vectors.
756//
757// Template parameters:
758// SrcOpTy - The source operation type (e.g., arith::MulFOp, math::ExpOp)
759//
760// Parameters:
761// srcOp - The source operation to replace
762// wideInputs - The wide input values (e.g., lhs and rhs for binary ops)
763// halfType - The half-width vector type (e.g., vector<16xbf16>)
764// wideType - The wide vector type (e.g., vector<32xbf16>)
765// rewriter - The pattern rewriter
766// processHalves - Callback that processes the lower and upper halves
767// and returns a pair of half-width results to be concatenated
768//
769// The callback signature is:
770// std::pair<Value, Value>(ArrayRef<std::pair<Value, Value>> halfInputs,
771// Location loc, ConversionPatternRewriter &rewriter)
772// where halfInputs[i] is {lowerHalf, upperHalf} for each wideInput
773template <typename SrcOpTy, typename Func>
774static void splitWideVectorOp(SrcOpTy srcOp, ArrayRef<Value> wideInputs,
775 VectorType halfType, VectorType wideType,
776 ConversionPatternRewriter &rewriter,
777 Func &&processHalves) {
778
779 Location loc = srcOp.getLoc();
780
781 // Extract lower and upper halves for each wide input
782 SmallVector<std::pair<Value, Value>> halfInputs;
783 halfInputs.reserve(wideInputs.size());
784 for (Value wideInput : wideInputs) {
785 auto lowerHalf =
786 aievec::ExtOp::create(rewriter, loc, halfType, wideInput, 0);
787 auto upperHalf =
788 aievec::ExtOp::create(rewriter, loc, halfType, wideInput, 1);
789 halfInputs.emplace_back(lowerHalf.getResult(), upperHalf.getResult());
790 }
791
792 // Process halves using the callback
793 auto [lowResult, highResult] = processHalves(halfInputs, loc, rewriter);
794
795 // Concatenate results
796 SmallVector<Value> concatSources = {lowResult, highResult};
797 rewriter.replaceOpWithNewOp<aievec::ConcatOp>(srcOp, wideType, concatSources);
798}
799
800// Simplified version for unary operations
801template <typename SrcOpTy>
802static void splitWideUnaryVectorOp(
803 SrcOpTy srcOp, Value wideInput, VectorType halfType, VectorType wideType,
804 ConversionPatternRewriter &rewriter,
805 std::function<Value(Value, Location, ConversionPatternRewriter &)>
806 processHalf) {
807
808 splitWideVectorOp<SrcOpTy>(
809 srcOp, {wideInput}, halfType, wideType, rewriter,
810 [&processHalf](ArrayRef<std::pair<Value, Value>> halfInputs, Location loc,
811 ConversionPatternRewriter &rewriter) {
812 auto [lowerHalf, upperHalf] = halfInputs[0];
813 Value lowResult = processHalf(lowerHalf, loc, rewriter);
814 Value highResult = processHalf(upperHalf, loc, rewriter);
815 return std::make_pair(lowResult, highResult);
816 });
817}
818
819// Simplified version for binary operations
820template <typename SrcOpTy>
821static void splitWideBinaryVectorOp(
822 SrcOpTy srcOp, Value lhs, Value rhs, VectorType halfType,
823 VectorType wideType, ConversionPatternRewriter &rewriter,
824 std::function<Value(Value, Value, Location, ConversionPatternRewriter &)>
825 processHalf) {
826
827 splitWideVectorOp<SrcOpTy>(
828 srcOp, {lhs, rhs}, halfType, wideType, rewriter,
829 [&processHalf](ArrayRef<std::pair<Value, Value>> halfInputs, Location loc,
830 ConversionPatternRewriter &rewriter) {
831 auto [lhsLow, lhsHigh] = halfInputs[0];
832 auto [rhsLow, rhsHigh] = halfInputs[1];
833 Value lowResult = processHalf(lhsLow, rhsLow, loc, rewriter);
834 Value highResult = processHalf(lhsHigh, rhsHigh, loc, rewriter);
835 return std::make_pair(lowResult, highResult);
836 });
837}
838
839//===----------------------------------------------------------------------===//
840// Math operation matching utilities
841//===----------------------------------------------------------------------===//
842
843// Check if math.exp op matches AIE2 LUT-based exp constraints
844static bool matchExpOpForAIE2LUT(math::ExpOp::Adaptor adaptor) {
845 auto srcType = dyn_cast<VectorType>(adaptor.getOperand().getType());
846
847 if (!srcType)
848 return false;
849
850 Type scalarType = srcType.getElementType();
851 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
852 unsigned laneSize = getVectorLaneSize(srcType);
853 // AIE2 LUT-based exp: supports v16bf16 and v32bf16
854 return isa<FloatType>(scalarType) && (laneSize == 16 || laneSize == 32) &&
855 elWidth == 16;
856}
857
858// Check if math.exp op matches AIE2P exp constraints
859static bool matchExpOpForAIE2P(math::ExpOp::Adaptor adaptor) {
860 auto srcType = dyn_cast<VectorType>(adaptor.getOperand().getType());
861
862 if (!srcType)
863 return false;
864
865 Type scalarType = srcType.getElementType();
866 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
867 unsigned laneSize = getVectorLaneSize(srcType);
868 // AIE2P exp: supports v16bf16 and v32bf16
869 return scalarType.isBF16() && (laneSize == 16 || laneSize == 32) &&
870 elWidth == 16;
871}
872
873//===----------------------------------------------------------------------===//
874// Rewrite patterns
875//===----------------------------------------------------------------------===//
876
877// This pattern fold `vector.extract` and `vector.broadcast` into
878// `aievec.broadcast` for AIE2
880 : OpConversionPattern<vector::BroadcastOp> {
881 using OpConversionPattern::OpConversionPattern;
882
883 LogicalResult
884 matchAndRewrite(vector::BroadcastOp bcastOp, OpAdaptor adaptor,
885 ConversionPatternRewriter &rewriter) const override {
886
887 auto extOp = adaptor.getSource().getDefiningOp<vector::ExtractOp>();
888
889 if (!extOp)
890 return failure();
891
892 auto src = extOp.getSource();
893 auto pos = extOp.getStaticPosition();
894 int64_t posVal = pos[0];
895 auto srcVecType = cast<VectorType>(src.getType());
896 auto resultType = cast<VectorType>(bcastOp.getResult().getType());
897 if (srcVecType != resultType) {
898 if (srcVecType.getNumElements() != 2 * resultType.getNumElements())
899 return failure();
900 auto half = static_cast<int8_t>(posVal / resultType.getNumElements());
901 posVal -= half * resultType.getNumElements();
902 src = aievec::ExtOp::create(rewriter, extOp.getLoc(), resultType, src,
903 rewriter.getI8IntegerAttr(half))
904 .getResult();
905 }
906
907 unsigned elWidth = resultType.getElementType().getIntOrFloatBitWidth();
908
909 if (unsigned laneSize = getVectorLaneSize(resultType);
910 laneSize * elWidth == 512) {
911 // Common use case for the broadcast_elem intrinsic
912 rewriter.replaceOpWithNewOp<aievec::BroadcastOp>(bcastOp, resultType, src,
913 posVal);
914 } else if (laneSize * elWidth == 256) {
915 // e.g. need v16bf16 due to the subsequent v16accfloat operation
916 VectorType aievecBcastType =
917 createVectorType(512 / elWidth, resultType.getElementType());
918 auto concatOp =
919 aievec::ConcatOp::create(rewriter, bcastOp.getLoc(), aievecBcastType,
920 SmallVector<Value>({src, src}));
921 auto aieBcastOp = aievec::BroadcastOp::create(
922 rewriter, bcastOp.getLoc(), aievecBcastType, concatOp.getResult(),
923 posVal);
924 rewriter.replaceOpWithNewOp<aievec::ExtOp>(bcastOp, resultType,
925 aieBcastOp.getResult(), 0);
926 } else if (laneSize * elWidth == 1024) {
927 // e.g. need v32int32 due to the subsequent v32acc32 operation
928 VectorType aievecBcastType =
929 createVectorType(512 / elWidth, resultType.getElementType());
930 auto half = static_cast<int8_t>(posVal / resultType.getNumElements());
931 posVal -= half * resultType.getNumElements();
932 auto extOp =
933 aievec::ExtOp::create(rewriter, bcastOp.getLoc(), aievecBcastType,
934 src, rewriter.getI8IntegerAttr(half));
935 auto aieBcastOp = aievec::BroadcastOp::create(rewriter, bcastOp.getLoc(),
936 aievecBcastType,
937 extOp.getResult(), posVal);
938 rewriter.replaceOpWithNewOp<aievec::ConcatOp>(
939 bcastOp, resultType,
940 SmallVector<Value>({aieBcastOp.getResult(), aieBcastOp.getResult()}));
941 } else {
942 return failure();
943 }
944
945 return success();
946 }
947};
948
949struct ConvertSplatToAIEBroadcast : OpConversionPattern<vector::BroadcastOp> {
950 using OpConversionPattern::OpConversionPattern;
951
952 LogicalResult
953 matchAndRewrite(vector::BroadcastOp bcastOp, OpAdaptor adaptor,
954 ConversionPatternRewriter &rewriter) const override {
955
956 if (adaptor.getSource().getDefiningOp<vector::ExtractOp>())
957 return failure();
958
959 auto resultType = cast<VectorType>(bcastOp.getResult().getType());
960 auto flatResultType = getFlattenedVectorType(resultType);
961 Type scalarType = resultType.getElementType();
962 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
963 unsigned laneSize = getVectorLaneSize(resultType);
964 auto src = bcastOp.getSource();
965
966 if (laneSize * elWidth == 512) {
967 Value newOp = aievec::BroadcastScalarOp::create(
968 rewriter, bcastOp.getLoc(), flatResultType, src);
969 if (resultType != flatResultType)
970 newOp = vector::ShapeCastOp::create(rewriter, bcastOp.getLoc(),
971 resultType, newOp);
972 rewriter.replaceOp(bcastOp, newOp);
973 return success();
974 }
975
976 if (laneSize * elWidth == 256) {
977 VectorType vecType = createVectorType(512 / elWidth, scalarType);
978 auto aieBcastOp = aievec::BroadcastScalarOp::create(
979 rewriter, bcastOp.getLoc(), vecType, src);
980 Value newOp =
981 aievec::ExtOp::create(rewriter, bcastOp.getLoc(), flatResultType,
982 aieBcastOp.getResult(), 0);
983 if (resultType != flatResultType)
984 newOp = vector::ShapeCastOp::create(rewriter, bcastOp.getLoc(),
985 resultType, newOp);
986 rewriter.replaceOp(bcastOp, newOp);
987 return success();
988 }
989
990 if (laneSize * elWidth == 1024) {
991 VectorType vecType = createVectorType(512 / elWidth, scalarType);
992 auto aieBcastOp = aievec::BroadcastScalarOp::create(
993 rewriter, bcastOp.getLoc(), vecType, src);
994 Value newOp = aievec::ConcatOp::create(
995 rewriter, bcastOp.getLoc(), flatResultType,
996 SmallVector<Value>({aieBcastOp.getResult(), aieBcastOp.getResult()}));
997 if (resultType != flatResultType)
998 newOp = vector::ShapeCastOp::create(rewriter, bcastOp.getLoc(),
999 resultType, newOp);
1000 rewriter.replaceOp(bcastOp, newOp);
1001 return success();
1002 }
1003
1004 return failure();
1005 }
1006};
1007
1008// This pattern replaces `arith.muli`+`arith.addi` on vectors with
1009// `aievec.mac_elem`. This pattern works for AIE2.
1011 : OpConversionPattern<arith::AddIOp> {
1012 using OpConversionPattern::OpConversionPattern;
1013
1015 unsigned shiftParam = 0)
1017
1018 LogicalResult
1019 matchAndRewrite(arith::AddIOp addOp, OpAdaptor adaptor,
1020 ConversionPatternRewriter &rewriter) const override {
1021 // Verify it's a vector operation
1022 auto resultType = dyn_cast<VectorType>(addOp.getType());
1023 if (!resultType)
1024 return failure();
1025
1026 // Verify it can be replaced by a MAC
1027 auto res =
1028 extractMACOperandsFromAddOperands(adaptor.getLhs(), adaptor.getRhs());
1029 if (!res)
1030 return failure();
1031 auto [lhs, rhs, acc] = *res;
1032
1033 // Verify the vector type is supported by AIE2
1034 unsigned resultElWidth =
1035 resultType.getElementType().getIntOrFloatBitWidth();
1036 unsigned laneSize = getVectorLaneSize(resultType);
1037
1038 if ((laneSize != 32 || resultElWidth != 16) &&
1039 (laneSize != 16 || resultElWidth != 32))
1040 return failure();
1041
1042 Type accType = getVectorOpDestType(cast<VectorType>(acc.getType()),
1043 /*AIE2 =*/true);
1044 auto upsOp = aievec::UPSOp::create(rewriter, addOp.getLoc(), accType, acc,
1045 shiftParam);
1046 auto fmaElemOp = aievec::FMAElemOp::create(
1047 rewriter, addOp.getLoc(), accType, lhs, rhs, upsOp.getResult(),
1048 /*fmsub=*/false);
1049
1050 auto shiftParamOp = arith::ConstantOp::create(
1051 rewriter, addOp.getLoc(), rewriter.getI32IntegerAttr(shiftParam));
1052 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1053 addOp, resultType, fmaElemOp.getResult(), shiftParamOp.getResult());
1054
1055 return success();
1056 }
1057
1058 unsigned shiftParam;
1059};
1060
1061// Lower a single 16-lane bf16 FMA half: UPS -> FMAElem -> SRS.
1062static Value lowerBF16FMAHalf(Value lhs, Value rhs, Value acc,
1063 unsigned shiftParam, Location loc,
1064 ConversionPatternRewriter &rewriter) {
1065 auto f32AccType = VectorType::get({16}, rewriter.getF32Type());
1066 auto upsOp =
1067 aievec::UPSOp::create(rewriter, loc, f32AccType, acc, shiftParam);
1068 auto fmaElemOp = aievec::FMAElemOp::create(rewriter, loc, f32AccType, lhs,
1069 rhs, upsOp.getResult(),
1070 /*fmsub=*/false);
1071 auto shiftParamOp = arith::ConstantOp::create(
1072 rewriter, loc, rewriter.getI32IntegerAttr(shiftParam));
1073 auto srsOp =
1074 aievec::SRSOp::create(rewriter, loc, cast<VectorType>(lhs.getType()),
1075 fmaElemOp.getResult(), shiftParamOp.getResult());
1076 return srsOp.getResult();
1077}
1078
1079// Convert `vector.fma` to `aievec.mac_elem`. Supported operand types:
1080// `vector<16xf32>`, `vector<16xbf16>`, and `vector<32xbf16>` (split into two
1081// 16-lane FMAs). In the case of vectors with
1082// `f32` elemental type, this pattern will try to match `bf16` to `f32`
1083// widening ops in the `lhs` and `rhs` operands, or fail otherwise.
1084// TODO: When sign extensions are not found, a conversion from `f32` to `bf16`
1085// TODO: can be inserted to emulate `f32` fma with `bf16` logic.
1087 : OpConversionPattern<vector::FMAOp> {
1088 using OpConversionPattern::OpConversionPattern;
1089
1091 unsigned shiftParam = 0)
1093
1094 LogicalResult
1095 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
1096 ConversionPatternRewriter &rewriter) const override {
1097 // Verify the vector type is supported by AIE2
1098 auto resVecTy = cast<VectorType>(fmaOp.getType());
1099 auto resElemTy = resVecTy.getElementType();
1100 unsigned numElems = getVectorLaneSize(resVecTy);
1101
1102 // Only support f32 with 16 lanes; bf16 with 16 or 32 lanes.
1103 if ((!resElemTy.isF32() && !resElemTy.isBF16()) ||
1104 (numElems != 16 && !(resElemTy.isBF16() && numElems == 32)))
1105 return rewriter.notifyMatchFailure(
1106 fmaOp, "Unsupported operand types in vector.fma lowering.");
1107
1108 Value lhs = adaptor.getLhs();
1109 Value rhs = adaptor.getRhs();
1110 Value acc = adaptor.getAcc();
1111
1112 // Handle vector<32xbf16> by splitting into two vector<16xbf16> FMAs
1113 if (numElems == 32 && resElemTy.isBF16()) {
1114 VectorType halfType = createVectorType(16, resElemTy);
1115 unsigned localShiftParam = shiftParam;
1116
1117 splitWideVectorOp<vector::FMAOp>(
1118 fmaOp, {lhs, rhs, acc}, halfType, resVecTy, rewriter,
1119 [localShiftParam](ArrayRef<std::pair<Value, Value>> halfInputs,
1120 Location loc, ConversionPatternRewriter &rewriter) {
1121 auto [lhsLow, lhsHigh] = halfInputs[0];
1122 auto [rhsLow, rhsHigh] = halfInputs[1];
1123 auto [accLow, accHigh] = halfInputs[2];
1124
1125 Value lowResult = lowerBF16FMAHalf(lhsLow, rhsLow, accLow,
1126 localShiftParam, loc, rewriter);
1127 Value highResult = lowerBF16FMAHalf(lhsHigh, rhsHigh, accHigh,
1128 localShiftParam, loc, rewriter);
1129 return std::make_pair(lowResult, highResult);
1130 });
1131 return success();
1132 }
1133
1134 if (resElemTy.isBF16())
1135 acc = aievec::UPSOp::create(rewriter, fmaOp.getLoc(),
1136 VectorType::get({16}, rewriter.getF32Type()),
1137 acc, shiftParam);
1138 else {
1139 lhs = getSourceOfWideningOp(lhs).value_or(nullptr);
1140 rhs = getSourceOfWideningOp(rhs).value_or(nullptr);
1141 if (!lhs || !rhs)
1142 return rewriter.notifyMatchFailure(
1143 fmaOp, "vector.fma operands are f32, and they don't come from "
1144 "arith.extf on bf16; can't lower to aievec.");
1145 if (!cast<VectorType>(lhs.getType()).getElementType().isBF16() ||
1146 !cast<VectorType>(rhs.getType()).getElementType().isBF16())
1147 return rewriter.notifyMatchFailure(
1148 fmaOp, "vector.fma operands come from arith.extf, but the source "
1149 "of the widening op is not bf16; can't lower to aievec.");
1150 }
1151 Value newOp =
1152 aievec::FMAElemOp::create(rewriter, fmaOp.getLoc(), acc.getType(), lhs,
1153 rhs, acc, /*fmsub=*/false);
1154
1155 if (resElemTy.isBF16()) {
1156 auto shiftParamOp = arith::ConstantOp::create(
1157 rewriter, fmaOp.getLoc(), rewriter.getI32IntegerAttr(shiftParam));
1158 newOp = aievec::SRSOp::create(rewriter, fmaOp.getLoc(), resVecTy, newOp,
1159 shiftParamOp);
1160 }
1161
1162 rewriter.replaceOp(fmaOp, newOp);
1163
1164 return success();
1165 }
1166
1167 unsigned shiftParam;
1168};
1169
1170// This pattern fuses `arith.mulf` + `arith.addf` on bf16 vectors into
1171// `aievec.mac_elem` (float FMA). This pattern works for AIE2.
1173 : OpConversionPattern<arith::AddFOp> {
1174 using OpConversionPattern::OpConversionPattern;
1175
1177 unsigned shiftParam = 0)
1179
1180 LogicalResult
1181 matchAndRewrite(arith::AddFOp addOp, OpAdaptor adaptor,
1182 ConversionPatternRewriter &rewriter) const override {
1183 // Verify it's a vector operation
1184 auto resultType = dyn_cast<VectorType>(addOp.getType());
1185 if (!resultType)
1186 return failure();
1187
1188 // Only handle bf16 element type
1189 auto elemType = resultType.getElementType();
1190 if (!elemType.isBF16())
1191 return failure();
1192
1193 // Verify it can be replaced by an FMA
1194 auto res =
1195 extractFMACOperandsFromAddOperands(adaptor.getLhs(), adaptor.getRhs());
1196 if (!res)
1197 return failure();
1198 auto [lhs, rhs, acc] = *res;
1199
1200 unsigned laneSize = getVectorLaneSize(resultType);
1201
1202 // Handle vector<32xbf16> by splitting into two vector<16xbf16> FMAs
1203 if (laneSize == 32) {
1204 VectorType halfType = createVectorType(16, elemType);
1205 unsigned localShiftParam = shiftParam;
1206
1207 splitWideVectorOp<arith::AddFOp>(
1208 addOp, {lhs, rhs, acc}, halfType, resultType, rewriter,
1209 [localShiftParam](ArrayRef<std::pair<Value, Value>> halfInputs,
1210 Location loc, ConversionPatternRewriter &rewriter) {
1211 auto [lhsLow, lhsHigh] = halfInputs[0];
1212 auto [rhsLow, rhsHigh] = halfInputs[1];
1213 auto [accLow, accHigh] = halfInputs[2];
1214
1215 Value lowResult = lowerBF16FMAHalf(lhsLow, rhsLow, accLow,
1216 localShiftParam, loc, rewriter);
1217 Value highResult = lowerBF16FMAHalf(lhsHigh, rhsHigh, accHigh,
1218 localShiftParam, loc, rewriter);
1219 return std::make_pair(lowResult, highResult);
1220 });
1221 return success();
1222 }
1223
1224 // Handle vector<16xbf16>
1225 if (laneSize != 16)
1226 return failure();
1227
1228 auto f32AccType = VectorType::get({16}, rewriter.getF32Type());
1229 auto upsOp = aievec::UPSOp::create(rewriter, addOp.getLoc(), f32AccType,
1230 acc, shiftParam);
1231 auto fmaElemOp = aievec::FMAElemOp::create(
1232 rewriter, addOp.getLoc(), f32AccType, lhs, rhs, upsOp.getResult(),
1233 /*fmsub=*/false);
1234
1235 auto shiftParamOp = arith::ConstantOp::create(
1236 rewriter, addOp.getLoc(), rewriter.getI32IntegerAttr(shiftParam));
1237 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1238 addOp, resultType, fmaElemOp.getResult(), shiftParamOp.getResult());
1239
1240 return success();
1241 }
1242
1243 unsigned shiftParam;
1244};
1245
1246// This pattern replaces `arith.mulf` on vectors with
1247// `aievec.mul_elem`. This pattern works for AIE2.
1249 : OpConversionPattern<arith::MulFOp> {
1250 using OpConversionPattern::OpConversionPattern;
1251
1253 unsigned shiftParam = 0)
1255
1256 LogicalResult
1257 matchAndRewrite(arith::MulFOp mulOp, OpAdaptor adaptor,
1258 ConversionPatternRewriter &rewriter) const override {
1259 // Verify it's a vector operation
1260 auto resultType = dyn_cast<VectorType>(mulOp.getType());
1261 if (!resultType)
1262 return failure();
1263
1264 // Skip standalone mul conversion when this MulFOp feeds into an AddFOp as
1265 // its sole user. ConvertMulAddFToAIEVecFMAElemOpPattern will fuse the
1266 // multiply and add into a single aievec.mac_elem.
1267 // Only defer to FMA for bf16 where the FMA pattern is registered.
1268 auto isAddOp = [&](Operation *op) { return isa<arith::AddFOp>(op); };
1269 if (resultType.getElementType().isBF16() && mulOp->hasOneUse() &&
1270 llvm::any_of(mulOp->getUsers(), isAddOp))
1271 return failure();
1272
1273 unsigned resultElWidth =
1274 resultType.getElementType().getIntOrFloatBitWidth();
1275
1276 unsigned laneSize = getVectorLaneSize(resultType);
1277
1278 // Handle vector<32xbf16> by splitting into two vector<16xbf16> operations
1279 if (laneSize == 32 && resultElWidth == 16) {
1280 VectorType halfType = createVectorType(16, resultType.getElementType());
1281 unsigned localShiftParam = shiftParam;
1282
1283 splitWideBinaryVectorOp<arith::MulFOp>(
1284 mulOp, adaptor.getLhs(), adaptor.getRhs(), halfType, resultType,
1285 rewriter,
1286 [localShiftParam](Value lhsHalf, Value rhsHalf, Location loc,
1287 ConversionPatternRewriter &rewriter) -> Value {
1288 Type accType = getVectorOpDestType(
1289 cast<VectorType>(lhsHalf.getType()), /*AIE2 =*/true);
1290 auto mulElemOp = aievec::MulElemOp::create(rewriter, loc, accType,
1291 lhsHalf, rhsHalf);
1292 auto shiftParamOp = arith::ConstantOp::create(
1293 rewriter, loc, rewriter.getI32IntegerAttr(localShiftParam));
1294 auto srsOp = aievec::SRSOp::create(
1295 rewriter, loc, cast<VectorType>(lhsHalf.getType()),
1296 mulElemOp.getResult(), shiftParamOp.getResult());
1297 return srsOp.getResult();
1298 });
1299 return success();
1300 }
1301
1302 // bfloat16 and float type (laneSize == 16)
1303 if (laneSize != 16 || (resultElWidth != 16 && resultElWidth != 32))
1304 return failure();
1305
1306 // Decide the accType for aievec.mul_elem based on mulOp's lhs & rhs
1307 auto lval = adaptor.getLhs();
1308 auto rval = adaptor.getRhs();
1309 lval = getSourceOfWideningOp(lval).value_or(lval);
1310 rval = getSourceOfWideningOp(rval).value_or(rval);
1311 auto lSrcType = cast<VectorType>(lval.getType());
1312 auto rSrcType = cast<VectorType>(rval.getType());
1313 unsigned lBitWidth = lSrcType.getElementType().getIntOrFloatBitWidth();
1314 unsigned rBitWidth = rSrcType.getElementType().getIntOrFloatBitWidth();
1315 Type accType = getVectorOpDestType(lSrcType, /*AIE2 =*/true);
1316 if (rBitWidth > lBitWidth) {
1317 accType = getVectorOpDestType(rSrcType, /*AIE2 =*/true);
1318 }
1319 // Only support the same lhs/rhs type at the moment
1320 if (lSrcType != rSrcType) {
1321 return failure();
1322 }
1323
1324 // Prepare lhr/rhs for the aievec.mul_elem op
1325 unsigned bitWidth = (rBitWidth > lBitWidth) ? rBitWidth : lBitWidth;
1326 Type srcElemType = (rBitWidth > lBitWidth) ? rSrcType.getElementType()
1327 : lSrcType.getElementType();
1328 unsigned numLanes = 0;
1329 if (isa<FloatType>(srcElemType) && (bitWidth == 16 || bitWidth == 32)) {
1330 numLanes = 16;
1331 } else if (isa<IntegerType>(srcElemType) &&
1332 (bitWidth == 8 || bitWidth == 16)) {
1333 numLanes = 32;
1334 } else if (isa<IntegerType>(srcElemType) && (bitWidth == 32)) {
1335 numLanes = 16;
1336 } else {
1337 return failure();
1338 }
1339 VectorType targetInputType = createVectorType(numLanes, srcElemType);
1340 if (targetInputType != lSrcType) {
1341 lval = convertValueToTargetTypeAIE2(rewriter, mulOp.getLoc(), lval,
1342 targetInputType)
1343 .value();
1344 }
1345 if (targetInputType != rSrcType) {
1346 rval = convertValueToTargetTypeAIE2(rewriter, mulOp.getLoc(), rval,
1347 targetInputType)
1348 .value();
1349 }
1350 if (!lval || !rval)
1351 return failure();
1352
1353 // Create an aievec.mul_elem op
1354 auto mulElemOp = aievec::MulElemOp::create(rewriter, mulOp.getLoc(),
1355 accType, lval, rval);
1356
1357 // Create an aievec.cast or an aievec.srs op
1358 auto mulElemResultType = mulElemOp.getType();
1359 auto mulElemResultElWidth =
1360 mulElemResultType.getElementType().getIntOrFloatBitWidth();
1361
1362 if (mulElemResultElWidth == resultElWidth) {
1363 rewriter.replaceOpWithNewOp<aievec::CastOp>(
1364 mulOp, resultType, mulElemOp.getResult(), /*isResAcc*/ false);
1365 } else if (mulElemResultElWidth > resultElWidth) {
1366 auto shiftParamOp = arith::ConstantOp::create(
1367 rewriter, mulOp.getLoc(), rewriter.getI32IntegerAttr(shiftParam));
1368 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1369 mulOp, resultType, mulElemOp.getResult(), shiftParamOp.getResult());
1370 } else {
1371 return failure();
1372 }
1373
1374 return success();
1375 }
1376
1377 unsigned shiftParam;
1378};
1379
1380// This pattern replaces `arith.muli` on vectors with
1381// `aievec.mul_elem`. This pattern works for AIE2.
1383 : OpConversionPattern<arith::MulIOp> {
1384 using OpConversionPattern::OpConversionPattern;
1385
1387 unsigned shiftParam = 0)
1389
1390 LogicalResult
1391 matchAndRewrite(arith::MulIOp mulOp, OpAdaptor adaptor,
1392 ConversionPatternRewriter &rewriter) const override {
1393 // Verify it's a vector operation
1394 auto resultType = dyn_cast<VectorType>(mulOp.getType());
1395 if (!resultType)
1396 return failure();
1397
1398 // FIXME: Verify it is not a part of MAC
1399 auto isAddOp = [&](Operation *op) { return isa<arith::AddIOp>(op); };
1400 if (mulOp->hasOneUse() && llvm::any_of(mulOp->getUsers(), isAddOp))
1401 return failure();
1402
1403 // Verify the vector type is supported by AIE2
1404 unsigned resultElWidth =
1405 resultType.getElementType().getIntOrFloatBitWidth();
1406 unsigned laneSize = getVectorLaneSize(resultType);
1407
1408 if ((laneSize != 32 || (resultElWidth != 16 && resultElWidth != 8)) &&
1409 ((laneSize != 16 && laneSize != 32) || resultElWidth != 32))
1410 return failure();
1411
1412 // Decide the accType for aievec.mul_elem based on mulOp's lhs & rhs
1413 auto lval = adaptor.getLhs();
1414 auto rval = adaptor.getRhs();
1415
1416 lval = getSourceOfWideningOp(lval).value_or(lval);
1417 rval = getSourceOfWideningOp(rval).value_or(rval);
1418
1419 auto lSrcType = cast<VectorType>(lval.getType());
1420 auto rSrcType = cast<VectorType>(rval.getType());
1421 unsigned lBitWidth = lSrcType.getElementType().getIntOrFloatBitWidth();
1422 unsigned rBitWidth = rSrcType.getElementType().getIntOrFloatBitWidth();
1423 Type accType = getVectorOpDestType(lSrcType, /*AIE2 =*/true);
1424 if (rBitWidth > lBitWidth) {
1425 accType = getVectorOpDestType(rSrcType, /*AIE2 =*/true);
1426 }
1427
1428 // Prepare lhr/rhs for the aievec.mul_elem op
1429 unsigned bitWidth = (rBitWidth > lBitWidth) ? rBitWidth : lBitWidth;
1430 Type srcElemType = (rBitWidth > lBitWidth) ? rSrcType.getElementType()
1431 : lSrcType.getElementType();
1432 unsigned numLanes = 0;
1433 if (isa<FloatType>(srcElemType) && (bitWidth == 16 || bitWidth == 32)) {
1434 numLanes = 16;
1435 } else if (isa<IntegerType>(srcElemType) &&
1436 (bitWidth == 8 || bitWidth == 16)) {
1437 numLanes = 32;
1438 } else if (isa<IntegerType>(srcElemType) && (bitWidth == 32)) {
1439 numLanes = 16;
1440 } else {
1441 return failure();
1442 }
1443 VectorType targetInputType = createVectorType(numLanes, srcElemType);
1444 if (targetInputType != lSrcType) {
1445 lval = convertValueToTargetTypeAIE2(rewriter, mulOp.getLoc(), lval,
1446 targetInputType)
1447 .value();
1448 }
1449 if (targetInputType != rSrcType) {
1450 rval = convertValueToTargetTypeAIE2(rewriter, mulOp.getLoc(), rval,
1451 targetInputType)
1452 .value();
1453 }
1454 if (!lval || !rval)
1455 return failure();
1456
1457 // Create an aievec.mul_elem op
1458 auto mulElemOp = aievec::MulElemOp::create(rewriter, mulOp.getLoc(),
1459 accType, lval, rval);
1460
1461 // Create an aievec.cast or an aievec.srs op
1462 auto mulElemResultType = mulElemOp.getType();
1463 auto mulElemResultElWidth =
1464 mulElemResultType.getElementType().getIntOrFloatBitWidth();
1465
1466 if (mulElemResultElWidth == resultElWidth) {
1467 rewriter.replaceOpWithNewOp<aievec::CastOp>(
1468 mulOp, resultType, mulElemOp.getResult(), /*isResAcc*/ false);
1469 } else if (mulElemResultElWidth > resultElWidth) {
1470 auto shiftParamOp = arith::ConstantOp::create(
1471 rewriter, mulOp.getLoc(), rewriter.getI32IntegerAttr(shiftParam));
1472 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1473 mulOp, resultType, mulElemOp.getResult(), shiftParamOp.getResult());
1474 } else {
1475 return failure();
1476 }
1477
1478 return success();
1479 }
1480
1481 unsigned shiftParam;
1482};
1483
1484// This pattern folds an extract + broadcast feeding into an
1485// `aievec::aie1::FMAOp` into the op, using the shuffle attributes.
1486struct FoldSplatToFMAOp : OpConversionPattern<aievec::aie1::FMAOp> {
1487 using OpConversionPattern::OpConversionPattern;
1488
1489 LogicalResult
1490 matchAndRewrite(aievec::aie1::FMAOp fmaOp, OpAdaptor adaptor,
1491 ConversionPatternRewriter &rewriter) const override {
1492 auto concatOp =
1493 dyn_cast<aievec::ConcatOp>(adaptor.getLhs().getDefiningOp());
1494 if (!concatOp)
1495 return failure();
1496 vector::BroadcastOp bcastOp = nullptr;
1497 auto *concatDefOp = concatOp.getSources()[0].getDefiningOp();
1498 if (concatDefOp)
1499 bcastOp = dyn_cast<vector::BroadcastOp>(concatDefOp);
1500 Value lhs = adaptor.getRhs();
1501 if (!bcastOp) {
1502 bcastOp = dyn_cast<vector::BroadcastOp>(adaptor.getRhs().getDefiningOp());
1503 if (!bcastOp)
1504 return failure();
1505 lhs = concatOp.getSources()[0];
1506 }
1507 auto extOp =
1508 dyn_cast<vector::ExtractOp>(bcastOp.getSource().getDefiningOp());
1509 if (!extOp)
1510 return failure();
1511
1512 auto rhs = extOp.getSource();
1513 auto concatVecType = cast<VectorType>(concatOp.getResult().getType());
1514 auto zvec =
1515 arith::ConstantOp::create(rewriter, concatOp.getLoc(), lhs.getType(),
1516 rewriter.getZeroAttr(lhs.getType()));
1517 auto lhsX2 =
1518 aievec::ConcatOp::create(rewriter, concatOp.getLoc(), concatVecType,
1519 SmallVector<Value, 2>({lhs, zvec}))
1520 .getResult();
1521 // XXX: We assume a 1D vector
1522 auto pos = extOp.getStaticPosition();
1523 int64_t zstart = pos[0];
1524 auto fmaOpAttr = buildFMAOpSplatAttrForElemTy(fmaOp, zstart);
1525 rewriter.replaceOpWithNewOp<aievec::aie1::FMAOp>(
1526 fmaOp, TypeRange({fmaOp.getResult().getType()}),
1527 ValueRange({lhsX2, rhs, adaptor.getAcc()}), fmaOpAttr);
1528
1529 return success();
1530 }
1531};
1532
1534 : OpConversionPattern<aievec::aie1::AddOp> {
1535 using OpConversionPattern::OpConversionPattern;
1536
1537 LogicalResult
1538 matchAndRewrite(aievec::aie1::AddOp addOp, OpAdaptor adaptor,
1539 ConversionPatternRewriter &rewriter) const override {
1540 auto vecType = cast<VectorType>(addOp.getType());
1541
1542 auto res =
1543 extractMACOperandsFromAddOperands(adaptor.getLhs(), adaptor.getRhs());
1544 if (!res)
1545 return failure();
1546 auto [lhs, rhs, acc] = *res;
1547
1548 SmallVector<int64_t, 4> concatVecShape(vecType.getShape().begin(),
1549 vecType.getShape().end());
1550 concatVecShape[vecType.getRank() - 1] *= 2;
1551 auto concatVecType =
1552 VectorType::get(concatVecShape, vecType.getElementType());
1553 Type accType = getVectorOpDestType(cast<VectorType>(acc.getType()),
1554 /*AIE2 =*/false);
1555 auto lhsX2 =
1556 aievec::ConcatOp::create(rewriter, addOp.getLoc(), concatVecType,
1557 SmallVector<Value, 2>(2, lhs))
1558 .getResult();
1559 auto upsOp = aievec::UPSOp::create(rewriter, addOp.getLoc(), accType, acc);
1560 auto fmaOp = aievec::aie1::FMAOp::create(
1561 rewriter, addOp.getLoc(), accType, lhsX2, rhs, upsOp.getResult(),
1562 /*xstart=*/"", /*xoffsets=*/"", /*xoffsets_hi=*/"", /*xstep=*/"",
1563 /*xsquare=*/"", /*zstart=*/"", /*zoffsets=*/"", /*zoffsets_hi=*/"",
1564 /*zstep=*/"", /*zsquare=*/"", /*fmsub=*/false);
1565 auto shiftParamOp = arith::ConstantOp::create(
1566 rewriter, addOp.getLoc(), rewriter.getI32IntegerAttr(0));
1567 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1568 addOp, vecType, fmaOp.getResult(), shiftParamOp.getResult());
1569 return success();
1570 }
1571};
1572
1573// This pattern replaces `vector.transfer_read` with `aievec.upd`. Right now,
1574// it performs a naïve direct translation. This needs to be expanded to
1575// support more complex scenarios.
1577 : OpConversionPattern<vector::TransferReadOp> {
1578 using OpConversionPattern::OpConversionPattern;
1579
1580 LowerVectorTransferReadToAIEUPD(MLIRContext *context, int64_t minVectorSize,
1581 int64_t maxVectorSize, int64_t alignment,
1582 int64_t maxLoadSize)
1586
1587 LogicalResult
1588 matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
1589 ConversionPatternRewriter &rewriter) const override {
1590 // Masked loads
1591 if (readOp.getMask())
1592 return readOp.emitError() << "AIE doesn't support masked loads.";
1593
1594 // Non-contiguous loads
1595 AffineMap map = readOp.getPermutationMap();
1596 if (!map.isMinorIdentity())
1597 return failure();
1598
1599 // Splats
1600 if (map.isConstant())
1601 return failure();
1602
1603 // Misaligned accesses
1604 auto vType = readOp.getVectorType();
1606 .value_or(0) != 0)
1607 return failure();
1608
1609 // Invalid vector size.
1610 // We can handle cases where the vector size is:
1611 // 1) the minimum vector size
1612 // 2) a square multiple of the alignment size and up to the maximum
1613 // vector size.
1614 int64_t vSize = vType.getNumElements() * vType.getElementTypeBitWidth();
1615 if (vSize > maxVectorSize ||
1616 (vSize % vectorAlignment && vSize != minVectorSize))
1617 return failure();
1618 // We can deal with linked update instructions when the vector size is
1619 // exactly twice the load size. This could change in future architectures
1620 if (vSize > maxLoadSize && vSize != maxLoadSize * 2)
1621 return failure();
1622 int64_t multiplicity = vSize / vectorAlignment;
1623 if ((vSize > minVectorSize) && std::bitset<8>(multiplicity).count() != 1)
1624 return failure();
1625
1626 auto updOp = xilinx::aievec::UPDOp::create(
1627 rewriter, readOp.getLoc(), vType, adaptor.getBase(),
1628 adaptor.getIndices(), 0, 0, TypedValue<VectorType>(nullptr));
1629 if (vSize > maxLoadSize) {
1630 updOp = xilinx::aievec::UPDOp::create(
1631 rewriter, readOp.getLoc(), vType, adaptor.getBase(),
1632 adaptor.getIndices(), maxLoadSize, 1, updOp.getResult());
1633 }
1634 rewriter.replaceOp(readOp, updOp.getResult());
1635
1636 return success();
1637 }
1638
1640};
1641
1642// XXX: Notice that this template doesn't verify that the vector element type
1643// XXX: is supported by the target architecture.
1644template <typename SrcOpTy, typename DstOpTy>
1647 using OpAdaptor = typename SrcOpTy::Adaptor;
1648
1649 LogicalResult
1650 matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor,
1651 ConversionPatternRewriter &rewriter) const override {
1652 rewriter.replaceOpWithNewOp<DstOpTy>(
1653 srcOp, srcOp.getResult().getType(), adaptor.getLhs(), adaptor.getRhs(),
1654 /*xstart=*/"", /*xoffsets=*/"", /*xoffsets_hi=*/"", /*xsquare=*/"",
1655 /*zstart=*/"", /*zoffsets=*/"", /*zoffsets_hi=*/"", /*zsquare=*/"");
1656 return success();
1657 }
1658};
1659
1661 using OpConversionPattern::OpConversionPattern;
1662
1663 LogicalResult
1664 matchAndRewrite(arith::AddIOp addOp, OpAdaptor adaptor,
1665 ConversionPatternRewriter &rewriter) const override {
1666 auto resType = addOp.getType();
1667 if (!isa<VectorType>(resType))
1668 return failure();
1669
1670 auto lhs = adaptor.getLhs();
1671 auto rhs = adaptor.getRhs();
1672 auto *lhsDefOp = lhs.getDefiningOp();
1673 auto *rhsDefOp = rhs.getDefiningOp();
1674 if ((isa_and_nonnull<arith::MulIOp>(lhsDefOp)) ||
1675 (isa_and_nonnull<arith::MulIOp>(rhsDefOp)))
1676 return failure();
1677
1678 rewriter.replaceOpWithNewOp<aievec::aie1::AddOp>(
1679 addOp, resType, lhs, rhs,
1680 /*xstart=*/"", /*xoffsets=*/"", /*xoffsets_hi=*/"", /*xsquare=*/"",
1681 /*zstart=*/"", /*zoffsets=*/"", /*zoffsets_hi=*/"", /*zsquare=*/"");
1682 return success();
1683 }
1684};
1685
1694
1696 using OpConversionPattern::OpConversionPattern;
1697 LogicalResult
1698 matchAndRewrite(arith::MulIOp mulOp, OpAdaptor adaptor,
1699 ConversionPatternRewriter &rewriter) const override {
1700 auto resTy = dyn_cast<VectorType>(mulOp.getType());
1701 if (!resTy)
1702 return failure();
1703 auto accTy = getVectorOpDestType(resTy, /*AIE2 =*/false);
1704 auto newMulOp = aievec::aie1::MulOp::create(
1705 rewriter, mulOp.getLoc(), accTy, adaptor.getLhs(), adaptor.getRhs());
1706 auto shiftParamOp = arith::ConstantOp::create(
1707 rewriter, mulOp.getLoc(), rewriter.getI32IntegerAttr(0));
1708 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1709 mulOp, resTy, newMulOp.getResult(), shiftParamOp.getResult());
1710 return success();
1711 }
1712};
1713
1714template <typename SrcOpTy, typename DstOpTy>
1716 : OpConversionPattern<SrcOpTy> {
1718 using OpAdaptor = typename SrcOpTy::Adaptor;
1719
1720 LogicalResult
1721 matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor,
1722 ConversionPatternRewriter &rewriter) const override {
1723 VectorType resultType = dyn_cast<VectorType>(srcOp.getType());
1724 if (!resultType)
1725 return failure();
1726
1727 // A set recording the vector lane size and element width we are supporting
1728 // for AIE2.
1729 llvm::SmallSet<std::pair<unsigned, signed>, 16> laneSizeElWidthPairSet;
1730 laneSizeElWidthPairSet.insert({64, 8});
1731 laneSizeElWidthPairSet.insert({32, 16});
1732 laneSizeElWidthPairSet.insert({16, 32});
1733 laneSizeElWidthPairSet.insert({32, 32});
1734
1735 auto lhs = adaptor.getLhs();
1736 auto rhs = adaptor.getRhs();
1737 auto lhsDefOp = lhs.getDefiningOp();
1738 auto rhsDefOp = rhs.getDefiningOp();
1739 // Check if this is part of a MAC/FMA pattern (mul + add).
1740 // We only skip conversion if BOTH operands could potentially be part of an
1741 // FMA pattern (i.e., neither is a constant). Constants can never be the
1742 // multiply result in an FMA, so we should allow conversion in those cases.
1743 bool lhsIsMul = lhsDefOp && (isa<arith::MulIOp>(lhsDefOp) ||
1744 isa<arith::MulFOp>(lhsDefOp));
1745 bool rhsIsMul = rhsDefOp && (isa<arith::MulIOp>(rhsDefOp) ||
1746 isa<arith::MulFOp>(rhsDefOp));
1747 bool lhsIsConst = lhsDefOp && isa<arith::ConstantOp>(lhsDefOp);
1748 bool rhsIsConst = rhsDefOp && isa<arith::ConstantOp>(rhsDefOp);
1749
1750 // Defer to FMA/MAC patterns when a multiply feeds into add, UNLESS the
1751 // element type is f32 (where no FMA pattern exists). For bf16 and integer
1752 // types, FMA/MAC patterns handle the fusion.
1753 if (!resultType.getElementType().isF32() &&
1754 ((lhsIsMul && !rhsIsConst) || (rhsIsMul && !lhsIsConst)))
1755 return failure();
1756
1757 Type scalarType = resultType.getElementType();
1758 unsigned resultElWidth = scalarType.getIntOrFloatBitWidth();
1759 unsigned laneSize = getVectorLaneSize(resultType);
1760
1761 // Integer cases
1762 if (isa<IntegerType>(scalarType)) {
1763 if (!laneSizeElWidthPairSet.count(
1764 std::make_pair(laneSize, resultElWidth)))
1765 return failure();
1766
1767 // If the ops are defined without extension ops and with supported data
1768 // type, the arith::AddI or arith::SubI can be directly replaced with
1769 // aievec::AddElem or aievec::SubElem.
1770 if (!lhsDefOp && !rhsDefOp) {
1771 if (laneSize * resultElWidth == 512) {
1772 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getType(), lhs,
1773 rhs);
1774 return success();
1775 }
1776 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs, resultType,
1777 srcOp);
1778 }
1779
1780 // If element width is 32, we need to consider sign extension cases
1781 if (resultElWidth == 32) {
1782 auto lhsExt = getSourceOfWideningOp(lhs).value_or(nullptr);
1783 auto rhsExt = getSourceOfWideningOp(rhs).value_or(nullptr);
1784
1785 if (!lhsExt && !rhsExt) {
1786 if (laneSize * resultElWidth == 512) {
1787 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getType(), lhs,
1788 rhs);
1789 return success();
1790 }
1791 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1792 resultType, srcOp);
1793 }
1794
1795 if (lhsExt && rhsExt) {
1796 auto lval = lhsExt;
1797 auto rval = rhsExt;
1798 VectorType lSrcType = cast<VectorType>(lval.getType());
1799
1800 Type accType = getVectorOpDestType(lSrcType, /*AIE2 =*/true);
1801 auto lUpsOp =
1802 aievec::UPSOp::create(rewriter, srcOp.getLoc(), accType, lval);
1803 auto rUpsOp =
1804 aievec::UPSOp::create(rewriter, srcOp.getLoc(), accType, rval);
1805 auto elemOp = DstOpTy::create(
1806 rewriter, srcOp.getLoc(), lUpsOp->getResult(0).getType(),
1807 lUpsOp->getResult(0), rUpsOp->getResult(0));
1808 rewriter.replaceOpWithNewOp<aievec::CastOp>(
1809 srcOp, srcOp.getType(), elemOp.getResult(), /*isResAcc*/ false);
1810 return success();
1811 }
1812
1813 if (!lhsExt || !rhsExt) {
1814 auto lval = lhsExt ? lhsExt : lhs;
1815 auto rval = rhsExt ? rhsExt : rhs;
1816 auto extVal = lhsExt ? lval : rval;
1817 VectorType vType = cast<VectorType>(extVal.getType());
1818 unsigned bitWidth = vType.getElementType().getIntOrFloatBitWidth();
1819
1820 if (bitWidth != 8 && bitWidth != 16) {
1821 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1822 resultType, srcOp);
1823 }
1824
1825 if (bitWidth * laneSize != 256) {
1826 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1827 resultType, srcOp);
1828 }
1829
1830 Type accType = nullptr;
1831
1832 if (bitWidth == 8) {
1833 accType = getVectorOpDestType(vType, /*AIE2 =*/true);
1834 Value valToUps = lhsExt ? lval : rval;
1835 Value valToCast = lhsExt ? rval : lval;
1836 auto upsOp = aievec::UPSOp::create(rewriter, srcOp.getLoc(),
1837 accType, valToUps);
1838 auto castOp =
1839 aievec::CastOp::create(rewriter, srcOp.getLoc(), resultType,
1840 valToCast, /*isResAcc*/ true);
1841 Value lhsToElemOp =
1842 lhsExt ? upsOp->getResult(0) : castOp->getResult(0);
1843 Value rhsToElemOp =
1844 lhsExt ? castOp->getResult(0) : upsOp->getResult(0);
1845 auto elemOp = DstOpTy::create(rewriter, srcOp.getLoc(),
1846 upsOp->getResult(0).getType(),
1847 lhsToElemOp, rhsToElemOp);
1848 rewriter.replaceOpWithNewOp<aievec::CastOp>(
1849 srcOp, srcOp.getType(), elemOp.getResult(), /*isResAcc*/ false);
1850 return success();
1851 }
1852
1853 if (bitWidth == 16) {
1854 accType = getVectorOpDestType(resultType, /*AIE2 =*/true);
1855 auto lUpsOp =
1856 aievec::UPSOp::create(rewriter, srcOp.getLoc(), accType, lval);
1857 auto rUpsOp =
1858 aievec::UPSOp::create(rewriter, srcOp.getLoc(), accType, rval);
1859
1860 auto elemOp = DstOpTy::create(
1861 rewriter, srcOp.getLoc(), lUpsOp->getResult(0).getType(),
1862 lUpsOp->getResult(0), rUpsOp->getResult(0));
1863
1864 auto shiftParamOp = arith::ConstantOp::create(
1865 rewriter, srcOp.getLoc(), rewriter.getI32IntegerAttr(0));
1866 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1867 srcOp, srcOp.getType(), elemOp.getResult(),
1868 shiftParamOp.getResult());
1869 return success();
1870 }
1871 }
1872 } else {
1873 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getType(), lhs, rhs);
1874 return success();
1875 }
1876 }
1877 // Float types
1878 else {
1879 if (laneSize != 16 && laneSize != 32)
1880 return failure();
1881
1882 // v32f32: split into two v16f32 ops
1883 if (laneSize == 32 && resultElWidth == 32) {
1884 VectorType halfType = createVectorType(16, scalarType);
1885
1886 splitWideBinaryVectorOp<SrcOpTy>(
1887 srcOp, lhs, rhs, halfType, resultType, rewriter,
1888 [](Value lhsHalf, Value rhsHalf, Location loc,
1889 ConversionPatternRewriter &rewriter) -> Value {
1890 VectorType halfVecType = cast<VectorType>(lhsHalf.getType());
1891 // For f32, use cast to acc, add_elem/sub_elem, cast back
1892 auto lCastOp = aievec::CastOp::create(rewriter, loc, halfVecType,
1893 lhsHalf, /*isResAcc*/ true);
1894 auto rCastOp = aievec::CastOp::create(rewriter, loc, halfVecType,
1895 rhsHalf, /*isResAcc*/ true);
1896 auto elemOp = DstOpTy::create(
1897 rewriter, loc, lCastOp->getResult(0).getType(),
1898 lCastOp->getResult(0), rCastOp->getResult(0));
1899 auto resCastOp = aievec::CastOp::create(
1900 rewriter, loc, halfVecType, elemOp.getResult(),
1901 /*isResAcc*/ false);
1902 return resCastOp.getResult();
1903 });
1904 return success();
1905 }
1906
1907 // v32bf16: split into two v16bf16 ops
1908 if (laneSize == 32 && resultElWidth == 16) {
1909 VectorType halfType = createVectorType(16, scalarType);
1910
1911 splitWideBinaryVectorOp<SrcOpTy>(
1912 srcOp, lhs, rhs, halfType, resultType, rewriter,
1913 [](Value lhsHalf, Value rhsHalf, Location loc,
1914 ConversionPatternRewriter &rewriter) -> Value {
1915 VectorType halfVecType = cast<VectorType>(lhsHalf.getType());
1916 Type accType = getVectorOpDestType(halfVecType, /*AIE2 =*/true);
1917 auto lUpsOp =
1918 aievec::UPSOp::create(rewriter, loc, accType, lhsHalf);
1919 auto rUpsOp =
1920 aievec::UPSOp::create(rewriter, loc, accType, rhsHalf);
1921 auto elemOp =
1922 DstOpTy::create(rewriter, loc, lUpsOp->getResult(0).getType(),
1923 lUpsOp->getResult(0), rUpsOp->getResult(0));
1924 auto shiftParamOp = arith::ConstantOp::create(
1925 rewriter, loc, rewriter.getI32IntegerAttr(0));
1926 auto srsOp = aievec::SRSOp::create(rewriter, loc, halfVecType,
1927 elemOp.getResult(),
1928 shiftParamOp.getResult());
1929 return srsOp.getResult();
1930 });
1931 return success();
1932 }
1933
1934 // Now we know laneSize == 16 for remaining float cases
1935 // v16float or v16bf16 with extension op case
1936 if (resultElWidth == 32) {
1937 if (!lhsDefOp && !rhsDefOp) {
1938 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1939 resultType, srcOp);
1940 }
1941
1942 auto lhsExt = getSourceOfWideningOp(lhs).value_or(nullptr);
1943 auto rhsExt = getSourceOfWideningOp(rhs).value_or(nullptr);
1944 // v16float
1945 if (!lhsExt && !rhsExt) {
1946 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1947 resultType, srcOp);
1948 }
1949
1950 // v16bf16 with two extension ops
1951 if (lhsExt && rhsExt) {
1952 auto lval = lhsExt;
1953 auto rval = rhsExt;
1954 VectorType vType = cast<VectorType>(lval.getType());
1955
1956 Type accType = getVectorOpDestType(vType, /*AIE2 =*/true);
1957 auto lUpsOp =
1958 aievec::UPSOp::create(rewriter, srcOp.getLoc(), accType, lval);
1959 auto rUpsOp =
1960 aievec::UPSOp::create(rewriter, srcOp.getLoc(), accType, rval);
1961 auto elemOp = DstOpTy::create(
1962 rewriter, srcOp.getLoc(), lUpsOp->getResult(0).getType(),
1963 lUpsOp->getResult(0), rUpsOp->getResult(0));
1964 rewriter.replaceOpWithNewOp<aievec::CastOp>(srcOp, srcOp.getType(),
1965 elemOp.getResult());
1966 return success();
1967 }
1968
1969 // v16bf16 with one extension op
1970 if (!lhsExt || !rhsExt) {
1971 auto lval = lhsExt ? lhsExt : lhs;
1972 auto rval = rhsExt ? rhsExt : rhs;
1973 auto extVal = lhsExt ? lval : rval;
1974 VectorType vType = cast<VectorType>(extVal.getType());
1975 Type accType = getVectorOpDestType(vType, /*AIE2 =*/true);
1976
1977 aievec::UPSOp upsOp;
1978 aievec::CastOp castOp;
1979 if (lhsExt) {
1980 upsOp =
1981 aievec::UPSOp::create(rewriter, srcOp.getLoc(), accType, lval);
1982 castOp = aievec::CastOp::create(rewriter, srcOp.getLoc(),
1983 resultType, rval,
1984 /*isResAcc*/ true);
1985 } else {
1986 upsOp =
1987 aievec::UPSOp::create(rewriter, srcOp.getLoc(), accType, rval);
1988 castOp = aievec::CastOp::create(rewriter, srcOp.getLoc(),
1989 resultType, lval,
1990 /*isResAcc*/ true);
1991 }
1992
1993 auto elemOp = DstOpTy::create(
1994 rewriter, srcOp.getLoc(), upsOp->getResult(0).getType(),
1995 upsOp->getResult(0), castOp->getResult(0));
1996
1997 rewriter.replaceOpWithNewOp<aievec::CastOp>(
1998 srcOp, srcOp.getType(), elemOp.getResult(), /*isResAcc*/ false);
1999
2000 return success();
2001 }
2002 }
2003
2004 // v16bfloat16
2005 Type accType = getVectorOpDestType(resultType, /*AIE2 =*/true);
2006 auto lUpsOp =
2007 aievec::UPSOp::create(rewriter, srcOp.getLoc(), accType, lhs);
2008 auto rUpsOp =
2009 aievec::UPSOp::create(rewriter, srcOp.getLoc(), accType, rhs);
2010 auto elemOp = DstOpTy::create(rewriter, srcOp.getLoc(),
2011 lUpsOp->getResult(0).getType(),
2012 lUpsOp->getResult(0), rUpsOp->getResult(0));
2013 auto shiftParamOp = arith::ConstantOp::create(
2014 rewriter, srcOp.getLoc(), rewriter.getI32IntegerAttr(0));
2015 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
2016 srcOp, srcOp.getType(), elemOp.getResult(), shiftParamOp.getResult());
2017
2018 return success();
2019 }
2020
2021 return failure();
2022 }
2023};
2024
2027 aievec::AddElemOp>;
2030 aievec::SubElemOp>;
2033 aievec::AddElemOp>;
2036 aievec::SubElemOp>;
2037
2038template <typename SrcOpTy, typename DstOpTy>
2041 using OpAdaptor = typename SrcOpTy::Adaptor;
2042
2043 LogicalResult
2044 matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor,
2045 ConversionPatternRewriter &rewriter) const override {
2046 VectorType resultType = dyn_cast<VectorType>(srcOp.getType());
2047 if (!resultType)
2048 return failure();
2049
2050 // A set recording the element width we are supporting for AIE2.
2051 llvm::SmallSet<unsigned, 16> elWidthSet;
2052 elWidthSet.insert(8);
2053 elWidthSet.insert(16);
2054 elWidthSet.insert(32);
2055
2056 Type scalarType = resultType.getElementType();
2057 unsigned resultElWidth = scalarType.getIntOrFloatBitWidth();
2058 unsigned laneSize = getVectorLaneSize(resultType);
2059
2060 unsigned totalBits = laneSize * resultElWidth;
2061 if (!elWidthSet.count(resultElWidth) ||
2062 (totalBits != 512 && !(totalBits == 256 && resultElWidth == 16)))
2063 return failure();
2064
2065 if (totalBits == 256 && resultElWidth == 16) {
2066 // Pad v16bf16 to v32bf16, apply max/min, extract lower half
2067 Location loc = srcOp.getLoc();
2068 VectorType wideType = createVectorType(32, scalarType);
2069 Value lhsPad =
2070 padV16ToV32WithZeros(rewriter, loc, adaptor.getLhs(), scalarType);
2071 Value rhsPad =
2072 padV16ToV32WithZeros(rewriter, loc, adaptor.getRhs(), scalarType);
2073 auto wideOp = DstOpTy::create(rewriter, loc, wideType, lhsPad, rhsPad);
2074 rewriter.replaceOpWithNewOp<aievec::ExtOp>(srcOp, resultType,
2075 wideOp.getResult(), 0);
2076 return success();
2077 }
2078
2079 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getType(),
2080 adaptor.getLhs(), adaptor.getRhs());
2081 return success();
2082 }
2083};
2084
2089// Promote scalar arith.maxsi/arith.minsi to vector aievec.max/aievec.min
2090// to avoid the AIE2 G_SELECT legalizer crash on scalar i32 select.
2091template <typename SrcOpTy, typename DstOpTy>
2094 using OpAdaptor = typename SrcOpTy::Adaptor;
2095
2096 LogicalResult
2097 matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor,
2098 ConversionPatternRewriter &rewriter) const override {
2099 // Only match scalar integer types (reject vectors)
2100 Type resultType = srcOp.getType();
2101 if (isa<VectorType>(resultType))
2102 return failure();
2103
2104 auto intType = dyn_cast<IntegerType>(resultType);
2105 if (!intType)
2106 return failure();
2107
2108 unsigned elWidth = intType.getWidth();
2109 if (elWidth != 8 && elWidth != 16 && elWidth != 32)
2110 return failure();
2111
2112 unsigned numLanes = 512 / elWidth;
2113 VectorType vecType = createVectorType(numLanes, intType);
2114 Location loc = srcOp.getLoc();
2115
2116 // Broadcast both scalars to 512-bit vectors
2117 auto lhsBcast = aievec::BroadcastScalarOp::create(rewriter, loc, vecType,
2118 adaptor.getLhs());
2119 auto rhsBcast = aievec::BroadcastScalarOp::create(rewriter, loc, vecType,
2120 adaptor.getRhs());
2121
2122 // Apply vector min/max
2123 auto vecOp = DstOpTy::create(rewriter, loc, vecType, lhsBcast.getResult(),
2124 rhsBcast.getResult());
2125
2126 // Extract element 0 back to scalar
2127 auto zeroIdx =
2128 arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0));
2129 rewriter.replaceOpWithNewOp<aievec::ExtElemOp>(
2130 srcOp, intType, vecOp.getResult(), zeroIdx.getResult());
2131 return success();
2132 }
2133};
2134
2139
2146
2147template <typename SrcOpTy, typename CmpTy>
2150 using OpAdaptor = typename SrcOpTy::Adaptor;
2151
2152 LogicalResult
2153 matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor,
2154 ConversionPatternRewriter &rewriter) const override {
2155 VectorType lhsType = dyn_cast<VectorType>(srcOp.getLhs().getType());
2156 if (!lhsType)
2157 return failure();
2158
2159 llvm::SmallSet<unsigned, 16> elWidthSet;
2160 elWidthSet.insert(8);
2161 elWidthSet.insert(16);
2162 elWidthSet.insert(32);
2163
2164 Type scalarType = lhsType.getElementType();
2165 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2166 unsigned laneSize = getVectorLaneSize(lhsType);
2167
2168 unsigned totalBits = laneSize * elWidth;
2169 if (!elWidthSet.count(elWidth) ||
2170 (totalBits != 512 && !(totalBits == 256 && elWidth == 16)))
2171 return failure();
2172
2173 Location loc = srcOp.getLoc();
2174 Value lhs = srcOp.getLhs();
2175 Value rhs = srcOp.getRhs();
2176 unsigned effectiveLaneSize = laneSize;
2177
2178 if (totalBits == 256 && elWidth == 16) {
2179 lhs = padV16ToV32WithZeros(rewriter, loc, lhs, scalarType);
2180 rhs = padV16ToV32WithZeros(rewriter, loc, rhs, scalarType);
2181 effectiveLaneSize = 32;
2182 }
2183
2184 // Unsigned int and unsigned long long are acceptable type.
2185 Type type = mlir::IntegerType::get(srcOp.getContext(),
2186 effectiveLaneSize <= 32 ? 32 : 64,
2187 mlir::IntegerType::Unsigned);
2188
2189 CmpTy pred = srcOp.getPredicate();
2190
2191 arith::CmpIPredicate ipred = convertToIntegerPredicate(pred);
2192
2193 aievec::CmpOp aieCmpOp =
2194 createCmpOpAIE2(rewriter, ipred, loc, type, lhs, rhs);
2195
2196 if (!aieCmpOp)
2197 return failure();
2198
2199 VectorType resultType = dyn_cast<VectorType>(srcOp.getResult().getType());
2200 // Convert vector i1 type to unsigned interger type by built-in unrealized
2201 // conversion cast op.
2202 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2203 srcOp, resultType, aieCmpOp.getResult());
2204
2205 return success();
2206 }
2207};
2208
2213
2215 using OpConversionPattern::OpConversionPattern;
2216
2217 LogicalResult
2218 matchAndRewrite(arith::SelectOp srcOp, OpAdaptor adaptor,
2219 ConversionPatternRewriter &rewriter) const override {
2220 auto resultType = dyn_cast<VectorType>(srcOp.getType());
2221 if (!resultType)
2222 return failure();
2223
2224 llvm::SmallSet<unsigned, 16> elWidthSet;
2225 elWidthSet.insert(8);
2226 elWidthSet.insert(16);
2227 elWidthSet.insert(32);
2228
2229 Type scalarType = resultType.getElementType();
2230 unsigned resultElWidth = scalarType.getIntOrFloatBitWidth();
2231 unsigned laneSize = getVectorLaneSize(resultType);
2232
2233 unsigned totalBits = laneSize * resultElWidth;
2234 if (!elWidthSet.count(resultElWidth) ||
2235 (totalBits != 512 && !(totalBits == 256 && resultElWidth == 16)))
2236 return failure();
2237
2238 if (totalBits == 256 && resultElWidth == 16) {
2239 // Pad trueValue and falseValue to v32, do sel, extract lower half
2240 Location loc = srcOp.getLoc();
2241 VectorType wideType = createVectorType(32, scalarType);
2242
2243 // aievec.sel: bit=0 selects lhs, bit=1 selects rhs.
2244 // aievec.cmp: bit=1 where predicate is true.
2245 // arith.select: condition=1 returns true_value.
2246 // So false_value goes to lhs (bit=0) and true_value goes to rhs (bit=1).
2247 Value falsePad = padV16ToV32WithZeros(rewriter, loc,
2248 srcOp.getFalseValue(), scalarType);
2249 Value truePad =
2250 padV16ToV32WithZeros(rewriter, loc, srcOp.getTrueValue(), scalarType);
2251
2252 // The condition bitmask from cmp was produced at 32 lanes (since cmp
2253 // was also padded). Use 32-lane integer type for the cast.
2254 Type type = mlir::IntegerType::get(srcOp.getContext(), 32,
2255 mlir::IntegerType::Unsigned);
2256 auto convertOp = UnrealizedConversionCastOp::create(
2257 rewriter, loc, type, adaptor.getCondition());
2258
2259 auto wideSelOp = aievec::SelOp::create(rewriter, loc, wideType, falsePad,
2260 truePad, convertOp.getResult(0));
2261
2262 rewriter.replaceOpWithNewOp<aievec::ExtOp>(srcOp, resultType,
2263 wideSelOp.getResult(), 0);
2264 return success();
2265 }
2266
2267 Type type =
2268 mlir::IntegerType::get(srcOp.getContext(), laneSize <= 32 ? 32 : 64,
2269 mlir::IntegerType::Unsigned);
2270
2271 auto convertOp = UnrealizedConversionCastOp::create(
2272 rewriter, srcOp.getLoc(), type, adaptor.getCondition());
2273
2274 // aievec.sel: bit=0 selects lhs, bit=1 selects rhs.
2275 // aievec.cmp: bit=1 where predicate is true.
2276 // arith.select: condition=1 returns true_value.
2277 // So false_value goes to lhs (bit=0) and true_value goes to rhs (bit=1).
2278 rewriter.replaceOpWithNewOp<aievec::SelOp>(
2279 srcOp, srcOp.getResult().getType(), srcOp.getFalseValue(),
2280 srcOp.getTrueValue(), convertOp.getResult(0));
2281
2282 return success();
2283 }
2284};
2285
2286struct LowerVectorReductionMinOp : OpConversionPattern<vector::ReductionOp> {
2287 using OpConversionPattern::OpConversionPattern;
2288
2289 LogicalResult
2290 matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor,
2291 ConversionPatternRewriter &rewriter) const override {
2292 if (auto kind = srcOp.getKind(); kind != vector::CombiningKind::MINSI &&
2293 kind != vector::CombiningKind::MINUI &&
2294 kind != vector::CombiningKind::MINIMUMF &&
2295 kind != vector::CombiningKind::MINNUMF)
2296 return failure();
2297
2298 auto vType = cast<VectorType>(srcOp.getVector().getType());
2299 Type scalarType = vType.getElementType();
2300 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2301 unsigned laneSize = getVectorLaneSize(vType);
2302 unsigned vectorSize = laneSize * elWidth;
2303
2304 // Support 512-bit vectors directly, and 256-bit bf16 vectors by padding
2305 if (vectorSize != 512 && !(vectorSize == 256 && scalarType.isBF16()))
2306 return failure();
2307
2308 Location loc = srcOp.getLoc();
2309 Value inputVec = srcOp.getVector();
2310
2311 // For 256-bit bf16 (v16bf16), pad to 512-bit (v32bf16) with +inf
2312 if (vectorSize == 256) {
2313 std::tie(inputVec, laneSize) = padV16ToV32WithInfinity(
2314 rewriter, loc, srcOp.getVector(), scalarType, /*negativeInf=*/false);
2315 }
2316
2317 int shiftIndex = laneSize / 2;
2318 auto reduceResultOp = generateAIEVecOpsForReductionOp<aievec::MinOp>(
2319 rewriter, srcOp, shiftIndex, inputVec);
2320
2321 if (srcOp.getAcc()) {
2322 Value reduceResult = reduceResultOp.getResult();
2323 Value acc = srcOp.getAcc();
2324
2325 // If accumulator is bf16, use the high-level helper for bf16->f32->bf16
2326 if (acc.getType().isBF16()) {
2327 // Define the min operation to be performed in f32
2328 auto minOpBuilder = [&](Value lhs, Value rhs) -> Value {
2329 auto cmpOp = arith::CmpFOp::create(
2330 rewriter, srcOp.getLoc(), arith::CmpFPredicate::OLT, lhs, rhs);
2331 return arith::SelectOp::create(rewriter, srcOp.getLoc(), cmpOp, lhs,
2332 rhs);
2333 };
2334
2335 // Use helper to handle bf16->f32 conversion, perform min, and convert
2336 // back
2337 performBF16BinaryOpInF32(reduceResult, acc, srcOp, srcOp.getLoc(),
2338 rewriter, minOpBuilder);
2339 } else {
2340 // Non-bf16 path: perform min using cmpf and select
2341 auto cmpOp =
2342 arith::CmpFOp::create(rewriter, srcOp.getLoc(),
2343 arith::CmpFPredicate::OLT, reduceResult, acc);
2344 rewriter.replaceOpWithNewOp<arith::SelectOp>(srcOp, cmpOp, reduceResult,
2345 acc);
2346 }
2347 } else {
2348 rewriter.replaceOp(srcOp, reduceResultOp);
2349 }
2350 return success();
2351 }
2352};
2353
2354struct LowerVectorReductionMaxOp : OpConversionPattern<vector::ReductionOp> {
2355 using OpConversionPattern::OpConversionPattern;
2356
2357 LogicalResult
2358 matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor,
2359 ConversionPatternRewriter &rewriter) const override {
2360 if (auto kind = srcOp.getKind(); kind != vector::CombiningKind::MAXSI &&
2361 kind != vector::CombiningKind::MAXUI &&
2362 kind != vector::CombiningKind::MAXIMUMF &&
2363 kind != vector::CombiningKind::MAXNUMF)
2364 return failure();
2365
2366 auto vType = cast<VectorType>(srcOp.getVector().getType());
2367 Type scalarType = vType.getElementType();
2368 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2369 unsigned laneSize = getVectorLaneSize(vType);
2370 unsigned vectorSize = laneSize * elWidth;
2371
2372 // Support 512-bit vectors directly, and 256-bit bf16 vectors by padding
2373 // Only bf16 is supported for the 256-bit padding path (not f16)
2374 if (vectorSize != 512 && !(vectorSize == 256 && scalarType.isBF16()))
2375 return failure();
2376
2377 Location loc = srcOp.getLoc();
2378 Value inputVec = srcOp.getVector();
2379
2380 // For 256-bit bf16 (v16bf16), pad to 512-bit (v32bf16) with -inf
2381 if (vectorSize == 256) {
2382 std::tie(inputVec, laneSize) = padV16ToV32WithInfinity(
2383 rewriter, loc, srcOp.getVector(), scalarType, /*negativeInf=*/true);
2384 }
2385
2386 int shiftIndex = laneSize / 2;
2387 auto reduceResultOp = generateAIEVecOpsForReductionOp<aievec::MaxOp>(
2388 rewriter, srcOp, shiftIndex, inputVec);
2389
2390 if (srcOp.getAcc()) {
2391 Value reduceResult = reduceResultOp.getResult();
2392 Value acc = srcOp.getAcc();
2393
2394 // If accumulator is bf16, use the high-level helper for bf16->f32->bf16
2395 if (acc.getType().isBF16()) {
2396 // Define the max operation to be performed in f32
2397 auto maxOpBuilder = [&](Value lhs, Value rhs) -> Value {
2398 auto cmpOp = arith::CmpFOp::create(
2399 rewriter, srcOp.getLoc(), arith::CmpFPredicate::OGT, lhs, rhs);
2400 return arith::SelectOp::create(rewriter, srcOp.getLoc(), cmpOp, lhs,
2401 rhs);
2402 };
2403
2404 // Use helper to handle bf16->f32 conversion, perform max, and convert
2405 // back
2406 performBF16BinaryOpInF32(reduceResult, acc, srcOp, srcOp.getLoc(),
2407 rewriter, maxOpBuilder);
2408 } else {
2409 // Non-bf16 path: perform max directly
2410 auto cmpOp =
2411 arith::CmpFOp::create(rewriter, srcOp.getLoc(),
2412 arith::CmpFPredicate::OGT, reduceResult, acc);
2413 rewriter.replaceOpWithNewOp<arith::SelectOp>(srcOp, cmpOp, reduceResult,
2414 acc);
2415 }
2416 } else {
2417 rewriter.replaceOp(srcOp, reduceResultOp);
2418 }
2419 return success();
2420 }
2421};
2422
2424 using OpConversionPattern::OpConversionPattern;
2425
2426 LogicalResult
2427 matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor,
2428 ConversionPatternRewriter &rewriter) const override {
2429 if (auto kind = srcOp.getKind(); kind != vector::CombiningKind::ADD)
2430 return failure();
2431
2432 auto vType = cast<VectorType>(srcOp.getVector().getType());
2433 Type scalarType = vType.getElementType();
2434 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2435 unsigned laneSize = getVectorLaneSize(vType);
2436 llvm::SmallSet<std::pair<unsigned, signed>, 16> laneSizeElWidthPairSet;
2437 laneSizeElWidthPairSet.insert({64, 8});
2438 laneSizeElWidthPairSet.insert({32, 16});
2439 laneSizeElWidthPairSet.insert({32, 32});
2440 laneSizeElWidthPairSet.insert({16, 32});
2441
2442 if (!isa<IntegerType>(scalarType) ||
2443 !laneSizeElWidthPairSet.count(std::make_pair(laneSize, elWidth)))
2444 return failure();
2445
2446 int shiftIndex = laneSize / 2;
2447 if (laneSize == 32 && elWidth == 32) {
2448 Location loc = srcOp.getLoc();
2449 VectorType vecType = createVectorType(laneSize / 2, scalarType);
2450
2451 auto lExtOp =
2452 aievec::ExtOp::create(rewriter, loc, vecType, srcOp.getVector(), 0);
2453 auto rExtOp =
2454 aievec::ExtOp::create(rewriter, loc, vecType, srcOp.getVector(), 1);
2455 auto addElemOp =
2456 aievec::AddElemOp::create(rewriter, loc, lExtOp.getResult().getType(),
2457 lExtOp.getResult(), rExtOp.getResult());
2458 shiftIndex /= 2;
2459 auto reduceResultOp = generateAIEVecOpsForReductionOp<aievec::AddElemOp>(
2460 rewriter, srcOp, shiftIndex, addElemOp.getResult());
2461 if (srcOp.getAcc())
2462 rewriter.replaceOpWithNewOp<arith::AddIOp>(
2463 srcOp, reduceResultOp.getResult(), srcOp.getAcc());
2464 else
2465 rewriter.replaceOp(srcOp, reduceResultOp);
2466 } else {
2467 auto reduceResultOp = generateAIEVecOpsForReductionOp<aievec::AddElemOp>(
2468 rewriter, srcOp, shiftIndex, srcOp.getVector());
2469 if (srcOp.getAcc())
2470 rewriter.replaceOpWithNewOp<arith::AddIOp>(
2471 srcOp, reduceResultOp.getResult(), srcOp.getAcc());
2472 else
2473 rewriter.replaceOp(srcOp, reduceResultOp);
2474 }
2475
2476 return success();
2477 }
2478};
2479
2481 : OpConversionPattern<vector::ReductionOp> {
2482 using OpConversionPattern::OpConversionPattern;
2483
2484 LogicalResult
2485 matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor,
2486 ConversionPatternRewriter &rewriter) const override {
2487 if (auto kind = srcOp.getKind(); kind != vector::CombiningKind::ADD)
2488 return failure();
2489
2490 auto vType = cast<VectorType>(srcOp.getVector().getType());
2491 Type scalarType = vType.getElementType();
2492 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2493 unsigned laneSize = getVectorLaneSize(vType);
2494
2495 if (!isa<FloatType>(scalarType) || laneSize != 16 || elWidth != 32)
2496 return failure();
2497
2498 int shiftIndex = laneSize / 2;
2499 assert(shiftIndex > 0 && (shiftIndex & (shiftIndex - 1)) == 0 &&
2500 "shiftIndex must be power of 2");
2501
2502 Location loc = srcOp.getLoc();
2503 Value curValue = srcOp.getVector();
2504 aievec::CastOp curOp = nullptr;
2505
2506 for (int id = shiftIndex; id > 0; id /= 2) {
2507 auto constOp = arith::ConstantOp::create(
2508 rewriter, loc, rewriter.getI32IntegerAttr(id * elWidth / 8));
2509
2510 auto shiftBytesOp = aievec::ShiftOp::create(
2511 rewriter, loc, vType, curValue, curValue, constOp.getResult());
2512
2513 auto lCastOp = aievec::CastOp::create(rewriter, loc, vType, curValue,
2514 /*isResAcc*/ true);
2515 auto rCastOp =
2516 aievec::CastOp::create(rewriter, loc, vType, shiftBytesOp.getResult(),
2517 /*isResAcc*/ true);
2518 auto elemOp = aievec::AddElemOp::create(
2519 rewriter, loc, lCastOp.getResult().getType(), lCastOp.getResult(),
2520 rCastOp.getResult());
2521 curOp = aievec::CastOp::create(rewriter, loc, vType, elemOp.getResult(),
2522 /*isResAcc*/ false);
2523 curValue = curOp.getResult();
2524 }
2525
2526 auto zeroConstOp =
2527 arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0));
2528 auto reduceResultOp = aievec::ExtElemOp::create(
2529 rewriter, srcOp.getLoc(), scalarType, curOp, zeroConstOp.getResult());
2530
2531 if (srcOp.getAcc())
2532 rewriter.replaceOpWithNewOp<arith::AddFOp>(
2533 srcOp, reduceResultOp.getResult(), srcOp.getAcc());
2534 else
2535 rewriter.replaceOp(srcOp, reduceResultOp);
2536 return success();
2537 }
2538};
2539
2540// AIE2-specific bf16 ADD reduction - requires concat to v32bf16 before ext_elem
2541// due to aie2 ext_elem limitation
2543 : OpConversionPattern<vector::ReductionOp> {
2544 using OpConversionPattern::OpConversionPattern;
2545
2546 LogicalResult
2547 matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor,
2548 ConversionPatternRewriter &rewriter) const override {
2549
2550 if (auto kind = srcOp.getKind(); kind != vector::CombiningKind::ADD) {
2551 return failure();
2552 }
2553
2554 auto vType = cast<VectorType>(srcOp.getVector().getType());
2555 Type scalarType = vType.getElementType();
2556 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2557 unsigned laneSize = getVectorLaneSize(vType);
2558
2559 // Support both lane=16 and lane=32 for bf16
2560 if (!isa<FloatType>(scalarType) || (laneSize != 16 && laneSize != 32) ||
2561 elWidth != 16) {
2562 return failure();
2563 }
2564
2565 Location loc = srcOp.getLoc();
2566 Value curValue = srcOp.getVector();
2567 VectorType currentVType = vType; // Track current working vector type
2568
2569 // For lane=32, split into two v16bf16 halves and add them
2570 if (laneSize == 32) {
2571 VectorType halfType = createVectorType(laneSize / 2, scalarType);
2572 auto lowerHalf =
2573 aievec::ExtOp::create(rewriter, loc, halfType, srcOp.getVector(), 0);
2574 auto upperHalf =
2575 aievec::ExtOp::create(rewriter, loc, halfType, srcOp.getVector(), 1);
2576
2577 Type accType = getVectorOpDestType(halfType, /*AIE2 =*/true);
2578 auto lUpsOp =
2579 aievec::UPSOp::create(rewriter, loc, accType, lowerHalf.getResult());
2580 auto rUpsOp =
2581 aievec::UPSOp::create(rewriter, loc, accType, upperHalf.getResult());
2582 auto addElemOp = aievec::AddElemOp::create(
2583 rewriter, loc, accType, lUpsOp.getResult(), rUpsOp.getResult());
2584 auto shiftParamOp = arith::ConstantOp::create(
2585 rewriter, loc, rewriter.getI32IntegerAttr(0));
2586 auto srsOp =
2587 aievec::SRSOp::create(rewriter, loc, halfType, addElemOp.getResult(),
2588 shiftParamOp.getResult());
2589 curValue = srsOp.getResult();
2590 currentVType = halfType; // Update to v16bf16 after split
2591 }
2592
2593 int shiftIndex = 8; // Always 8 since we work with v16bf16
2594 Type accType = getVectorOpDestType(cast<VectorType>(curValue.getType()),
2595 /*AIE2 =*/true);
2596 unsigned accWidth =
2597 dyn_cast<VectorType>(accType).getElementType().getIntOrFloatBitWidth();
2598
2599 auto upsOp = aievec::UPSOp::create(rewriter, loc, accType, curValue);
2600 curValue = upsOp.getResult();
2601
2602 aievec::AddElemOp curOp = nullptr;
2603
2604 for (int id = shiftIndex; id > 0; id /= 2) {
2605 auto constOp = arith::ConstantOp::create(
2606 rewriter, loc, rewriter.getI32IntegerAttr(id * accWidth / 8));
2607 auto shiftBytesOp = aievec::ShiftOp::create(
2608 rewriter, loc, accType, curValue, curValue, constOp, true);
2609 curOp = aievec::AddElemOp::create(rewriter, loc, accType, curValue,
2610 shiftBytesOp.getResult());
2611 curValue = curOp.getResult();
2612 }
2613
2614 auto shiftParamOp = arith::ConstantOp::create(
2615 rewriter, srcOp.getLoc(), rewriter.getI32IntegerAttr(0));
2616 // Use currentVType instead of vType to ensure lane count matches
2617 auto srsOp =
2618 aievec::SRSOp::create(rewriter, loc, currentVType, curOp.getResult(),
2619 shiftParamOp.getResult());
2620
2621 // AIE2 ext_elem requires v32bf16, so concat v16bf16 to v32bf16
2622 VectorType vecType = createVectorType(32, scalarType);
2623 SmallVector<Value> concatSources = {srsOp.getResult(), srsOp.getResult()};
2624 auto concatOp =
2625 aievec::ConcatOp::create(rewriter, loc, vecType, concatSources);
2626
2627 auto zeroConstOp =
2628 arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0));
2629 auto reduceResultOp =
2630 aievec::ExtElemOp::create(rewriter, srcOp.getLoc(), scalarType,
2631 concatOp, zeroConstOp.getResult());
2632
2633 if (srcOp.getAcc())
2634 rewriter.replaceOpWithNewOp<arith::AddFOp>(
2635 srcOp, reduceResultOp.getResult(), srcOp.getAcc());
2636 else
2637 rewriter.replaceOp(srcOp, reduceResultOp);
2638 return success();
2639 }
2640};
2641
2642// AIE2P-specific bf16 ADD reduction - can extract directly from v16bf16
2644 : OpConversionPattern<vector::ReductionOp> {
2645 using OpConversionPattern::OpConversionPattern;
2646
2647 LogicalResult
2648 matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor,
2649 ConversionPatternRewriter &rewriter) const override {
2650 if (auto kind = srcOp.getKind(); kind != vector::CombiningKind::ADD)
2651 return failure();
2652
2653 auto vType = cast<VectorType>(srcOp.getVector().getType());
2654 Type scalarType = vType.getElementType();
2655 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2656 unsigned laneSize = getVectorLaneSize(vType);
2657
2658 // Support both lane=16 and lane=32 for bf16
2659 if (!isa<FloatType>(scalarType) || (laneSize != 16 && laneSize != 32) ||
2660 elWidth != 16)
2661 return failure();
2662
2663 Location loc = srcOp.getLoc();
2664 int shiftIndex = laneSize / 2;
2665 Value inputToReduce = srcOp.getVector();
2666
2667 // For lane=32, split into two v16bf16 halves, add them, then reduce
2668 if (laneSize == 32) {
2669 VectorType halfType = createVectorType(laneSize / 2, scalarType);
2670
2671 // Extract lower and upper halves
2672 auto lowerHalf =
2673 aievec::ExtOp::create(rewriter, loc, halfType, srcOp.getVector(), 0);
2674 auto upperHalf =
2675 aievec::ExtOp::create(rewriter, loc, halfType, srcOp.getVector(), 1);
2676
2677 // Add the two halves together
2678 Type accType = getVectorOpDestType(halfType, /*AIE2 =*/true);
2679 auto lUpsOp =
2680 aievec::UPSOp::create(rewriter, loc, accType, lowerHalf.getResult());
2681 auto rUpsOp =
2682 aievec::UPSOp::create(rewriter, loc, accType, upperHalf.getResult());
2683 auto addElemOp = aievec::AddElemOp::create(
2684 rewriter, loc, accType, lUpsOp.getResult(), rUpsOp.getResult());
2685 auto shiftParamOp = arith::ConstantOp::create(
2686 rewriter, loc, rewriter.getI32IntegerAttr(0));
2687 auto srsOp =
2688 aievec::SRSOp::create(rewriter, loc, halfType, addElemOp.getResult(),
2689 shiftParamOp.getResult());
2690
2691 inputToReduce = srsOp.getResult();
2692 shiftIndex = 8;
2693 }
2694
2695 // Perform reduction using utility
2696 Type accType = getVectorOpDestType(
2697 cast<VectorType>(inputToReduce.getType()), /*AIE2 =*/true);
2698 unsigned accWidth =
2699 dyn_cast<VectorType>(accType).getElementType().getIntOrFloatBitWidth();
2700
2701 auto upsOp = aievec::UPSOp::create(rewriter, loc, accType, inputToReduce);
2702 Value curValue = upsOp.getResult();
2703
2704 aievec::AddElemOp curOp = nullptr;
2705 for (int id = shiftIndex; id > 0; id /= 2) {
2706 auto constOp = arith::ConstantOp::create(
2707 rewriter, loc, rewriter.getI32IntegerAttr(id * accWidth / 8));
2708 auto shiftBytesOp = aievec::ShiftOp::create(
2709 rewriter, loc, accType, curValue, curValue, constOp, true);
2710 curOp = aievec::AddElemOp::create(rewriter, loc, accType, curValue,
2711 shiftBytesOp.getResult());
2712 curValue = curOp.getResult();
2713 }
2714
2715 // Extract element 0 from the f32 accumulator
2716 // The loop has already fully reduced the vector to a single value in
2717 // element 0
2718 auto zeroConstOp =
2719 arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0));
2720 auto extractedF32 = aievec::ExtElemOp::create(
2721 rewriter, srcOp.getLoc(), rewriter.getF32Type(), curOp.getResult(),
2722 zeroConstOp.getResult());
2723
2724 // Convert extracted f32 to bf16
2725 auto reduceResultBF16 = arith::TruncFOp::create(
2726 rewriter, srcOp.getLoc(), scalarType, extractedF32.getResult());
2727
2728 if (srcOp.getAcc())
2729 rewriter.replaceOpWithNewOp<arith::AddFOp>(srcOp, reduceResultBF16,
2730 srcOp.getAcc());
2731 else
2732 rewriter.replaceOp(srcOp, reduceResultBF16);
2733 return success();
2734 }
2735};
2736
2737// Convert a `vector.extract_strided_slice` op on 1D vectors into an
2738// `aievec.select` + `aievec.ext` op.
2740 : OpConversionPattern<vector::ExtractStridedSliceOp> {
2741 using OpConversionPattern::OpConversionPattern;
2742
2743 LogicalResult
2744 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
2745 ConversionPatternRewriter &rewriter) const override {
2746 auto vType = extractOp.getSourceVectorType();
2747 if (vType.getRank() != 1)
2748 return failure();
2749
2750 int64_t stride = cast<IntegerAttr>(adaptor.getStrides()[0]).getInt();
2751 if (stride != 1)
2752 return failure();
2753
2754 // AIE doesn't support select operations on i8
2755 if (getElementSizeInBits(vType) == 8)
2756 return extractOp.emitError()
2757 << "AIEv1 doesn't support select ops on int8 types";
2758
2759 // We only accept the case where we are extracting a slice half the size of
2760 // the input vector.
2761 int64_t size = cast<IntegerAttr>(adaptor.getSizes()[0]).getInt();
2762 if (vType.getNumElements() != 2 * size)
2763 return failure();
2764
2765 int64_t offset = cast<IntegerAttr>(adaptor.getOffsets()[0]).getInt();
2766 auto selectOp = aievec::aie1::SelectOp::create(
2767 rewriter, extractOp.getLoc(), vType, adaptor.getSource(),
2768 buildAttributeListForRotationSelectOp(rewriter, vType, offset));
2769 rewriter.replaceOpWithNewOp<aievec::aie1::ExtOp>(
2770 extractOp, extractOp.getType(), selectOp.getResult(),
2771 rewriter.getI8IntegerAttr(0));
2772 return success();
2773 }
2774};
2775
2776// Convert a `vector.extract_strided_slice` op on 1D vectors into an
2777// `aievec.shift` op.
2779 : OpConversionPattern<vector::ExtractStridedSliceOp> {
2780 using OpConversionPattern::OpConversionPattern;
2781
2782 LogicalResult
2783 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
2784 ConversionPatternRewriter &rewriter) const override {
2785 auto vType = cast<VectorType>(adaptor.getSource().getType());
2786 if (vType.getRank() != 1)
2787 return failure();
2788
2789 int64_t stride = cast<IntegerAttr>(adaptor.getStrides()[0]).getInt();
2790 if (stride != 1)
2791 return failure();
2792
2793 // We only accept the case where we are extracting a slice half the size of
2794 // the input vector.
2795 int64_t size = cast<IntegerAttr>(adaptor.getSizes()[0]).getInt();
2796 if (vType.getNumElements() != 2 * size)
2797 return failure();
2798
2799 auto shortVecType = cast<VectorType>(extractOp.getResult().getType());
2800 auto bottomHalf =
2801 aievec::ExtOp::create(rewriter, extractOp.getLoc(), shortVecType,
2802 adaptor.getSource(), rewriter.getI8IntegerAttr(0))
2803 .getResult();
2804 auto topHalf =
2805 aievec::ExtOp::create(rewriter, extractOp.getLoc(), shortVecType,
2806 adaptor.getSource(), rewriter.getI8IntegerAttr(1))
2807 .getResult();
2808 int64_t offset = cast<IntegerAttr>(adaptor.getOffsets()[0]).getInt();
2809 int32_t shiftBytes = offset * getElementSizeInBits(vType) / 8;
2810 auto shiftBytesConstOp = arith::ConstantOp::create(
2811 rewriter, extractOp.getLoc(), rewriter.getIntegerType(32),
2812 rewriter.getI32IntegerAttr(shiftBytes));
2813 rewriter.replaceOpWithNewOp<aievec::ShiftOp>(
2814 extractOp, shortVecType, bottomHalf, topHalf, shiftBytesConstOp);
2815
2816 return success();
2817 }
2818};
2819
2820// Replaces a short UPD op with a wide one followed by an ext op of the bottom
2821// half.
2823 using OpConversionPattern::OpConversionPattern;
2824
2825 ExpandUPDToUPDAndExtPattern(MLIRContext *context)
2826 : OpConversionPattern(context) {}
2827
2828 LogicalResult
2829 matchAndRewrite(aievec::UPDOp updOp, OpAdaptor adaptor,
2830 ConversionPatternRewriter &rewriter) const override {
2831 // Verify that we haven't already expanded this one
2832 if (updOp->hasOneUse() && isa<aievec::ExtOp>(*updOp->getUsers().begin()))
2833 return failure();
2834
2835 auto vecType = cast<VectorType>(updOp.getType());
2836 SmallVector<int64_t, 4> vecShape(vecType.getShape().begin(),
2837 vecType.getShape().end());
2838 vecShape[vecType.getRank() - 1] *= 2;
2839 auto longVecType = VectorType::get(vecShape, vecType.getElementType());
2840 auto newUpdOp = aievec::UPDOp::create(
2841 rewriter, updOp.getLoc(), longVecType, adaptor.getSource(),
2842 adaptor.getIndices(), adaptor.getOffset(), adaptor.getIndex(),
2843 adaptor.getVector());
2844 rewriter.replaceOpWithNewOp<aievec::ExtOp>(
2845 updOp, vecType, newUpdOp.getResult(), rewriter.getI8IntegerAttr(0));
2846
2847 return success();
2848 }
2849};
2850
2851// Replaces a wide UPD op followed by an ext op of the bottom half with a short
2852// UPD op.
2854 using OpConversionPattern::OpConversionPattern;
2855
2856 FuseExtIntoUPDPattern(MLIRContext *context) : OpConversionPattern(context) {}
2857
2858 LogicalResult
2859 matchAndRewrite(aievec::ExtOp extOp, OpAdaptor adaptor,
2860 ConversionPatternRewriter &rewriter) const override {
2861 // Verify we are extracting the lower half...
2862 if (extOp.getIndex() != 0)
2863 return failure();
2864 // ...of a UPDOp
2865 auto updOp = dyn_cast<aievec::UPDOp>(extOp.getSource().getDefiningOp());
2866 if (!updOp)
2867 return failure();
2868
2869 // Verify that this is a direct upd -> ext pattern
2870 if (!updOp->hasOneUse())
2871 return failure();
2872
2873 rewriter.replaceOpWithNewOp<aievec::UPDOp>(
2874 extOp, extOp.getType(), updOp.getSource(), updOp.getIndices(),
2875 updOp.getOffset(), updOp.getIndex(), updOp.getVector());
2876
2877 return success();
2878 }
2879};
2880
2881// Convert math.exp to aievec.exp for AIE2P (will be further lowered to exp2
2882// intrinsic)
2884 using OpConversionPattern::OpConversionPattern;
2885
2886 LogicalResult
2887 matchAndRewrite(math::ExpOp expOp, OpAdaptor adaptor,
2888 ConversionPatternRewriter &rewriter) const override {
2889 if (!matchExpOpForAIE2P(adaptor))
2890 return failure();
2891
2892 auto srcType = dyn_cast<VectorType>(adaptor.getOperand().getType());
2893 rewriter.replaceOpWithNewOp<aievec::ExpOp>(expOp, srcType,
2894 adaptor.getOperand());
2895 return success();
2896 }
2897};
2898
2899// Convert math.tanh to aievec.tanh for AIE2P (will be further lowered to tanh
2900// intrinsic)
2902 : OpConversionPattern<math::TanhOp> {
2903 using OpConversionPattern::OpConversionPattern;
2904
2905 LogicalResult
2906 matchAndRewrite(math::TanhOp tanhOp, OpAdaptor adaptor,
2907 ConversionPatternRewriter &rewriter) const override {
2908 auto srcType = dyn_cast<VectorType>(adaptor.getOperand().getType());
2909 if (!srcType)
2910 return failure();
2911
2912 Type scalarType = srcType.getElementType();
2913 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2914 unsigned laneSize = getVectorLaneSize(srcType);
2915 // AIE2P tanh: supports v16bf16 and v32bf16
2916 if (!scalarType.isBF16() || (laneSize != 16 && laneSize != 32) ||
2917 elWidth != 16)
2918 return failure();
2919
2920 rewriter.replaceOpWithNewOp<aievec::TanhOp>(tanhOp, srcType,
2921 adaptor.getOperand());
2922 return success();
2923 }
2924};
2925
2927 using OpConversionPattern::OpConversionPattern;
2928
2929 LogicalResult
2930 matchAndRewrite(math::ExpOp expOp, OpAdaptor adaptor,
2931 ConversionPatternRewriter &rewriter) const override {
2932
2933 if (!matchExpOpForAIE2LUT(adaptor))
2934 return failure();
2935
2936 auto srcType = dyn_cast<VectorType>(adaptor.getOperand().getType());
2937 unsigned laneSize = getVectorLaneSize(srcType);
2938 Location loc = expOp.getLoc();
2939 StringRef funcName = "getExpBf16";
2940
2941 VectorType v16bf16Ty = mlir::VectorType::get({16}, rewriter.getBF16Type());
2942 VectorType v8i64Ty = mlir::VectorType::get({8}, rewriter.getI64Type());
2943 func::FuncOp fnOp = getOrInsertFuncDecl(
2944 rewriter, expOp->getParentWithTrait<OpTrait::SymbolTable>(), funcName,
2945 TypeRange{v16bf16Ty}, TypeRange{v8i64Ty});
2946
2947 // Handle v32bf16 by splitting into two v16bf16 operations
2948 if (laneSize == 32) {
2949 splitWideUnaryVectorOp<math::ExpOp>(
2950 expOp, adaptor.getOperand(), v16bf16Ty, srcType, rewriter,
2951 [&fnOp](Value halfInput, Location loc,
2952 ConversionPatternRewriter &rewriter) -> Value {
2953 VectorType v16bf16Ty =
2954 mlir::VectorType::get({16}, rewriter.getBF16Type());
2955 auto callOp = func::CallOp::create(rewriter, loc, fnOp,
2956 SmallVector<Value>{halfInput});
2957 Type accType = getVectorOpDestType(v16bf16Ty, /*AIE2 =*/true);
2958 auto resCastOp = vector::BitCastOp::create(rewriter, loc, accType,
2959 callOp.getResults());
2960 auto shiftParamOp = arith::ConstantOp::create(
2961 rewriter, loc, rewriter.getI32IntegerAttr(0));
2962 auto srsOp = aievec::SRSOp::create(rewriter, loc, v16bf16Ty,
2963 resCastOp.getResult(),
2964 shiftParamOp.getResult());
2965 return srsOp.getResult();
2966 });
2967 return success();
2968 }
2969
2970 // Handle v16bf16 directly
2971 SmallVector<Value> expOperands = {adaptor.getOperand()};
2972
2973 Type accTypeNative = getVectorOpDestType(srcType, /*AIE2 =*/true);
2974 auto callOp = func::CallOp::create(rewriter, loc, fnOp, expOperands);
2975 auto resCastOp = vector::BitCastOp::create(rewriter, loc, accTypeNative,
2976 callOp.getResults());
2977 auto shiftParamOp =
2978 arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0));
2979 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
2980 expOp, srcType, resCastOp.getResult(), shiftParamOp.getResult());
2981
2982 return success();
2983 }
2984};
2985// Lower ExpOp to function call
2987 using OpConversionPattern::OpConversionPattern;
2988
2989 LogicalResult
2990 matchAndRewrite(math::ExpOp expOp, OpAdaptor adaptor,
2991 ConversionPatternRewriter &rewriter) const override {
2992 if (!matchExpOpForAIE2LUT(adaptor))
2993 return failure();
2994 auto srcType = dyn_cast<VectorType>(adaptor.getOperand().getType());
2995 StringRef includeName = "lut_based_ops.h";
2996 auto moduleOp = expOp->getParentOfType<mlir::ModuleOp>();
2997 rewriter.setInsertionPointToStart(
2998 &moduleOp.getRegion().getBlocks().front());
2999 emitc::IncludeOp::create(rewriter, moduleOp.getLoc(), includeName, false);
3000
3001 rewriter.setInsertionPoint(expOp);
3002
3003 auto v16bf16OpaqueTy =
3004 emitc::OpaqueType::get(rewriter.getContext(), "v16bfloat16");
3005 auto opaquedOperand =
3006 UnrealizedConversionCastOp::create(
3007 rewriter, expOp.getLoc(), v16bf16OpaqueTy, adaptor.getOperand())
3008 .getResult(0);
3009 SmallVector<Value> expOperands = {opaquedOperand};
3010
3011 Type accTypeNative = getVectorOpDestType(srcType, /*AIE2 =*/true);
3012 Type v16accf32OpaqueTy =
3013 emitc::OpaqueType::get(rewriter.getContext(), "v16accfloat");
3014 auto callOp = emitc::CallOpaqueOp::create(
3015 rewriter, expOp.getLoc(), TypeRange{v16accf32OpaqueTy}, "getExpBf16",
3016 nullptr, nullptr, expOperands);
3017 auto resCastOp = UnrealizedConversionCastOp::create(
3018 rewriter, expOp.getLoc(), accTypeNative, callOp.getResults());
3019 auto shiftParamOp = arith::ConstantOp::create(
3020 rewriter, expOp.getLoc(), rewriter.getI32IntegerAttr(0));
3021 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
3022 expOp, srcType, resCastOp.getResult(0), shiftParamOp.getResult());
3023
3024 return success();
3025 }
3026};
3027
3028// Lower the inverse of a float to a function call (CPP backend)
3029// Convert the pattern-
3030// %cst = arith.constant 1.000000e+00 : f32
3031// %0 = arith.divf %cst, %arg1 : f32
3032// %1 = arith.truncf %0 : f32 to bf16
3033// to -
3034// %0 = emitc.call "getInvBf16"(%0) : f32 -> bf16;
3036 using OpConversionPattern::OpConversionPattern;
3037
3038 LogicalResult
3039 matchAndRewrite(arith::DivFOp divOp, OpAdaptor adaptor,
3040 ConversionPatternRewriter &rewriter) const override {
3041 Type srcType = adaptor.getLhs().getType();
3042 if (!divOp->hasOneUse() || isa<VectorType>(srcType) ||
3043 !isa<FloatType>(srcType))
3044 return failure();
3045
3046 if (!isNarrowingOp(*divOp->getUsers().begin()))
3047 return failure();
3048
3049 auto fType = cast<FloatType>(srcType);
3050 if (fType.getWidth() != 32)
3051 return failure();
3052
3053 auto constOp = dyn_cast<arith::ConstantOp>(divOp.getLhs().getDefiningOp());
3054 if (!constOp ||
3055 cast<FloatAttr>(constOp.getValue()).getValue().convertToDouble() !=
3056 1.0f)
3057 return failure();
3058
3059 StringRef includeName = "lut_based_ops.h";
3060 auto moduleOp = divOp->getParentOfType<mlir::ModuleOp>();
3061 rewriter.setInsertionPointToStart(
3062 &moduleOp.getRegion().getBlocks().front());
3063 emitc::IncludeOp::create(rewriter, moduleOp.getLoc(), includeName, false);
3064
3065 auto truncOp = cast<arith::TruncFOp>(*divOp->getUsers().begin());
3066
3067 rewriter.setInsertionPoint(truncOp);
3068 Type bf16OpaqueTy =
3069 emitc::OpaqueType::get(rewriter.getContext(), "bfloat16");
3070 SmallVector<Value> invOperands = {adaptor.getRhs()};
3071 auto callOp = emitc::CallOpaqueOp::create(rewriter, truncOp.getLoc(),
3072 bf16OpaqueTy, "getInvBf16",
3073 nullptr, nullptr, invOperands);
3074 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
3075 truncOp, TypeRange{truncOp.getResult().getType()}, callOp.getResults());
3076 rewriter.eraseOp(divOp);
3077
3078 return success();
3079 }
3080};
3081
3082// Lower the inverse of a float to aievec.inv (LLVMIR backend for AIE2P)
3083// Supports both scalar f32 and vector<Nxf32> types.
3084// Convert the pattern-
3085// %cst = arith.constant 1.000000e+00 : f32
3086// %0 = arith.divf %cst, %arg1 : f32
3087// to -
3088// %0 = aievec.inv %arg1 : f32
3089// Also supports:
3090// %cst = arith.constant dense<1.0> : vector<16xf32>
3091// %0 = arith.divf %cst, %arg1 : vector<16xf32>
3092// to -
3093// %0 = aievec.inv %arg1 : vector<16xf32>
3095 using OpConversionPattern::OpConversionPattern;
3096
3097 LogicalResult
3098 matchAndRewrite(arith::DivFOp divOp, OpAdaptor adaptor,
3099 ConversionPatternRewriter &rewriter) const override {
3100 Type srcType = adaptor.getLhs().getType();
3101
3102 // Check if LHS is defined by an operation
3103 auto *defOp = divOp.getLhs().getDefiningOp();
3104 if (!defOp)
3105 return failure();
3106
3107 auto constOp = dyn_cast<arith::ConstantOp>(defOp);
3108 if (!constOp)
3109 return failure();
3110
3111 // Handle scalar f32 case
3112 if (auto fType = dyn_cast<FloatType>(srcType)) {
3113 if (fType.getWidth() != 32)
3114 return failure();
3115
3116 auto floatAttr = dyn_cast<FloatAttr>(constOp.getValue());
3117 if (!floatAttr || !floatAttr.getValue().isExactlyValue(1.0))
3118 return failure();
3119
3120 rewriter.replaceOpWithNewOp<aievec::InvOp>(divOp, srcType,
3121 adaptor.getRhs());
3122 return success();
3123 }
3124
3125 // Handle vector f32 case
3126 if (auto vecType = dyn_cast<VectorType>(srcType)) {
3127 auto elemType = vecType.getElementType();
3128 if (!elemType.isF32())
3129 return failure();
3130
3131 // Check for supported vector sizes (16 or 32 lanes)
3132 unsigned laneSize = getVectorLaneSize(vecType);
3133 if (laneSize != 16 && laneSize != 32)
3134 return failure();
3135
3136 // Check if it's a splat of 1.0
3137 auto denseAttr = dyn_cast<DenseFPElementsAttr>(constOp.getValue());
3138 if (!denseAttr || !denseAttr.isSplat())
3139 return failure();
3140
3141 if (!denseAttr.getSplatValue<APFloat>().isExactlyValue(1.0))
3142 return failure();
3143
3144 rewriter.replaceOpWithNewOp<aievec::InvOp>(divOp, vecType,
3145 adaptor.getRhs());
3146 return success();
3147 }
3148
3149 return failure();
3150 }
3151};
3152
3153// Convert math.tanh to a function call to compute tanh(x) by look up tables
3155 using OpConversionPattern::OpConversionPattern;
3156
3157 LogicalResult
3158 matchAndRewrite(math::TanhOp tanhOp, OpAdaptor adaptor,
3159 ConversionPatternRewriter &rewriter) const override {
3160 auto srcType = dyn_cast<VectorType>(tanhOp.getOperand().getType());
3161 if (!srcType)
3162 return failure();
3163
3164 Type scalarType = srcType.getElementType();
3165 if (!isa<FloatType>(scalarType))
3166 return failure();
3167
3168 unsigned laneSize = getVectorLaneSize(srcType);
3169 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3170 if (elWidth != 16 || laneSize != 16)
3171 return failure();
3172
3173 StringRef includeName = "lut_based_ops.h";
3174 auto moduleOp = tanhOp->getParentOfType<mlir::ModuleOp>();
3175 rewriter.setInsertionPointToStart(
3176 &moduleOp.getRegion().getBlocks().front());
3177 emitc::IncludeOp::create(rewriter, moduleOp.getLoc(), includeName, false);
3178
3179 rewriter.setInsertionPoint(tanhOp);
3180 Type v16bf16OpaqueTy =
3181 emitc::OpaqueType::get(rewriter.getContext(), "v16bfloat16");
3182 auto opaquedOperand =
3183 UnrealizedConversionCastOp::create(
3184 rewriter, tanhOp.getLoc(), v16bf16OpaqueTy, adaptor.getOperand())
3185 .getResult(0);
3186 SmallVector<Value> tanhOperands = {opaquedOperand};
3187 auto callOp = emitc::CallOpaqueOp::create(rewriter, tanhOp.getLoc(),
3188 v16bf16OpaqueTy, "getTanhBf16",
3189 nullptr, nullptr, tanhOperands);
3190 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
3191 tanhOp, TypeRange{tanhOp.getResult().getType()}, callOp.getResults());
3192
3193 return success();
3194 }
3195};
3196
3197// Convert math.sqrt to a function call to compute sqrt(x) for v16bfloat16 and
3198// v32bfloat16 types
3200 using OpConversionPattern::OpConversionPattern;
3201
3202 LogicalResult
3203 matchAndRewrite(math::SqrtOp sqrtOp, OpAdaptor adaptor,
3204 ConversionPatternRewriter &rewriter) const override {
3205 auto srcType = dyn_cast<VectorType>(sqrtOp.getOperand().getType());
3206 if (!srcType)
3207 return failure();
3208
3209 Type scalarType = srcType.getElementType();
3210 if (!isa<FloatType>(scalarType))
3211 return failure();
3212
3213 unsigned laneSize = getVectorLaneSize(srcType);
3214 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3215 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
3216 return failure();
3217
3218 StringRef includeName = "vec_math.h";
3219 auto moduleOp = sqrtOp->getParentOfType<mlir::ModuleOp>();
3220 rewriter.setInsertionPointToStart(
3221 &moduleOp.getRegion().getBlocks().front());
3222 emitc::IncludeOp::create(rewriter, moduleOp.getLoc(), includeName, false);
3223
3224 rewriter.setInsertionPoint(sqrtOp);
3225 Type vLNbf16OpaqueTy;
3226 if (laneSize == 16)
3227 vLNbf16OpaqueTy =
3228 emitc::OpaqueType::get(rewriter.getContext(), "v16bfloat16");
3229 else
3230 vLNbf16OpaqueTy =
3231 emitc::OpaqueType::get(rewriter.getContext(), "v32bfloat16");
3232 auto opaquedOperand =
3233 UnrealizedConversionCastOp::create(
3234 rewriter, sqrtOp.getLoc(), vLNbf16OpaqueTy, adaptor.getOperand())
3235 .getResult(0);
3236 SmallVector<Value> sqrtOperands = {opaquedOperand};
3237 auto callOp = emitc::CallOpaqueOp::create(
3238 rewriter, sqrtOp.getLoc(), TypeRange{vLNbf16OpaqueTy}, "getSqrtBf16",
3239 nullptr, nullptr, sqrtOperands);
3240 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
3241 sqrtOp, TypeRange{sqrtOp.getResult().getType()}, callOp.getResults());
3242
3243 return success();
3244 }
3245};
3246
3248 using OpConversionPattern::OpConversionPattern;
3249
3250 LogicalResult
3251 matchAndRewrite(math::RsqrtOp rsqrtOp, OpAdaptor adaptor,
3252 ConversionPatternRewriter &rewriter) const override {
3253 auto srcType = dyn_cast<VectorType>(adaptor.getOperand().getType());
3254 if (!srcType)
3255 return failure();
3256
3257 Type scalarType = srcType.getElementType();
3258 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3259 unsigned laneSize = getVectorLaneSize(srcType);
3260
3261 // Only support v16bf16 for LLVM backend
3262 if (!isa<FloatType>(scalarType) || laneSize != 16 || elWidth != 16)
3263 return failure();
3264
3265 StringRef funcName = "getRsqrtBf16";
3266
3267 VectorType v16bf16Ty = mlir::VectorType::get({16}, rewriter.getBF16Type());
3268 VectorType v8i64Ty = mlir::VectorType::get({8}, rewriter.getI64Type());
3269 func::FuncOp fnOp = getOrInsertFuncDecl(
3270 rewriter, rsqrtOp->getParentWithTrait<OpTrait::SymbolTable>(), funcName,
3271 TypeRange{v16bf16Ty}, TypeRange{v8i64Ty});
3272
3273 SmallVector<Value> rsqrtOperands = {adaptor.getOperand()};
3274
3275 Type accTypeNative = getVectorOpDestType(srcType, /*AIE2 =*/true);
3276 auto callOp =
3277 func::CallOp::create(rewriter, rsqrtOp.getLoc(), fnOp, rsqrtOperands);
3278 auto resCastOp = vector::BitCastOp::create(
3279 rewriter, rsqrtOp.getLoc(), accTypeNative, callOp.getResults());
3280 auto shiftParamOp = arith::ConstantOp::create(
3281 rewriter, rsqrtOp.getLoc(), rewriter.getI32IntegerAttr(0));
3282 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
3283 rsqrtOp, srcType, resCastOp.getResult(), shiftParamOp.getResult());
3284
3285 return success();
3286 }
3287};
3288
3289// Convert math.rsqrt to a function call to compute 1.0f / sqrt(x) for
3290// v16bfloat16 and v32bfloat16 types
3292 using OpConversionPattern::OpConversionPattern;
3293
3294 LogicalResult
3295 matchAndRewrite(math::RsqrtOp rsqrtOp, OpAdaptor adaptor,
3296 ConversionPatternRewriter &rewriter) const override {
3297 auto srcType = dyn_cast<VectorType>(rsqrtOp.getOperand().getType());
3298 if (!srcType)
3299 return failure();
3300
3301 Type scalarType = srcType.getElementType();
3302 if (!isa<FloatType>(scalarType))
3303 return failure();
3304
3305 unsigned laneSize = getVectorLaneSize(srcType);
3306 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3307 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
3308 return failure();
3309
3310 StringRef includeName = "vec_math.h";
3311 auto moduleOp = rsqrtOp->getParentOfType<mlir::ModuleOp>();
3312 rewriter.setInsertionPointToStart(
3313 &moduleOp.getRegion().getBlocks().front());
3314 emitc::IncludeOp::create(rewriter, moduleOp.getLoc(), includeName, false);
3315
3316 rewriter.setInsertionPoint(rsqrtOp);
3317 Type vLNbf16OpaqueTy;
3318 if (laneSize == 16)
3319 vLNbf16OpaqueTy =
3320 emitc::OpaqueType::get(rewriter.getContext(), "v16bfloat16");
3321 else
3322 vLNbf16OpaqueTy =
3323 emitc::OpaqueType::get(rewriter.getContext(), "v32bfloat16");
3324 auto opaquedOperand =
3325 UnrealizedConversionCastOp::create(
3326 rewriter, rsqrtOp.getLoc(), vLNbf16OpaqueTy, adaptor.getOperand())
3327 .getResult(0);
3328 SmallVector<Value> rsqrtOperands = {opaquedOperand};
3329 auto callOp = emitc::CallOpaqueOp::create(
3330 rewriter, rsqrtOp.getLoc(), TypeRange{vLNbf16OpaqueTy}, "getRsqrtBf16",
3331 nullptr, nullptr, rsqrtOperands);
3332 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
3333 rsqrtOp, TypeRange{rsqrtOp.getResult().getType()}, callOp.getResults());
3334
3335 return success();
3336 }
3337};
3338
3339// Convert math.erf to a function call to compute erf(x) for v16bfloat16 and
3340// v32bfloat16 types
3342 using OpConversionPattern::OpConversionPattern;
3343
3344 LogicalResult
3345 matchAndRewrite(math::ErfOp erfOp, OpAdaptor adaptor,
3346 ConversionPatternRewriter &rewriter) const override {
3347 auto srcType = dyn_cast<VectorType>(erfOp.getOperand().getType());
3348 if (!srcType)
3349 return failure();
3350
3351 Type scalarType = srcType.getElementType();
3352 if (!isa<FloatType>(scalarType))
3353 return failure();
3354
3355 unsigned laneSize = getVectorLaneSize(srcType);
3356 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3357 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
3358 return failure();
3359
3360 StringRef includeName = "vec_math.h";
3361 auto moduleOp = erfOp->getParentOfType<mlir::ModuleOp>();
3362 rewriter.setInsertionPointToStart(
3363 &moduleOp.getRegion().getBlocks().front());
3364 emitc::IncludeOp::create(rewriter, moduleOp.getLoc(), includeName, false);
3365
3366 rewriter.setInsertionPoint(erfOp);
3367 Type vLNbf16OpaqueTy;
3368 if (laneSize == 16)
3369 vLNbf16OpaqueTy =
3370 emitc::OpaqueType::get(rewriter.getContext(), "v16bfloat16");
3371 else
3372 vLNbf16OpaqueTy =
3373 emitc::OpaqueType::get(rewriter.getContext(), "v32bfloat16");
3374 auto opaquedOperand =
3375 UnrealizedConversionCastOp::create(
3376 rewriter, erfOp.getLoc(), vLNbf16OpaqueTy, adaptor.getOperand())
3377 .getResult(0);
3378 SmallVector<Value> erfOperands = {opaquedOperand};
3379 auto callOp = emitc::CallOpaqueOp::create(
3380 rewriter, erfOp.getLoc(), TypeRange{vLNbf16OpaqueTy}, "getErfBf16",
3381 nullptr, nullptr, erfOperands);
3382 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
3383 erfOp, TypeRange{erfOp.getResult().getType()}, callOp.getResults());
3384
3385 return success();
3386 }
3387};
3388
3389// Convert math.absf and math.absi to a function call to compute abs(x) for
3390// v16bfloat16, v32bfloat16, v16float, v16int32, v32int16 and v64int8 types
3391template <typename SrcOpTy>
3394 using OpAdaptor = typename SrcOpTy::Adaptor;
3395
3396 LogicalResult
3397 matchAndRewrite(SrcOpTy absOp, OpAdaptor adaptor,
3398 ConversionPatternRewriter &rewriter) const override {
3399 auto vecTy = dyn_cast<VectorType>(absOp.getOperand().getType());
3400 if (!vecTy)
3401 return failure();
3402
3403 Type elemTy = vecTy.getElementType();
3404
3405 unsigned laneSize = getVectorLaneSize(vecTy);
3406 unsigned elWidth = elemTy.getIntOrFloatBitWidth();
3407
3408 StringRef includeName = "vec_math.h";
3409 auto moduleOp = absOp->template getParentOfType<mlir::ModuleOp>();
3410 rewriter.setInsertionPointToStart(
3411 &moduleOp.getRegion().getBlocks().front());
3412 emitc::IncludeOp::create(rewriter, moduleOp.getLoc(), includeName, false);
3413
3414 rewriter.setInsertionPoint(absOp);
3415 std::ostringstream typeName;
3416 typeName << "v" << laneSize;
3417 if (isa<FloatType>(elemTy)) {
3418 if (elWidth == 16)
3419 typeName << "bfloat16";
3420 else
3421 typeName << "float";
3422 } else
3423 typeName << "int" << elWidth;
3424 Type vecOpaqueTy =
3425 emitc::OpaqueType::get(rewriter.getContext(), typeName.str());
3426 auto opaquedOperand =
3427 UnrealizedConversionCastOp::create(rewriter, absOp.getLoc(),
3428 vecOpaqueTy, adaptor.getOperand())
3429 .getResult(0);
3430 SmallVector<Value> absOperands = {opaquedOperand};
3431 auto callOp = emitc::CallOpaqueOp::create(rewriter, absOp.getLoc(),
3432 TypeRange{vecOpaqueTy}, "getAbs",
3433 nullptr, nullptr, absOperands);
3434 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
3435 absOp, TypeRange{absOp.getResult().getType()}, callOp.getResults());
3436
3437 return success();
3438 }
3439};
3440
3443
3444template <typename SrcOpTy>
3447 using OpAdaptor = typename SrcOpTy::Adaptor;
3448
3449 LogicalResult
3450 matchAndRewrite(SrcOpTy extOp, OpAdaptor adaptor,
3451 ConversionPatternRewriter &rewriter) const override {
3452 VectorType srcType = dyn_cast<VectorType>(extOp.getIn().getType());
3453 VectorType dstType = dyn_cast<VectorType>(extOp.getOut().getType());
3454
3455 Type scalarType = dstType.getElementType();
3456 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3457 auto accType =
3458 isa<IntegerType>(scalarType) && (elWidth == 32 || elWidth == 64)
3459 ? dstType
3460 : getVectorOpDestType(srcType, /*AIE2 =*/true);
3461 auto upsOp =
3462 aievec::UPSOp::create(rewriter, extOp.getLoc(), accType, extOp.getIn());
3463
3464 if (dstType.getElementType().getIntOrFloatBitWidth() == 16) {
3465 auto shiftParamOp = arith::ConstantOp::create(
3466 rewriter, extOp.getLoc(), rewriter.getI32IntegerAttr(0));
3467 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
3468 extOp, dstType, upsOp.getResult(), shiftParamOp.getResult());
3469 } else
3470 rewriter.replaceOpWithNewOp<aievec::CastOp>(
3471 extOp, dstType, upsOp.getResult(), /*isResAcc*/ false);
3472
3473 return success();
3474 }
3475};
3476
3479
3480template <typename SrcOpTy>
3483 using OpAdaptor = typename SrcOpTy::Adaptor;
3484
3485 LogicalResult
3486 matchAndRewrite(SrcOpTy truncOp, OpAdaptor adaptor,
3487 ConversionPatternRewriter &rewriter) const override {
3488 VectorType srcType = dyn_cast<VectorType>(truncOp.getIn().getType());
3489 VectorType dstType = dyn_cast<VectorType>(truncOp.getOut().getType());
3490 Type scalarType = srcType.getElementType();
3491 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3492 auto accType =
3493 isa<IntegerType>(scalarType) && (elWidth == 32 || elWidth == 64)
3494 ? srcType
3495 : getVectorOpDestType(srcType, /*AIE2 =*/true);
3496
3497 auto shiftParamOp = arith::ConstantOp::create(
3498 rewriter, truncOp.getLoc(), rewriter.getI32IntegerAttr(0));
3499 if (elWidth == 16) {
3500 auto upsOp = aievec::UPSOp::create(rewriter, truncOp.getLoc(), accType,
3501 truncOp.getIn());
3502 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
3503 truncOp, dstType, upsOp.getResult(), shiftParamOp.getResult());
3504 } else {
3505 auto castOp = aievec::CastOp::create(rewriter, truncOp.getLoc(), accType,
3506 truncOp.getIn(), true);
3507 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
3508 truncOp, dstType, castOp.getResult(), shiftParamOp.getResult());
3509 }
3510
3511 return success();
3512 }
3513};
3514
3517
3518// If `op` is the last operation in the sequence:
3519// %0 = unrealized_conversion_cast <%IN> : <native type>, !emitc.opaque_type
3520// %1 = emitc.call_opaque <funcName>, %0...
3521// %2 = unrealized_conversion_cast %1 : !emitc.opaque_type, <native type>
3522// return the value <%IN>.
3523static std::optional<Value>
3524getUnOpaquedOperandOfEmitCOpaqueCallOp(Operation *op, StringRef funcName) {
3525 auto uccOp = dyn_cast<UnrealizedConversionCastOp>(op);
3526 if (!uccOp)
3527 return {};
3528
3529 auto inVal = uccOp.getInputs()[0];
3530 if (!isa<emitc::OpaqueType>(inVal.getType()))
3531 return {};
3532
3533 auto callOp = inVal.getDefiningOp<emitc::CallOpaqueOp>();
3534 if (callOp.getCallee() != funcName)
3535 return {};
3536
3537 auto callOperandsUccOp =
3538 callOp.getOperands()[0].getDefiningOp<UnrealizedConversionCastOp>();
3539 if (!callOperandsUccOp)
3540 return {};
3541
3542 return callOperandsUccOp.getInputs()[0];
3543}
3544
3545// Check there is an operation chain like-
3546//
3547// %cst_0 = arith.constant dense<1.000000e+00> : vector<16xbf16>
3548// %cst_1 = arith.constant 0.000000e+00 : bf16
3549// %0 = vector.transfer_read %arg0[%arg2], %cst_1 : memref<1024xbf16>,
3550// vector<16xbf16>
3551// %1 = arith.negf %0 : vector<16xbf16>
3552// %2 = math.exp %1 : vector<16xbf16>
3553// %3 = arith.addf %2, %cst_0 : vector<16xbf16>
3554// %4 = arith.divf %cst_0, %3 : vector<16xbf16>
3555//
3556// so that this operation chain can be converted to a function call to compute
3557// sigmoid value for v16bfloat16 and v32bfloat16 types
3558template <typename DivFOpTy>
3559static bool hasSigmoidComputationChain(DivFOpTy divfOp, arith::NegFOp &negOp) {
3560 auto *lhsDefOp = divfOp.getLhs().getDefiningOp();
3561 if (!lhsDefOp)
3562 return false;
3563 auto constOp = dyn_cast<arith::ConstantOp>(lhsDefOp);
3564 if (!constOp)
3565 return false;
3566
3567 auto cstDense = dyn_cast<DenseFPElementsAttr>(constOp.getValue());
3568 if (!cstDense)
3569 return false;
3570
3571 if (cstDense.template getSplatValue<APFloat>().convertToFloat() != 1.0f)
3572 return false;
3573
3574 Operation *addLvalOp;
3575 Operation *addRvalOp;
3576 // divfOp's rval could be an arith::AddFOp or the pattern like-
3577 // %1 = aievec.ups %a
3578 // %2 = aievec.ups %b;
3579 // %3 = aievec.add_elem %1, %2
3580 // %4 = aievec.srs %3;
3581 auto *rhsDefOp = divfOp.getRhs().getDefiningOp();
3582 if (!rhsDefOp)
3583 return false;
3584 auto addOp = dyn_cast<arith::AddFOp>(rhsDefOp);
3585 if (!addOp) {
3586 auto srsOp = dyn_cast<aievec::SRSOp>(rhsDefOp);
3587 if (!srsOp)
3588 return false;
3589
3590 auto addElemOp =
3591 dyn_cast<aievec::AddElemOp>(srsOp.getSource().getDefiningOp());
3592 if (!addElemOp)
3593 return false;
3594
3595 auto lUpsOp = dyn_cast<aievec::UPSOp>(addElemOp.getLhs().getDefiningOp());
3596 auto rUpsOp = dyn_cast<aievec::UPSOp>(addElemOp.getRhs().getDefiningOp());
3597 if (!lUpsOp || !rUpsOp)
3598 return false;
3599
3600 addLvalOp = lUpsOp.getSource().getDefiningOp();
3601 addRvalOp = rUpsOp.getSource().getDefiningOp();
3602 // One of add operation's operand is a constant op and another operand could
3603 // be arith::ExpOp or the combination of emitc.call and aievec.srs
3604 auto addDefOp = isa<arith::ConstantOp>(addLvalOp)
3605 ? dyn_cast<aievec::SRSOp>(addRvalOp)
3606 : dyn_cast<aievec::SRSOp>(addLvalOp);
3607 if (!addDefOp)
3608 addLvalOp = isa<arith::ConstantOp>(addLvalOp)
3609 ? dyn_cast<math::ExpOp>(addRvalOp)
3610 : dyn_cast<math::ExpOp>(addLvalOp);
3611 else
3612 addLvalOp = addDefOp.getSource().getDefiningOp();
3613
3614 addRvalOp = isa<arith::ConstantOp>(addLvalOp)
3615 ? lUpsOp.getSource().getDefiningOp()
3616 : rUpsOp.getSource().getDefiningOp();
3617 } else {
3618 addLvalOp = addOp.getLhs().getDefiningOp();
3619 addRvalOp = addOp.getRhs().getDefiningOp();
3620 }
3621
3622 if (!addLvalOp || !addRvalOp)
3623 return false;
3624
3625 auto addLvalExpOp = dyn_cast<math::ExpOp>(addLvalOp);
3626 auto addRvalExpOp = dyn_cast<math::ExpOp>(addRvalOp);
3627 auto addLvalExpOpIn =
3628 getUnOpaquedOperandOfEmitCOpaqueCallOp(addLvalOp, "getExpBf16")
3629 .value_or(nullptr);
3630 auto addRvalExpOpIn =
3631 getUnOpaquedOperandOfEmitCOpaqueCallOp(addRvalOp, "getExpBf16")
3632 .value_or(nullptr);
3633 if (!addLvalExpOpIn && addLvalExpOp)
3634 addLvalExpOpIn = addLvalExpOp.getOperand();
3635 if (!addRvalExpOpIn && addRvalExpOp)
3636 addRvalExpOpIn = addRvalExpOp.getOperand();
3637
3638 if (!((addLvalExpOpIn && isa<arith::ConstantOp>(addRvalOp)) ||
3639 (addRvalExpOpIn && isa<arith::ConstantOp>(addLvalOp))))
3640 return false;
3641
3642 constOp = isa<arith::ConstantOp>(addLvalOp)
3643 ? cast<arith::ConstantOp>(addLvalOp)
3644 : cast<arith::ConstantOp>(addRvalOp);
3645
3646 cstDense = dyn_cast<DenseFPElementsAttr>(constOp.getValue());
3647 if (!cstDense)
3648 return false;
3649 if (cstDense.template getSplatValue<APFloat>().convertToFloat() != 1.0f)
3650 return false;
3651
3652 auto expOperand = addLvalExpOpIn ? addLvalExpOpIn : addRvalExpOpIn;
3653
3654 negOp = expOperand.getDefiningOp<arith::NegFOp>();
3655
3656 return negOp != nullptr;
3657}
3658
3659// Convert the operation chain like-
3660//
3661// %cst_0 = arith.constant dense<1.000000e+00> : vector<16xbf16>
3662// %cst_1 = arith.constant 0.000000e+00 : bf16
3663// %0 = vector.transfer_read %arg0[%arg2], %cst_1 : memref<1024xbf16>,
3664// vector<16xbf16>
3665// %1 = arith.negf %0 : vector<16xbf16>
3666// %2 = math.exp %1 :vector<16xbf16>
3667// %3 = arith.addf %2, %cst_0 : vector<16xbf16>
3668// %4 = arith.divf %cst_0, %3 : vector<16xbf16>
3669//
3670// to a function call to compute sigmoid value for v16bfloat16 and
3671// v32bfloat16 types
3673 using OpConversionPattern::OpConversionPattern;
3674
3675 LogicalResult
3676 matchAndRewrite(arith::DivFOp divfOp, OpAdaptor adaptor,
3677 ConversionPatternRewriter &rewriter) const override {
3678 auto srcType = dyn_cast<VectorType>(adaptor.getLhs().getType());
3679 if (!srcType)
3680 return failure();
3681
3682 Type scalarType = srcType.getElementType();
3683 if (!isa<FloatType>(scalarType))
3684 return failure();
3685
3686 unsigned laneSize = getVectorLaneSize(srcType);
3687 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3688 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
3689 return failure();
3690
3691 arith::NegFOp negOp = nullptr;
3692 if (!hasSigmoidComputationChain(adaptor, negOp))
3693 return failure();
3694
3695 StringRef includeName = "vec_math.h";
3696 auto moduleOp = divfOp->getParentOfType<mlir::ModuleOp>();
3697 rewriter.setInsertionPointToStart(
3698 &moduleOp.getRegion().getBlocks().front());
3699 emitc::IncludeOp::create(rewriter, moduleOp.getLoc(), includeName, false);
3700
3701 rewriter.setInsertionPoint(divfOp);
3702 Type vecOpaqueTy;
3703 if (laneSize == 16)
3704 vecOpaqueTy =
3705 emitc::OpaqueType::get(rewriter.getContext(), "v16bfloat16");
3706 else
3707 vecOpaqueTy =
3708 emitc::OpaqueType::get(rewriter.getContext(), "v32bfloat16");
3709 auto opaquedOperand =
3710 UnrealizedConversionCastOp::create(rewriter, divfOp.getLoc(),
3711 vecOpaqueTy, negOp.getOperand())
3712 .getResult(0);
3713 SmallVector<Value> sigmoidOperands = {opaquedOperand};
3714 auto callOp = emitc::CallOpaqueOp::create(
3715 rewriter, divfOp.getLoc(), TypeRange{vecOpaqueTy}, "getSigmoidBf16",
3716 nullptr, nullptr, sigmoidOperands);
3717 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
3718 divfOp, TypeRange{adaptor.getLhs().getType()}, callOp.getResults());
3719
3720 return success();
3721 }
3722};
3723
3724// Convert math.ceil to a function call to compute ceil(x) for v16bfloat16
3726 using OpConversionPattern::OpConversionPattern;
3727
3728 LogicalResult
3729 matchAndRewrite(math::CeilOp ceilOp, OpAdaptor adaptor,
3730 ConversionPatternRewriter &rewriter) const override {
3731 auto srcType = dyn_cast<VectorType>(ceilOp.getOperand().getType());
3732 if (!srcType)
3733 return failure();
3734
3735 Type scalarType = srcType.getElementType();
3736 if (!isa<FloatType>(scalarType))
3737 return failure();
3738
3739 unsigned laneSize = getVectorLaneSize(srcType);
3740 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3741 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
3742 return failure();
3743
3744 StringRef includeName = "vec_math.h";
3745 auto moduleOp = ceilOp->getParentOfType<mlir::ModuleOp>();
3746 rewriter.setInsertionPointToStart(
3747 &moduleOp.getRegion().getBlocks().front());
3748 emitc::IncludeOp::create(rewriter, moduleOp.getLoc(), includeName, false);
3749
3750 rewriter.setInsertionPoint(ceilOp);
3751 Type vecOpaqueTy;
3752 if (laneSize == 16)
3753 vecOpaqueTy =
3754 emitc::OpaqueType::get(rewriter.getContext(), "v16bfloat16");
3755 else
3756 vecOpaqueTy =
3757 emitc::OpaqueType::get(rewriter.getContext(), "v32bfloat16");
3758 auto opaquedOperand =
3759 UnrealizedConversionCastOp::create(rewriter, ceilOp.getLoc(),
3760 vecOpaqueTy, adaptor.getOperand())
3761 .getResult(0);
3762 SmallVector<Value> ceilOperands = {opaquedOperand};
3763 auto callOp = emitc::CallOpaqueOp::create(
3764 rewriter, ceilOp.getLoc(), TypeRange{vecOpaqueTy}, "getCeilBf16",
3765 nullptr, nullptr, ceilOperands);
3766 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
3767 ceilOp, TypeRange{ceilOp.getResult().getType()}, callOp.getResults());
3768
3769 return success();
3770 }
3771};
3772
3773// Convert math.floor to a function call to compute floor(x) for v16bfloat16
3775 using OpConversionPattern::OpConversionPattern;
3776
3777 LogicalResult
3778 matchAndRewrite(math::FloorOp floorOp, OpAdaptor adaptor,
3779 ConversionPatternRewriter &rewriter) const override {
3780 auto srcType = dyn_cast<VectorType>(floorOp.getOperand().getType());
3781 if (!srcType)
3782 return failure();
3783
3784 Type scalarType = srcType.getElementType();
3785 if (!isa<FloatType>(scalarType))
3786 return failure();
3787
3788 unsigned laneSize = getVectorLaneSize(srcType);
3789 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3790 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
3791 return failure();
3792
3793 StringRef includeName = "vec_math.h";
3794 auto moduleOp = floorOp->getParentOfType<mlir::ModuleOp>();
3795 rewriter.setInsertionPointToStart(
3796 &moduleOp.getRegion().getBlocks().front());
3797 emitc::IncludeOp::create(rewriter, moduleOp.getLoc(), includeName, false);
3798
3799 rewriter.setInsertionPoint(floorOp);
3800 Type vecOpaqueTy;
3801 if (laneSize == 16)
3802 vecOpaqueTy =
3803 emitc::OpaqueType::get(rewriter.getContext(), "v16bfloat16");
3804 else
3805 vecOpaqueTy =
3806 emitc::OpaqueType::get(rewriter.getContext(), "v32bfloat16");
3807 auto opaquedOperand =
3808 UnrealizedConversionCastOp::create(rewriter, floorOp.getLoc(),
3809 vecOpaqueTy, adaptor.getOperand())
3810 .getResult(0);
3811 SmallVector<Value> floorOperands = {opaquedOperand};
3812 auto callOp = emitc::CallOpaqueOp::create(
3813 rewriter, floorOp.getLoc(), TypeRange{vecOpaqueTy}, "getFloorBf16",
3814 nullptr, nullptr, floorOperands);
3815 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
3816 floorOp, TypeRange{floorOp.getResult().getType()}, callOp.getResults());
3817
3818 return success();
3819 }
3820};
3821
3822// Convert arith.negf to aievec.neg to negate the vector for v16bfloat16 and
3823// v16float types.
3825 using OpConversionPattern::OpConversionPattern;
3826
3827 LogicalResult
3828 matchAndRewrite(arith::NegFOp negOp, OpAdaptor adaptor,
3829 ConversionPatternRewriter &rewriter) const override {
3830 auto srcType = dyn_cast<VectorType>(negOp.getOperand().getType());
3831 if (!srcType)
3832 return failure();
3833
3834 Type scalarType = srcType.getElementType();
3835 if (!isa<FloatType>(scalarType))
3836 return failure();
3837
3838 if (unsigned laneSize = getVectorLaneSize(srcType); laneSize != 16)
3839 return failure();
3840
3841 Location loc = negOp.getLoc();
3842 auto accType = getVectorOpDestType(srcType, /*AIE2 =*/true);
3843
3844 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3845 if (elWidth == 16) {
3846 auto upsOp =
3847 aievec::UPSOp::create(rewriter, loc, accType, adaptor.getOperand());
3848 auto aieNegOp =
3849 aievec::NegOp::create(rewriter, loc, accType, upsOp.getResult());
3850 auto shiftParamOp = arith::ConstantOp::create(
3851 rewriter, negOp.getLoc(), rewriter.getI32IntegerAttr(0));
3852 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
3853 negOp, srcType, aieNegOp.getResult(), shiftParamOp.getResult());
3854 } else {
3855 auto castOp = aievec::CastOp::create(
3856 rewriter, loc, accType, adaptor.getOperand(), /*isResAcc*/ true);
3857 auto aieNegOp =
3858 aievec::NegOp::create(rewriter, loc, accType, castOp.getResult());
3859 rewriter.replaceOpWithNewOp<aievec::CastOp>(
3860 negOp, srcType, aieNegOp.getResult(), /*isResAcc*/ false);
3861 }
3862
3863 return success();
3864 }
3865};
3866
3867// Check whether the value of constant operation is int type and the dense value
3868// is -1.
3869static bool hasConstNegOneValue(arith::ConstantOp constOp, unsigned elWidth) {
3870 if (!constOp)
3871 return false;
3872
3873 auto cstDense = dyn_cast<DenseIntElementsAttr>(constOp.getValue());
3874 if (!cstDense)
3875 return false;
3876
3877 if (elWidth == 32)
3878 return cstDense.getSplatValue<int32_t>() == -1;
3879 if (elWidth == 16)
3880 return cstDense.getSplatValue<int16_t>() == -1;
3881 if (elWidth == 8)
3882 return cstDense.getSplatValue<int8_t>() == -1;
3883 return false;
3884}
3885
3886// Convert arith.xori to aievec.bxor to compute bitwise xor of two vectors for
3887// integer types
3889 using OpConversionPattern::OpConversionPattern;
3890
3891 LogicalResult
3892 matchAndRewrite(arith::XOrIOp xorOp, OpAdaptor adaptor,
3893 ConversionPatternRewriter &rewriter) const override {
3894 auto srcType = dyn_cast<VectorType>(xorOp.getLhs().getType());
3895 if (!srcType)
3896 return failure();
3897
3898 Type scalarType = srcType.getElementType();
3899 if (!isa<IntegerType>(scalarType))
3900 return failure();
3901
3902 unsigned laneSize = getVectorLaneSize(srcType);
3903 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3904 if (laneSize * elWidth != 512)
3905 return failure();
3906
3907 auto lhsConstOp =
3908 dyn_cast<arith::ConstantOp>(xorOp.getLhs().getDefiningOp());
3909 auto rhsConstOp =
3910 dyn_cast<arith::ConstantOp>(xorOp.getRhs().getDefiningOp());
3911
3912 // If one of operands in xorOp is a constant -1, xorOp will be replaced with
3913 // aievec::BnegOp.
3914 if ((lhsConstOp && hasConstNegOneValue(lhsConstOp, elWidth)) ||
3915 (rhsConstOp && hasConstNegOneValue(rhsConstOp, elWidth))) {
3916 Value val = hasConstNegOneValue(lhsConstOp, elWidth) ? adaptor.getRhs()
3917 : adaptor.getLhs();
3918 rewriter.replaceOpWithNewOp<aievec::BnegOp>(xorOp, srcType, val);
3919 } else
3920 rewriter.replaceOpWithNewOp<aievec::BxorOp>(
3921 xorOp, srcType, adaptor.getLhs(), adaptor.getRhs());
3922
3923 return success();
3924 }
3925};
3926
3927template <typename SrcOpTy, typename DstOpTy>
3930 using OpAdaptor = typename SrcOpTy::Adaptor;
3931
3932 LogicalResult
3933 matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor,
3934 ConversionPatternRewriter &rewriter) const override {
3935 VectorType srcType = dyn_cast<VectorType>(srcOp.getLhs().getType());
3936 if (!srcType)
3937 return failure();
3938
3939 Type scalarType = srcType.getElementType();
3940 if (!isa<IntegerType>(scalarType))
3941 return failure();
3942
3943 unsigned laneSize = getVectorLaneSize(srcType);
3944 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3945 if (laneSize * elWidth != 512)
3946 return failure();
3947
3948 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getResult().getType(),
3949 adaptor.getLhs(), adaptor.getRhs());
3950
3951 return success();
3952 }
3953};
3954
3959
3960// Convert arith.shrsi to a combination of aievec.ups and aievec.srs to compute
3961// arithmetic right shift for integer types. Currently, only support the shift
3962// value with a broadcast vector.
3964 : OpConversionPattern<arith::ShRSIOp> {
3965 using OpConversionPattern::OpConversionPattern;
3966
3967 LogicalResult
3968 matchAndRewrite(arith::ShRSIOp rsOp, OpAdaptor adaptor,
3969 ConversionPatternRewriter &rewriter) const override {
3970 auto srcType = dyn_cast<VectorType>(adaptor.getLhs().getType());
3971 if (!srcType)
3972 return failure();
3973
3974 Type scalarType = srcType.getElementType();
3975 unsigned laneSize = getVectorLaneSize(srcType);
3976 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3977 if (laneSize * elWidth != 512)
3978 return failure();
3979
3980 auto bcastOp =
3981 dyn_cast<aievec::BroadcastOp>(adaptor.getRhs().getDefiningOp());
3982 if (!bcastOp)
3983 return failure();
3984
3985 auto constOp =
3986 arith::ConstantOp::create(rewriter, bcastOp.getLoc(),
3987 rewriter.getI32IntegerAttr(bcastOp.getIdx()));
3988 auto extElemOp = aievec::ExtElemOp::create(
3989 rewriter, bcastOp.getLoc(), scalarType, bcastOp, constOp.getResult());
3990 Location loc = rsOp.getLoc();
3991
3992 // The vector with v64int8 type can be divided into two v32int8 vectors and
3993 // be processed individually and be concatenated at the end.
3994 if (elWidth == 8) {
3995 VectorType halfSrcType = createVectorType(laneSize / 2, scalarType);
3996 auto rsOpLow = aievec::ExtOp::create(rewriter, loc, halfSrcType,
3997 adaptor.getLhs(), 0);
3998 auto rsOpHigh = aievec::ExtOp::create(rewriter, loc, halfSrcType,
3999 adaptor.getLhs(), 1);
4000 Type accType = getVectorOpDestType(halfSrcType, /*AIE2 =*/true);
4001 auto upsOpLow =
4002 aievec::UPSOp::create(rewriter, loc, accType, rsOpLow.getResult());
4003 auto srsOpLow =
4004 aievec::SRSOp::create(rewriter, loc, halfSrcType,
4005 upsOpLow.getResult(), extElemOp.getResult());
4006 auto upsOpHigh =
4007 aievec::UPSOp::create(rewriter, loc, accType, rsOpHigh.getResult());
4008 auto srsOpHigh =
4009 aievec::SRSOp::create(rewriter, loc, halfSrcType,
4010 upsOpHigh.getResult(), extElemOp.getResult());
4011 SmallVector<Value> inputSources = {srsOpLow.getResult(),
4012 srsOpHigh.getResult()};
4013 rewriter.replaceOpWithNewOp<aievec::ConcatOp>(rsOp, srcType,
4014 inputSources);
4015 } else {
4016 Type accType = getVectorOpDestType(srcType, /*AIE2 =*/true);
4017 auto upsOp =
4018 aievec::UPSOp::create(rewriter, loc, accType, adaptor.getLhs());
4019 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
4020 rsOp, srcType, upsOp.getResult(), extElemOp.getResult());
4021 }
4022
4023 return success();
4024 }
4025};
4026
4027// Recognize the compound shift+clamp+truncate pattern and lower it to
4028// aievec.ups + aievec.srs. This maps the standard quantized neural network
4029// "shift-round-saturate + clamp" idiom to the AIE2 SRS hardware unit.
4030//
4031// Pattern A (with clamp):
4032// %shifted = arith.shrsi %wide, %shift_splat
4033// %clamped0 = arith.maxsi %shifted, %c_lo
4034// %clamped = arith.minsi %clamped0, %c_hi
4035// %result = arith.trunci %clamped
4036//
4037// Pattern B (without clamp):
4038// %shifted = arith.shrsi %wide, %shift_splat
4039// %result = arith.trunci %shifted
4040//
4041// Both lower to: aievec.ups + aievec.srs
4043 using OpConversionPattern::OpConversionPattern;
4044
4045 ShiftClampTruncToSRSPattern(MLIRContext *context, PatternBenefit benefit = 2)
4046 : OpConversionPattern(context, benefit) {}
4047
4048 // Try to extract a scalar integer splat value from a Value.
4049 // Returns std::nullopt if the value is not a constant splat.
4050 static std::optional<int64_t> getConstantSplatValue(Value val) {
4051 auto defOp = val.getDefiningOp<arith::ConstantOp>();
4052 if (!defOp)
4053 return std::nullopt;
4054 auto denseAttr = dyn_cast<DenseIntElementsAttr>(defOp.getValue());
4055 if (!denseAttr || !denseAttr.isSplat())
4056 return std::nullopt;
4057 return denseAttr.getSplatValue<APInt>().getSExtValue();
4058 }
4059
4060 // Try to extract the shift amount from the right operand of shrsi.
4061 // Accepts either a constant splat vector or an aievec.broadcast of a
4062 // constant.
4063 static std::optional<Value>
4064 getShiftValue(Value rhs, ConversionPatternRewriter &rewriter, Location loc) {
4065 // Case 1: constant splat vector
4066 if (auto constOp = rhs.getDefiningOp<arith::ConstantOp>()) {
4067 auto denseAttr = dyn_cast<DenseIntElementsAttr>(constOp.getValue());
4068 if (denseAttr && denseAttr.isSplat()) {
4069 int64_t shiftVal = denseAttr.getSplatValue<APInt>().getSExtValue();
4070 return arith::ConstantOp::create(rewriter, loc,
4071 rewriter.getI32IntegerAttr(shiftVal))
4072 .getResult();
4073 }
4074 }
4075 // Case 2: aievec.broadcast
4076 if (auto bcastOp = dyn_cast<aievec::BroadcastOp>(rhs.getDefiningOp())) {
4077 auto constOp = arith::ConstantOp::create(
4078 rewriter, bcastOp.getLoc(),
4079 rewriter.getI32IntegerAttr(bcastOp.getIdx()));
4080 return aievec::ExtElemOp::create(rewriter, bcastOp.getLoc(),
4081 rewriter.getI32Type(), bcastOp,
4082 constOp.getResult())
4083 .getResult();
4084 }
4085 return std::nullopt;
4086 }
4087
4088 LogicalResult
4089 matchAndRewrite(arith::TruncIOp truncOp, OpAdaptor adaptor,
4090 ConversionPatternRewriter &rewriter) const override {
4091 auto dstType = dyn_cast<VectorType>(truncOp.getOut().getType());
4092 if (!dstType)
4093 return failure();
4094
4095 Type dstScalarType = dstType.getElementType();
4096 if (!isa<IntegerType>(dstScalarType))
4097 return failure();
4098
4099 // Walk backward through optional clamp chain
4100 Value source = adaptor.getIn();
4101 int32_t sign = 1; // default: signed
4102
4103 // Check for minsi(maxsi(...), hi) or maxsi(minsi(...), lo) clamp pattern
4104 arith::MinSIOp minOp = nullptr;
4105 arith::MaxSIOp maxOp = nullptr;
4106
4107 if (auto minsiOp = source.getDefiningOp<arith::MinSIOp>()) {
4108 if (auto maxsiOp = minsiOp.getLhs().getDefiningOp<arith::MaxSIOp>()) {
4109 minOp = minsiOp;
4110 maxOp = maxsiOp;
4111 source = maxOp.getLhs();
4112 }
4113 } else if (auto maxsiOp = source.getDefiningOp<arith::MaxSIOp>()) {
4114 if (auto minsiOp = maxsiOp.getLhs().getDefiningOp<arith::MinSIOp>()) {
4115 maxOp = maxsiOp;
4116 minOp = minsiOp;
4117 source = minOp.getLhs();
4118 }
4119 }
4120
4121 // If we found a clamp, verify it's a valid saturation range
4122 if (minOp && maxOp) {
4123 auto loVal = getConstantSplatValue(maxOp.getRhs());
4124 auto hiVal = getConstantSplatValue(minOp.getRhs());
4125 if (!loVal || !hiVal)
4126 return failure();
4127
4128 unsigned dstBits = dstScalarType.getIntOrFloatBitWidth();
4129 // Guard against UB from shifting into or past the sign bit.
4130 if (dstBits == 0 || dstBits > 63)
4131 return failure();
4132 uint64_t one = 1ULL;
4133 int64_t unsignedLo = 0;
4134 int64_t unsignedHi = static_cast<int64_t>((one << dstBits) - 1);
4135 int64_t signedLo = -static_cast<int64_t>(one << (dstBits - 1));
4136 int64_t signedHi = static_cast<int64_t>((one << (dstBits - 1)) - 1);
4137
4138 if (*loVal == unsignedLo && *hiVal == unsignedHi) {
4139 sign = 0; // unsigned saturation
4140 } else if (*loVal == signedLo && *hiVal == signedHi) {
4141 sign = 1; // signed saturation
4142 } else {
4143 // Clamp range doesn't match standard saturation — don't match
4144 return failure();
4145 }
4146 }
4147
4148 // Now source should be the shrsi result
4149 auto shrsiOp = source.getDefiningOp<arith::ShRSIOp>();
4150 if (!shrsiOp)
4151 return failure();
4152
4153 auto srcType = dyn_cast<VectorType>(shrsiOp.getLhs().getType());
4154 if (!srcType)
4155 return failure();
4156
4157 Type srcScalarType = srcType.getElementType();
4158 if (!isa<IntegerType>(srcScalarType))
4159 return failure();
4160
4161 unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth();
4162 unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth();
4163 if (dstElWidth >= srcElWidth)
4164 return failure();
4165
4166 Location loc = truncOp.getLoc();
4167
4168 // Extract the shift amount
4169 auto shiftVal = getShiftValue(shrsiOp.getRhs(), rewriter, loc);
4170 if (!shiftVal)
4171 return failure();
4172
4173 // Get the wide input (pre-shift)
4174 Value wideInput = shrsiOp.getLhs();
4175
4176 unsigned laneSize = getVectorLaneSize(srcType);
4177 bool needsPadding = (laneSize % 16 != 0);
4178
4179 VectorType paddedSrcType = srcType;
4180 VectorType paddedDstType = dstType;
4181 unsigned paddedLanes = laneSize;
4182
4183 if (needsPadding) {
4184 // Round up to nearest multiple of 16
4185 paddedLanes = ((laneSize + 15) / 16) * 16;
4186 paddedSrcType = createVectorType(paddedLanes, srcScalarType);
4187 paddedDstType = createVectorType(paddedLanes, dstScalarType);
4188
4189 // Zero-pad the input using insert_strided_slice
4190 auto zeroAttr = rewriter.getZeroAttr(paddedSrcType);
4191 auto zeroPad =
4192 arith::ConstantOp::create(rewriter, loc, zeroAttr).getResult();
4193 SmallVector<int64_t> offsets(1, 0);
4194 SmallVector<int64_t> strides(1, 1);
4195 wideInput = vector::InsertStridedSliceOp::create(
4196 rewriter, loc, wideInput, zeroPad, offsets, strides)
4197 .getResult();
4198 }
4199
4200 // Determine accumulator type and create the accumulator value.
4201 // For i16 source: use UPS to widen to accumulator type.
4202 // For i32/i64 source: use CastOp (marks as accumulator without widening),
4203 // matching the approach in LowerTruncOpPattern.
4204 Type accScalarType = paddedSrcType.getElementType();
4205 unsigned accElWidth = accScalarType.getIntOrFloatBitWidth();
4206 Value accValue;
4207 if (accElWidth == 16) {
4208 Type accType = getVectorOpDestType(paddedSrcType, /*AIE2=*/true);
4209 accValue =
4210 aievec::UPSOp::create(rewriter, loc, accType, wideInput).getResult();
4211 } else {
4212 // For i32/i64: CastOp with isResAcc=true marks as accumulator
4213 accValue = aievec::CastOp::create(rewriter, loc, paddedSrcType, wideInput,
4214 /*isResAcc=*/true)
4215 .getResult();
4216 }
4217 auto srsOp = aievec::SRSOp::create(rewriter, loc, paddedDstType, accValue,
4218 *shiftVal, sign);
4219
4220 Value result = srsOp.getResult();
4221
4222 if (needsPadding) {
4223 // Extract original lanes from the padded result
4224 SmallVector<int64_t> offsets(1, 0);
4225 SmallVector<int64_t> sizes = {static_cast<int64_t>(laneSize)};
4226 SmallVector<int64_t> strides(1, 1);
4227 result = vector::ExtractStridedSliceOp::create(rewriter, loc, result,
4228 offsets, sizes, strides)
4229 .getResult();
4230 }
4231
4232 rewriter.replaceOp(truncOp, result);
4233
4234 // Erase the intermediate clamp/shift ops if they have no other uses.
4235 // These must be cleaned up because shrsi on 512-bit vectors is marked
4236 // illegal and would cause conversion failure if left as dead ops.
4237 SmallVector<Operation *, 3> opsToErase;
4238 if (minOp && minOp->use_empty())
4239 opsToErase.push_back(minOp);
4240 if (maxOp && maxOp->use_empty())
4241 opsToErase.push_back(maxOp);
4242 if (shrsiOp->use_empty())
4243 opsToErase.push_back(shrsiOp);
4244 for (Operation *op : opsToErase)
4245 rewriter.eraseOp(op);
4246
4247 return success();
4248 }
4249};
4250
4251// Promote scalar shrsi + [clamp] + trunci chain to a vectorized SRS sequence.
4252// Anchored on scalar arith::TruncIOp. Fuses the entire chain into:
4253// broadcast_scalar -> cast(isResAcc) -> srs(narrowed, shift, sign) ->
4254// ext_elem
4255// This prevents scalar arith.trunci i32->i8 from reaching the AIE2 backend
4256// where it crashes (the backend only supports 32->16/20/32 truncations).
4258 using OpConversionPattern::OpConversionPattern;
4259
4260 // Try to extract a scalar integer constant value from a Value.
4261 static std::optional<int64_t> getScalarConstantValue(Value val) {
4262 auto defOp = val.getDefiningOp<arith::ConstantOp>();
4263 if (!defOp)
4264 return std::nullopt;
4265 auto intAttr = dyn_cast<IntegerAttr>(defOp.getValue());
4266 if (!intAttr)
4267 return std::nullopt;
4268 return intAttr.getInt();
4269 }
4270
4271 LogicalResult
4272 matchAndRewrite(arith::TruncIOp truncOp, OpAdaptor adaptor,
4273 ConversionPatternRewriter &rewriter) const override {
4274 // Only match scalar types (vector compound pattern handles vectors)
4275 Type dstType = truncOp.getOut().getType();
4276 if (isa<VectorType>(dstType))
4277 return failure();
4278
4279 auto dstIntType = dyn_cast<IntegerType>(dstType);
4280 if (!dstIntType)
4281 return failure();
4282
4283 unsigned dstBits = dstIntType.getWidth();
4284 if (dstBits != 8 && dstBits != 16)
4285 return failure();
4286
4287 auto srcIntType = dyn_cast<IntegerType>(truncOp.getIn().getType());
4288 if (!srcIntType || srcIntType.getWidth() != 32)
4289 return failure();
4290
4291 // Walk backward through optional clamp chain
4292 Value source = truncOp.getIn();
4293 int32_t sign = 1; // default: signed
4294
4295 arith::MinSIOp minOp = nullptr;
4296 arith::MaxSIOp maxOp = nullptr;
4297
4298 if (auto minsiOp = source.getDefiningOp<arith::MinSIOp>()) {
4299 if (auto maxsiOp = minsiOp.getLhs().getDefiningOp<arith::MaxSIOp>()) {
4300 minOp = minsiOp;
4301 maxOp = maxsiOp;
4302 source = maxOp.getLhs();
4303 }
4304 } else if (auto maxsiOp = source.getDefiningOp<arith::MaxSIOp>()) {
4305 if (auto minsiOp = maxsiOp.getLhs().getDefiningOp<arith::MinSIOp>()) {
4306 maxOp = maxsiOp;
4307 minOp = minsiOp;
4308 source = minOp.getLhs();
4309 }
4310 }
4311
4312 // If we found a clamp, verify it's a valid saturation range
4313 if (minOp && maxOp) {
4314 auto loVal = getScalarConstantValue(maxOp.getRhs());
4315 auto hiVal = getScalarConstantValue(minOp.getRhs());
4316 if (!loVal || !hiVal)
4317 return failure();
4318
4319 if (dstBits == 0 || dstBits > 63)
4320 return failure();
4321 uint64_t one = 1ULL;
4322 int64_t unsignedLo = 0;
4323 int64_t unsignedHi = static_cast<int64_t>((one << dstBits) - 1);
4324 int64_t signedLo = -static_cast<int64_t>(one << (dstBits - 1));
4325 int64_t signedHi = static_cast<int64_t>((one << (dstBits - 1)) - 1);
4326
4327 if (*loVal == unsignedLo && *hiVal == unsignedHi) {
4328 sign = 0; // unsigned saturation
4329 } else if (*loVal == signedLo && *hiVal == signedHi) {
4330 sign = 1; // signed saturation
4331 } else {
4332 return failure();
4333 }
4334 }
4335
4336 // Now source should be the shrsi result, or any i32 value for clamp-only
4337 Location loc = truncOp.getLoc();
4338 Value preShiftVal;
4339 Value shiftVal;
4340 arith::ShRSIOp shrsiOp = source.getDefiningOp<arith::ShRSIOp>();
4341 if (shrsiOp) {
4342 // Verify shrsi operand types
4343 if (!isa<IntegerType>(shrsiOp.getLhs().getType()))
4344 return failure();
4345 preShiftVal = shrsiOp.getLhs();
4346 shiftVal = shrsiOp.getRhs();
4347 } else {
4348 // No shrsi found: treat as identity shift (shift=0).
4349 // This handles clamp+trunci patterns without a preceding shift,
4350 // e.g., after skip-add with skip_scale=0.
4351 preShiftVal = source;
4352 shiftVal = arith::ConstantOp::create(rewriter, loc,
4353 rewriter.getI32IntegerAttr(0));
4354 }
4355
4356 // Create 512-bit vector type: vector<16xi32>
4357 unsigned srcLanes = 512 / srcIntType.getWidth(); // 16 for i32
4358 VectorType bcastVecType = createVectorType(srcLanes, srcIntType);
4359
4360 // Broadcast pre-shift value to 512-bit vector
4361 auto bcast = aievec::BroadcastScalarOp::create(rewriter, loc, bcastVecType,
4362 preShiftVal);
4363
4364 Value accValue;
4365 VectorType srsOutType;
4366
4367 if (dstBits == 8) {
4368 // i32→i8: The SRS intrinsic I256V32Acc32Srs needs 1024-bit source
4369 // (vector<32xi32>, cast to vector<16xi64> internally) and produces
4370 // vector<32xi8> (256-bit). Concat two broadcast copies to get 1024 bits.
4371 unsigned accLanes = srcLanes * 2; // 32
4372 VectorType accVecType =
4373 createVectorType(accLanes, srcIntType); // vector<32xi32>
4374 auto concatSrc = aievec::ConcatOp::create(
4375 rewriter, loc, accVecType,
4376 SmallVector<Value>({bcast.getResult(), bcast.getResult()}));
4377 accValue =
4378 aievec::CastOp::create(rewriter, loc, accVecType,
4379 concatSrc.getResult(), /*isResAcc=*/true)
4380 .getResult();
4381 srsOutType = createVectorType(accLanes, dstIntType); // vector<32xi8>
4382 } else {
4383 // i32→i16: The SRS intrinsic I256V16Acc32Srs needs 512-bit source
4384 // (vector<16xi32>, cast to vector<8xi64> internally) and produces
4385 // vector<16xi16> (256-bit). 512-bit broadcast works directly.
4386 accValue = aievec::CastOp::create(rewriter, loc, bcastVecType,
4387 bcast.getResult(), /*isResAcc=*/true)
4388 .getResult();
4389 srsOutType = createVectorType(srcLanes, dstIntType); // vector<16xi16>
4390 }
4391
4392 // SRS: accumulator → narrowed output with shift and sign
4393 auto srsOp = aievec::SRSOp::create(rewriter, loc, srsOutType, accValue,
4394 shiftVal, sign);
4395
4396 // ExtElem needs 512-bit source. SRS output is 256-bit, so concat to 512.
4397 unsigned extLanes = 512 / dstBits; // 64 for i8, 32 for i16
4398 VectorType extVecType = createVectorType(extLanes, dstIntType);
4399 auto concatForExt = aievec::ConcatOp::create(
4400 rewriter, loc, extVecType,
4401 SmallVector<Value>({srsOp.getResult(), srsOp.getResult()}));
4402
4403 // Extract element 0 back to scalar
4404 auto zeroIdx =
4405 arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0));
4406 rewriter.replaceOpWithNewOp<aievec::ExtElemOp>(
4407 truncOp, dstIntType, concatForExt.getResult(), zeroIdx.getResult());
4408
4409 // Erase dead intermediate ops
4410 SmallVector<Operation *, 3> opsToErase;
4411 if (minOp && minOp->use_empty())
4412 opsToErase.push_back(minOp);
4413 if (maxOp && maxOp->use_empty())
4414 opsToErase.push_back(maxOp);
4415 if (shrsiOp && shrsiOp->use_empty())
4416 opsToErase.push_back(shrsiOp);
4417 for (Operation *op : opsToErase)
4418 rewriter.eraseOp(op);
4419
4420 return success();
4421 }
4422};
4423
4424// Promote scalar arith.shrsi to vector aievec.ups + aievec.srs to prevent
4425// LLVM's SLP vectorizer from creating sub-512-bit vector shifts that the
4426// AIE2 backend cannot legalize (G_LSHR on <4 x s32>).
4428 using OpConversionPattern::OpConversionPattern;
4429
4430 LogicalResult
4431 matchAndRewrite(arith::ShRSIOp rsOp, OpAdaptor adaptor,
4432 ConversionPatternRewriter &rewriter) const override {
4433 // Only match scalar i32
4434 Type resultType = rsOp.getType();
4435 if (isa<VectorType>(resultType))
4436 return failure();
4437
4438 auto intType = dyn_cast<IntegerType>(resultType);
4439 if (!intType || intType.getWidth() != 32)
4440 return failure();
4441
4442 Location loc = rsOp.getLoc();
4443 VectorType vecType = createVectorType(16, intType); // vector<16xi32>
4444
4445 // Broadcast scalar value to 512-bit vector
4446 auto lhsBcast = aievec::BroadcastScalarOp::create(rewriter, loc, vecType,
4447 adaptor.getLhs());
4448
4449 // UPS: vector<16xi32> -> accumulator type (vector<16xi64>)
4450 Type accType = getVectorOpDestType(vecType, /*AIE2=*/true);
4451 auto upsOp =
4452 aievec::UPSOp::create(rewriter, loc, accType, lhsBcast.getResult());
4453
4454 // SRS: accumulator + i32 shift -> vector<16xi32>
4455 auto srsOp = aievec::SRSOp::create(rewriter, loc, vecType,
4456 upsOp.getResult(), adaptor.getRhs());
4457
4458 // Extract element 0 back to scalar
4459 auto zeroIdx =
4460 arith::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0));
4461 rewriter.replaceOpWithNewOp<aievec::ExtElemOp>(
4462 rsOp, intType, srsOp.getResult(), zeroIdx.getResult());
4463 return success();
4464 }
4465};
4466
4467// Convert a `vector.contract` op to an `aievec.matmul` op for AIE2 or
4468// `aievec.matmul_aie2p` for AIE2P
4469template <typename MatMulOpTy>
4471 : OpConversionPattern<vector::ContractionOp> {
4472 using OpConversionPattern::OpConversionPattern;
4473
4477
4478 Value reshapeLeadingUnitDims(OpBuilder &b, Value v) const {
4479 auto vecTy = dyn_cast<VectorType>(v.getType());
4480 if (!vecTy)
4481 return v;
4482 auto vecShape = vecTy.getShape();
4483
4484 size_t numLeadUnitDims = 0;
4485 while (numLeadUnitDims < vecShape.size() && vecShape[numLeadUnitDims] == 1)
4486 numLeadUnitDims++;
4487
4488 if (!numLeadUnitDims)
4489 return v;
4490
4491 SmallVector<int64_t> newShape(vecShape.begin() + numLeadUnitDims,
4492 vecShape.end());
4493 auto newVecTy = VectorType::get(newShape, vecTy.getElementType());
4494 return vector::ShapeCastOp::create(b, v.getLoc(), newVecTy, v).getResult();
4495 }
4496
4497 LogicalResult
4498 matchAndRewrite(vector::ContractionOp contractOp, OpAdaptor adaptor,
4499 ConversionPatternRewriter &rewriter) const override {
4500 auto lhs = reshapeLeadingUnitDims(rewriter, adaptor.getLhs());
4501 auto rhs = reshapeLeadingUnitDims(rewriter, adaptor.getRhs());
4502 auto acc = reshapeLeadingUnitDims(rewriter, adaptor.getAcc());
4503 bool bReshapedAcc = (acc != adaptor.getAcc());
4504
4505 if (matMoveToAcc)
4506 acc = aievec::CastOp::create(rewriter, contractOp.getLoc(), acc.getType(),
4507 acc, true);
4508
4509 auto matmulOp = MatMulOpTy::create(rewriter, contractOp.getLoc(),
4510 acc.getType(), lhs, rhs, acc);
4511 Value result;
4512 {
4513 // Replace diagnostics handler to silence errors when verifying the
4514 // validity of the matmul ops being generated.
4515 ScopedDiagnosticHandler diagHandler(
4516 contractOp.getContext(), [](Diagnostic &) { return success(); });
4517 if (failed(matmulOp.verifyInvariants())) {
4518 rewriter.eraseOp(matmulOp);
4519 // There is a possibility that, when the linalg op is converted to
4520 // contractions, lower precisions operands are cast to the target
4521 // precision outside the contraction. For those cases, we check.
4522 lhs = adaptor.getLhs();
4523 auto wideLhsValue = getSourceOfWideningOp(lhs).value_or(nullptr);
4524 if (wideLhsValue)
4525 lhs = reshapeLeadingUnitDims(rewriter, wideLhsValue);
4526
4527 rhs = adaptor.getRhs();
4528 auto wideRhsValue = getSourceOfWideningOp(rhs).value_or(nullptr);
4529 if (wideRhsValue)
4530 rhs = reshapeLeadingUnitDims(rewriter, wideRhsValue);
4531
4532 matmulOp = MatMulOpTy::create(rewriter, contractOp.getLoc(),
4533 acc.getType(), lhs, rhs, acc);
4534 if (failed(matmulOp.verifyInvariants()))
4535 return failure();
4536 }
4537 }
4538 result = matmulOp.getResult();
4539
4540 if (matMoveToAcc)
4541 result = aievec::CastOp::create(rewriter, contractOp.getLoc(),
4542 acc.getType(), result, false);
4543 if (bReshapedAcc)
4544 result = vector::ShapeCastOp::create(rewriter, contractOp.getLoc(),
4545 adaptor.getAcc().getType(), result);
4546 rewriter.replaceOp(contractOp, result);
4547
4548 return success();
4549 }
4550
4552};
4553
4558
4559// Convert a `vector.transpose` op to an `aievec.shuffle` op for AIE2.
4561 : OpConversionPattern<vector::TransposeOp> {
4562 using OpConversionPattern::OpConversionPattern;
4563 LogicalResult
4564 matchAndRewrite(vector::TransposeOp transpOp, OpAdaptor adaptor,
4565 ConversionPatternRewriter &rewriter) const override {
4566 auto resTy = transpOp.getResultVectorType();
4567 auto resShape = resTy.getShape();
4568 auto elemTyBitWidth = resTy.getElementTypeBitWidth();
4569 auto vBitWidth = std::accumulate(resShape.begin(), resShape.end(),
4570 elemTyBitWidth, std::multiplies<>());
4571 if (vBitWidth != 512)
4572 return failure();
4573
4574 if (elemTyBitWidth != 8 && elemTyBitWidth != 16 && elemTyBitWidth != 32)
4575 return failure();
4576
4577 // Verify leading dimensions are all 1.
4578 for (int64_t i = 0; i < static_cast<int64_t>(resShape.size() - 2); ++i)
4579 if (resShape[i] != 1)
4580 return failure();
4581
4582 // Only permutation of the 2 innermost dimensions are supported.
4583 ArrayRef<int64_t> perm = transpOp.getPermutation();
4584 for (int64_t i = 0; i < static_cast<int64_t>(perm.size() - 2); ++i)
4585 if (perm[i] != i)
4586 return failure();
4587 if (perm.back() != static_cast<int64_t>(perm.size() - 2))
4588 return failure();
4589
4590 auto shuffleMode = aievec::ShuffleMode::T32_4X4;
4591 if (elemTyBitWidth == 8) {
4592 switch (resShape.back()) {
4593 case 4:
4594 shuffleMode = aievec::ShuffleMode::T8_4X16;
4595 break;
4596 case 8:
4597 shuffleMode = aievec::ShuffleMode::T8_8X8;
4598 break;
4599 case 16:
4600 shuffleMode = aievec::ShuffleMode::T8_16X4;
4601 break;
4602 default:
4603 return failure();
4604 }
4605 } else if (elemTyBitWidth == 16) {
4606 switch (resShape.back()) {
4607 case 2:
4608 shuffleMode = aievec::ShuffleMode::T16_2X16;
4609 break;
4610 case 4:
4611 shuffleMode = aievec::ShuffleMode::T16_4X8;
4612 break;
4613 case 8:
4614 shuffleMode = aievec::ShuffleMode::T16_8X4;
4615 break;
4616 case 16:
4617 shuffleMode = aievec::ShuffleMode::T16_16X2;
4618 break;
4619 default:
4620 return failure();
4621 }
4622 } else if (resShape.back() != 4)
4623 return failure();
4624
4625 auto flatVecTy =
4626 VectorType::get({512 / elemTyBitWidth}, resTy.getElementType());
4627 auto loc = transpOp.getLoc();
4628 auto flatInput = vector::ShapeCastOp::create(rewriter, loc, flatVecTy,
4629 adaptor.getVector());
4630 auto shuffOp = aievec::ShuffleOp::create(rewriter, loc, flatVecTy,
4631 flatInput, nullptr, shuffleMode);
4632 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transpOp, resTy, shuffOp);
4633
4634 return success();
4635 }
4636};
4637
4638//===----------------------------------------------------------------------===//
4639// Pattern collection
4640//===----------------------------------------------------------------------===//
4641
4642static void populateAIEVecCommonConversionPatterns(RewritePatternSet &patterns,
4643 TargetBackend backend) {
4644 // clang-format off
4645 patterns.add<LowerExtFOpPattern,
4648 LowerTruncIOpPattern>(patterns.getContext());
4649 // clang-format on
4650}
4651
4652static void populateAIEVecV1ConversionPatterns(RewritePatternSet &patterns,
4653 TargetBackend backend) {
4654 patterns.add<LowerVectorTransferReadToAIEUPD>(patterns.getContext(), 128, 512,
4655 128, 256);
4656 // clang-format off
4657 patterns.add<LowerVectorAddIOpToAIEVecAddOp,
4665 LowerVectorExtractStridedSliceOpAIEv1Pattern>(patterns.getContext());
4666 // clang-format on
4667}
4668
4669// Populate common conversion patterns for AIE2 and AIE2P
4670static void
4671populateAIEVecV2CommonConversionPatterns(RewritePatternSet &patterns,
4672 TargetBackend backend) {
4673 // clang-format off
4674 // TODO: Reorder these alphabetically
4675 if (backend == TargetBackend::CPP) {
4676 patterns.add<
4678 >(patterns.getContext(), 128, 1024, 256, 1024);
4679 patterns.add<
4687 >(patterns.getContext());
4688 } else if (backend == TargetBackend::LLVMIR){
4689 patterns.add<
4692 >(patterns.getContext());
4693 }
4694 // Add the compound shift+clamp+trunc→SRS pattern with higher benefit
4695 // so it takes priority over the individual shrsi and trunci patterns.
4696 patterns.add<ShiftClampTruncToSRSPattern>(patterns.getContext(),
4697 /*benefit=*/2);
4698 // Scalar version of compound SRS with even higher benefit.
4699 patterns.add<LowerScalarShiftClampTruncToSRS>(patterns.getContext(),
4700 /*benefit=*/3);
4701 patterns.add<
4739 >(patterns.getContext());
4740 // clang-format on
4741}
4742
4743static void populateAIEVecV2ConversionPatterns(RewritePatternSet &patterns,
4744 TargetBackend backend) {
4745 populateAIEVecV2CommonConversionPatterns(patterns, backend);
4747 patterns.getContext(), backend == TargetBackend::CPP);
4748 patterns.add<LowerVectorReductionAddBfloat16OpAIE2>(patterns.getContext());
4749 // For AIE2 with LLVMIR backend, use LUT-based exp and rsqrt
4750 if (backend == TargetBackend::LLVMIR) {
4752 patterns.getContext());
4753 }
4754}
4755
4756// AIE2p-specific version of ConvertSplatToAIEBroadcast that supports direct
4757// 256-bit broadcasts without extract
4759 : OpConversionPattern<vector::BroadcastOp> {
4760 using OpConversionPattern::OpConversionPattern;
4761
4762 LogicalResult
4763 matchAndRewrite(vector::BroadcastOp bcastOp, OpAdaptor adaptor,
4764 ConversionPatternRewriter &rewriter) const override {
4765
4766 if (adaptor.getSource().getDefiningOp<vector::ExtractOp>())
4767 return failure();
4768
4769 auto resultType = cast<VectorType>(bcastOp.getResult().getType());
4770 auto flatResultType = getFlattenedVectorType(resultType);
4771 Type scalarType = resultType.getElementType();
4772 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
4773 unsigned laneSize = getVectorLaneSize(resultType);
4774 auto src = bcastOp.getSource();
4775
4776 // AIE2p supports both 256-bit and 512-bit broadcast directly
4777 if (laneSize * elWidth == 512 || laneSize * elWidth == 256) {
4778 Value newOp = aievec::BroadcastScalarOp::create(
4779 rewriter, bcastOp.getLoc(), flatResultType, src);
4780 if (resultType != flatResultType)
4781 newOp = vector::ShapeCastOp::create(rewriter, bcastOp.getLoc(),
4782 resultType, newOp);
4783 rewriter.replaceOp(bcastOp, newOp);
4784 return success();
4785 }
4786
4787 if (laneSize * elWidth == 1024) {
4788 VectorType vecType = createVectorType(512 / elWidth, scalarType);
4789 auto aieBcastOp = aievec::BroadcastScalarOp::create(
4790 rewriter, bcastOp.getLoc(), vecType, src);
4791 Value newOp = aievec::ConcatOp::create(
4792 rewriter, bcastOp.getLoc(), flatResultType,
4793 SmallVector<Value>({aieBcastOp.getResult(), aieBcastOp.getResult()}));
4794 if (resultType != flatResultType)
4795 newOp = vector::ShapeCastOp::create(rewriter, bcastOp.getLoc(),
4796 resultType, newOp);
4797 rewriter.replaceOp(bcastOp, newOp);
4798 return success();
4799 }
4800
4801 return failure();
4802 }
4803};
4804
4805static void populateAIEVecV2PConversionPatterns(RewritePatternSet &patterns,
4806 TargetBackend backend) {
4807 populateAIEVecV2CommonConversionPatterns(patterns, backend);
4809 patterns.getContext(), backend == TargetBackend::CPP);
4810 // AIE2p-specific broadcast pattern that handles 256-bit directly
4811 patterns.add<ConvertSplatToAIEBroadcastAIE2p>(patterns.getContext());
4812 patterns.add<LowerVectorReductionAddBfloat16OpAIE2P>(patterns.getContext());
4813 // For AIE2P with LLVMIR backend, use aievec.exp and aievec.inv
4814 // math.rsqrt is kept legal and will be lowered in AIEVecToLLVM pass
4815 if (backend == TargetBackend::LLVMIR) {
4817 ConvertDivFToAIEVecInvOpPattern>(patterns.getContext());
4818 // Higher benefit to take priority over the AIE2 LUT-based tanh pattern
4819 // registered in the common patterns.
4820 patterns.add<ConvertMathTanhToAIEVecTanhOpPattern>(patterns.getContext(),
4821 /*benefit=*/2);
4822 }
4823}
4824
4825//===----------------------------------------------------------------------===//
4826// Legalizations
4827//===----------------------------------------------------------------------===//
4828
4829// TODO: Review the validity of these legalizations beyond basic cases.
4830
4831static bool isInSigmoidOperationChain(math::ExpOp expOp) {
4832 if (!expOp.getOperand().getDefiningOp<arith::NegFOp>())
4833 return false;
4834
4835 arith::AddFOp addOp = nullptr;
4836 for (Operation *user : expOp->getUsers()) {
4837 addOp = dyn_cast<arith::AddFOp>(user);
4838 if (addOp)
4839 break;
4840 }
4841
4842 if (!addOp)
4843 return false;
4844
4845 auto *addLvalOp = addOp.getLhs().getDefiningOp();
4846 auto *addRvalOp = addOp.getRhs().getDefiningOp();
4847 if (!((isa<math::ExpOp>(addLvalOp) && isa<arith::ConstantOp>(addRvalOp)) ||
4848 (isa<math::ExpOp>(addRvalOp) && isa<arith::ConstantOp>(addLvalOp))))
4849 return false;
4850
4851 auto constOp = isa<arith::ConstantOp>(addLvalOp)
4852 ? cast<arith::ConstantOp>(addLvalOp)
4853 : cast<arith::ConstantOp>(addRvalOp);
4854
4855 auto cstDense = dyn_cast<DenseFPElementsAttr>(constOp.getValue());
4856 if (!cstDense)
4857 return false;
4858
4859 if (cstDense.getSplatValue<APFloat>().convertToFloat() != 1.0f)
4860 return false;
4861
4862 arith::DivFOp divOp = nullptr;
4863 for (Operation *user : addOp->getUsers()) {
4864 divOp = dyn_cast<arith::DivFOp>(user);
4865 if (divOp)
4866 break;
4867 }
4868
4869 if (!divOp)
4870 return false;
4871
4872 constOp = dyn_cast<arith::ConstantOp>(divOp.getLhs().getDefiningOp());
4873 if (!constOp)
4874 return false;
4875 cstDense = dyn_cast<DenseFPElementsAttr>(constOp.getValue());
4876 if (!cstDense)
4877 return false;
4878 if (cstDense.getSplatValue<APFloat>().convertToFloat() != 1.0f)
4879 return false;
4880
4881 return true;
4882}
4883
4884static void configureAIEVecCommonLegalizations(ConversionTarget &target,
4885 TargetBackend backend) {
4886 target
4887 .addLegalDialect<xilinx::aievec::aie1::AIEVecAIE1Dialect,
4888 xilinx::aievec::AIEVecDialect, arith::ArithDialect,
4889 ub::UBDialect, emitc::EmitCDialect, func::FuncDialect>();
4890 if (backend == TargetBackend::CPP) {
4891 target.addIllegalOp<vector::TransferReadOp>();
4892 }
4893 target.addIllegalOp<vector::ExtractStridedSliceOp>();
4894 target.addLegalOp<vector::BitCastOp>();
4895
4896 target.addDynamicallyLegalOp<arith::ExtFOp>([](arith::ExtFOp extfOp) {
4897 auto srcType = dyn_cast<VectorType>(extfOp.getIn().getType());
4898 auto dstType = dyn_cast<VectorType>(extfOp.getOut().getType());
4899 if (!srcType || !dstType)
4900 return true;
4901
4902 Type srcScalarType = srcType.getElementType();
4903 Type dstScalarType = dstType.getElementType();
4904 if (!isa<FloatType>(srcScalarType) || !isa<FloatType>(dstScalarType))
4905 return true;
4906
4907 unsigned srcLaneSize = getVectorLaneSize(srcType);
4908 unsigned dstLaneSize = getVectorLaneSize(dstType);
4909 unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth();
4910 unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth();
4911 return srcElWidth != 16 || srcLaneSize != 16 || dstElWidth != 32 ||
4912 dstLaneSize != 16;
4913 });
4914
4915 target.addDynamicallyLegalOp<arith::ExtSIOp>([](arith::ExtSIOp extsiOp) {
4916 auto srcType = dyn_cast<VectorType>(extsiOp.getIn().getType());
4917 auto dstType = dyn_cast<VectorType>(extsiOp.getOut().getType());
4918 if (!srcType || !dstType)
4919 return true;
4920
4921 Type srcScalarType = srcType.getElementType();
4922 Type dstScalarType = dstType.getElementType();
4923 if (!isa<IntegerType>(srcScalarType) || !isa<IntegerType>(dstScalarType))
4924 return true;
4925
4926 unsigned srcLaneSize = getVectorLaneSize(srcType);
4927 unsigned dstLaneSize = getVectorLaneSize(dstType);
4928 unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth();
4929 unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth();
4930 return srcLaneSize != 32 || (dstElWidth <= srcElWidth) ||
4931 (dstLaneSize != srcLaneSize);
4932 });
4933
4934 target.addDynamicallyLegalOp<arith::TruncFOp>([](arith::TruncFOp truncfOp) {
4935 auto srcType = dyn_cast<VectorType>(truncfOp.getIn().getType());
4936 auto dstType = dyn_cast<VectorType>(truncfOp.getOut().getType());
4937 if (!srcType || !dstType)
4938 return true;
4939
4940 Type srcScalarType = srcType.getElementType();
4941 Type dstScalarType = dstType.getElementType();
4942 if (!isa<FloatType>(srcScalarType) || !isa<FloatType>(dstScalarType))
4943 return true;
4944
4945 unsigned srcLaneSize = getVectorLaneSize(srcType);
4946 unsigned dstLaneSize = getVectorLaneSize(dstType);
4947 unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth();
4948 unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth();
4949 return srcElWidth != 32 || srcLaneSize != 16 || dstElWidth != 16 ||
4950 dstLaneSize != 16;
4951 });
4952
4953 target.addDynamicallyLegalOp<arith::TruncIOp>([](arith::TruncIOp trunciOp) {
4954 auto srcType = dyn_cast<VectorType>(trunciOp.getIn().getType());
4955 auto dstType = dyn_cast<VectorType>(trunciOp.getOut().getType());
4956 if (!srcType || !dstType) {
4957 // Scalar trunci: mark illegal if part of compound SRS chain
4958 // so the LowerScalarShiftClampTruncToSRS pattern can convert it.
4959 if (!srcType && !dstType && isSRSCompoundCandidate(trunciOp))
4960 return false;
4961 return true;
4962 }
4963
4964 Type srcScalarType = srcType.getElementType();
4965 Type dstScalarType = dstType.getElementType();
4966 if (!isa<IntegerType>(srcScalarType) || !isa<IntegerType>(dstScalarType))
4967 return true;
4968
4969 // Also mark vector trunci as illegal if it's part of a compound SRS chain
4970 if (isSRSCompoundCandidate(trunciOp))
4971 return false;
4972
4973 unsigned srcLaneSize = getVectorLaneSize(srcType);
4974 unsigned dstLaneSize = getVectorLaneSize(dstType);
4975 unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth();
4976 unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth();
4977
4978 return srcLaneSize != 32 || (dstElWidth >= srcElWidth) ||
4979 (dstLaneSize != srcLaneSize);
4980 });
4981
4982 target.addDynamicallyLegalOp<math::TanhOp>([](math::TanhOp tanhOp) {
4983 auto srcType = dyn_cast<VectorType>(tanhOp.getOperand().getType());
4984 if (!srcType)
4985 return true;
4986
4987 Type scalarType = srcType.getElementType();
4988 if (!isa<FloatType>(scalarType))
4989 return true;
4990
4991 unsigned laneSize = getVectorLaneSize(srcType);
4992 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
4993 return elWidth != 16 || laneSize != 16;
4994 });
4995
4996 target.addDynamicallyLegalOp<math::SqrtOp>([](math::SqrtOp sqrtOp) {
4997 auto srcType = dyn_cast<VectorType>(sqrtOp.getOperand().getType());
4998 if (!srcType)
4999 return true;
5000
5001 Type scalarType = srcType.getElementType();
5002 if (!isa<FloatType>(scalarType))
5003 return true;
5004
5005 unsigned laneSize = getVectorLaneSize(srcType);
5006 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5007 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
5008 });
5009
5010 target.addDynamicallyLegalOp<math::ErfOp>([](math::ErfOp erfOp) {
5011 auto srcType = dyn_cast<VectorType>(erfOp.getOperand().getType());
5012 if (!srcType)
5013 return true;
5014
5015 Type scalarType = srcType.getElementType();
5016 if (!isa<FloatType>(scalarType))
5017 return true;
5018
5019 unsigned laneSize = getVectorLaneSize(srcType);
5020 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5021 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
5022 });
5023
5024 target.addDynamicallyLegalOp<math::AbsFOp>([](math::AbsFOp absfOp) {
5025 auto srcType = dyn_cast<VectorType>(absfOp.getOperand().getType());
5026 if (!srcType)
5027 return true;
5028
5029 Type scalarType = srcType.getElementType();
5030 unsigned laneSize = getVectorLaneSize(srcType);
5031 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5032 return elWidth * laneSize != 512 && elWidth * laneSize != 256;
5033 });
5034
5035 target.addDynamicallyLegalOp<math::AbsIOp>([](math::AbsIOp absiOp) {
5036 auto srcType = dyn_cast<VectorType>(absiOp.getOperand().getType());
5037 if (!srcType)
5038 return true;
5039
5040 Type scalarType = srcType.getElementType();
5041 unsigned laneSize = getVectorLaneSize(srcType);
5042 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5043 return elWidth * laneSize != 512 && elWidth * laneSize != 256;
5044 });
5045
5046 // CPP backend: Mark 1/x pattern as illegal for conversion to inv() via LUT
5047 // LLVMIR backend: Keep scalar divf legal (handled by downstream passes)
5048 if (backend == TargetBackend::CPP) {
5049 target.addDynamicallyLegalOp<arith::DivFOp>([](arith::DivFOp divfOp) {
5050 if (auto srcType = dyn_cast<VectorType>(divfOp.getLhs().getType());
5051 !srcType) {
5052 Type scalarType = divfOp.getLhs().getType();
5053 if (!divfOp->hasOneUse() || !isa<FloatType>(scalarType))
5054 return true;
5055 if (!isNarrowingOp(*divfOp->getUsers().begin()))
5056 return true;
5057
5058 auto fType = cast<FloatType>(scalarType);
5059 if (fType.getWidth() != 32)
5060 return true;
5061
5062 auto constOp =
5063 dyn_cast<arith::ConstantOp>(divfOp.getLhs().getDefiningOp());
5064 if (!constOp ||
5065 cast<FloatAttr>(constOp.getValue()).getValue().convertToDouble() !=
5066 1.0f)
5067 return true;
5068 } else {
5069 Type scalarType = srcType.getElementType();
5070 if (!isa<FloatType>(scalarType))
5071 return true;
5072
5073 unsigned laneSize = getVectorLaneSize(srcType);
5074 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5075
5076 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
5077 return true;
5078
5079 arith::NegFOp negOp = nullptr;
5080 if (!hasSigmoidComputationChain(divfOp, negOp))
5081 return true;
5082 }
5083
5084 return false;
5085 });
5086 }
5087
5088 target.addDynamicallyLegalOp<math::CeilOp>([](math::CeilOp ceilOp) {
5089 auto srcType = dyn_cast<VectorType>(ceilOp.getOperand().getType());
5090 if (!srcType)
5091 return true;
5092 Type scalarType = srcType.getElementType();
5093 if (!isa<FloatType>(scalarType))
5094 return true;
5095
5096 unsigned laneSize = getVectorLaneSize(srcType);
5097 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5098 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
5099 });
5100
5101 target.addDynamicallyLegalOp<math::FloorOp>([](math::FloorOp floorOp) {
5102 auto srcType = dyn_cast<VectorType>(floorOp.getOperand().getType());
5103 if (!srcType)
5104 return true;
5105 Type scalarType = srcType.getElementType();
5106 if (!isa<FloatType>(scalarType))
5107 return true;
5108
5109 unsigned laneSize = getVectorLaneSize(srcType);
5110 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5111 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
5112 });
5113
5114 target.addDynamicallyLegalOp<arith::NegFOp>([](arith::NegFOp negOp) {
5115 auto srcType = dyn_cast<VectorType>(negOp.getOperand().getType());
5116 if (!srcType)
5117 return true;
5118 if (Type scalarType = srcType.getElementType(); !isa<FloatType>(scalarType))
5119 return true;
5120
5121 unsigned laneSize = getVectorLaneSize(srcType);
5122 return laneSize != 16;
5123 });
5124
5125 target.addDynamicallyLegalOp<arith::XOrIOp>([](arith::XOrIOp xorOp) {
5126 auto srcType = dyn_cast<VectorType>(xorOp.getLhs().getType());
5127 if (!srcType)
5128 return true;
5129 Type scalarType = srcType.getElementType();
5130 if (!isa<IntegerType>(scalarType))
5131 return true;
5132
5133 unsigned laneSize = getVectorLaneSize(srcType);
5134 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5135
5136 return laneSize * elWidth != 512;
5137 });
5138
5139 target.addDynamicallyLegalOp<arith::OrIOp>([](arith::OrIOp orOp) {
5140 auto srcType = dyn_cast<VectorType>(orOp.getLhs().getType());
5141 if (!srcType)
5142 return true;
5143 Type scalarType = srcType.getElementType();
5144 if (!isa<IntegerType>(scalarType))
5145 return true;
5146
5147 unsigned laneSize = getVectorLaneSize(srcType);
5148 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5149
5150 return laneSize * elWidth != 512;
5151 });
5152
5153 target.addDynamicallyLegalOp<arith::ShRSIOp>([](arith::ShRSIOp rsOp) {
5154 auto srcType = dyn_cast<VectorType>(rsOp.getLhs().getType());
5155 if (!srcType) {
5156 // Scalar i32 shrsi: illegal unless it feeds into a compound SRS chain
5157 // (the compound pattern consumes it via the trunci anchor)
5158 if (auto intType = dyn_cast<IntegerType>(rsOp.getLhs().getType()))
5159 if (intType.getWidth() == 32) {
5160 if (shrsiUsedByCompoundSRS(rsOp))
5161 return true; // legal — compound pattern will handle
5162 return false; // illegal — individual pattern promotes
5163 }
5164 return true;
5165 }
5166
5167 // If the shrsi feeds into a compound SRS pattern (shrsi+clamp+trunc),
5168 // keep it legal — the compound pattern will consume it via the trunci.
5169 if (shrsiUsedByCompoundSRS(rsOp))
5170 return true;
5171
5172 Type scalarType = srcType.getElementType();
5173 unsigned laneSize = getVectorLaneSize(srcType);
5174 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5175
5176 return laneSize * elWidth != 512;
5177 });
5178
5179 target.addDynamicallyLegalOp<arith::AndIOp>([](arith::AndIOp andOp) {
5180 auto srcType = dyn_cast<VectorType>(andOp.getLhs().getType());
5181 if (!srcType)
5182 return true;
5183 Type scalarType = srcType.getElementType();
5184 if (!isa<IntegerType>(scalarType))
5185 return true;
5186
5187 unsigned laneSize = getVectorLaneSize(srcType);
5188 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5189
5190 return laneSize * elWidth != 512;
5191 });
5192
5193 if (backend == TargetBackend::CPP) {
5194 target.addDynamicallyLegalOp<arith::AddIOp>(
5195 [](arith::AddIOp op) { return !isa<VectorType>(op.getType()); });
5196 }
5197 target.addDynamicallyLegalOp<arith::AddFOp>(
5198 [](arith::AddFOp op) { return !isa<VectorType>(op.getType()); });
5199 target.addDynamicallyLegalOp<arith::SubIOp>(
5200 [](arith::SubIOp op) { return !isa<VectorType>(op.getType()); });
5201 target.addDynamicallyLegalOp<arith::SubFOp>(
5202 [](arith::SubFOp op) { return !isa<VectorType>(op.getType()); });
5203}
5204
5205static void configureAIEVecV1Legalizations(ConversionTarget &target,
5206 TargetBackend backend) {
5207 target.addDynamicallyLegalOp<arith::MulIOp>(
5208 [](arith::MulIOp op) { return !isa<VectorType>(op.getType()); });
5209 target.addDynamicallyLegalOp<arith::MulFOp>(
5210 [](arith::MulFOp op) { return !isa<VectorType>(op.getType()); });
5211 target.addDynamicallyLegalOp<aievec::aie1::FMAOp>(
5212 [](xilinx::aievec::aie1::FMAOp op) {
5213 auto *lhsDefOp = op.getLhs().getDefiningOp();
5214 aievec::ConcatOp concatOp = nullptr;
5215 if (lhsDefOp)
5216 concatOp = dyn_cast<aievec::ConcatOp>(op.getLhs().getDefiningOp());
5217 if (!concatOp)
5218 return true;
5219
5220 vector::BroadcastOp srcBcast = nullptr;
5221 if (auto *lhsOp = concatOp.getSources()[0].getDefiningOp())
5222 srcBcast = dyn_cast<vector::BroadcastOp>(lhsOp);
5223 if (!srcBcast) {
5224 auto *rhsOp = op.getRhs().getDefiningOp();
5225 if (!rhsOp)
5226 return true;
5227 srcBcast = dyn_cast<vector::BroadcastOp>(rhsOp);
5228 }
5229
5230 if (srcBcast)
5231 if (auto *srcOp = srcBcast.getSource().getDefiningOp())
5232 return !isa<vector::ExtractOp>(srcOp);
5233
5234 return true;
5235 });
5236
5237 target.addDynamicallyLegalOp<aievec::aie1::AddOp>([](aievec::aie1::AddOp op) {
5238 auto lSrsOp = op.getLhs().getDefiningOp<aievec::SRSOp>();
5239 auto rSrsOp = op.getRhs().getDefiningOp<aievec::SRSOp>();
5240 return (!lSrsOp ||
5241 !lSrsOp.getSource().getDefiningOp<aievec::aie1::MulOp>()) &&
5242 (!rSrsOp ||
5243 !rSrsOp.getSource().getDefiningOp<aievec::aie1::MulOp>());
5244 });
5245 target.addLegalDialect<memref::MemRefDialect>();
5246}
5247
5248static void configureAIEVecV2PLegalizations(ConversionTarget &target,
5249 TargetBackend backend) {
5250 // AIE2P-specific legalization for rsqrt with LLVMIR backend
5251 // Vector bf16 rsqrt is illegal (no hardware support)
5252 // Scalar f32 and vector f32 rsqrt are legal (lowered in AIEVecToLLVM pass)
5253 if (backend == TargetBackend::LLVMIR) {
5254 target.addDynamicallyLegalOp<math::RsqrtOp>([](math::RsqrtOp rsqrtOp) {
5255 auto vecType = dyn_cast<VectorType>(rsqrtOp.getOperand().getType());
5256 // Vector bf16 rsqrt is illegal
5257 if (vecType && vecType.getElementType().isBF16())
5258 return false;
5259 // Everything else is legal (scalar f32, vector f32)
5260 return true;
5261 });
5262
5263 // AIE2P-specific legalization for exp with LLVMIR backend
5264 // v16bf16 and v32bf16 exp are illegal (uses hardware intrinsic)
5265 target.addDynamicallyLegalOp<math::ExpOp>([](math::ExpOp expOp) {
5266 auto srcType = dyn_cast<VectorType>(expOp.getOperand().getType());
5267 if (!srcType)
5268 return true;
5269
5270 Type scalarType = srcType.getElementType();
5271 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5272 unsigned laneSize = getVectorLaneSize(srcType);
5273 // AIE2P LLVMIR: v16bf16 and v32bf16 are illegal (uses aievec.exp)
5274 if (!scalarType.isBF16() || (laneSize != 16 && laneSize != 32) ||
5275 elWidth != 16)
5276 return true;
5277 if (expOp->hasOneUse() && isInSigmoidOperationChain(expOp))
5278 return true;
5279
5280 return false;
5281 });
5282
5283 // AIE2P-specific legalization for tanh with LLVMIR backend
5284 // v16bf16 and v32bf16 tanh are illegal (uses hardware intrinsic)
5285 target.addDynamicallyLegalOp<math::TanhOp>([](math::TanhOp tanhOp) {
5286 auto srcType = dyn_cast<VectorType>(tanhOp.getOperand().getType());
5287 if (!srcType)
5288 return true;
5289
5290 Type scalarType = srcType.getElementType();
5291 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5292 unsigned laneSize = getVectorLaneSize(srcType);
5293 // AIE2P LLVMIR: v16bf16 and v32bf16 are illegal (uses aievec.tanh)
5294 if (!scalarType.isBF16() || (laneSize != 16 && laneSize != 32) ||
5295 elWidth != 16)
5296 return true;
5297
5298 return false;
5299 });
5300
5301 // AIE2P-specific legalization for divf 1.0/x pattern with LLVMIR backend
5302 // Scalar f32 or vector<Nxf32> divf with constant 1.0 LHS is illegal
5303 target.addDynamicallyLegalOp<arith::DivFOp>([](arith::DivFOp divfOp) {
5304 Type srcType = divfOp.getLhs().getType();
5305
5306 // Check if LHS is defined by a constant operation
5307 auto constOp =
5308 dyn_cast_or_null<arith::ConstantOp>(divfOp.getLhs().getDefiningOp());
5309 if (!constOp)
5310 return true;
5311
5312 // Scalar f32 case - check for exactly 1.0
5313 if (srcType.isF32()) {
5314 auto floatAttr = dyn_cast<FloatAttr>(constOp.getValue());
5315 if (floatAttr && floatAttr.getValue().isExactlyValue(1.0))
5316 return false; // illegal - will be converted to aievec.inv
5317 return true;
5318 }
5319
5320 // Vector f32 case - check for splat of exactly 1.0
5321 if (auto vecType = dyn_cast<VectorType>(srcType)) {
5322 if (vecType.getElementType().isF32()) {
5323 unsigned laneSize = getVectorLaneSize(vecType);
5324 if (laneSize == 16 || laneSize == 32) {
5325 auto denseAttr = dyn_cast<DenseFPElementsAttr>(constOp.getValue());
5326 if (denseAttr && denseAttr.isSplat() &&
5327 denseAttr.getSplatValue<APFloat>().isExactlyValue(1.0))
5328 return false; // illegal - will be converted to aievec.inv
5329 }
5330 }
5331 }
5332
5333 return true;
5334 });
5335 }
5336 // For CPP backend, exp remains legal (uses LUT pattern from common patterns)
5337
5338 // AIE2P-specific legalization: ExtFOp on vector is always illegal
5339 target.addDynamicallyLegalOp<arith::ExtFOp>([](arith::ExtFOp extfOp) {
5340 auto srcType = dyn_cast<VectorType>(extfOp.getIn().getType());
5341 auto dstType = dyn_cast<VectorType>(extfOp.getOut().getType());
5342 if (!srcType || !dstType)
5343 return true;
5344
5345 Type srcScalarType = srcType.getElementType();
5346 Type dstScalarType = dstType.getElementType();
5347 if (!isa<FloatType>(srcScalarType) || !isa<FloatType>(dstScalarType))
5348 return true;
5349
5350 unsigned srcLaneSize = getVectorLaneSize(srcType);
5351 unsigned dstLaneSize = getVectorLaneSize(dstType);
5352 if ((srcLaneSize % 16 == 0) && (dstLaneSize % 16 == 0))
5353 return false;
5354
5355 return true;
5356 });
5357
5358 // AIE2P-specific legalization: TruncFOp on vector is always illegal
5359 target.addDynamicallyLegalOp<arith::TruncFOp>([](arith::TruncFOp truncfOp) {
5360 auto srcType = dyn_cast<VectorType>(truncfOp.getIn().getType());
5361 auto dstType = dyn_cast<VectorType>(truncfOp.getOut().getType());
5362 if (!srcType || !dstType)
5363 return true;
5364
5365 Type srcScalarType = srcType.getElementType();
5366 Type dstScalarType = dstType.getElementType();
5367 if (!isa<FloatType>(srcScalarType) || !isa<FloatType>(dstScalarType))
5368 return true;
5369
5370 unsigned srcLaneSize = getVectorLaneSize(srcType);
5371 unsigned dstLaneSize = getVectorLaneSize(dstType);
5372 if ((srcLaneSize % 16 == 0) && (dstLaneSize % 16 == 0))
5373 return false;
5374
5375 return true;
5376 });
5377
5378 // AIE2P-specific legalization: ExtSIOp on vector is always illegal
5379 target.addDynamicallyLegalOp<arith::ExtSIOp>([](arith::ExtSIOp extsiOp) {
5380 auto srcType = dyn_cast<VectorType>(extsiOp.getIn().getType());
5381 auto dstType = dyn_cast<VectorType>(extsiOp.getOut().getType());
5382 if (!srcType || !dstType)
5383 return true;
5384
5385 Type srcScalarType = srcType.getElementType();
5386 Type dstScalarType = dstType.getElementType();
5387 if (!isa<IntegerType>(srcScalarType) || !isa<IntegerType>(dstScalarType))
5388 return true;
5389
5390 unsigned srcLaneSize = getVectorLaneSize(srcType);
5391 unsigned dstLaneSize = getVectorLaneSize(dstType);
5392 if ((srcLaneSize % 16 == 0) && (dstLaneSize % 16 == 0))
5393 return false;
5394
5395 return true;
5396 });
5397
5398 // AIE2P-specific legalization: TruncIOp on vector is always illegal
5399 target.addDynamicallyLegalOp<arith::TruncIOp>([](arith::TruncIOp trunciOp) {
5400 auto srcType = dyn_cast<VectorType>(trunciOp.getIn().getType());
5401 auto dstType = dyn_cast<VectorType>(trunciOp.getOut().getType());
5402 if (!srcType || !dstType) {
5403 // Scalar trunci: mark illegal if part of compound SRS chain
5404 if (!srcType && !dstType && isSRSCompoundCandidate(trunciOp))
5405 return false;
5406 return true;
5407 }
5408 Type srcScalarType = srcType.getElementType();
5409 Type dstScalarType = dstType.getElementType();
5410 if (!isa<IntegerType>(srcScalarType) || !isa<IntegerType>(dstScalarType))
5411 return true;
5412
5413 // Also mark as illegal if it's part of a shrsi+clamp+trunc SRS pattern,
5414 // even for sub-AIE-width vectors
5415 if (isSRSCompoundCandidate(trunciOp))
5416 return false;
5417
5418 unsigned srcLaneSize = getVectorLaneSize(srcType);
5419 unsigned dstLaneSize = getVectorLaneSize(dstType);
5420 if ((srcLaneSize % 16 == 0) && (dstLaneSize % 16 == 0))
5421 return false;
5422
5423 return true;
5424 });
5425
5426 // AIE2P-specific legalization: Override AddFOp to support laneSize==32 for
5427 // float types
5428 target.addDynamicallyLegalOp<arith::AddFOp>([](arith::AddFOp op) {
5429 auto resultType = dyn_cast<VectorType>(op.getType());
5430 if (!resultType)
5431 return true;
5432
5433 Type scalarType = resultType.getElementType();
5434 unsigned laneSize = getVectorLaneSize(resultType);
5435
5436 // For float types, support both laneSize==16 and laneSize==32
5437 if (isa<FloatType>(scalarType))
5438 return laneSize != 16 && laneSize != 32;
5439
5440 // For other types, only laneSize==16 (same as AIE2)
5441 return laneSize != 16;
5442 });
5443
5444 // AIE2P-specific legalization: Override SubFOp to support laneSize==32 for
5445 // float types
5446 target.addDynamicallyLegalOp<arith::SubFOp>([](arith::SubFOp op) {
5447 auto resultType = dyn_cast<VectorType>(op.getType());
5448 if (!resultType)
5449 return true;
5450
5451 Type scalarType = resultType.getElementType();
5452 unsigned laneSize = getVectorLaneSize(resultType);
5453
5454 // For float types, support both laneSize==16 and laneSize==32
5455 if (isa<FloatType>(scalarType))
5456 return laneSize != 16 && laneSize != 32;
5457
5458 // For other types, only laneSize==16 (same as AIE2)
5459 return laneSize != 16;
5460 });
5461}
5462
5463static void configureAIEVecV2Legalizations(ConversionTarget &target,
5464 TargetBackend backend) {
5465 target.addLegalOp<UnrealizedConversionCastOp>();
5466 target.addLegalOp<vector::ShapeCastOp>();
5467
5468 // A set recording the vector lane size and element width supported
5469 llvm::SmallSet<std::pair<unsigned, unsigned>, 16> laneSizeElWidthPairSet;
5470 laneSizeElWidthPairSet.insert({64, 8});
5471 laneSizeElWidthPairSet.insert({32, 16});
5472 laneSizeElWidthPairSet.insert({16, 32});
5473 laneSizeElWidthPairSet.insert({32, 32});
5474
5475 // A set recording the element width supported
5476 llvm::SmallSet<unsigned, 16> elWidthSet;
5477 elWidthSet.insert(8);
5478 elWidthSet.insert(16);
5479 elWidthSet.insert(32);
5480
5481 if (backend == TargetBackend::CPP) {
5482 target.addDynamicallyLegalOp<arith::AddIOp>([=](arith::AddIOp op) {
5483 auto resultType = dyn_cast<VectorType>(op.getType());
5484 if (!resultType)
5485 return true;
5486
5487 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
5488 unsigned laneSize = getVectorLaneSize(resultType);
5489
5490 return !laneSizeElWidthPairSet.count(
5491 std::make_pair(laneSize, resultElWidth));
5492 });
5493 }
5494
5495 target.addDynamicallyLegalOp<arith::SubIOp>([=](arith::SubIOp op) {
5496 auto resultType = dyn_cast<VectorType>(op.getType());
5497 if (!resultType)
5498 return true;
5499 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
5500 unsigned laneSize = getVectorLaneSize(resultType);
5501
5502 return !laneSizeElWidthPairSet.count(
5503 std::make_pair(laneSize, resultElWidth));
5504 });
5505
5506 target.addDynamicallyLegalOp<arith::AddFOp>([](arith::AddFOp op) {
5507 auto resultType = dyn_cast<VectorType>(op.getType());
5508 if (!resultType)
5509 return true;
5510
5511 Type scalarType = resultType.getElementType();
5512 unsigned laneSize = getVectorLaneSize(resultType);
5513 unsigned resultElWidth = scalarType.getIntOrFloatBitWidth();
5514
5515 // Support laneSize == 16 for f32/bf16
5516 if (laneSize == 16)
5517 return false; // illegal - will be converted
5518 // Support laneSize == 32 for bf16 (split into two v16bf16 ops)
5519 if (laneSize == 32 && resultElWidth == 16)
5520 return false; // illegal - will be split
5521 // Support laneSize == 32 for f32 (split into two v16f32 ops)
5522 if (laneSize == 32 && resultElWidth == 32)
5523 return false; // illegal - will be split into two v16f32 ops
5524
5525 return true; // legal - not supported
5526 });
5527
5528 target.addDynamicallyLegalOp<arith::SubFOp>([](arith::SubFOp op) {
5529 auto resultType = dyn_cast<VectorType>(op.getType());
5530 if (!resultType)
5531 return true;
5532
5533 Type scalarType = resultType.getElementType();
5534 unsigned laneSize = getVectorLaneSize(resultType);
5535 unsigned resultElWidth = scalarType.getIntOrFloatBitWidth();
5536
5537 // Support laneSize == 16 for f32/bf16
5538 if (laneSize == 16)
5539 return false; // illegal - will be converted
5540 // Support laneSize == 32 for bf16 (split into two v16bf16 ops)
5541 if (laneSize == 32 && resultElWidth == 16)
5542 return false; // illegal - will be split
5543 // Support laneSize == 32 for f32 (split into two v16f32 ops)
5544 if (laneSize == 32 && resultElWidth == 32)
5545 return false; // illegal - will be split into two v16f32 ops
5546
5547 return true; // legal - not supported
5548 });
5549
5550 target.addDynamicallyLegalOp<arith::MulIOp>([](arith::MulIOp op) {
5551 auto resultType = dyn_cast<VectorType>(op.getType());
5552 if (!resultType)
5553 return true;
5554 auto isAddOp = [&](Operation *op) { return isa<arith::AddIOp>(op); };
5555 // Verify it is not a part of MAC
5556 if (op->hasOneUse() && llvm::any_of(op->getUsers(), isAddOp))
5557 return true;
5558
5559 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
5560 unsigned laneSize = getVectorLaneSize(resultType);
5561
5562 return (laneSize != 32 || (resultElWidth != 16 && resultElWidth != 8)) &&
5563 ((laneSize != 16 && laneSize != 32) || resultElWidth != 32);
5564 });
5565
5566 target.addDynamicallyLegalOp<arith::MulFOp>([](arith::MulFOp op) {
5567 auto resultType = dyn_cast<VectorType>(op.getType());
5568 if (!resultType)
5569 return true;
5570
5571 auto isAddOp = [&](Operation *op) { return isa<arith::AddFOp>(op); };
5572 // Verify it is not a part of FMA
5573 if (op->hasOneUse() && llvm::any_of(op->getUsers(), isAddOp))
5574 return true;
5575
5576 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
5577 unsigned laneSize = getVectorLaneSize(resultType);
5578
5579 // Support laneSize == 16 for bf16/f32, and laneSize == 32 for bf16 (split)
5580 if (laneSize == 16 && (resultElWidth == 16 || resultElWidth == 32))
5581 return false; // illegal - will be converted
5582 if (laneSize == 32 && resultElWidth == 16)
5583 return false; // illegal - will be split into two v16bf16 ops
5584
5585 return true; // legal - not supported
5586 });
5587
5588 target.addDynamicallyLegalOp<arith::MinSIOp>([=](arith::MinSIOp op) {
5589 auto resultType = dyn_cast<VectorType>(op.getType());
5590 if (!resultType) {
5591 // Scalar i8/i16/i32 minsi: illegal unless in compound SRS chain
5592 if (auto intType = dyn_cast<IntegerType>(op.getType())) {
5593 unsigned w = intType.getWidth();
5594 if (w == 8 || w == 16 || w == 32) {
5595 if (scalarClampInCompoundSRS(op))
5596 return true; // legal — compound pattern consumes
5597 return false; // illegal — individual pattern promotes
5598 }
5599 }
5600 return true;
5601 }
5602
5603 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
5604 unsigned laneSize = getVectorLaneSize(resultType);
5605
5606 return !elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512;
5607 });
5608
5609 target.addDynamicallyLegalOp<arith::MaxSIOp>([=](arith::MaxSIOp op) {
5610 auto resultType = dyn_cast<VectorType>(op.getType());
5611 if (!resultType) {
5612 // Scalar i8/i16/i32 maxsi: illegal unless in compound SRS chain
5613 if (auto intType = dyn_cast<IntegerType>(op.getType())) {
5614 unsigned w = intType.getWidth();
5615 if (w == 8 || w == 16 || w == 32) {
5616 if (scalarClampInCompoundSRS(op))
5617 return true; // legal — compound pattern consumes
5618 return false; // illegal — individual pattern promotes
5619 }
5620 }
5621 return true;
5622 }
5623
5624 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
5625 unsigned laneSize = getVectorLaneSize(resultType);
5626
5627 return !elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512;
5628 });
5629
5630 target.addDynamicallyLegalOp<arith::MinimumFOp>([=](arith::MinimumFOp op) {
5631 auto resultType = dyn_cast<VectorType>(op.getType());
5632 if (!resultType)
5633 return true;
5634
5635 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
5636 unsigned laneSize = getVectorLaneSize(resultType);
5637 unsigned totalBits = laneSize * resultElWidth;
5638
5639 return !elWidthSet.count(resultElWidth) ||
5640 (totalBits != 512 && !(totalBits == 256 && resultElWidth == 16));
5641 });
5642
5643 target.addDynamicallyLegalOp<arith::MaximumFOp>([=](arith::MaximumFOp op) {
5644 auto resultType = dyn_cast<VectorType>(op.getType());
5645 if (!resultType)
5646 return true;
5647
5648 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
5649 unsigned laneSize = getVectorLaneSize(resultType);
5650 unsigned totalBits = laneSize * resultElWidth;
5651
5652 return !elWidthSet.count(resultElWidth) ||
5653 (totalBits != 512 && !(totalBits == 256 && resultElWidth == 16));
5654 });
5655
5656 target.addDynamicallyLegalOp<arith::MaxNumFOp>([=](arith::MaxNumFOp op) {
5657 auto resultType = dyn_cast<VectorType>(op.getType());
5658 if (!resultType)
5659 return true;
5660
5661 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
5662 unsigned laneSize = getVectorLaneSize(resultType);
5663 unsigned totalBits = laneSize * resultElWidth;
5664
5665 return !elWidthSet.count(resultElWidth) ||
5666 (totalBits != 512 && !(totalBits == 256 && resultElWidth == 16));
5667 });
5668
5669 target.addDynamicallyLegalOp<arith::MinNumFOp>([=](arith::MinNumFOp op) {
5670 auto resultType = dyn_cast<VectorType>(op.getType());
5671 if (!resultType)
5672 return true;
5673
5674 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
5675 unsigned laneSize = getVectorLaneSize(resultType);
5676 unsigned totalBits = laneSize * resultElWidth;
5677
5678 return !elWidthSet.count(resultElWidth) ||
5679 (totalBits != 512 && !(totalBits == 256 && resultElWidth == 16));
5680 });
5681
5682 target.addDynamicallyLegalOp<arith::CmpIOp>([=](arith::CmpIOp op) {
5683 auto lhsType = dyn_cast<VectorType>(op.getLhs().getType());
5684 if (!lhsType)
5685 return true;
5686
5687 auto lhsElWidth = lhsType.getElementType().getIntOrFloatBitWidth();
5688 unsigned laneSize = getVectorLaneSize(lhsType);
5689 unsigned totalBits = laneSize * lhsElWidth;
5690
5691 return !elWidthSet.count(lhsElWidth) ||
5692 (totalBits != 512 && !(totalBits == 256 && lhsElWidth == 16));
5693 });
5694
5695 target.addDynamicallyLegalOp<arith::CmpFOp>([=](arith::CmpFOp op) {
5696 auto lhsType = dyn_cast<VectorType>(op.getLhs().getType());
5697 if (!lhsType)
5698 return true;
5699
5700 auto lhsElWidth = lhsType.getElementType().getIntOrFloatBitWidth();
5701 unsigned laneSize = getVectorLaneSize(lhsType);
5702 unsigned totalBits = laneSize * lhsElWidth;
5703
5704 return !elWidthSet.count(lhsElWidth) ||
5705 (totalBits != 512 && !(totalBits == 256 && lhsElWidth == 16));
5706 });
5707
5708 target.addDynamicallyLegalOp<arith::SelectOp>([=](arith::SelectOp op) {
5709 auto resultType = dyn_cast<VectorType>(op.getType());
5710 if (!resultType)
5711 return true;
5712
5713 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
5714 unsigned laneSize = getVectorLaneSize(resultType);
5715 unsigned totalBits = laneSize * resultElWidth;
5716
5717 return !elWidthSet.count(resultElWidth) ||
5718 (totalBits != 512 && !(totalBits == 256 && resultElWidth == 16));
5719 });
5720
5721 target.addDynamicallyLegalOp<vector::ReductionOp>(
5722 [=](vector::ReductionOp op) {
5723 if (auto kind = op.getKind(); kind != vector::CombiningKind::ADD &&
5724 kind != vector::CombiningKind::MINSI &&
5725 kind != vector::CombiningKind::MINUI &&
5726 kind != vector::CombiningKind::MINIMUMF &&
5727 kind != vector::CombiningKind::MINNUMF &&
5728 kind != vector::CombiningKind::MAXSI &&
5729 kind != vector::CombiningKind::MAXUI &&
5730 kind != vector::CombiningKind::MAXIMUMF &&
5731 kind != vector::CombiningKind::MAXNUMF)
5732 return true;
5733
5734 auto vType = dyn_cast<VectorType>(op.getVector().getType());
5735 if (!vType)
5736 return true;
5737
5738 llvm::SmallSet<std::pair<unsigned, signed>, 16> laneSizeElWidthPairSet;
5739 laneSizeElWidthPairSet.insert({64, 8});
5740 laneSizeElWidthPairSet.insert({32, 16});
5741 laneSizeElWidthPairSet.insert({32, 32});
5742 laneSizeElWidthPairSet.insert({16, 32});
5743
5744 Type scalarType = vType.getElementType();
5745 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5746 unsigned laneSize = getVectorLaneSize(vType);
5747
5748 if (isa<IntegerType>(scalarType) &&
5749 !laneSizeElWidthPairSet.count(std::make_pair(laneSize, elWidth)))
5750 return true;
5751
5752 if (isa<FloatType>(scalarType) && laneSize != 16 && laneSize != 32)
5753 return true;
5754
5755 return false;
5756 });
5757
5758 // AIE2-specific legalization: ExtFOp on vector is always illegal
5759 target.addDynamicallyLegalOp<arith::ExtFOp>([](arith::ExtFOp extfOp) {
5760 auto srcType = dyn_cast<VectorType>(extfOp.getIn().getType());
5761 auto dstType = dyn_cast<VectorType>(extfOp.getOut().getType());
5762 if (!srcType || !dstType)
5763 return true;
5764
5765 Type srcScalarType = srcType.getElementType();
5766 Type dstScalarType = dstType.getElementType();
5767 if (!isa<FloatType>(srcScalarType) || !isa<FloatType>(dstScalarType))
5768 return true;
5769
5770 unsigned srcLaneSize = getVectorLaneSize(srcType);
5771 unsigned dstLaneSize = getVectorLaneSize(dstType);
5772 if ((srcLaneSize % 16 == 0) && (dstLaneSize % 16 == 0))
5773 return false;
5774
5775 return true;
5776 });
5777
5778 // AIE2-specific legalization: TruncFOp on vector is always illegal
5779 target.addDynamicallyLegalOp<arith::TruncFOp>([](arith::TruncFOp truncfOp) {
5780 auto srcType = dyn_cast<VectorType>(truncfOp.getIn().getType());
5781 auto dstType = dyn_cast<VectorType>(truncfOp.getOut().getType());
5782 if (!srcType || !dstType)
5783 return true;
5784
5785 Type srcScalarType = srcType.getElementType();
5786 Type dstScalarType = dstType.getElementType();
5787 if (!isa<FloatType>(srcScalarType) || !isa<FloatType>(dstScalarType))
5788 return true;
5789
5790 unsigned srcLaneSize = getVectorLaneSize(srcType);
5791 unsigned dstLaneSize = getVectorLaneSize(dstType);
5792 if ((srcLaneSize % 16 == 0) && (dstLaneSize % 16 == 0))
5793 return false;
5794
5795 return true;
5796 });
5797
5798 // AIE2-specific legalization: ExtSIOp on vector is always illegal
5799 target.addDynamicallyLegalOp<arith::ExtSIOp>([](arith::ExtSIOp extsiOp) {
5800 auto srcType = dyn_cast<VectorType>(extsiOp.getIn().getType());
5801 auto dstType = dyn_cast<VectorType>(extsiOp.getOut().getType());
5802 if (!srcType || !dstType)
5803 return true;
5804
5805 Type srcScalarType = srcType.getElementType();
5806 Type dstScalarType = dstType.getElementType();
5807 if (!isa<IntegerType>(srcScalarType) || !isa<IntegerType>(dstScalarType))
5808 return true;
5809
5810 unsigned srcLaneSize = getVectorLaneSize(srcType);
5811 unsigned dstLaneSize = getVectorLaneSize(dstType);
5812 if ((srcLaneSize % 16 == 0) && (dstLaneSize % 16 == 0))
5813 return false;
5814
5815 return true;
5816 });
5817
5818 // AIE2-specific legalization: TruncIOp on vector is always illegal
5819 target.addDynamicallyLegalOp<arith::TruncIOp>([](arith::TruncIOp trunciOp) {
5820 auto srcType = dyn_cast<VectorType>(trunciOp.getIn().getType());
5821 auto dstType = dyn_cast<VectorType>(trunciOp.getOut().getType());
5822 if (!srcType || !dstType) {
5823 // Scalar trunci: mark illegal if part of compound SRS chain
5824 if (!srcType && !dstType && isSRSCompoundCandidate(trunciOp))
5825 return false;
5826 return true;
5827 }
5828 Type srcScalarType = srcType.getElementType();
5829 Type dstScalarType = dstType.getElementType();
5830 if (!isa<IntegerType>(srcScalarType) || !isa<IntegerType>(dstScalarType))
5831 return true;
5832
5833 // Also mark as illegal if it's part of a shrsi+clamp+trunc SRS pattern,
5834 // even for sub-AIE-width vectors
5835 if (isSRSCompoundCandidate(trunciOp))
5836 return false;
5837
5838 unsigned srcLaneSize = getVectorLaneSize(srcType);
5839 unsigned dstLaneSize = getVectorLaneSize(dstType);
5840 if ((srcLaneSize % 16 == 0) && (dstLaneSize % 16 == 0))
5841 return false;
5842
5843 return true;
5844 });
5845
5846 target.addIllegalOp<vector::ContractionOp, vector::TransposeOp,
5847 vector::FMAOp>();
5848
5849 // AIE2-specific legalization: math.exp for v16bf16 and v32bf16 is illegal
5850 // (uses LUT)
5851 target.addDynamicallyLegalOp<math::ExpOp>([](math::ExpOp expOp) {
5852 auto srcType = dyn_cast<VectorType>(expOp.getOperand().getType());
5853 if (!srcType)
5854 return true;
5855
5856 Type scalarType = srcType.getElementType();
5857 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5858 unsigned laneSize = getVectorLaneSize(srcType);
5859 // AIE2: v16bf16 and v32bf16 are illegal (uses LUT-based lowering)
5860 if (!isa<FloatType>(scalarType) || (laneSize != 16 && laneSize != 32) ||
5861 elWidth != 16)
5862 return true;
5863 if (expOp->hasOneUse() && isInSigmoidOperationChain(expOp))
5864 return true;
5865
5866 return false;
5867 });
5868
5869 target.addDynamicallyLegalOp<math::RsqrtOp>([](math::RsqrtOp rsqrtOp) {
5870 auto srcType = dyn_cast<VectorType>(rsqrtOp.getOperand().getType());
5871 if (!srcType)
5872 return true;
5873
5874 Type scalarType = srcType.getElementType();
5875 if (!isa<FloatType>(scalarType))
5876 return true;
5877
5878 unsigned laneSize = getVectorLaneSize(srcType);
5879 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
5880 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
5881 });
5882}
5883
5884//===----------------------------------------------------------------------===//
5885// Lowering passes
5886//===----------------------------------------------------------------------===//
5887
5888/// Lower incoming vector operations into their corresponding AIE vector
5889/// intrinsics.
5890struct LowerVectorToAIEVec : PassWrapper<LowerVectorToAIEVec, OperationPass<>> {
5891 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerVectorToAIEVec)
5892
5895
5901
5902 // In case we want to register this pass as a standalone pass for test
5903 // purposes.
5904 StringRef getArgument() const final { return "test-lower-vector-to-aievec"; }
5905 StringRef getDescription() const final {
5906 return "Lower vector operations to AIE vector intrinsics";
5907 }
5908 void getDependentDialects(DialectRegistry &registry) const override {
5909 registry
5910 .insert<affine::AffineDialect, xilinx::aievec::aie1::AIEVecAIE1Dialect,
5911 xilinx::aievec::AIEVecDialect, arith::ArithDialect,
5912 memref::MemRefDialect, scf::SCFDialect, vector::VectorDialect,
5913 emitc::EmitCDialect>();
5914 }
5915
5916 Option<std::string> aieTarget{
5917 *this, "aie-target",
5918 llvm::cl::desc(
5919 "Select AIE version: \"aie\", \"aie2\", or \"aie2p\". This will "
5920 "determine the vector size and available operations."),
5921 llvm::cl::init("aie")};
5922
5923 Option<std::string> targetBackend{
5924 *this, "target-backend",
5925 llvm::cl::desc("Select translation backend: \"cpp\" or \"llvmir\". This "
5926 "will determine the aievec operations used to convert "
5927 "from vector dialect."),
5928 llvm::cl::init("cpp")};
5929
5930 void runOnOperation() override {
5931 auto *op = getOperation();
5932 MLIRContext *context = &getContext();
5933 RewritePatternSet patterns(context);
5934 ConversionTarget target(*context);
5935 auto aieVersion = AIEArch::AIE;
5936 if (!aieTarget.empty()) {
5937 std::string targetStr = aieTarget;
5938 if (targetStr == "aieml" || targetStr == "aie2")
5939 aieVersion = AIEArch::AIE2;
5940 else if (targetStr == "aie2p")
5941 aieVersion = AIEArch::AIE2P;
5942 else if (targetStr != "aie") {
5943 op->emitError() << "unknown AIE target '" << aieTarget << "'";
5944 return signalPassFailure();
5945 }
5946 }
5947
5948 TargetBackend backend = TargetBackend::CPP;
5949 if (!targetBackend.empty()) {
5950 std::string backendStr = targetBackend;
5951 if (backendStr == "llvmir") {
5952 backend = TargetBackend::LLVMIR;
5953 if (aieVersion == AIEArch::AIE) {
5954 op->emitError() << "targetting LLVM IR is not supported for AIEv1";
5955 signalPassFailure();
5956 return;
5957 }
5958 } else if (backendStr != "cpp") {
5959 op->emitError() << "unknown target backend '" << targetBackend << "'";
5960 signalPassFailure();
5961 return;
5962 }
5963 }
5964
5965 populateAIEVecCommonConversionPatterns(patterns, backend);
5966 configureAIEVecCommonLegalizations(target, backend);
5967 if (aieVersion == AIEArch::AIE) {
5968 populateAIEVecV1ConversionPatterns(patterns, backend);
5969 configureAIEVecV1Legalizations(target, backend);
5970 } else if (aieVersion == AIEArch::AIE2) {
5971 populateAIEVecV2ConversionPatterns(patterns, backend);
5972 configureAIEVecV2Legalizations(target, backend);
5973 } else if (aieVersion == AIEArch::AIE2P) {
5974 populateAIEVecV2PConversionPatterns(patterns, backend);
5975 configureAIEVecV2Legalizations(target, backend);
5976 configureAIEVecV2PLegalizations(target, backend);
5977 } else {
5978 llvm_unreachable("AIE version is misconfigured");
5979 }
5980
5981 if (failed(applyPartialConversion(op, target, std::move(patterns))))
5982 return signalPassFailure();
5983 }
5984};
5985
5986static std::unique_ptr<Pass>
5987createLowerVectorToAIEVec(const LowerVectorToAIEVecOptions &options) {
5988 return std::make_unique<LowerVectorToAIEVec>(options);
5989}
5990
5991//===---------------------------------------------------------------------------
5992// Custom canonicalization passes
5993//===---------------------------------------------------------------------------
5994
5995// This pass widens UPD ops to twice the width followed by an ext op of the
5996// bottom half. This can be used together with SimplifyUPDOpsPass to find
5997// additional common subexpressions with UPDs generated from unaligned
5998// `transfer_read` ops.
5999struct ExtendUPDOpsPass : PassWrapper<ExtendUPDOpsPass, OperationPass<>> {
6000
6001 void runOnOperation() override {
6002 MLIRContext *context = &getContext();
6003 RewritePatternSet patterns(context);
6004 ConversionTarget target(*context);
6005 patterns.add<ExpandUPDToUPDAndExtPattern>(patterns.getContext());
6006 target.addLegalDialect<aievec::AIEVecDialect>();
6007 target.addDynamicallyLegalOp<aievec::UPDOp>([](aievec::UPDOp op) {
6008 return op.getVector() ||
6009 (op->hasOneUse() && isa<aievec::UPDOp>(*op->getUsers().begin())) ||
6010 llvm::all_of(op->getUsers(),
6011 [](Operation *op) { return isa<aievec::ExtOp>(op); });
6012 });
6013
6014 if (auto *op = getOperation();
6015 failed(applyPartialConversion(op, target, std::move(patterns)))) {
6016 return signalPassFailure();
6017 }
6018 }
6019};
6020
6021// This pass replaces wide UPD ops that are only used by a single ext op of the
6022// bottom half. This pass undos the work of ExtendUPDOpsPass.
6023// TODO: This pass can be extended to work with wide UPD ops that are used by
6024// TODO: a single ext op of the top half, which might be a good opportunity to
6025// TODO: further optimize wide UPDs.
6026struct SimplifyUPDOpsPass : PassWrapper<SimplifyUPDOpsPass, OperationPass<>> {
6027
6028 void runOnOperation() override {
6029 MLIRContext *context = &getContext();
6030 RewritePatternSet patterns(context);
6031 ConversionTarget target(*context);
6032 patterns.add<FuseExtIntoUPDPattern>(patterns.getContext());
6033 target.addLegalDialect<aievec::AIEVecDialect>();
6034 target.addDynamicallyLegalOp<aievec::ExtOp>([](aievec::ExtOp op) {
6035 auto *defOp = op.getSource().getDefiningOp();
6036 return !defOp || !isa<aievec::UPDOp>(defOp) || !defOp->hasOneUse() ||
6037 op.getIndex() != 0;
6038 });
6039
6040 if (auto *op = getOperation();
6041 failed(applyPartialConversion(op, target, std::move(patterns)))) {
6042 return signalPassFailure();
6043 }
6044 }
6045};
6046
6047//============================================================================//
6048//=============== Main Vector2AIEVec Pipeline Configuration ==================//
6049//============================================================================//
6050
6052 OpPassManager &pm, const LowerVectorToAIEVecOptions &options) {
6053 // Add lowering from `Vector` to `AIEVec`
6054 pm.addPass(createLowerVectorToAIEVec(options));
6055 pm.addPass(createCanonicalizerPass());
6056
6057 // Simplify UPD ops
6058 pm.addPass(std::make_unique<ExtendUPDOpsPass>());
6059 pm.addPass(createCSEPass());
6060 pm.addPass(std::make_unique<SimplifyUPDOpsPass>());
6061 pm.addPass(createCanonicalizerPass());
6062}
LowerScalarMinMaxToAIEVecMinMaxOp< arith::MaxSIOp, aievec::MaxOp > LowerScalarMaxSIOpToAIEVecMaxOp
LowerVectorMinMaxOpToAIEVecMinMaxOp< arith::MaximumFOp, aievec::MaxOp > LowerVectorMaximumFOpToAIEVecMaxOp
ComputeBandAndBorOpPattern< arith::OrIOp, aievec::BorOp > ComputeBorOpPattern
ComputeBandAndBorOpPattern< arith::AndIOp, aievec::BandOp > ComputeBandOpPattern
OneToOneVectorOpToAIEVecOpPattern< arith::SubFOp, aievec::aie1::SubOp > LowerVectorSubFOpToAIEVecSubOp
ComputeAbsOpPattern< math::AbsIOp > ComputeAbsIOpPattern
LowerTruncOpPattern< arith::TruncFOp > LowerTruncFOpPattern
LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp< arith::AddFOp, aievec::AddElemOp > LowerVectorAddFOpToAIEVecAddElemOp
LowerScalarMinMaxToAIEVecMinMaxOp< arith::MinSIOp, aievec::MinOp > LowerScalarMinSIOpToAIEVecMinOp
LowerVectorMinMaxOpToAIEVecMinMaxOp< arith::MaxSIOp, aievec::MaxOp > LowerVectorMaxSIOpToAIEVecMaxOp
LowerExtOpPattern< arith::ExtFOp > LowerExtFOpPattern
LowerVectorCmpOpToAIEVecCmpOp< arith::CmpFOp, CmpFPredicate > LowerVectorCmpFOpToAIEVecCmpOp
OneToOneVectorOpToAIEVecOpPattern< arith::SubIOp, aievec::aie1::SubOp > LowerVectorSubIOpToAIEVecSubOp
ComputeAbsOpPattern< math::AbsFOp > ComputeAbsFOpPattern
LowerVectorCmpOpToAIEVecCmpOp< arith::CmpIOp, CmpIPredicate > LowerVectorCmpIOpToAIEVecCmpOp
LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp< arith::SubFOp, aievec::SubElemOp > LowerVectorSubFOpToAIEVecSubElemOp
LowerVectorMinMaxOpToAIEVecMinMaxOp< arith::MinSIOp, aievec::MinOp > LowerVectorMinSIOpToAIEVecMinOp
OneToOneVectorOpToAIEVecOpPattern< arith::AddFOp, aievec::aie1::AddOp > LowerVectorAddFOpToAIEVecAddOp
OneToOneVectorOpToAIEVecOpPattern< arith::MulFOp, aievec::aie1::MulOp > LowerVectorMulFOpToAIEVecMulOp
LowerVectorMinMaxOpToAIEVecMinMaxOp< arith::MaxNumFOp, aievec::MaxOp > LowerVectorMaxNumFFOpToAIEVecMaxOp
LowerExtOpPattern< arith::ExtSIOp > LowerExtSIOpPattern
LowerVectorMinMaxOpToAIEVecMinMaxOp< arith::MinimumFOp, aievec::MinOp > LowerVectorMinimumFOpToAIEVecMinOp
LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp< arith::AddIOp, aievec::AddElemOp > LowerVectorAddIOpToAIEVecAddElemOp
PathEndPoint src
mlir::VectorType getFlattenedVectorType(mlir::VectorType vecTy)
unsigned getVectorLaneSize(mlir::VectorType type)
Definition AIEVecUtils.h:55
SmallVector< NamedAttribute > buildFMAOpSplatAttrForElemTy(aievec::aie1::FMAOp fmaOp, int64_t bcastPos, int64_t step=1)
std::optional< int64_t > getTransferReadAlignmentOffset(TransferReadLikeOp readOp, mlir::VectorType vType, int64_t alignment)
mlir::VectorType createVectorType(unsigned lanes, mlir::Type elementType)
Definition AIEVecUtils.h:42
int32_t getElementSizeInBits(mlir::VectorType type)
Definition AIEVecUtils.h:49
void buildLowerVectorToAIEVec(mlir::OpPassManager &pm, const LowerVectorToAIEVecOptions &options)
mlir::VectorType getVectorOpDestType(mlir::VectorType type, bool AIE2)
Definition AIEVecUtils.h:80
TargetBackend
Definition Passes.h:27
LogicalResult matchAndRewrite(SrcOpTy absOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::XOrIOp xorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::CeilOp ceilOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::ErfOp erfOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::ExpOp expOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::ExpOp expOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::FloorOp floorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::DivFOp divOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::NegFOp negOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::RsqrtOp rsqrtOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::RsqrtOp rsqrtOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::DivFOp divfOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::ShRSIOp rsOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::SqrtOp sqrtOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::TanhOp tanhOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::DivFOp divOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::ExpOp expOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::TanhOp tanhOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::AddFOp addOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertMulAddFToAIEVecFMAElemOpPattern(MLIRContext *context, unsigned shiftParam=0)
LogicalResult matchAndRewrite(arith::AddIOp addOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertMulAddToAIEVecFMAElemOpPattern(MLIRContext *context, unsigned shiftParam=0)
LogicalResult matchAndRewrite(aievec::aie1::AddOp addOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::MulFOp mulOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertMulFToAIEVecMulElemOpPattern(MLIRContext *context, unsigned shiftParam=0)
LogicalResult matchAndRewrite(arith::MulIOp mulOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertMulIToAIEVecMulElemOpPattern(MLIRContext *context, unsigned shiftParam=0)
LogicalResult matchAndRewrite(vector::BroadcastOp bcastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::BroadcastOp bcastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertVectorFMAOpToAIEVecFMAElemOpPattern(MLIRContext *context, unsigned shiftParam=0)
ExpandUPDToUPDAndExtPattern(MLIRContext *context)
LogicalResult matchAndRewrite(aievec::UPDOp updOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::aie1::FMAOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::BroadcastOp bcastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ExtOp extOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
FuseExtIntoUPDPattern(MLIRContext *context)
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(SrcOpTy extOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::ShRSIOp rsOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::TruncIOp truncOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
static std::optional< int64_t > getScalarConstantValue(Value val)
LogicalResult matchAndRewrite(SrcOpTy truncOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(arith::AddIOp addOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ContractionOp contractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LowerVectorContractionOpToAIEVecMatMulPattern(MLIRContext *context, bool matMoveToAcc=true)
LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::MulIOp mulOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::SelectOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Lower incoming vector operations into their corresponding AIE vector intrinsics.
void getDependentDialects(DialectRegistry &registry) const override
LowerVectorToAIEVec(const LowerVectorToAIEVecOptions &options)
StringRef getDescription() const final
Option< std::string > targetBackend
StringRef getArgument() const final
LogicalResult matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LowerVectorTransferReadToAIEUPD(MLIRContext *context, int64_t minVectorSize, int64_t maxVectorSize, int64_t alignment, int64_t maxLoadSize)
LogicalResult matchAndRewrite(vector::TransposeOp transpOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ShiftClampTruncToSRSPattern(MLIRContext *context, PatternBenefit benefit=2)
static std::optional< Value > getShiftValue(Value rhs, ConversionPatternRewriter &rewriter, Location loc)
static std::optional< int64_t > getConstantSplatValue(Value val)
LogicalResult matchAndRewrite(arith::TruncIOp truncOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Options for the "lower-vector-to-aievec" pipeline.
Definition Passes.h:64
PassOptions::Option< std::string > aieTarget
Definition Passes.h:65
PassOptions::Option< std::string > targetBackend
Definition Passes.h:70