MLIR-AIE
VectorToAIEVecConversions.cpp
Go to the documentation of this file.
1//===-VectorToAIEVecConversions.cpp - Vector to AIEVec convs. ---*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2023, Advanced Micro Devices, Inc.
8//
9//===----------------------------------------------------------------------===//
10// This file contains conversions from the Vector dialect into the AIEVec
11// dialect. Conversions assume that the Vector dialect has been rectricted
12// to ops that can be translated to a sequence of valid AIEVec ops.
13//===----------------------------------------------------------------------===//
14
19
21#include "mlir/Dialect/Affine/IR/AffineOps.h"
22#include "mlir/Dialect/EmitC/IR/EmitC.h"
23#include "mlir/Dialect/Func/IR/FuncOps.h"
24#include "mlir/Dialect/Math/IR/Math.h"
25#include "mlir/Dialect/MemRef/IR/MemRef.h"
26#include "mlir/Dialect/SCF/IR/SCF.h"
27#include "mlir/Dialect/UB/IR/UBOps.h"
28#include "mlir/IR/PatternMatch.h"
29#include "mlir/IR/SymbolTable.h"
30#include "mlir/IR/TypeUtilities.h"
31#include "mlir/Pass/PassManager.h"
32#include "mlir/Transforms/DialectConversion.h"
33#include "mlir/Transforms/Passes.h"
34#include "llvm/ADT/SmallSet.h"
35#include <bitset>
36#include <optional>
37#include <tuple>
38
39#define DEBUG_TYPE "lower-vector-to-aievec"
40
41using namespace llvm;
42using namespace mlir;
43using namespace arith;
44using namespace vector;
45using namespace xilinx;
46using namespace xilinx::aievec;
47
48//===----------------------------------------------------------------------===//
49// Utility functions
50//===----------------------------------------------------------------------===//
51
52static bool isNarrowingOp(Operation *op) {
53 if (isa<arith::TruncFOp>(op) || isa<arith::TruncIOp>(op))
54 return true;
55
56 if (auto srsOp = dyn_cast<aievec::SRSOp>(op)) {
57 auto *srsOpSrcOp = srsOp.getSource().getDefiningOp();
58 if (isa<aievec::UPSOp>(srsOpSrcOp) || isa<aievec::CastOp>(srsOpSrcOp))
59 return true;
60 }
61 return false;
62}
63
64// Given a Value, if it is defined by a widening op (arith:ExtSIOp,
65// arith::ExtUIOp, arith::ExtFOp, aievec::UPSOp + aievec::SRSOp,
66// aievec::UPSOp + aievec::CastOp), return the source of the widening op.
67static std::optional<Value> getSourceOfWideningOp(Value src) {
68 if (auto extSIOp = src.getDefiningOp<arith::ExtSIOp>())
69 return extSIOp.getIn();
70 if (auto extUIOp = src.getDefiningOp<arith::ExtUIOp>())
71 return extUIOp.getIn();
72 if (auto extFOp = src.getDefiningOp<arith::ExtFOp>())
73 return extFOp.getIn();
74 if (auto srsOp = src.getDefiningOp<aievec::SRSOp>()) {
75 // Conversion through AIE intrinsics takes two steps:
76 // 1) Load to accumulator: aievec.ups
77 // 2) Move from accumulator: aievec.srs
78 auto srsSource = srsOp.getSource();
79 if (srsSource)
80 if (auto upsOp = srsSource.getDefiningOp<aievec::UPSOp>())
81 return upsOp.getSource();
82 }
83 if (auto castOp = src.getDefiningOp<aievec::CastOp>()) {
84 // Conversion through AIE intrinsics can also take the following two steps:
85 // 1) Load to accumulator: aievec.ups
86 // 2) Move from accumulator: aievec.cast
87 auto castSource = castOp.getSource();
88 if (castSource)
89 if (auto upsOp = castSource.getDefiningOp<aievec::UPSOp>())
90 return upsOp.getSource();
91 }
92 return std::optional<Value>();
93}
94
95// Given the LHS and RHS of an `arith::AddIOp`, if one of them is defined by an
96// `arith::MulIOp`, return a tuple with the `lhs`, `rhs`, and `acc` of the MAC
97// operation that can replace them.
98static std::optional<std::tuple<Value, Value, Value>>
99extractMACOperandsFromAddOperands(Value addLhs, Value addRhs) {
100 auto *lhsDefOp = addLhs.getDefiningOp();
101 auto *rhsDefOp = addRhs.getDefiningOp();
102 arith::MulIOp mulOp = nullptr;
103 Value acc;
104 if (lhsDefOp) {
105 mulOp = dyn_cast<arith::MulIOp>(lhsDefOp);
106 acc = addRhs;
107 }
108 if (!mulOp && rhsDefOp) {
109 mulOp = dyn_cast<arith::MulIOp>(rhsDefOp);
110 acc = addLhs;
111 }
112 if (mulOp)
113 return std::make_tuple(mulOp.getLhs(), mulOp.getRhs(), acc);
114
115 // If the MulIOp has been already translated to aievec::aie1::MulOp:
116 auto lhsSrsOp = addLhs.getDefiningOp<aievec::SRSOp>();
117 auto rhsSrsOp = addRhs.getDefiningOp<aievec::SRSOp>();
118 aievec::aie1::MulOp aieMulOp = nullptr;
119 if (lhsSrsOp) {
120 aieMulOp = lhsSrsOp.getSource().getDefiningOp<aievec::aie1::MulOp>();
121 acc = addRhs;
122 }
123 if (!aieMulOp && rhsSrsOp) {
124 aieMulOp = rhsSrsOp.getSource().getDefiningOp<aievec::aie1::MulOp>();
125 acc = addLhs;
126 }
127 if (aieMulOp)
128 return std::make_tuple(aieMulOp.getLhs(), aieMulOp.getRhs(), acc);
129 return {};
130}
131
132// Convert a input value to a target vector type. This function can insert
133// multiple aievec ops depending on the combination of input and output vector
134// types.
135static std::optional<Value>
136convertValueToTargetTypeAIE2(ConversionPatternRewriter &rewriter, Location loc,
137 Value inputVal, VectorType tgtType) {
138 auto srcType = cast<VectorType>(inputVal.getType());
139 auto srcElemType = srcType.getElementType();
140 unsigned srcBitWidth = srcElemType.getIntOrFloatBitWidth();
141 unsigned srcLaneSize = getVectorLaneSize(srcType);
142
143 auto tgtElemType = tgtType.getElementType();
144 unsigned tgtBitWidth = tgtElemType.getIntOrFloatBitWidth();
145 unsigned tgtLaneSize = getVectorLaneSize(tgtType);
146
147 if (srcType == tgtType)
148 return inputVal;
149
150 if ((srcElemType == tgtElemType) && (srcLaneSize != tgtLaneSize)) {
151 // TODO: relax the condition below?
152 if ((srcLaneSize == 16 && tgtLaneSize == 32 &&
153 isa<FloatType>(srcElemType)) ||
154 (srcLaneSize == 32 && tgtLaneSize == 64 &&
155 isa<IntegerType>(srcElemType))) {
156 auto zeroConstOp = rewriter.create<arith::ConstantOp>(
157 loc, srcType.getElementType(),
158 rewriter.getZeroAttr(srcType.getElementType()));
159 auto broadcastZeroOp = rewriter.create<aievec::BroadcastScalarOp>(
160 loc, tgtType, zeroConstOp->getResult(0));
161 auto extOp = rewriter.create<aievec::ExtOp>(
162 loc, srcType, broadcastZeroOp->getResult(0), 0);
163
164 SmallVector<Value> inputSources = {inputVal, extOp->getResult(0)};
165 auto concatOp =
166 rewriter.create<aievec::ConcatOp>(loc, tgtType, inputSources);
167
168 return concatOp.getResult();
169 }
170 } else if ((srcElemType != tgtElemType) && (srcLaneSize == tgtLaneSize) &&
171 isa<IntegerType>(srcElemType) && isa<IntegerType>(tgtElemType)) {
172 if (srcBitWidth == 16 && tgtBitWidth == 32 && srcLaneSize == 16) {
173 // Case 1: vector<16xi16> to vector<16xi32> conversion by aievec.ups +
174 // aievec.cast
175 auto accType = getVectorOpDestType(srcType, /*AIE2 =*/true);
176 auto upsOp = rewriter.create<aievec::UPSOp>(loc, accType, inputVal);
177 auto castOp = rewriter.create<aievec::CastOp>(
178 loc, tgtType, upsOp.getResult(), /*isResAcc*/ false);
179 return castOp.getResult();
180 }
181
182 if (srcBitWidth == 8 && tgtBitWidth == 32 && srcLaneSize == 16) {
183 // Case 2: vector<16xi8> to vector<16xi32> conversion by aievec.concat +
184 // aievec.ups + aievec.cast + aievec.ext
185 auto concatOutType = createVectorType(32, srcElemType);
186 auto concatOp = rewriter.create<aievec::ConcatOp>(
187 loc, concatOutType, SmallVector<Value>({inputVal, inputVal}));
188 auto accType = getVectorOpDestType(concatOutType, /*AIE2 =*/true);
189 auto upsOp =
190 rewriter.create<aievec::UPSOp>(loc, accType, concatOp.getResult());
191 auto castType = createVectorType(32, tgtElemType);
192 auto castOp = rewriter.create<aievec::CastOp>(
193 loc, castType, upsOp.getResult(), /*isResAcc*/ false);
194 auto extOp =
195 rewriter.create<aievec::ExtOp>(loc, tgtType, castOp.getResult(), 0);
196 return extOp.getResult();
197 }
198
199 if (srcBitWidth == 8 && tgtBitWidth == 16 && srcLaneSize == 32) {
200 // Case 3: vector<32xi8> to vector<32xi16> conversion by aievec.unpack
201 auto unpackOp = rewriter.create<aievec::UnpackOp>(loc, tgtType, inputVal);
202 return unpackOp.getResult();
203 }
204 }
205
206 return std::nullopt;
207}
208
209// Return the list of attributes that configure an `aievec.select` op to
210// perform a rotation of the input vector by `rotation` number of elements.
211// The attribute values depend on the vector type of the select operation.
212static SmallVector<NamedAttribute>
213buildAttributeListForRotationSelectOp(PatternRewriter &rewriter, VectorType vTy,
214 int64_t rotation) {
215 unsigned width = 0;
216 auto elemTy = vTy.getElementType();
217 if (auto intTy = dyn_cast<IntegerType>(elemTy))
218 width = intTy.getWidth();
219 StringAttr attr0 = rewriter.getStringAttr("0");
220 StringAttr attr0x06040200 = rewriter.getStringAttr("0x06040200");
221 StringAttr attr0x0e0c0a08 = rewriter.getStringAttr("0x0e0c0a08");
222 StringAttr attr0x2103 = rewriter.getStringAttr("0x2103");
223 StringAttr attr0x3210 = rewriter.getStringAttr("0x3210");
224 StringAttr selectAttrName = rewriter.getStringAttr("select");
225 StringAttr xoffsetsAttrName = rewriter.getStringAttr("xoffsets");
226 StringAttr xoffsetsHiAttrName = rewriter.getStringAttr("xoffsets_hi");
227 StringAttr xsquareAttrName = rewriter.getStringAttr("xsquare");
228 StringAttr xstartAttrName = rewriter.getStringAttr("xstart");
229 StringAttr yoffsetsAttrName = rewriter.getStringAttr("yoffsets");
230 StringAttr yoffsetsHiAttrName = rewriter.getStringAttr("yoffsets_hi");
231 StringAttr ysquareAttrName = rewriter.getStringAttr("ysquare");
232 StringAttr ystartAttrName = rewriter.getStringAttr("ystart");
233
234 switch (width) {
235 case 16: {
236 if (rotation % 2) {
237 int64_t xstart = rotation + 1;
238 int64_t ystart = rotation - 1;
239 return SmallVector<NamedAttribute, 9>(
240 {{selectAttrName, rewriter.getStringAttr("0x11111111")},
241 {xoffsetsAttrName, attr0x06040200},
242 {xoffsetsHiAttrName, attr0x0e0c0a08},
243 {xsquareAttrName, attr0x2103},
244 {xstartAttrName, rewriter.getStringAttr(std::to_string(xstart))},
245 {yoffsetsAttrName, rewriter.getStringAttr("0x0503010f")},
246 {yoffsetsHiAttrName, rewriter.getStringAttr("0x0d0b0907")},
247 {ysquareAttrName, attr0x2103},
248 {ystartAttrName, rewriter.getStringAttr(std::to_string(ystart))}});
249 }
250 return SmallVector<NamedAttribute, 9>(
251 {{selectAttrName, attr0},
252 {xoffsetsAttrName, attr0x06040200},
253 {xoffsetsHiAttrName, attr0x0e0c0a08},
254 {xsquareAttrName, attr0x3210},
255 {xstartAttrName, rewriter.getStringAttr(std::to_string(rotation))},
256 {yoffsetsAttrName, attr0},
257 {yoffsetsHiAttrName, attr0},
258 {ysquareAttrName, attr0},
259 {ystartAttrName, attr0}});
260 }
261 case 32:
262 return SmallVector<NamedAttribute, 7>(
263 {{selectAttrName, attr0},
264 {xoffsetsAttrName, rewriter.getStringAttr("0x76543210")},
265 {xsquareAttrName, attr0x3210},
266 {xstartAttrName, rewriter.getStringAttr(std::to_string(rotation))},
267 {yoffsetsAttrName, attr0},
268 {ysquareAttrName, attr0},
269 {ystartAttrName, attr0}});
270 default:
271 llvm::report_fatal_error("Unexpected width!");
272 }
273
274 return {};
275}
276
277namespace xilinx::aievec {
278
279SmallVector<NamedAttribute>
280buildFMAOpSplatAttrForElemTy(aievec::aie1::FMAOp fmaOp, int64_t bcastPos,
281 int64_t step = 1) {
282 unsigned width = 0;
283 auto elemTy = fmaOp.getLhs().getType().getElementType();
284 if (auto intTy = dyn_cast<IntegerType>(elemTy))
285 width = intTy.getWidth();
286 auto *ctx = fmaOp.getContext();
287 switch (width) {
288 case 16:
289 // NOTE: The pattern is:
290 // acc[0] = x[0] * z[bcastPos] + x[16] * z[bcastPos+step]
291 // acc[1] = x[1] * z[bcastPos] + x[17] * z[bcastPos+step]
292 // acc[2] = x[2] * z[bcastPos] + x[18] * z[bcastPos+step]
293 // acc[3] = x[3] * z[bcastPos] + x[19] * z[bcastPos+step]
294 // acc[4] = x[4] * z[bcastPos] + x[20] * z[bcastPos+step]
295 // acc[5] = x[5] * z[bcastPos] + x[21] * z[bcastPos+step]
296 // acc[6] = x[6] * z[bcastPos] + x[22] * z[bcastPos+step]
297 // acc[7] = x[7] * z[bcastPos] + x[23] * z[bcastPos+step]
298 // acc[8] = x[8] * z[bcastPos] + x[24] * z[bcastPos+step]
299 // acc[9] = x[9] * z[bcastPos] + x[25] * z[bcastPos+step]
300 // acc[10] = x[10] * z[bcastPos] + x[26] * z[bcastPos+step]
301 // acc[11] = x[11] * z[bcastPos] + x[27] * z[bcastPos+step]
302 // acc[12] = x[12] * z[bcastPos] + x[28] * z[bcastPos+step]
303 // acc[13] = x[13] * z[bcastPos] + x[29] * z[bcastPos+step]
304 // acc[14] = x[14] * z[bcastPos] + x[30] * z[bcastPos+step]
305 // acc[15] = x[15] * z[bcastPos] + x[31] * z[bcastPos+step]
306 return SmallVector<NamedAttribute, 11>(
307 {{fmaOp.getXstartAttrName(), StringAttr::get(ctx, "0")},
308 {fmaOp.getXoffsetsAttrName(), StringAttr::get(ctx, "0x73727170")},
309 {fmaOp.getXoffsetsHiAttrName(), StringAttr::get(ctx, "0x77767574")},
310 {fmaOp.getXstepAttrName(), fmaOp.getXstepAttr()},
311 {fmaOp.getXsquareAttrName(), StringAttr::get(ctx, "0x3120")},
312 {fmaOp.getZstartAttrName(),
313 StringAttr::get(ctx, std::to_string(bcastPos))},
314 {fmaOp.getZoffsetsAttrName(), StringAttr::get(ctx, "0")},
315 {fmaOp.getZoffsetsHiAttrName(), StringAttr::get(ctx, "0")},
316 {fmaOp.getZstepAttrName(), StringAttr::get(ctx, std::to_string(step))},
317 {fmaOp.getZsquareAttrName(), fmaOp.getZsquareAttr()},
318 {fmaOp.getFmsubAttrName(), fmaOp.getFmsubAttr()}});
319 case 32:
320 return SmallVector<NamedAttribute, 11>(
321 {{fmaOp.getXstartAttrName(), StringAttr::get(ctx, "0")},
322 {fmaOp.getXoffsetsAttrName(), StringAttr::get(ctx, "0x76543210")},
323 {fmaOp.getXoffsetsHiAttrName(), fmaOp.getXoffsetsHiAttr()},
324 {fmaOp.getXstepAttrName(), fmaOp.getXstepAttr()},
325 {fmaOp.getXsquareAttrName(), fmaOp.getXsquareAttr()},
326 {fmaOp.getZstartAttrName(),
327 StringAttr::get(ctx, std::to_string(bcastPos))},
328 {fmaOp.getZoffsetsAttrName(), StringAttr::get(ctx, "0x00000000")},
329 {fmaOp.getZoffsetsHiAttrName(), fmaOp.getZoffsetsHiAttr()},
330 {fmaOp.getZstepAttrName(), fmaOp.getZstepAttr()},
331 {fmaOp.getZsquareAttrName(), fmaOp.getZsquareAttr()},
332 {fmaOp.getFmsubAttrName(), fmaOp.getFmsubAttr()}});
333 default:
334 llvm::report_fatal_error("Unexpected width!");
335 }
336
337 return {};
338}
339
340} // namespace xilinx::aievec
341
342template <typename SrcOpTy, typename AIEv2ElemOp>
343static LogicalResult genAddElemAIE2(ConversionPatternRewriter &rewriter,
344 Value lval, Value rval, VectorType srcType,
345 SrcOpTy srcOp) {
346 auto lCastOp = rewriter.create<aievec::CastOp>(srcOp.getLoc(), srcType, lval,
347 /*isResAcc*/ true);
348 auto rCastOp = rewriter.create<aievec::CastOp>(srcOp.getLoc(), srcType, rval,
349 /*isResAcc*/ true);
350 auto elemOp = rewriter.create<AIEv2ElemOp>(
351 srcOp.getLoc(), lCastOp->getResult(0).getType(), lCastOp->getResult(0),
352 rCastOp->getResult(0));
353 rewriter.replaceOpWithNewOp<aievec::CastOp>(
354 srcOp, srcOp.getType(), elemOp.getResult(), /*isResAcc*/ false);
355 return success();
356}
357
358static arith::CmpIPredicate
359convertToIntegerPredicate(arith::CmpFPredicate pred) {
360 switch (pred) {
361 case CmpFPredicate::UEQ:
362 case CmpFPredicate::OEQ:
363 return CmpIPredicate::eq;
364 case CmpFPredicate::UGT:
365 return CmpIPredicate::ugt;
366 case CmpFPredicate::OGT:
367 return CmpIPredicate::sgt;
368 case CmpFPredicate::UGE:
369 return CmpIPredicate::uge;
370 case CmpFPredicate::OGE:
371 return CmpIPredicate::sge;
372 case CmpFPredicate::ULT:
373 return CmpIPredicate::ult;
374 case CmpFPredicate::OLT:
375 return CmpIPredicate::slt;
376 case CmpFPredicate::ULE:
377 return CmpIPredicate::ule;
378 case CmpFPredicate::OLE:
379 return CmpIPredicate::sle;
380 case CmpFPredicate::UNE:
381 case CmpFPredicate::ONE:
382 return CmpIPredicate::ne;
383 default:
384 llvm::report_fatal_error("Unexpected predicate!");
385 }
386}
387
388static arith::CmpIPredicate
389convertToIntegerPredicate(arith::CmpIPredicate pred) {
390 return pred;
391}
392
393static aievec::CmpOp createCmpOpAIE2(ConversionPatternRewriter &rewriter,
394 CmpIPredicate pred, Location loc,
395 Type type, Value lhs, Value rhs) {
396 switch (pred) {
397 case CmpIPredicate::eq:
398 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs, "eq");
399 case CmpIPredicate::ne:
400 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs, "ne");
401 case CmpIPredicate::slt:
402 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs, "slt");
403 case CmpIPredicate::ult:
404 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs, "ult");
405 case CmpIPredicate::sle:
406 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs, "sle");
407 case CmpIPredicate::ule:
408 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs, "ule");
409 case CmpIPredicate::sgt:
410 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs, "sgt");
411 case CmpIPredicate::ugt:
412 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs, "ugt");
413 case CmpIPredicate::sge:
414 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs, "sge");
415 case CmpIPredicate::uge:
416 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs, "uge");
417 }
418 return nullptr;
419}
420
421template <typename DstOpTy>
422static void generateAIEVecOpsForReductionOp(ConversionPatternRewriter &rewriter,
423 vector::ReductionOp srcOp,
424 int shiftIndex, Value curValue) {
425 assert(shiftIndex > 0 && (shiftIndex & (shiftIndex - 1)) == 0 &&
426 "shiftIndex must be power of 2");
427
428 Location loc = srcOp.getLoc();
429 auto vType = dyn_cast<VectorType>(curValue.getType());
430 Type scalarType = vType.getElementType();
431 Type vecType = curValue.getType();
432 DstOpTy curOp = nullptr;
433 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
434
435 for (int id = shiftIndex; id > 0; id /= 2) {
436 auto constOp = rewriter.create<arith::ConstantOp>(
437 loc, rewriter.getI32IntegerAttr(id * elWidth / 8));
438
439 auto shiftBytesOp = rewriter.create<aievec::ShiftOp>(
440 loc, vecType, curValue, curValue, constOp.getResult());
441
442 curOp = rewriter.create<DstOpTy>(loc, vecType, curValue,
443 shiftBytesOp.getResult());
444
445 curValue = curOp.getResult();
446 }
447
448 auto zeroConstOp =
449 rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
450 rewriter.replaceOpWithNewOp<aievec::ExtElemOp>(srcOp, scalarType, curOp,
451 zeroConstOp.getResult());
452}
453
454static func::FuncOp getOrInsertFuncDecl(ConversionPatternRewriter &rewriter,
455 mlir::ModuleOp parentModuleOp,
456 StringRef funcName, TypeRange inTypes,
457 TypeRange outTypes) {
458
459 mlir::OpBuilder::InsertionGuard insertGuard(rewriter);
460 rewriter.setInsertionPointToStart(
461 &parentModuleOp.getRegion().getBlocks().front());
462 SymbolTable st = SymbolTable(parentModuleOp);
463 func::FuncOp fnOpLookup = st.lookup<func::FuncOp>(funcName);
464 func::FuncOp fnOp;
465 // if the function is already declared, use the existing function, don't
466 // declare multiple times
467 if (fnOpLookup != NULL) {
468 fnOp = fnOpLookup;
469 } else {
470 StringAttr t1 = rewriter.getStringAttr("sym_visibility");
471 StringAttr t2 = rewriter.getStringAttr("private");
472 NamedAttribute funcAccess = NamedAttribute(t1, t2);
473 FunctionType fnType =
474 mlir::FunctionType::get(rewriter.getContext(), inTypes, outTypes);
475 fnOp = rewriter.create<func::FuncOp>(parentModuleOp.getLoc(), funcName,
476 fnType, funcAccess);
477 }
478 return fnOp;
479}
480
481static bool matchExpOpForLUT(math::ExpOp::Adaptor adaptor) {
482 auto srcType = dyn_cast<VectorType>(adaptor.getOperand().getType());
483
484 if (!srcType)
485 return false;
486
487 Type scalarType = srcType.getElementType();
488 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
489 unsigned laneSize = getVectorLaneSize(srcType);
490 return isa<FloatType>(scalarType) && laneSize == 16 && elWidth == 16;
491}
492
493//===----------------------------------------------------------------------===//
494// Rewrite patterns
495//===----------------------------------------------------------------------===//
496
497// This pattern fold `vector.extract` and `vector.splat` into
498// `aievec.broadcast` for AIE2
500 : OpConversionPattern<vector::SplatOp> {
501 using OpConversionPattern::OpConversionPattern;
502
503 LogicalResult
504 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
505 ConversionPatternRewriter &rewriter) const override {
506
507 auto extOp = adaptor.getInput().getDefiningOp<vector::ExtractOp>();
508
509 if (!extOp)
510 return failure();
511
512 auto src = extOp.getVector();
513 auto pos = extOp.getStaticPosition();
514 int64_t posVal = pos[0];
515 auto srcVecType = cast<VectorType>(src.getType());
516 auto resultType = cast<VectorType>(splatOp.getResult().getType());
517 if (srcVecType != resultType) {
518 if (srcVecType.getNumElements() != 2 * resultType.getNumElements())
519 return failure();
520 auto half = static_cast<int8_t>(posVal / resultType.getNumElements());
521 posVal -= half * resultType.getNumElements();
522 src = rewriter
523 .create<aievec::ExtOp>(extOp.getLoc(), resultType, src,
524 rewriter.getI8IntegerAttr(half))
525 .getResult();
526 }
527
528 unsigned elWidth = resultType.getElementType().getIntOrFloatBitWidth();
529
530 if (unsigned laneSize = getVectorLaneSize(resultType);
531 laneSize * elWidth == 512) {
532 // Common use case for the broadcast_elem intrinsic
533 rewriter.replaceOpWithNewOp<aievec::BroadcastOp>(splatOp, resultType, src,
534 posVal);
535 } else if (laneSize * elWidth == 256) {
536 // e.g. need v16bf16 due to the subsequent v16accfloat operation
537 VectorType aievecBcastType =
538 createVectorType(512 / elWidth, resultType.getElementType());
539 auto concatOp = rewriter.create<aievec::ConcatOp>(
540 splatOp.getLoc(), aievecBcastType, SmallVector<Value>({src, src}));
541 auto aieBcastOp = rewriter.create<aievec::BroadcastOp>(
542 splatOp.getLoc(), aievecBcastType, concatOp.getResult(), posVal);
543 rewriter.replaceOpWithNewOp<aievec::ExtOp>(splatOp, resultType,
544 aieBcastOp.getResult(), 0);
545 } else if (laneSize * elWidth == 1024) {
546 // e.g. need v32int32 due to the subsequent v32acc32 operation
547 VectorType aievecBcastType =
548 createVectorType(512 / elWidth, resultType.getElementType());
549 auto half = static_cast<int8_t>(posVal / resultType.getNumElements());
550 posVal -= half * resultType.getNumElements();
551 auto extOp =
552 rewriter.create<aievec::ExtOp>(splatOp.getLoc(), aievecBcastType, src,
553 rewriter.getI8IntegerAttr(half));
554 auto aieBcastOp = rewriter.create<aievec::BroadcastOp>(
555 splatOp.getLoc(), aievecBcastType, extOp.getResult(), posVal);
556 rewriter.replaceOpWithNewOp<aievec::ConcatOp>(
557 splatOp, resultType,
558 SmallVector<Value>({aieBcastOp.getResult(), aieBcastOp.getResult()}));
559 } else {
560 return failure();
561 }
562
563 return success();
564 }
565};
566
568 using OpConversionPattern::OpConversionPattern;
569
570 LogicalResult
571 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
572 ConversionPatternRewriter &rewriter) const override {
573
574 if (adaptor.getInput().getDefiningOp<vector::ExtractOp>())
575 return failure();
576
577 auto resultType = cast<VectorType>(splatOp.getResult().getType());
578 auto flatResultType = getFlattenedVectorType(resultType);
579 Type scalarType = resultType.getElementType();
580 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
581 unsigned laneSize = getVectorLaneSize(resultType);
582 auto src = splatOp.getInput();
583
584 if (laneSize * elWidth == 512) {
585 Value newOp = rewriter.create<aievec::BroadcastScalarOp>(
586 splatOp.getLoc(), flatResultType, src);
587 if (resultType != flatResultType)
588 newOp = rewriter.create<vector::ShapeCastOp>(splatOp.getLoc(),
589 resultType, newOp);
590 rewriter.replaceOp(splatOp, newOp);
591 return success();
592 }
593
594 if (laneSize * elWidth == 256) {
595 VectorType vecType = createVectorType(512 / elWidth, scalarType);
596 auto aieBcastOp = rewriter.create<aievec::BroadcastScalarOp>(
597 splatOp.getLoc(), vecType, src);
598 Value newOp = rewriter.create<aievec::ExtOp>(
599 splatOp.getLoc(), flatResultType, aieBcastOp.getResult(), 0);
600 if (resultType != flatResultType)
601 newOp = rewriter.create<vector::ShapeCastOp>(splatOp.getLoc(),
602 resultType, newOp);
603 rewriter.replaceOp(splatOp, newOp);
604 return success();
605 }
606
607 if (laneSize * elWidth == 1024) {
608 VectorType vecType = createVectorType(512 / elWidth, scalarType);
609 auto aieBcastOp = rewriter.create<aievec::BroadcastScalarOp>(
610 splatOp.getLoc(), vecType, src);
611 Value newOp = rewriter.create<aievec::ConcatOp>(
612 splatOp.getLoc(), flatResultType,
613 SmallVector<Value>({aieBcastOp.getResult(), aieBcastOp.getResult()}));
614 if (resultType != flatResultType)
615 newOp = rewriter.create<vector::ShapeCastOp>(splatOp.getLoc(),
616 resultType, newOp);
617 rewriter.replaceOp(splatOp, newOp);
618 return success();
619 }
620
621 return failure();
622 }
623};
624
625// This pattern replaces `arith.muli`+`arith.addi` on vectors with
626// `aievec.mac_elem`. This pattern works for AIE2.
628 : OpConversionPattern<arith::AddIOp> {
629 using OpConversionPattern::OpConversionPattern;
630
632 unsigned shiftParam = 0)
634
635 LogicalResult
636 matchAndRewrite(arith::AddIOp addOp, OpAdaptor adaptor,
637 ConversionPatternRewriter &rewriter) const override {
638 // Verify it's a vector operation
639 auto resultType = dyn_cast<VectorType>(addOp.getType());
640 if (!resultType)
641 return failure();
642
643 // Verify it can be replaced by a MAC
644 auto res =
645 extractMACOperandsFromAddOperands(adaptor.getLhs(), adaptor.getRhs());
646 if (!res)
647 return failure();
648 auto [lhs, rhs, acc] = *res;
649
650 // Verify the vector type is supported by AIE2
651 unsigned resultElWidth =
652 resultType.getElementType().getIntOrFloatBitWidth();
653 unsigned laneSize = getVectorLaneSize(resultType);
654
655 if ((laneSize != 32 || resultElWidth != 16) &&
656 (laneSize != 16 || resultElWidth != 32))
657 return failure();
658
659 Type accType = getVectorOpDestType(cast<VectorType>(acc.getType()),
660 /*AIE2 =*/true);
661 auto upsOp = rewriter.create<aievec::UPSOp>(addOp.getLoc(), accType, acc,
662 shiftParam);
663 auto fmaElemOp = rewriter.create<aievec::FMAElemOp>(
664 addOp.getLoc(), accType, lhs, rhs, upsOp.getResult(),
665 /*fmsub=*/false);
666
667 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
668 addOp.getLoc(), rewriter.getI32IntegerAttr(shiftParam));
669 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
670 addOp, resultType, fmaElemOp.getResult(), shiftParamOp.getResult());
671
672 return success();
673 }
674
675 unsigned shiftParam;
676};
677
678// Convert `vector.fma` to `aievec.mac_elem`. Only `vector<16xf32>` and
679// `vector<16xbf16>` operand types are supported. In the case of vectors with
680// `f32` elemental type, this pattern will try to match `bf16` to `f32`
681// widening ops in the `lhs` and `rhs` operands, or fail otherwise.
682// TODO: When sign extensions are not found, a conversion from `f32` to `bf16`
683// TODO: can be inserted to emulate `f32` fma with `bf16` logic.
685 : OpConversionPattern<vector::FMAOp> {
686 using OpConversionPattern::OpConversionPattern;
687
689 unsigned shiftParam = 0)
691
692 LogicalResult
693 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
694 ConversionPatternRewriter &rewriter) const override {
695 // Verify the vector type is supported by AIE2
696 auto resVecTy = cast<VectorType>(fmaOp.getType());
697 auto resElemTy = resVecTy.getElementType();
698 unsigned numElems = getVectorLaneSize(resVecTy);
699
700 if (numElems != 16 || (!resElemTy.isF32() && !resElemTy.isBF16()))
701 return rewriter.notifyMatchFailure(
702 fmaOp, "Unsupported operand types in vector.fma lowering.");
703
704 Value lhs = adaptor.getLhs();
705 Value rhs = adaptor.getRhs();
706 Value acc = adaptor.getAcc();
707 if (resElemTy.isBF16())
708 acc = rewriter.create<aievec::UPSOp>(
709 fmaOp.getLoc(), VectorType::get({16}, rewriter.getF32Type()), acc,
710 shiftParam);
711 else {
712 lhs = getSourceOfWideningOp(lhs).value_or(nullptr);
713 rhs = getSourceOfWideningOp(rhs).value_or(nullptr);
714 if (!lhs || !rhs)
715 return rewriter.notifyMatchFailure(
716 fmaOp, "vector.fma operands are f32, and they don't come from "
717 "arith.extf on bf16; can't lower to aievec.");
718 if (!cast<VectorType>(lhs.getType()).getElementType().isBF16() ||
719 !cast<VectorType>(rhs.getType()).getElementType().isBF16())
720 return rewriter.notifyMatchFailure(
721 fmaOp, "vector.fma operands come from arith.extf, but the source "
722 "of the widening op is not bf16; can't lower to aievec.");
723 }
724 Value newOp = rewriter.create<aievec::FMAElemOp>(
725 fmaOp.getLoc(), acc.getType(), lhs, rhs, acc, /*fmsub=*/false);
726
727 if (resElemTy.isBF16()) {
728 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
729 fmaOp.getLoc(), rewriter.getI32IntegerAttr(shiftParam));
730 newOp = rewriter.create<aievec::SRSOp>(fmaOp.getLoc(), resVecTy, newOp,
731 shiftParamOp);
732 }
733
734 rewriter.replaceOp(fmaOp, newOp);
735
736 return success();
737 }
738
739 unsigned shiftParam;
740};
741
742// This pattern replaces `arith.mulf` on vectors with
743// `aievec.mul_elem`. This pattern works for AIE2.
745 : OpConversionPattern<arith::MulFOp> {
746 using OpConversionPattern::OpConversionPattern;
747
749 unsigned shiftParam = 0)
751
752 LogicalResult
753 matchAndRewrite(arith::MulFOp mulOp, OpAdaptor adaptor,
754 ConversionPatternRewriter &rewriter) const override {
755 // Verify it's a vector operation
756 auto resultType = dyn_cast<VectorType>(mulOp.getType());
757 if (!resultType)
758 return failure();
759
760 // FIXME: Verify it is not a part of FMA
761 auto isAddOp = [&](Operation *op) { return isa<arith::AddFOp>(op); };
762 if (mulOp->hasOneUse() && llvm::any_of(mulOp->getUsers(), isAddOp))
763 return failure();
764
765 unsigned resultElWidth =
766 resultType.getElementType().getIntOrFloatBitWidth();
767
768 unsigned laneSize = getVectorLaneSize(resultType);
769
770 // bfloat16 and float type
771 if (laneSize != 16 || (resultElWidth != 16 && resultElWidth != 32))
772 return failure();
773
774 // Decide the accType for aievec.mul_elem based on mulOp's lhs & rhs
775 auto lval = adaptor.getLhs();
776 auto rval = adaptor.getRhs();
777 lval = getSourceOfWideningOp(lval).value_or(lval);
778 rval = getSourceOfWideningOp(rval).value_or(rval);
779 auto lSrcType = cast<VectorType>(lval.getType());
780 auto rSrcType = cast<VectorType>(rval.getType());
781 unsigned lBitWidth = lSrcType.getElementType().getIntOrFloatBitWidth();
782 unsigned rBitWidth = rSrcType.getElementType().getIntOrFloatBitWidth();
783 Type accType = getVectorOpDestType(lSrcType, /*AIE2 =*/true);
784 if (rBitWidth > lBitWidth) {
785 accType = getVectorOpDestType(rSrcType, /*AIE2 =*/true);
786 }
787 // Only support the same lhs/rhs type at the moment
788 if (lSrcType != rSrcType) {
789 return failure();
790 }
791
792 // Prepare lhr/rhs for the aievec.mul_elem op
793 unsigned bitWidth = (rBitWidth > lBitWidth) ? rBitWidth : lBitWidth;
794 Type srcElemType = (rBitWidth > lBitWidth) ? rSrcType.getElementType()
795 : lSrcType.getElementType();
796 unsigned numLanes = 0;
797 if (isa<FloatType>(srcElemType) && (bitWidth == 16 || bitWidth == 32)) {
798 numLanes = 16;
799 } else if (isa<IntegerType>(srcElemType) &&
800 (bitWidth == 8 || bitWidth == 16)) {
801 numLanes = 32;
802 } else if (isa<IntegerType>(srcElemType) && (bitWidth == 32)) {
803 numLanes = 16;
804 } else {
805 return failure();
806 }
807 VectorType targetInputType = createVectorType(numLanes, srcElemType);
808 if (targetInputType != lSrcType) {
809 lval = convertValueToTargetTypeAIE2(rewriter, mulOp.getLoc(), lval,
810 targetInputType)
811 .value();
812 }
813 if (targetInputType != rSrcType) {
814 rval = convertValueToTargetTypeAIE2(rewriter, mulOp.getLoc(), rval,
815 targetInputType)
816 .value();
817 }
818 if (!lval || !rval)
819 return failure();
820
821 // Create an aievec.mul_elem op
822 auto mulElemOp =
823 rewriter.create<aievec::MulElemOp>(mulOp.getLoc(), accType, lval, rval);
824
825 // Create an aievec.cast or an aievec.srs op
826 auto mulElemResultType = mulElemOp.getType();
827 auto mulElemResultElWidth =
828 mulElemResultType.getElementType().getIntOrFloatBitWidth();
829
830 if (mulElemResultElWidth == resultElWidth) {
831 rewriter.replaceOpWithNewOp<aievec::CastOp>(
832 mulOp, resultType, mulElemOp.getResult(), /*isResAcc*/ false);
833 } else if (mulElemResultElWidth > resultElWidth) {
834 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
835 mulOp.getLoc(), rewriter.getI32IntegerAttr(shiftParam));
836 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
837 mulOp, resultType, mulElemOp.getResult(), shiftParamOp.getResult());
838 } else {
839 return failure();
840 }
841
842 return success();
843 }
844
845 unsigned shiftParam;
846};
847
848// This pattern replaces `arith.muli` on vectors with
849// `aievec.mul_elem`. This pattern works for AIE2.
851 : OpConversionPattern<arith::MulIOp> {
852 using OpConversionPattern::OpConversionPattern;
853
855 unsigned shiftParam = 0)
857
858 LogicalResult
859 matchAndRewrite(arith::MulIOp mulOp, OpAdaptor adaptor,
860 ConversionPatternRewriter &rewriter) const override {
861 // Verify it's a vector operation
862 auto resultType = dyn_cast<VectorType>(mulOp.getType());
863 if (!resultType)
864 return failure();
865
866 // FIXME: Verify it is not a part of MAC
867 auto isAddOp = [&](Operation *op) { return isa<arith::AddIOp>(op); };
868 if (mulOp->hasOneUse() && llvm::any_of(mulOp->getUsers(), isAddOp))
869 return failure();
870
871 // Verify the vector type is supported by AIE2
872 unsigned resultElWidth =
873 resultType.getElementType().getIntOrFloatBitWidth();
874 unsigned laneSize = getVectorLaneSize(resultType);
875
876 if ((laneSize != 32 || (resultElWidth != 16 && resultElWidth != 8)) &&
877 ((laneSize != 16 && laneSize != 32) || resultElWidth != 32))
878 return failure();
879
880 // Decide the accType for aievec.mul_elem based on mulOp's lhs & rhs
881 auto lval = adaptor.getLhs();
882 auto rval = adaptor.getRhs();
883
884 lval = getSourceOfWideningOp(lval).value_or(lval);
885 rval = getSourceOfWideningOp(rval).value_or(rval);
886
887 auto lSrcType = cast<VectorType>(lval.getType());
888 auto rSrcType = cast<VectorType>(rval.getType());
889 unsigned lBitWidth = lSrcType.getElementType().getIntOrFloatBitWidth();
890 unsigned rBitWidth = rSrcType.getElementType().getIntOrFloatBitWidth();
891 Type accType = getVectorOpDestType(lSrcType, /*AIE2 =*/true);
892 if (rBitWidth > lBitWidth) {
893 accType = getVectorOpDestType(rSrcType, /*AIE2 =*/true);
894 }
895
896 // Prepare lhr/rhs for the aievec.mul_elem op
897 unsigned bitWidth = (rBitWidth > lBitWidth) ? rBitWidth : lBitWidth;
898 Type srcElemType = (rBitWidth > lBitWidth) ? rSrcType.getElementType()
899 : lSrcType.getElementType();
900 unsigned numLanes = 0;
901 if (isa<FloatType>(srcElemType) && (bitWidth == 16 || bitWidth == 32)) {
902 numLanes = 16;
903 } else if (isa<IntegerType>(srcElemType) &&
904 (bitWidth == 8 || bitWidth == 16)) {
905 numLanes = 32;
906 } else if (isa<IntegerType>(srcElemType) && (bitWidth == 32)) {
907 numLanes = 16;
908 } else {
909 return failure();
910 }
911 VectorType targetInputType = createVectorType(numLanes, srcElemType);
912 if (targetInputType != lSrcType) {
913 lval = convertValueToTargetTypeAIE2(rewriter, mulOp.getLoc(), lval,
914 targetInputType)
915 .value();
916 }
917 if (targetInputType != rSrcType) {
918 rval = convertValueToTargetTypeAIE2(rewriter, mulOp.getLoc(), rval,
919 targetInputType)
920 .value();
921 }
922 if (!lval || !rval)
923 return failure();
924
925 // Create an aievec.mul_elem op
926 auto mulElemOp =
927 rewriter.create<aievec::MulElemOp>(mulOp.getLoc(), accType, lval, rval);
928
929 // Create an aievec.cast or an aievec.srs op
930 auto mulElemResultType = mulElemOp.getType();
931 auto mulElemResultElWidth =
932 mulElemResultType.getElementType().getIntOrFloatBitWidth();
933
934 if (mulElemResultElWidth == resultElWidth) {
935 rewriter.replaceOpWithNewOp<aievec::CastOp>(
936 mulOp, resultType, mulElemOp.getResult(), /*isResAcc*/ false);
937 } else if (mulElemResultElWidth > resultElWidth) {
938 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
939 mulOp.getLoc(), rewriter.getI32IntegerAttr(shiftParam));
940 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
941 mulOp, resultType, mulElemOp.getResult(), shiftParamOp.getResult());
942 } else {
943 return failure();
944 }
945
946 return success();
947 }
948
949 unsigned shiftParam;
950};
951
952// This pattern folds an extract + broadcast feeding into an
953// `aievec::aie1::FMAOp` into the op, using the shuffle attributes.
954struct FoldSplatToFMAOp : OpConversionPattern<aievec::aie1::FMAOp> {
955 using OpConversionPattern::OpConversionPattern;
956
957 LogicalResult
958 matchAndRewrite(aievec::aie1::FMAOp fmaOp, OpAdaptor adaptor,
959 ConversionPatternRewriter &rewriter) const override {
960 auto concatOp =
961 dyn_cast<aievec::ConcatOp>(adaptor.getLhs().getDefiningOp());
962 if (!concatOp)
963 return failure();
964 vector::SplatOp splatOp = nullptr;
965 auto *concatDefOp = concatOp.getSources()[0].getDefiningOp();
966 if (concatDefOp)
967 splatOp = dyn_cast<vector::SplatOp>(concatDefOp);
968 Value lhs = adaptor.getRhs();
969 if (!splatOp) {
970 splatOp = dyn_cast<vector::SplatOp>(adaptor.getRhs().getDefiningOp());
971 if (!splatOp)
972 return failure();
973 lhs = concatOp.getSources()[0];
974 }
975 auto extOp =
976 dyn_cast<vector::ExtractOp>(splatOp.getInput().getDefiningOp());
977 if (!extOp)
978 return failure();
979
980 auto rhs = extOp.getVector();
981 auto concatVecType = cast<VectorType>(concatOp.getResult().getType());
982 auto zvec = rewriter.create<arith::ConstantOp>(
983 concatOp.getLoc(), lhs.getType(), rewriter.getZeroAttr(lhs.getType()));
984 auto lhsX2 =
985 rewriter
986 .create<aievec::ConcatOp>(concatOp.getLoc(), concatVecType,
987 SmallVector<Value, 2>({lhs, zvec}))
988 .getResult();
989 // XXX: We assume a 1D vector
990 auto pos = extOp.getStaticPosition();
991 int64_t zstart = pos[0];
992 auto fmaOpAttr = buildFMAOpSplatAttrForElemTy(fmaOp, zstart);
993 rewriter.replaceOpWithNewOp<aievec::aie1::FMAOp>(
994 fmaOp, TypeRange({fmaOp.getResult().getType()}),
995 ValueRange({lhsX2, rhs, adaptor.getAcc()}), fmaOpAttr);
996
997 return success();
998 }
999};
1000
1002 : OpConversionPattern<aievec::aie1::AddOp> {
1003 using OpConversionPattern::OpConversionPattern;
1004
1005 LogicalResult
1006 matchAndRewrite(aievec::aie1::AddOp addOp, OpAdaptor adaptor,
1007 ConversionPatternRewriter &rewriter) const override {
1008 auto vecType = cast<VectorType>(addOp.getType());
1009
1010 auto res =
1011 extractMACOperandsFromAddOperands(adaptor.getLhs(), adaptor.getRhs());
1012 if (!res)
1013 return failure();
1014 auto [lhs, rhs, acc] = *res;
1015
1016 SmallVector<int64_t, 4> concatVecShape(vecType.getShape().begin(),
1017 vecType.getShape().end());
1018 concatVecShape[vecType.getRank() - 1] *= 2;
1019 auto concatVecType =
1020 VectorType::get(concatVecShape, vecType.getElementType());
1021 Type accType = getVectorOpDestType(cast<VectorType>(acc.getType()),
1022 /*AIE2 =*/false);
1023 auto lhsX2 = rewriter
1024 .create<aievec::ConcatOp>(addOp.getLoc(), concatVecType,
1025 SmallVector<Value, 2>(2, lhs))
1026 .getResult();
1027 auto upsOp = rewriter.create<aievec::UPSOp>(addOp.getLoc(), accType, acc);
1028 auto fmaOp = rewriter.create<aievec::aie1::FMAOp>(
1029 addOp.getLoc(), accType, lhsX2, rhs, upsOp.getResult(),
1030 /*xstart=*/"", /*xoffsets=*/"", /*xoffsets_hi=*/"", /*xstep=*/"",
1031 /*xsquare=*/"", /*zstart=*/"", /*zoffsets=*/"", /*zoffsets_hi=*/"",
1032 /*zstep=*/"", /*zsquare=*/"", /*fmsub=*/false);
1033 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
1034 addOp.getLoc(), rewriter.getI32IntegerAttr(0));
1035 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1036 addOp, vecType, fmaOp.getResult(), shiftParamOp.getResult());
1037 return success();
1038 }
1039};
1040
1041// This pattern replaces `vector.transfer_read` with `aievec.upd`. Right now,
1042// it performs a naïve direct translation. This needs to be expanded to
1043// support more complex scenarios.
1045 : OpConversionPattern<vector::TransferReadOp> {
1046 using OpConversionPattern::OpConversionPattern;
1047
1048 LowerVectorTransferReadToAIEUPD(MLIRContext *context, int64_t minVectorSize,
1049 int64_t maxVectorSize, int64_t alignment,
1050 int64_t maxLoadSize)
1054
1055 LogicalResult
1056 matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
1057 ConversionPatternRewriter &rewriter) const override {
1058 // Masked loads
1059 if (readOp.getMask())
1060 return readOp.emitError() << "AIE doesn't support masked loads.";
1061
1062 // Non-contiguous loads
1063 AffineMap map = readOp.getPermutationMap();
1064 if (!map.isMinorIdentity())
1065 return failure();
1066
1067 // Splats
1068 if (map.isConstant())
1069 return failure();
1070
1071 // Misaligned accesses
1072 auto vType = readOp.getVectorType();
1074 .value_or(0) != 0)
1075 return failure();
1076
1077 // Invalid vector size.
1078 // We can handle cases where the vector size is:
1079 // 1) the minimum vector size
1080 // 2) a square multiple of the alignment size and up to the maximum
1081 // vector size.
1082 int64_t vSize = vType.getNumElements() * vType.getElementTypeBitWidth();
1083 if (vSize > maxVectorSize ||
1084 (vSize % vectorAlignment && vSize != minVectorSize))
1085 return failure();
1086 // We can deal with linked update instructions when the vector size is
1087 // exactly twice the load size. This could change in future architectures
1088 if (vSize > maxLoadSize && vSize != maxLoadSize * 2)
1089 return failure();
1090 int64_t multiplicity = vSize / vectorAlignment;
1091 if ((vSize > minVectorSize) && std::bitset<8>(multiplicity).count() != 1)
1092 return failure();
1093
1094 auto updOp = rewriter.create<xilinx::aievec::UPDOp>(
1095 readOp.getLoc(), vType, adaptor.getBase(), adaptor.getIndices(), 0, 0,
1096 TypedValue<VectorType>(nullptr));
1097 if (vSize > maxLoadSize) {
1098 updOp = rewriter.create<xilinx::aievec::UPDOp>(
1099 readOp.getLoc(), vType, adaptor.getBase(), adaptor.getIndices(),
1100 maxLoadSize, 1, updOp.getResult());
1101 }
1102 rewriter.replaceOp(readOp, updOp.getResult());
1103
1104 return success();
1105 }
1106
1108};
1109
1110// XXX: Notice that this template doesn't verify that the vector element type
1111// XXX: is supported by the target architecture.
1112template <typename SrcOpTy, typename DstOpTy>
1115 using OpAdaptor = typename SrcOpTy::Adaptor;
1116
1117 LogicalResult
1118 matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor,
1119 ConversionPatternRewriter &rewriter) const override {
1120 rewriter.replaceOpWithNewOp<DstOpTy>(
1121 srcOp, srcOp.getResult().getType(), adaptor.getLhs(), adaptor.getRhs(),
1122 /*xstart=*/"", /*xoffsets=*/"", /*xoffsets_hi=*/"", /*xsquare=*/"",
1123 /*zstart=*/"", /*zoffsets=*/"", /*zoffsets_hi=*/"", /*zsquare=*/"");
1124 return success();
1125 }
1126};
1127
1129 using OpConversionPattern::OpConversionPattern;
1130
1131 LogicalResult
1132 matchAndRewrite(arith::AddIOp addOp, OpAdaptor adaptor,
1133 ConversionPatternRewriter &rewriter) const override {
1134 auto resType = addOp.getType();
1135 if (!isa<VectorType>(resType))
1136 return failure();
1137
1138 auto lhs = adaptor.getLhs();
1139 auto rhs = adaptor.getRhs();
1140 auto *lhsDefOp = lhs.getDefiningOp();
1141 auto *rhsDefOp = rhs.getDefiningOp();
1142 if ((isa_and_nonnull<arith::MulIOp>(lhsDefOp)) ||
1143 (isa_and_nonnull<arith::MulIOp>(rhsDefOp)))
1144 return failure();
1145
1146 rewriter.replaceOpWithNewOp<aievec::aie1::AddOp>(
1147 addOp, resType, lhs, rhs,
1148 /*xstart=*/"", /*xoffsets=*/"", /*xoffsets_hi=*/"", /*xsquare=*/"",
1149 /*zstart=*/"", /*zoffsets=*/"", /*zoffsets_hi=*/"", /*zsquare=*/"");
1150 return success();
1151 }
1152};
1153
1162
1164 using OpConversionPattern::OpConversionPattern;
1165 LogicalResult
1166 matchAndRewrite(arith::MulIOp mulOp, OpAdaptor adaptor,
1167 ConversionPatternRewriter &rewriter) const override {
1168 auto resTy = dyn_cast<VectorType>(mulOp.getType());
1169 if (!resTy)
1170 return failure();
1171 auto accTy = getVectorOpDestType(resTy, /*AIE2 =*/false);
1172 auto newMulOp = rewriter.create<aievec::aie1::MulOp>(
1173 mulOp.getLoc(), accTy, adaptor.getLhs(), adaptor.getRhs());
1174 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
1175 mulOp.getLoc(), rewriter.getI32IntegerAttr(0));
1176 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1177 mulOp, resTy, newMulOp.getResult(), shiftParamOp.getResult());
1178 return success();
1179 }
1180};
1181
1182template <typename SrcOpTy, typename DstOpTy>
1184 : OpConversionPattern<SrcOpTy> {
1186 using OpAdaptor = typename SrcOpTy::Adaptor;
1187
1188 LogicalResult
1189 matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor,
1190 ConversionPatternRewriter &rewriter) const override {
1191 VectorType resultType = dyn_cast<VectorType>(srcOp.getType());
1192 if (!resultType)
1193 return failure();
1194
1195 // A set recording the vector lane size and element width we are supporting
1196 // for AIE2.
1197 llvm::SmallSet<std::pair<unsigned, signed>, 16> laneSizeElWidthPairSet;
1198 laneSizeElWidthPairSet.insert({64, 8});
1199 laneSizeElWidthPairSet.insert({32, 16});
1200 laneSizeElWidthPairSet.insert({16, 32});
1201 laneSizeElWidthPairSet.insert({32, 32});
1202
1203 auto lhs = adaptor.getLhs();
1204 auto rhs = adaptor.getRhs();
1205 auto lhsDefOp = lhs.getDefiningOp();
1206 auto rhsDefOp = rhs.getDefiningOp();
1207 if ((lhsDefOp && isa<arith::MulIOp>(lhsDefOp)) ||
1208 (rhsDefOp && isa<arith::MulIOp>(rhsDefOp)) ||
1209 (lhsDefOp && isa<arith::MulFOp>(lhsDefOp)) ||
1210 (rhsDefOp && isa<arith::MulFOp>(rhsDefOp)))
1211 return failure();
1212
1213 Type scalarType = resultType.getElementType();
1214 unsigned resultElWidth = scalarType.getIntOrFloatBitWidth();
1215 unsigned laneSize = getVectorLaneSize(resultType);
1216
1217 // Integer cases
1218 if (isa<IntegerType>(scalarType)) {
1219 if (!laneSizeElWidthPairSet.count(
1220 std::make_pair(laneSize, resultElWidth)))
1221 return failure();
1222
1223 // If the ops are defined without extension ops and with supported data
1224 // type, the arith::AddI or arith::SubI can be directly replaced with
1225 // aievec::AddElem or aievec::SubElem.
1226 if (!lhsDefOp && !rhsDefOp) {
1227 if (laneSize * resultElWidth == 512) {
1228 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getType(), lhs,
1229 rhs);
1230 return success();
1231 }
1232 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs, resultType,
1233 srcOp);
1234 }
1235
1236 // If element width is 32, we need to consider sign extension cases
1237 if (resultElWidth == 32) {
1238 auto lhsExt = getSourceOfWideningOp(lhs).value_or(nullptr);
1239 auto rhsExt = getSourceOfWideningOp(rhs).value_or(nullptr);
1240
1241 if (!lhsExt && !rhsExt) {
1242 if (laneSize * resultElWidth == 512) {
1243 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getType(), lhs,
1244 rhs);
1245 return success();
1246 }
1247 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1248 resultType, srcOp);
1249 }
1250
1251 if (lhsExt && rhsExt) {
1252 auto lval = lhsExt;
1253 auto rval = rhsExt;
1254 VectorType lSrcType = cast<VectorType>(lval.getType());
1255
1256 Type accType = getVectorOpDestType(lSrcType, /*AIE2 =*/true);
1257 auto lUpsOp =
1258 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, lval);
1259 auto rUpsOp =
1260 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, rval);
1261 auto elemOp = rewriter.create<DstOpTy>(
1262 srcOp.getLoc(), lUpsOp->getResult(0).getType(),
1263 lUpsOp->getResult(0), rUpsOp->getResult(0));
1264 rewriter.replaceOpWithNewOp<aievec::CastOp>(
1265 srcOp, srcOp.getType(), elemOp.getResult(), /*isResAcc*/ false);
1266 return success();
1267 }
1268
1269 if (!lhsExt || !rhsExt) {
1270 auto lval = lhsExt ? lhsExt : lhs;
1271 auto rval = rhsExt ? rhsExt : rhs;
1272 auto extVal = lhsExt ? lval : rval;
1273 VectorType vType = cast<VectorType>(extVal.getType());
1274 unsigned bitWidth = vType.getElementType().getIntOrFloatBitWidth();
1275
1276 if (bitWidth != 8 && bitWidth != 16) {
1277 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1278 resultType, srcOp);
1279 }
1280
1281 if (bitWidth * laneSize != 256) {
1282 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1283 resultType, srcOp);
1284 }
1285
1286 Type accType = nullptr;
1287
1288 if (bitWidth == 8) {
1289 accType = getVectorOpDestType(vType, /*AIE2 =*/true);
1290 Value valToUps = lhsExt ? lval : rval;
1291 Value valToCast = lhsExt ? rval : lval;
1292 auto upsOp = rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType,
1293 valToUps);
1294 auto castOp = rewriter.create<aievec::CastOp>(
1295 srcOp.getLoc(), resultType, valToCast, /*isResAcc*/ true);
1296 Value lhsToElemOp =
1297 lhsExt ? upsOp->getResult(0) : castOp->getResult(0);
1298 Value rhsToElemOp =
1299 lhsExt ? castOp->getResult(0) : upsOp->getResult(0);
1300 auto elemOp = rewriter.create<DstOpTy>(
1301 srcOp.getLoc(), upsOp->getResult(0).getType(), lhsToElemOp,
1302 rhsToElemOp);
1303 rewriter.replaceOpWithNewOp<aievec::CastOp>(
1304 srcOp, srcOp.getType(), elemOp.getResult(), /*isResAcc*/ false);
1305 return success();
1306 }
1307
1308 if (bitWidth == 16) {
1309 accType = getVectorOpDestType(resultType, /*AIE2 =*/true);
1310 auto lUpsOp =
1311 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, lval);
1312 auto rUpsOp =
1313 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, rval);
1314
1315 auto elemOp = rewriter.create<DstOpTy>(
1316 srcOp.getLoc(), lUpsOp->getResult(0).getType(),
1317 lUpsOp->getResult(0), rUpsOp->getResult(0));
1318
1319 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
1320 srcOp.getLoc(), rewriter.getI32IntegerAttr(0));
1321 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1322 srcOp, srcOp.getType(), elemOp.getResult(),
1323 shiftParamOp.getResult());
1324 return success();
1325 }
1326 }
1327 } else {
1328 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getType(), lhs, rhs);
1329 return success();
1330 }
1331 }
1332 // Float types
1333 else {
1334 if (laneSize != 16)
1335 return failure();
1336
1337 // v16float or v16bf16 with extension op case
1338 if (resultElWidth == 32) {
1339 if (!lhsDefOp && !rhsDefOp) {
1340 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1341 resultType, srcOp);
1342 }
1343
1344 auto lhsExt = getSourceOfWideningOp(lhs).value_or(nullptr);
1345 auto rhsExt = getSourceOfWideningOp(rhs).value_or(nullptr);
1346 // v16float
1347 if (!lhsExt && !rhsExt) {
1348 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1349 resultType, srcOp);
1350 }
1351
1352 // v16bf16 with two extension ops
1353 if (lhsExt && rhsExt) {
1354 auto lval = lhsExt;
1355 auto rval = rhsExt;
1356 VectorType vType = cast<VectorType>(lval.getType());
1357
1358 Type accType = getVectorOpDestType(vType, /*AIE2 =*/true);
1359 auto lUpsOp =
1360 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, lval);
1361 auto rUpsOp =
1362 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, rval);
1363 auto elemOp = rewriter.create<DstOpTy>(
1364 srcOp.getLoc(), lUpsOp->getResult(0).getType(),
1365 lUpsOp->getResult(0), rUpsOp->getResult(0));
1366 rewriter.replaceOpWithNewOp<aievec::CastOp>(srcOp, srcOp.getType(),
1367 elemOp.getResult());
1368 return success();
1369 }
1370
1371 // v16bf16 with one extension op
1372 if (!lhsExt || !rhsExt) {
1373 auto lval = lhsExt ? lhsExt : lhs;
1374 auto rval = rhsExt ? rhsExt : rhs;
1375 auto extVal = lhsExt ? lval : rval;
1376 VectorType vType = cast<VectorType>(extVal.getType());
1377 Type accType = getVectorOpDestType(vType, /*AIE2 =*/true);
1378
1379 aievec::UPSOp upsOp;
1380 aievec::CastOp castOp;
1381 if (lhsExt) {
1382 upsOp =
1383 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, lval);
1384 castOp = rewriter.create<aievec::CastOp>(srcOp.getLoc(), resultType,
1385 rval,
1386 /*isResAcc*/ true);
1387 } else {
1388 upsOp =
1389 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, rval);
1390 castOp = rewriter.create<aievec::CastOp>(srcOp.getLoc(), resultType,
1391 lval,
1392 /*isResAcc*/ true);
1393 }
1394
1395 auto elemOp = rewriter.create<DstOpTy>(
1396 srcOp.getLoc(), upsOp->getResult(0).getType(),
1397 upsOp->getResult(0), castOp->getResult(0));
1398
1399 rewriter.replaceOpWithNewOp<aievec::CastOp>(
1400 srcOp, srcOp.getType(), elemOp.getResult(), /*isResAcc*/ false);
1401
1402 return success();
1403 }
1404 }
1405
1406 // v16bfloat16
1407 Type accType = getVectorOpDestType(resultType, /*AIE2 =*/true);
1408 auto lUpsOp =
1409 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, lhs);
1410 auto rUpsOp =
1411 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, rhs);
1412 auto elemOp = rewriter.create<DstOpTy>(
1413 srcOp.getLoc(), lUpsOp->getResult(0).getType(), lUpsOp->getResult(0),
1414 rUpsOp->getResult(0));
1415 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
1416 srcOp.getLoc(), rewriter.getI32IntegerAttr(0));
1417 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1418 srcOp, srcOp.getType(), elemOp.getResult(), shiftParamOp.getResult());
1419
1420 return success();
1421 }
1422
1423 return failure();
1424 }
1425};
1426
1429 aievec::AddElemOp>;
1432 aievec::SubElemOp>;
1435 aievec::AddElemOp>;
1438 aievec::SubElemOp>;
1439
1440template <typename SrcOpTy, typename DstOpTy>
1443 using OpAdaptor = typename SrcOpTy::Adaptor;
1444
1445 LogicalResult
1446 matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor,
1447 ConversionPatternRewriter &rewriter) const override {
1448 VectorType resultType = dyn_cast<VectorType>(srcOp.getType());
1449 if (!resultType)
1450 return failure();
1451
1452 // A set recording the element width we are supporting for AIE2.
1453 llvm::SmallSet<unsigned, 16> elWidthSet;
1454 elWidthSet.insert(8);
1455 elWidthSet.insert(16);
1456 elWidthSet.insert(32);
1457
1458 Type scalarType = resultType.getElementType();
1459 unsigned resultElWidth = scalarType.getIntOrFloatBitWidth();
1460 unsigned laneSize = getVectorLaneSize(resultType);
1461
1462 if (!elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512)
1463 return failure();
1464
1465 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getType(),
1466 adaptor.getLhs(), adaptor.getRhs());
1467 return success();
1468 }
1469};
1470
1479
1480template <typename SrcOpTy, typename CmpTy>
1483 using OpAdaptor = typename SrcOpTy::Adaptor;
1484
1485 LogicalResult
1486 matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor,
1487 ConversionPatternRewriter &rewriter) const override {
1488 VectorType lhsType = dyn_cast<VectorType>(srcOp.getLhs().getType());
1489 if (!lhsType)
1490 return failure();
1491
1492 llvm::SmallSet<unsigned, 16> elWidthSet;
1493 elWidthSet.insert(8);
1494 elWidthSet.insert(16);
1495 elWidthSet.insert(32);
1496
1497 Type scalarType = lhsType.getElementType();
1498 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
1499 unsigned laneSize = getVectorLaneSize(lhsType);
1500
1501 if (!elWidthSet.count(elWidth) || laneSize * elWidth != 512)
1502 return failure();
1503
1504 // Unsigned int and unsigned long long are acceptable type.
1505 Type type =
1506 mlir::IntegerType::get(srcOp.getContext(), laneSize <= 32 ? 32 : 64,
1507 mlir::IntegerType::Unsigned);
1508
1509 Location loc = srcOp.getLoc();
1510 Value lhs = srcOp.getLhs();
1511 Value rhs = srcOp.getRhs();
1512 CmpTy pred = srcOp.getPredicate();
1513
1514 arith::CmpIPredicate ipred = convertToIntegerPredicate(pred);
1515
1516 aievec::CmpOp aieCmpOp =
1517 createCmpOpAIE2(rewriter, ipred, loc, type, lhs, rhs);
1518
1519 if (!aieCmpOp)
1520 return failure();
1521
1522 VectorType resultType = dyn_cast<VectorType>(srcOp.getResult().getType());
1523 // Convert vector i1 type to unsigned interger type by built-in unrealized
1524 // conversion cast op.
1525 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
1526 srcOp, resultType, aieCmpOp.getResult());
1527
1528 return success();
1529 }
1530};
1531
1536
1538 using OpConversionPattern::OpConversionPattern;
1539
1540 LogicalResult
1541 matchAndRewrite(arith::SelectOp srcOp, OpAdaptor adaptor,
1542 ConversionPatternRewriter &rewriter) const override {
1543 auto resultType = dyn_cast<VectorType>(srcOp.getType());
1544 if (!resultType)
1545 return failure();
1546
1547 llvm::SmallSet<unsigned, 16> elWidthSet;
1548 elWidthSet.insert(8);
1549 elWidthSet.insert(16);
1550 elWidthSet.insert(32);
1551
1552 Type scalarType = resultType.getElementType();
1553 unsigned resultElWidth = scalarType.getIntOrFloatBitWidth();
1554 unsigned laneSize = getVectorLaneSize(resultType);
1555
1556 if (!elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512)
1557 return failure();
1558
1559 Type type =
1560 mlir::IntegerType::get(srcOp.getContext(), laneSize <= 32 ? 32 : 64,
1561 mlir::IntegerType::Unsigned);
1562
1563 auto convertOp = rewriter.create<UnrealizedConversionCastOp>(
1564 srcOp.getLoc(), type, adaptor.getCondition());
1565
1566 rewriter.replaceOpWithNewOp<aievec::SelOp>(
1567 srcOp, srcOp.getResult().getType(), srcOp.getTrueValue(),
1568 srcOp.getFalseValue(), convertOp.getResult(0));
1569
1570 return success();
1571 }
1572};
1573
1574struct LowerVectorReductionMinOp : OpConversionPattern<vector::ReductionOp> {
1575 using OpConversionPattern::OpConversionPattern;
1576
1577 LogicalResult
1578 matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor,
1579 ConversionPatternRewriter &rewriter) const override {
1580 if (auto kind = srcOp.getKind(); kind != vector::CombiningKind::MINSI &&
1581 kind != vector::CombiningKind::MINUI &&
1582 kind != vector::CombiningKind::MINIMUMF)
1583 return failure();
1584
1585 auto vType = cast<VectorType>(srcOp.getVector().getType());
1586 Type scalarType = vType.getElementType();
1587 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
1588 unsigned laneSize = getVectorLaneSize(vType);
1589
1590 if (laneSize * elWidth != 512)
1591 return failure();
1592
1593 int shiftIndex = laneSize / 2;
1594 generateAIEVecOpsForReductionOp<aievec::MinOp>(rewriter, srcOp, shiftIndex,
1595 srcOp.getVector());
1596 return success();
1597 }
1598};
1599
1600struct LowerVectorReductionMaxOp : OpConversionPattern<vector::ReductionOp> {
1601 using OpConversionPattern::OpConversionPattern;
1602
1603 LogicalResult
1604 matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor,
1605 ConversionPatternRewriter &rewriter) const override {
1606 if (auto kind = srcOp.getKind(); kind != vector::CombiningKind::MAXSI &&
1607 kind != vector::CombiningKind::MAXUI &&
1608 kind != vector::CombiningKind::MAXIMUMF)
1609 return failure();
1610
1611 auto vType = cast<VectorType>(srcOp.getVector().getType());
1612 Type scalarType = vType.getElementType();
1613 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
1614 unsigned laneSize = getVectorLaneSize(vType);
1615
1616 if (laneSize * elWidth != 512)
1617 return failure();
1618
1619 int shiftIndex = laneSize / 2;
1620 generateAIEVecOpsForReductionOp<aievec::MaxOp>(rewriter, srcOp, shiftIndex,
1621 srcOp.getVector());
1622 return success();
1623 }
1624};
1625
1627 using OpConversionPattern::OpConversionPattern;
1628
1629 LogicalResult
1630 matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor,
1631 ConversionPatternRewriter &rewriter) const override {
1632 if (auto kind = srcOp.getKind(); kind != vector::CombiningKind::ADD)
1633 return failure();
1634
1635 auto vType = cast<VectorType>(srcOp.getVector().getType());
1636 Type scalarType = vType.getElementType();
1637 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
1638 unsigned laneSize = getVectorLaneSize(vType);
1639 llvm::SmallSet<std::pair<unsigned, signed>, 16> laneSizeElWidthPairSet;
1640 laneSizeElWidthPairSet.insert({64, 8});
1641 laneSizeElWidthPairSet.insert({32, 16});
1642 laneSizeElWidthPairSet.insert({32, 32});
1643 laneSizeElWidthPairSet.insert({16, 32});
1644
1645 if (!isa<IntegerType>(scalarType) ||
1646 !laneSizeElWidthPairSet.count(std::make_pair(laneSize, elWidth)))
1647 return failure();
1648
1649 int shiftIndex = laneSize / 2;
1650 if (laneSize == 32 && elWidth == 32) {
1651 Location loc = srcOp.getLoc();
1652 VectorType vecType = createVectorType(laneSize / 2, scalarType);
1653
1654 auto lExtOp =
1655 rewriter.create<aievec::ExtOp>(loc, vecType, srcOp.getVector(), 0);
1656 auto rExtOp =
1657 rewriter.create<aievec::ExtOp>(loc, vecType, srcOp.getVector(), 1);
1658 auto addElemOp = rewriter.create<aievec::AddElemOp>(
1659 loc, lExtOp.getResult().getType(), lExtOp.getResult(),
1660 rExtOp.getResult());
1661 shiftIndex /= 2;
1662 generateAIEVecOpsForReductionOp<aievec::AddElemOp>(
1663 rewriter, srcOp, shiftIndex, addElemOp.getResult());
1664 } else
1665 generateAIEVecOpsForReductionOp<aievec::AddElemOp>(
1666 rewriter, srcOp, shiftIndex, srcOp.getVector());
1667
1668 return success();
1669 }
1670};
1671
1673 : OpConversionPattern<vector::ReductionOp> {
1674 using OpConversionPattern::OpConversionPattern;
1675
1676 LogicalResult
1677 matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor,
1678 ConversionPatternRewriter &rewriter) const override {
1679 if (auto kind = srcOp.getKind(); kind != vector::CombiningKind::ADD)
1680 return failure();
1681
1682 auto vType = cast<VectorType>(srcOp.getVector().getType());
1683 Type scalarType = vType.getElementType();
1684 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
1685 unsigned laneSize = getVectorLaneSize(vType);
1686
1687 if (!isa<FloatType>(scalarType) || laneSize != 16 || elWidth != 32)
1688 return failure();
1689
1690 int shiftIndex = laneSize / 2;
1691 assert(shiftIndex > 0 && (shiftIndex & (shiftIndex - 1)) == 0 &&
1692 "shiftIndex must be power of 2");
1693
1694 Location loc = srcOp.getLoc();
1695 Value curValue = srcOp.getVector();
1696 aievec::CastOp curOp = nullptr;
1697
1698 for (int id = shiftIndex; id > 0; id /= 2) {
1699 auto constOp = rewriter.create<arith::ConstantOp>(
1700 loc, rewriter.getI32IntegerAttr(id * elWidth / 8));
1701
1702 auto shiftBytesOp = rewriter.create<aievec::ShiftOp>(
1703 loc, vType, curValue, curValue, constOp.getResult());
1704
1705 auto lCastOp = rewriter.create<aievec::CastOp>(loc, vType, curValue,
1706 /*isResAcc*/ true);
1707 auto rCastOp =
1708 rewriter.create<aievec::CastOp>(loc, vType, shiftBytesOp.getResult(),
1709 /*isResAcc*/ true);
1710 auto elemOp = rewriter.create<aievec::AddElemOp>(
1711 loc, lCastOp.getResult().getType(), lCastOp.getResult(),
1712 rCastOp.getResult());
1713 curOp = rewriter.create<aievec::CastOp>(loc, vType, elemOp.getResult(),
1714 /*isResAcc*/ false);
1715 curValue = curOp.getResult();
1716 }
1717
1718 auto zeroConstOp =
1719 rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
1720 rewriter.replaceOpWithNewOp<aievec::ExtElemOp>(srcOp, scalarType, curOp,
1721 zeroConstOp.getResult());
1722 return success();
1723 }
1724};
1725
1727 : OpConversionPattern<vector::ReductionOp> {
1728 using OpConversionPattern::OpConversionPattern;
1729
1730 LogicalResult
1731 matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor,
1732 ConversionPatternRewriter &rewriter) const override {
1733 if (auto kind = srcOp.getKind(); kind != vector::CombiningKind::ADD)
1734 return failure();
1735
1736 auto vType = cast<VectorType>(srcOp.getVector().getType());
1737 Type scalarType = vType.getElementType();
1738 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
1739 unsigned laneSize = getVectorLaneSize(vType);
1740
1741 if (!isa<FloatType>(scalarType) || laneSize != 16 || elWidth != 16)
1742 return failure();
1743
1744 int shiftIndex = laneSize / 2;
1745 assert(shiftIndex > 0 && (shiftIndex & (shiftIndex - 1)) == 0 &&
1746 "shiftIndex must be power of 2");
1747
1748 Value curValue = srcOp.getVector();
1749 Location loc = srcOp.getLoc();
1750 Type accType = getVectorOpDestType(vType, /*AIE2 =*/true);
1751 unsigned accWidth =
1752 dyn_cast<VectorType>(accType).getElementType().getIntOrFloatBitWidth();
1753
1754 auto upsOp =
1755 rewriter.create<aievec::UPSOp>(loc, accType, srcOp.getVector());
1756
1757 curValue = upsOp.getResult();
1758
1759 VectorType vecType = createVectorType(2 * laneSize, scalarType);
1760 aievec::AddElemOp curOp = nullptr;
1761
1762 for (int id = shiftIndex; id > 0; id /= 2) {
1763 auto constOp = rewriter.create<arith::ConstantOp>(
1764 loc, rewriter.getI32IntegerAttr(id * accWidth / 8));
1765 auto shiftBytesOp = rewriter.create<aievec::ShiftOp>(
1766 loc, accType, curValue, curValue, constOp, true);
1767 curOp = rewriter.create<aievec::AddElemOp>(loc, accType, curValue,
1768 shiftBytesOp.getResult());
1769 curValue = curOp.getResult();
1770 }
1771
1772 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
1773 srcOp.getLoc(), rewriter.getI32IntegerAttr(0));
1774 auto srsOp = rewriter.create<aievec::SRSOp>(loc, vType, curOp.getResult(),
1775 shiftParamOp.getResult());
1776 SmallVector<Value> concatSources = {srsOp.getResult(), srsOp.getResult()};
1777 auto concatOp =
1778 rewriter.create<aievec::ConcatOp>(loc, vecType, concatSources);
1779
1780 auto zeroConstOp =
1781 rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
1782 rewriter.replaceOpWithNewOp<aievec::ExtElemOp>(srcOp, scalarType, concatOp,
1783 zeroConstOp.getResult());
1784 return success();
1785 }
1786};
1787
1788// Convert a `vector.extract_strided_slice` op on 1D vectors into an
1789// `aievec.select` + `aievec.ext` op.
1791 : OpConversionPattern<vector::ExtractStridedSliceOp> {
1792 using OpConversionPattern::OpConversionPattern;
1793
1794 LogicalResult
1795 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
1796 ConversionPatternRewriter &rewriter) const override {
1797 auto vType = extractOp.getSourceVectorType();
1798 if (vType.getRank() != 1)
1799 return failure();
1800
1801 int64_t stride = cast<IntegerAttr>(adaptor.getStrides()[0]).getInt();
1802 if (stride != 1)
1803 return failure();
1804
1805 // AIE doesn't support select operations on i8
1806 if (getElementSizeInBits(vType) == 8)
1807 return extractOp.emitError()
1808 << "AIEv1 doesn't support select ops on int8 types";
1809
1810 // We only accept the case where we are extracting a slice half the size of
1811 // the input vector.
1812 int64_t size = cast<IntegerAttr>(adaptor.getSizes()[0]).getInt();
1813 if (vType.getNumElements() != 2 * size)
1814 return failure();
1815
1816 int64_t offset = cast<IntegerAttr>(adaptor.getOffsets()[0]).getInt();
1817 auto selectOp = rewriter.create<aievec::aie1::SelectOp>(
1818 extractOp.getLoc(), vType, adaptor.getVector(),
1819 buildAttributeListForRotationSelectOp(rewriter, vType, offset));
1820 rewriter.replaceOpWithNewOp<aievec::aie1::ExtOp>(
1821 extractOp, extractOp.getType(), selectOp.getResult(),
1822 rewriter.getI8IntegerAttr(0));
1823 return success();
1824 }
1825};
1826
1827// Convert a `vector.extract_strided_slice` op on 1D vectors into an
1828// `aievec.shift` op.
1830 : OpConversionPattern<vector::ExtractStridedSliceOp> {
1831 using OpConversionPattern::OpConversionPattern;
1832
1833 LogicalResult
1834 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
1835 ConversionPatternRewriter &rewriter) const override {
1836 auto vType = cast<VectorType>(adaptor.getVector().getType());
1837 if (vType.getRank() != 1)
1838 return failure();
1839
1840 int64_t stride = cast<IntegerAttr>(adaptor.getStrides()[0]).getInt();
1841 if (stride != 1)
1842 return failure();
1843
1844 // We only accept the case where we are extracting a slice half the size of
1845 // the input vector.
1846 int64_t size = cast<IntegerAttr>(adaptor.getSizes()[0]).getInt();
1847 if (vType.getNumElements() != 2 * size)
1848 return failure();
1849
1850 auto shortVecType = cast<VectorType>(extractOp.getResult().getType());
1851 auto bottomHalf = rewriter
1852 .create<aievec::ExtOp>(
1853 extractOp.getLoc(), shortVecType,
1854 adaptor.getVector(), rewriter.getI8IntegerAttr(0))
1855 .getResult();
1856 auto topHalf = rewriter
1857 .create<aievec::ExtOp>(extractOp.getLoc(), shortVecType,
1858 adaptor.getVector(),
1859 rewriter.getI8IntegerAttr(1))
1860 .getResult();
1861 int64_t offset = cast<IntegerAttr>(adaptor.getOffsets()[0]).getInt();
1862 int32_t shiftBytes = offset * getElementSizeInBits(vType) / 8;
1863 auto shiftBytesConstOp = rewriter.create<arith::ConstantOp>(
1864 extractOp.getLoc(), rewriter.getIntegerType(32),
1865 rewriter.getI32IntegerAttr(shiftBytes));
1866 rewriter.replaceOpWithNewOp<aievec::ShiftOp>(
1867 extractOp, shortVecType, bottomHalf, topHalf, shiftBytesConstOp);
1868
1869 return success();
1870 }
1871};
1872
1873// Replaces a short UPD op with a wide one followed by an ext op of the bottom
1874// half.
1876 using OpConversionPattern::OpConversionPattern;
1877
1878 ExpandUPDToUPDAndExtPattern(MLIRContext *context)
1879 : OpConversionPattern(context) {}
1880
1881 LogicalResult
1882 matchAndRewrite(aievec::UPDOp updOp, OpAdaptor adaptor,
1883 ConversionPatternRewriter &rewriter) const override {
1884 // Verify that we haven't already expanded this one
1885 if (updOp->hasOneUse() && isa<aievec::ExtOp>(*updOp->getUsers().begin()))
1886 return failure();
1887
1888 auto vecType = cast<VectorType>(updOp.getType());
1889 SmallVector<int64_t, 4> vecShape(vecType.getShape().begin(),
1890 vecType.getShape().end());
1891 vecShape[vecType.getRank() - 1] *= 2;
1892 auto longVecType = VectorType::get(vecShape, vecType.getElementType());
1893 auto newUpdOp = rewriter.create<aievec::UPDOp>(
1894 updOp.getLoc(), longVecType, adaptor.getSource(), adaptor.getIndices(),
1895 adaptor.getOffset(), adaptor.getIndex(), adaptor.getVector());
1896 rewriter.replaceOpWithNewOp<aievec::ExtOp>(
1897 updOp, vecType, newUpdOp.getResult(), rewriter.getI8IntegerAttr(0));
1898
1899 return success();
1900 }
1901};
1902
1903// Replaces a wide UPD op followed by an ext op of the bottom half with a short
1904// UPD op.
1906 using OpConversionPattern::OpConversionPattern;
1907
1908 FuseExtIntoUPDPattern(MLIRContext *context) : OpConversionPattern(context) {}
1909
1910 LogicalResult
1911 matchAndRewrite(aievec::ExtOp extOp, OpAdaptor adaptor,
1912 ConversionPatternRewriter &rewriter) const override {
1913 // Verify we are extracting the lower half...
1914 if (extOp.getIndex() != 0)
1915 return failure();
1916 // ...of a UPDOp
1917 auto updOp = dyn_cast<aievec::UPDOp>(extOp.getSource().getDefiningOp());
1918 if (!updOp)
1919 return failure();
1920
1921 // Verify that this is a direct upd -> ext pattern
1922 if (!updOp->hasOneUse())
1923 return failure();
1924
1925 rewriter.replaceOpWithNewOp<aievec::UPDOp>(
1926 extOp, extOp.getType(), updOp.getSource(), updOp.getIndices(),
1927 updOp.getOffset(), updOp.getIndex(), updOp.getVector());
1928
1929 return success();
1930 }
1931};
1932
1934 using OpConversionPattern::OpConversionPattern;
1935
1936 LogicalResult
1937 matchAndRewrite(math::ExpOp expOp, OpAdaptor adaptor,
1938 ConversionPatternRewriter &rewriter) const override {
1939
1940 if (!matchExpOpForLUT(adaptor))
1941 return failure();
1942
1943 auto srcType = dyn_cast<VectorType>(adaptor.getOperand().getType());
1944 StringRef funcName = "getExpBf16";
1945 auto moduleOp = expOp->getParentOfType<mlir::ModuleOp>();
1946
1947 VectorType v16bf16Ty = mlir::VectorType::get({16}, rewriter.getBF16Type());
1948 VectorType v8i64Ty = mlir::VectorType::get({8}, rewriter.getI64Type());
1949 func::FuncOp fnOp = getOrInsertFuncDecl(
1950 rewriter, moduleOp, funcName, TypeRange{v16bf16Ty}, TypeRange{v8i64Ty});
1951
1952 SmallVector<Value> expOperands = {adaptor.getOperand()};
1953
1954 Type accTypeNative = getVectorOpDestType(srcType, /*AIE2 =*/true);
1955 auto callOp =
1956 rewriter.create<func::CallOp>(expOp.getLoc(), fnOp, expOperands);
1957 auto resCastOp = rewriter.create<vector::BitCastOp>(
1958 expOp.getLoc(), accTypeNative, callOp.getResults());
1959 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
1960 expOp.getLoc(), rewriter.getI32IntegerAttr(0));
1961 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1962 expOp, srcType, resCastOp.getResult(), shiftParamOp.getResult());
1963
1964 return success();
1965 }
1966};
1967// Lower ExpOp to function call
1969 using OpConversionPattern::OpConversionPattern;
1970
1971 LogicalResult
1972 matchAndRewrite(math::ExpOp expOp, OpAdaptor adaptor,
1973 ConversionPatternRewriter &rewriter) const override {
1974 if (!matchExpOpForLUT(adaptor))
1975 return failure();
1976 auto srcType = dyn_cast<VectorType>(adaptor.getOperand().getType());
1977 StringRef includeName = "lut_based_ops.h";
1978 auto moduleOp = expOp->getParentOfType<mlir::ModuleOp>();
1979 rewriter.setInsertionPointToStart(
1980 &moduleOp.getRegion().getBlocks().front());
1981 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName, false);
1982
1983 rewriter.setInsertionPoint(expOp);
1984
1985 auto v16bf16OpaqueTy =
1986 emitc::OpaqueType::get(rewriter.getContext(), "v16bfloat16");
1987 auto opaquedOperand =
1988 rewriter
1989 .create<UnrealizedConversionCastOp>(expOp.getLoc(), v16bf16OpaqueTy,
1990 adaptor.getOperand())
1991 .getResult(0);
1992 SmallVector<Value> expOperands = {opaquedOperand};
1993
1994 Type accTypeNative = getVectorOpDestType(srcType, /*AIE2 =*/true);
1995 Type v16accf32OpaqueTy =
1996 emitc::OpaqueType::get(rewriter.getContext(), "v16accfloat");
1997 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
1998 expOp.getLoc(), TypeRange{v16accf32OpaqueTy}, "getExpBf16", nullptr,
1999 nullptr, expOperands);
2000 auto resCastOp = rewriter.create<UnrealizedConversionCastOp>(
2001 expOp.getLoc(), accTypeNative, callOp.getResults());
2002 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
2003 expOp.getLoc(), rewriter.getI32IntegerAttr(0));
2004 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
2005 expOp, srcType, resCastOp.getResult(0), shiftParamOp.getResult());
2006
2007 return success();
2008 }
2009};
2010
2011// Lower the inverse of a float to a function call
2012// Convert the pattern-
2013// %cst = arith.constant 1.000000e+00 : f32
2014// %0 = arith.divf %cst, %arg1 : f32
2015// %1 = arith.truncf %0 : f32 to bf16
2016// to -
2017// %0 = emitc.call "getInvBf16"(%0) : f32 -> bf16;
2019 using OpConversionPattern::OpConversionPattern;
2020
2021 LogicalResult
2022 matchAndRewrite(arith::DivFOp divOp, OpAdaptor adaptor,
2023 ConversionPatternRewriter &rewriter) const override {
2024 Type srcType = adaptor.getLhs().getType();
2025 if (!divOp->hasOneUse() || isa<VectorType>(srcType) ||
2026 !isa<FloatType>(srcType))
2027 return failure();
2028
2029 if (!isNarrowingOp(*divOp->getUsers().begin()))
2030 return failure();
2031
2032 auto fType = cast<FloatType>(srcType);
2033 if (fType.getWidth() != 32)
2034 return failure();
2035
2036 auto constOp = dyn_cast<arith::ConstantOp>(divOp.getLhs().getDefiningOp());
2037 if (!constOp ||
2038 cast<FloatAttr>(constOp.getValue()).getValue().convertToDouble() !=
2039 1.0f)
2040 return failure();
2041
2042 StringRef includeName = "lut_based_ops.h";
2043 auto moduleOp = divOp->getParentOfType<mlir::ModuleOp>();
2044 rewriter.setInsertionPointToStart(
2045 &moduleOp.getRegion().getBlocks().front());
2046 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName, false);
2047
2048 auto truncOp = cast<arith::TruncFOp>(*divOp->getUsers().begin());
2049
2050 rewriter.setInsertionPoint(truncOp);
2051 Type bf16OpaqueTy =
2052 emitc::OpaqueType::get(rewriter.getContext(), "bfloat16");
2053 SmallVector<Value> invOperands = {adaptor.getRhs()};
2054 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2055 truncOp.getLoc(), bf16OpaqueTy, "getInvBf16", nullptr, nullptr,
2056 invOperands);
2057 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2058 truncOp, TypeRange{truncOp.getResult().getType()}, callOp.getResults());
2059 rewriter.eraseOp(divOp);
2060
2061 return success();
2062 }
2063};
2064
2065// Convert math.tanh to a function call to compute tanh(x) by look up tables
2067 using OpConversionPattern::OpConversionPattern;
2068
2069 LogicalResult
2070 matchAndRewrite(math::TanhOp tanhOp, OpAdaptor adaptor,
2071 ConversionPatternRewriter &rewriter) const override {
2072 auto srcType = dyn_cast<VectorType>(tanhOp.getOperand().getType());
2073 if (!srcType)
2074 return failure();
2075
2076 Type scalarType = srcType.getElementType();
2077 if (!isa<FloatType>(scalarType))
2078 return failure();
2079
2080 unsigned laneSize = getVectorLaneSize(srcType);
2081 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2082 if (elWidth != 16 || laneSize != 16)
2083 return failure();
2084
2085 StringRef includeName = "lut_based_ops.h";
2086 auto moduleOp = tanhOp->getParentOfType<mlir::ModuleOp>();
2087 rewriter.setInsertionPointToStart(
2088 &moduleOp.getRegion().getBlocks().front());
2089 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName, false);
2090
2091 rewriter.setInsertionPoint(tanhOp);
2092 Type v16bf16OpaqueTy =
2093 emitc::OpaqueType::get(rewriter.getContext(), "v16bfloat16");
2094 auto opaquedOperand =
2095 rewriter
2096 .create<UnrealizedConversionCastOp>(
2097 tanhOp.getLoc(), v16bf16OpaqueTy, adaptor.getOperand())
2098 .getResult(0);
2099 SmallVector<Value> tanhOperands = {opaquedOperand};
2100 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2101 tanhOp.getLoc(), v16bf16OpaqueTy, "getTanhBf16", nullptr, nullptr,
2102 tanhOperands);
2103 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2104 tanhOp, TypeRange{tanhOp.getResult().getType()}, callOp.getResults());
2105
2106 return success();
2107 }
2108};
2109
2110// Convert math.sqrt to a function call to compute sqrt(x) for v16bfloat16 and
2111// v32bfloat16 types
2113 using OpConversionPattern::OpConversionPattern;
2114
2115 LogicalResult
2116 matchAndRewrite(math::SqrtOp sqrtOp, OpAdaptor adaptor,
2117 ConversionPatternRewriter &rewriter) const override {
2118 auto srcType = dyn_cast<VectorType>(sqrtOp.getOperand().getType());
2119 if (!srcType)
2120 return failure();
2121
2122 Type scalarType = srcType.getElementType();
2123 if (!isa<FloatType>(scalarType))
2124 return failure();
2125
2126 unsigned laneSize = getVectorLaneSize(srcType);
2127 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2128 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
2129 return failure();
2130
2131 StringRef includeName = "vec_math.h";
2132 auto moduleOp = sqrtOp->getParentOfType<mlir::ModuleOp>();
2133 rewriter.setInsertionPointToStart(
2134 &moduleOp.getRegion().getBlocks().front());
2135 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName, false);
2136
2137 rewriter.setInsertionPoint(sqrtOp);
2138 Type vLNbf16OpaqueTy;
2139 if (laneSize == 16)
2140 vLNbf16OpaqueTy =
2141 emitc::OpaqueType::get(rewriter.getContext(), "v16bfloat16");
2142 else
2143 vLNbf16OpaqueTy =
2144 emitc::OpaqueType::get(rewriter.getContext(), "v32bfloat16");
2145 auto opaquedOperand =
2146 rewriter
2147 .create<UnrealizedConversionCastOp>(
2148 sqrtOp.getLoc(), vLNbf16OpaqueTy, adaptor.getOperand())
2149 .getResult(0);
2150 SmallVector<Value> sqrtOperands = {opaquedOperand};
2151 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2152 sqrtOp.getLoc(), TypeRange{vLNbf16OpaqueTy}, "getSqrtBf16", nullptr,
2153 nullptr, sqrtOperands);
2154 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2155 sqrtOp, TypeRange{sqrtOp.getResult().getType()}, callOp.getResults());
2156
2157 return success();
2158 }
2159};
2160
2161// Convert math.rsqrt to a function call to compute 1.0f / sqrt(x) for
2162// v16bfloat16 and v32bfloat16 types
2164 using OpConversionPattern::OpConversionPattern;
2165
2166 LogicalResult
2167 matchAndRewrite(math::RsqrtOp rsqrtOp, OpAdaptor adaptor,
2168 ConversionPatternRewriter &rewriter) const override {
2169 auto srcType = dyn_cast<VectorType>(rsqrtOp.getOperand().getType());
2170 if (!srcType)
2171 return failure();
2172
2173 Type scalarType = srcType.getElementType();
2174 if (!isa<FloatType>(scalarType))
2175 return failure();
2176
2177 unsigned laneSize = getVectorLaneSize(srcType);
2178 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2179 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
2180 return failure();
2181
2182 StringRef includeName = "vec_math.h";
2183 auto moduleOp = rsqrtOp->getParentOfType<mlir::ModuleOp>();
2184 rewriter.setInsertionPointToStart(
2185 &moduleOp.getRegion().getBlocks().front());
2186 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName, false);
2187
2188 rewriter.setInsertionPoint(rsqrtOp);
2189 Type vLNbf16OpaqueTy;
2190 if (laneSize == 16)
2191 vLNbf16OpaqueTy =
2192 emitc::OpaqueType::get(rewriter.getContext(), "v16bfloat16");
2193 else
2194 vLNbf16OpaqueTy =
2195 emitc::OpaqueType::get(rewriter.getContext(), "v32bfloat16");
2196 auto opaquedOperand =
2197 rewriter
2198 .create<UnrealizedConversionCastOp>(
2199 rsqrtOp.getLoc(), vLNbf16OpaqueTy, adaptor.getOperand())
2200 .getResult(0);
2201 SmallVector<Value> rsqrtOperands = {opaquedOperand};
2202 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2203 rsqrtOp.getLoc(), TypeRange{vLNbf16OpaqueTy}, "getRsqrtBf16", nullptr,
2204 nullptr, rsqrtOperands);
2205 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2206 rsqrtOp, TypeRange{rsqrtOp.getResult().getType()}, callOp.getResults());
2207
2208 return success();
2209 }
2210};
2211
2212// Convert math.erf to a function call to compute erf(x) for v16bfloat16 and
2213// v32bfloat16 types
2215 using OpConversionPattern::OpConversionPattern;
2216
2217 LogicalResult
2218 matchAndRewrite(math::ErfOp erfOp, OpAdaptor adaptor,
2219 ConversionPatternRewriter &rewriter) const override {
2220 auto srcType = dyn_cast<VectorType>(erfOp.getOperand().getType());
2221 if (!srcType)
2222 return failure();
2223
2224 Type scalarType = srcType.getElementType();
2225 if (!isa<FloatType>(scalarType))
2226 return failure();
2227
2228 unsigned laneSize = getVectorLaneSize(srcType);
2229 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2230 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
2231 return failure();
2232
2233 StringRef includeName = "vec_math.h";
2234 auto moduleOp = erfOp->getParentOfType<mlir::ModuleOp>();
2235 rewriter.setInsertionPointToStart(
2236 &moduleOp.getRegion().getBlocks().front());
2237 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName, false);
2238
2239 rewriter.setInsertionPoint(erfOp);
2240 Type vLNbf16OpaqueTy;
2241 if (laneSize == 16)
2242 vLNbf16OpaqueTy =
2243 emitc::OpaqueType::get(rewriter.getContext(), "v16bfloat16");
2244 else
2245 vLNbf16OpaqueTy =
2246 emitc::OpaqueType::get(rewriter.getContext(), "v32bfloat16");
2247 auto opaquedOperand =
2248 rewriter
2249 .create<UnrealizedConversionCastOp>(erfOp.getLoc(), vLNbf16OpaqueTy,
2250 adaptor.getOperand())
2251 .getResult(0);
2252 SmallVector<Value> erfOperands = {opaquedOperand};
2253 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2254 erfOp.getLoc(), TypeRange{vLNbf16OpaqueTy}, "getErfBf16", nullptr,
2255 nullptr, erfOperands);
2256 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2257 erfOp, TypeRange{erfOp.getResult().getType()}, callOp.getResults());
2258
2259 return success();
2260 }
2261};
2262
2263// Convert math.absf and math.absi to a function call to compute abs(x) for
2264// v16bfloat16, v32bfloat16, v16float, v16int32, v32int16 and v64int8 types
2265template <typename SrcOpTy>
2268 using OpAdaptor = typename SrcOpTy::Adaptor;
2269
2270 LogicalResult
2271 matchAndRewrite(SrcOpTy absOp, OpAdaptor adaptor,
2272 ConversionPatternRewriter &rewriter) const override {
2273 auto vecTy = dyn_cast<VectorType>(absOp.getOperand().getType());
2274 if (!vecTy)
2275 return failure();
2276
2277 Type elemTy = vecTy.getElementType();
2278
2279 unsigned laneSize = getVectorLaneSize(vecTy);
2280 unsigned elWidth = elemTy.getIntOrFloatBitWidth();
2281
2282 StringRef includeName = "vec_math.h";
2283 auto moduleOp = absOp->template getParentOfType<mlir::ModuleOp>();
2284 rewriter.setInsertionPointToStart(
2285 &moduleOp.getRegion().getBlocks().front());
2286 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName, false);
2287
2288 rewriter.setInsertionPoint(absOp);
2289 std::ostringstream typeName;
2290 typeName << "v" << laneSize;
2291 if (isa<FloatType>(elemTy)) {
2292 if (elWidth == 16)
2293 typeName << "bfloat16";
2294 else
2295 typeName << "float";
2296 } else
2297 typeName << "int" << elWidth;
2298 Type vecOpaqueTy =
2299 emitc::OpaqueType::get(rewriter.getContext(), typeName.str());
2300 auto opaquedOperand =
2301 rewriter
2302 .create<UnrealizedConversionCastOp>(absOp.getLoc(), vecOpaqueTy,
2303 adaptor.getOperand())
2304 .getResult(0);
2305 SmallVector<Value> absOperands = {opaquedOperand};
2306 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2307 absOp.getLoc(), TypeRange{vecOpaqueTy}, "getAbs", nullptr, nullptr,
2308 absOperands);
2309 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2310 absOp, TypeRange{absOp.getResult().getType()}, callOp.getResults());
2311
2312 return success();
2313 }
2314};
2315
2318
2319template <typename SrcOpTy>
2322 using OpAdaptor = typename SrcOpTy::Adaptor;
2323
2324 LogicalResult
2325 matchAndRewrite(SrcOpTy extOp, OpAdaptor adaptor,
2326 ConversionPatternRewriter &rewriter) const override {
2327 VectorType srcType = dyn_cast<VectorType>(extOp.getIn().getType());
2328 VectorType dstType = dyn_cast<VectorType>(extOp.getOut().getType());
2329
2330 auto accType = getVectorOpDestType(srcType, /*AIE2 =*/true);
2331 auto upsOp =
2332 rewriter.create<aievec::UPSOp>(extOp.getLoc(), accType, extOp.getIn());
2333
2334 if (dstType.getElementType().getIntOrFloatBitWidth() == 16) {
2335 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
2336 extOp.getLoc(), rewriter.getI32IntegerAttr(0));
2337 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
2338 extOp, dstType, upsOp.getResult(), shiftParamOp.getResult());
2339 } else
2340 rewriter.replaceOpWithNewOp<aievec::CastOp>(
2341 extOp, dstType, upsOp.getResult(), /*isResAcc*/ false);
2342
2343 return success();
2344 }
2345};
2346
2349
2350template <typename SrcOpTy>
2353 using OpAdaptor = typename SrcOpTy::Adaptor;
2354
2355 LogicalResult
2356 matchAndRewrite(SrcOpTy truncOp, OpAdaptor adaptor,
2357 ConversionPatternRewriter &rewriter) const override {
2358 VectorType srcType = dyn_cast<VectorType>(truncOp.getIn().getType());
2359 VectorType dstType = dyn_cast<VectorType>(truncOp.getOut().getType());
2360 Type scalarType = srcType.getElementType();
2361 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2362
2363 unsigned laneSize = getVectorLaneSize(srcType);
2364 auto accType = isa<IntegerType>(scalarType) && (elWidth == 32)
2365 ? createVectorType(laneSize, scalarType)
2366 : getVectorOpDestType(srcType, /*AIE2 =*/true);
2367
2368 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
2369 truncOp.getLoc(), rewriter.getI32IntegerAttr(0));
2370 if (elWidth == 16) {
2371 auto upsOp = rewriter.create<aievec::UPSOp>(truncOp.getLoc(), accType,
2372 truncOp.getIn());
2373 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
2374 truncOp, dstType, upsOp.getResult(), shiftParamOp.getResult());
2375 } else {
2376 auto castOp = rewriter.create<aievec::CastOp>(truncOp.getLoc(), accType,
2377 truncOp.getIn(), true);
2378 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
2379 truncOp, dstType, castOp.getResult(), shiftParamOp.getResult());
2380 }
2381
2382 return success();
2383 }
2384};
2385
2388
2389// If `op` is the last operation in the sequence:
2390// %0 = unrealized_conversion_cast <%IN> : <native type>, !emitc.opaque_type
2391// %1 = emitc.call_opaque <funcName>, %0...
2392// %2 = unrealized_conversion_cast %1 : !emitc.opaque_type, <native type>
2393// return the value <%IN>.
2394static std::optional<Value>
2395getUnOpaquedOperandOfEmitCOpaqueCallOp(Operation *op, StringRef funcName) {
2396 auto uccOp = dyn_cast<UnrealizedConversionCastOp>(op);
2397 if (!uccOp)
2398 return {};
2399
2400 auto inVal = uccOp.getInputs()[0];
2401 if (!isa<emitc::OpaqueType>(inVal.getType()))
2402 return {};
2403
2404 auto callOp = inVal.getDefiningOp<emitc::CallOpaqueOp>();
2405 if (callOp.getCallee() != funcName)
2406 return {};
2407
2408 auto callOperandsUccOp =
2409 callOp.getOperands()[0].getDefiningOp<UnrealizedConversionCastOp>();
2410 if (!callOperandsUccOp)
2411 return {};
2412
2413 return callOperandsUccOp.getInputs()[0];
2414}
2415
2416// Check there is an operation chain like-
2417//
2418// %cst_0 = arith.constant dense<1.000000e+00> : vector<16xbf16>
2419// %cst_1 = arith.constant 0.000000e+00 : bf16
2420// %0 = vector.transfer_read %arg0[%arg2], %cst_1 : memref<1024xbf16>,
2421// vector<16xbf16>
2422// %1 = arith.negf %0 : vector<16xbf16>
2423// %2 = math.exp %1 : vector<16xbf16>
2424// %3 = arith.addf %2, %cst_0 : vector<16xbf16>
2425// %4 = arith.divf %cst_0, %3 : vector<16xbf16>
2426//
2427// so that this operation chain can be converted to a function call to compute
2428// sigmoid value for v16bfloat16 and v32bfloat16 types
2429template <typename DivFOpTy>
2430static bool hasSigmoidComputationChain(DivFOpTy divfOp, arith::NegFOp &negOp) {
2431 auto constOp = dyn_cast<arith::ConstantOp>(divfOp.getLhs().getDefiningOp());
2432 if (!constOp)
2433 return false;
2434
2435 auto cstDense = dyn_cast<DenseFPElementsAttr>(constOp.getValue());
2436 if (!cstDense)
2437 return false;
2438
2439 if (cstDense.template getSplatValue<APFloat>().convertToFloat() != 1.0f)
2440 return false;
2441
2442 Operation *addLvalOp;
2443 Operation *addRvalOp;
2444 // divfOp's rval could be an arith::AddFOp or the pattern like-
2445 // %1 = aievec.ups %a
2446 // %2 = aievec.ups %b;
2447 // %3 = aievec.add_elem %1, %2
2448 // %4 = aievec.srs %3;
2449 auto addOp = dyn_cast<arith::AddFOp>(divfOp.getRhs().getDefiningOp());
2450 if (!addOp) {
2451 auto srsOp = dyn_cast<aievec::SRSOp>(divfOp.getRhs().getDefiningOp());
2452 if (!srsOp)
2453 return false;
2454
2455 auto addElemOp =
2456 dyn_cast<aievec::AddElemOp>(srsOp.getSource().getDefiningOp());
2457 if (!addElemOp)
2458 return false;
2459
2460 auto lUpsOp = dyn_cast<aievec::UPSOp>(addElemOp.getLhs().getDefiningOp());
2461 auto rUpsOp = dyn_cast<aievec::UPSOp>(addElemOp.getRhs().getDefiningOp());
2462 if (!lUpsOp || !rUpsOp)
2463 return false;
2464
2465 addLvalOp = lUpsOp.getSource().getDefiningOp();
2466 addRvalOp = rUpsOp.getSource().getDefiningOp();
2467 // One of add operation's operand is a constant op and another operand could
2468 // be arith::ExpOp or the combination of emitc.call and aievec.srs
2469 auto addDefOp = isa<arith::ConstantOp>(addLvalOp)
2470 ? dyn_cast<aievec::SRSOp>(addRvalOp)
2471 : dyn_cast<aievec::SRSOp>(addLvalOp);
2472 if (!addDefOp)
2473 addLvalOp = isa<arith::ConstantOp>(addLvalOp)
2474 ? dyn_cast<math::ExpOp>(addRvalOp)
2475 : dyn_cast<math::ExpOp>(addLvalOp);
2476 else
2477 addLvalOp = addDefOp.getSource().getDefiningOp();
2478
2479 addRvalOp = isa<arith::ConstantOp>(addLvalOp)
2480 ? lUpsOp.getSource().getDefiningOp()
2481 : rUpsOp.getSource().getDefiningOp();
2482 } else {
2483 addLvalOp = addOp.getLhs().getDefiningOp();
2484 addRvalOp = addOp.getRhs().getDefiningOp();
2485 }
2486
2487 if (!addLvalOp || !addRvalOp)
2488 return false;
2489
2490 auto addLvalExpOp = dyn_cast<math::ExpOp>(addLvalOp);
2491 auto addRvalExpOp = dyn_cast<math::ExpOp>(addRvalOp);
2492 auto addLvalExpOpIn =
2493 getUnOpaquedOperandOfEmitCOpaqueCallOp(addLvalOp, "getExpBf16")
2494 .value_or(nullptr);
2495 auto addRvalExpOpIn =
2496 getUnOpaquedOperandOfEmitCOpaqueCallOp(addRvalOp, "getExpBf16")
2497 .value_or(nullptr);
2498 if (!addLvalExpOpIn && addLvalExpOp)
2499 addLvalExpOpIn = addLvalExpOp.getOperand();
2500 if (!addRvalExpOpIn && addRvalExpOp)
2501 addRvalExpOpIn = addRvalExpOp.getOperand();
2502
2503 if (!((addLvalExpOpIn && isa<arith::ConstantOp>(addRvalOp)) ||
2504 (addRvalExpOpIn && isa<arith::ConstantOp>(addLvalOp))))
2505 return false;
2506
2507 constOp = isa<arith::ConstantOp>(addLvalOp)
2508 ? cast<arith::ConstantOp>(addLvalOp)
2509 : cast<arith::ConstantOp>(addRvalOp);
2510
2511 cstDense = dyn_cast<DenseFPElementsAttr>(constOp.getValue());
2512 if (!cstDense)
2513 return false;
2514 if (cstDense.template getSplatValue<APFloat>().convertToFloat() != 1.0f)
2515 return false;
2516
2517 auto expOperand = addLvalExpOpIn ? addLvalExpOpIn : addRvalExpOpIn;
2518
2519 negOp = expOperand.getDefiningOp<arith::NegFOp>();
2520
2521 return negOp != nullptr;
2522}
2523
2524// Convert the operation chain like-
2525//
2526// %cst_0 = arith.constant dense<1.000000e+00> : vector<16xbf16>
2527// %cst_1 = arith.constant 0.000000e+00 : bf16
2528// %0 = vector.transfer_read %arg0[%arg2], %cst_1 : memref<1024xbf16>,
2529// vector<16xbf16>
2530// %1 = arith.negf %0 : vector<16xbf16>
2531// %2 = math.exp %1 :vector<16xbf16>
2532// %3 = arith.addf %2, %cst_0 : vector<16xbf16>
2533// %4 = arith.divf %cst_0, %3 : vector<16xbf16>
2534//
2535// to a function call to compute sigmoid value for v16bfloat16 and
2536// v32bfloat16 types
2538 using OpConversionPattern::OpConversionPattern;
2539
2540 LogicalResult
2541 matchAndRewrite(arith::DivFOp divfOp, OpAdaptor adaptor,
2542 ConversionPatternRewriter &rewriter) const override {
2543 auto srcType = dyn_cast<VectorType>(adaptor.getLhs().getType());
2544 if (!srcType)
2545 return failure();
2546
2547 Type scalarType = srcType.getElementType();
2548 if (!isa<FloatType>(scalarType))
2549 return failure();
2550
2551 unsigned laneSize = getVectorLaneSize(srcType);
2552 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2553 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
2554 return failure();
2555
2556 arith::NegFOp negOp = nullptr;
2557 if (!hasSigmoidComputationChain(adaptor, negOp))
2558 return failure();
2559
2560 StringRef includeName = "vec_math.h";
2561 auto moduleOp = divfOp->getParentOfType<mlir::ModuleOp>();
2562 rewriter.setInsertionPointToStart(
2563 &moduleOp.getRegion().getBlocks().front());
2564 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName, false);
2565
2566 rewriter.setInsertionPoint(divfOp);
2567 Type vecOpaqueTy;
2568 if (laneSize == 16)
2569 vecOpaqueTy =
2570 emitc::OpaqueType::get(rewriter.getContext(), "v16bfloat16");
2571 else
2572 vecOpaqueTy =
2573 emitc::OpaqueType::get(rewriter.getContext(), "v32bfloat16");
2574 auto opaquedOperand =
2575 rewriter
2576 .create<UnrealizedConversionCastOp>(divfOp.getLoc(), vecOpaqueTy,
2577 negOp.getOperand())
2578 .getResult(0);
2579 SmallVector<Value> sigmoidOperands = {opaquedOperand};
2580 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2581 divfOp.getLoc(), TypeRange{vecOpaqueTy}, "getSigmoidBf16", nullptr,
2582 nullptr, sigmoidOperands);
2583 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2584 divfOp, TypeRange{adaptor.getLhs().getType()}, callOp.getResults());
2585
2586 return success();
2587 }
2588};
2589
2590// Convert math.ceil to a function call to compute ceil(x) for v16bfloat16
2592 using OpConversionPattern::OpConversionPattern;
2593
2594 LogicalResult
2595 matchAndRewrite(math::CeilOp ceilOp, OpAdaptor adaptor,
2596 ConversionPatternRewriter &rewriter) const override {
2597 auto srcType = dyn_cast<VectorType>(ceilOp.getOperand().getType());
2598 if (!srcType)
2599 return failure();
2600
2601 Type scalarType = srcType.getElementType();
2602 if (!isa<FloatType>(scalarType))
2603 return failure();
2604
2605 unsigned laneSize = getVectorLaneSize(srcType);
2606 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2607 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
2608 return failure();
2609
2610 StringRef includeName = "vec_math.h";
2611 auto moduleOp = ceilOp->getParentOfType<mlir::ModuleOp>();
2612 rewriter.setInsertionPointToStart(
2613 &moduleOp.getRegion().getBlocks().front());
2614 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName, false);
2615
2616 rewriter.setInsertionPoint(ceilOp);
2617 Type vecOpaqueTy;
2618 if (laneSize == 16)
2619 vecOpaqueTy =
2620 emitc::OpaqueType::get(rewriter.getContext(), "v16bfloat16");
2621 else
2622 vecOpaqueTy =
2623 emitc::OpaqueType::get(rewriter.getContext(), "v32bfloat16");
2624 auto opaquedOperand =
2625 rewriter
2626 .create<UnrealizedConversionCastOp>(ceilOp.getLoc(), vecOpaqueTy,
2627 adaptor.getOperand())
2628 .getResult(0);
2629 SmallVector<Value> ceilOperands = {opaquedOperand};
2630 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2631 ceilOp.getLoc(), TypeRange{vecOpaqueTy}, "getCeilBf16", nullptr,
2632 nullptr, ceilOperands);
2633 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2634 ceilOp, TypeRange{ceilOp.getResult().getType()}, callOp.getResults());
2635
2636 return success();
2637 }
2638};
2639
2640// Convert math.floor to a function call to compute floor(x) for v16bfloat16
2642 using OpConversionPattern::OpConversionPattern;
2643
2644 LogicalResult
2645 matchAndRewrite(math::FloorOp floorOp, OpAdaptor adaptor,
2646 ConversionPatternRewriter &rewriter) const override {
2647 auto srcType = dyn_cast<VectorType>(floorOp.getOperand().getType());
2648 if (!srcType)
2649 return failure();
2650
2651 Type scalarType = srcType.getElementType();
2652 if (!isa<FloatType>(scalarType))
2653 return failure();
2654
2655 unsigned laneSize = getVectorLaneSize(srcType);
2656 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2657 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
2658 return failure();
2659
2660 StringRef includeName = "vec_math.h";
2661 auto moduleOp = floorOp->getParentOfType<mlir::ModuleOp>();
2662 rewriter.setInsertionPointToStart(
2663 &moduleOp.getRegion().getBlocks().front());
2664 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName, false);
2665
2666 rewriter.setInsertionPoint(floorOp);
2667 Type vecOpaqueTy;
2668 if (laneSize == 16)
2669 vecOpaqueTy =
2670 emitc::OpaqueType::get(rewriter.getContext(), "v16bfloat16");
2671 else
2672 vecOpaqueTy =
2673 emitc::OpaqueType::get(rewriter.getContext(), "v32bfloat16");
2674 auto opaquedOperand =
2675 rewriter
2676 .create<UnrealizedConversionCastOp>(floorOp.getLoc(), vecOpaqueTy,
2677 adaptor.getOperand())
2678 .getResult(0);
2679 SmallVector<Value> floorOperands = {opaquedOperand};
2680 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2681 floorOp.getLoc(), TypeRange{vecOpaqueTy}, "getFloorBf16", nullptr,
2682 nullptr, floorOperands);
2683 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2684 floorOp, TypeRange{floorOp.getResult().getType()}, callOp.getResults());
2685
2686 return success();
2687 }
2688};
2689
2690// Convert arith.negf to aievec.neg to negate the vector for v16bfloat16 and
2691// v16float types.
2693 using OpConversionPattern::OpConversionPattern;
2694
2695 LogicalResult
2696 matchAndRewrite(arith::NegFOp negOp, OpAdaptor adaptor,
2697 ConversionPatternRewriter &rewriter) const override {
2698 auto srcType = dyn_cast<VectorType>(negOp.getOperand().getType());
2699 if (!srcType)
2700 return failure();
2701
2702 Type scalarType = srcType.getElementType();
2703 if (!isa<FloatType>(scalarType))
2704 return failure();
2705
2706 if (unsigned laneSize = getVectorLaneSize(srcType); laneSize != 16)
2707 return failure();
2708
2709 Location loc = negOp.getLoc();
2710 auto accType = getVectorOpDestType(srcType, /*AIE2 =*/true);
2711
2712 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2713 if (elWidth == 16) {
2714 auto upsOp =
2715 rewriter.create<aievec::UPSOp>(loc, accType, adaptor.getOperand());
2716 auto aieNegOp =
2717 rewriter.create<aievec::NegOp>(loc, accType, upsOp.getResult());
2718 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
2719 negOp.getLoc(), rewriter.getI32IntegerAttr(0));
2720 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
2721 negOp, srcType, aieNegOp.getResult(), shiftParamOp.getResult());
2722 } else {
2723 auto castOp = rewriter.create<aievec::CastOp>(
2724 loc, accType, adaptor.getOperand(), /*isResAcc*/ true);
2725 auto aieNegOp =
2726 rewriter.create<aievec::NegOp>(loc, accType, castOp.getResult());
2727 rewriter.replaceOpWithNewOp<aievec::CastOp>(
2728 negOp, srcType, aieNegOp.getResult(), /*isResAcc*/ false);
2729 }
2730
2731 return success();
2732 }
2733};
2734
2735// Check whether the value of constant operation is int type and the dense value
2736// is -1.
2737static bool hasConstNegOneValue(arith::ConstantOp constOp, unsigned elWidth) {
2738 if (!constOp)
2739 return false;
2740
2741 auto cstDense = dyn_cast<DenseIntElementsAttr>(constOp.getValue());
2742 if (!cstDense)
2743 return false;
2744
2745 if (elWidth == 32)
2746 return cstDense.getSplatValue<int32_t>() == -1;
2747 if (elWidth == 16)
2748 return cstDense.getSplatValue<int16_t>() == -1;
2749 if (elWidth == 8)
2750 return cstDense.getSplatValue<int8_t>() == -1;
2751 return false;
2752}
2753
2754// Convert arith.xori to aievec.bxor to compute bitwise xor of two vectors for
2755// integer types
2757 using OpConversionPattern::OpConversionPattern;
2758
2759 LogicalResult
2760 matchAndRewrite(arith::XOrIOp xorOp, OpAdaptor adaptor,
2761 ConversionPatternRewriter &rewriter) const override {
2762 auto srcType = dyn_cast<VectorType>(xorOp.getLhs().getType());
2763 if (!srcType)
2764 return failure();
2765
2766 Type scalarType = srcType.getElementType();
2767 if (!isa<IntegerType>(scalarType))
2768 return failure();
2769
2770 unsigned laneSize = getVectorLaneSize(srcType);
2771 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2772 if (laneSize * elWidth != 512)
2773 return failure();
2774
2775 auto lhsConstOp =
2776 dyn_cast<arith::ConstantOp>(xorOp.getLhs().getDefiningOp());
2777 auto rhsConstOp =
2778 dyn_cast<arith::ConstantOp>(xorOp.getRhs().getDefiningOp());
2779
2780 // If one of operands in xorOp is a constant -1, xorOp will be replaced with
2781 // aievec::BnegOp.
2782 if ((lhsConstOp && hasConstNegOneValue(lhsConstOp, elWidth)) ||
2783 (rhsConstOp && hasConstNegOneValue(rhsConstOp, elWidth))) {
2784 Value val = hasConstNegOneValue(lhsConstOp, elWidth) ? adaptor.getRhs()
2785 : adaptor.getLhs();
2786 rewriter.replaceOpWithNewOp<aievec::BnegOp>(xorOp, srcType, val);
2787 } else
2788 rewriter.replaceOpWithNewOp<aievec::BxorOp>(
2789 xorOp, srcType, adaptor.getLhs(), adaptor.getRhs());
2790
2791 return success();
2792 }
2793};
2794
2795template <typename SrcOpTy, typename DstOpTy>
2798 using OpAdaptor = typename SrcOpTy::Adaptor;
2799
2800 LogicalResult
2801 matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor,
2802 ConversionPatternRewriter &rewriter) const override {
2803 VectorType srcType = dyn_cast<VectorType>(srcOp.getLhs().getType());
2804 if (!srcType)
2805 return failure();
2806
2807 Type scalarType = srcType.getElementType();
2808 if (!isa<IntegerType>(scalarType))
2809 return failure();
2810
2811 unsigned laneSize = getVectorLaneSize(srcType);
2812 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2813 if (laneSize * elWidth != 512)
2814 return failure();
2815
2816 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getResult().getType(),
2817 adaptor.getLhs(), adaptor.getRhs());
2818
2819 return success();
2820 }
2821};
2822
2827
2828// Convert arith.shrsi to a combination of aievec.ups and aievec.srs to compute
2829// arithmetic right shift for integer types. Currently, only support the shift
2830// value with a broadcast vector.
2832 : OpConversionPattern<arith::ShRSIOp> {
2833 using OpConversionPattern::OpConversionPattern;
2834
2835 LogicalResult
2836 matchAndRewrite(arith::ShRSIOp rsOp, OpAdaptor adaptor,
2837 ConversionPatternRewriter &rewriter) const override {
2838 auto srcType = dyn_cast<VectorType>(adaptor.getLhs().getType());
2839 if (!srcType)
2840 return failure();
2841
2842 Type scalarType = srcType.getElementType();
2843 unsigned laneSize = getVectorLaneSize(srcType);
2844 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2845 if (laneSize * elWidth != 512)
2846 return failure();
2847
2848 auto bcastOp =
2849 dyn_cast<aievec::BroadcastOp>(adaptor.getRhs().getDefiningOp());
2850 if (!bcastOp)
2851 return failure();
2852
2853 auto constOp = rewriter.create<arith::ConstantOp>(
2854 bcastOp.getLoc(), rewriter.getI32IntegerAttr(bcastOp.getIdx()));
2855 auto extElemOp = rewriter.create<aievec::ExtElemOp>(
2856 bcastOp.getLoc(), scalarType, bcastOp, constOp.getResult());
2857 Location loc = rsOp.getLoc();
2858
2859 // The vector with v64int8 type can be divided into two v32int8 vectors and
2860 // be processed individually and be concatenated at the end.
2861 if (elWidth == 8) {
2862 VectorType halfSrcType = createVectorType(laneSize / 2, scalarType);
2863 auto rsOpLow =
2864 rewriter.create<aievec::ExtOp>(loc, halfSrcType, adaptor.getLhs(), 0);
2865 auto rsOpHigh =
2866 rewriter.create<aievec::ExtOp>(loc, halfSrcType, adaptor.getLhs(), 1);
2867 Type accType = getVectorOpDestType(halfSrcType, /*AIE2 =*/true);
2868 auto upsOpLow =
2869 rewriter.create<aievec::UPSOp>(loc, accType, rsOpLow.getResult());
2870 auto srsOpLow = rewriter.create<aievec::SRSOp>(
2871 loc, halfSrcType, upsOpLow.getResult(), extElemOp.getResult());
2872 auto upsOpHigh =
2873 rewriter.create<aievec::UPSOp>(loc, accType, rsOpHigh.getResult());
2874 auto srsOpHigh = rewriter.create<aievec::SRSOp>(
2875 loc, halfSrcType, upsOpHigh.getResult(), extElemOp.getResult());
2876 SmallVector<Value> inputSources = {srsOpLow.getResult(),
2877 srsOpHigh.getResult()};
2878 rewriter.replaceOpWithNewOp<aievec::ConcatOp>(rsOp, srcType,
2879 inputSources);
2880 } else {
2881 Type accType = getVectorOpDestType(srcType, /*AIE2 =*/true);
2882 auto upsOp =
2883 rewriter.create<aievec::UPSOp>(loc, accType, adaptor.getLhs());
2884 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
2885 rsOp, srcType, upsOp.getResult(), extElemOp.getResult());
2886 }
2887
2888 return success();
2889 }
2890};
2891
2892// Convert a `vector.contract` op to an `aievec.matmul` op for AIE2
2894 : OpConversionPattern<vector::ContractionOp> {
2895 using OpConversionPattern::OpConversionPattern;
2896
2900
2901 Value reshapeLeadingUnitDims(OpBuilder &b, Value v) const {
2902 auto vecTy = dyn_cast<VectorType>(v.getType());
2903 if (!vecTy)
2904 return v;
2905 auto vecShape = vecTy.getShape();
2906
2907 size_t numLeadUnitDims = 0;
2908 while (numLeadUnitDims < vecShape.size() && vecShape[numLeadUnitDims] == 1)
2909 numLeadUnitDims++;
2910
2911 if (!numLeadUnitDims)
2912 return v;
2913
2914 SmallVector<int64_t> newShape(vecShape.begin() + numLeadUnitDims,
2915 vecShape.end());
2916 auto newVecTy = VectorType::get(newShape, vecTy.getElementType());
2917 return b.create<vector::ShapeCastOp>(v.getLoc(), newVecTy, v).getResult();
2918 }
2919
2920 LogicalResult
2921 matchAndRewrite(vector::ContractionOp contractOp, OpAdaptor adaptor,
2922 ConversionPatternRewriter &rewriter) const override {
2923 auto lhs = reshapeLeadingUnitDims(rewriter, adaptor.getLhs());
2924 auto rhs = reshapeLeadingUnitDims(rewriter, adaptor.getRhs());
2925 auto acc = reshapeLeadingUnitDims(rewriter, adaptor.getAcc());
2926 bool bReshapedAcc = (acc != adaptor.getAcc());
2927
2928 if (matMoveToAcc)
2929 acc = rewriter.create<aievec::CastOp>(contractOp.getLoc(), acc.getType(),
2930 acc, true);
2931
2932 auto matmulOp = rewriter.create<aievec::MatMulOp>(
2933 contractOp.getLoc(), acc.getType(), lhs, rhs, acc);
2934 {
2935 // Replace diagnostics handler to silence errors when verifying the
2936 // validity of the `aievec.matmul` ops being generated.
2937 ScopedDiagnosticHandler diagHandler(
2938 contractOp.getContext(), [](Diagnostic &) { return success(); });
2939 if (failed(matmulOp.verifyInvariants())) {
2940 rewriter.eraseOp(matmulOp);
2941 // There is a possibility that, when the linalg op is converted to
2942 // contractions, lower precisions operands are cast to the target
2943 // precission outside the contraction. For those cases, we check.
2944 lhs = adaptor.getLhs();
2945 auto wideLhsValue = getSourceOfWideningOp(lhs).value_or(nullptr);
2946 if (wideLhsValue)
2947 lhs = reshapeLeadingUnitDims(rewriter, wideLhsValue);
2948
2949 rhs = adaptor.getRhs();
2950 auto wideRhsValue = getSourceOfWideningOp(rhs).value_or(nullptr);
2951 if (wideRhsValue)
2952 rhs = reshapeLeadingUnitDims(rewriter, wideRhsValue);
2953
2954 matmulOp = rewriter.create<aievec::MatMulOp>(
2955 contractOp.getLoc(), acc.getType(), lhs, rhs, acc);
2956 if (failed(matmulOp.verifyInvariants()))
2957 return failure();
2958 }
2959 }
2960
2961 Value result = matmulOp.getResult();
2962 if (matMoveToAcc)
2963 result = rewriter.create<aievec::CastOp>(contractOp.getLoc(),
2964 acc.getType(), matmulOp, false);
2965 if (bReshapedAcc)
2966 result = rewriter.create<vector::ShapeCastOp>(
2967 contractOp.getLoc(), adaptor.getAcc().getType(), result);
2968 rewriter.replaceOp(contractOp, result);
2969
2970 return success();
2971 }
2972
2974};
2975
2976// Convert a `vector.transpose` op to an `aievec.shuffle` op for AIE2.
2978 : OpConversionPattern<vector::TransposeOp> {
2979 using OpConversionPattern::OpConversionPattern;
2980 LogicalResult
2981 matchAndRewrite(vector::TransposeOp transpOp, OpAdaptor adaptor,
2982 ConversionPatternRewriter &rewriter) const override {
2983 auto resTy = transpOp.getResultVectorType();
2984 auto resShape = resTy.getShape();
2985 auto elemTyBitWidth = resTy.getElementTypeBitWidth();
2986 auto vBitWidth = std::accumulate(resShape.begin(), resShape.end(),
2987 elemTyBitWidth, std::multiplies<>());
2988 if (vBitWidth != 512)
2989 return failure();
2990
2991 if (elemTyBitWidth != 8 && elemTyBitWidth != 16 && elemTyBitWidth != 32)
2992 return failure();
2993
2994 // Verify leading dimensions are all 1.
2995 for (int64_t i = 0; i < static_cast<int64_t>(resShape.size() - 2); ++i)
2996 if (resShape[i] != 1)
2997 return failure();
2998
2999 // Only permutation of the 2 innermost dimensions are supported.
3000 ArrayRef<int64_t> perm = transpOp.getPermutation();
3001 for (int64_t i = 0; i < static_cast<int64_t>(perm.size() - 2); ++i)
3002 if (perm[i] != i)
3003 return failure();
3004 if (perm.back() != static_cast<int64_t>(perm.size() - 2))
3005 return failure();
3006
3007 auto shuffleMode = aievec::ShuffleMode::T32_4X4;
3008 if (elemTyBitWidth == 8) {
3009 switch (resShape.back()) {
3010 case 4:
3011 shuffleMode = aievec::ShuffleMode::T8_4X16;
3012 break;
3013 case 8:
3014 shuffleMode = aievec::ShuffleMode::T8_8X8;
3015 break;
3016 case 16:
3017 shuffleMode = aievec::ShuffleMode::T8_16X4;
3018 break;
3019 default:
3020 return failure();
3021 }
3022 } else if (elemTyBitWidth == 16) {
3023 switch (resShape.back()) {
3024 case 2:
3025 shuffleMode = aievec::ShuffleMode::T16_2X16;
3026 break;
3027 case 4:
3028 shuffleMode = aievec::ShuffleMode::T16_4X8;
3029 break;
3030 case 8:
3031 shuffleMode = aievec::ShuffleMode::T16_8X4;
3032 break;
3033 case 16:
3034 shuffleMode = aievec::ShuffleMode::T16_16X2;
3035 break;
3036 default:
3037 return failure();
3038 }
3039 } else if (resShape.back() != 4)
3040 return failure();
3041
3042 auto flatVecTy =
3043 VectorType::get({512 / elemTyBitWidth}, resTy.getElementType());
3044 auto loc = transpOp.getLoc();
3045 auto flatInput = rewriter.create<vector::ShapeCastOp>(loc, flatVecTy,
3046 adaptor.getVector());
3047 auto shuffOp = rewriter.create<aievec::ShuffleOp>(loc, flatVecTy, flatInput,
3048 nullptr, shuffleMode);
3049 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transpOp, resTy, shuffOp);
3050
3051 return success();
3052 }
3053};
3054
3055//===----------------------------------------------------------------------===//
3056// Pattern collection
3057//===----------------------------------------------------------------------===//
3058
3059static void populateAIEVecCommonConversionPatterns(RewritePatternSet &patterns,
3060 TargetBackend backend) {
3061 // clang-format off
3062 patterns.add<LowerExtFOpPattern,
3065 LowerTruncIOpPattern>(patterns.getContext());
3066 // clang-format on
3067}
3068
3069static void populateAIEVecV1ConversionPatterns(RewritePatternSet &patterns,
3070 TargetBackend backend) {
3071 patterns.add<LowerVectorTransferReadToAIEUPD>(patterns.getContext(), 128, 512,
3072 128, 256);
3073 // clang-format off
3074 patterns.add<LowerVectorAddIOpToAIEVecAddOp,
3082 LowerVectorExtractStridedSliceOpAIEv1Pattern>(patterns.getContext());
3083 // clang-format on
3084}
3085
3086static void populateAIEVecV2ConversionPatterns(RewritePatternSet &patterns,
3087 TargetBackend backend) {
3088 // clang-format off
3089 // TODO: Reorder these alphabetically
3090 if (backend == TargetBackend::CPP) {
3091 patterns.add<
3093 >(patterns.getContext(), 128, 1024, 256, 1024);
3094 patterns.add<
3100 >(patterns.getContext());
3101 } else if (backend == TargetBackend::LLVMIR){
3102 patterns.add<
3104 >(patterns.getContext());
3105 }
3106 patterns.add<
3142 >(patterns.getContext());
3144 >(patterns.getContext(), backend == TargetBackend::CPP);
3145 // clang-format on
3146}
3147
3148//===----------------------------------------------------------------------===//
3149// Legalizations
3150//===----------------------------------------------------------------------===//
3151
3152// TODO: Review the validity of these legalizations beyond basic cases.
3153
3154static bool isInSigmoidOperationChain(math::ExpOp expOp) {
3155 if (!expOp.getOperand().getDefiningOp<arith::NegFOp>())
3156 return false;
3157
3158 arith::AddFOp addOp = nullptr;
3159 for (Operation *user : expOp->getUsers()) {
3160 addOp = dyn_cast<arith::AddFOp>(user);
3161 if (addOp)
3162 break;
3163 }
3164
3165 if (!addOp)
3166 return false;
3167
3168 auto *addLvalOp = addOp.getLhs().getDefiningOp();
3169 auto *addRvalOp = addOp.getRhs().getDefiningOp();
3170 if (!((isa<math::ExpOp>(addLvalOp) && isa<arith::ConstantOp>(addRvalOp)) ||
3171 (isa<math::ExpOp>(addRvalOp) && isa<arith::ConstantOp>(addLvalOp))))
3172 return false;
3173
3174 auto constOp = isa<arith::ConstantOp>(addLvalOp)
3175 ? cast<arith::ConstantOp>(addLvalOp)
3176 : cast<arith::ConstantOp>(addRvalOp);
3177
3178 auto cstDense = dyn_cast<DenseFPElementsAttr>(constOp.getValue());
3179 if (!cstDense)
3180 return false;
3181
3182 if (cstDense.getSplatValue<APFloat>().convertToFloat() != 1.0f)
3183 return false;
3184
3185 arith::DivFOp divOp = nullptr;
3186 for (Operation *user : addOp->getUsers()) {
3187 divOp = dyn_cast<arith::DivFOp>(user);
3188 if (divOp)
3189 break;
3190 }
3191
3192 if (!divOp)
3193 return false;
3194
3195 constOp = dyn_cast<arith::ConstantOp>(divOp.getLhs().getDefiningOp());
3196 if (!constOp)
3197 return false;
3198 cstDense = dyn_cast<DenseFPElementsAttr>(constOp.getValue());
3199 if (!cstDense)
3200 return false;
3201 if (cstDense.getSplatValue<APFloat>().convertToFloat() != 1.0f)
3202 return false;
3203
3204 return true;
3205}
3206
3207static void configureAIEVecCommonLegalizations(ConversionTarget &target,
3208 TargetBackend backend) {
3209 target
3210 .addLegalDialect<xilinx::aievec::aie1::AIEVecAIE1Dialect,
3211 xilinx::aievec::AIEVecDialect, arith::ArithDialect,
3212 ub::UBDialect, emitc::EmitCDialect, func::FuncDialect>();
3213 if (backend == TargetBackend::CPP) {
3214 target.addIllegalOp<vector::TransferReadOp>();
3215 }
3216 target.addIllegalOp<vector::ExtractStridedSliceOp>();
3217 target.addLegalOp<vector::BitCastOp>();
3218
3219 target.addDynamicallyLegalOp<arith::ExtFOp>([](arith::ExtFOp extfOp) {
3220 auto srcType = dyn_cast<VectorType>(extfOp.getIn().getType());
3221 auto dstType = dyn_cast<VectorType>(extfOp.getOut().getType());
3222 if (!srcType || !dstType)
3223 return true;
3224
3225 Type srcScalarType = srcType.getElementType();
3226 Type dstScalarType = dstType.getElementType();
3227 if (!isa<FloatType>(srcScalarType) || !isa<FloatType>(dstScalarType))
3228 return true;
3229
3230 unsigned srcLaneSize = getVectorLaneSize(srcType);
3231 unsigned dstLaneSize = getVectorLaneSize(dstType);
3232 unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth();
3233 unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth();
3234 return srcElWidth != 16 || srcLaneSize != 16 || dstElWidth != 32 ||
3235 dstLaneSize != 16;
3236 });
3237
3238 target.addDynamicallyLegalOp<arith::ExtSIOp>([](arith::ExtSIOp extsiOp) {
3239 auto srcType = dyn_cast<VectorType>(extsiOp.getIn().getType());
3240 auto dstType = dyn_cast<VectorType>(extsiOp.getOut().getType());
3241 if (!srcType || !dstType)
3242 return true;
3243
3244 Type srcScalarType = srcType.getElementType();
3245 Type dstScalarType = dstType.getElementType();
3246 if (!isa<IntegerType>(srcScalarType) || !isa<IntegerType>(dstScalarType))
3247 return true;
3248
3249 unsigned srcLaneSize = getVectorLaneSize(srcType);
3250 unsigned dstLaneSize = getVectorLaneSize(dstType);
3251 unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth();
3252 unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth();
3253 return srcLaneSize != 32 || (dstElWidth <= srcElWidth) ||
3254 (dstLaneSize != srcLaneSize);
3255 });
3256
3257 target.addDynamicallyLegalOp<arith::TruncFOp>([](arith::TruncFOp truncfOp) {
3258 auto srcType = dyn_cast<VectorType>(truncfOp.getIn().getType());
3259 auto dstType = dyn_cast<VectorType>(truncfOp.getOut().getType());
3260 if (!srcType || !dstType)
3261 return true;
3262
3263 Type srcScalarType = srcType.getElementType();
3264 Type dstScalarType = dstType.getElementType();
3265 if (!isa<FloatType>(srcScalarType) || !isa<FloatType>(dstScalarType))
3266 return true;
3267
3268 unsigned srcLaneSize = getVectorLaneSize(srcType);
3269 unsigned dstLaneSize = getVectorLaneSize(dstType);
3270 unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth();
3271 unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth();
3272 return srcElWidth != 32 || srcLaneSize != 16 || dstElWidth != 16 ||
3273 dstLaneSize != 16;
3274 });
3275
3276 target.addDynamicallyLegalOp<arith::TruncIOp>([](arith::TruncIOp trunciOp) {
3277 auto srcType = dyn_cast<VectorType>(trunciOp.getIn().getType());
3278 auto dstType = dyn_cast<VectorType>(trunciOp.getOut().getType());
3279 if (!srcType || !dstType)
3280 return true;
3281
3282 Type srcScalarType = srcType.getElementType();
3283 Type dstScalarType = dstType.getElementType();
3284 if (!isa<IntegerType>(srcScalarType) || !isa<IntegerType>(dstScalarType))
3285 return true;
3286
3287 unsigned srcLaneSize = getVectorLaneSize(srcType);
3288 unsigned dstLaneSize = getVectorLaneSize(dstType);
3289 unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth();
3290 unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth();
3291
3292 return srcLaneSize != 32 || (dstElWidth >= srcElWidth) ||
3293 (dstLaneSize != srcLaneSize);
3294 });
3295
3296 target.addDynamicallyLegalOp<math::ExpOp>([](math::ExpOp expOp) {
3297 auto srcType = dyn_cast<VectorType>(expOp.getOperand().getType());
3298 if (!srcType)
3299 return true;
3300
3301 Type scalarType = srcType.getElementType();
3302 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3303 unsigned laneSize = getVectorLaneSize(srcType);
3304 if (!isa<FloatType>(scalarType) || laneSize != 16 || elWidth != 16)
3305 return true;
3306 if (expOp->hasOneUse() && isInSigmoidOperationChain(expOp))
3307 return true;
3308
3309 return false;
3310 });
3311
3312 target.addDynamicallyLegalOp<math::TanhOp>([](math::TanhOp tanhOp) {
3313 auto srcType = dyn_cast<VectorType>(tanhOp.getOperand().getType());
3314 if (!srcType)
3315 return true;
3316
3317 Type scalarType = srcType.getElementType();
3318 if (!isa<FloatType>(scalarType))
3319 return true;
3320
3321 unsigned laneSize = getVectorLaneSize(srcType);
3322 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3323 return elWidth != 16 || laneSize != 16;
3324 });
3325
3326 target.addDynamicallyLegalOp<math::SqrtOp>([](math::SqrtOp sqrtOp) {
3327 auto srcType = dyn_cast<VectorType>(sqrtOp.getOperand().getType());
3328 if (!srcType)
3329 return true;
3330
3331 Type scalarType = srcType.getElementType();
3332 if (!isa<FloatType>(scalarType))
3333 return true;
3334
3335 unsigned laneSize = getVectorLaneSize(srcType);
3336 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3337 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
3338 });
3339
3340 target.addDynamicallyLegalOp<math::RsqrtOp>([](math::RsqrtOp rsqrtOp) {
3341 auto srcType = dyn_cast<VectorType>(rsqrtOp.getOperand().getType());
3342 Type scalarType = srcType.getElementType();
3343 if (!srcType || !isa<FloatType>(scalarType))
3344 return true;
3345
3346 unsigned laneSize = getVectorLaneSize(srcType);
3347 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3348 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
3349 });
3350
3351 target.addDynamicallyLegalOp<math::ErfOp>([](math::ErfOp erfOp) {
3352 auto srcType = dyn_cast<VectorType>(erfOp.getOperand().getType());
3353 if (!srcType)
3354 return true;
3355
3356 Type scalarType = srcType.getElementType();
3357 if (!isa<FloatType>(scalarType))
3358 return true;
3359
3360 unsigned laneSize = getVectorLaneSize(srcType);
3361 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3362 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
3363 });
3364
3365 target.addDynamicallyLegalOp<math::AbsFOp>([](math::AbsFOp absfOp) {
3366 auto srcType = dyn_cast<VectorType>(absfOp.getOperand().getType());
3367 if (!srcType)
3368 return true;
3369
3370 Type scalarType = srcType.getElementType();
3371 unsigned laneSize = getVectorLaneSize(srcType);
3372 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3373 return elWidth * laneSize != 512 && elWidth * laneSize != 256;
3374 });
3375
3376 target.addDynamicallyLegalOp<math::AbsIOp>([](math::AbsIOp absiOp) {
3377 auto srcType = dyn_cast<VectorType>(absiOp.getOperand().getType());
3378 if (!srcType)
3379 return true;
3380
3381 Type scalarType = srcType.getElementType();
3382 unsigned laneSize = getVectorLaneSize(srcType);
3383 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3384 return elWidth * laneSize != 512 && elWidth * laneSize != 256;
3385 });
3386
3387 target.addDynamicallyLegalOp<arith::DivFOp>([](arith::DivFOp divfOp) {
3388 if (auto srcType = dyn_cast<VectorType>(divfOp.getLhs().getType());
3389 !srcType) {
3390 Type scalarType = divfOp.getLhs().getType();
3391 if (!divfOp->hasOneUse() || !isa<FloatType>(scalarType))
3392 return true;
3393 if (!isNarrowingOp(*divfOp->getUsers().begin()))
3394 return true;
3395
3396 auto fType = cast<FloatType>(scalarType);
3397 if (fType.getWidth() != 32)
3398 return true;
3399
3400 auto constOp =
3401 dyn_cast<arith::ConstantOp>(divfOp.getLhs().getDefiningOp());
3402 if (!constOp ||
3403 cast<FloatAttr>(constOp.getValue()).getValue().convertToDouble() !=
3404 1.0f)
3405 return true;
3406 } else {
3407 Type scalarType = srcType.getElementType();
3408 if (!isa<FloatType>(scalarType))
3409 return true;
3410
3411 unsigned laneSize = getVectorLaneSize(srcType);
3412 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3413
3414 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
3415 return true;
3416
3417 arith::NegFOp negOp = nullptr;
3418 if (!hasSigmoidComputationChain(divfOp, negOp))
3419 return true;
3420 }
3421
3422 return false;
3423 });
3424
3425 target.addDynamicallyLegalOp<math::CeilOp>([](math::CeilOp ceilOp) {
3426 auto srcType = dyn_cast<VectorType>(ceilOp.getOperand().getType());
3427 if (!srcType)
3428 return true;
3429 Type scalarType = srcType.getElementType();
3430 if (!isa<FloatType>(scalarType))
3431 return true;
3432
3433 unsigned laneSize = getVectorLaneSize(srcType);
3434 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3435 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
3436 });
3437
3438 target.addDynamicallyLegalOp<math::FloorOp>([](math::FloorOp floorOp) {
3439 auto srcType = dyn_cast<VectorType>(floorOp.getOperand().getType());
3440 if (!srcType)
3441 return true;
3442 Type scalarType = srcType.getElementType();
3443 if (!isa<FloatType>(scalarType))
3444 return true;
3445
3446 unsigned laneSize = getVectorLaneSize(srcType);
3447 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3448 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
3449 });
3450
3451 target.addDynamicallyLegalOp<arith::NegFOp>([](arith::NegFOp negOp) {
3452 auto srcType = dyn_cast<VectorType>(negOp.getOperand().getType());
3453 if (!srcType)
3454 return true;
3455 if (Type scalarType = srcType.getElementType(); !isa<FloatType>(scalarType))
3456 return true;
3457
3458 unsigned laneSize = getVectorLaneSize(srcType);
3459 return laneSize != 16;
3460 });
3461
3462 target.addDynamicallyLegalOp<arith::XOrIOp>([](arith::XOrIOp xorOp) {
3463 auto srcType = dyn_cast<VectorType>(xorOp.getLhs().getType());
3464 if (!srcType)
3465 return true;
3466 Type scalarType = srcType.getElementType();
3467 if (!isa<IntegerType>(scalarType))
3468 return true;
3469
3470 unsigned laneSize = getVectorLaneSize(srcType);
3471 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3472
3473 return laneSize * elWidth != 512;
3474 });
3475
3476 target.addDynamicallyLegalOp<arith::OrIOp>([](arith::OrIOp orOp) {
3477 auto srcType = dyn_cast<VectorType>(orOp.getLhs().getType());
3478 if (!srcType)
3479 return true;
3480 Type scalarType = srcType.getElementType();
3481 if (!isa<IntegerType>(scalarType))
3482 return true;
3483
3484 unsigned laneSize = getVectorLaneSize(srcType);
3485 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3486
3487 return laneSize * elWidth != 512;
3488 });
3489
3490 target.addDynamicallyLegalOp<arith::ShRSIOp>([](arith::ShRSIOp rsOp) {
3491 auto srcType = dyn_cast<VectorType>(rsOp.getLhs().getType());
3492 if (!srcType)
3493 return true;
3494 Type scalarType = srcType.getElementType();
3495
3496 unsigned laneSize = getVectorLaneSize(srcType);
3497 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3498
3499 return laneSize * elWidth != 512;
3500 });
3501
3502 target.addDynamicallyLegalOp<arith::AndIOp>([](arith::AndIOp andOp) {
3503 auto srcType = dyn_cast<VectorType>(andOp.getLhs().getType());
3504 if (!srcType)
3505 return true;
3506 Type scalarType = srcType.getElementType();
3507 if (!isa<IntegerType>(scalarType))
3508 return true;
3509
3510 unsigned laneSize = getVectorLaneSize(srcType);
3511 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3512
3513 return laneSize * elWidth != 512;
3514 });
3515
3516 if (backend == TargetBackend::CPP) {
3517 target.addDynamicallyLegalOp<arith::AddIOp>(
3518 [](arith::AddIOp op) { return !isa<VectorType>(op.getType()); });
3519 }
3520 target.addDynamicallyLegalOp<arith::AddFOp>(
3521 [](arith::AddFOp op) { return !isa<VectorType>(op.getType()); });
3522 target.addDynamicallyLegalOp<arith::SubIOp>(
3523 [](arith::SubIOp op) { return !isa<VectorType>(op.getType()); });
3524 target.addDynamicallyLegalOp<arith::SubFOp>(
3525 [](arith::SubFOp op) { return !isa<VectorType>(op.getType()); });
3526}
3527
3528static void configureAIEVecV1Legalizations(ConversionTarget &target,
3529 TargetBackend backend) {
3530 target.addDynamicallyLegalOp<arith::MulIOp>(
3531 [](arith::MulIOp op) { return !isa<VectorType>(op.getType()); });
3532 target.addDynamicallyLegalOp<arith::MulFOp>(
3533 [](arith::MulFOp op) { return !isa<VectorType>(op.getType()); });
3534 target.addDynamicallyLegalOp<aievec::aie1::FMAOp>(
3535 [](xilinx::aievec::aie1::FMAOp op) {
3536 auto *lhsDefOp = op.getLhs().getDefiningOp();
3537 aievec::ConcatOp concatOp = nullptr;
3538 if (lhsDefOp)
3539 concatOp = dyn_cast<aievec::ConcatOp>(op.getLhs().getDefiningOp());
3540 if (!concatOp)
3541 return true;
3542
3543 vector::SplatOp srcSplat = nullptr;
3544 if (auto *lhsOp = concatOp.getSources()[0].getDefiningOp())
3545 srcSplat = dyn_cast<vector::SplatOp>(lhsOp);
3546 if (!srcSplat) {
3547 auto *rhsOp = op.getRhs().getDefiningOp();
3548 if (!rhsOp)
3549 return true;
3550 srcSplat = dyn_cast<vector::SplatOp>(rhsOp);
3551 }
3552
3553 if (srcSplat)
3554 if (auto *srcOp = srcSplat.getInput().getDefiningOp())
3555 return !isa<vector::ExtractOp>(srcOp);
3556
3557 return true;
3558 });
3559
3560 target.addDynamicallyLegalOp<aievec::aie1::AddOp>([](aievec::aie1::AddOp op) {
3561 auto lSrsOp = op.getLhs().getDefiningOp<aievec::SRSOp>();
3562 auto rSrsOp = op.getRhs().getDefiningOp<aievec::SRSOp>();
3563 return (!lSrsOp ||
3564 !lSrsOp.getSource().getDefiningOp<aievec::aie1::MulOp>()) &&
3565 (!rSrsOp ||
3566 !rSrsOp.getSource().getDefiningOp<aievec::aie1::MulOp>());
3567 });
3568 target.addLegalDialect<memref::MemRefDialect>();
3569}
3570
3571static void configureAIEVecV2Legalizations(ConversionTarget &target,
3572 TargetBackend backend) {
3573 target.addLegalOp<UnrealizedConversionCastOp>();
3574 target.addLegalOp<vector::ShapeCastOp>();
3575
3576 // A set recording the vector lane size and element width supported
3577 llvm::SmallSet<std::pair<unsigned, unsigned>, 16> laneSizeElWidthPairSet;
3578 laneSizeElWidthPairSet.insert({64, 8});
3579 laneSizeElWidthPairSet.insert({32, 16});
3580 laneSizeElWidthPairSet.insert({16, 32});
3581 laneSizeElWidthPairSet.insert({32, 32});
3582
3583 // A set recording the element width supported
3584 llvm::SmallSet<unsigned, 16> elWidthSet;
3585 elWidthSet.insert(8);
3586 elWidthSet.insert(16);
3587 elWidthSet.insert(32);
3588
3589 if (backend == TargetBackend::CPP) {
3590 target.addDynamicallyLegalOp<arith::AddIOp>([=](arith::AddIOp op) {
3591 auto resultType = dyn_cast<VectorType>(op.getType());
3592 if (!resultType)
3593 return true;
3594
3595 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3596 unsigned laneSize = getVectorLaneSize(resultType);
3597
3598 return !laneSizeElWidthPairSet.count(
3599 std::make_pair(laneSize, resultElWidth));
3600 });
3601 }
3602
3603 target.addDynamicallyLegalOp<arith::SubIOp>([=](arith::SubIOp op) {
3604 auto resultType = dyn_cast<VectorType>(op.getType());
3605 if (!resultType)
3606 return true;
3607 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3608 unsigned laneSize = getVectorLaneSize(resultType);
3609
3610 return !laneSizeElWidthPairSet.count(
3611 std::make_pair(laneSize, resultElWidth));
3612 });
3613
3614 target.addDynamicallyLegalOp<arith::AddFOp>([](arith::AddFOp op) {
3615 auto resultType = dyn_cast<VectorType>(op.getType());
3616 if (!resultType)
3617 return true;
3618
3619 unsigned laneSize = getVectorLaneSize(resultType);
3620 return laneSize != 16;
3621 });
3622
3623 target.addDynamicallyLegalOp<arith::SubFOp>([](arith::SubFOp op) {
3624 auto resultType = dyn_cast<VectorType>(op.getType());
3625 if (!resultType)
3626 return true;
3627
3628 unsigned laneSize = getVectorLaneSize(resultType);
3629 return laneSize != 16;
3630 });
3631
3632 target.addDynamicallyLegalOp<arith::MulIOp>([](arith::MulIOp op) {
3633 auto resultType = dyn_cast<VectorType>(op.getType());
3634 if (!resultType)
3635 return true;
3636 auto isAddOp = [&](Operation *op) { return isa<arith::AddIOp>(op); };
3637 // Verify it is not a part of MAC
3638 if (op->hasOneUse() && llvm::any_of(op->getUsers(), isAddOp))
3639 return true;
3640
3641 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3642 unsigned laneSize = getVectorLaneSize(resultType);
3643
3644 return (laneSize != 32 || (resultElWidth != 16 && resultElWidth != 8)) &&
3645 ((laneSize != 16 && laneSize != 32) || resultElWidth != 32);
3646 });
3647
3648 target.addDynamicallyLegalOp<arith::MulFOp>([](arith::MulFOp op) {
3649 auto resultType = dyn_cast<VectorType>(op.getType());
3650 if (!resultType)
3651 return true;
3652
3653 auto isAddOp = [&](Operation *op) { return isa<arith::AddFOp>(op); };
3654 // Verify it is not a part of FMA
3655 if (op->hasOneUse() && llvm::any_of(op->getUsers(), isAddOp))
3656 return true;
3657
3658 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3659 unsigned laneSize = getVectorLaneSize(resultType);
3660
3661 return laneSize != 16 || (resultElWidth != 16 && resultElWidth != 32);
3662 });
3663
3664 target.addDynamicallyLegalOp<arith::MinSIOp>([=](arith::MinSIOp op) {
3665 auto resultType = dyn_cast<VectorType>(op.getType());
3666 if (!resultType)
3667 return true;
3668
3669 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3670 unsigned laneSize = getVectorLaneSize(resultType);
3671
3672 return !elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512;
3673 });
3674
3675 target.addDynamicallyLegalOp<arith::MaxSIOp>([=](arith::MaxSIOp op) {
3676 auto resultType = dyn_cast<VectorType>(op.getType());
3677 if (!resultType)
3678 return true;
3679
3680 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3681 unsigned laneSize = getVectorLaneSize(resultType);
3682
3683 return !elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512;
3684 });
3685
3686 target.addDynamicallyLegalOp<arith::MinimumFOp>([=](arith::MinimumFOp op) {
3687 auto resultType = dyn_cast<VectorType>(op.getType());
3688 if (!resultType)
3689 return true;
3690
3691 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3692 unsigned laneSize = getVectorLaneSize(resultType);
3693
3694 return !elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512;
3695 });
3696
3697 target.addDynamicallyLegalOp<arith::MaximumFOp>([=](arith::MaximumFOp op) {
3698 auto resultType = dyn_cast<VectorType>(op.getType());
3699 if (!resultType)
3700 return true;
3701
3702 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3703 unsigned laneSize = getVectorLaneSize(resultType);
3704
3705 return !elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512;
3706 });
3707
3708 target.addDynamicallyLegalOp<arith::CmpIOp>([=](arith::CmpIOp op) {
3709 auto lhsType = dyn_cast<VectorType>(op.getLhs().getType());
3710 if (!lhsType)
3711 return true;
3712
3713 auto lhsElWidth = lhsType.getElementType().getIntOrFloatBitWidth();
3714 unsigned laneSize = getVectorLaneSize(lhsType);
3715
3716 return !elWidthSet.count(lhsElWidth) || laneSize * lhsElWidth != 512;
3717 });
3718
3719 target.addDynamicallyLegalOp<arith::CmpFOp>([=](arith::CmpFOp op) {
3720 auto lhsType = dyn_cast<VectorType>(op.getLhs().getType());
3721 if (!lhsType)
3722 return true;
3723
3724 auto lhsElWidth = lhsType.getElementType().getIntOrFloatBitWidth();
3725 unsigned laneSize = getVectorLaneSize(lhsType);
3726
3727 return !elWidthSet.count(lhsElWidth) || laneSize * lhsElWidth != 512;
3728 });
3729
3730 target.addDynamicallyLegalOp<arith::SelectOp>([=](arith::SelectOp op) {
3731 auto resultType = dyn_cast<VectorType>(op.getType());
3732 if (!resultType)
3733 return true;
3734
3735 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3736 unsigned laneSize = getVectorLaneSize(resultType);
3737
3738 return !elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512;
3739 });
3740
3741 target.addDynamicallyLegalOp<vector::ReductionOp>(
3742 [=](vector::ReductionOp op) {
3743 if (auto kind = op.getKind(); kind != vector::CombiningKind::ADD &&
3744 kind != vector::CombiningKind::MINSI &&
3745 kind != vector::CombiningKind::MINUI &&
3746 kind != vector::CombiningKind::MINIMUMF &&
3747 kind != vector::CombiningKind::MAXSI &&
3748 kind != vector::CombiningKind::MAXUI &&
3749 kind != vector::CombiningKind::MAXIMUMF)
3750 return true;
3751
3752 auto vType = dyn_cast<VectorType>(op.getVector().getType());
3753 if (!vType)
3754 return true;
3755
3756 llvm::SmallSet<std::pair<unsigned, signed>, 16> laneSizeElWidthPairSet;
3757 laneSizeElWidthPairSet.insert({64, 8});
3758 laneSizeElWidthPairSet.insert({32, 16});
3759 laneSizeElWidthPairSet.insert({32, 32});
3760 laneSizeElWidthPairSet.insert({16, 32});
3761
3762 Type scalarType = vType.getElementType();
3763 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3764 unsigned laneSize = getVectorLaneSize(vType);
3765
3766 if (isa<IntegerType>(scalarType) &&
3767 !laneSizeElWidthPairSet.count(std::make_pair(laneSize, elWidth)))
3768 return true;
3769
3770 if (isa<FloatType>(scalarType) && laneSize != 16 && laneSize != 32)
3771 return true;
3772
3773 return false;
3774 });
3775
3776 target.addIllegalOp<vector::ContractionOp, vector::TransposeOp,
3777 vector::FMAOp>();
3778}
3779
3780//===----------------------------------------------------------------------===//
3781// Lowering passes
3782//===----------------------------------------------------------------------===//
3783
3784/// Lower incoming vector operations into their corresponding AIE vector
3785/// intrinsics.
3786struct LowerVectorToAIEVec : PassWrapper<LowerVectorToAIEVec, OperationPass<>> {
3787 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerVectorToAIEVec)
3788
3791
3797
3798 // In case we want to register this pass as a standalone pass for test
3799 // purposes.
3800 StringRef getArgument() const final { return "test-lower-vector-to-aievec"; }
3801 StringRef getDescription() const final {
3802 return "Lower vector operations to AIE vector intrinsics";
3803 }
3804 void getDependentDialects(DialectRegistry &registry) const override {
3805 registry
3806 .insert<affine::AffineDialect, xilinx::aievec::aie1::AIEVecAIE1Dialect,
3807 xilinx::aievec::AIEVecDialect, arith::ArithDialect,
3808 memref::MemRefDialect, scf::SCFDialect, vector::VectorDialect,
3809 emitc::EmitCDialect>();
3810 }
3811
3812 Option<std::string> aieTarget{
3813 *this, "aie-target",
3814 llvm::cl::desc(
3815 "Select AIE version: \"aie\", \"aie2\", or \"aie2p\". This will "
3816 "determine the vector size and available operations."),
3817 llvm::cl::init("aie")};
3818
3819 Option<std::string> targetBackend{
3820 *this, "target-backend",
3821 llvm::cl::desc("Select translation backend: \"cpp\" or \"llvmir\". This "
3822 "will determine the aievec operations used to convert "
3823 "from vector dialect."),
3824 llvm::cl::init("cpp")};
3825
3826 void runOnOperation() override {
3827 auto *op = getOperation();
3828 MLIRContext *context = &getContext();
3829 RewritePatternSet patterns(context);
3830 ConversionTarget target(*context);
3831 auto aieVersion = AIEArch::AIE;
3832 if (!aieTarget.empty()) {
3833 std::string targetStr = aieTarget;
3834 if (targetStr == "aieml" || targetStr == "aie2" || targetStr == "aie2p")
3835 aieVersion = AIEArch::AIE2;
3836 else if (targetStr != "aie") {
3837 op->emitError() << "unknown AIE target '" << aieTarget << "'";
3838 return signalPassFailure();
3839 }
3840 }
3841
3842 TargetBackend backend = TargetBackend::CPP;
3843 if (!targetBackend.empty()) {
3844 std::string backendStr = targetBackend;
3845 if (backendStr == "llvmir") {
3846 backend = TargetBackend::LLVMIR;
3847 if (aieVersion == AIEArch::AIE) {
3848 op->emitError() << "targetting LLVM IR is not supported for AIEv1";
3849 signalPassFailure();
3850 return;
3851 }
3852 } else if (backendStr != "cpp") {
3853 op->emitError() << "unknown target backend'" << targetBackend << "'";
3854 signalPassFailure();
3855 return;
3856 }
3857 }
3858
3859 populateAIEVecCommonConversionPatterns(patterns, backend);
3860 configureAIEVecCommonLegalizations(target, backend);
3861 if (aieVersion == AIEArch::AIE) {
3862 populateAIEVecV1ConversionPatterns(patterns, backend);
3863 configureAIEVecV1Legalizations(target, backend);
3864 } else {
3865 populateAIEVecV2ConversionPatterns(patterns, backend);
3866 configureAIEVecV2Legalizations(target, backend);
3867 }
3868
3869 if (failed(applyPartialConversion(op, target, std::move(patterns))))
3870 return signalPassFailure();
3871 }
3872};
3873
3874static std::unique_ptr<Pass>
3875createLowerVectorToAIEVec(const LowerVectorToAIEVecOptions &options) {
3876 return std::make_unique<LowerVectorToAIEVec>(options);
3877}
3878
3879//===---------------------------------------------------------------------------
3880// Custom canonicalization passes
3881//===---------------------------------------------------------------------------
3882
3883// This pass widens UPD ops to twice the width followed by an ext op of the
3884// bottom half. This can be used together with SimplifyUPDOpsPass to find
3885// additional common subexpressions with UPDs generated from unaligned
3886// `transfer_read` ops.
3887struct ExtendUPDOpsPass : PassWrapper<ExtendUPDOpsPass, OperationPass<>> {
3888
3889 void runOnOperation() override {
3890 MLIRContext *context = &getContext();
3891 RewritePatternSet patterns(context);
3892 ConversionTarget target(*context);
3893 patterns.add<ExpandUPDToUPDAndExtPattern>(patterns.getContext());
3894 target.addLegalDialect<aievec::AIEVecDialect>();
3895 target.addDynamicallyLegalOp<aievec::UPDOp>([](aievec::UPDOp op) {
3896 return op.getVector() ||
3897 (op->hasOneUse() && isa<aievec::UPDOp>(*op->getUsers().begin())) ||
3898 llvm::all_of(op->getUsers(),
3899 [](Operation *op) { return isa<aievec::ExtOp>(op); });
3900 });
3901
3902 if (auto *op = getOperation();
3903 failed(applyPartialConversion(op, target, std::move(patterns)))) {
3904 return signalPassFailure();
3905 }
3906 }
3907};
3908
3909// This pass replaces wide UPD ops that are only used by a single ext op of the
3910// bottom half. This pass undos the work of ExtendUPDOpsPass.
3911// TODO: This pass can be extended to work with wide UPD ops that are used by
3912// TODO: a single ext op of the top half, which might be a good opportunity to
3913// TODO: further optimize wide UPDs.
3914struct SimplifyUPDOpsPass : PassWrapper<SimplifyUPDOpsPass, OperationPass<>> {
3915
3916 void runOnOperation() override {
3917 MLIRContext *context = &getContext();
3918 RewritePatternSet patterns(context);
3919 ConversionTarget target(*context);
3920 patterns.add<FuseExtIntoUPDPattern>(patterns.getContext());
3921 target.addLegalDialect<aievec::AIEVecDialect>();
3922 target.addDynamicallyLegalOp<aievec::ExtOp>([](aievec::ExtOp op) {
3923 auto *defOp = op.getSource().getDefiningOp();
3924 return !defOp || !isa<aievec::UPDOp>(defOp) || !defOp->hasOneUse() ||
3925 op.getIndex() != 0;
3926 });
3927
3928 if (auto *op = getOperation();
3929 failed(applyPartialConversion(op, target, std::move(patterns)))) {
3930 return signalPassFailure();
3931 }
3932 }
3933};
3934
3935//============================================================================//
3936//=============== Main Vector2AIEVec Pipeline Configuration ==================//
3937//============================================================================//
3938
3940 OpPassManager &pm, const LowerVectorToAIEVecOptions &options) {
3941 // Add lowering from `Vector` to `AIEVec`
3942 pm.addPass(createLowerVectorToAIEVec(options));
3943 pm.addPass(createCanonicalizerPass());
3944
3945 // Simplify UPD ops
3946 pm.addPass(std::make_unique<ExtendUPDOpsPass>());
3947 pm.addPass(createCSEPass());
3948 pm.addPass(std::make_unique<SimplifyUPDOpsPass>());
3949 pm.addPass(createCanonicalizerPass());
3950}
LowerVectorMinMaxOpToAIEVecMinMaxOp< arith::MaximumFOp, aievec::MaxOp > LowerVectorMaximumFOpToAIEVecMaxOp
ComputeBandAndBorOpPattern< arith::OrIOp, aievec::BorOp > ComputeBorOpPattern
ComputeBandAndBorOpPattern< arith::AndIOp, aievec::BandOp > ComputeBandOpPattern
OneToOneVectorOpToAIEVecOpPattern< arith::SubFOp, aievec::aie1::SubOp > LowerVectorSubFOpToAIEVecSubOp
ComputeAbsOpPattern< math::AbsIOp > ComputeAbsIOpPattern
LowerTruncOpPattern< arith::TruncFOp > LowerTruncFOpPattern
LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp< arith::AddFOp, aievec::AddElemOp > LowerVectorAddFOpToAIEVecAddElemOp
LowerVectorMinMaxOpToAIEVecMinMaxOp< arith::MaxSIOp, aievec::MaxOp > LowerVectorMaxSIOpToAIEVecMaxOp
LowerExtOpPattern< arith::ExtFOp > LowerExtFOpPattern
LowerVectorCmpOpToAIEVecCmpOp< arith::CmpFOp, CmpFPredicate > LowerVectorCmpFOpToAIEVecCmpOp
OneToOneVectorOpToAIEVecOpPattern< arith::SubIOp, aievec::aie1::SubOp > LowerVectorSubIOpToAIEVecSubOp
ComputeAbsOpPattern< math::AbsFOp > ComputeAbsFOpPattern
LowerVectorCmpOpToAIEVecCmpOp< arith::CmpIOp, CmpIPredicate > LowerVectorCmpIOpToAIEVecCmpOp
LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp< arith::SubFOp, aievec::SubElemOp > LowerVectorSubFOpToAIEVecSubElemOp
LowerVectorMinMaxOpToAIEVecMinMaxOp< arith::MinSIOp, aievec::MinOp > LowerVectorMinSIOpToAIEVecMinOp
OneToOneVectorOpToAIEVecOpPattern< arith::AddFOp, aievec::aie1::AddOp > LowerVectorAddFOpToAIEVecAddOp
OneToOneVectorOpToAIEVecOpPattern< arith::MulFOp, aievec::aie1::MulOp > LowerVectorMulFOpToAIEVecMulOp
LowerExtOpPattern< arith::ExtSIOp > LowerExtSIOpPattern
LowerVectorMinMaxOpToAIEVecMinMaxOp< arith::MinimumFOp, aievec::MinOp > LowerVectorMinimumFOpToAIEVecMinOp
LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp< arith::AddIOp, aievec::AddElemOp > LowerVectorAddIOpToAIEVecAddElemOp
PathEndPoint src
mlir::VectorType getFlattenedVectorType(mlir::VectorType vecTy)
unsigned getVectorLaneSize(mlir::VectorType type)
Definition AIEVecUtils.h:55
SmallVector< NamedAttribute > buildFMAOpSplatAttrForElemTy(aievec::aie1::FMAOp fmaOp, int64_t bcastPos, int64_t step=1)
std::optional< int64_t > getTransferReadAlignmentOffset(TransferReadLikeOp readOp, mlir::VectorType vType, int64_t alignment)
mlir::VectorType createVectorType(unsigned lanes, mlir::Type elementType)
Definition AIEVecUtils.h:42
int32_t getElementSizeInBits(mlir::VectorType type)
Definition AIEVecUtils.h:49
void buildLowerVectorToAIEVec(mlir::OpPassManager &pm, const LowerVectorToAIEVecOptions &options)
mlir::VectorType getVectorOpDestType(mlir::VectorType type, bool AIE2)
Definition AIEVecUtils.h:80
TargetBackend
Definition Passes.h:27
LogicalResult matchAndRewrite(SrcOpTy absOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::XOrIOp xorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::CeilOp ceilOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::ErfOp erfOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::ExpOp expOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::ExpOp expOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::FloorOp floorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::DivFOp divOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::NegFOp negOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::RsqrtOp rsqrtOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::DivFOp divfOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::ShRSIOp rsOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::SqrtOp sqrtOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::TanhOp tanhOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::AddIOp addOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertMulAddToAIEVecFMAElemOpPattern(MLIRContext *context, unsigned shiftParam=0)
LogicalResult matchAndRewrite(aievec::aie1::AddOp addOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::MulFOp mulOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertMulFToAIEVecMulElemOpPattern(MLIRContext *context, unsigned shiftParam=0)
LogicalResult matchAndRewrite(arith::MulIOp mulOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertMulIToAIEVecMulElemOpPattern(MLIRContext *context, unsigned shiftParam=0)
LogicalResult matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertVectorFMAOpToAIEVecFMAElemOpPattern(MLIRContext *context, unsigned shiftParam=0)
ExpandUPDToUPDAndExtPattern(MLIRContext *context)
LogicalResult matchAndRewrite(aievec::UPDOp updOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::aie1::FMAOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ExtOp extOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
FuseExtIntoUPDPattern(MLIRContext *context)
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(SrcOpTy extOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(SrcOpTy truncOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(arith::AddIOp addOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LowerVectorContractionOpToAIEVecMatMulPattern(MLIRContext *context, bool matMoveToAcc=true)
LogicalResult matchAndRewrite(vector::ContractionOp contractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::MulIOp mulOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::SelectOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Lower incoming vector operations into their corresponding AIE vector intrinsics.
void getDependentDialects(DialectRegistry &registry) const override
LowerVectorToAIEVec(const LowerVectorToAIEVecOptions &options)
StringRef getDescription() const final
Option< std::string > targetBackend
StringRef getArgument() const final
LogicalResult matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LowerVectorTransferReadToAIEUPD(MLIRContext *context, int64_t minVectorSize, int64_t maxVectorSize, int64_t alignment, int64_t maxLoadSize)
LogicalResult matchAndRewrite(vector::TransposeOp transpOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Options for the "lower-vector-to-aievec" pipeline.
Definition Passes.h:57
PassOptions::Option< std::string > aieTarget
Definition Passes.h:58
PassOptions::Option< std::string > targetBackend
Definition Passes.h:63