MLIR-AIE
VectorToAIEVecConversions.cpp
Go to the documentation of this file.
1//===-VectorToAIEVecConversions.cpp - Vector to AIEVec convs. ---*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2023, Advanced Micro Devices, Inc.
8//
9//===----------------------------------------------------------------------===//
10// This file contains conversions from the Vector dialect into the AIEVec
11// dialect. Conversions assume that the Vector dialect has been rectricted
12// to ops that can be translated to a sequence of valid AIEVec ops.
13//===----------------------------------------------------------------------===//
14
19
21#include "mlir/Dialect/Affine/IR/AffineOps.h"
22#include "mlir/Dialect/EmitC/IR/EmitC.h"
23#include "mlir/Dialect/Func/IR/FuncOps.h"
24#include "mlir/Dialect/Math/IR/Math.h"
25#include "mlir/Dialect/MemRef/IR/MemRef.h"
26#include "mlir/Dialect/SCF/IR/SCF.h"
27#include "mlir/IR/PatternMatch.h"
28#include "mlir/IR/SymbolTable.h"
29#include "mlir/IR/TypeUtilities.h"
30#include "mlir/Pass/PassManager.h"
31#include "mlir/Transforms/DialectConversion.h"
32#include "mlir/Transforms/Passes.h"
33#include "llvm/ADT/SmallSet.h"
34#include <bitset>
35#include <optional>
36#include <tuple>
37
38#define DEBUG_TYPE "lower-vector-to-aievec"
39
40using namespace llvm;
41using namespace mlir;
42using namespace arith;
43using namespace vector;
44using namespace xilinx;
45using namespace xilinx::aievec;
46
47//===----------------------------------------------------------------------===//
48// Utility functions
49//===----------------------------------------------------------------------===//
50
51static bool isNarrowingOp(Operation *op) {
52 if (isa<arith::TruncFOp>(op) || isa<arith::TruncIOp>(op))
53 return true;
54
55 if (auto srsOp = dyn_cast<aievec::SRSOp>(op)) {
56 auto *srsOpSrcOp = srsOp.getSource().getDefiningOp();
57 if (isa<aievec::UPSOp>(srsOpSrcOp) || isa<aievec::CastOp>(srsOpSrcOp))
58 return true;
59 }
60 return false;
61}
62
63// Given a Value, if it is defined by a widening op (arith:ExtSIOp,
64// arith::ExtUIOp, arith::ExtFOp, aievec::UPSOp + aievec::SRSOp,
65// aievec::UPSOp + aievec::CastOp), return the source of the widening op.
66static std::optional<Value> getSourceOfWideningOp(Value src) {
67 if (auto extSIOp = src.getDefiningOp<arith::ExtSIOp>())
68 return extSIOp.getIn();
69 if (auto extUIOp = src.getDefiningOp<arith::ExtUIOp>())
70 return extUIOp.getIn();
71 if (auto extFOp = src.getDefiningOp<arith::ExtFOp>())
72 return extFOp.getIn();
73 if (auto srsOp = src.getDefiningOp<aievec::SRSOp>()) {
74 // Conversion through AIE intrinsics takes two steps:
75 // 1) Load to accumulator: aievec.ups
76 // 2) Move from accumulator: aievec.srs
77 auto srsSource = srsOp.getSource();
78 if (srsSource)
79 if (auto upsOp = srsSource.getDefiningOp<aievec::UPSOp>())
80 return upsOp.getSource();
81 }
82 if (auto castOp = src.getDefiningOp<aievec::CastOp>()) {
83 // Conversion through AIE intrinsics can also take the following two steps:
84 // 1) Load to accumulator: aievec.ups
85 // 2) Move from accumulator: aievec.cast
86 auto castSource = castOp.getSource();
87 if (castSource)
88 if (auto upsOp = castSource.getDefiningOp<aievec::UPSOp>())
89 return upsOp.getSource();
90 }
91 return std::optional<Value>();
92}
93
94// Given the LHS and RHS of an `arith::AddIOp`, if one of them is defined by an
95// `arith::MulIOp`, return a tuple with the `lhs`, `rhs`, and `acc` of the MAC
96// operation that can replace them.
97static std::optional<std::tuple<Value, Value, Value>>
98extractMACOperandsFromAddOperands(Value addLhs, Value addRhs) {
99 auto *lhsDefOp = addLhs.getDefiningOp();
100 auto *rhsDefOp = addRhs.getDefiningOp();
101 arith::MulIOp mulOp = nullptr;
102 Value acc;
103 if (lhsDefOp) {
104 mulOp = dyn_cast<arith::MulIOp>(lhsDefOp);
105 acc = addRhs;
106 }
107 if (!mulOp && rhsDefOp) {
108 mulOp = dyn_cast<arith::MulIOp>(rhsDefOp);
109 acc = addLhs;
110 }
111 if (mulOp)
112 return std::make_tuple(mulOp.getLhs(), mulOp.getRhs(), acc);
113
114 // If the MulIOp has been already translated to aievec::aie1::MulOp:
115 auto lhsSrsOp = addLhs.getDefiningOp<aievec::SRSOp>();
116 auto rhsSrsOp = addRhs.getDefiningOp<aievec::SRSOp>();
117 aievec::aie1::MulOp aieMulOp = nullptr;
118 if (lhsSrsOp) {
119 aieMulOp = lhsSrsOp.getSource().getDefiningOp<aievec::aie1::MulOp>();
120 acc = addRhs;
121 }
122 if (!aieMulOp && rhsSrsOp) {
123 aieMulOp = rhsSrsOp.getSource().getDefiningOp<aievec::aie1::MulOp>();
124 acc = addLhs;
125 }
126 if (aieMulOp)
127 return std::make_tuple(aieMulOp.getLhs(), aieMulOp.getRhs(), acc);
128 return {};
129}
130
131// Convert a input value to a target vector type. This function can insert
132// multiple aievec ops depending on the combination of input and output vector
133// types.
134static std::optional<Value>
135convertValueToTargetTypeAIE2(ConversionPatternRewriter &rewriter, Location loc,
136 Value inputVal, VectorType tgtType) {
137 auto srcType = cast<VectorType>(inputVal.getType());
138 auto srcElemType = srcType.getElementType();
139 unsigned srcBitWidth = srcElemType.getIntOrFloatBitWidth();
140 unsigned srcLaneSize = getVectorLaneSize(srcType);
141
142 auto tgtElemType = tgtType.getElementType();
143 unsigned tgtBitWidth = tgtElemType.getIntOrFloatBitWidth();
144 unsigned tgtLaneSize = getVectorLaneSize(tgtType);
145
146 if (srcType == tgtType)
147 return inputVal;
148
149 if ((srcElemType == tgtElemType) && (srcLaneSize != tgtLaneSize)) {
150 // TODO: relax the condition below?
151 if ((srcLaneSize == 16 && tgtLaneSize == 32 &&
152 isa<FloatType>(srcElemType)) ||
153 (srcLaneSize == 32 && tgtLaneSize == 64 &&
154 isa<IntegerType>(srcElemType))) {
155 auto zeroConstOp = rewriter.create<arith::ConstantOp>(
156 loc, srcType.getElementType(),
157 rewriter.getZeroAttr(srcType.getElementType()));
158 auto broadcastZeroOp = rewriter.create<aievec::BroadcastScalarOp>(
159 loc, tgtType, zeroConstOp->getResult(0));
160 auto extOp = rewriter.create<aievec::ExtOp>(
161 loc, srcType, broadcastZeroOp->getResult(0), 0);
162
163 SmallVector<Value> inputSources = {inputVal, extOp->getResult(0)};
164 auto concatOp =
165 rewriter.create<aievec::ConcatOp>(loc, tgtType, inputSources);
166
167 return concatOp.getResult();
168 }
169 } else if ((srcElemType != tgtElemType) && (srcLaneSize == tgtLaneSize) &&
170 isa<IntegerType>(srcElemType) && isa<IntegerType>(tgtElemType)) {
171 if (srcBitWidth == 16 && tgtBitWidth == 32 && srcLaneSize == 16) {
172 // Case 1: vector<16xi16> to vector<16xi32> conversion by aievec.ups +
173 // aievec.cast
174 auto accType = getVectorOpDestType(srcType, /*AIE2 =*/true);
175 auto upsOp = rewriter.create<aievec::UPSOp>(loc, accType, inputVal);
176 auto castOp = rewriter.create<aievec::CastOp>(
177 loc, tgtType, upsOp.getResult(), /*isResAcc*/ false);
178 return castOp.getResult();
179 }
180
181 if (srcBitWidth == 8 && tgtBitWidth == 32 && srcLaneSize == 16) {
182 // Case 2: vector<16xi8> to vector<16xi32> conversion by aievec.concat +
183 // aievec.ups + aievec.cast + aievec.ext
184 auto concatOutType = createVectorType(32, srcElemType);
185 auto concatOp = rewriter.create<aievec::ConcatOp>(
186 loc, concatOutType, SmallVector<Value>({inputVal, inputVal}));
187 auto accType = getVectorOpDestType(concatOutType, /*AIE2 =*/true);
188 auto upsOp =
189 rewriter.create<aievec::UPSOp>(loc, accType, concatOp.getResult());
190 auto castType = createVectorType(32, tgtElemType);
191 auto castOp = rewriter.create<aievec::CastOp>(
192 loc, castType, upsOp.getResult(), /*isResAcc*/ false);
193 auto extOp =
194 rewriter.create<aievec::ExtOp>(loc, tgtType, castOp.getResult(), 0);
195 return extOp.getResult();
196 }
197
198 if (srcBitWidth == 8 && tgtBitWidth == 16 && srcLaneSize == 32) {
199 // Case 3: vector<32xi8> to vector<32xi16> conversion by aievec.unpack
200 auto unpackOp = rewriter.create<aievec::UnpackOp>(loc, tgtType, inputVal);
201 return unpackOp.getResult();
202 }
203 }
204
205 return std::nullopt;
206}
207
208// Return the list of attributes that configure an `aievec.select` op to
209// perform a rotation of the input vector by `rotation` number of elements.
210// The attribute values depend on the vector type of the select operation.
211static SmallVector<NamedAttribute>
212buildAttributeListForRotationSelectOp(PatternRewriter &rewriter, VectorType vTy,
213 int64_t rotation) {
214 unsigned width = 0;
215 auto elemTy = vTy.getElementType();
216 if (auto intTy = dyn_cast<IntegerType>(elemTy))
217 width = intTy.getWidth();
218 StringAttr attr0 = rewriter.getStringAttr("0");
219 StringAttr attr0x06040200 = rewriter.getStringAttr("0x06040200");
220 StringAttr attr0x0e0c0a08 = rewriter.getStringAttr("0x0e0c0a08");
221 StringAttr attr0x2103 = rewriter.getStringAttr("0x2103");
222 StringAttr attr0x3210 = rewriter.getStringAttr("0x3210");
223 StringAttr selectAttrName = rewriter.getStringAttr("select");
224 StringAttr xoffsetsAttrName = rewriter.getStringAttr("xoffsets");
225 StringAttr xoffsetsHiAttrName = rewriter.getStringAttr("xoffsets_hi");
226 StringAttr xsquareAttrName = rewriter.getStringAttr("xsquare");
227 StringAttr xstartAttrName = rewriter.getStringAttr("xstart");
228 StringAttr yoffsetsAttrName = rewriter.getStringAttr("yoffsets");
229 StringAttr yoffsetsHiAttrName = rewriter.getStringAttr("yoffsets_hi");
230 StringAttr ysquareAttrName = rewriter.getStringAttr("ysquare");
231 StringAttr ystartAttrName = rewriter.getStringAttr("ystart");
232
233 switch (width) {
234 case 16: {
235 if (rotation % 2) {
236 int64_t xstart = rotation + 1;
237 int64_t ystart = rotation - 1;
238 return SmallVector<NamedAttribute, 9>(
239 {{selectAttrName, rewriter.getStringAttr("0x11111111")},
240 {xoffsetsAttrName, attr0x06040200},
241 {xoffsetsHiAttrName, attr0x0e0c0a08},
242 {xsquareAttrName, attr0x2103},
243 {xstartAttrName, rewriter.getStringAttr(std::to_string(xstart))},
244 {yoffsetsAttrName, rewriter.getStringAttr("0x0503010f")},
245 {yoffsetsHiAttrName, rewriter.getStringAttr("0x0d0b0907")},
246 {ysquareAttrName, attr0x2103},
247 {ystartAttrName, rewriter.getStringAttr(std::to_string(ystart))}});
248 }
249 return SmallVector<NamedAttribute, 9>(
250 {{selectAttrName, attr0},
251 {xoffsetsAttrName, attr0x06040200},
252 {xoffsetsHiAttrName, attr0x0e0c0a08},
253 {xsquareAttrName, attr0x3210},
254 {xstartAttrName, rewriter.getStringAttr(std::to_string(rotation))},
255 {yoffsetsAttrName, attr0},
256 {yoffsetsHiAttrName, attr0},
257 {ysquareAttrName, attr0},
258 {ystartAttrName, attr0}});
259 }
260 case 32:
261 return SmallVector<NamedAttribute, 7>(
262 {{selectAttrName, attr0},
263 {xoffsetsAttrName, rewriter.getStringAttr("0x76543210")},
264 {xsquareAttrName, attr0x3210},
265 {xstartAttrName, rewriter.getStringAttr(std::to_string(rotation))},
266 {yoffsetsAttrName, attr0},
267 {ysquareAttrName, attr0},
268 {ystartAttrName, attr0}});
269 default:
270 llvm::report_fatal_error("Unexpected width!");
271 }
272
273 return {};
274}
275
276namespace xilinx::aievec {
277
278SmallVector<NamedAttribute>
279buildFMAOpSplatAttrForElemTy(aievec::aie1::FMAOp fmaOp, int64_t bcastPos,
280 int64_t step = 1) {
281 unsigned width = 0;
282 auto elemTy = fmaOp.getLhs().getType().getElementType();
283 if (auto intTy = dyn_cast<IntegerType>(elemTy))
284 width = intTy.getWidth();
285 auto *ctx = fmaOp.getContext();
286 switch (width) {
287 case 16:
288 // NOTE: The pattern is:
289 // acc[0] = x[0] * z[bcastPos] + x[16] * z[bcastPos+step]
290 // acc[1] = x[1] * z[bcastPos] + x[17] * z[bcastPos+step]
291 // acc[2] = x[2] * z[bcastPos] + x[18] * z[bcastPos+step]
292 // acc[3] = x[3] * z[bcastPos] + x[19] * z[bcastPos+step]
293 // acc[4] = x[4] * z[bcastPos] + x[20] * z[bcastPos+step]
294 // acc[5] = x[5] * z[bcastPos] + x[21] * z[bcastPos+step]
295 // acc[6] = x[6] * z[bcastPos] + x[22] * z[bcastPos+step]
296 // acc[7] = x[7] * z[bcastPos] + x[23] * z[bcastPos+step]
297 // acc[8] = x[8] * z[bcastPos] + x[24] * z[bcastPos+step]
298 // acc[9] = x[9] * z[bcastPos] + x[25] * z[bcastPos+step]
299 // acc[10] = x[10] * z[bcastPos] + x[26] * z[bcastPos+step]
300 // acc[11] = x[11] * z[bcastPos] + x[27] * z[bcastPos+step]
301 // acc[12] = x[12] * z[bcastPos] + x[28] * z[bcastPos+step]
302 // acc[13] = x[13] * z[bcastPos] + x[29] * z[bcastPos+step]
303 // acc[14] = x[14] * z[bcastPos] + x[30] * z[bcastPos+step]
304 // acc[15] = x[15] * z[bcastPos] + x[31] * z[bcastPos+step]
305 return SmallVector<NamedAttribute, 11>(
306 {{fmaOp.getXstartAttrName(), StringAttr::get(ctx, "0")},
307 {fmaOp.getXoffsetsAttrName(), StringAttr::get(ctx, "0x73727170")},
308 {fmaOp.getXoffsetsHiAttrName(), StringAttr::get(ctx, "0x77767574")},
309 {fmaOp.getXstepAttrName(), fmaOp.getXstepAttr()},
310 {fmaOp.getXsquareAttrName(), StringAttr::get(ctx, "0x3120")},
311 {fmaOp.getZstartAttrName(),
312 StringAttr::get(ctx, std::to_string(bcastPos))},
313 {fmaOp.getZoffsetsAttrName(), StringAttr::get(ctx, "0")},
314 {fmaOp.getZoffsetsHiAttrName(), StringAttr::get(ctx, "0")},
315 {fmaOp.getZstepAttrName(), StringAttr::get(ctx, std::to_string(step))},
316 {fmaOp.getZsquareAttrName(), fmaOp.getZsquareAttr()},
317 {fmaOp.getFmsubAttrName(), fmaOp.getFmsubAttr()}});
318 case 32:
319 return SmallVector<NamedAttribute, 11>(
320 {{fmaOp.getXstartAttrName(), StringAttr::get(ctx, "0")},
321 {fmaOp.getXoffsetsAttrName(), StringAttr::get(ctx, "0x76543210")},
322 {fmaOp.getXoffsetsHiAttrName(), fmaOp.getXoffsetsHiAttr()},
323 {fmaOp.getXstepAttrName(), fmaOp.getXstepAttr()},
324 {fmaOp.getXsquareAttrName(), fmaOp.getXsquareAttr()},
325 {fmaOp.getZstartAttrName(),
326 StringAttr::get(ctx, std::to_string(bcastPos))},
327 {fmaOp.getZoffsetsAttrName(), StringAttr::get(ctx, "0x00000000")},
328 {fmaOp.getZoffsetsHiAttrName(), fmaOp.getZoffsetsHiAttr()},
329 {fmaOp.getZstepAttrName(), fmaOp.getZstepAttr()},
330 {fmaOp.getZsquareAttrName(), fmaOp.getZsquareAttr()},
331 {fmaOp.getFmsubAttrName(), fmaOp.getFmsubAttr()}});
332 default:
333 llvm::report_fatal_error("Unexpected width!");
334 }
335
336 return {};
337}
338
339} // namespace xilinx::aievec
340
341template <typename SrcOpTy, typename AIEv2ElemOp>
342static LogicalResult genAddElemAIE2(ConversionPatternRewriter &rewriter,
343 Value lval, Value rval, VectorType srcType,
344 SrcOpTy srcOp) {
345 auto lCastOp = rewriter.create<aievec::CastOp>(srcOp.getLoc(), srcType, lval,
346 /*isResAcc*/ true);
347 auto rCastOp = rewriter.create<aievec::CastOp>(srcOp.getLoc(), srcType, rval,
348 /*isResAcc*/ true);
349 auto elemOp = rewriter.create<AIEv2ElemOp>(
350 srcOp.getLoc(), lCastOp->getResult(0).getType(), lCastOp->getResult(0),
351 rCastOp->getResult(0));
352 rewriter.replaceOpWithNewOp<aievec::CastOp>(
353 srcOp, srcOp.getType(), elemOp.getResult(), /*isResAcc*/ false);
354 return success();
355}
356
357static arith::CmpIPredicate
358convertToIntegerPredicate(arith::CmpFPredicate pred) {
359 switch (pred) {
360 case CmpFPredicate::UEQ:
361 case CmpFPredicate::OEQ:
362 return CmpIPredicate::eq;
363 case CmpFPredicate::UGT:
364 return CmpIPredicate::ugt;
365 case CmpFPredicate::OGT:
366 return CmpIPredicate::sgt;
367 case CmpFPredicate::UGE:
368 return CmpIPredicate::uge;
369 case CmpFPredicate::OGE:
370 return CmpIPredicate::sge;
371 case CmpFPredicate::ULT:
372 return CmpIPredicate::ult;
373 case CmpFPredicate::OLT:
374 return CmpIPredicate::slt;
375 case CmpFPredicate::ULE:
376 return CmpIPredicate::ule;
377 case CmpFPredicate::OLE:
378 return CmpIPredicate::sle;
379 case CmpFPredicate::UNE:
380 case CmpFPredicate::ONE:
381 return CmpIPredicate::ne;
382 default:
383 llvm::report_fatal_error("Unexpected predicate!");
384 }
385}
386
387static arith::CmpIPredicate
388convertToIntegerPredicate(arith::CmpIPredicate pred) {
389 return pred;
390}
391
392static aievec::CmpOp createCmpOpAIE2(ConversionPatternRewriter &rewriter,
393 CmpIPredicate pred, Location loc,
394 Type type, Value lhs, Value rhs) {
395 switch (pred) {
396 case CmpIPredicate::eq:
397 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs, "eq");
398 case CmpIPredicate::ne:
399 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs, "ne");
400 case CmpIPredicate::slt:
401 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs, "slt");
402 case CmpIPredicate::ult:
403 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs, "ult");
404 case CmpIPredicate::sle:
405 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs, "sle");
406 case CmpIPredicate::ule:
407 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs, "ule");
408 case CmpIPredicate::sgt:
409 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs, "sgt");
410 case CmpIPredicate::ugt:
411 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs, "ugt");
412 case CmpIPredicate::sge:
413 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs, "sge");
414 case CmpIPredicate::uge:
415 return rewriter.create<aievec::CmpOp>(loc, type, lhs, rhs, "uge");
416 }
417 return nullptr;
418}
419
420template <typename DstOpTy>
421static void generateAIEVecOpsForReductionOp(ConversionPatternRewriter &rewriter,
422 vector::ReductionOp srcOp,
423 int shiftIndex, Value curValue) {
424 assert(shiftIndex > 0 && (shiftIndex & (shiftIndex - 1)) == 0 &&
425 "shiftIndex must be power of 2");
426
427 Location loc = srcOp.getLoc();
428 auto vType = dyn_cast<VectorType>(curValue.getType());
429 Type scalarType = vType.getElementType();
430 Type vecType = curValue.getType();
431 DstOpTy curOp = nullptr;
432 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
433
434 for (int id = shiftIndex; id > 0; id /= 2) {
435 auto constOp = rewriter.create<arith::ConstantOp>(
436 loc, rewriter.getI32IntegerAttr(id * elWidth / 8));
437
438 auto shiftBytesOp = rewriter.create<aievec::ShiftOp>(
439 loc, vecType, curValue, curValue, constOp.getResult());
440
441 curOp = rewriter.create<DstOpTy>(loc, vecType, curValue,
442 shiftBytesOp.getResult());
443
444 curValue = curOp.getResult();
445 }
446
447 auto zeroConstOp =
448 rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
449 rewriter.replaceOpWithNewOp<aievec::ExtElemOp>(srcOp, scalarType, curOp,
450 zeroConstOp.getResult());
451}
452
453static func::FuncOp getOrInsertFuncDecl(ConversionPatternRewriter &rewriter,
454 mlir::ModuleOp parentModuleOp,
455 StringRef funcName, TypeRange inTypes,
456 TypeRange outTypes) {
457
458 mlir::OpBuilder::InsertionGuard insertGuard(rewriter);
459 rewriter.setInsertionPointToStart(
460 &parentModuleOp.getRegion().getBlocks().front());
461 SymbolTable st = SymbolTable(parentModuleOp);
462 func::FuncOp fnOpLookup = st.lookup<func::FuncOp>(funcName);
463 func::FuncOp fnOp;
464 // if the function is already declared, use the existing function, don't
465 // declare multiple times
466 if (fnOpLookup != NULL) {
467 fnOp = fnOpLookup;
468 } else {
469 StringAttr t1 = rewriter.getStringAttr("sym_visibility");
470 StringAttr t2 = rewriter.getStringAttr("private");
471 NamedAttribute funcAccess = NamedAttribute(t1, t2);
472 FunctionType fnType =
473 mlir::FunctionType::get(rewriter.getContext(), inTypes, outTypes);
474 fnOp = rewriter.create<func::FuncOp>(parentModuleOp.getLoc(), funcName,
475 fnType, funcAccess);
476 }
477 return fnOp;
478}
479
480static bool matchExpOpForLUT(math::ExpOp::Adaptor adaptor) {
481 auto srcType = dyn_cast<VectorType>(adaptor.getOperand().getType());
482
483 if (!srcType)
484 return false;
485
486 Type scalarType = srcType.getElementType();
487 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
488 unsigned laneSize = getVectorLaneSize(srcType);
489 return isa<FloatType>(scalarType) && laneSize == 16 && elWidth == 16;
490}
491
492//===----------------------------------------------------------------------===//
493// Rewrite patterns
494//===----------------------------------------------------------------------===//
495
496// This pattern fold `vector.extract` and `vector.splat` into
497// `aievec.broadcast` for AIE2
499 : OpConversionPattern<vector::SplatOp> {
500 using OpConversionPattern::OpConversionPattern;
501
502 LogicalResult
503 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
504 ConversionPatternRewriter &rewriter) const override {
505
506 auto extOp = adaptor.getInput().getDefiningOp<vector::ExtractOp>();
507
508 if (!extOp)
509 return failure();
510
511 auto src = extOp.getVector();
512 auto pos = extOp.getStaticPosition();
513 int64_t posVal = pos[0];
514 auto srcVecType = cast<VectorType>(src.getType());
515 auto resultType = cast<VectorType>(splatOp.getResult().getType());
516 if (srcVecType != resultType) {
517 if (srcVecType.getNumElements() != 2 * resultType.getNumElements())
518 return failure();
519 auto half = static_cast<int8_t>(posVal / resultType.getNumElements());
520 posVal -= half * resultType.getNumElements();
521 src = rewriter
522 .create<aievec::ExtOp>(extOp.getLoc(), resultType, src,
523 rewriter.getI8IntegerAttr(half))
524 .getResult();
525 }
526
527 unsigned elWidth = resultType.getElementType().getIntOrFloatBitWidth();
528
529 if (unsigned laneSize = getVectorLaneSize(resultType);
530 laneSize * elWidth == 512) {
531 // Common use case for the broadcast_elem intrinsic
532 rewriter.replaceOpWithNewOp<aievec::BroadcastOp>(splatOp, resultType, src,
533 posVal);
534 } else if (laneSize * elWidth == 256) {
535 // e.g. need v16bf16 due to the subsequent v16accfloat operation
536 VectorType aievecBcastType =
537 createVectorType(512 / elWidth, resultType.getElementType());
538 auto concatOp = rewriter.create<aievec::ConcatOp>(
539 splatOp.getLoc(), aievecBcastType, SmallVector<Value>({src, src}));
540 auto aieBcastOp = rewriter.create<aievec::BroadcastOp>(
541 splatOp.getLoc(), aievecBcastType, concatOp.getResult(), posVal);
542 rewriter.replaceOpWithNewOp<aievec::ExtOp>(splatOp, resultType,
543 aieBcastOp.getResult(), 0);
544 } else if (laneSize * elWidth == 1024) {
545 // e.g. need v32int32 due to the subsequent v32acc32 operation
546 VectorType aievecBcastType =
547 createVectorType(512 / elWidth, resultType.getElementType());
548 auto half = static_cast<int8_t>(posVal / resultType.getNumElements());
549 posVal -= half * resultType.getNumElements();
550 auto extOp =
551 rewriter.create<aievec::ExtOp>(splatOp.getLoc(), aievecBcastType, src,
552 rewriter.getI8IntegerAttr(half));
553 auto aieBcastOp = rewriter.create<aievec::BroadcastOp>(
554 splatOp.getLoc(), aievecBcastType, extOp.getResult(), posVal);
555 rewriter.replaceOpWithNewOp<aievec::ConcatOp>(
556 splatOp, resultType,
557 SmallVector<Value>({aieBcastOp.getResult(), aieBcastOp.getResult()}));
558 } else {
559 return failure();
560 }
561
562 return success();
563 }
564};
565
567 using OpConversionPattern::OpConversionPattern;
568
569 LogicalResult
570 matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
571 ConversionPatternRewriter &rewriter) const override {
572
573 if (adaptor.getInput().getDefiningOp<vector::ExtractOp>())
574 return failure();
575
576 auto resultType = cast<VectorType>(splatOp.getResult().getType());
577 auto flatResultType = getFlattenedVectorType(resultType);
578 Type scalarType = resultType.getElementType();
579 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
580 unsigned laneSize = getVectorLaneSize(resultType);
581 auto src = splatOp.getInput();
582
583 if (laneSize * elWidth == 512) {
584 Value newOp = rewriter.create<aievec::BroadcastScalarOp>(
585 splatOp.getLoc(), flatResultType, src);
586 if (resultType != flatResultType)
587 newOp = rewriter.create<vector::ShapeCastOp>(splatOp.getLoc(),
588 resultType, newOp);
589 rewriter.replaceOp(splatOp, newOp);
590 return success();
591 }
592
593 if (laneSize * elWidth == 256) {
594 VectorType vecType = createVectorType(512 / elWidth, scalarType);
595 auto aieBcastOp = rewriter.create<aievec::BroadcastScalarOp>(
596 splatOp.getLoc(), vecType, src);
597 Value newOp = rewriter.create<aievec::ExtOp>(
598 splatOp.getLoc(), flatResultType, aieBcastOp.getResult(), 0);
599 if (resultType != flatResultType)
600 newOp = rewriter.create<vector::ShapeCastOp>(splatOp.getLoc(),
601 resultType, newOp);
602 rewriter.replaceOp(splatOp, newOp);
603 return success();
604 }
605
606 if (laneSize * elWidth == 1024) {
607 VectorType vecType = createVectorType(512 / elWidth, scalarType);
608 auto aieBcastOp = rewriter.create<aievec::BroadcastScalarOp>(
609 splatOp.getLoc(), vecType, src);
610 Value newOp = rewriter.create<aievec::ConcatOp>(
611 splatOp.getLoc(), flatResultType,
612 SmallVector<Value>({aieBcastOp.getResult(), aieBcastOp.getResult()}));
613 if (resultType != flatResultType)
614 newOp = rewriter.create<vector::ShapeCastOp>(splatOp.getLoc(),
615 resultType, newOp);
616 rewriter.replaceOp(splatOp, newOp);
617 return success();
618 }
619
620 return failure();
621 }
622};
623
624// This pattern replaces `arith.muli`+`arith.addi` on vectors with
625// `aievec.mac_elem`. This pattern works for AIE2.
627 : OpConversionPattern<arith::AddIOp> {
628 using OpConversionPattern::OpConversionPattern;
629
631 unsigned shiftParam = 0)
633
634 LogicalResult
635 matchAndRewrite(arith::AddIOp addOp, OpAdaptor adaptor,
636 ConversionPatternRewriter &rewriter) const override {
637 // Verify it's a vector operation
638 auto resultType = dyn_cast<VectorType>(addOp.getType());
639 if (!resultType)
640 return failure();
641
642 // Verify it can be replaced by a MAC
643 auto res =
644 extractMACOperandsFromAddOperands(adaptor.getLhs(), adaptor.getRhs());
645 if (!res)
646 return failure();
647 auto [lhs, rhs, acc] = *res;
648
649 // Verify the vector type is supported by AIE2
650 unsigned resultElWidth =
651 resultType.getElementType().getIntOrFloatBitWidth();
652 unsigned laneSize = getVectorLaneSize(resultType);
653
654 if ((laneSize != 32 || resultElWidth != 16) &&
655 (laneSize != 16 || resultElWidth != 32))
656 return failure();
657
658 Type accType = getVectorOpDestType(cast<VectorType>(acc.getType()),
659 /*AIE2 =*/true);
660 auto upsOp = rewriter.create<aievec::UPSOp>(addOp.getLoc(), accType, acc,
661 shiftParam);
662 auto fmaElemOp = rewriter.create<aievec::FMAElemOp>(
663 addOp.getLoc(), accType, lhs, rhs, upsOp.getResult(),
664 /*fmsub=*/false);
665
666 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
667 addOp.getLoc(), rewriter.getI32IntegerAttr(shiftParam));
668 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
669 addOp, resultType, fmaElemOp.getResult(), shiftParamOp.getResult());
670
671 return success();
672 }
673
674 unsigned shiftParam;
675};
676
677// Convert `vector.fma` to `aievec.mac_elem`. Only `vector<16xf32>` and
678// `vector<16xbf16>` operand types are supported. In the case of vectors with
679// `f32` elemental type, this pattern will try to match `bf16` to `f32`
680// widening ops in the `lhs` and `rhs` operands, or fail otherwise.
681// TODO: When sign extensions are not found, a conversion from `f32` to `bf16`
682// TODO: can be inserted to emulate `f32` fma with `bf16` logic.
684 : OpConversionPattern<vector::FMAOp> {
685 using OpConversionPattern::OpConversionPattern;
686
688 unsigned shiftParam = 0)
690
691 LogicalResult
692 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor,
693 ConversionPatternRewriter &rewriter) const override {
694 // Verify the vector type is supported by AIE2
695 auto resVecTy = cast<VectorType>(fmaOp.getType());
696 auto resElemTy = resVecTy.getElementType();
697 unsigned numElems = getVectorLaneSize(resVecTy);
698
699 if (numElems != 16 || (!resElemTy.isF32() && !resElemTy.isBF16()))
700 return rewriter.notifyMatchFailure(
701 fmaOp, "Unsupported operand types in vector.fma lowering.");
702
703 Value lhs = adaptor.getLhs();
704 Value rhs = adaptor.getRhs();
705 Value acc = adaptor.getAcc();
706 if (resElemTy.isBF16())
707 acc = rewriter.create<aievec::UPSOp>(
708 fmaOp.getLoc(), VectorType::get({16}, rewriter.getF32Type()), acc,
709 shiftParam);
710 else {
711 lhs = getSourceOfWideningOp(lhs).value_or(nullptr);
712 rhs = getSourceOfWideningOp(rhs).value_or(nullptr);
713 if (!lhs || !rhs)
714 return rewriter.notifyMatchFailure(
715 fmaOp, "vector.fma operands are f32, and they don't come from "
716 "arith.extf on bf16; can't lower to aievec.");
717 if (!cast<VectorType>(lhs.getType()).getElementType().isBF16() ||
718 !cast<VectorType>(rhs.getType()).getElementType().isBF16())
719 return rewriter.notifyMatchFailure(
720 fmaOp, "vector.fma operands come from arith.extf, but the source "
721 "of the widening op is not bf16; can't lower to aievec.");
722 }
723 Value newOp = rewriter.create<aievec::FMAElemOp>(
724 fmaOp.getLoc(), acc.getType(), lhs, rhs, acc, /*fmsub=*/false);
725
726 if (resElemTy.isBF16()) {
727 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
728 fmaOp.getLoc(), rewriter.getI32IntegerAttr(shiftParam));
729 newOp = rewriter.create<aievec::SRSOp>(fmaOp.getLoc(), resVecTy, newOp,
730 shiftParamOp);
731 }
732
733 rewriter.replaceOp(fmaOp, newOp);
734
735 return success();
736 }
737
738 unsigned shiftParam;
739};
740
741// This pattern replaces `arith.mulf` on vectors with
742// `aievec.mul_elem`. This pattern works for AIE2.
744 : OpConversionPattern<arith::MulFOp> {
745 using OpConversionPattern::OpConversionPattern;
746
748 unsigned shiftParam = 0)
750
751 LogicalResult
752 matchAndRewrite(arith::MulFOp mulOp, OpAdaptor adaptor,
753 ConversionPatternRewriter &rewriter) const override {
754 // Verify it's a vector operation
755 auto resultType = dyn_cast<VectorType>(mulOp.getType());
756 if (!resultType)
757 return failure();
758
759 // FIXME: Verify it is not a part of FMA
760 auto isAddOp = [&](Operation *op) { return isa<arith::AddFOp>(op); };
761 if (mulOp->hasOneUse() && llvm::any_of(mulOp->getUsers(), isAddOp))
762 return failure();
763
764 unsigned resultElWidth =
765 resultType.getElementType().getIntOrFloatBitWidth();
766
767 unsigned laneSize = getVectorLaneSize(resultType);
768
769 // bfloat16 and float type
770 if (laneSize != 16 || (resultElWidth != 16 && resultElWidth != 32))
771 return failure();
772
773 // Decide the accType for aievec.mul_elem based on mulOp's lhs & rhs
774 auto lval = adaptor.getLhs();
775 auto rval = adaptor.getRhs();
776 lval = getSourceOfWideningOp(lval).value_or(lval);
777 rval = getSourceOfWideningOp(rval).value_or(rval);
778 auto lSrcType = cast<VectorType>(lval.getType());
779 auto rSrcType = cast<VectorType>(rval.getType());
780 unsigned lBitWidth = lSrcType.getElementType().getIntOrFloatBitWidth();
781 unsigned rBitWidth = rSrcType.getElementType().getIntOrFloatBitWidth();
782 Type accType = getVectorOpDestType(lSrcType, /*AIE2 =*/true);
783 if (rBitWidth > lBitWidth) {
784 accType = getVectorOpDestType(rSrcType, /*AIE2 =*/true);
785 }
786 // Only support the same lhs/rhs type at the moment
787 if (lSrcType != rSrcType) {
788 return failure();
789 }
790
791 // Prepare lhr/rhs for the aievec.mul_elem op
792 unsigned bitWidth = (rBitWidth > lBitWidth) ? rBitWidth : lBitWidth;
793 Type srcElemType = (rBitWidth > lBitWidth) ? rSrcType.getElementType()
794 : lSrcType.getElementType();
795 unsigned numLanes = 0;
796 if (isa<FloatType>(srcElemType) && (bitWidth == 16 || bitWidth == 32)) {
797 numLanes = 16;
798 } else if (isa<IntegerType>(srcElemType) &&
799 (bitWidth == 8 || bitWidth == 16)) {
800 numLanes = 32;
801 } else if (isa<IntegerType>(srcElemType) && (bitWidth == 32)) {
802 numLanes = 16;
803 } else {
804 return failure();
805 }
806 VectorType targetInputType = createVectorType(numLanes, srcElemType);
807 if (targetInputType != lSrcType) {
808 lval = convertValueToTargetTypeAIE2(rewriter, mulOp.getLoc(), lval,
809 targetInputType)
810 .value();
811 }
812 if (targetInputType != rSrcType) {
813 rval = convertValueToTargetTypeAIE2(rewriter, mulOp.getLoc(), rval,
814 targetInputType)
815 .value();
816 }
817 if (!lval || !rval)
818 return failure();
819
820 // Create an aievec.mul_elem op
821 auto mulElemOp =
822 rewriter.create<aievec::MulElemOp>(mulOp.getLoc(), accType, lval, rval);
823
824 // Create an aievec.cast or an aievec.srs op
825 auto mulElemResultType = mulElemOp.getType();
826 auto mulElemResultElWidth =
827 mulElemResultType.getElementType().getIntOrFloatBitWidth();
828
829 if (mulElemResultElWidth == resultElWidth) {
830 rewriter.replaceOpWithNewOp<aievec::CastOp>(
831 mulOp, resultType, mulElemOp.getResult(), /*isResAcc*/ false);
832 } else if (mulElemResultElWidth > resultElWidth) {
833 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
834 mulOp.getLoc(), rewriter.getI32IntegerAttr(shiftParam));
835 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
836 mulOp, resultType, mulElemOp.getResult(), shiftParamOp.getResult());
837 } else {
838 return failure();
839 }
840
841 return success();
842 }
843
844 unsigned shiftParam;
845};
846
847// This pattern replaces `arith.muli` on vectors with
848// `aievec.mul_elem`. This pattern works for AIE2.
850 : OpConversionPattern<arith::MulIOp> {
851 using OpConversionPattern::OpConversionPattern;
852
854 unsigned shiftParam = 0)
856
857 LogicalResult
858 matchAndRewrite(arith::MulIOp mulOp, OpAdaptor adaptor,
859 ConversionPatternRewriter &rewriter) const override {
860 // Verify it's a vector operation
861 auto resultType = dyn_cast<VectorType>(mulOp.getType());
862 if (!resultType)
863 return failure();
864
865 // FIXME: Verify it is not a part of MAC
866 auto isAddOp = [&](Operation *op) { return isa<arith::AddIOp>(op); };
867 if (mulOp->hasOneUse() && llvm::any_of(mulOp->getUsers(), isAddOp))
868 return failure();
869
870 // Verify the vector type is supported by AIE2
871 unsigned resultElWidth =
872 resultType.getElementType().getIntOrFloatBitWidth();
873 unsigned laneSize = getVectorLaneSize(resultType);
874
875 if ((laneSize != 32 || (resultElWidth != 16 && resultElWidth != 8)) &&
876 ((laneSize != 16 && laneSize != 32) || resultElWidth != 32))
877 return failure();
878
879 // Decide the accType for aievec.mul_elem based on mulOp's lhs & rhs
880 auto lval = adaptor.getLhs();
881 auto rval = adaptor.getRhs();
882
883 lval = getSourceOfWideningOp(lval).value_or(lval);
884 rval = getSourceOfWideningOp(rval).value_or(rval);
885
886 auto lSrcType = cast<VectorType>(lval.getType());
887 auto rSrcType = cast<VectorType>(rval.getType());
888 unsigned lBitWidth = lSrcType.getElementType().getIntOrFloatBitWidth();
889 unsigned rBitWidth = rSrcType.getElementType().getIntOrFloatBitWidth();
890 Type accType = getVectorOpDestType(lSrcType, /*AIE2 =*/true);
891 if (rBitWidth > lBitWidth) {
892 accType = getVectorOpDestType(rSrcType, /*AIE2 =*/true);
893 }
894
895 // Prepare lhr/rhs for the aievec.mul_elem op
896 unsigned bitWidth = (rBitWidth > lBitWidth) ? rBitWidth : lBitWidth;
897 Type srcElemType = (rBitWidth > lBitWidth) ? rSrcType.getElementType()
898 : lSrcType.getElementType();
899 unsigned numLanes = 0;
900 if (isa<FloatType>(srcElemType) && (bitWidth == 16 || bitWidth == 32)) {
901 numLanes = 16;
902 } else if (isa<IntegerType>(srcElemType) &&
903 (bitWidth == 8 || bitWidth == 16)) {
904 numLanes = 32;
905 } else if (isa<IntegerType>(srcElemType) && (bitWidth == 32)) {
906 numLanes = 16;
907 } else {
908 return failure();
909 }
910 VectorType targetInputType = createVectorType(numLanes, srcElemType);
911 if (targetInputType != lSrcType) {
912 lval = convertValueToTargetTypeAIE2(rewriter, mulOp.getLoc(), lval,
913 targetInputType)
914 .value();
915 }
916 if (targetInputType != rSrcType) {
917 rval = convertValueToTargetTypeAIE2(rewriter, mulOp.getLoc(), rval,
918 targetInputType)
919 .value();
920 }
921 if (!lval || !rval)
922 return failure();
923
924 // Create an aievec.mul_elem op
925 auto mulElemOp =
926 rewriter.create<aievec::MulElemOp>(mulOp.getLoc(), accType, lval, rval);
927
928 // Create an aievec.cast or an aievec.srs op
929 auto mulElemResultType = mulElemOp.getType();
930 auto mulElemResultElWidth =
931 mulElemResultType.getElementType().getIntOrFloatBitWidth();
932
933 if (mulElemResultElWidth == resultElWidth) {
934 rewriter.replaceOpWithNewOp<aievec::CastOp>(
935 mulOp, resultType, mulElemOp.getResult(), /*isResAcc*/ false);
936 } else if (mulElemResultElWidth > resultElWidth) {
937 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
938 mulOp.getLoc(), rewriter.getI32IntegerAttr(shiftParam));
939 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
940 mulOp, resultType, mulElemOp.getResult(), shiftParamOp.getResult());
941 } else {
942 return failure();
943 }
944
945 return success();
946 }
947
948 unsigned shiftParam;
949};
950
951// This pattern folds an extract + broadcast feeding into an
952// `aievec::aie1::FMAOp` into the op, using the shuffle attributes.
953struct FoldSplatToFMAOp : OpConversionPattern<aievec::aie1::FMAOp> {
954 using OpConversionPattern::OpConversionPattern;
955
956 LogicalResult
957 matchAndRewrite(aievec::aie1::FMAOp fmaOp, OpAdaptor adaptor,
958 ConversionPatternRewriter &rewriter) const override {
959 auto concatOp =
960 dyn_cast<aievec::ConcatOp>(adaptor.getLhs().getDefiningOp());
961 if (!concatOp)
962 return failure();
963 vector::SplatOp splatOp = nullptr;
964 auto *concatDefOp = concatOp.getSources()[0].getDefiningOp();
965 if (concatDefOp)
966 splatOp = dyn_cast<vector::SplatOp>(concatDefOp);
967 Value lhs = adaptor.getRhs();
968 if (!splatOp) {
969 splatOp = dyn_cast<vector::SplatOp>(adaptor.getRhs().getDefiningOp());
970 if (!splatOp)
971 return failure();
972 lhs = concatOp.getSources()[0];
973 }
974 auto extOp =
975 dyn_cast<vector::ExtractOp>(splatOp.getInput().getDefiningOp());
976 if (!extOp)
977 return failure();
978
979 auto rhs = extOp.getVector();
980 auto concatVecType = cast<VectorType>(concatOp.getResult().getType());
981 auto zvec = rewriter.create<arith::ConstantOp>(
982 concatOp.getLoc(), lhs.getType(), rewriter.getZeroAttr(lhs.getType()));
983 auto lhsX2 =
984 rewriter
985 .create<aievec::ConcatOp>(concatOp.getLoc(), concatVecType,
986 SmallVector<Value, 2>({lhs, zvec}))
987 .getResult();
988 // XXX: We assume a 1D vector
989 auto pos = extOp.getStaticPosition();
990 int64_t zstart = pos[0];
991 auto fmaOpAttr = buildFMAOpSplatAttrForElemTy(fmaOp, zstart);
992 rewriter.replaceOpWithNewOp<aievec::aie1::FMAOp>(
993 fmaOp, TypeRange({fmaOp.getResult().getType()}),
994 ValueRange({lhsX2, rhs, adaptor.getAcc()}), fmaOpAttr);
995
996 return success();
997 }
998};
999
1001 : OpConversionPattern<aievec::aie1::AddOp> {
1002 using OpConversionPattern::OpConversionPattern;
1003
1004 LogicalResult
1005 matchAndRewrite(aievec::aie1::AddOp addOp, OpAdaptor adaptor,
1006 ConversionPatternRewriter &rewriter) const override {
1007 auto vecType = cast<VectorType>(addOp.getType());
1008
1009 auto res =
1010 extractMACOperandsFromAddOperands(adaptor.getLhs(), adaptor.getRhs());
1011 if (!res)
1012 return failure();
1013 auto [lhs, rhs, acc] = *res;
1014
1015 SmallVector<int64_t, 4> concatVecShape(vecType.getShape().begin(),
1016 vecType.getShape().end());
1017 concatVecShape[vecType.getRank() - 1] *= 2;
1018 auto concatVecType =
1019 VectorType::get(concatVecShape, vecType.getElementType());
1020 Type accType = getVectorOpDestType(cast<VectorType>(acc.getType()),
1021 /*AIE2 =*/false);
1022 auto lhsX2 = rewriter
1023 .create<aievec::ConcatOp>(addOp.getLoc(), concatVecType,
1024 SmallVector<Value, 2>(2, lhs))
1025 .getResult();
1026 auto upsOp = rewriter.create<aievec::UPSOp>(addOp.getLoc(), accType, acc);
1027 auto fmaOp = rewriter.create<aievec::aie1::FMAOp>(
1028 addOp.getLoc(), accType, lhsX2, rhs, upsOp.getResult(),
1029 /*xstart=*/"", /*xoffsets=*/"", /*xoffsets_hi=*/"", /*xstep=*/"",
1030 /*xsquare=*/"", /*zstart=*/"", /*zoffsets=*/"", /*zoffsets_hi=*/"",
1031 /*zstep=*/"", /*zsquare=*/"", /*fmsub=*/false);
1032 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
1033 addOp.getLoc(), rewriter.getI32IntegerAttr(0));
1034 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1035 addOp, vecType, fmaOp.getResult(), shiftParamOp.getResult());
1036 return success();
1037 }
1038};
1039
1040// This pattern replaces `vector.transfer_read` with `aievec.upd`. Right now,
1041// it performs a naïve direct translation. This needs to be expanded to
1042// support more complex scenarios.
1044 : OpConversionPattern<vector::TransferReadOp> {
1045 using OpConversionPattern::OpConversionPattern;
1046
1047 LowerVectorTransferReadToAIEUPD(MLIRContext *context, int64_t minVectorSize,
1048 int64_t maxVectorSize, int64_t alignment,
1049 int64_t maxLoadSize)
1053
1054 LogicalResult
1055 matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
1056 ConversionPatternRewriter &rewriter) const override {
1057 // Masked loads
1058 if (readOp.getMask())
1059 return readOp.emitError() << "AIE doesn't support masked loads.";
1060
1061 // Non-contiguous loads
1062 AffineMap map = readOp.getPermutationMap();
1063 if (!map.isMinorIdentity())
1064 return failure();
1065
1066 // Splats
1067 if (map.isConstant())
1068 return failure();
1069
1070 // Misaligned accesses
1071 auto vType = readOp.getVectorType();
1073 .value_or(0) != 0)
1074 return failure();
1075
1076 // Invalid vector size.
1077 // We can handle cases where the vector size is:
1078 // 1) the minimum vector size
1079 // 2) a square multiple of the alignment size and up to the maximum
1080 // vector size.
1081 int64_t vSize = vType.getNumElements() * vType.getElementTypeBitWidth();
1082 if (vSize > maxVectorSize ||
1083 (vSize % vectorAlignment && vSize != minVectorSize))
1084 return failure();
1085 // We can deal with linked update instructions when the vector size is
1086 // exactly twice the load size. This could change in future architectures
1087 if (vSize > maxLoadSize && vSize != maxLoadSize * 2)
1088 return failure();
1089 int64_t multiplicity = vSize / vectorAlignment;
1090 if ((vSize > minVectorSize) && std::bitset<8>(multiplicity).count() != 1)
1091 return failure();
1092
1093 auto updOp = rewriter.create<xilinx::aievec::UPDOp>(
1094 readOp.getLoc(), vType, adaptor.getSource(), adaptor.getIndices(), 0, 0,
1095 TypedValue<VectorType>(nullptr));
1096 if (vSize > maxLoadSize) {
1097 updOp = rewriter.create<xilinx::aievec::UPDOp>(
1098 readOp.getLoc(), vType, adaptor.getSource(), adaptor.getIndices(),
1099 maxLoadSize, 1, updOp.getResult());
1100 }
1101 rewriter.replaceOp(readOp, updOp.getResult());
1102
1103 return success();
1104 }
1105
1107};
1108
1109// XXX: Notice that this template doesn't verify that the vector element type
1110// XXX: is supported by the target architecture.
1111template <typename SrcOpTy, typename DstOpTy>
1114 using OpAdaptor = typename SrcOpTy::Adaptor;
1115
1116 LogicalResult
1117 matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor,
1118 ConversionPatternRewriter &rewriter) const override {
1119 rewriter.replaceOpWithNewOp<DstOpTy>(
1120 srcOp, srcOp.getResult().getType(), adaptor.getLhs(), adaptor.getRhs(),
1121 /*xstart=*/"", /*xoffsets=*/"", /*xoffsets_hi=*/"", /*xsquare=*/"",
1122 /*zstart=*/"", /*zoffsets=*/"", /*zoffsets_hi=*/"", /*zsquare=*/"");
1123 return success();
1124 }
1125};
1126
1128 using OpConversionPattern::OpConversionPattern;
1129
1130 LogicalResult
1131 matchAndRewrite(arith::AddIOp addOp, OpAdaptor adaptor,
1132 ConversionPatternRewriter &rewriter) const override {
1133 auto resType = addOp.getType();
1134 if (!isa<VectorType>(resType))
1135 return failure();
1136
1137 auto lhs = adaptor.getLhs();
1138 auto rhs = adaptor.getRhs();
1139 auto *lhsDefOp = lhs.getDefiningOp();
1140 auto *rhsDefOp = rhs.getDefiningOp();
1141 if ((isa_and_nonnull<arith::MulIOp>(lhsDefOp)) ||
1142 (isa_and_nonnull<arith::MulIOp>(rhsDefOp)))
1143 return failure();
1144
1145 rewriter.replaceOpWithNewOp<aievec::aie1::AddOp>(
1146 addOp, resType, lhs, rhs,
1147 /*xstart=*/"", /*xoffsets=*/"", /*xoffsets_hi=*/"", /*xsquare=*/"",
1148 /*zstart=*/"", /*zoffsets=*/"", /*zoffsets_hi=*/"", /*zsquare=*/"");
1149 return success();
1150 }
1151};
1152
1161
1163 using OpConversionPattern::OpConversionPattern;
1164 LogicalResult
1165 matchAndRewrite(arith::MulIOp mulOp, OpAdaptor adaptor,
1166 ConversionPatternRewriter &rewriter) const override {
1167 auto resTy = dyn_cast<VectorType>(mulOp.getType());
1168 if (!resTy)
1169 return failure();
1170 auto accTy = getVectorOpDestType(resTy, /*AIE2 =*/false);
1171 auto newMulOp = rewriter.create<aievec::aie1::MulOp>(
1172 mulOp.getLoc(), accTy, adaptor.getLhs(), adaptor.getRhs());
1173 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
1174 mulOp.getLoc(), rewriter.getI32IntegerAttr(0));
1175 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1176 mulOp, resTy, newMulOp.getResult(), shiftParamOp.getResult());
1177 return success();
1178 }
1179};
1180
1181template <typename SrcOpTy, typename DstOpTy>
1183 : OpConversionPattern<SrcOpTy> {
1185 using OpAdaptor = typename SrcOpTy::Adaptor;
1186
1187 LogicalResult
1188 matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor,
1189 ConversionPatternRewriter &rewriter) const override {
1190 VectorType resultType = dyn_cast<VectorType>(srcOp.getType());
1191 if (!resultType)
1192 return failure();
1193
1194 // A set recording the vector lane size and element width we are supporting
1195 // for AIE2.
1196 llvm::SmallSet<std::pair<unsigned, signed>, 16> laneSizeElWidthPairSet;
1197 laneSizeElWidthPairSet.insert({64, 8});
1198 laneSizeElWidthPairSet.insert({32, 16});
1199 laneSizeElWidthPairSet.insert({16, 32});
1200 laneSizeElWidthPairSet.insert({32, 32});
1201
1202 auto lhs = adaptor.getLhs();
1203 auto rhs = adaptor.getRhs();
1204 auto lhsDefOp = lhs.getDefiningOp();
1205 auto rhsDefOp = rhs.getDefiningOp();
1206 if ((lhsDefOp && isa<arith::MulIOp>(lhsDefOp)) ||
1207 (rhsDefOp && isa<arith::MulIOp>(rhsDefOp)) ||
1208 (lhsDefOp && isa<arith::MulFOp>(lhsDefOp)) ||
1209 (rhsDefOp && isa<arith::MulFOp>(rhsDefOp)))
1210 return failure();
1211
1212 Type scalarType = resultType.getElementType();
1213 unsigned resultElWidth = scalarType.getIntOrFloatBitWidth();
1214 unsigned laneSize = getVectorLaneSize(resultType);
1215
1216 // Integer cases
1217 if (isa<IntegerType>(scalarType)) {
1218 if (!laneSizeElWidthPairSet.count(
1219 std::make_pair(laneSize, resultElWidth)))
1220 return failure();
1221
1222 // If the ops are defined without extension ops and with supported data
1223 // type, the arith::AddI or arith::SubI can be directly replaced with
1224 // aievec::AddElem or aievec::SubElem.
1225 if (!lhsDefOp && !rhsDefOp) {
1226 if (laneSize * resultElWidth == 512) {
1227 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getType(), lhs,
1228 rhs);
1229 return success();
1230 }
1231 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs, resultType,
1232 srcOp);
1233 }
1234
1235 // If element width is 32, we need to consider sign extension cases
1236 if (resultElWidth == 32) {
1237 auto lhsExt = getSourceOfWideningOp(lhs).value_or(nullptr);
1238 auto rhsExt = getSourceOfWideningOp(rhs).value_or(nullptr);
1239
1240 if (!lhsExt && !rhsExt) {
1241 if (laneSize * resultElWidth == 512) {
1242 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getType(), lhs,
1243 rhs);
1244 return success();
1245 }
1246 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1247 resultType, srcOp);
1248 }
1249
1250 if (lhsExt && rhsExt) {
1251 auto lval = lhsExt;
1252 auto rval = rhsExt;
1253 VectorType lSrcType = cast<VectorType>(lval.getType());
1254
1255 Type accType = getVectorOpDestType(lSrcType, /*AIE2 =*/true);
1256 auto lUpsOp =
1257 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, lval);
1258 auto rUpsOp =
1259 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, rval);
1260 auto elemOp = rewriter.create<DstOpTy>(
1261 srcOp.getLoc(), lUpsOp->getResult(0).getType(),
1262 lUpsOp->getResult(0), rUpsOp->getResult(0));
1263 rewriter.replaceOpWithNewOp<aievec::CastOp>(
1264 srcOp, srcOp.getType(), elemOp.getResult(), /*isResAcc*/ false);
1265 return success();
1266 }
1267
1268 if (!lhsExt || !rhsExt) {
1269 auto lval = lhsExt ? lhsExt : lhs;
1270 auto rval = rhsExt ? rhsExt : rhs;
1271 auto extVal = lhsExt ? lval : rval;
1272 VectorType vType = cast<VectorType>(extVal.getType());
1273 unsigned bitWidth = vType.getElementType().getIntOrFloatBitWidth();
1274
1275 if (bitWidth != 8 && bitWidth != 16) {
1276 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1277 resultType, srcOp);
1278 }
1279
1280 if (bitWidth * laneSize != 256) {
1281 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1282 resultType, srcOp);
1283 }
1284
1285 Type accType = nullptr;
1286
1287 if (bitWidth == 8) {
1288 accType = getVectorOpDestType(vType, /*AIE2 =*/true);
1289 Value valToUps = lhsExt ? lval : rval;
1290 Value valToCast = lhsExt ? rval : lval;
1291 auto upsOp = rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType,
1292 valToUps);
1293 auto castOp = rewriter.create<aievec::CastOp>(
1294 srcOp.getLoc(), resultType, valToCast, /*isResAcc*/ true);
1295 Value lhsToElemOp =
1296 lhsExt ? upsOp->getResult(0) : castOp->getResult(0);
1297 Value rhsToElemOp =
1298 lhsExt ? castOp->getResult(0) : upsOp->getResult(0);
1299 auto elemOp = rewriter.create<DstOpTy>(
1300 srcOp.getLoc(), upsOp->getResult(0).getType(), lhsToElemOp,
1301 rhsToElemOp);
1302 rewriter.replaceOpWithNewOp<aievec::CastOp>(
1303 srcOp, srcOp.getType(), elemOp.getResult(), /*isResAcc*/ false);
1304 return success();
1305 }
1306
1307 if (bitWidth == 16) {
1308 accType = getVectorOpDestType(resultType, /*AIE2 =*/true);
1309 auto lUpsOp =
1310 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, lval);
1311 auto rUpsOp =
1312 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, rval);
1313
1314 auto elemOp = rewriter.create<DstOpTy>(
1315 srcOp.getLoc(), lUpsOp->getResult(0).getType(),
1316 lUpsOp->getResult(0), rUpsOp->getResult(0));
1317
1318 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
1319 srcOp.getLoc(), rewriter.getI32IntegerAttr(0));
1320 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1321 srcOp, srcOp.getType(), elemOp.getResult(),
1322 shiftParamOp.getResult());
1323 return success();
1324 }
1325 }
1326 } else {
1327 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getType(), lhs, rhs);
1328 return success();
1329 }
1330 }
1331 // Float types
1332 else {
1333 if (laneSize != 16)
1334 return failure();
1335
1336 // v16float or v16bf16 with extension op case
1337 if (resultElWidth == 32) {
1338 if (!lhsDefOp && !rhsDefOp) {
1339 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1340 resultType, srcOp);
1341 }
1342
1343 auto lhsExt = getSourceOfWideningOp(lhs).value_or(nullptr);
1344 auto rhsExt = getSourceOfWideningOp(rhs).value_or(nullptr);
1345 // v16float
1346 if (!lhsExt && !rhsExt) {
1347 return genAddElemAIE2<SrcOpTy, DstOpTy>(rewriter, lhs, rhs,
1348 resultType, srcOp);
1349 }
1350
1351 // v16bf16 with two extension ops
1352 if (lhsExt && rhsExt) {
1353 auto lval = lhsExt;
1354 auto rval = rhsExt;
1355 VectorType vType = cast<VectorType>(lval.getType());
1356
1357 Type accType = getVectorOpDestType(vType, /*AIE2 =*/true);
1358 auto lUpsOp =
1359 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, lval);
1360 auto rUpsOp =
1361 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, rval);
1362 auto elemOp = rewriter.create<DstOpTy>(
1363 srcOp.getLoc(), lUpsOp->getResult(0).getType(),
1364 lUpsOp->getResult(0), rUpsOp->getResult(0));
1365 rewriter.replaceOpWithNewOp<aievec::CastOp>(srcOp, srcOp.getType(),
1366 elemOp.getResult());
1367 return success();
1368 }
1369
1370 // v16bf16 with one extension op
1371 if (!lhsExt || !rhsExt) {
1372 auto lval = lhsExt ? lhsExt : lhs;
1373 auto rval = rhsExt ? rhsExt : rhs;
1374 auto extVal = lhsExt ? lval : rval;
1375 VectorType vType = cast<VectorType>(extVal.getType());
1376 Type accType = getVectorOpDestType(vType, /*AIE2 =*/true);
1377
1378 aievec::UPSOp upsOp;
1379 aievec::CastOp castOp;
1380 if (lhsExt) {
1381 upsOp =
1382 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, lval);
1383 castOp = rewriter.create<aievec::CastOp>(srcOp.getLoc(), resultType,
1384 rval,
1385 /*isResAcc*/ true);
1386 } else {
1387 upsOp =
1388 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, rval);
1389 castOp = rewriter.create<aievec::CastOp>(srcOp.getLoc(), resultType,
1390 lval,
1391 /*isResAcc*/ true);
1392 }
1393
1394 auto elemOp = rewriter.create<DstOpTy>(
1395 srcOp.getLoc(), upsOp->getResult(0).getType(),
1396 upsOp->getResult(0), castOp->getResult(0));
1397
1398 rewriter.replaceOpWithNewOp<aievec::CastOp>(
1399 srcOp, srcOp.getType(), elemOp.getResult(), /*isResAcc*/ false);
1400
1401 return success();
1402 }
1403 }
1404
1405 // v16bfloat16
1406 Type accType = getVectorOpDestType(resultType, /*AIE2 =*/true);
1407 auto lUpsOp =
1408 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, lhs);
1409 auto rUpsOp =
1410 rewriter.create<aievec::UPSOp>(srcOp.getLoc(), accType, rhs);
1411 auto elemOp = rewriter.create<DstOpTy>(
1412 srcOp.getLoc(), lUpsOp->getResult(0).getType(), lUpsOp->getResult(0),
1413 rUpsOp->getResult(0));
1414 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
1415 srcOp.getLoc(), rewriter.getI32IntegerAttr(0));
1416 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1417 srcOp, srcOp.getType(), elemOp.getResult(), shiftParamOp.getResult());
1418
1419 return success();
1420 }
1421
1422 return failure();
1423 }
1424};
1425
1428 aievec::AddElemOp>;
1431 aievec::SubElemOp>;
1434 aievec::AddElemOp>;
1437 aievec::SubElemOp>;
1438
1439template <typename SrcOpTy, typename DstOpTy>
1442 using OpAdaptor = typename SrcOpTy::Adaptor;
1443
1444 LogicalResult
1445 matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor,
1446 ConversionPatternRewriter &rewriter) const override {
1447 VectorType resultType = dyn_cast<VectorType>(srcOp.getType());
1448 if (!resultType)
1449 return failure();
1450
1451 // A set recording the element width we are supporting for AIE2.
1452 llvm::SmallSet<unsigned, 16> elWidthSet;
1453 elWidthSet.insert(8);
1454 elWidthSet.insert(16);
1455 elWidthSet.insert(32);
1456
1457 Type scalarType = resultType.getElementType();
1458 unsigned resultElWidth = scalarType.getIntOrFloatBitWidth();
1459 unsigned laneSize = getVectorLaneSize(resultType);
1460
1461 if (!elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512)
1462 return failure();
1463
1464 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getType(),
1465 adaptor.getLhs(), adaptor.getRhs());
1466 return success();
1467 }
1468};
1469
1478
1479template <typename SrcOpTy, typename CmpTy>
1482 using OpAdaptor = typename SrcOpTy::Adaptor;
1483
1484 LogicalResult
1485 matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor,
1486 ConversionPatternRewriter &rewriter) const override {
1487 VectorType lhsType = dyn_cast<VectorType>(srcOp.getLhs().getType());
1488 if (!lhsType)
1489 return failure();
1490
1491 llvm::SmallSet<unsigned, 16> elWidthSet;
1492 elWidthSet.insert(8);
1493 elWidthSet.insert(16);
1494 elWidthSet.insert(32);
1495
1496 Type scalarType = lhsType.getElementType();
1497 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
1498 unsigned laneSize = getVectorLaneSize(lhsType);
1499
1500 if (!elWidthSet.count(elWidth) || laneSize * elWidth != 512)
1501 return failure();
1502
1503 // Unsigned int and unsigned long long are acceptable type.
1504 Type type =
1505 mlir::IntegerType::get(srcOp.getContext(), laneSize <= 32 ? 32 : 64,
1506 mlir::IntegerType::Unsigned);
1507
1508 Location loc = srcOp.getLoc();
1509 Value lhs = srcOp.getLhs();
1510 Value rhs = srcOp.getRhs();
1511 CmpTy pred = srcOp.getPredicate();
1512
1513 arith::CmpIPredicate ipred = convertToIntegerPredicate(pred);
1514
1515 aievec::CmpOp aieCmpOp =
1516 createCmpOpAIE2(rewriter, ipred, loc, type, lhs, rhs);
1517
1518 if (!aieCmpOp)
1519 return failure();
1520
1521 VectorType resultType = dyn_cast<VectorType>(srcOp.getResult().getType());
1522 // Convert vector i1 type to unsigned interger type by built-in unrealized
1523 // conversion cast op.
1524 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
1525 srcOp, resultType, aieCmpOp.getResult());
1526
1527 return success();
1528 }
1529};
1530
1535
1537 using OpConversionPattern::OpConversionPattern;
1538
1539 LogicalResult
1540 matchAndRewrite(arith::SelectOp srcOp, OpAdaptor adaptor,
1541 ConversionPatternRewriter &rewriter) const override {
1542 auto resultType = dyn_cast<VectorType>(srcOp.getType());
1543 if (!resultType)
1544 return failure();
1545
1546 llvm::SmallSet<unsigned, 16> elWidthSet;
1547 elWidthSet.insert(8);
1548 elWidthSet.insert(16);
1549 elWidthSet.insert(32);
1550
1551 Type scalarType = resultType.getElementType();
1552 unsigned resultElWidth = scalarType.getIntOrFloatBitWidth();
1553 unsigned laneSize = getVectorLaneSize(resultType);
1554
1555 if (!elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512)
1556 return failure();
1557
1558 Type type =
1559 mlir::IntegerType::get(srcOp.getContext(), laneSize <= 32 ? 32 : 64,
1560 mlir::IntegerType::Unsigned);
1561
1562 auto convertOp = rewriter.create<UnrealizedConversionCastOp>(
1563 srcOp.getLoc(), type, adaptor.getCondition());
1564
1565 rewriter.replaceOpWithNewOp<aievec::SelOp>(
1566 srcOp, srcOp.getResult().getType(), srcOp.getTrueValue(),
1567 srcOp.getFalseValue(), convertOp.getResult(0));
1568
1569 return success();
1570 }
1571};
1572
1573struct LowerVectorReductionMinOp : OpConversionPattern<vector::ReductionOp> {
1574 using OpConversionPattern::OpConversionPattern;
1575
1576 LogicalResult
1577 matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor,
1578 ConversionPatternRewriter &rewriter) const override {
1579 if (auto kind = srcOp.getKind(); kind != vector::CombiningKind::MINSI &&
1580 kind != vector::CombiningKind::MINUI &&
1581 kind != vector::CombiningKind::MINIMUMF)
1582 return failure();
1583
1584 auto vType = cast<VectorType>(srcOp.getVector().getType());
1585 Type scalarType = vType.getElementType();
1586 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
1587 unsigned laneSize = getVectorLaneSize(vType);
1588
1589 if (laneSize * elWidth != 512)
1590 return failure();
1591
1592 int shiftIndex = laneSize / 2;
1593 generateAIEVecOpsForReductionOp<aievec::MinOp>(rewriter, srcOp, shiftIndex,
1594 srcOp.getVector());
1595 return success();
1596 }
1597};
1598
1599struct LowerVectorReductionMaxOp : OpConversionPattern<vector::ReductionOp> {
1600 using OpConversionPattern::OpConversionPattern;
1601
1602 LogicalResult
1603 matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor,
1604 ConversionPatternRewriter &rewriter) const override {
1605 if (auto kind = srcOp.getKind(); kind != vector::CombiningKind::MAXSI &&
1606 kind != vector::CombiningKind::MAXUI &&
1607 kind != vector::CombiningKind::MAXIMUMF)
1608 return failure();
1609
1610 auto vType = cast<VectorType>(srcOp.getVector().getType());
1611 Type scalarType = vType.getElementType();
1612 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
1613 unsigned laneSize = getVectorLaneSize(vType);
1614
1615 if (laneSize * elWidth != 512)
1616 return failure();
1617
1618 int shiftIndex = laneSize / 2;
1619 generateAIEVecOpsForReductionOp<aievec::MaxOp>(rewriter, srcOp, shiftIndex,
1620 srcOp.getVector());
1621 return success();
1622 }
1623};
1624
1626 using OpConversionPattern::OpConversionPattern;
1627
1628 LogicalResult
1629 matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor,
1630 ConversionPatternRewriter &rewriter) const override {
1631 if (auto kind = srcOp.getKind(); kind != vector::CombiningKind::ADD)
1632 return failure();
1633
1634 auto vType = cast<VectorType>(srcOp.getVector().getType());
1635 Type scalarType = vType.getElementType();
1636 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
1637 unsigned laneSize = getVectorLaneSize(vType);
1638 llvm::SmallSet<std::pair<unsigned, signed>, 16> laneSizeElWidthPairSet;
1639 laneSizeElWidthPairSet.insert({64, 8});
1640 laneSizeElWidthPairSet.insert({32, 16});
1641 laneSizeElWidthPairSet.insert({32, 32});
1642 laneSizeElWidthPairSet.insert({16, 32});
1643
1644 if (!isa<IntegerType>(scalarType) ||
1645 !laneSizeElWidthPairSet.count(std::make_pair(laneSize, elWidth)))
1646 return failure();
1647
1648 int shiftIndex = laneSize / 2;
1649 if (laneSize == 32 && elWidth == 32) {
1650 Location loc = srcOp.getLoc();
1651 VectorType vecType = createVectorType(laneSize / 2, scalarType);
1652
1653 auto lExtOp =
1654 rewriter.create<aievec::ExtOp>(loc, vecType, srcOp.getVector(), 0);
1655 auto rExtOp =
1656 rewriter.create<aievec::ExtOp>(loc, vecType, srcOp.getVector(), 1);
1657 auto addElemOp = rewriter.create<aievec::AddElemOp>(
1658 loc, lExtOp.getResult().getType(), lExtOp.getResult(),
1659 rExtOp.getResult());
1660 shiftIndex /= 2;
1661 generateAIEVecOpsForReductionOp<aievec::AddElemOp>(
1662 rewriter, srcOp, shiftIndex, addElemOp.getResult());
1663 } else
1664 generateAIEVecOpsForReductionOp<aievec::AddElemOp>(
1665 rewriter, srcOp, shiftIndex, srcOp.getVector());
1666
1667 return success();
1668 }
1669};
1670
1672 : OpConversionPattern<vector::ReductionOp> {
1673 using OpConversionPattern::OpConversionPattern;
1674
1675 LogicalResult
1676 matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor,
1677 ConversionPatternRewriter &rewriter) const override {
1678 if (auto kind = srcOp.getKind(); kind != vector::CombiningKind::ADD)
1679 return failure();
1680
1681 auto vType = cast<VectorType>(srcOp.getVector().getType());
1682 Type scalarType = vType.getElementType();
1683 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
1684 unsigned laneSize = getVectorLaneSize(vType);
1685
1686 if (!isa<FloatType>(scalarType) || laneSize != 16 || elWidth != 32)
1687 return failure();
1688
1689 int shiftIndex = laneSize / 2;
1690 assert(shiftIndex > 0 && (shiftIndex & (shiftIndex - 1)) == 0 &&
1691 "shiftIndex must be power of 2");
1692
1693 Location loc = srcOp.getLoc();
1694 Value curValue = srcOp.getVector();
1695 aievec::CastOp curOp = nullptr;
1696
1697 for (int id = shiftIndex; id > 0; id /= 2) {
1698 auto constOp = rewriter.create<arith::ConstantOp>(
1699 loc, rewriter.getI32IntegerAttr(id * elWidth / 8));
1700
1701 auto shiftBytesOp = rewriter.create<aievec::ShiftOp>(
1702 loc, vType, curValue, curValue, constOp.getResult());
1703
1704 auto lCastOp = rewriter.create<aievec::CastOp>(loc, vType, curValue,
1705 /*isResAcc*/ true);
1706 auto rCastOp =
1707 rewriter.create<aievec::CastOp>(loc, vType, shiftBytesOp.getResult(),
1708 /*isResAcc*/ true);
1709 auto elemOp = rewriter.create<aievec::AddElemOp>(
1710 loc, lCastOp.getResult().getType(), lCastOp.getResult(),
1711 rCastOp.getResult());
1712 curOp = rewriter.create<aievec::CastOp>(loc, vType, elemOp.getResult(),
1713 /*isResAcc*/ false);
1714 curValue = curOp.getResult();
1715 }
1716
1717 auto zeroConstOp =
1718 rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
1719 rewriter.replaceOpWithNewOp<aievec::ExtElemOp>(srcOp, scalarType, curOp,
1720 zeroConstOp.getResult());
1721 return success();
1722 }
1723};
1724
1726 : OpConversionPattern<vector::ReductionOp> {
1727 using OpConversionPattern::OpConversionPattern;
1728
1729 LogicalResult
1730 matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor,
1731 ConversionPatternRewriter &rewriter) const override {
1732 if (auto kind = srcOp.getKind(); kind != vector::CombiningKind::ADD)
1733 return failure();
1734
1735 auto vType = cast<VectorType>(srcOp.getVector().getType());
1736 Type scalarType = vType.getElementType();
1737 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
1738 unsigned laneSize = getVectorLaneSize(vType);
1739
1740 if (!isa<FloatType>(scalarType) || laneSize != 16 || elWidth != 16)
1741 return failure();
1742
1743 int shiftIndex = laneSize / 2;
1744 assert(shiftIndex > 0 && (shiftIndex & (shiftIndex - 1)) == 0 &&
1745 "shiftIndex must be power of 2");
1746
1747 Value curValue = srcOp.getVector();
1748 Location loc = srcOp.getLoc();
1749 Type accType = getVectorOpDestType(vType, /*AIE2 =*/true);
1750 unsigned accWidth =
1751 dyn_cast<VectorType>(accType).getElementType().getIntOrFloatBitWidth();
1752
1753 auto upsOp =
1754 rewriter.create<aievec::UPSOp>(loc, accType, srcOp.getVector());
1755
1756 curValue = upsOp.getResult();
1757
1758 VectorType vecType = createVectorType(2 * laneSize, scalarType);
1759 aievec::AddElemOp curOp = nullptr;
1760
1761 for (int id = shiftIndex; id > 0; id /= 2) {
1762 auto constOp = rewriter.create<arith::ConstantOp>(
1763 loc, rewriter.getI32IntegerAttr(id * accWidth / 8));
1764 auto shiftBytesOp = rewriter.create<aievec::ShiftOp>(
1765 loc, accType, curValue, curValue, constOp, true);
1766 curOp = rewriter.create<aievec::AddElemOp>(loc, accType, curValue,
1767 shiftBytesOp.getResult());
1768 curValue = curOp.getResult();
1769 }
1770
1771 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
1772 srcOp.getLoc(), rewriter.getI32IntegerAttr(0));
1773 auto srsOp = rewriter.create<aievec::SRSOp>(loc, vType, curOp.getResult(),
1774 shiftParamOp.getResult());
1775 SmallVector<Value> concatSources = {srsOp.getResult(), srsOp.getResult()};
1776 auto concatOp =
1777 rewriter.create<aievec::ConcatOp>(loc, vecType, concatSources);
1778
1779 auto zeroConstOp =
1780 rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
1781 rewriter.replaceOpWithNewOp<aievec::ExtElemOp>(srcOp, scalarType, concatOp,
1782 zeroConstOp.getResult());
1783 return success();
1784 }
1785};
1786
1787// Convert a `vector.extract_strided_slice` op on 1D vectors into an
1788// `aievec.select` + `aievec.ext` op.
1790 : OpConversionPattern<vector::ExtractStridedSliceOp> {
1791 using OpConversionPattern::OpConversionPattern;
1792
1793 LogicalResult
1794 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
1795 ConversionPatternRewriter &rewriter) const override {
1796 auto vType = extractOp.getSourceVectorType();
1797 if (vType.getRank() != 1)
1798 return failure();
1799
1800 int64_t stride = cast<IntegerAttr>(adaptor.getStrides()[0]).getInt();
1801 if (stride != 1)
1802 return failure();
1803
1804 // AIE doesn't support select operations on i8
1805 if (getElementSizeInBits(vType) == 8)
1806 return extractOp.emitError()
1807 << "AIEv1 doesn't support select ops on int8 types";
1808
1809 // We only accept the case where we are extracting a slice half the size of
1810 // the input vector.
1811 int64_t size = cast<IntegerAttr>(adaptor.getSizes()[0]).getInt();
1812 if (vType.getNumElements() != 2 * size)
1813 return failure();
1814
1815 int64_t offset = cast<IntegerAttr>(adaptor.getOffsets()[0]).getInt();
1816 auto selectOp = rewriter.create<aievec::aie1::SelectOp>(
1817 extractOp.getLoc(), vType, adaptor.getVector(),
1818 buildAttributeListForRotationSelectOp(rewriter, vType, offset));
1819 rewriter.replaceOpWithNewOp<aievec::aie1::ExtOp>(
1820 extractOp, extractOp.getType(), selectOp.getResult(),
1821 rewriter.getI8IntegerAttr(0));
1822 return success();
1823 }
1824};
1825
1826// Convert a `vector.extract_strided_slice` op on 1D vectors into an
1827// `aievec.shift` op.
1829 : OpConversionPattern<vector::ExtractStridedSliceOp> {
1830 using OpConversionPattern::OpConversionPattern;
1831
1832 LogicalResult
1833 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
1834 ConversionPatternRewriter &rewriter) const override {
1835 auto vType = cast<VectorType>(adaptor.getVector().getType());
1836 if (vType.getRank() != 1)
1837 return failure();
1838
1839 int64_t stride = cast<IntegerAttr>(adaptor.getStrides()[0]).getInt();
1840 if (stride != 1)
1841 return failure();
1842
1843 // We only accept the case where we are extracting a slice half the size of
1844 // the input vector.
1845 int64_t size = cast<IntegerAttr>(adaptor.getSizes()[0]).getInt();
1846 if (vType.getNumElements() != 2 * size)
1847 return failure();
1848
1849 auto shortVecType = cast<VectorType>(extractOp.getResult().getType());
1850 auto bottomHalf = rewriter
1851 .create<aievec::ExtOp>(
1852 extractOp.getLoc(), shortVecType,
1853 adaptor.getVector(), rewriter.getI8IntegerAttr(0))
1854 .getResult();
1855 auto topHalf = rewriter
1856 .create<aievec::ExtOp>(extractOp.getLoc(), shortVecType,
1857 adaptor.getVector(),
1858 rewriter.getI8IntegerAttr(1))
1859 .getResult();
1860 int64_t offset = cast<IntegerAttr>(adaptor.getOffsets()[0]).getInt();
1861 int32_t shiftBytes = offset * getElementSizeInBits(vType) / 8;
1862 auto shiftBytesConstOp = rewriter.create<arith::ConstantOp>(
1863 extractOp.getLoc(), rewriter.getIntegerType(32),
1864 rewriter.getI32IntegerAttr(shiftBytes));
1865 rewriter.replaceOpWithNewOp<aievec::ShiftOp>(
1866 extractOp, shortVecType, bottomHalf, topHalf, shiftBytesConstOp);
1867
1868 return success();
1869 }
1870};
1871
1872// Replaces a short UPD op with a wide one followed by an ext op of the bottom
1873// half.
1875 using OpConversionPattern::OpConversionPattern;
1876
1877 ExpandUPDToUPDAndExtPattern(MLIRContext *context)
1878 : OpConversionPattern(context) {}
1879
1880 LogicalResult
1881 matchAndRewrite(aievec::UPDOp updOp, OpAdaptor adaptor,
1882 ConversionPatternRewriter &rewriter) const override {
1883 // Verify that we haven't already expanded this one
1884 if (updOp->hasOneUse() && isa<aievec::ExtOp>(*updOp->getUsers().begin()))
1885 return failure();
1886
1887 auto vecType = cast<VectorType>(updOp.getType());
1888 SmallVector<int64_t, 4> vecShape(vecType.getShape().begin(),
1889 vecType.getShape().end());
1890 vecShape[vecType.getRank() - 1] *= 2;
1891 auto longVecType = VectorType::get(vecShape, vecType.getElementType());
1892 auto newUpdOp = rewriter.create<aievec::UPDOp>(
1893 updOp.getLoc(), longVecType, adaptor.getSource(), adaptor.getIndices(),
1894 adaptor.getOffset(), adaptor.getIndex(), adaptor.getVector());
1895 rewriter.replaceOpWithNewOp<aievec::ExtOp>(
1896 updOp, vecType, newUpdOp.getResult(), rewriter.getI8IntegerAttr(0));
1897
1898 return success();
1899 }
1900};
1901
1902// Replaces a wide UPD op followed by an ext op of the bottom half with a short
1903// UPD op.
1905 using OpConversionPattern::OpConversionPattern;
1906
1907 FuseExtIntoUPDPattern(MLIRContext *context) : OpConversionPattern(context) {}
1908
1909 LogicalResult
1910 matchAndRewrite(aievec::ExtOp extOp, OpAdaptor adaptor,
1911 ConversionPatternRewriter &rewriter) const override {
1912 // Verify we are extracting the lower half...
1913 if (extOp.getIndex() != 0)
1914 return failure();
1915 // ...of a UPDOp
1916 auto updOp = dyn_cast<aievec::UPDOp>(extOp.getSource().getDefiningOp());
1917 if (!updOp)
1918 return failure();
1919
1920 // Verify that this is a direct upd -> ext pattern
1921 if (!updOp->hasOneUse())
1922 return failure();
1923
1924 rewriter.replaceOpWithNewOp<aievec::UPDOp>(
1925 extOp, extOp.getType(), updOp.getSource(), updOp.getIndices(),
1926 updOp.getOffset(), updOp.getIndex(), updOp.getVector());
1927
1928 return success();
1929 }
1930};
1931
1933 using OpConversionPattern::OpConversionPattern;
1934
1935 LogicalResult
1936 matchAndRewrite(math::ExpOp expOp, OpAdaptor adaptor,
1937 ConversionPatternRewriter &rewriter) const override {
1938
1939 if (!matchExpOpForLUT(adaptor))
1940 return failure();
1941
1942 auto srcType = dyn_cast<VectorType>(adaptor.getOperand().getType());
1943 StringRef funcName = "getExpBf16";
1944 auto moduleOp = expOp->getParentOfType<mlir::ModuleOp>();
1945
1946 VectorType v16bf16Ty = mlir::VectorType::get({16}, rewriter.getBF16Type());
1947 VectorType v8i64Ty = mlir::VectorType::get({8}, rewriter.getI64Type());
1948 func::FuncOp fnOp = getOrInsertFuncDecl(
1949 rewriter, moduleOp, funcName, TypeRange{v16bf16Ty}, TypeRange{v8i64Ty});
1950
1951 SmallVector<Value> expOperands = {adaptor.getOperand()};
1952
1953 Type accTypeNative = getVectorOpDestType(srcType, /*AIE2 =*/true);
1954 auto callOp =
1955 rewriter.create<func::CallOp>(expOp.getLoc(), fnOp, expOperands);
1956 auto resCastOp = rewriter.create<vector::BitCastOp>(
1957 expOp.getLoc(), accTypeNative, callOp.getResults());
1958 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
1959 expOp.getLoc(), rewriter.getI32IntegerAttr(0));
1960 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
1961 expOp, srcType, resCastOp.getResult(), shiftParamOp.getResult());
1962
1963 return success();
1964 }
1965};
1966// Lower ExpOp to function call
1968 using OpConversionPattern::OpConversionPattern;
1969
1970 LogicalResult
1971 matchAndRewrite(math::ExpOp expOp, OpAdaptor adaptor,
1972 ConversionPatternRewriter &rewriter) const override {
1973 if (!matchExpOpForLUT(adaptor))
1974 return failure();
1975 auto srcType = dyn_cast<VectorType>(adaptor.getOperand().getType());
1976 StringRef includeName = "lut_based_ops.h";
1977 auto moduleOp = expOp->getParentOfType<mlir::ModuleOp>();
1978 rewriter.setInsertionPointToStart(
1979 &moduleOp.getRegion().getBlocks().front());
1980 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName, false);
1981
1982 rewriter.setInsertionPoint(expOp);
1983
1984 auto v16bf16OpaqueTy =
1985 emitc::OpaqueType::get(rewriter.getContext(), "v16bfloat16");
1986 auto opaquedOperand =
1987 rewriter
1988 .create<UnrealizedConversionCastOp>(expOp.getLoc(), v16bf16OpaqueTy,
1989 adaptor.getOperand())
1990 .getResult(0);
1991 SmallVector<Value> expOperands = {opaquedOperand};
1992
1993 Type accTypeNative = getVectorOpDestType(srcType, /*AIE2 =*/true);
1994 Type v16accf32OpaqueTy =
1995 emitc::OpaqueType::get(rewriter.getContext(), "v16accfloat");
1996 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
1997 expOp.getLoc(), TypeRange{v16accf32OpaqueTy}, "getExpBf16", nullptr,
1998 nullptr, expOperands);
1999 auto resCastOp = rewriter.create<UnrealizedConversionCastOp>(
2000 expOp.getLoc(), accTypeNative, callOp.getResults());
2001 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
2002 expOp.getLoc(), rewriter.getI32IntegerAttr(0));
2003 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
2004 expOp, srcType, resCastOp.getResult(0), shiftParamOp.getResult());
2005
2006 return success();
2007 }
2008};
2009
2010// Lower the inverse of a float to a function call
2011// Convert the pattern-
2012// %cst = arith.constant 1.000000e+00 : f32
2013// %0 = arith.divf %cst, %arg1 : f32
2014// %1 = arith.truncf %0 : f32 to bf16
2015// to -
2016// %0 = emitc.call "getInvBf16"(%0) : f32 -> bf16;
2018 using OpConversionPattern::OpConversionPattern;
2019
2020 LogicalResult
2021 matchAndRewrite(arith::DivFOp divOp, OpAdaptor adaptor,
2022 ConversionPatternRewriter &rewriter) const override {
2023 Type srcType = adaptor.getLhs().getType();
2024 if (!divOp->hasOneUse() || isa<VectorType>(srcType) ||
2025 !isa<FloatType>(srcType))
2026 return failure();
2027
2028 if (!isNarrowingOp(*divOp->getUsers().begin()))
2029 return failure();
2030
2031 auto fType = cast<FloatType>(srcType);
2032 if (fType.getWidth() != 32)
2033 return failure();
2034
2035 auto constOp = dyn_cast<arith::ConstantOp>(divOp.getLhs().getDefiningOp());
2036 if (!constOp ||
2037 cast<FloatAttr>(constOp.getValue()).getValue().convertToDouble() !=
2038 1.0f)
2039 return failure();
2040
2041 StringRef includeName = "lut_based_ops.h";
2042 auto moduleOp = divOp->getParentOfType<mlir::ModuleOp>();
2043 rewriter.setInsertionPointToStart(
2044 &moduleOp.getRegion().getBlocks().front());
2045 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName, false);
2046
2047 auto truncOp = cast<arith::TruncFOp>(*divOp->getUsers().begin());
2048
2049 rewriter.setInsertionPoint(truncOp);
2050 Type bf16OpaqueTy =
2051 emitc::OpaqueType::get(rewriter.getContext(), "bfloat16");
2052 SmallVector<Value> invOperands = {adaptor.getRhs()};
2053 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2054 truncOp.getLoc(), bf16OpaqueTy, "getInvBf16", nullptr, nullptr,
2055 invOperands);
2056 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2057 truncOp, TypeRange{truncOp.getResult().getType()}, callOp.getResults());
2058 rewriter.eraseOp(divOp);
2059
2060 return success();
2061 }
2062};
2063
2064// Convert math.tanh to a function call to compute tanh(x) by look up tables
2066 using OpConversionPattern::OpConversionPattern;
2067
2068 LogicalResult
2069 matchAndRewrite(math::TanhOp tanhOp, OpAdaptor adaptor,
2070 ConversionPatternRewriter &rewriter) const override {
2071 auto srcType = dyn_cast<VectorType>(tanhOp.getOperand().getType());
2072 if (!srcType)
2073 return failure();
2074
2075 Type scalarType = srcType.getElementType();
2076 if (!isa<FloatType>(scalarType))
2077 return failure();
2078
2079 unsigned laneSize = getVectorLaneSize(srcType);
2080 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2081 if (elWidth != 16 || laneSize != 16)
2082 return failure();
2083
2084 StringRef includeName = "lut_based_ops.h";
2085 auto moduleOp = tanhOp->getParentOfType<mlir::ModuleOp>();
2086 rewriter.setInsertionPointToStart(
2087 &moduleOp.getRegion().getBlocks().front());
2088 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName, false);
2089
2090 rewriter.setInsertionPoint(tanhOp);
2091 Type v16bf16OpaqueTy =
2092 emitc::OpaqueType::get(rewriter.getContext(), "v16bfloat16");
2093 auto opaquedOperand =
2094 rewriter
2095 .create<UnrealizedConversionCastOp>(
2096 tanhOp.getLoc(), v16bf16OpaqueTy, adaptor.getOperand())
2097 .getResult(0);
2098 SmallVector<Value> tanhOperands = {opaquedOperand};
2099 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2100 tanhOp.getLoc(), v16bf16OpaqueTy, "getTanhBf16", nullptr, nullptr,
2101 tanhOperands);
2102 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2103 tanhOp, TypeRange{tanhOp.getResult().getType()}, callOp.getResults());
2104
2105 return success();
2106 }
2107};
2108
2109// Convert math.sqrt to a function call to compute sqrt(x) for v16bfloat16 and
2110// v32bfloat16 types
2112 using OpConversionPattern::OpConversionPattern;
2113
2114 LogicalResult
2115 matchAndRewrite(math::SqrtOp sqrtOp, OpAdaptor adaptor,
2116 ConversionPatternRewriter &rewriter) const override {
2117 auto srcType = dyn_cast<VectorType>(sqrtOp.getOperand().getType());
2118 if (!srcType)
2119 return failure();
2120
2121 Type scalarType = srcType.getElementType();
2122 if (!isa<FloatType>(scalarType))
2123 return failure();
2124
2125 unsigned laneSize = getVectorLaneSize(srcType);
2126 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2127 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
2128 return failure();
2129
2130 StringRef includeName = "vec_math.h";
2131 auto moduleOp = sqrtOp->getParentOfType<mlir::ModuleOp>();
2132 rewriter.setInsertionPointToStart(
2133 &moduleOp.getRegion().getBlocks().front());
2134 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName, false);
2135
2136 rewriter.setInsertionPoint(sqrtOp);
2137 Type vLNbf16OpaqueTy;
2138 if (laneSize == 16)
2139 vLNbf16OpaqueTy =
2140 emitc::OpaqueType::get(rewriter.getContext(), "v16bfloat16");
2141 else
2142 vLNbf16OpaqueTy =
2143 emitc::OpaqueType::get(rewriter.getContext(), "v32bfloat16");
2144 auto opaquedOperand =
2145 rewriter
2146 .create<UnrealizedConversionCastOp>(
2147 sqrtOp.getLoc(), vLNbf16OpaqueTy, adaptor.getOperand())
2148 .getResult(0);
2149 SmallVector<Value> sqrtOperands = {opaquedOperand};
2150 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2151 sqrtOp.getLoc(), TypeRange{vLNbf16OpaqueTy}, "getSqrtBf16", nullptr,
2152 nullptr, sqrtOperands);
2153 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2154 sqrtOp, TypeRange{sqrtOp.getResult().getType()}, callOp.getResults());
2155
2156 return success();
2157 }
2158};
2159
2160// Convert math.rsqrt to a function call to compute 1.0f / sqrt(x) for
2161// v16bfloat16 and v32bfloat16 types
2163 using OpConversionPattern::OpConversionPattern;
2164
2165 LogicalResult
2166 matchAndRewrite(math::RsqrtOp rsqrtOp, OpAdaptor adaptor,
2167 ConversionPatternRewriter &rewriter) const override {
2168 auto srcType = dyn_cast<VectorType>(rsqrtOp.getOperand().getType());
2169 if (!srcType)
2170 return failure();
2171
2172 Type scalarType = srcType.getElementType();
2173 if (!isa<FloatType>(scalarType))
2174 return failure();
2175
2176 unsigned laneSize = getVectorLaneSize(srcType);
2177 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2178 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
2179 return failure();
2180
2181 StringRef includeName = "vec_math.h";
2182 auto moduleOp = rsqrtOp->getParentOfType<mlir::ModuleOp>();
2183 rewriter.setInsertionPointToStart(
2184 &moduleOp.getRegion().getBlocks().front());
2185 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName, false);
2186
2187 rewriter.setInsertionPoint(rsqrtOp);
2188 Type vLNbf16OpaqueTy;
2189 if (laneSize == 16)
2190 vLNbf16OpaqueTy =
2191 emitc::OpaqueType::get(rewriter.getContext(), "v16bfloat16");
2192 else
2193 vLNbf16OpaqueTy =
2194 emitc::OpaqueType::get(rewriter.getContext(), "v32bfloat16");
2195 auto opaquedOperand =
2196 rewriter
2197 .create<UnrealizedConversionCastOp>(
2198 rsqrtOp.getLoc(), vLNbf16OpaqueTy, adaptor.getOperand())
2199 .getResult(0);
2200 SmallVector<Value> rsqrtOperands = {opaquedOperand};
2201 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2202 rsqrtOp.getLoc(), TypeRange{vLNbf16OpaqueTy}, "getRsqrtBf16", nullptr,
2203 nullptr, rsqrtOperands);
2204 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2205 rsqrtOp, TypeRange{rsqrtOp.getResult().getType()}, callOp.getResults());
2206
2207 return success();
2208 }
2209};
2210
2211// Convert math.erf to a function call to compute erf(x) for v16bfloat16 and
2212// v32bfloat16 types
2214 using OpConversionPattern::OpConversionPattern;
2215
2216 LogicalResult
2217 matchAndRewrite(math::ErfOp erfOp, OpAdaptor adaptor,
2218 ConversionPatternRewriter &rewriter) const override {
2219 auto srcType = dyn_cast<VectorType>(erfOp.getOperand().getType());
2220 if (!srcType)
2221 return failure();
2222
2223 Type scalarType = srcType.getElementType();
2224 if (!isa<FloatType>(scalarType))
2225 return failure();
2226
2227 unsigned laneSize = getVectorLaneSize(srcType);
2228 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2229 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
2230 return failure();
2231
2232 StringRef includeName = "vec_math.h";
2233 auto moduleOp = erfOp->getParentOfType<mlir::ModuleOp>();
2234 rewriter.setInsertionPointToStart(
2235 &moduleOp.getRegion().getBlocks().front());
2236 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName, false);
2237
2238 rewriter.setInsertionPoint(erfOp);
2239 Type vLNbf16OpaqueTy;
2240 if (laneSize == 16)
2241 vLNbf16OpaqueTy =
2242 emitc::OpaqueType::get(rewriter.getContext(), "v16bfloat16");
2243 else
2244 vLNbf16OpaqueTy =
2245 emitc::OpaqueType::get(rewriter.getContext(), "v32bfloat16");
2246 auto opaquedOperand =
2247 rewriter
2248 .create<UnrealizedConversionCastOp>(erfOp.getLoc(), vLNbf16OpaqueTy,
2249 adaptor.getOperand())
2250 .getResult(0);
2251 SmallVector<Value> erfOperands = {opaquedOperand};
2252 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2253 erfOp.getLoc(), TypeRange{vLNbf16OpaqueTy}, "getErfBf16", nullptr,
2254 nullptr, erfOperands);
2255 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2256 erfOp, TypeRange{erfOp.getResult().getType()}, callOp.getResults());
2257
2258 return success();
2259 }
2260};
2261
2262// Convert math.absf and math.absi to a function call to compute abs(x) for
2263// v16bfloat16, v32bfloat16, v16float, v16int32, v32int16 and v64int8 types
2264template <typename SrcOpTy>
2267 using OpAdaptor = typename SrcOpTy::Adaptor;
2268
2269 LogicalResult
2270 matchAndRewrite(SrcOpTy absOp, OpAdaptor adaptor,
2271 ConversionPatternRewriter &rewriter) const override {
2272 auto vecTy = dyn_cast<VectorType>(absOp.getOperand().getType());
2273 if (!vecTy)
2274 return failure();
2275
2276 Type elemTy = vecTy.getElementType();
2277
2278 unsigned laneSize = getVectorLaneSize(vecTy);
2279 unsigned elWidth = elemTy.getIntOrFloatBitWidth();
2280
2281 StringRef includeName = "vec_math.h";
2282 auto moduleOp = absOp->template getParentOfType<mlir::ModuleOp>();
2283 rewriter.setInsertionPointToStart(
2284 &moduleOp.getRegion().getBlocks().front());
2285 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName, false);
2286
2287 rewriter.setInsertionPoint(absOp);
2288 std::ostringstream typeName;
2289 typeName << "v" << laneSize;
2290 if (isa<FloatType>(elemTy)) {
2291 if (elWidth == 16)
2292 typeName << "bfloat16";
2293 else
2294 typeName << "float";
2295 } else
2296 typeName << "int" << elWidth;
2297 Type vecOpaqueTy =
2298 emitc::OpaqueType::get(rewriter.getContext(), typeName.str());
2299 auto opaquedOperand =
2300 rewriter
2301 .create<UnrealizedConversionCastOp>(absOp.getLoc(), vecOpaqueTy,
2302 adaptor.getOperand())
2303 .getResult(0);
2304 SmallVector<Value> absOperands = {opaquedOperand};
2305 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2306 absOp.getLoc(), TypeRange{vecOpaqueTy}, "getAbs", nullptr, nullptr,
2307 absOperands);
2308 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2309 absOp, TypeRange{absOp.getResult().getType()}, callOp.getResults());
2310
2311 return success();
2312 }
2313};
2314
2317
2318template <typename SrcOpTy>
2321 using OpAdaptor = typename SrcOpTy::Adaptor;
2322
2323 LogicalResult
2324 matchAndRewrite(SrcOpTy extOp, OpAdaptor adaptor,
2325 ConversionPatternRewriter &rewriter) const override {
2326 VectorType srcType = dyn_cast<VectorType>(extOp.getIn().getType());
2327 VectorType dstType = dyn_cast<VectorType>(extOp.getOut().getType());
2328
2329 auto accType = getVectorOpDestType(srcType, /*AIE2 =*/true);
2330 auto upsOp =
2331 rewriter.create<aievec::UPSOp>(extOp.getLoc(), accType, extOp.getIn());
2332
2333 if (dstType.getElementType().getIntOrFloatBitWidth() == 16) {
2334 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
2335 extOp.getLoc(), rewriter.getI32IntegerAttr(0));
2336 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
2337 extOp, dstType, upsOp.getResult(), shiftParamOp.getResult());
2338 } else
2339 rewriter.replaceOpWithNewOp<aievec::CastOp>(
2340 extOp, dstType, upsOp.getResult(), /*isResAcc*/ false);
2341
2342 return success();
2343 }
2344};
2345
2348
2349template <typename SrcOpTy>
2352 using OpAdaptor = typename SrcOpTy::Adaptor;
2353
2354 LogicalResult
2355 matchAndRewrite(SrcOpTy truncOp, OpAdaptor adaptor,
2356 ConversionPatternRewriter &rewriter) const override {
2357 VectorType srcType = dyn_cast<VectorType>(truncOp.getIn().getType());
2358 VectorType dstType = dyn_cast<VectorType>(truncOp.getOut().getType());
2359 Type scalarType = srcType.getElementType();
2360 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2361
2362 unsigned laneSize = getVectorLaneSize(srcType);
2363 auto accType = isa<IntegerType>(scalarType) && (elWidth == 32)
2364 ? createVectorType(laneSize, scalarType)
2365 : getVectorOpDestType(srcType, /*AIE2 =*/true);
2366
2367 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
2368 truncOp.getLoc(), rewriter.getI32IntegerAttr(0));
2369 if (elWidth == 16) {
2370 auto upsOp = rewriter.create<aievec::UPSOp>(truncOp.getLoc(), accType,
2371 truncOp.getIn());
2372 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
2373 truncOp, dstType, upsOp.getResult(), shiftParamOp.getResult());
2374 } else {
2375 auto castOp = rewriter.create<aievec::CastOp>(truncOp.getLoc(), accType,
2376 truncOp.getIn(), true);
2377 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
2378 truncOp, dstType, castOp.getResult(), shiftParamOp.getResult());
2379 }
2380
2381 return success();
2382 }
2383};
2384
2387
2388// If `op` is the last operation in the sequence:
2389// %0 = unrealized_conversion_cast <%IN> : <native type>, !emitc.opaque_type
2390// %1 = emitc.call_opaque <funcName>, %0...
2391// %2 = unrealized_conversion_cast %1 : !emitc.opaque_type, <native type>
2392// return the value <%IN>.
2393static std::optional<Value>
2394getUnOpaquedOperandOfEmitCOpaqueCallOp(Operation *op, StringRef funcName) {
2395 auto uccOp = dyn_cast<UnrealizedConversionCastOp>(op);
2396 if (!uccOp)
2397 return {};
2398
2399 auto inVal = uccOp.getInputs()[0];
2400 if (!isa<emitc::OpaqueType>(inVal.getType()))
2401 return {};
2402
2403 auto callOp = inVal.getDefiningOp<emitc::CallOpaqueOp>();
2404 if (callOp.getCallee() != funcName)
2405 return {};
2406
2407 auto callOperandsUccOp =
2408 callOp.getOperands()[0].getDefiningOp<UnrealizedConversionCastOp>();
2409 if (!callOperandsUccOp)
2410 return {};
2411
2412 return callOperandsUccOp.getInputs()[0];
2413}
2414
2415// Check there is an operation chain like-
2416//
2417// %cst_0 = arith.constant dense<1.000000e+00> : vector<16xbf16>
2418// %cst_1 = arith.constant 0.000000e+00 : bf16
2419// %0 = vector.transfer_read %arg0[%arg2], %cst_1 : memref<1024xbf16>,
2420// vector<16xbf16>
2421// %1 = arith.negf %0 : vector<16xbf16>
2422// %2 = math.exp %1 : vector<16xbf16>
2423// %3 = arith.addf %2, %cst_0 : vector<16xbf16>
2424// %4 = arith.divf %cst_0, %3 : vector<16xbf16>
2425//
2426// so that this operation chain can be converted to a function call to compute
2427// sigmoid value for v16bfloat16 and v32bfloat16 types
2428template <typename DivFOpTy>
2429static bool hasSigmoidComputationChain(DivFOpTy divfOp, arith::NegFOp &negOp) {
2430 auto constOp = dyn_cast<arith::ConstantOp>(divfOp.getLhs().getDefiningOp());
2431 if (!constOp)
2432 return false;
2433
2434 auto cstDense = dyn_cast<DenseFPElementsAttr>(constOp.getValue());
2435 if (!cstDense)
2436 return false;
2437
2438 if (cstDense.template getSplatValue<APFloat>().convertToFloat() != 1.0f)
2439 return false;
2440
2441 Operation *addLvalOp;
2442 Operation *addRvalOp;
2443 // divfOp's rval could be an arith::AddFOp or the pattern like-
2444 // %1 = aievec.ups %a
2445 // %2 = aievec.ups %b;
2446 // %3 = aievec.add_elem %1, %2
2447 // %4 = aievec.srs %3;
2448 auto addOp = dyn_cast<arith::AddFOp>(divfOp.getRhs().getDefiningOp());
2449 if (!addOp) {
2450 auto srsOp = dyn_cast<aievec::SRSOp>(divfOp.getRhs().getDefiningOp());
2451 if (!srsOp)
2452 return false;
2453
2454 auto addElemOp =
2455 dyn_cast<aievec::AddElemOp>(srsOp.getSource().getDefiningOp());
2456 if (!addElemOp)
2457 return false;
2458
2459 auto lUpsOp = dyn_cast<aievec::UPSOp>(addElemOp.getLhs().getDefiningOp());
2460 auto rUpsOp = dyn_cast<aievec::UPSOp>(addElemOp.getRhs().getDefiningOp());
2461 if (!lUpsOp || !rUpsOp)
2462 return false;
2463
2464 addLvalOp = lUpsOp.getSource().getDefiningOp();
2465 addRvalOp = rUpsOp.getSource().getDefiningOp();
2466 // One of add operation's operand is a constant op and another operand could
2467 // be arith::ExpOp or the combination of emitc.call and aievec.srs
2468 auto addDefOp = isa<arith::ConstantOp>(addLvalOp)
2469 ? dyn_cast<aievec::SRSOp>(addRvalOp)
2470 : dyn_cast<aievec::SRSOp>(addLvalOp);
2471 if (!addDefOp)
2472 addLvalOp = isa<arith::ConstantOp>(addLvalOp)
2473 ? dyn_cast<math::ExpOp>(addRvalOp)
2474 : dyn_cast<math::ExpOp>(addLvalOp);
2475 else
2476 addLvalOp = addDefOp.getSource().getDefiningOp();
2477
2478 addRvalOp = isa<arith::ConstantOp>(addLvalOp)
2479 ? lUpsOp.getSource().getDefiningOp()
2480 : rUpsOp.getSource().getDefiningOp();
2481 } else {
2482 addLvalOp = addOp.getLhs().getDefiningOp();
2483 addRvalOp = addOp.getRhs().getDefiningOp();
2484 }
2485
2486 if (!addLvalOp || !addRvalOp)
2487 return false;
2488
2489 auto addLvalExpOp = dyn_cast<math::ExpOp>(addLvalOp);
2490 auto addRvalExpOp = dyn_cast<math::ExpOp>(addRvalOp);
2491 auto addLvalExpOpIn =
2492 getUnOpaquedOperandOfEmitCOpaqueCallOp(addLvalOp, "getExpBf16")
2493 .value_or(nullptr);
2494 auto addRvalExpOpIn =
2495 getUnOpaquedOperandOfEmitCOpaqueCallOp(addRvalOp, "getExpBf16")
2496 .value_or(nullptr);
2497 if (!addLvalExpOpIn && addLvalExpOp)
2498 addLvalExpOpIn = addLvalExpOp.getOperand();
2499 if (!addRvalExpOpIn && addRvalExpOp)
2500 addRvalExpOpIn = addRvalExpOp.getOperand();
2501
2502 if (!((addLvalExpOpIn && isa<arith::ConstantOp>(addRvalOp)) ||
2503 (addRvalExpOpIn && isa<arith::ConstantOp>(addLvalOp))))
2504 return false;
2505
2506 constOp = isa<arith::ConstantOp>(addLvalOp)
2507 ? cast<arith::ConstantOp>(addLvalOp)
2508 : cast<arith::ConstantOp>(addRvalOp);
2509
2510 cstDense = dyn_cast<DenseFPElementsAttr>(constOp.getValue());
2511 if (!cstDense)
2512 return false;
2513 if (cstDense.template getSplatValue<APFloat>().convertToFloat() != 1.0f)
2514 return false;
2515
2516 auto expOperand = addLvalExpOpIn ? addLvalExpOpIn : addRvalExpOpIn;
2517
2518 negOp = expOperand.getDefiningOp<arith::NegFOp>();
2519
2520 return negOp != nullptr;
2521}
2522
2523// Convert the operation chain like-
2524//
2525// %cst_0 = arith.constant dense<1.000000e+00> : vector<16xbf16>
2526// %cst_1 = arith.constant 0.000000e+00 : bf16
2527// %0 = vector.transfer_read %arg0[%arg2], %cst_1 : memref<1024xbf16>,
2528// vector<16xbf16>
2529// %1 = arith.negf %0 : vector<16xbf16>
2530// %2 = math.exp %1 :vector<16xbf16>
2531// %3 = arith.addf %2, %cst_0 : vector<16xbf16>
2532// %4 = arith.divf %cst_0, %3 : vector<16xbf16>
2533//
2534// to a function call to compute sigmoid value for v16bfloat16 and
2535// v32bfloat16 types
2537 using OpConversionPattern::OpConversionPattern;
2538
2539 LogicalResult
2540 matchAndRewrite(arith::DivFOp divfOp, OpAdaptor adaptor,
2541 ConversionPatternRewriter &rewriter) const override {
2542 auto srcType = dyn_cast<VectorType>(adaptor.getLhs().getType());
2543 if (!srcType)
2544 return failure();
2545
2546 Type scalarType = srcType.getElementType();
2547 if (!isa<FloatType>(scalarType))
2548 return failure();
2549
2550 unsigned laneSize = getVectorLaneSize(srcType);
2551 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2552 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
2553 return failure();
2554
2555 arith::NegFOp negOp = nullptr;
2556 if (!hasSigmoidComputationChain(adaptor, negOp))
2557 return failure();
2558
2559 StringRef includeName = "vec_math.h";
2560 auto moduleOp = divfOp->getParentOfType<mlir::ModuleOp>();
2561 rewriter.setInsertionPointToStart(
2562 &moduleOp.getRegion().getBlocks().front());
2563 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName, false);
2564
2565 rewriter.setInsertionPoint(divfOp);
2566 Type vecOpaqueTy;
2567 if (laneSize == 16)
2568 vecOpaqueTy =
2569 emitc::OpaqueType::get(rewriter.getContext(), "v16bfloat16");
2570 else
2571 vecOpaqueTy =
2572 emitc::OpaqueType::get(rewriter.getContext(), "v32bfloat16");
2573 auto opaquedOperand =
2574 rewriter
2575 .create<UnrealizedConversionCastOp>(divfOp.getLoc(), vecOpaqueTy,
2576 negOp.getOperand())
2577 .getResult(0);
2578 SmallVector<Value> sigmoidOperands = {opaquedOperand};
2579 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2580 divfOp.getLoc(), TypeRange{vecOpaqueTy}, "getSigmoidBf16", nullptr,
2581 nullptr, sigmoidOperands);
2582 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2583 divfOp, TypeRange{adaptor.getLhs().getType()}, callOp.getResults());
2584
2585 return success();
2586 }
2587};
2588
2589// Convert math.ceil to a function call to compute ceil(x) for v16bfloat16
2591 using OpConversionPattern::OpConversionPattern;
2592
2593 LogicalResult
2594 matchAndRewrite(math::CeilOp ceilOp, OpAdaptor adaptor,
2595 ConversionPatternRewriter &rewriter) const override {
2596 auto srcType = dyn_cast<VectorType>(ceilOp.getOperand().getType());
2597 if (!srcType)
2598 return failure();
2599
2600 Type scalarType = srcType.getElementType();
2601 if (!isa<FloatType>(scalarType))
2602 return failure();
2603
2604 unsigned laneSize = getVectorLaneSize(srcType);
2605 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2606 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
2607 return failure();
2608
2609 StringRef includeName = "vec_math.h";
2610 auto moduleOp = ceilOp->getParentOfType<mlir::ModuleOp>();
2611 rewriter.setInsertionPointToStart(
2612 &moduleOp.getRegion().getBlocks().front());
2613 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName, false);
2614
2615 rewriter.setInsertionPoint(ceilOp);
2616 Type vecOpaqueTy;
2617 if (laneSize == 16)
2618 vecOpaqueTy =
2619 emitc::OpaqueType::get(rewriter.getContext(), "v16bfloat16");
2620 else
2621 vecOpaqueTy =
2622 emitc::OpaqueType::get(rewriter.getContext(), "v32bfloat16");
2623 auto opaquedOperand =
2624 rewriter
2625 .create<UnrealizedConversionCastOp>(ceilOp.getLoc(), vecOpaqueTy,
2626 adaptor.getOperand())
2627 .getResult(0);
2628 SmallVector<Value> ceilOperands = {opaquedOperand};
2629 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2630 ceilOp.getLoc(), TypeRange{vecOpaqueTy}, "getCeilBf16", nullptr,
2631 nullptr, ceilOperands);
2632 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2633 ceilOp, TypeRange{ceilOp.getResult().getType()}, callOp.getResults());
2634
2635 return success();
2636 }
2637};
2638
2639// Convert math.floor to a function call to compute floor(x) for v16bfloat16
2641 using OpConversionPattern::OpConversionPattern;
2642
2643 LogicalResult
2644 matchAndRewrite(math::FloorOp floorOp, OpAdaptor adaptor,
2645 ConversionPatternRewriter &rewriter) const override {
2646 auto srcType = dyn_cast<VectorType>(floorOp.getOperand().getType());
2647 if (!srcType)
2648 return failure();
2649
2650 Type scalarType = srcType.getElementType();
2651 if (!isa<FloatType>(scalarType))
2652 return failure();
2653
2654 unsigned laneSize = getVectorLaneSize(srcType);
2655 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2656 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
2657 return failure();
2658
2659 StringRef includeName = "vec_math.h";
2660 auto moduleOp = floorOp->getParentOfType<mlir::ModuleOp>();
2661 rewriter.setInsertionPointToStart(
2662 &moduleOp.getRegion().getBlocks().front());
2663 rewriter.create<emitc::IncludeOp>(moduleOp.getLoc(), includeName, false);
2664
2665 rewriter.setInsertionPoint(floorOp);
2666 Type vecOpaqueTy;
2667 if (laneSize == 16)
2668 vecOpaqueTy =
2669 emitc::OpaqueType::get(rewriter.getContext(), "v16bfloat16");
2670 else
2671 vecOpaqueTy =
2672 emitc::OpaqueType::get(rewriter.getContext(), "v32bfloat16");
2673 auto opaquedOperand =
2674 rewriter
2675 .create<UnrealizedConversionCastOp>(floorOp.getLoc(), vecOpaqueTy,
2676 adaptor.getOperand())
2677 .getResult(0);
2678 SmallVector<Value> floorOperands = {opaquedOperand};
2679 auto callOp = rewriter.create<emitc::CallOpaqueOp>(
2680 floorOp.getLoc(), TypeRange{vecOpaqueTy}, "getFloorBf16", nullptr,
2681 nullptr, floorOperands);
2682 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
2683 floorOp, TypeRange{floorOp.getResult().getType()}, callOp.getResults());
2684
2685 return success();
2686 }
2687};
2688
2689// Convert arith.negf to aievec.neg to negate the vector for v16bfloat16 and
2690// v16float types.
2692 using OpConversionPattern::OpConversionPattern;
2693
2694 LogicalResult
2695 matchAndRewrite(arith::NegFOp negOp, OpAdaptor adaptor,
2696 ConversionPatternRewriter &rewriter) const override {
2697 auto srcType = dyn_cast<VectorType>(negOp.getOperand().getType());
2698 if (!srcType)
2699 return failure();
2700
2701 Type scalarType = srcType.getElementType();
2702 if (!isa<FloatType>(scalarType))
2703 return failure();
2704
2705 if (unsigned laneSize = getVectorLaneSize(srcType); laneSize != 16)
2706 return failure();
2707
2708 Location loc = negOp.getLoc();
2709 auto accType = getVectorOpDestType(srcType, /*AIE2 =*/true);
2710
2711 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2712 if (elWidth == 16) {
2713 auto upsOp =
2714 rewriter.create<aievec::UPSOp>(loc, accType, adaptor.getOperand());
2715 auto aieNegOp =
2716 rewriter.create<aievec::NegOp>(loc, accType, upsOp.getResult());
2717 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
2718 negOp.getLoc(), rewriter.getI32IntegerAttr(0));
2719 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
2720 negOp, srcType, aieNegOp.getResult(), shiftParamOp.getResult());
2721 } else {
2722 auto castOp = rewriter.create<aievec::CastOp>(
2723 loc, accType, adaptor.getOperand(), /*isResAcc*/ true);
2724 auto aieNegOp =
2725 rewriter.create<aievec::NegOp>(loc, accType, castOp.getResult());
2726 rewriter.replaceOpWithNewOp<aievec::CastOp>(
2727 negOp, srcType, aieNegOp.getResult(), /*isResAcc*/ false);
2728 }
2729
2730 return success();
2731 }
2732};
2733
2734// Check whether the value of constant operation is int type and the dense value
2735// is -1.
2736static bool hasConstNegOneValue(arith::ConstantOp constOp, unsigned elWidth) {
2737 if (!constOp)
2738 return false;
2739
2740 auto cstDense = dyn_cast<DenseIntElementsAttr>(constOp.getValue());
2741 if (!cstDense)
2742 return false;
2743
2744 if (elWidth == 32)
2745 return cstDense.getSplatValue<int32_t>() == -1;
2746 if (elWidth == 16)
2747 return cstDense.getSplatValue<int16_t>() == -1;
2748 if (elWidth == 8)
2749 return cstDense.getSplatValue<int8_t>() == -1;
2750 return false;
2751}
2752
2753// Convert arith.xori to aievec.bxor to compute bitwise xor of two vectors for
2754// integer types
2756 using OpConversionPattern::OpConversionPattern;
2757
2758 LogicalResult
2759 matchAndRewrite(arith::XOrIOp xorOp, OpAdaptor adaptor,
2760 ConversionPatternRewriter &rewriter) const override {
2761 auto srcType = dyn_cast<VectorType>(xorOp.getLhs().getType());
2762 if (!srcType)
2763 return failure();
2764
2765 Type scalarType = srcType.getElementType();
2766 if (!isa<IntegerType>(scalarType))
2767 return failure();
2768
2769 unsigned laneSize = getVectorLaneSize(srcType);
2770 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2771 if (laneSize * elWidth != 512)
2772 return failure();
2773
2774 auto lhsConstOp =
2775 dyn_cast<arith::ConstantOp>(xorOp.getLhs().getDefiningOp());
2776 auto rhsConstOp =
2777 dyn_cast<arith::ConstantOp>(xorOp.getRhs().getDefiningOp());
2778
2779 // If one of operands in xorOp is a constant -1, xorOp will be replaced with
2780 // aievec::BnegOp.
2781 if ((lhsConstOp && hasConstNegOneValue(lhsConstOp, elWidth)) ||
2782 (rhsConstOp && hasConstNegOneValue(rhsConstOp, elWidth))) {
2783 Value val = hasConstNegOneValue(lhsConstOp, elWidth) ? adaptor.getRhs()
2784 : adaptor.getLhs();
2785 rewriter.replaceOpWithNewOp<aievec::BnegOp>(xorOp, srcType, val);
2786 } else
2787 rewriter.replaceOpWithNewOp<aievec::BxorOp>(
2788 xorOp, srcType, adaptor.getLhs(), adaptor.getRhs());
2789
2790 return success();
2791 }
2792};
2793
2794template <typename SrcOpTy, typename DstOpTy>
2797 using OpAdaptor = typename SrcOpTy::Adaptor;
2798
2799 LogicalResult
2800 matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor,
2801 ConversionPatternRewriter &rewriter) const override {
2802 VectorType srcType = dyn_cast<VectorType>(srcOp.getLhs().getType());
2803 if (!srcType)
2804 return failure();
2805
2806 Type scalarType = srcType.getElementType();
2807 if (!isa<IntegerType>(scalarType))
2808 return failure();
2809
2810 unsigned laneSize = getVectorLaneSize(srcType);
2811 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2812 if (laneSize * elWidth != 512)
2813 return failure();
2814
2815 rewriter.replaceOpWithNewOp<DstOpTy>(srcOp, srcOp.getResult().getType(),
2816 adaptor.getLhs(), adaptor.getRhs());
2817
2818 return success();
2819 }
2820};
2821
2826
2827// Convert arith.shrsi to a combination of aievec.ups and aievec.srs to compute
2828// arithmetic right shift for integer types. Currently, only support the shift
2829// value with a broadcast vector.
2831 : OpConversionPattern<arith::ShRSIOp> {
2832 using OpConversionPattern::OpConversionPattern;
2833
2834 LogicalResult
2835 matchAndRewrite(arith::ShRSIOp rsOp, OpAdaptor adaptor,
2836 ConversionPatternRewriter &rewriter) const override {
2837 auto srcType = dyn_cast<VectorType>(adaptor.getLhs().getType());
2838 if (!srcType)
2839 return failure();
2840
2841 Type scalarType = srcType.getElementType();
2842 unsigned laneSize = getVectorLaneSize(srcType);
2843 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
2844 if (laneSize * elWidth != 512)
2845 return failure();
2846
2847 auto bcastOp =
2848 dyn_cast<aievec::BroadcastOp>(adaptor.getRhs().getDefiningOp());
2849 if (!bcastOp)
2850 return failure();
2851
2852 auto constOp = rewriter.create<arith::ConstantOp>(
2853 bcastOp.getLoc(), rewriter.getI32IntegerAttr(bcastOp.getIdx()));
2854 auto extElemOp = rewriter.create<aievec::ExtElemOp>(
2855 bcastOp.getLoc(), scalarType, bcastOp, constOp.getResult());
2856 Location loc = rsOp.getLoc();
2857
2858 // The vector with v64int8 type can be divided into two v32int8 vectors and
2859 // be processed individually and be concatenated at the end.
2860 if (elWidth == 8) {
2861 VectorType halfSrcType = createVectorType(laneSize / 2, scalarType);
2862 auto rsOpLow =
2863 rewriter.create<aievec::ExtOp>(loc, halfSrcType, adaptor.getLhs(), 0);
2864 auto rsOpHigh =
2865 rewriter.create<aievec::ExtOp>(loc, halfSrcType, adaptor.getLhs(), 1);
2866 Type accType = getVectorOpDestType(halfSrcType, /*AIE2 =*/true);
2867 auto upsOpLow =
2868 rewriter.create<aievec::UPSOp>(loc, accType, rsOpLow.getResult());
2869 auto srsOpLow = rewriter.create<aievec::SRSOp>(
2870 loc, halfSrcType, upsOpLow.getResult(), extElemOp.getResult());
2871 auto upsOpHigh =
2872 rewriter.create<aievec::UPSOp>(loc, accType, rsOpHigh.getResult());
2873 auto srsOpHigh = rewriter.create<aievec::SRSOp>(
2874 loc, halfSrcType, upsOpHigh.getResult(), extElemOp.getResult());
2875 SmallVector<Value> inputSources = {srsOpLow.getResult(),
2876 srsOpHigh.getResult()};
2877 rewriter.replaceOpWithNewOp<aievec::ConcatOp>(rsOp, srcType,
2878 inputSources);
2879 } else {
2880 Type accType = getVectorOpDestType(srcType, /*AIE2 =*/true);
2881 auto upsOp =
2882 rewriter.create<aievec::UPSOp>(loc, accType, adaptor.getLhs());
2883 rewriter.replaceOpWithNewOp<aievec::SRSOp>(
2884 rsOp, srcType, upsOp.getResult(), extElemOp.getResult());
2885 }
2886
2887 return success();
2888 }
2889};
2890
2891// Convert a `vector.contract` op to an `aievec.matmul` op for AIE2
2893 : OpConversionPattern<vector::ContractionOp> {
2894 using OpConversionPattern::OpConversionPattern;
2895
2899
2900 Value reshapeLeadingUnitDims(OpBuilder &b, Value v) const {
2901 auto vecTy = dyn_cast<VectorType>(v.getType());
2902 if (!vecTy)
2903 return v;
2904 auto vecShape = vecTy.getShape();
2905
2906 size_t numLeadUnitDims = 0;
2907 while (numLeadUnitDims < vecShape.size() && vecShape[numLeadUnitDims] == 1)
2908 numLeadUnitDims++;
2909
2910 if (!numLeadUnitDims)
2911 return v;
2912
2913 SmallVector<int64_t> newShape(vecShape.begin() + numLeadUnitDims,
2914 vecShape.end());
2915 auto newVecTy = VectorType::get(newShape, vecTy.getElementType());
2916 return b.create<vector::ShapeCastOp>(v.getLoc(), newVecTy, v).getResult();
2917 }
2918
2919 LogicalResult
2920 matchAndRewrite(vector::ContractionOp contractOp, OpAdaptor adaptor,
2921 ConversionPatternRewriter &rewriter) const override {
2922 auto lhs = reshapeLeadingUnitDims(rewriter, adaptor.getLhs());
2923 auto rhs = reshapeLeadingUnitDims(rewriter, adaptor.getRhs());
2924 auto acc = reshapeLeadingUnitDims(rewriter, adaptor.getAcc());
2925 bool bReshapedAcc = (acc != adaptor.getAcc());
2926
2927 if (matMoveToAcc)
2928 acc = rewriter.create<aievec::CastOp>(contractOp.getLoc(), acc.getType(),
2929 acc, true);
2930
2931 auto matmulOp = rewriter.create<aievec::MatMulOp>(
2932 contractOp.getLoc(), acc.getType(), lhs, rhs, acc);
2933 {
2934 // Replace diagnostics handler to silence errors when verifying the
2935 // validity of the `aievec.matmul` ops being generated.
2936 ScopedDiagnosticHandler diagHandler(
2937 contractOp.getContext(), [](Diagnostic &) { return success(); });
2938 if (failed(matmulOp.verifyInvariants())) {
2939 rewriter.eraseOp(matmulOp);
2940 // There is a possibility that, when the linalg op is converted to
2941 // contractions, lower precisions operands are cast to the target
2942 // precission outside the contraction. For those cases, we check.
2943 lhs = adaptor.getLhs();
2944 auto wideLhsValue = getSourceOfWideningOp(lhs).value_or(nullptr);
2945 if (wideLhsValue)
2946 lhs = reshapeLeadingUnitDims(rewriter, wideLhsValue);
2947
2948 rhs = adaptor.getRhs();
2949 auto wideRhsValue = getSourceOfWideningOp(rhs).value_or(nullptr);
2950 if (wideRhsValue)
2951 rhs = reshapeLeadingUnitDims(rewriter, wideRhsValue);
2952
2953 matmulOp = rewriter.create<aievec::MatMulOp>(
2954 contractOp.getLoc(), acc.getType(), lhs, rhs, acc);
2955 if (failed(matmulOp.verifyInvariants()))
2956 return failure();
2957 }
2958 }
2959
2960 Value result = matmulOp.getResult();
2961 if (matMoveToAcc)
2962 result = rewriter.create<aievec::CastOp>(contractOp.getLoc(),
2963 acc.getType(), matmulOp, false);
2964 if (bReshapedAcc)
2965 result = rewriter.create<vector::ShapeCastOp>(
2966 contractOp.getLoc(), adaptor.getAcc().getType(), result);
2967 rewriter.replaceOp(contractOp, result);
2968
2969 return success();
2970 }
2971
2973};
2974
2975// Convert a `vector.transpose` op to an `aievec.shuffle` op for AIE2.
2977 : OpConversionPattern<vector::TransposeOp> {
2978 using OpConversionPattern::OpConversionPattern;
2979 LogicalResult
2980 matchAndRewrite(vector::TransposeOp transpOp, OpAdaptor adaptor,
2981 ConversionPatternRewriter &rewriter) const override {
2982 auto resTy = transpOp.getResultVectorType();
2983 auto resShape = resTy.getShape();
2984 auto elemTyBitWidth = resTy.getElementTypeBitWidth();
2985 auto vBitWidth = std::accumulate(resShape.begin(), resShape.end(),
2986 elemTyBitWidth, std::multiplies<>());
2987 if (vBitWidth != 512)
2988 return failure();
2989
2990 if (elemTyBitWidth != 8 && elemTyBitWidth != 16 && elemTyBitWidth != 32)
2991 return failure();
2992
2993 // Verify leading dimensions are all 1.
2994 for (int64_t i = 0; i < static_cast<int64_t>(resShape.size() - 2); ++i)
2995 if (resShape[i] != 1)
2996 return failure();
2997
2998 // Only permutation of the 2 innermost dimensions are supported.
2999 ArrayRef<int64_t> perm = transpOp.getPermutation();
3000 for (int64_t i = 0; i < static_cast<int64_t>(perm.size() - 2); ++i)
3001 if (perm[i] != i)
3002 return failure();
3003 if (perm.back() != static_cast<int64_t>(perm.size() - 2))
3004 return failure();
3005
3006 auto shuffleMode = aievec::ShuffleMode::T32_4X4;
3007 if (elemTyBitWidth == 8) {
3008 switch (resShape.back()) {
3009 case 4:
3010 shuffleMode = aievec::ShuffleMode::T8_4X16;
3011 break;
3012 case 8:
3013 shuffleMode = aievec::ShuffleMode::T8_8X8;
3014 break;
3015 case 16:
3016 shuffleMode = aievec::ShuffleMode::T8_16X4;
3017 break;
3018 default:
3019 return failure();
3020 }
3021 } else if (elemTyBitWidth == 16) {
3022 switch (resShape.back()) {
3023 case 2:
3024 shuffleMode = aievec::ShuffleMode::T16_2X16;
3025 break;
3026 case 4:
3027 shuffleMode = aievec::ShuffleMode::T16_4X8;
3028 break;
3029 case 8:
3030 shuffleMode = aievec::ShuffleMode::T16_8X4;
3031 break;
3032 case 16:
3033 shuffleMode = aievec::ShuffleMode::T16_16X2;
3034 break;
3035 default:
3036 return failure();
3037 }
3038 } else if (resShape.back() != 4)
3039 return failure();
3040
3041 auto flatVecTy =
3042 VectorType::get({512 / elemTyBitWidth}, resTy.getElementType());
3043 auto loc = transpOp.getLoc();
3044 auto flatInput = rewriter.create<vector::ShapeCastOp>(loc, flatVecTy,
3045 adaptor.getVector());
3046 auto shuffOp = rewriter.create<aievec::ShuffleOp>(loc, flatVecTy, flatInput,
3047 nullptr, shuffleMode);
3048 rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transpOp, resTy, shuffOp);
3049
3050 return success();
3051 }
3052};
3053
3054//===----------------------------------------------------------------------===//
3055// Pattern collection
3056//===----------------------------------------------------------------------===//
3057
3058static void populateAIEVecCommonConversionPatterns(RewritePatternSet &patterns,
3059 TargetBackend backend) {
3060 // clang-format off
3061 patterns.add<LowerExtFOpPattern,
3064 LowerTruncIOpPattern>(patterns.getContext());
3065 // clang-format on
3066}
3067
3068static void populateAIEVecV1ConversionPatterns(RewritePatternSet &patterns,
3069 TargetBackend backend) {
3070 patterns.add<LowerVectorTransferReadToAIEUPD>(patterns.getContext(), 128, 512,
3071 128, 256);
3072 // clang-format off
3073 patterns.add<LowerVectorAddIOpToAIEVecAddOp,
3081 LowerVectorExtractStridedSliceOpAIEv1Pattern>(patterns.getContext());
3082 // clang-format on
3083}
3084
3085static void populateAIEVecV2ConversionPatterns(RewritePatternSet &patterns,
3086 TargetBackend backend) {
3087 // clang-format off
3088 // TODO: Reorder these alphabetically
3089 if (backend == TargetBackend::CPP) {
3090 patterns.add<
3092 >(patterns.getContext(), 128, 1024, 256, 1024);
3093 patterns.add<
3099 >(patterns.getContext());
3100 } else if (backend == TargetBackend::LLVMIR){
3101 patterns.add<
3103 >(patterns.getContext());
3104 }
3105 patterns.add<
3141 >(patterns.getContext());
3143 >(patterns.getContext(), backend == TargetBackend::CPP);
3144 // clang-format on
3145}
3146
3147//===----------------------------------------------------------------------===//
3148// Legalizations
3149//===----------------------------------------------------------------------===//
3150
3151// TODO: Review the validity of these legalizations beyond basic cases.
3152
3153static bool isInSigmoidOperationChain(math::ExpOp expOp) {
3154 if (!expOp.getOperand().getDefiningOp<arith::NegFOp>())
3155 return false;
3156
3157 arith::AddFOp addOp = nullptr;
3158 for (Operation *user : expOp->getUsers()) {
3159 addOp = dyn_cast<arith::AddFOp>(user);
3160 if (addOp)
3161 break;
3162 }
3163
3164 if (!addOp)
3165 return false;
3166
3167 auto *addLvalOp = addOp.getLhs().getDefiningOp();
3168 auto *addRvalOp = addOp.getRhs().getDefiningOp();
3169 if (!((isa<math::ExpOp>(addLvalOp) && isa<arith::ConstantOp>(addRvalOp)) ||
3170 (isa<math::ExpOp>(addRvalOp) && isa<arith::ConstantOp>(addLvalOp))))
3171 return false;
3172
3173 auto constOp = isa<arith::ConstantOp>(addLvalOp)
3174 ? cast<arith::ConstantOp>(addLvalOp)
3175 : cast<arith::ConstantOp>(addRvalOp);
3176
3177 auto cstDense = dyn_cast<DenseFPElementsAttr>(constOp.getValue());
3178 if (!cstDense)
3179 return false;
3180
3181 if (cstDense.getSplatValue<APFloat>().convertToFloat() != 1.0f)
3182 return false;
3183
3184 arith::DivFOp divOp = nullptr;
3185 for (Operation *user : addOp->getUsers()) {
3186 divOp = dyn_cast<arith::DivFOp>(user);
3187 if (divOp)
3188 break;
3189 }
3190
3191 if (!divOp)
3192 return false;
3193
3194 constOp = dyn_cast<arith::ConstantOp>(divOp.getLhs().getDefiningOp());
3195 if (!constOp)
3196 return false;
3197 cstDense = dyn_cast<DenseFPElementsAttr>(constOp.getValue());
3198 if (!cstDense)
3199 return false;
3200 if (cstDense.getSplatValue<APFloat>().convertToFloat() != 1.0f)
3201 return false;
3202
3203 return true;
3204}
3205
3206static void configureAIEVecCommonLegalizations(ConversionTarget &target,
3207 TargetBackend backend) {
3208 target.addLegalDialect<xilinx::aievec::aie1::AIEVecAIE1Dialect,
3209 xilinx::aievec::AIEVecDialect, arith::ArithDialect,
3210 emitc::EmitCDialect, func::FuncDialect>();
3211 if (backend == TargetBackend::CPP) {
3212 target.addIllegalOp<vector::TransferReadOp>();
3213 }
3214 target.addIllegalOp<vector::ExtractStridedSliceOp>();
3215 target.addLegalOp<vector::BitCastOp>();
3216
3217 target.addDynamicallyLegalOp<arith::ExtFOp>([](arith::ExtFOp extfOp) {
3218 auto srcType = dyn_cast<VectorType>(extfOp.getIn().getType());
3219 auto dstType = dyn_cast<VectorType>(extfOp.getOut().getType());
3220 if (!srcType || !dstType)
3221 return true;
3222
3223 Type srcScalarType = srcType.getElementType();
3224 Type dstScalarType = dstType.getElementType();
3225 if (!isa<FloatType>(srcScalarType) || !isa<FloatType>(dstScalarType))
3226 return true;
3227
3228 unsigned srcLaneSize = getVectorLaneSize(srcType);
3229 unsigned dstLaneSize = getVectorLaneSize(dstType);
3230 unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth();
3231 unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth();
3232 return srcElWidth != 16 || srcLaneSize != 16 || dstElWidth != 32 ||
3233 dstLaneSize != 16;
3234 });
3235
3236 target.addDynamicallyLegalOp<arith::ExtSIOp>([](arith::ExtSIOp extsiOp) {
3237 auto srcType = dyn_cast<VectorType>(extsiOp.getIn().getType());
3238 auto dstType = dyn_cast<VectorType>(extsiOp.getOut().getType());
3239 if (!srcType || !dstType)
3240 return true;
3241
3242 Type srcScalarType = srcType.getElementType();
3243 Type dstScalarType = dstType.getElementType();
3244 if (!isa<IntegerType>(srcScalarType) || !isa<IntegerType>(dstScalarType))
3245 return true;
3246
3247 unsigned srcLaneSize = getVectorLaneSize(srcType);
3248 unsigned dstLaneSize = getVectorLaneSize(dstType);
3249 unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth();
3250 unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth();
3251 return srcLaneSize != 32 || (dstElWidth <= srcElWidth) ||
3252 (dstLaneSize != srcLaneSize);
3253 });
3254
3255 target.addDynamicallyLegalOp<arith::TruncFOp>([](arith::TruncFOp truncfOp) {
3256 auto srcType = dyn_cast<VectorType>(truncfOp.getIn().getType());
3257 auto dstType = dyn_cast<VectorType>(truncfOp.getOut().getType());
3258 if (!srcType || !dstType)
3259 return true;
3260
3261 Type srcScalarType = srcType.getElementType();
3262 Type dstScalarType = dstType.getElementType();
3263 if (!isa<FloatType>(srcScalarType) || !isa<FloatType>(dstScalarType))
3264 return true;
3265
3266 unsigned srcLaneSize = getVectorLaneSize(srcType);
3267 unsigned dstLaneSize = getVectorLaneSize(dstType);
3268 unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth();
3269 unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth();
3270 return srcElWidth != 32 || srcLaneSize != 16 || dstElWidth != 16 ||
3271 dstLaneSize != 16;
3272 });
3273
3274 target.addDynamicallyLegalOp<arith::TruncIOp>([](arith::TruncIOp trunciOp) {
3275 auto srcType = dyn_cast<VectorType>(trunciOp.getIn().getType());
3276 auto dstType = dyn_cast<VectorType>(trunciOp.getOut().getType());
3277 if (!srcType || !dstType)
3278 return true;
3279
3280 Type srcScalarType = srcType.getElementType();
3281 Type dstScalarType = dstType.getElementType();
3282 if (!isa<IntegerType>(srcScalarType) || !isa<IntegerType>(dstScalarType))
3283 return true;
3284
3285 unsigned srcLaneSize = getVectorLaneSize(srcType);
3286 unsigned dstLaneSize = getVectorLaneSize(dstType);
3287 unsigned srcElWidth = srcScalarType.getIntOrFloatBitWidth();
3288 unsigned dstElWidth = dstScalarType.getIntOrFloatBitWidth();
3289
3290 return srcLaneSize != 32 || (dstElWidth >= srcElWidth) ||
3291 (dstLaneSize != srcLaneSize);
3292 });
3293
3294 target.addDynamicallyLegalOp<math::ExpOp>([](math::ExpOp expOp) {
3295 auto srcType = dyn_cast<VectorType>(expOp.getOperand().getType());
3296 if (!srcType)
3297 return true;
3298
3299 Type scalarType = srcType.getElementType();
3300 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3301 unsigned laneSize = getVectorLaneSize(srcType);
3302 if (!isa<FloatType>(scalarType) || laneSize != 16 || elWidth != 16)
3303 return true;
3304 if (expOp->hasOneUse() && isInSigmoidOperationChain(expOp))
3305 return true;
3306
3307 return false;
3308 });
3309
3310 target.addDynamicallyLegalOp<math::TanhOp>([](math::TanhOp tanhOp) {
3311 auto srcType = dyn_cast<VectorType>(tanhOp.getOperand().getType());
3312 if (!srcType)
3313 return true;
3314
3315 Type scalarType = srcType.getElementType();
3316 if (!isa<FloatType>(scalarType))
3317 return true;
3318
3319 unsigned laneSize = getVectorLaneSize(srcType);
3320 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3321 return elWidth != 16 || laneSize != 16;
3322 });
3323
3324 target.addDynamicallyLegalOp<math::SqrtOp>([](math::SqrtOp sqrtOp) {
3325 auto srcType = dyn_cast<VectorType>(sqrtOp.getOperand().getType());
3326 if (!srcType)
3327 return true;
3328
3329 Type scalarType = srcType.getElementType();
3330 if (!isa<FloatType>(scalarType))
3331 return true;
3332
3333 unsigned laneSize = getVectorLaneSize(srcType);
3334 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3335 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
3336 });
3337
3338 target.addDynamicallyLegalOp<math::RsqrtOp>([](math::RsqrtOp rsqrtOp) {
3339 auto srcType = dyn_cast<VectorType>(rsqrtOp.getOperand().getType());
3340 Type scalarType = srcType.getElementType();
3341 if (!srcType || !isa<FloatType>(scalarType))
3342 return true;
3343
3344 unsigned laneSize = getVectorLaneSize(srcType);
3345 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3346 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
3347 });
3348
3349 target.addDynamicallyLegalOp<math::ErfOp>([](math::ErfOp erfOp) {
3350 auto srcType = dyn_cast<VectorType>(erfOp.getOperand().getType());
3351 if (!srcType)
3352 return true;
3353
3354 Type scalarType = srcType.getElementType();
3355 if (!isa<FloatType>(scalarType))
3356 return true;
3357
3358 unsigned laneSize = getVectorLaneSize(srcType);
3359 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3360 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
3361 });
3362
3363 target.addDynamicallyLegalOp<math::AbsFOp>([](math::AbsFOp absfOp) {
3364 auto srcType = dyn_cast<VectorType>(absfOp.getOperand().getType());
3365 if (!srcType)
3366 return true;
3367
3368 Type scalarType = srcType.getElementType();
3369 unsigned laneSize = getVectorLaneSize(srcType);
3370 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3371 return elWidth * laneSize != 512 && elWidth * laneSize != 256;
3372 });
3373
3374 target.addDynamicallyLegalOp<math::AbsIOp>([](math::AbsIOp absiOp) {
3375 auto srcType = dyn_cast<VectorType>(absiOp.getOperand().getType());
3376 if (!srcType)
3377 return true;
3378
3379 Type scalarType = srcType.getElementType();
3380 unsigned laneSize = getVectorLaneSize(srcType);
3381 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3382 return elWidth * laneSize != 512 && elWidth * laneSize != 256;
3383 });
3384
3385 target.addDynamicallyLegalOp<arith::DivFOp>([](arith::DivFOp divfOp) {
3386 if (auto srcType = dyn_cast<VectorType>(divfOp.getLhs().getType());
3387 !srcType) {
3388 Type scalarType = divfOp.getLhs().getType();
3389 if (!divfOp->hasOneUse() || !isa<FloatType>(scalarType))
3390 return true;
3391 if (!isNarrowingOp(*divfOp->getUsers().begin()))
3392 return true;
3393
3394 auto fType = cast<FloatType>(scalarType);
3395 if (fType.getWidth() != 32)
3396 return true;
3397
3398 auto constOp =
3399 dyn_cast<arith::ConstantOp>(divfOp.getLhs().getDefiningOp());
3400 if (!constOp ||
3401 cast<FloatAttr>(constOp.getValue()).getValue().convertToDouble() !=
3402 1.0f)
3403 return true;
3404 } else {
3405 Type scalarType = srcType.getElementType();
3406 if (!isa<FloatType>(scalarType))
3407 return true;
3408
3409 unsigned laneSize = getVectorLaneSize(srcType);
3410 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3411
3412 if (elWidth != 16 || (laneSize != 16 && laneSize != 32))
3413 return true;
3414
3415 arith::NegFOp negOp = nullptr;
3416 if (!hasSigmoidComputationChain(divfOp, negOp))
3417 return true;
3418 }
3419
3420 return false;
3421 });
3422
3423 target.addDynamicallyLegalOp<math::CeilOp>([](math::CeilOp ceilOp) {
3424 auto srcType = dyn_cast<VectorType>(ceilOp.getOperand().getType());
3425 if (!srcType)
3426 return true;
3427 Type scalarType = srcType.getElementType();
3428 if (!isa<FloatType>(scalarType))
3429 return true;
3430
3431 unsigned laneSize = getVectorLaneSize(srcType);
3432 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3433 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
3434 });
3435
3436 target.addDynamicallyLegalOp<math::FloorOp>([](math::FloorOp floorOp) {
3437 auto srcType = dyn_cast<VectorType>(floorOp.getOperand().getType());
3438 if (!srcType)
3439 return true;
3440 Type scalarType = srcType.getElementType();
3441 if (!isa<FloatType>(scalarType))
3442 return true;
3443
3444 unsigned laneSize = getVectorLaneSize(srcType);
3445 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3446 return elWidth != 16 || (laneSize != 16 && laneSize != 32);
3447 });
3448
3449 target.addDynamicallyLegalOp<arith::NegFOp>([](arith::NegFOp negOp) {
3450 auto srcType = dyn_cast<VectorType>(negOp.getOperand().getType());
3451 if (!srcType)
3452 return true;
3453 if (Type scalarType = srcType.getElementType(); !isa<FloatType>(scalarType))
3454 return true;
3455
3456 unsigned laneSize = getVectorLaneSize(srcType);
3457 return laneSize != 16;
3458 });
3459
3460 target.addDynamicallyLegalOp<arith::XOrIOp>([](arith::XOrIOp xorOp) {
3461 auto srcType = dyn_cast<VectorType>(xorOp.getLhs().getType());
3462 if (!srcType)
3463 return true;
3464 Type scalarType = srcType.getElementType();
3465 if (!isa<IntegerType>(scalarType))
3466 return true;
3467
3468 unsigned laneSize = getVectorLaneSize(srcType);
3469 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3470
3471 return laneSize * elWidth != 512;
3472 });
3473
3474 target.addDynamicallyLegalOp<arith::OrIOp>([](arith::OrIOp orOp) {
3475 auto srcType = dyn_cast<VectorType>(orOp.getLhs().getType());
3476 if (!srcType)
3477 return true;
3478 Type scalarType = srcType.getElementType();
3479 if (!isa<IntegerType>(scalarType))
3480 return true;
3481
3482 unsigned laneSize = getVectorLaneSize(srcType);
3483 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3484
3485 return laneSize * elWidth != 512;
3486 });
3487
3488 target.addDynamicallyLegalOp<arith::ShRSIOp>([](arith::ShRSIOp rsOp) {
3489 auto srcType = dyn_cast<VectorType>(rsOp.getLhs().getType());
3490 if (!srcType)
3491 return true;
3492 Type scalarType = srcType.getElementType();
3493
3494 unsigned laneSize = getVectorLaneSize(srcType);
3495 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3496
3497 return laneSize * elWidth != 512;
3498 });
3499
3500 target.addDynamicallyLegalOp<arith::AndIOp>([](arith::AndIOp andOp) {
3501 auto srcType = dyn_cast<VectorType>(andOp.getLhs().getType());
3502 if (!srcType)
3503 return true;
3504 Type scalarType = srcType.getElementType();
3505 if (!isa<IntegerType>(scalarType))
3506 return true;
3507
3508 unsigned laneSize = getVectorLaneSize(srcType);
3509 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3510
3511 return laneSize * elWidth != 512;
3512 });
3513
3514 if (backend == TargetBackend::CPP) {
3515 target.addDynamicallyLegalOp<arith::AddIOp>(
3516 [](arith::AddIOp op) { return !isa<VectorType>(op.getType()); });
3517 }
3518 target.addDynamicallyLegalOp<arith::AddFOp>(
3519 [](arith::AddFOp op) { return !isa<VectorType>(op.getType()); });
3520 target.addDynamicallyLegalOp<arith::SubIOp>(
3521 [](arith::SubIOp op) { return !isa<VectorType>(op.getType()); });
3522 target.addDynamicallyLegalOp<arith::SubFOp>(
3523 [](arith::SubFOp op) { return !isa<VectorType>(op.getType()); });
3524}
3525
3526static void configureAIEVecV1Legalizations(ConversionTarget &target,
3527 TargetBackend backend) {
3528 target.addDynamicallyLegalOp<arith::MulIOp>(
3529 [](arith::MulIOp op) { return !isa<VectorType>(op.getType()); });
3530 target.addDynamicallyLegalOp<arith::MulFOp>(
3531 [](arith::MulFOp op) { return !isa<VectorType>(op.getType()); });
3532 target.addDynamicallyLegalOp<aievec::aie1::FMAOp>(
3533 [](xilinx::aievec::aie1::FMAOp op) {
3534 auto *lhsDefOp = op.getLhs().getDefiningOp();
3535 aievec::ConcatOp concatOp = nullptr;
3536 if (lhsDefOp)
3537 concatOp = dyn_cast<aievec::ConcatOp>(op.getLhs().getDefiningOp());
3538 if (!concatOp)
3539 return true;
3540
3541 vector::SplatOp srcSplat = nullptr;
3542 if (auto *lhsOp = concatOp.getSources()[0].getDefiningOp())
3543 srcSplat = dyn_cast<vector::SplatOp>(lhsOp);
3544 if (!srcSplat) {
3545 auto *rhsOp = op.getRhs().getDefiningOp();
3546 if (!rhsOp)
3547 return true;
3548 srcSplat = dyn_cast<vector::SplatOp>(rhsOp);
3549 }
3550
3551 if (srcSplat)
3552 if (auto *srcOp = srcSplat.getInput().getDefiningOp())
3553 return !isa<vector::ExtractOp>(srcOp);
3554
3555 return true;
3556 });
3557
3558 target.addDynamicallyLegalOp<aievec::aie1::AddOp>([](aievec::aie1::AddOp op) {
3559 auto lSrsOp = op.getLhs().getDefiningOp<aievec::SRSOp>();
3560 auto rSrsOp = op.getRhs().getDefiningOp<aievec::SRSOp>();
3561 return (!lSrsOp ||
3562 !lSrsOp.getSource().getDefiningOp<aievec::aie1::MulOp>()) &&
3563 (!rSrsOp ||
3564 !rSrsOp.getSource().getDefiningOp<aievec::aie1::MulOp>());
3565 });
3566 target.addLegalDialect<memref::MemRefDialect>();
3567}
3568
3569static void configureAIEVecV2Legalizations(ConversionTarget &target,
3570 TargetBackend backend) {
3571 target.addLegalOp<UnrealizedConversionCastOp>();
3572 target.addLegalOp<vector::ShapeCastOp>();
3573
3574 // A set recording the vector lane size and element width supported
3575 llvm::SmallSet<std::pair<unsigned, unsigned>, 16> laneSizeElWidthPairSet;
3576 laneSizeElWidthPairSet.insert({64, 8});
3577 laneSizeElWidthPairSet.insert({32, 16});
3578 laneSizeElWidthPairSet.insert({16, 32});
3579 laneSizeElWidthPairSet.insert({32, 32});
3580
3581 // A set recording the element width supported
3582 llvm::SmallSet<unsigned, 16> elWidthSet;
3583 elWidthSet.insert(8);
3584 elWidthSet.insert(16);
3585 elWidthSet.insert(32);
3586
3587 if (backend == TargetBackend::CPP) {
3588 target.addDynamicallyLegalOp<arith::AddIOp>([=](arith::AddIOp op) {
3589 auto resultType = dyn_cast<VectorType>(op.getType());
3590 if (!resultType)
3591 return true;
3592
3593 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3594 unsigned laneSize = getVectorLaneSize(resultType);
3595
3596 return !laneSizeElWidthPairSet.count(
3597 std::make_pair(laneSize, resultElWidth));
3598 });
3599 }
3600
3601 target.addDynamicallyLegalOp<arith::SubIOp>([=](arith::SubIOp op) {
3602 auto resultType = dyn_cast<VectorType>(op.getType());
3603 if (!resultType)
3604 return true;
3605 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3606 unsigned laneSize = getVectorLaneSize(resultType);
3607
3608 return !laneSizeElWidthPairSet.count(
3609 std::make_pair(laneSize, resultElWidth));
3610 });
3611
3612 target.addDynamicallyLegalOp<arith::AddFOp>([](arith::AddFOp op) {
3613 auto resultType = dyn_cast<VectorType>(op.getType());
3614 if (!resultType)
3615 return true;
3616
3617 unsigned laneSize = getVectorLaneSize(resultType);
3618 return laneSize != 16;
3619 });
3620
3621 target.addDynamicallyLegalOp<arith::SubFOp>([](arith::SubFOp op) {
3622 auto resultType = dyn_cast<VectorType>(op.getType());
3623 if (!resultType)
3624 return true;
3625
3626 unsigned laneSize = getVectorLaneSize(resultType);
3627 return laneSize != 16;
3628 });
3629
3630 target.addDynamicallyLegalOp<arith::MulIOp>([](arith::MulIOp op) {
3631 auto resultType = dyn_cast<VectorType>(op.getType());
3632 if (!resultType)
3633 return true;
3634 auto isAddOp = [&](Operation *op) { return isa<arith::AddIOp>(op); };
3635 // Verify it is not a part of MAC
3636 if (op->hasOneUse() && llvm::any_of(op->getUsers(), isAddOp))
3637 return true;
3638
3639 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3640 unsigned laneSize = getVectorLaneSize(resultType);
3641
3642 return (laneSize != 32 || (resultElWidth != 16 && resultElWidth != 8)) &&
3643 ((laneSize != 16 && laneSize != 32) || resultElWidth != 32);
3644 });
3645
3646 target.addDynamicallyLegalOp<arith::MulFOp>([](arith::MulFOp op) {
3647 auto resultType = dyn_cast<VectorType>(op.getType());
3648 if (!resultType)
3649 return true;
3650
3651 auto isAddOp = [&](Operation *op) { return isa<arith::AddFOp>(op); };
3652 // Verify it is not a part of FMA
3653 if (op->hasOneUse() && llvm::any_of(op->getUsers(), isAddOp))
3654 return true;
3655
3656 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3657 unsigned laneSize = getVectorLaneSize(resultType);
3658
3659 return laneSize != 16 || (resultElWidth != 16 && resultElWidth != 32);
3660 });
3661
3662 target.addDynamicallyLegalOp<arith::MinSIOp>([=](arith::MinSIOp op) {
3663 auto resultType = dyn_cast<VectorType>(op.getType());
3664 if (!resultType)
3665 return true;
3666
3667 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3668 unsigned laneSize = getVectorLaneSize(resultType);
3669
3670 return !elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512;
3671 });
3672
3673 target.addDynamicallyLegalOp<arith::MaxSIOp>([=](arith::MaxSIOp op) {
3674 auto resultType = dyn_cast<VectorType>(op.getType());
3675 if (!resultType)
3676 return true;
3677
3678 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3679 unsigned laneSize = getVectorLaneSize(resultType);
3680
3681 return !elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512;
3682 });
3683
3684 target.addDynamicallyLegalOp<arith::MinimumFOp>([=](arith::MinimumFOp op) {
3685 auto resultType = dyn_cast<VectorType>(op.getType());
3686 if (!resultType)
3687 return true;
3688
3689 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3690 unsigned laneSize = getVectorLaneSize(resultType);
3691
3692 return !elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512;
3693 });
3694
3695 target.addDynamicallyLegalOp<arith::MaximumFOp>([=](arith::MaximumFOp op) {
3696 auto resultType = dyn_cast<VectorType>(op.getType());
3697 if (!resultType)
3698 return true;
3699
3700 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3701 unsigned laneSize = getVectorLaneSize(resultType);
3702
3703 return !elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512;
3704 });
3705
3706 target.addDynamicallyLegalOp<arith::CmpIOp>([=](arith::CmpIOp op) {
3707 auto lhsType = dyn_cast<VectorType>(op.getLhs().getType());
3708 if (!lhsType)
3709 return true;
3710
3711 auto lhsElWidth = lhsType.getElementType().getIntOrFloatBitWidth();
3712 unsigned laneSize = getVectorLaneSize(lhsType);
3713
3714 return !elWidthSet.count(lhsElWidth) || laneSize * lhsElWidth != 512;
3715 });
3716
3717 target.addDynamicallyLegalOp<arith::CmpFOp>([=](arith::CmpFOp op) {
3718 auto lhsType = dyn_cast<VectorType>(op.getLhs().getType());
3719 if (!lhsType)
3720 return true;
3721
3722 auto lhsElWidth = lhsType.getElementType().getIntOrFloatBitWidth();
3723 unsigned laneSize = getVectorLaneSize(lhsType);
3724
3725 return !elWidthSet.count(lhsElWidth) || laneSize * lhsElWidth != 512;
3726 });
3727
3728 target.addDynamicallyLegalOp<arith::SelectOp>([=](arith::SelectOp op) {
3729 auto resultType = dyn_cast<VectorType>(op.getType());
3730 if (!resultType)
3731 return true;
3732
3733 auto resultElWidth = resultType.getElementType().getIntOrFloatBitWidth();
3734 unsigned laneSize = getVectorLaneSize(resultType);
3735
3736 return !elWidthSet.count(resultElWidth) || laneSize * resultElWidth != 512;
3737 });
3738
3739 target.addDynamicallyLegalOp<vector::ReductionOp>(
3740 [=](vector::ReductionOp op) {
3741 if (auto kind = op.getKind(); kind != vector::CombiningKind::ADD &&
3742 kind != vector::CombiningKind::MINSI &&
3743 kind != vector::CombiningKind::MINUI &&
3744 kind != vector::CombiningKind::MINIMUMF &&
3745 kind != vector::CombiningKind::MAXSI &&
3746 kind != vector::CombiningKind::MAXUI &&
3747 kind != vector::CombiningKind::MAXIMUMF)
3748 return true;
3749
3750 auto vType = dyn_cast<VectorType>(op.getVector().getType());
3751 if (!vType)
3752 return true;
3753
3754 llvm::SmallSet<std::pair<unsigned, signed>, 16> laneSizeElWidthPairSet;
3755 laneSizeElWidthPairSet.insert({64, 8});
3756 laneSizeElWidthPairSet.insert({32, 16});
3757 laneSizeElWidthPairSet.insert({32, 32});
3758 laneSizeElWidthPairSet.insert({16, 32});
3759
3760 Type scalarType = vType.getElementType();
3761 unsigned elWidth = scalarType.getIntOrFloatBitWidth();
3762 unsigned laneSize = getVectorLaneSize(vType);
3763
3764 if (isa<IntegerType>(scalarType) &&
3765 !laneSizeElWidthPairSet.count(std::make_pair(laneSize, elWidth)))
3766 return true;
3767
3768 if (isa<FloatType>(scalarType) && laneSize != 16 && laneSize != 32)
3769 return true;
3770
3771 return false;
3772 });
3773
3774 target.addIllegalOp<vector::ContractionOp, vector::TransposeOp,
3775 vector::FMAOp>();
3776}
3777
3778//===----------------------------------------------------------------------===//
3779// Lowering passes
3780//===----------------------------------------------------------------------===//
3781
3782/// Lower incoming vector operations into their corresponding AIE vector
3783/// intrinsics.
3784struct LowerVectorToAIEVec : PassWrapper<LowerVectorToAIEVec, OperationPass<>> {
3785 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerVectorToAIEVec)
3786
3789
3795
3796 // In case we want to register this pass as a standalone pass for test
3797 // purposes.
3798 StringRef getArgument() const final { return "test-lower-vector-to-aievec"; }
3799 StringRef getDescription() const final {
3800 return "Lower vector operations to AIE vector intrinsics";
3801 }
3802 void getDependentDialects(DialectRegistry &registry) const override {
3803 registry
3804 .insert<affine::AffineDialect, xilinx::aievec::aie1::AIEVecAIE1Dialect,
3805 xilinx::aievec::AIEVecDialect, arith::ArithDialect,
3806 memref::MemRefDialect, scf::SCFDialect, vector::VectorDialect,
3807 emitc::EmitCDialect>();
3808 }
3809
3810 Option<std::string> aieTarget{
3811 *this, "aie-target",
3812 llvm::cl::desc("Select AIE version: \"aie\" or \"aie2\". This will "
3813 "determine the vector size and available operations."),
3814 llvm::cl::init("aie")};
3815
3816 Option<std::string> targetBackend{
3817 *this, "target-backend",
3818 llvm::cl::desc("Select translation backend: \"cpp\" or \"llvmir\". This "
3819 "will determine the aievec operations used to convert "
3820 "from vector dialect."),
3821 llvm::cl::init("cpp")};
3822
3823 void runOnOperation() override {
3824 auto *op = getOperation();
3825 MLIRContext *context = &getContext();
3826 RewritePatternSet patterns(context);
3827 ConversionTarget target(*context);
3828 auto aieVersion = AIEArch::AIE;
3829 if (!aieTarget.empty()) {
3830 std::string target = aieTarget;
3831 if (target == "aieml" || target == "aie2")
3832 aieVersion = AIEArch::AIE2;
3833 else if (target != "aie") {
3834 op->emitError() << "unknown AIE target '" << aieTarget << "'";
3835 return signalPassFailure();
3836 }
3837 }
3838
3839 TargetBackend backend = TargetBackend::CPP;
3840 if (!targetBackend.empty()) {
3841 std::string backendStr = targetBackend;
3842 if (backendStr == "llvmir") {
3843 backend = TargetBackend::LLVMIR;
3844 if (aieVersion == AIEArch::AIE) {
3845 op->emitError() << "targetting LLVM IR is not supported for AIEv1";
3846 signalPassFailure();
3847 return;
3848 }
3849 } else if (backendStr != "cpp") {
3850 op->emitError() << "unknown target backend'" << targetBackend << "'";
3851 signalPassFailure();
3852 return;
3853 }
3854 }
3855
3856 populateAIEVecCommonConversionPatterns(patterns, backend);
3857 configureAIEVecCommonLegalizations(target, backend);
3858 if (aieVersion == AIEArch::AIE) {
3859 populateAIEVecV1ConversionPatterns(patterns, backend);
3860 configureAIEVecV1Legalizations(target, backend);
3861 } else {
3862 populateAIEVecV2ConversionPatterns(patterns, backend);
3863 configureAIEVecV2Legalizations(target, backend);
3864 }
3865
3866 if (failed(applyPartialConversion(op, target, std::move(patterns))))
3867 return signalPassFailure();
3868 }
3869};
3870
3871static std::unique_ptr<Pass>
3872createLowerVectorToAIEVec(const LowerVectorToAIEVecOptions &options) {
3873 return std::make_unique<LowerVectorToAIEVec>(options);
3874}
3875
3876//===---------------------------------------------------------------------------
3877// Custom canonicalization passes
3878//===---------------------------------------------------------------------------
3879
3880// This pass widens UPD ops to twice the width followed by an ext op of the
3881// bottom half. This can be used together with SimplifyUPDOpsPass to find
3882// additional common subexpressions with UPDs generated from unaligned
3883// `transfer_read` ops.
3884struct ExtendUPDOpsPass : PassWrapper<ExtendUPDOpsPass, OperationPass<>> {
3885
3886 void runOnOperation() override {
3887 MLIRContext *context = &getContext();
3888 RewritePatternSet patterns(context);
3889 ConversionTarget target(*context);
3890 patterns.add<ExpandUPDToUPDAndExtPattern>(patterns.getContext());
3891 target.addLegalDialect<aievec::AIEVecDialect>();
3892 target.addDynamicallyLegalOp<aievec::UPDOp>([](aievec::UPDOp op) {
3893 return op.getVector() ||
3894 (op->hasOneUse() && isa<aievec::UPDOp>(*op->getUsers().begin())) ||
3895 llvm::all_of(op->getUsers(),
3896 [](Operation *op) { return isa<aievec::ExtOp>(op); });
3897 });
3898
3899 if (auto *op = getOperation();
3900 failed(applyPartialConversion(op, target, std::move(patterns)))) {
3901 return signalPassFailure();
3902 }
3903 }
3904};
3905
3906// This pass replaces wide UPD ops that are only used by a single ext op of the
3907// bottom half. This pass undos the work of ExtendUPDOpsPass.
3908// TODO: This pass can be extended to work with wide UPD ops that are used by
3909// TODO: a single ext op of the top half, which might be a good opportunity to
3910// TODO: further optimize wide UPDs.
3911struct SimplifyUPDOpsPass : PassWrapper<SimplifyUPDOpsPass, OperationPass<>> {
3912
3913 void runOnOperation() override {
3914 MLIRContext *context = &getContext();
3915 RewritePatternSet patterns(context);
3916 ConversionTarget target(*context);
3917 patterns.add<FuseExtIntoUPDPattern>(patterns.getContext());
3918 target.addLegalDialect<aievec::AIEVecDialect>();
3919 target.addDynamicallyLegalOp<aievec::ExtOp>([](aievec::ExtOp op) {
3920 auto *defOp = op.getSource().getDefiningOp();
3921 return !defOp || !isa<aievec::UPDOp>(defOp) || !defOp->hasOneUse() ||
3922 op.getIndex() != 0;
3923 });
3924
3925 if (auto *op = getOperation();
3926 failed(applyPartialConversion(op, target, std::move(patterns)))) {
3927 return signalPassFailure();
3928 }
3929 }
3930};
3931
3932//============================================================================//
3933//=============== Main Vector2AIEVec Pipeline Configuration ==================//
3934//============================================================================//
3935
3937 OpPassManager &pm, const LowerVectorToAIEVecOptions &options) {
3938 // Add lowering from `Vector` to `AIEVec`
3939 pm.addPass(createLowerVectorToAIEVec(options));
3940 pm.addPass(createCanonicalizerPass());
3941
3942 // Simplify UPD ops
3943 pm.addPass(std::make_unique<ExtendUPDOpsPass>());
3944 pm.addPass(createCSEPass());
3945 pm.addPass(std::make_unique<SimplifyUPDOpsPass>());
3946 pm.addPass(createCanonicalizerPass());
3947}
LowerVectorMinMaxOpToAIEVecMinMaxOp< arith::MaximumFOp, aievec::MaxOp > LowerVectorMaximumFOpToAIEVecMaxOp
ComputeBandAndBorOpPattern< arith::OrIOp, aievec::BorOp > ComputeBorOpPattern
ComputeBandAndBorOpPattern< arith::AndIOp, aievec::BandOp > ComputeBandOpPattern
OneToOneVectorOpToAIEVecOpPattern< arith::SubFOp, aievec::aie1::SubOp > LowerVectorSubFOpToAIEVecSubOp
ComputeAbsOpPattern< math::AbsIOp > ComputeAbsIOpPattern
LowerTruncOpPattern< arith::TruncFOp > LowerTruncFOpPattern
LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp< arith::AddFOp, aievec::AddElemOp > LowerVectorAddFOpToAIEVecAddElemOp
LowerVectorMinMaxOpToAIEVecMinMaxOp< arith::MaxSIOp, aievec::MaxOp > LowerVectorMaxSIOpToAIEVecMaxOp
LowerExtOpPattern< arith::ExtFOp > LowerExtFOpPattern
LowerVectorCmpOpToAIEVecCmpOp< arith::CmpFOp, CmpFPredicate > LowerVectorCmpFOpToAIEVecCmpOp
OneToOneVectorOpToAIEVecOpPattern< arith::SubIOp, aievec::aie1::SubOp > LowerVectorSubIOpToAIEVecSubOp
ComputeAbsOpPattern< math::AbsFOp > ComputeAbsFOpPattern
LowerVectorCmpOpToAIEVecCmpOp< arith::CmpIOp, CmpIPredicate > LowerVectorCmpIOpToAIEVecCmpOp
LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp< arith::SubFOp, aievec::SubElemOp > LowerVectorSubFOpToAIEVecSubElemOp
LowerVectorMinMaxOpToAIEVecMinMaxOp< arith::MinSIOp, aievec::MinOp > LowerVectorMinSIOpToAIEVecMinOp
OneToOneVectorOpToAIEVecOpPattern< arith::AddFOp, aievec::aie1::AddOp > LowerVectorAddFOpToAIEVecAddOp
OneToOneVectorOpToAIEVecOpPattern< arith::MulFOp, aievec::aie1::MulOp > LowerVectorMulFOpToAIEVecMulOp
LowerExtOpPattern< arith::ExtSIOp > LowerExtSIOpPattern
LowerVectorMinMaxOpToAIEVecMinMaxOp< arith::MinimumFOp, aievec::MinOp > LowerVectorMinimumFOpToAIEVecMinOp
LowerVectorAddOrSubOpToAIEVecAddElemOrSubElemOp< arith::AddIOp, aievec::AddElemOp > LowerVectorAddIOpToAIEVecAddElemOp
PathEndPoint src
mlir::VectorType getFlattenedVectorType(mlir::VectorType vecTy)
unsigned getVectorLaneSize(mlir::VectorType type)
Definition AIEVecUtils.h:55
SmallVector< NamedAttribute > buildFMAOpSplatAttrForElemTy(aievec::aie1::FMAOp fmaOp, int64_t bcastPos, int64_t step=1)
std::optional< int64_t > getTransferReadAlignmentOffset(TransferReadLikeOp readOp, mlir::VectorType vType, int64_t alignment)
mlir::VectorType createVectorType(unsigned lanes, mlir::Type elementType)
Definition AIEVecUtils.h:42
int32_t getElementSizeInBits(mlir::VectorType type)
Definition AIEVecUtils.h:49
void buildLowerVectorToAIEVec(mlir::OpPassManager &pm, const LowerVectorToAIEVecOptions &options)
mlir::VectorType getVectorOpDestType(mlir::VectorType type, bool AIE2)
Definition AIEVecUtils.h:80
TargetBackend
Definition Passes.h:27
LogicalResult matchAndRewrite(SrcOpTy absOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::XOrIOp xorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::CeilOp ceilOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::ErfOp erfOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::ExpOp expOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::ExpOp expOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::FloorOp floorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::DivFOp divOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::NegFOp negOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::RsqrtOp rsqrtOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::DivFOp divfOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::ShRSIOp rsOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::SqrtOp sqrtOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(math::TanhOp tanhOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::AddIOp addOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertMulAddToAIEVecFMAElemOpPattern(MLIRContext *context, unsigned shiftParam=0)
LogicalResult matchAndRewrite(aievec::aie1::AddOp addOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::MulFOp mulOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertMulFToAIEVecMulElemOpPattern(MLIRContext *context, unsigned shiftParam=0)
LogicalResult matchAndRewrite(arith::MulIOp mulOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertMulIToAIEVecMulElemOpPattern(MLIRContext *context, unsigned shiftParam=0)
LogicalResult matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
ConvertVectorFMAOpToAIEVecFMAElemOpPattern(MLIRContext *context, unsigned shiftParam=0)
ExpandUPDToUPDAndExtPattern(MLIRContext *context)
LogicalResult matchAndRewrite(aievec::UPDOp updOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::aie1::FMAOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(aievec::ExtOp extOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
FuseExtIntoUPDPattern(MLIRContext *context)
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(SrcOpTy extOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(SrcOpTy truncOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
typename SrcOpTy::Adaptor OpAdaptor
LogicalResult matchAndRewrite(arith::AddIOp addOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LowerVectorContractionOpToAIEVecMatMulPattern(MLIRContext *context, bool matMoveToAcc=true)
LogicalResult matchAndRewrite(vector::ContractionOp contractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::MulIOp mulOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(vector::ReductionOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(arith::SelectOp srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Lower incoming vector operations into their corresponding AIE vector intrinsics.
void getDependentDialects(DialectRegistry &registry) const override
LowerVectorToAIEVec(const LowerVectorToAIEVecOptions &options)
StringRef getDescription() const final
Option< std::string > targetBackend
StringRef getArgument() const final
LogicalResult matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LowerVectorTransferReadToAIEUPD(MLIRContext *context, int64_t minVectorSize, int64_t maxVectorSize, int64_t alignment, int64_t maxLoadSize)
LogicalResult matchAndRewrite(vector::TransposeOp transpOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(SrcOpTy srcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override
Options for the "lower-vector-to-aievec" pipeline.
Definition Passes.h:57
PassOptions::Option< std::string > aieTarget
Definition Passes.h:58
PassOptions::Option< std::string > targetBackend
Definition Passes.h:63