MLIR-AIE
AIEVecOps.cpp
Go to the documentation of this file.
1//===---- AIEVecOps.cpp - MLIR AIE Vector Dialect Operations ----*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2022-2024 Advanced Micro Devices, Inc. or its affiliates
8//
9//===----------------------------------------------------------------------===//
10// This file implements AIE vector op printing, pasing, and verification.
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
14#include "mlir/IR/DialectImplementation.h"
15#include "mlir/IR/OpDefinition.h"
16#include "mlir/IR/TypeUtilities.h"
17#include "mlir/Transforms/FoldUtils.h"
18#include "llvm/ADT/TypeSwitch.h"
19
22
23using namespace llvm;
24using namespace mlir;
25using namespace xilinx;
26using namespace xilinx::aievec;
27
28#include "aie/Dialect/AIEVec/IR/AIEVecEnums.cpp.inc"
29#include "aie/Dialect/AIEVec/IR/AIEVecOpsDialect.cpp.inc"
30
31//===----------------------------------------------------------------------===//
32// AIEVecDialect
33//===----------------------------------------------------------------------===//
34
35void AIEVecDialect::initialize() {
36 registerTypes();
37 addAttributes<
38#define GET_ATTRDEF_LIST
39#include "aie/Dialect/AIEVec/IR/AIEVecAttributes.cpp.inc"
40 >();
41 addOperations<
42#define GET_OP_LIST
43#include "aie/Dialect/AIEVec/IR/AIEVecOps.cpp.inc"
44 >();
45}
46
47//===----------------------------------------------------------------------===//
48// UPDOp
49//===----------------------------------------------------------------------===//
50
51// Print out UPD op.
52void UPDOp::print(OpAsmPrinter &p) {
53 // Print the source memref
54 p << " " << getSource() << "[" << getIndices() << "]";
55 // Now print the optional vector that links upd idx=1 with idx=0
56 if (getVector())
57 p << ", " << getVector();
58
59 // Print the attributes, but don't print the operand segment sizes
60 SmallVector<StringRef, 3> elidedAttrs;
61 elidedAttrs.push_back(UPDOp::getOperandSegmentSizeAttr());
62 p.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
63
64 // And now print the types
65 p << " : " << getSource().getType() << ", " << getResult().getType();
66}
67
68// Verify UPD op.
69LogicalResult UPDOp::verify() {
70 // Verify the types: source is memref, and result is vector
71 MemRefType sourceType = llvm::dyn_cast<MemRefType>(getSource().getType());
72 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
73 if (!sourceType)
74 return emitError("requires memref type");
75 if (!resultType)
76 return emitError("requires vector type");
77 if (getIndices().empty())
78 return emitError("upd source cannot come from scalar value");
79
80 // If this UPD op is linked to another UPD op, then verify that the linked
81 // vector and the result vector match.
82 if (getVector()) {
83 Type vecType = llvm::dyn_cast<VectorType>(getVector().getType());
84 if (vecType != resultType)
85 return emitError("result types of linked UPD ops do not match");
86 }
87 return success();
88}
89
90// Parse UPD op.
91ParseResult UPDOp::parse(OpAsmParser &parser, OperationState &result) {
92 auto &builder = parser.getBuilder();
93 llvm::SMLoc typesLoc;
94 SmallVector<Type, 2> types;
95 OpAsmParser::UnresolvedOperand source, vector;
96 SmallVector<OpAsmParser::UnresolvedOperand, 8> indices;
97
98 // Parse the source, indices, and optional vector
99 if (parser.parseOperand(source) ||
100 parser.parseOperandList(indices, OpAsmParser::Delimiter::Square))
101 return failure();
102 ParseResult hasVector = parser.parseOptionalComma();
103 if (hasVector.succeeded() && parser.parseOperand(vector))
104 return failure();
105
106 // Parse all the attributes and types
107 if (parser.parseOptionalAttrDict(result.attributes) ||
108 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
109 return failure();
110
111 if (result.attributes.getAttrs().size() != 2)
112 return parser.emitError(typesLoc, "requires two attributes");
113
114 // Assert that there are two types (memref source and vector result)
115 if (types.size() != 2)
116 return parser.emitError(typesLoc, "requires two types");
117
118 // Some verification
119 auto memrefType = llvm::dyn_cast<MemRefType>(types[0]);
120 if (!memrefType)
121 return parser.emitError(typesLoc, "requires memref type");
122 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
123 if (!vectorType)
124 return parser.emitError(typesLoc, "requires vector type");
125 auto indicesType = builder.getIndexType();
126
127 // Populate the source and indices in result
128 if (parser.resolveOperand(source, memrefType, result.operands) ||
129 parser.resolveOperands(indices, indicesType, result.operands))
130 return failure();
131 // Populate optional vector in result
132 if (hasVector.succeeded())
133 if (parser.resolveOperand(vector, vectorType, result.operands))
134 return failure();
135
136 // Populate operand size attribute in result
137 result.addAttribute(UPDOp::getOperandSegmentSizeAttr(),
138 builder.getDenseI32ArrayAttr(
139 {1, static_cast<int32_t>(indices.size()),
140 static_cast<int32_t>(hasVector.succeeded())}));
141
142 return parser.addTypeToList(vectorType, result.types);
143}
144
145//===----------------------------------------------------------------------===//
146// CastOp
147//===----------------------------------------------------------------------===//
148
149// Print out Cast op.
150void CastOp::print(OpAsmPrinter &p) {
151 // Print the source accumulator
152 p << " " << getSource();
153
154 // Print the attributes
155 p.printOptionalAttrDict((*this)->getAttrs());
156
157 // And now print the types
158 p << " : " << getSource().getType() << ", " << getResult().getType();
159}
160
161// Verify Cast op.
162LogicalResult CastOp::verify() {
163 // Verify the types
164 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
165 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
166 if (!sourceType)
167 return emitError("requires source vector type");
168 if (!resultType)
169 return emitError("requires result vector type");
170
171 if (sourceType.getElementType().getIntOrFloatBitWidth() !=
172 resultType.getElementType().getIntOrFloatBitWidth()) {
173 return emitError("the bitwidth of resource and result should be equal");
174 }
175
176 return success();
177}
178
179// Parse Cast op.
180ParseResult CastOp::parse(OpAsmParser &parser, OperationState &result) {
181 llvm::SMLoc typesLoc;
182 SmallVector<Type, 2> types;
183 OpAsmParser::UnresolvedOperand source;
184
185 // Parse the source vector
186 if (parser.parseOperand(source))
187 return failure();
188
189 // Parse all the attributes and types
190 if (parser.parseOptionalAttrDict(result.attributes) ||
191 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
192 return failure();
193
194 if (result.attributes.getAttrs().size() != 1)
195 return parser.emitError(typesLoc, "requires one attribute");
196
197 // Assert that there are two types (source and result)
198 if (types.size() != 2)
199 return parser.emitError(typesLoc, "requires two types");
200
201 // Some verification of types
202 VectorType sourceType = llvm::dyn_cast<VectorType>(types[0]);
203 if (!sourceType)
204 return parser.emitError(typesLoc, "requires vector type");
205 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
206 if (!vectorType)
207 return parser.emitError(typesLoc, "requires vector type");
208
209 // Populate the source in result
210 if (parser.resolveOperand(source, sourceType, result.operands))
211 return failure();
212
213 return parser.addTypeToList(vectorType, result.types);
214}
215
216// Cast fold method. It will fold with a preceding Cast operation.
217OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
218 auto srcCastOp = getSource().getDefiningOp<aievec::CastOp>();
219 if (!srcCastOp)
220 return nullptr;
221
222 if (srcCastOp.getIsResAcc() == getIsResAcc())
223 return srcCastOp.getResult();
224
225 return srcCastOp.getSource();
226}
227
228//===----------------------------------------------------------------------===//
229// SRSOp
230//===----------------------------------------------------------------------===//
231
232// SRS fold method. It will fold with a preceding UPS operation.
233OpFoldResult SRSOp::fold(FoldAdaptor adaptor) {
234 auto srcDefOp = getSource().getDefiningOp();
235 if (!srcDefOp)
236 return nullptr;
237
238 auto upsOp = dyn_cast<UPSOp>(srcDefOp);
239 if (!upsOp)
240 return nullptr;
241
242 auto shiftDefOp = getShift().getDefiningOp();
243 if (!shiftDefOp)
244 return nullptr;
245
246 auto constOp = dyn_cast<arith::ConstantOp>(shiftDefOp);
247 if (!constOp)
248 return nullptr;
249
250 // Only fold srs(ups(x), 0) → x. Non-zero shift performs actual
251 // shift-round-saturate and must not be folded away.
252 auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue());
253 if (!intAttr || intAttr.getInt() != 0)
254 return nullptr;
255
256 if (upsOp.getSource().getType() != getResult().getType())
257 return nullptr;
258
259 return upsOp.getSource();
260}
261
262// Print out SRS op.
263void SRSOp::print(OpAsmPrinter &p) {
264 // Print the source accumulator
265 p << " " << getSource() << ", ";
266
267 // Print the shift
268 p << getShift();
269
270 // Always print the attribute dict, but elide the sign attribute when it
271 // has the default value (1 = signed) to avoid breaking existing tests.
272 if (getSign() == 1) {
273 p.printOptionalAttrDict((*this)->getAttrs(), {"sign"});
274 } else {
275 p.printOptionalAttrDict((*this)->getAttrs());
276 }
277
278 // And now print the types
279 p << " : " << getSource().getType() << ", " << getShift().getType() << ", "
280 << getResult().getType();
281}
282
283// Verify SRS op.
284LogicalResult SRSOp::verify() {
285 // Verify the types
286 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
287 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
288 if (!sourceType)
289 return emitError("requires accumulator type");
290 if (!resultType)
291 return emitError("requires vector type");
292
293 // The number of lanes of source accumulator and result vector must match
294 unsigned accLanes = getVectorLaneSize(sourceType);
295 unsigned vecLanes = getVectorLaneSize(resultType);
296 if (accLanes != vecLanes)
297 return emitError("The number of lanes in result vector "
298 "and source accumulator must match");
299
300 // The datatype of accumulator must have greater width
301 Type stype = resultType.getElementType();
302 Type atype = sourceType.getElementType();
303 unsigned stypeWidth = stype.getIntOrFloatBitWidth();
304 unsigned atypeWidth = atype.getIntOrFloatBitWidth();
305
306 if (isa<IntegerType>(atype) && stypeWidth >= atypeWidth)
307 return emitError("the element type of source accumulator must be "
308 "wider than that of the result vector");
309 else if (isa<FloatType>(atype) && stypeWidth != 16 &&
310 stypeWidth != atypeWidth)
311 return emitError("the element type of source accumulator must be "
312 "same as the result vector");
313
314 return success();
315}
316
317// Parse SRS op.
318ParseResult SRSOp::parse(OpAsmParser &parser, OperationState &result) {
319 llvm::SMLoc typesLoc;
320 SmallVector<Type, 3> types;
321 OpAsmParser::UnresolvedOperand source, shift;
322
323 // Parse the source accumulator
324 if (parser.parseOperand(source) || parser.parseComma() ||
325 parser.parseOperand(shift))
326 return failure();
327
328 // Parse optional attributes (e.g., {sign = 0 : i32})
329 if (parser.parseOptionalAttrDict(result.attributes))
330 return failure();
331
332 // Parse types
333 if (parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
334 return failure();
335
336 // Assert that there are two types (accumulator source and vector result)
337 if (types.size() != 3)
338 return parser.emitError(typesLoc, "requires three types");
339
340 // Some verification of types
341 VectorType accType = llvm::dyn_cast<VectorType>(types[0]);
342 if (!accType)
343 return parser.emitError(typesLoc, "requires vector type");
344
345 IntegerType shiftType = llvm::dyn_cast<IntegerType>(types[1]);
346 if (!shiftType)
347 return parser.emitError(typesLoc, "requires integer type");
348
349 VectorType vectorType = llvm::dyn_cast<VectorType>(types[2]);
350 if (!vectorType)
351 return parser.emitError(typesLoc, "requires vector type");
352
353 // Populate the source in result
354 if (parser.resolveOperand(source, accType, result.operands) ||
355 parser.resolveOperand(shift, shiftType, result.operands))
356 return failure();
357
358 return parser.addTypeToList(vectorType, result.types);
359}
360
361//===----------------------------------------------------------------------===//
362// UPSOp
363//===----------------------------------------------------------------------===//
364
365// UPS fold method. It will fold with a preceding SRS operation.
366OpFoldResult UPSOp::fold(FoldAdaptor adaptor) {
367 // TODO: Both UPS and SRS have an aditional parameter (shift) that's being
368 // TODO: ignored here. Somebody should take a careful look at it.
369 // TODO: In next llvm version: auto srsDefOp =
370 // adaptor.getSource().getDefiningOp();
371 auto srcDefOp = getSource().getDefiningOp();
372 if (!srcDefOp)
373 return nullptr;
374 auto srsOp = llvm::dyn_cast<SRSOp>(srcDefOp);
375 if (!srsOp)
376 return nullptr;
377 return srsOp.getSource();
378}
379
380// Print out UPS op.
381void UPSOp::print(OpAsmPrinter &p) {
382 // Print the source vector
383 p << " " << getSource();
384
385 // Print the attributes
386 p.printOptionalAttrDict((*this)->getAttrs());
387
388 // And now print the types
389 p << " : " << getSource().getType() << ", " << getResult().getType();
390}
391
392// Verify UPS op.
393LogicalResult UPSOp::verify() {
394 // Verify the types
395 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
396 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
397 if (!sourceType)
398 return emitError("requires vector type");
399 if (!resultType)
400 return emitError("requires vector type");
401
402 // The number of lanes must match
403 unsigned vecLanes = getVectorLaneSize(sourceType);
404 unsigned accLanes = getVectorLaneSize(resultType);
405 if (vecLanes != accLanes)
406 return emitError("The number of lanes in source vector "
407 "and result accumulator must match");
408
409 // The datatype of accumulator must always be greater width
410 Type stype = sourceType.getElementType();
411 Type atype = resultType.getElementType();
412 unsigned stypeWidth = stype.getIntOrFloatBitWidth();
413 unsigned atypeWidth = atype.getIntOrFloatBitWidth();
414
415 if (stypeWidth >= atypeWidth)
416 return emitError("the element type of result accumulator "
417 "must be wider than that of the source vector");
418
419 return success();
420}
421
422// Parse UPS op.
423ParseResult UPSOp::parse(OpAsmParser &parser, OperationState &result) {
424 llvm::SMLoc typesLoc;
425 SmallVector<Type, 2> types;
426 OpAsmParser::UnresolvedOperand source;
427
428 // Parse the source vector
429 if (parser.parseOperand(source))
430 return failure();
431
432 // Parse all the attributes and types
433 if (parser.parseOptionalAttrDict(result.attributes) ||
434 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
435 return failure();
436
437 if (result.attributes.getAttrs().size() != 1)
438 return parser.emitError(typesLoc, "requires one attribute");
439
440 // Assert that there are two types (source vector and accumulator result)
441 if (types.size() != 2)
442 return parser.emitError(typesLoc, "requires two types");
443
444 // Some verification
445 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
446 if (!vectorType)
447 return parser.emitError(typesLoc, "requires vector type");
448 VectorType accType = llvm::dyn_cast<VectorType>(types[1]);
449 if (!accType)
450 return parser.emitError(typesLoc, "requires vector type");
451
452 // Populate the source in result
453 if (parser.resolveOperand(source, vectorType, result.operands))
454 return failure();
455
456 return parser.addTypeToList(accType, result.types);
457}
458
459//===----------------------------------------------------------------------===//
460// BroadcastOp
461//===----------------------------------------------------------------------===//
462
463// Print out Broadcast op.
464void BroadcastOp::print(OpAsmPrinter &p) {
465 // Print the source vector
466 p << " " << getSource();
467
468 // Print the attributes
469 p.printOptionalAttrDict((*this)->getAttrs());
470
471 // And now print the types
472 p << " : " << getSource().getType() << ", " << getResult().getType();
473}
474
475// Verify Broadcast op.
476LogicalResult BroadcastOp::verify() {
477 // Verify the types
478 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
479 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
480
481 if (!sourceType)
482 return emitError("requires vector type");
483 if (!resultType)
484 return emitError("requires vector type");
485
486 if (sourceType != resultType) {
487 return emitError("The vector type of source vector "
488 "and result vector must match");
489 }
490 // The number of lanes must match
491 unsigned sourceLanes = getVectorLaneSize(sourceType);
492 unsigned resultLanes = getVectorLaneSize(resultType);
493 if (sourceLanes != resultLanes)
494 return emitError("The number of lanes in source vector "
495 "and result vector must match");
496
497 // The element type of vectors must always be the same
498 Type stype = sourceType.getElementType();
499 Type rtype = resultType.getElementType();
500
501 if (stype != rtype) {
502 return emitError("the element type of result vector "
503 "must be the same as source vector");
504 }
505
506 return success();
507}
508
509// Parse Broadcast op.
510ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
511 llvm::SMLoc typesLoc;
512 SmallVector<Type, 2> types;
513 OpAsmParser::UnresolvedOperand source;
514
515 // Parse the source vector
516 if (parser.parseOperand(source))
517 return failure();
518
519 // Parse all the attributes and types
520 if (parser.parseOptionalAttrDict(result.attributes) ||
521 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
522 return failure();
523
524 if (result.attributes.getAttrs().size() != 1)
525 return parser.emitError(typesLoc, "requires one attribute");
526
527 // Assert that there are two types (source vector and result vector)
528 if (types.size() != 2)
529 return parser.emitError(typesLoc, "requires two types");
530
531 // Some verification
532 VectorType vecType = llvm::dyn_cast<VectorType>(types[0]);
533 if (!vecType)
534 return parser.emitError(typesLoc, "requires vector type");
535
536 VectorType resType = llvm::dyn_cast<VectorType>(types[1]);
537 if (!resType)
538 return parser.emitError(typesLoc, "requires vector type");
539
540 // Populate the source in result
541 if (parser.resolveOperand(source, vecType, result.operands))
542 return failure();
543
544 return parser.addTypeToList(resType, result.types);
545}
546
547//===----------------------------------------------------------------------===//
548// BroadcastScalarOp
549//===----------------------------------------------------------------------===//
550
551// Print out BroadcastScalar op.
552void BroadcastScalarOp::print(OpAsmPrinter &p) {
553 // Print the source vector
554 p << " " << getSource();
555
556 // And now print the types
557 p << " : " << getSource().getType() << ", " << getResult().getType();
558}
559
560// Verify BroadcastScalar op.
561LogicalResult BroadcastScalarOp::verify() {
562 // Verify the types
563 Type sourceType = getSource().getType();
564 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
565
566 if (!resultType)
567 return emitError("requires vector type");
568
569 if (!isa<IntegerType, FloatType>(sourceType))
570 return emitError("requires source type to be integer or float");
571
572 Type resultElemType = resultType.getElementType();
573 if (sourceType != resultElemType) {
574 return emitError("the element type of result vector must be the same as "
575 "the source type");
576 }
577
578 return success();
579}
580
581// Parse BroadcastScalar op.
582ParseResult BroadcastScalarOp::parse(OpAsmParser &parser,
583 OperationState &result) {
584 llvm::SMLoc typesLoc;
585 SmallVector<Type, 2> types;
586 OpAsmParser::UnresolvedOperand source;
587
588 // Parse the source vector
589 if (parser.parseOperand(source))
590 return failure();
591
592 // Parse all the attributes and types
593 if (parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
594 return failure();
595
596 if (!result.attributes.getAttrs().empty())
597 return parser.emitError(typesLoc, "do not require attributes");
598
599 // Assert that there is two type (source and result vector)
600 if (types.size() != 2)
601 return parser.emitError(typesLoc, "requires two types");
602
603 // Some verification
604 VectorType resType = llvm::dyn_cast<VectorType>(types[1]);
605 if (!resType)
606 return parser.emitError(typesLoc, "requires vector type");
607
608 if (parser.resolveOperand(source, types[0], result.operands))
609 return failure();
610
611 return parser.addTypeToList(resType, result.types);
612}
613
614//===----------------------------------------------------------------------===//
615// MulElemOp and FMAElemOp
616//===----------------------------------------------------------------------===//
617
618// MulElemOp and FMAElemOp are structurally similar, except that FMAElem op
619// has few extra fields (accumulator, bool flag to indicate if it is fmsub,
620// etc.). We create some specializations to print those fields specifically for
621// FMAElemOp and MULElemOp.
622
623// Print the accumulator
624template <typename T>
625void printAccumulator(OpAsmPrinter &p, T op);
626template <>
627inline void printAccumulator(OpAsmPrinter &p, aievec::FMAElemOp op) {
628 p << ", " << op.getAcc();
629}
630template <>
631inline void printAccumulator(OpAsmPrinter &p, aievec::MulElemOp op) {}
632
633// Mark fmsub indicator as elided if the FMAElem op is not fmsub
634template <typename T>
635void elideFMSubAttr(T op, SmallVector<StringRef, 4> &elidedAttrs);
636template <>
637inline void elideFMSubAttr(aievec::FMAElemOp op,
638 SmallVector<StringRef, 4> &elidedAttrs) {
639 if (!op.getFmsub())
640 elidedAttrs.push_back(op.getSubAttrName());
641}
642
643template <>
644inline void elideFMSubAttr(aievec::MulElemOp op,
645 SmallVector<StringRef, 4> &elidedAttrs) {}
646
647// Print out MulElem and FMAElem op.
648template <typename T>
649static void printMulFMAElemOp(OpAsmPrinter &p, T op) {
650 // Print the left operand
651 p << " " << op.getLhs();
652 // Print the right operand
653 p << ", " << op.getRhs();
654 // For fma op, print the accumulator
655 printAccumulator(p, op);
656
657 // Print the attributes, but don't print attributes that are empty strings
658 SmallVector<StringRef, 4> elidedAttrs;
659 for (int idx = 0; idx < 2; ++idx) {
660 elideFMSubAttr(op, elidedAttrs);
661 }
662 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
663
664 // And now print the types
665 p << " : " << op.getLhs().getType() << ", " << op.getRhs().getType();
666 p << ", " << op.getResult().getType();
667}
668
669void MulElemOp::print(OpAsmPrinter &p) {
670 printMulFMAElemOp<aievec::MulElemOp>(p, *this);
671}
672
673void aievec::FMAElemOp::print(OpAsmPrinter &p) {
674 printMulFMAElemOp<aievec::FMAElemOp>(p, *this);
675}
676
677// Verify MulElem and FMAElem op.
678template <typename T>
679LogicalResult verifyMulFMAElemOp(T op) {
680 // Verify the types
681 auto lhsType = llvm::dyn_cast<VectorType>(op.getLhs().getType());
682 auto rhsType = llvm::dyn_cast<VectorType>(op.getRhs().getType());
683
684 if (!lhsType || !rhsType)
685 return op.emitError("requires vector type");
686
687 auto resultType = llvm::dyn_cast<VectorType>(op.getResult().getType());
688
689 if (!resultType)
690 return op.emitError("requires vector type");
691
692 // Additional checks for FMAElem op
693 // Get the width of the underlying scalars of all the vectors
694 Type ltype = lhsType.getElementType();
695 Type rtype = rhsType.getElementType();
696 Type atype = resultType.getElementType();
697 unsigned ltypeWidth = ltype.getIntOrFloatBitWidth();
698 unsigned rtypeWidth = rtype.getIntOrFloatBitWidth();
699 unsigned atypeWidth = atype.getIntOrFloatBitWidth();
700
701 // Checks on the number of lanes
702 unsigned rhsLanes = getVectorLaneSize(rhsType);
703 unsigned lhsLanes = getVectorLaneSize(lhsType);
704
705 // lane size must match
706 if (lhsLanes != rhsLanes) {
707 return op.emitError("The number of lanes in lhs operand "
708 "must be the same as rhs operand");
709 }
710
711 // lhs and rhs vector's element type must match
712 if (ltype != rtype)
713 return op.emitError("The element type of lhs and rhs "
714 "operand vectors must match");
715
716 // The integer datatype of accumulator must always be greater width
717 if (isa<IntegerType>(atype)) {
718 if (!isa<IntegerType>(ltype))
719 return op.emitError("Integer result must have integer operands");
720
721 if (ltypeWidth >= atypeWidth || rtypeWidth >= atypeWidth)
722 return op.emitError("the element type of accumulator must have "
723 "wider width than that of the operand vectors");
724 } else if (isa<FloatType>(atype)) {
725 if (!isa<FloatType>(ltype))
726 return op.emitError("Floating point result must have "
727 "floating point operands");
728 }
729
730 return success();
731}
732
733LogicalResult aievec::MulElemOp::verify() {
734 return verifyMulFMAElemOp<aievec::MulElemOp>(*this);
735}
736
737LogicalResult aievec::FMAElemOp::verify() {
738 return verifyMulFMAElemOp<aievec::FMAElemOp>(*this);
739}
740
741// Parse MulElem and FMAElem op.
742ParseResult parseMulFMAElemOp(OpAsmParser &parser, OperationState &result,
743 bool isFMAElemOp = true) {
744 llvm::SMLoc typesLoc;
745 SmallVector<Type, 3> types;
746 OpAsmParser::UnresolvedOperand lhs, rhs, acc;
747
748 // Parse the lhs and rhs
749 if (parser.parseOperand(lhs) || parser.parseComma() ||
750 parser.parseOperand(rhs))
751 return failure();
752
753 // Parse the acc for FMA op
754 if (isFMAElemOp) {
755 if (parser.parseComma() || parser.parseOperand(acc))
756 return failure();
757 }
758
759 // Parse all the attributes and types
760 if (parser.parseOptionalAttrDict(result.attributes) ||
761 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
762 return failure();
763
764 // Assert that there are three types: lhs, rhs, and acc
765 if (types.size() != 3)
766 return parser.emitError(typesLoc, "requires three types");
767
768 // Some verification
769 VectorType lhsType = llvm::dyn_cast<VectorType>(types[0]);
770 if (!lhsType)
771 return parser.emitError(typesLoc, "requires vector type");
772 VectorType rhsType = llvm::dyn_cast<VectorType>(types[1]);
773 if (!rhsType)
774 return parser.emitError(typesLoc, "requires vector type");
775
776 // Int ops use the accumulator while float ops use normal vector registers
777 VectorType accType = llvm::dyn_cast<VectorType>(types[2]);
778 if (!accType)
779 return parser.emitError(typesLoc, "requires vector type");
780
781 // Populate the lhs and rhs operands, and result
782 if (parser.resolveOperand(lhs, lhsType, result.operands) ||
783 parser.resolveOperand(rhs, rhsType, result.operands))
784 return failure();
785
786 // Populate acc operand for FMA op
787 if (isFMAElemOp) {
788 if (parser.resolveOperand(acc, accType, result.operands))
789 return failure();
790 }
791
792 return parser.addTypeToList(accType, result.types);
793}
794
795ParseResult MulElemOp::parse(OpAsmParser &parser, OperationState &result) {
796 return parseMulFMAElemOp(parser, result, false);
797}
798
799ParseResult FMAElemOp::parse(OpAsmParser &parser, OperationState &result) {
800 return parseMulFMAElemOp(parser, result, true);
801}
802
803//===----------------------------------------------------------------------===//
804// ConcatOp
805//===----------------------------------------------------------------------===//
806
807// Print out Concat op.
808void ConcatOp::print(OpAsmPrinter &p) {
809 // Print the source vectors
810 assert(!getSources().empty() && "concat source empty");
811 p << " " << getSources();
812
813 // Print the attributes
814 p.printOptionalAttrDict((*this)->getAttrs());
815
816 // And now print the types
817 p << " : " << getSources().getTypes().front() << ", "
818 << getResult().getType();
819}
820
821// Verify Concat op.
822LogicalResult ConcatOp::verify() {
823 // Must be concatenating at least two sources
824 if (getSources().size() < 2)
825 return emitError("Must concatenate at least two vectors");
826
827 // Verify the types
828 VectorType sourceType =
829 llvm::dyn_cast<VectorType>(getSources().getTypes().front());
830 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
831 if (!sourceType || !resultType)
832 return emitError("requires vector type");
833
834 SmallVector<Value, 8> srcs(getSources().begin(), getSources().end());
835 // All the sources must have the same type
836 for (auto source : srcs) {
837 VectorType type = llvm::dyn_cast<VectorType>(source.getType());
838 if (!type)
839 return emitError("requires vector type");
840 if (type != sourceType)
841 return emitError("All sources must have same type");
842 }
843
844 // The lanes in concatenated type must be the sum of lanes of source vector
845 unsigned totalLanes = 0;
846 for (auto source : srcs) {
847 VectorType type = llvm::dyn_cast<VectorType>(source.getType());
848 totalLanes += getVectorLaneSize(type);
849 }
850
851 if (totalLanes != getVectorLaneSize(resultType))
852 return emitError("mismatch between vector lanes "
853 "and sum of source lanes");
854
855 return success();
856}
857
858// Parse Concat op.
859ParseResult ConcatOp::parse(OpAsmParser &parser, OperationState &result) {
860 llvm::SMLoc typesLoc;
861 SmallVector<Type, 2> types;
862 SmallVector<OpAsmParser::UnresolvedOperand, 8> sources;
863
864 // Parse the source vectors
865 if (parser.parseOperandList(sources))
866 return failure();
867
868 // Parse all the attributes and types
869 if (parser.parseOptionalAttrDict(result.attributes) ||
870 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
871 return failure();
872
873 // Currently there are no attributes in concat op
874 if (!result.attributes.getAttrs().empty())
875 return parser.emitError(typesLoc, "expects no attribute");
876
877 // Assert that there are two types (type of all sources, and result)
878 if (types.size() != 2)
879 return parser.emitError(typesLoc, "requires two types");
880
881 // Some verification
882 VectorType sourceType = llvm::dyn_cast<VectorType>(types[0]);
883 VectorType resultType = llvm::dyn_cast<VectorType>(types[1]);
884 if (!sourceType || !resultType)
885 return parser.emitError(typesLoc, "requires vector type");
886
887 // Populate the source vectors in result
888 if (parser.resolveOperands(sources, sourceType, result.operands))
889 return failure();
890
891 return parser.addTypeToList(resultType, result.types);
892}
893
894LogicalResult
895ConcatOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
896 ConcatOp::Adaptor adaptor,
897 SmallVectorImpl<Type> &inferredReturnTypes) {
898 SmallVector<Value, 8> srcs(adaptor.getSources().begin(),
899 adaptor.getSources().end());
900 unsigned totalLength = 0;
901 for (auto source : srcs) {
902 VectorType type = llvm::dyn_cast<VectorType>(source.getType());
903 assert(type.getRank() == 1 &&
904 "only rank 1 vectors currently supported by concat");
905 totalLength += type.getDimSize(0);
906 }
907 inferredReturnTypes.push_back(VectorType::get(
908 {totalLength},
909 llvm::dyn_cast<VectorType>(srcs[0].getType()).getElementType()));
910 return success();
911}
912
913//===----------------------------------------------------------------------===//
914// ExtOp
915//===----------------------------------------------------------------------===//
916
917// Print out Ext op.
918void ExtOp::print(OpAsmPrinter &p) {
919 // Print the source vector
920 p << " " << getSource();
921
922 // Print the attributes
923 p.printOptionalAttrDict((*this)->getAttrs());
924
925 // And now print the types
926 p << " : " << getSource().getType() << ", " << getResult().getType();
927}
928
929// Verify Ext op.
930LogicalResult ExtOp::verify() {
931 // Verify the types
932 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
933 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
934 if (!sourceType || !resultType)
935 return emitError("requires vector type");
936
937 // Check the number of lanes
938 unsigned sourceLanes = getVectorLaneSize(sourceType);
939 unsigned resultLanes = getVectorLaneSize(resultType);
940 // Source lanes must be greater than result lanes
941 if (sourceLanes / resultLanes <= 1)
942 return emitError("lanes in source vector must be at least "
943 "twice that of result vector");
944 // Source lanes must be a multiple of result lanes
945 if (sourceLanes % resultLanes != 0)
946 return emitError("lanes in result vector must be a multiple "
947 "of source vector lanes");
948
949 // Verify validity of index
950 unsigned factor = sourceLanes / resultLanes;
951 if (static_cast<unsigned>(getIndex()) >= factor)
952 return emitError("index out of bounds");
953
954 // The datatype of source and result must match
955 Type stype = sourceType.getElementType();
956 Type rtype = resultType.getElementType();
957 if (stype != rtype)
958 return emitError("source and result element type must be same");
959
960 return success();
961}
962
963// Parse Ext op.
964ParseResult ExtOp::parse(OpAsmParser &parser, OperationState &result) {
965 llvm::SMLoc typesLoc;
966 SmallVector<Type, 2> types;
967 OpAsmParser::UnresolvedOperand source;
968
969 // Parse the source vector
970 if (parser.parseOperand(source))
971 return failure();
972
973 // Parse all the attributes and types
974 if (parser.parseOptionalAttrDict(result.attributes) ||
975 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
976 return failure();
977
978 if (result.attributes.getAttrs().size() != 1)
979 return parser.emitError(typesLoc, "requires one attribute");
980
981 // Assert that there are two types (source and result)
982 if (types.size() != 2)
983 return parser.emitError(typesLoc, "requires two types");
984
985 // Some verification
986 VectorType sourceType = llvm::dyn_cast<VectorType>(types[0]);
987 VectorType resultType = llvm::dyn_cast<VectorType>(types[1]);
988 if (!sourceType || !resultType)
989 return parser.emitError(typesLoc, "requires vector type");
990
991 // Populate the source in result
992 if (parser.resolveOperand(source, sourceType, result.operands))
993 return failure();
994
995 return parser.addTypeToList(resultType, result.types);
996}
997
998//===----------------------------------------------------------------------===//
999// PackOp and UnpackOp
1000//===----------------------------------------------------------------------===//
1001
1002// Print out Pack and Unpack op.
1003template <typename T>
1004static void printPackUnpackOp(OpAsmPrinter &p, T op) {
1005 // Print the source vector
1006 p << " " << op.getSource();
1007
1008 // Print the attributes
1009 p.printOptionalAttrDict(op->getAttrs());
1010
1011 // And now print the types
1012 p << " : " << op.getSource().getType() << ", " << op.getResult().getType();
1013}
1014
1015void PackOp::print(OpAsmPrinter &p) { printPackUnpackOp<PackOp>(p, *this); }
1016
1017void UnpackOp::print(OpAsmPrinter &p) { printPackUnpackOp<UnpackOp>(p, *this); }
1018
1019// Verify Pack and Unpack op.
1020template <typename T>
1021LogicalResult verifyPackUnpackOp(T op) {
1022 // Verify the types
1023 auto sourceType = llvm::dyn_cast<VectorType>(op.getSource().getType());
1024 auto resultType = llvm::dyn_cast<VectorType>(op.getResult().getType());
1025 if (!sourceType || !resultType)
1026 return op.emitError("requires vector type");
1027
1028 // The number of lanes must match
1029 unsigned sourceLanes = getVectorLaneSize(sourceType);
1030 unsigned resultLanes = getVectorLaneSize(resultType);
1031 if (sourceLanes != resultLanes)
1032 return op.emitError("The number of lanes in input and "
1033 "output vector must match");
1034
1035 Type stype = sourceType.getElementType();
1036 unsigned stypeWidth = stype.getIntOrFloatBitWidth();
1037 Type rtype = resultType.getElementType();
1038 unsigned rtypeWidth = rtype.getIntOrFloatBitWidth();
1039
1040 if (isa<PackOp>(op)) {
1041 // The datatype of source must be i16, and datatype of result must be i8
1042 if (stypeWidth != 16)
1043 return op.emitError("input must be an int16 vector");
1044 if (rtypeWidth != 8)
1045 return op.emitError("output must be an int8 vector");
1046 } else {
1047 if (stypeWidth != 8)
1048 return op.emitError("input must be an int8 vector");
1049 if (rtypeWidth != 16)
1050 return op.emitError("output must be an int16 vector");
1051 }
1052
1053 return success();
1054}
1055
1056LogicalResult PackOp::verify() { return verifyPackUnpackOp<PackOp>(*this); }
1057
1058LogicalResult UnpackOp::verify() { return verifyPackUnpackOp<UnpackOp>(*this); }
1059
1060// Parse Pack and Unpack op.
1061ParseResult parsePackUnpackOp(OpAsmParser &parser, OperationState &result) {
1062 llvm::SMLoc typesLoc;
1063 SmallVector<Type, 2> types;
1064 OpAsmParser::UnresolvedOperand source;
1065
1066 // Parse the source vector
1067 if (parser.parseOperand(source))
1068 return failure();
1069
1070 // Parse all the attributes and types
1071 if (parser.parseOptionalAttrDict(result.attributes) ||
1072 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
1073 return failure();
1074
1075 // Currently there are no attributes in pack/unpack op
1076 if (!result.attributes.getAttrs().empty())
1077 return parser.emitError(typesLoc, "expects no attributes");
1078
1079 // Assert that there are two types (source and result vectors)
1080 if (types.size() != 2)
1081 return parser.emitError(typesLoc, "requires two types");
1082
1083 // Some verification
1084 VectorType sourceType = llvm::dyn_cast<VectorType>(types[0]);
1085 VectorType resultType = llvm::dyn_cast<VectorType>(types[1]);
1086 if (!sourceType || !resultType)
1087 return parser.emitError(typesLoc, "requires vector type");
1088
1089 // Populate the source in result
1090 if (parser.resolveOperand(source, sourceType, result.operands))
1091 return failure();
1092
1093 return parser.addTypeToList(resultType, result.types);
1094}
1095
1096ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
1097 return parsePackUnpackOp(parser, result);
1098}
1099
1100ParseResult UnpackOp::parse(OpAsmParser &parser, OperationState &result) {
1101 return parsePackUnpackOp(parser, result);
1102}
1103
1104//===----------------------------------------------------------------------===//
1105// ExtElemOp
1106//===----------------------------------------------------------------------===//
1107
1108// Verify Extract Element op.
1109LogicalResult ExtElemOp::verify() {
1110 // Verify the types
1111 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
1112
1113 if (!sourceType)
1114 return emitError("source requires vector type");
1115
1116 // The element type of vectors must always be the same
1117 Type stype = sourceType.getElementType();
1118 Type rtype = getResult().getType();
1119
1120 if (stype != rtype) {
1121 return emitError("the type of result must be the same as the element "
1122 "type of source vector");
1123 }
1124
1125 return success();
1126}
1127
1128//===----------------------------------------------------------------------===//
1129// ShiftOp
1130//===----------------------------------------------------------------------===//
1131
1132// Print out Shift op.
1133void ShiftOp::print(OpAsmPrinter &p) {
1134 // Print the lhs and rhs vectors
1135 p << " " << getLhs() << ", " << getRhs();
1136
1137 // Print shift
1138 p << ", " << getShift();
1139
1140 // Print the attributes
1141 p.printOptionalAttrDict((*this)->getAttrs());
1142
1143 // And now print the types
1144 p << " : " << getLhs().getType() << ", " << getLhs().getType() << ", "
1145 << getShift().getType() << ", " << getResult().getType();
1146}
1147
1148// Verify Shift op.
1149LogicalResult ShiftOp::verify() {
1150 // Verify the types
1151 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
1152 if (!resultType)
1153 return emitError("requires vector type");
1154
1155 // lhs, rhs and result must have the same type
1156 VectorType lhsType = llvm::dyn_cast<VectorType>(getLhs().getType());
1157 VectorType rhsType = llvm::dyn_cast<VectorType>(getRhs().getType());
1158
1159 if (!lhsType || !rhsType)
1160 return emitError("requires vector type");
1161 if (lhsType != resultType || rhsType != resultType)
1162 return emitError("All vectors must have same type");
1163
1164 if (!isa<IntegerType>(getShift().getType()))
1165 return emitError("requires integer type");
1166
1167 return success();
1168}
1169
1170// Parse Shift op.
1171ParseResult ShiftOp::parse(OpAsmParser &parser, OperationState &result) {
1172 llvm::SMLoc typesLoc;
1173 SmallVector<Type, 4> types;
1174 OpAsmParser::UnresolvedOperand lhs, rhs, shift;
1175
1176 // Parse the source vectors
1177 if (parser.parseOperand(lhs) || parser.parseComma() ||
1178 parser.parseOperand(rhs) || parser.parseComma() ||
1179 parser.parseOperand(shift))
1180 return failure();
1181
1182 // Parse all the attributes and types
1183 if (parser.parseOptionalAttrDict(result.attributes) ||
1184 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
1185 return failure();
1186
1187 if (result.attributes.getAttrs().size() != 1)
1188 return parser.emitError(typesLoc, "expects one attribute");
1189
1190 // Assert that there are two types (source and result vectors)
1191 if (types.size() != 4)
1192 return parser.emitError(typesLoc, "requires four types");
1193
1194 // Some verification
1195 VectorType lhsType = llvm::dyn_cast<VectorType>(types[0]);
1196 VectorType rhsType = llvm::dyn_cast<VectorType>(types[1]);
1197 IntegerType shiftType = llvm::dyn_cast<IntegerType>(types[2]);
1198 VectorType resultType = llvm::dyn_cast<VectorType>(types[3]);
1199 if (!lhsType || !rhsType || !resultType)
1200 return parser.emitError(typesLoc, "requires vector type");
1201
1202 if (!shiftType)
1203 return parser.emitError(typesLoc, "requires integer type");
1204
1205 // Populate the lhs vector, rhs vectors and shift in result
1206 if (parser.resolveOperand(lhs, lhsType, result.operands) ||
1207 parser.resolveOperand(rhs, rhsType, result.operands) ||
1208 parser.resolveOperand(shift, shiftType, result.operands))
1209 return failure();
1210
1211 return parser.addTypeToList(resultType, result.types);
1212}
1213
1214//===----------------------------------------------------------------------===//
1215// ShuffleOp
1216//===----------------------------------------------------------------------===//
1217
1218// This verification function makes sure that the shuffle mode supports the
1219// number and type of operands provided.
1220LogicalResult ShuffleOp::verify() {
1221 unsigned modeBitWidth;
1222 bool requireRhs = true;
1223 auto mode = getMode();
1224 switch (mode) {
1225 case ShuffleMode::T8_8X8: // 35
1226 case ShuffleMode::T8_16X4: // 36
1227 case ShuffleMode::T8_4X16: // 37
1228 case ShuffleMode::T8_8X4: // 46
1229 case ShuffleMode::T8_4X8: // 47
1230 requireRhs = false;
1231 LLVM_FALLTHROUGH;
1232 case ShuffleMode::T8_64X2_LO: // 0
1233 case ShuffleMode::T8_64X2_HI: // 1
1234 case ShuffleMode::T8_2X64_LO: // 20
1235 case ShuffleMode::T8_2X64_HI: // 21
1236 modeBitWidth = 8u;
1237 break;
1238 case ShuffleMode::T16_8X4: // 28
1239 case ShuffleMode::T16_4X8: // 29
1240 case ShuffleMode::T16_1X2_flip: // 38
1241 case ShuffleMode::T16_4X4: // 39
1242 case ShuffleMode::T16_4X2: // 40
1243 case ShuffleMode::T16_2X4: // 41
1244 case ShuffleMode::T16_8X2: // 42
1245 case ShuffleMode::T16_2X8: // 43
1246 case ShuffleMode::T16_16X2: // 44
1247 case ShuffleMode::T16_2X16: // 45
1248 requireRhs = false;
1249 LLVM_FALLTHROUGH;
1250 case ShuffleMode::T16_32X2_LO: // 2
1251 case ShuffleMode::T16_32X2_HI: // 3
1252 case ShuffleMode::T16_2X32_LO: // 18
1253 case ShuffleMode::T16_2X32_HI: // 19
1254 case ShuffleMode::T16_16X4_LO: // 24
1255 case ShuffleMode::T16_16X4_HI: // 25
1256 case ShuffleMode::T16_4X16_LO: // 26
1257 case ShuffleMode::T16_4X16_HI: // 27
1258 modeBitWidth = 16u;
1259 break;
1260 case ShuffleMode::T32_4X4: // 34
1261 requireRhs = false;
1262 LLVM_FALLTHROUGH;
1263 case ShuffleMode::T32_16X2_LO: // 4
1264 case ShuffleMode::T32_16X2_HI: // 5
1265 case ShuffleMode::T32_2X16_LO: // 16
1266 case ShuffleMode::T32_2X16_HI: // 17
1267 case ShuffleMode::T32_8X4_LO: // 30
1268 case ShuffleMode::T32_8X4_HI: // 31
1269 case ShuffleMode::T32_4X8_LO: // 32
1270 case ShuffleMode::T32_4X8_HI: // 33
1271 modeBitWidth = 32u;
1272 break;
1273 case ShuffleMode::T64_8X2_LO: // 6
1274 case ShuffleMode::T64_8X2_HI: // 7
1275 case ShuffleMode::T64_2X8_LO: // 14
1276 case ShuffleMode::T64_2X8_HI: // 15
1277 modeBitWidth = 64u;
1278 break;
1279 case ShuffleMode::T128_4X2_LO: // 8
1280 case ShuffleMode::T128_4X2_HI: // 9
1281 case ShuffleMode::T128_2X4_LO: // 12
1282 case ShuffleMode::T128_2X4_HI: // 13
1283 modeBitWidth = 128u;
1284 break;
1285 case ShuffleMode::T256_2X2_LO: // 10
1286 case ShuffleMode::T256_2X2_HI: // 11
1287 modeBitWidth = 256u;
1288 break;
1289 case ShuffleMode::T512_1X2_LO: // 22
1290 case ShuffleMode::T512_1X2_HI: // 23
1291 modeBitWidth = 512u;
1292 break;
1293 }
1294
1295 // Verify number of operands
1296 if (requireRhs && !getRhs())
1297 return emitError() << "shuffle mode '" << stringifyEnum(mode)
1298 << "' requires a second operand";
1299
1300 if (!requireRhs && getRhs())
1301 return emitError() << "shuffle mode '" << stringifyEnum(mode)
1302 << "' does not admit a second operand";
1303
1304 // Verify vector element type
1305 auto elemBitWidth =
1306 cast<VectorType>(getLhs().getType()).getElementTypeBitWidth();
1307 if (modeBitWidth != elemBitWidth)
1308 return emitError() << "shuffle mode '" << stringifyEnum(mode)
1309 << "' requires vectors of " << modeBitWidth
1310 << "-bit elements";
1311
1312 return success();
1313}
1314
1315// Print out Shuffle op.
1316void LegacyShuffleOp::print(OpAsmPrinter &p) {
1317 // Print the source vector
1318 p << " " << getSource();
1319
1320 // Print the attributes
1321 p.printOptionalAttrDict((*this)->getAttrs());
1322
1323 // And now print the types
1324 p << " : " << getSource().getType() << ", " << getResult().getType();
1325}
1326
1327// Verify Shuffle op.
1328LogicalResult LegacyShuffleOp::verify() {
1329 // Verify the types
1330 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
1331 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
1332 if (!sourceType || !resultType)
1333 return emitError("requires vector type");
1334
1335 // The number of lanes must match
1336 unsigned sourceLanes = getVectorLaneSize(sourceType);
1337 unsigned resultLanes = getVectorLaneSize(resultType);
1338 if (sourceLanes != resultLanes)
1339 return emitError("The number of lanes in input and "
1340 "output vector must match");
1341
1342 Type stype = sourceType.getElementType();
1343 unsigned stypeWidth = stype.getIntOrFloatBitWidth();
1344 Type rtype = resultType.getElementType();
1345 unsigned rtypeWidth = rtype.getIntOrFloatBitWidth();
1346
1347 if (stypeWidth != rtypeWidth)
1348 return emitError("The type width in input and "
1349 "output must match");
1350
1351 return success();
1352}
1353
1354// Parse Shuffle op.
1355ParseResult LegacyShuffleOp::parse(OpAsmParser &parser,
1356 OperationState &result) {
1357 llvm::SMLoc typesLoc;
1358 SmallVector<Type, 2> types;
1359 OpAsmParser::UnresolvedOperand source;
1360
1361 // Parse the source vector
1362 if (parser.parseOperand(source))
1363 return failure();
1364
1365 // Parse all the attributes and types
1366 if (parser.parseOptionalAttrDict(result.attributes) ||
1367 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
1368 return failure();
1369
1370 // Currently there is one attribute in shuffle op
1371 if (result.attributes.getAttrs().size() != 1)
1372 return parser.emitError(typesLoc, "expects one attribute");
1373
1374 // Some verification
1375 VectorType sourceType = llvm::dyn_cast<VectorType>(types[0]);
1376 VectorType resultType = llvm::dyn_cast<VectorType>(types[1]);
1377 if (!sourceType || !resultType)
1378 return parser.emitError(typesLoc, "requires vector type");
1379
1380 // Populate the source vectors in result
1381 if (parser.resolveOperand(source, sourceType, result.operands))
1382 return failure();
1383
1384 return parser.addTypeToList(resultType, result.types);
1385}
1386
1387//===----------------------------------------------------------------------===//
1388// MulConvOp and FMAConvOp
1389//===----------------------------------------------------------------------===//
1390
1391// MulConvOp and FMAConvOp are structurally similar, except that FMAConv op
1392// has few extra fields (accumulator, bool flag to indicate if it is fmsub,
1393// etc.). We create some specializations to print those fields specifically for
1394// FMAConvOp and MULConvOp.
1395
1396// Print the accumulator
1397template <typename T>
1398void printAccumulator(OpAsmPrinter &p, T op);
1399template <>
1400inline void printAccumulator(OpAsmPrinter &p, aievec::FMAConvOp op) {
1401 p << ", " << op.getAcc();
1402}
1403template <>
1404inline void printAccumulator(OpAsmPrinter &p, aievec::MulConvOp op) {}
1405
1406// Mark fmsub indicator as elided if the FMAElem op is not fmsub
1407template <typename T>
1408void elideFMSubAttr(T op, SmallVector<StringRef, 4> &elidedAttrs);
1409template <>
1410inline void elideFMSubAttr(FMAConvOp op,
1411 SmallVector<StringRef, 4> &elidedAttrs) {
1412 if (!op.getFmsub())
1413 elidedAttrs.push_back(op.getSubAttrName());
1414}
1415
1416template <>
1417inline void elideFMSubAttr(MulConvOp op,
1418 SmallVector<StringRef, 4> &elidedAttrs) {}
1419
1420// Print out MulConv and FMAConv op.
1421template <typename T>
1422static void printMulFMAConvOp(OpAsmPrinter &p, T op) {
1423 // Print the left operand
1424 p << " " << op.getLhs();
1425 // Print the right operand
1426 p << ", " << op.getRhs();
1427 // For fma op, print the accumulator
1428 printAccumulator(p, op);
1429
1430 // Print the attributes, but don't print attributes that are empty strings
1431 SmallVector<StringRef, 4> elidedAttrs;
1432 for (int idx = 0; idx < 2; ++idx) {
1433 elideFMSubAttr(op, elidedAttrs);
1434 }
1435 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
1436
1437 // And now print the types
1438 p << " : " << op.getLhs().getType() << ", " << op.getRhs().getType();
1439 p << ", " << op.getResult().getType();
1440}
1441
1442void MulConvOp::print(OpAsmPrinter &p) {
1443 printMulFMAConvOp<aievec::MulConvOp>(p, *this);
1444}
1445
1446void aievec::FMAConvOp::print(OpAsmPrinter &p) {
1447 printMulFMAConvOp<aievec::FMAConvOp>(p, *this);
1448}
1449
1450// Verify MulConv and FMAConv op.
1451template <typename T>
1452LogicalResult verifyMulFMAConvOp(T op) {
1453 // Verify the types
1454 auto lhsType = llvm::dyn_cast<VectorType>(op.getLhs().getType());
1455 auto rhsType = llvm::dyn_cast<VectorType>(op.getRhs().getType());
1456
1457 if (!lhsType || !rhsType)
1458 return op.emitError("requires vector type");
1459
1460 unsigned M = op.getM();
1461 unsigned N = op.getN();
1462
1463 if (M <= 0 || N <= 0 || 2 * M < M + N - 1)
1464 return op.emitError(
1465 "M and N should be larger than 0 and 2*M should be no less than M+N-1");
1466
1467 auto resultType = llvm::dyn_cast<VectorType>(op.getResult().getType());
1468
1469 if (!resultType)
1470 return op.emitError("requires vector type");
1471
1472 // Additional checks for FMAElem op
1473 // Get the width of the underlying scalars of all the vectors
1474 Type ltype = lhsType.getElementType();
1475 Type rtype = rhsType.getElementType();
1476 Type atype = resultType.getElementType();
1477
1478 // lhs and rhs vector's element type must match
1479 if (ltype != rtype)
1480 return op.emitError("The element type of lhs and rhs "
1481 "operand vectors must match");
1482
1483 if (!isa<IntegerType>(ltype) || !isa<IntegerType>(rtype) ||
1484 !isa<IntegerType>(atype)) {
1485 return op.emitError("requires integer type");
1486 }
1487
1488 unsigned ltypeWidth = ltype.getIntOrFloatBitWidth();
1489 unsigned rtypeWidth = rtype.getIntOrFloatBitWidth();
1490 unsigned atypeWidth = atype.getIntOrFloatBitWidth();
1491
1492 // Checks on the number of lanes
1493 unsigned accLanes = getVectorLaneSize(resultType);
1494 unsigned rhsLanes = getVectorLaneSize(rhsType);
1495 unsigned lhsLanes = getVectorLaneSize(lhsType);
1496
1497 // lane size must match
1498 if (accLanes != M || accLanes != (rhsLanes / 2) || lhsLanes != rhsLanes) {
1499 return op.emitError(
1500 "The number of lanes in accumulator "
1501 "must be the same as M and the half as lhs and rhs operand");
1502 }
1503
1504 // The integer datatype of accumulator must always be greater width
1505 if (ltypeWidth >= atypeWidth || rtypeWidth >= atypeWidth)
1506 return op.emitError("the element type of accumulator must have "
1507 "wider width than that of the operand vectors");
1508
1509 return success();
1510}
1511
1512LogicalResult aievec::MulConvOp::verify() {
1513 return verifyMulFMAConvOp<aievec::MulConvOp>(*this);
1514}
1515
1516LogicalResult aievec::FMAConvOp::verify() {
1517 return verifyMulFMAConvOp<aievec::FMAConvOp>(*this);
1518}
1519
1520// Parse MulConv and FMAConv op.
1521ParseResult parseMulFMAConvOp(OpAsmParser &parser, OperationState &result,
1522 bool isFMAConvOp = true) {
1523 llvm::SMLoc typesLoc;
1524 SmallVector<Type, 3> types;
1525 OpAsmParser::UnresolvedOperand lhs, rhs, acc;
1526
1527 // Parse the lhs and rhs
1528 if (parser.parseOperand(lhs) || parser.parseComma() ||
1529 parser.parseOperand(rhs))
1530 return failure();
1531
1532 // Parse the acc for FMA op
1533 if (isFMAConvOp) {
1534 if (parser.parseComma() || parser.parseOperand(acc))
1535 return failure();
1536 }
1537
1538 // Parse all the attributes and types
1539 if (parser.parseOptionalAttrDict(result.attributes) ||
1540 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
1541 return failure();
1542
1543 // Assert that there are three types: lhs, rhs, and acc
1544 if (types.size() != 3)
1545 return parser.emitError(typesLoc, "requires three types");
1546
1547 // Some verification
1548 VectorType lhsType = llvm::dyn_cast<VectorType>(types[0]);
1549 if (!lhsType)
1550 return parser.emitError(typesLoc, "requires vector type");
1551 VectorType rhsType = llvm::dyn_cast<VectorType>(types[1]);
1552 if (!rhsType)
1553 return parser.emitError(typesLoc, "requires vector type");
1554
1555 // Int ops use the accumulator
1556 VectorType accType = llvm::dyn_cast<VectorType>(types[2]);
1557 if (!accType)
1558 return parser.emitError(typesLoc, "requires vector type");
1559
1560 // Populate the lhs and rhs operands, and result
1561 if (parser.resolveOperand(lhs, lhsType, result.operands) ||
1562 parser.resolveOperand(rhs, rhsType, result.operands))
1563 return failure();
1564
1565 // Populate acc operand for FMA op
1566 if (isFMAConvOp) {
1567 if (parser.resolveOperand(acc, accType, result.operands))
1568 return failure();
1569 }
1570
1571 return parser.addTypeToList(accType, result.types);
1572}
1573
1574ParseResult MulConvOp::parse(OpAsmParser &parser, OperationState &result) {
1575 return parseMulFMAConvOp(parser, result, false);
1576}
1577
1578ParseResult FMAConvOp::parse(OpAsmParser &parser, OperationState &result) {
1579 return parseMulFMAConvOp(parser, result, true);
1580}
1581
1582#define GET_ATTRDEF_CLASSES
1583#include "aie/Dialect/AIEVec/IR/AIEVecAttributes.cpp.inc"
1584
1585#define GET_OP_CLASSES
1586#include "aie/Dialect/AIEVec/IR/AIEVecOps.cpp.inc"
ParseResult parsePackUnpackOp(OpAsmParser &parser, OperationState &result)
void printAccumulator(OpAsmPrinter &p, T op)
ParseResult parseMulFMAElemOp(OpAsmParser &parser, OperationState &result, bool isFMAElemOp=true)
ParseResult parseMulFMAConvOp(OpAsmParser &parser, OperationState &result, bool isFMAConvOp=true)
void elideFMSubAttr(T op, SmallVector< StringRef, 4 > &elidedAttrs)
LogicalResult verifyMulFMAElemOp(T op)
LogicalResult verifyMulFMAConvOp(T op)
LogicalResult verifyPackUnpackOp(T op)
bool empty(const std::string &s)
Definition cxxopts.hpp:275
std::vector< Port > srcs
unsigned getVectorLaneSize(mlir::VectorType type)
Definition AIEVecUtils.h:55