MLIR-AIE
AIEVecAIE1Ops.cpp
Go to the documentation of this file.
1//===-- AIEVecAIE1Ops.cpp - MLIR AIE Vector Dialect Operations --*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2024 Advanced Micro Devices, Inc. or its affiliates
8//
9//===----------------------------------------------------------------------===//
10// This file implements AIE1 vector op printing, pasing, and verification.
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
14#include "mlir/IR/DialectImplementation.h"
15#include "mlir/IR/OpDefinition.h"
16#include "mlir/IR/TypeUtilities.h"
17#include "mlir/Transforms/FoldUtils.h"
18#include "llvm/ADT/TypeSwitch.h"
19
22
23using namespace llvm;
24using namespace mlir;
25
26#include "aie/Dialect/AIEVec/AIE1/IR/AIEVecAIE1OpsDialect.cpp.inc"
27
28namespace xilinx::aievec::aie1 {
29
30//===----------------------------------------------------------------------===//
31// AIEVecAIE1Dialect
32//===----------------------------------------------------------------------===//
33
34void AIEVecAIE1Dialect::initialize() {
35 addOperations<
36#define GET_OP_LIST
37#include "aie/Dialect/AIEVec/AIE1/IR/AIEVecAIE1Ops.cpp.inc"
38 >();
39}
40
41//===----------------------------------------------------------------------===//
42// AddOp and SubOp
43//===----------------------------------------------------------------------===//
44
45// Print out Add and Sub op.
46template <typename T>
47void printAddSubOp(OpAsmPrinter &p, T op) {
48 // Print the lhs operand
49 p << " " << op.getLhs();
50 // Print the rhs operand
51 p << ", " << op.getRhs();
52
53 // Print the attributes, but don't print attributes that are empty strings
54 SmallVector<StringRef, 10> elidedAttrs;
55 for (int idx = 0; idx < 2; ++idx) {
56 if (op.getStart(idx).empty())
57 elidedAttrs.push_back(op.getStartAttrName(idx));
58 if (op.getOffset(idx).empty())
59 elidedAttrs.push_back(op.getOffsetAttrName(idx));
60 if (op.getOffsetHi(idx).empty())
61 elidedAttrs.push_back(op.getOffsetHiAttrName(idx));
62 if (op.getSquare(idx).empty())
63 elidedAttrs.push_back(op.getSquareAttrName(idx));
64 }
65 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
66
67 // And now print the types
68 p << " : " << op.getLhs().getType() << ", " << op.getRhs().getType();
69 p << ", " << op.getResult().getType();
70}
71
72void AddOp::print(OpAsmPrinter &p) { printAddSubOp<AddOp>(p, *this); }
73
74void SubOp::print(OpAsmPrinter &p) { printAddSubOp<SubOp>(p, *this); }
75
76// Verify Add and Sub op.
77template <typename T>
78LogicalResult verifyAddSubOp(T op) {
79 // Verify the types
80 auto resultType = llvm::dyn_cast<VectorType>(op.getResult().getType());
81 auto lhsType = llvm::dyn_cast<VectorType>(op.getLhs().getType());
82 auto rhsType = llvm::dyn_cast<VectorType>(op.getRhs().getType());
83
84 if (!lhsType || !rhsType || !resultType)
85 return op.emitError("requires vector type");
86
87 // All the vector types must match
88 if (lhsType != rhsType || rhsType != resultType)
89 return op.emitError("all vectors must be of same type");
90
91 return success();
92}
93
94LogicalResult AddOp::verify() { return verifyAddSubOp<AddOp>(*this); }
95
96LogicalResult SubOp::verify() { return verifyAddSubOp<SubOp>(*this); }
97
98// Parse Add and Sub op.
99ParseResult parseAddSubOp(OpAsmParser &parser, OperationState &result) {
100 llvm::SMLoc typesLoc;
101 SmallVector<Type, 3> types;
102 OpAsmParser::UnresolvedOperand lhs, rhs;
103
104 // Parse the lhs and rhs
105 if (parser.parseOperand(lhs) || parser.parseComma() ||
106 parser.parseOperand(rhs))
107 return failure();
108
109 // Parse all the attributes and types
110 if (parser.parseOptionalAttrDict(result.attributes) ||
111 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
112 return failure();
113
114 // Assert that there are three types: lhs, rhs, and result
115 if (types.size() != 3)
116 return parser.emitError(typesLoc, "requires three types");
117
118 // Some verification
119 VectorType lhsType = llvm::dyn_cast<VectorType>(types[0]);
120 if (!lhsType)
121 return parser.emitError(typesLoc, "requires vector type");
122 VectorType rhsType = llvm::dyn_cast<VectorType>(types[1]);
123 if (!rhsType)
124 return parser.emitError(typesLoc, "requires vector type");
125 VectorType resultType = llvm::dyn_cast<VectorType>(types[2]);
126 if (!resultType)
127 return parser.emitError(typesLoc, "requires vector type");
128
129 // Populate the lhs, rhs, and accumulator in the result
130 if (parser.resolveOperand(lhs, lhsType, result.operands) ||
131 parser.resolveOperand(rhs, rhsType, result.operands))
132 return failure();
133
134 return parser.addTypeToList(resultType, result.types);
135}
136
137ParseResult AddOp::parse(OpAsmParser &parser, OperationState &result) {
138 return parseAddSubOp(parser, result);
139}
140
141ParseResult SubOp::parse(OpAsmParser &parser, OperationState &result) {
142 return parseAddSubOp(parser, result);
143}
144
145//===----------------------------------------------------------------------===//
146// MulOp and FMAOp
147//===----------------------------------------------------------------------===//
148
149// MulOp and FMAOp are structurally similar, except that FMA op has few extra
150// fields (accumulator, bool flag to indicate if it is fmsub, etc.). We create
151// some specializations to print those fields specifically for FMA op.
152
153// Print the accumulator
154template <typename T>
155void printAccumulator(OpAsmPrinter &p, T op);
156template <>
157inline void printAccumulator(OpAsmPrinter &p, FMAOp op) {
158 p << ", " << op.getAcc();
159}
160template <>
161inline void printAccumulator(OpAsmPrinter &p, MulOp op) {}
162
163// Mark fmsub indicator as elided if the FMA op is not fmsub
164template <typename T>
165void elideFMSubAttr(T op, SmallVector<StringRef, 10> &elidedAttrs);
166template <>
167inline void elideFMSubAttr(FMAOp op, SmallVector<StringRef, 10> &elidedAttrs) {
168 if (!op.getFmsub())
169 elidedAttrs.push_back(op.getSubAttrName());
170}
171template <>
172inline void elideFMSubAttr(MulOp, SmallVector<StringRef, 10> &elidedAttrs) {}
173
174// Print out Mul and FMA op.
175template <typename T>
176static void printMulFMAOp(OpAsmPrinter &p, T op) {
177 // Print the left operand
178 p << " " << op.getLhs();
179 // Print the right operand
180 p << ", " << op.getRhs();
181 // For fma op, print the accumulator
182 printAccumulator(p, op);
183
184 // Print the attributes, but don't print attributes that are empty strings
185 SmallVector<StringRef, 10> elidedAttrs;
186 for (int idx = 0; idx < 2; ++idx) {
187 if (op.getStart(idx).empty())
188 elidedAttrs.push_back(op.getStartAttrName(idx));
189 if (op.getOffset(idx).empty())
190 elidedAttrs.push_back(op.getOffsetAttrName(idx));
191 if (op.getOffsetHi(idx).empty())
192 elidedAttrs.push_back(op.getOffsetHiAttrName(idx));
193 if (op.getStep(idx).empty())
194 elidedAttrs.push_back(op.getStepAttrName(idx));
195 if (op.getSquare(idx).empty())
196 elidedAttrs.push_back(op.getSquareAttrName(idx));
197 elideFMSubAttr(op, elidedAttrs);
198 }
199 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
200
201 // And now print the types
202 p << " : " << op.getLhs().getType() << ", " << op.getRhs().getType();
203 p << ", " << op.getResult().getType();
204}
205
206void MulOp::print(OpAsmPrinter &p) { printMulFMAOp<MulOp>(p, *this); }
207
208void FMAOp::print(OpAsmPrinter &p) { printMulFMAOp<FMAOp>(p, *this); }
209
210// Verify Mul and FMA op.
211template <typename T>
212LogicalResult verifyMulFMAOp(T op) {
213 // Verify the types
214 auto lhsType = llvm::dyn_cast<VectorType>(op.getLhs().getType());
215 auto rhsType = llvm::dyn_cast<VectorType>(op.getRhs().getType());
216
217 if (!lhsType || !rhsType)
218 return op.emitError("requires vector type");
219
220 auto resultType = llvm::dyn_cast<VectorType>(op.getResult().getType());
221 if (!resultType)
222 return op.emitError("requires vector type");
223
224 // Additional checks for FMA op
225 // Get the width of the underlying scalars of all the vectors
226 Type ltype = lhsType.getElementType();
227 Type rtype = rhsType.getElementType();
228 Type atype = resultType.getElementType();
229 unsigned ltypeWidth = ltype.getIntOrFloatBitWidth();
230 unsigned rtypeWidth = rtype.getIntOrFloatBitWidth();
231 unsigned atypeWidth = atype.getIntOrFloatBitWidth();
232
233 // Checks on the number of lanes
234 unsigned accLanes = getVectorLaneSize(resultType);
235 unsigned rhsLanes = getVectorLaneSize(rhsType);
236 unsigned lhsLanes = getVectorLaneSize(lhsType);
237
238 // If this is not a simple scheme, perform complex checks
239 if (accLanes != rhsLanes || accLanes != lhsLanes) {
240 if (rhsLanes != 256 / rtypeWidth)
241 return op.emitError("incorrect rhs operand vector lanes");
242 if (lhsLanes < 2 * rhsLanes)
243 return op.emitError("The number of lanes in lhs operand "
244 "must be at least twice that of rhs operand");
245 if (accLanes > rhsLanes)
246 return op.emitError("The number of lanes in accumulator "
247 "must be less than that of rhs operand");
248 }
249
250 // lhs and rhs vector's element type must match
251 if (ltype != rtype)
252 return op.emitError("The element type of lhs and rhs "
253 "operand vectors must match");
254
255 // The datatype of accumulator must always be greater width
256 if (isa<IntegerType>(atype)) {
257 if (!isa<IntegerType>(ltype))
258 return op.emitError("Integer result must have integer operands");
259
260 if (ltypeWidth >= atypeWidth || rtypeWidth >= atypeWidth)
261 return op.emitError("the element type of accumulator must have "
262 "wider width than that of the operand vectors");
263 } else if (isa<FloatType>(atype)) {
264 if (!isa<FloatType>(ltype))
265 return op.emitError("Floating point result must have "
266 "floating point operands");
267
268 if (ltypeWidth != atypeWidth || rtypeWidth != atypeWidth)
269 return op.emitError("the element type of accumulator must be "
270 "same width as the operand vectors");
271 }
272
273 return success();
274}
275
276LogicalResult MulOp::verify() { return verifyMulFMAOp<MulOp>(*this); }
277
278LogicalResult FMAOp::verify() { return verifyMulFMAOp<FMAOp>(*this); }
279
280// Parse Mul and FMA op.
281ParseResult parseMulFMAOp(OpAsmParser &parser, OperationState &result,
282 bool isFMAOp = true) {
283 llvm::SMLoc typesLoc;
284 SmallVector<Type, 3> types;
285 OpAsmParser::UnresolvedOperand lhs, rhs, acc;
286
287 // Parse the lhs and rhs
288 if (parser.parseOperand(lhs) || parser.parseComma() ||
289 parser.parseOperand(rhs))
290 return failure();
291
292 // Parse the acc for FMA op
293 if (isFMAOp) {
294 if (parser.parseComma() || parser.parseOperand(acc))
295 return failure();
296 }
297
298 // Parse all the attributes and types
299 if (parser.parseOptionalAttrDict(result.attributes) ||
300 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
301 return failure();
302
303 // Assert that there are three types: lhs, rhs, and acc
304 if (types.size() != 3)
305 return parser.emitError(typesLoc, "requires three types");
306
307 // Some verification
308 VectorType lhsType = llvm::dyn_cast<VectorType>(types[0]);
309 if (!lhsType)
310 return parser.emitError(typesLoc, "requires vector type");
311 VectorType rhsType = llvm::dyn_cast<VectorType>(types[1]);
312 if (!rhsType)
313 return parser.emitError(typesLoc, "requires vector type");
314
315 // Int ops use the accumulator while float ops use normal vector registers
316 VectorType accType = llvm::dyn_cast<VectorType>(types[2]);
317 if (!accType)
318 return parser.emitError(typesLoc, "requires vector type");
319
320 // Populate the lhs and rhs operands, and result
321 if (parser.resolveOperand(lhs, lhsType, result.operands) ||
322 parser.resolveOperand(rhs, rhsType, result.operands))
323 return failure();
324
325 // Populate acc operand for FMA op
326 if (isFMAOp) {
327 if (parser.resolveOperand(acc, accType, result.operands))
328 return failure();
329 }
330
331 return parser.addTypeToList(accType, result.types);
332}
333
334ParseResult MulOp::parse(OpAsmParser &parser, OperationState &result) {
335 return parseMulFMAOp(parser, result, false);
336}
337
338ParseResult FMAOp::parse(OpAsmParser &parser, OperationState &result) {
339 return parseMulFMAOp(parser, result, true);
340}
341
342//===----------------------------------------------------------------------===//
343// SelectOp
344//===----------------------------------------------------------------------===//
345
346// Print out select op.
347void SelectOp::print(OpAsmPrinter &p) {
348 // Print the xbuff
349 p << " " << getXbuff();
350 // Print the start, offsets, etc. for xbuff
351 if (getYbuff())
352 p << ", " << getYbuff();
353
354 // Print the attributes, but don't print attributes that are empty strings
355 SmallVector<StringRef, 10> elidedAttrs;
356 for (int idx = 0; idx < 2; ++idx) {
357 if (getStart(idx).empty())
358 elidedAttrs.push_back(getStartAttrName(idx));
359 if (getOffset(idx).empty())
360 elidedAttrs.push_back(getOffsetAttrName(idx));
361 if (getOffsetHi(idx).empty())
362 elidedAttrs.push_back(getOffsetHiAttrName(idx));
363 if (getSquare(idx).empty())
364 elidedAttrs.push_back(getSquareAttrName(idx));
365 }
366 p.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
367
368 // And now print the types
369 p << " : " << getXbuff().getType();
370 if (getYbuff())
371 p << ", " << getYbuff().getType();
372 p << ", " << getResult().getType();
373}
374
375// Verify select op.
376LogicalResult SelectOp::verify() {
377 // Verify the types
378 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
379 VectorType xbuffType = llvm::dyn_cast<VectorType>(getXbuff().getType());
380
381 if (!resultType || !xbuffType)
382 return emitError("requires vector type");
383
384 // The underlying scalar element type of all vectors must match
385 Type rtype = resultType.getElementType();
386 Type xtype = xbuffType.getElementType();
387 if (rtype != xtype)
388 return emitError("types of result and xbuff must match");
389
390 // If yuff is present, its vector type should be same as xbuff
391 if (getYbuff()) {
392 VectorType ybuffType = llvm::dyn_cast<VectorType>(getYbuff().getType());
393 if (xbuffType != ybuffType)
394 return emitError("types of xbuff and ybuff must match");
395 }
396
397 // Compare the lanes. xtype should have more lanes
398 unsigned sourceLanes = getVectorLaneSize(xbuffType);
399 unsigned resultLanes = getVectorLaneSize(resultType);
400 if (sourceLanes < resultLanes)
401 return emitError("xbuff cannot be smaller than result");
402
403 return success();
404}
405
406// Parse select op.
407ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
408 llvm::SMLoc typesLoc;
409 SmallVector<Type, 3> types;
410 OpAsmParser::UnresolvedOperand xbuff, ybuff;
411
412 // Parse xbuff
413 if (parser.parseOperand(xbuff))
414 return failure();
415
416 // Parse optional ybuff
417 ParseResult hasYbuff = parser.parseOptionalComma();
418 if (hasYbuff.succeeded() && parser.parseOperand(ybuff))
419 return failure();
420
421 // Parse all the attributes and types
422 if (parser.parseOptionalAttrDict(result.attributes) ||
423 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
424 return failure();
425
426 // Assert that there is at least two types
427 if (types.size() < 2)
428 return parser.emitError(typesLoc, "requires at least two type");
429
430 // Some verification
431 VectorType xbuffType = llvm::dyn_cast<VectorType>(types[0]);
432 if (!xbuffType)
433 return parser.emitError(typesLoc, "requires vector type");
434 VectorType ybuffType;
435 if (hasYbuff.succeeded()) {
436 ybuffType = llvm::dyn_cast<VectorType>(types[1]);
437 if (!ybuffType)
438 return parser.emitError(typesLoc, "requires vector type");
439 }
440 VectorType resultType = llvm::dyn_cast<VectorType>(types.back());
441 if (!resultType)
442 return parser.emitError(typesLoc, "requires vector type");
443
444 // Populate the xbuff
445 if (parser.resolveOperand(xbuff, xbuffType, result.operands))
446 return failure();
447 // Populate optional ybuff in result
448 if (hasYbuff.succeeded())
449 if (parser.resolveOperand(ybuff, ybuffType, result.operands))
450 return failure();
451
452 return parser.addTypeToList(resultType, result.types);
453}
454
455//===----------------------------------------------------------------------===//
456// ExtOp
457//===----------------------------------------------------------------------===//
458
459// Print out Ext op.
460void ExtOp::print(OpAsmPrinter &p) {
461 // Print the source vector
462 p << " " << getSource();
463
464 // Print the attributes
465 p.printOptionalAttrDict((*this)->getAttrs());
466
467 // And now print the types
468 p << " : " << getSource().getType() << ", " << getResult().getType();
469}
470
471// Verify Ext op.
472LogicalResult ExtOp::verify() {
473 // Verify the types
474 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
475 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
476 if (!sourceType || !resultType)
477 return emitError("requires vector type");
478
479 // Check the number of lanes
480 unsigned sourceLanes = getVectorLaneSize(sourceType);
481 unsigned resultLanes = getVectorLaneSize(resultType);
482 // Source lanes must be greater than result lanes
483 if (sourceLanes / resultLanes <= 1)
484 return emitError("lanes in source vector must be at least "
485 "twice that of result vector");
486 // Source lanes must be a multiple of result lanes
487 if (sourceLanes % resultLanes != 0)
488 return emitError("lanes in result vector must be a multiple "
489 "of source vector lanes");
490
491 // Verify validity of index
492 unsigned factor = sourceLanes / resultLanes;
493 if (static_cast<unsigned>(getIndex()) >= factor)
494 return emitError("index out of bounds");
495
496 // The datatype of source and result must match
497 Type stype = sourceType.getElementType();
498 Type rtype = resultType.getElementType();
499 if (stype != rtype)
500 return emitError("source and result element type must be same");
501
502 return success();
503}
504
505// Parse Ext op.
506ParseResult ExtOp::parse(OpAsmParser &parser, OperationState &result) {
507 llvm::SMLoc typesLoc;
508 SmallVector<Type, 2> types;
509 OpAsmParser::UnresolvedOperand source;
510
511 // Parse the source vector
512 if (parser.parseOperand(source))
513 return failure();
514
515 // Parse all the attributes and types
516 if (parser.parseOptionalAttrDict(result.attributes) ||
517 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
518 return failure();
519
520 if (result.attributes.getAttrs().size() != 1)
521 return parser.emitError(typesLoc, "requires one attribute");
522
523 // Assert that there are two types (source and result)
524 if (types.size() != 2)
525 return parser.emitError(typesLoc, "requires two types");
526
527 // Some verification
528 VectorType sourceType = llvm::dyn_cast<VectorType>(types[0]);
529 VectorType resultType = llvm::dyn_cast<VectorType>(types[1]);
530 if (!sourceType || !resultType)
531 return parser.emitError(typesLoc, "requires vector type");
532
533 // Populate the source in result
534 if (parser.resolveOperand(source, sourceType, result.operands))
535 return failure();
536
537 return parser.addTypeToList(resultType, result.types);
538}
539
540} // namespace xilinx::aievec::aie1
541
542// #define GET_ATTRDEF_CLASSES
543// #include "aie/Dialect/AIEVec/IR/AIEVecAttributes.cpp.inc"
544
545#define GET_OP_CLASSES
546#include "aie/Dialect/AIEVec/AIE1/IR/AIEVecAIE1Ops.cpp.inc"
ParseResult parseAddSubOp(OpAsmParser &parser, OperationState &result)
LogicalResult verifyAddSubOp(T op)
ParseResult parseMulFMAOp(OpAsmParser &parser, OperationState &result, bool isFMAOp=true)
LogicalResult verifyMulFMAOp(T op)
void printAddSubOp(OpAsmPrinter &p, T op)
void elideFMSubAttr(T op, SmallVector< StringRef, 10 > &elidedAttrs)
void printAccumulator(OpAsmPrinter &p, T op)
unsigned getVectorLaneSize(mlir::VectorType type)
Definition AIEVecUtils.h:55