MLIR-AIE
AIEVecOps.cpp
Go to the documentation of this file.
1//===---- AIEVecOps.cpp - MLIR AIE Vector Dialect Operations ----*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2022-2024 Advanced Micro Devices, Inc. or its affiliates
8//
9//===----------------------------------------------------------------------===//
10// This file implements AIE vector op printing, pasing, and verification.
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
14#include "mlir/IR/DialectImplementation.h"
15#include "mlir/IR/OpDefinition.h"
16#include "mlir/IR/TypeUtilities.h"
17#include "mlir/Transforms/FoldUtils.h"
18#include "llvm/ADT/TypeSwitch.h"
19
22
23using namespace llvm;
24using namespace mlir;
25using namespace xilinx;
26using namespace xilinx::aievec;
27
28#include "aie/Dialect/AIEVec/IR/AIEVecEnums.cpp.inc"
29#include "aie/Dialect/AIEVec/IR/AIEVecOpsDialect.cpp.inc"
30
31//===----------------------------------------------------------------------===//
32// AIEVecDialect
33//===----------------------------------------------------------------------===//
34
35void AIEVecDialect::initialize() {
36 registerTypes();
37 addAttributes<
38#define GET_ATTRDEF_LIST
39#include "aie/Dialect/AIEVec/IR/AIEVecAttributes.cpp.inc"
40 >();
41 addOperations<
42#define GET_OP_LIST
43#include "aie/Dialect/AIEVec/IR/AIEVecOps.cpp.inc"
44 >();
45}
46
47//===----------------------------------------------------------------------===//
48// UPDOp
49//===----------------------------------------------------------------------===//
50
51// Print out UPD op.
52void UPDOp::print(OpAsmPrinter &p) {
53 // Print the source memref
54 p << " " << getSource() << "[" << getIndices() << "]";
55 // Now print the optional vector that links upd idx=1 with idx=0
56 if (getVector())
57 p << ", " << getVector();
58
59 // Print the attributes, but don't print the operand segment sizes
60 SmallVector<StringRef, 3> elidedAttrs;
61 elidedAttrs.push_back(UPDOp::getOperandSegmentSizeAttr());
62 p.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
63
64 // And now print the types
65 p << " : " << getSource().getType() << ", " << getResult().getType();
66}
67
68// Verify UPD op.
69LogicalResult UPDOp::verify() {
70 // Verify the types: source is memref, and result is vector
71 MemRefType sourceType = llvm::dyn_cast<MemRefType>(getSource().getType());
72 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
73 if (!sourceType)
74 return emitError("requires memref type");
75 if (!resultType)
76 return emitError("requires vector type");
77 if (getIndices().empty())
78 return emitError("upd source cannot come from scalar value");
79
80 // If this UPD op is linked to another UPD op, then verify that the linked
81 // vector and the result vector match.
82 if (getVector()) {
83 Type vecType = llvm::dyn_cast<VectorType>(getVector().getType());
84 if (vecType != resultType)
85 return emitError("result types of linked UPD ops do not match");
86 }
87 return success();
88}
89
90// Parse UPD op.
91ParseResult UPDOp::parse(OpAsmParser &parser, OperationState &result) {
92 auto &builder = parser.getBuilder();
93 llvm::SMLoc typesLoc;
94 SmallVector<Type, 2> types;
95 OpAsmParser::UnresolvedOperand source, vector;
96 SmallVector<OpAsmParser::UnresolvedOperand, 8> indices;
97
98 // Parse the source, indices, and optional vector
99 if (parser.parseOperand(source) ||
100 parser.parseOperandList(indices, OpAsmParser::Delimiter::Square))
101 return failure();
102 ParseResult hasVector = parser.parseOptionalComma();
103 if (hasVector.succeeded() && parser.parseOperand(vector))
104 return failure();
105
106 // Parse all the attributes and types
107 if (parser.parseOptionalAttrDict(result.attributes) ||
108 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
109 return failure();
110
111 if (result.attributes.getAttrs().size() != 2)
112 return parser.emitError(typesLoc, "requires two attributes");
113
114 // Assert that there are two types (memref source and vector result)
115 if (types.size() != 2)
116 return parser.emitError(typesLoc, "requires two types");
117
118 // Some verification
119 auto memrefType = llvm::dyn_cast<MemRefType>(types[0]);
120 if (!memrefType)
121 return parser.emitError(typesLoc, "requires memref type");
122 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
123 if (!vectorType)
124 return parser.emitError(typesLoc, "requires vector type");
125 auto indicesType = builder.getIndexType();
126
127 // Populate the source and indices in result
128 if (parser.resolveOperand(source, memrefType, result.operands) ||
129 parser.resolveOperands(indices, indicesType, result.operands))
130 return failure();
131 // Populate optional vector in result
132 if (hasVector.succeeded())
133 if (parser.resolveOperand(vector, vectorType, result.operands))
134 return failure();
135
136 // Populate operand size attribute in result
137 result.addAttribute(UPDOp::getOperandSegmentSizeAttr(),
138 builder.getDenseI32ArrayAttr(
139 {1, static_cast<int32_t>(indices.size()),
140 static_cast<int32_t>(hasVector.succeeded())}));
141
142 return parser.addTypeToList(vectorType, result.types);
143}
144
145//===----------------------------------------------------------------------===//
146// CastOp
147//===----------------------------------------------------------------------===//
148
149// Print out Cast op.
150void CastOp::print(OpAsmPrinter &p) {
151 // Print the source accumulator
152 p << " " << getSource();
153
154 // Print the attributes
155 p.printOptionalAttrDict((*this)->getAttrs());
156
157 // And now print the types
158 p << " : " << getSource().getType() << ", " << getResult().getType();
159}
160
161// Verify Cast op.
162LogicalResult CastOp::verify() {
163 // Verify the types
164 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
165 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
166 if (!sourceType)
167 return emitError("requires source vector type");
168 if (!resultType)
169 return emitError("requires result vector type");
170
171 if (sourceType.getElementType().getIntOrFloatBitWidth() !=
172 resultType.getElementType().getIntOrFloatBitWidth()) {
173 return emitError("the bitwidth of resource and result should be equal");
174 }
175
176 return success();
177}
178
179// Parse Cast op.
180ParseResult CastOp::parse(OpAsmParser &parser, OperationState &result) {
181 llvm::SMLoc typesLoc;
182 SmallVector<Type, 2> types;
183 OpAsmParser::UnresolvedOperand source;
184
185 // Parse the source vector
186 if (parser.parseOperand(source))
187 return failure();
188
189 // Parse all the attributes and types
190 if (parser.parseOptionalAttrDict(result.attributes) ||
191 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
192 return failure();
193
194 if (result.attributes.getAttrs().size() != 1)
195 return parser.emitError(typesLoc, "requires one attribute");
196
197 // Assert that there are two types (source and result)
198 if (types.size() != 2)
199 return parser.emitError(typesLoc, "requires two types");
200
201 // Some verification of types
202 VectorType sourceType = llvm::dyn_cast<VectorType>(types[0]);
203 if (!sourceType)
204 return parser.emitError(typesLoc, "requires vector type");
205 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
206 if (!vectorType)
207 return parser.emitError(typesLoc, "requires vector type");
208
209 // Populate the source in result
210 if (parser.resolveOperand(source, sourceType, result.operands))
211 return failure();
212
213 return parser.addTypeToList(vectorType, result.types);
214}
215
216// Cast fold method. It will fold with a preceding Cast operation.
217OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
218 auto srcCastOp = getSource().getDefiningOp<aievec::CastOp>();
219 if (!srcCastOp)
220 return nullptr;
221
222 if (srcCastOp.getIsResAcc() == getIsResAcc())
223 return srcCastOp.getResult();
224
225 return srcCastOp.getSource();
226}
227
228//===----------------------------------------------------------------------===//
229// SRSOp
230//===----------------------------------------------------------------------===//
231
232// SRS fold method. It will fold with a preceding UPS operation.
233OpFoldResult SRSOp::fold(FoldAdaptor adaptor) {
234 auto srcDefOp = getSource().getDefiningOp();
235 if (!srcDefOp)
236 return nullptr;
237
238 auto upsOp = dyn_cast<UPSOp>(srcDefOp);
239 if (!upsOp)
240 return nullptr;
241
242 auto shiftDefOp = getShift().getDefiningOp();
243 if (!shiftDefOp)
244 return nullptr;
245
246 auto constOp = dyn_cast<arith::ConstantOp>(shiftDefOp);
247 if (!constOp)
248 return nullptr;
249
250 if (upsOp.getSource().getType() != getResult().getType())
251 return nullptr;
252
253 return upsOp.getSource();
254}
255
256// Print out SRS op.
257void SRSOp::print(OpAsmPrinter &p) {
258 // Print the source accumulator
259 p << " " << getSource() << ", ";
260
261 // Print the shift
262 p << getShift();
263
264 // And now print the types
265 p << " : " << getSource().getType() << ", " << getShift().getType() << ", "
266 << getResult().getType();
267}
268
269// Verify SRS op.
270LogicalResult SRSOp::verify() {
271 // Verify the types
272 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
273 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
274 if (!sourceType)
275 return emitError("requires accumulator type");
276 if (!resultType)
277 return emitError("requires vector type");
278
279 // The number of lanes of source accumulator and result vector must match
280 unsigned accLanes = getVectorLaneSize(sourceType);
281 unsigned vecLanes = getVectorLaneSize(resultType);
282 if (accLanes != vecLanes)
283 return emitError("The number of lanes in result vector "
284 "and source accumulator must match");
285
286 // The datatype of accumulator must have greater width
287 Type stype = resultType.getElementType();
288 Type atype = sourceType.getElementType();
289 unsigned stypeWidth = stype.getIntOrFloatBitWidth();
290 unsigned atypeWidth = atype.getIntOrFloatBitWidth();
291
292 if (isa<IntegerType>(atype) && stypeWidth >= atypeWidth)
293 return emitError("the element type of source accumulator must be "
294 "wider than that of the result vector");
295 else if (isa<FloatType>(atype) && stypeWidth != 16 &&
296 stypeWidth != atypeWidth)
297 return emitError("the element type of source accumulator must be "
298 "same as the result vector");
299
300 return success();
301}
302
303// Parse SRS op.
304ParseResult SRSOp::parse(OpAsmParser &parser, OperationState &result) {
305 llvm::SMLoc typesLoc;
306 SmallVector<Type, 3> types;
307 OpAsmParser::UnresolvedOperand source, shift;
308
309 // Parse the source accumulator
310 if (parser.parseOperand(source) || parser.parseComma() ||
311 parser.parseOperand(shift))
312 return failure();
313
314 // Parse types
315 if (parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
316 return failure();
317
318 // Assert that there are two types (accumulator source and vector result)
319 if (types.size() != 3)
320 return parser.emitError(typesLoc, "requires three types");
321
322 // Some verification of types
323 VectorType accType = llvm::dyn_cast<VectorType>(types[0]);
324 if (!accType)
325 return parser.emitError(typesLoc, "requires vector type");
326
327 IntegerType shiftType = llvm::dyn_cast<IntegerType>(types[1]);
328 if (!shiftType)
329 return parser.emitError(typesLoc, "requires integer type");
330
331 VectorType vectorType = llvm::dyn_cast<VectorType>(types[2]);
332 if (!vectorType)
333 return parser.emitError(typesLoc, "requires vector type");
334
335 // Populate the source in result
336 if (parser.resolveOperand(source, accType, result.operands) ||
337 parser.resolveOperand(shift, shiftType, result.operands))
338 return failure();
339
340 return parser.addTypeToList(vectorType, result.types);
341}
342
343//===----------------------------------------------------------------------===//
344// UPSOp
345//===----------------------------------------------------------------------===//
346
347// UPS fold method. It will fold with a preceding SRS operation.
348OpFoldResult UPSOp::fold(FoldAdaptor adaptor) {
349 // TODO: Both UPS and SRS have an aditional parameter (shift) that's being
350 // TODO: ignored here. Somebody should take a careful look at it.
351 // TODO: In next llvm version: auto srsDefOp =
352 // adaptor.getSource().getDefiningOp();
353 auto srcDefOp = getSource().getDefiningOp();
354 if (!srcDefOp)
355 return nullptr;
356 auto srsOp = llvm::dyn_cast<SRSOp>(srcDefOp);
357 if (!srsOp)
358 return nullptr;
359 return srsOp.getSource();
360}
361
362// Print out UPS op.
363void UPSOp::print(OpAsmPrinter &p) {
364 // Print the source vector
365 p << " " << getSource();
366
367 // Print the attributes
368 p.printOptionalAttrDict((*this)->getAttrs());
369
370 // And now print the types
371 p << " : " << getSource().getType() << ", " << getResult().getType();
372}
373
374// Verify UPS op.
375LogicalResult UPSOp::verify() {
376 // Verify the types
377 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
378 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
379 if (!sourceType)
380 return emitError("requires vector type");
381 if (!resultType)
382 return emitError("requires vector type");
383
384 // The number of lanes must match
385 unsigned vecLanes = getVectorLaneSize(sourceType);
386 unsigned accLanes = getVectorLaneSize(resultType);
387 if (vecLanes != accLanes)
388 return emitError("The number of lanes in source vector "
389 "and result accumulator must match");
390
391 // The datatype of accumulator must always be greater width
392 Type stype = sourceType.getElementType();
393 Type atype = resultType.getElementType();
394 unsigned stypeWidth = stype.getIntOrFloatBitWidth();
395 unsigned atypeWidth = atype.getIntOrFloatBitWidth();
396
397 if (stypeWidth >= atypeWidth)
398 return emitError("the element type of result accumulator "
399 "must be wider than that of the source vector");
400
401 return success();
402}
403
404// Parse UPS op.
405ParseResult UPSOp::parse(OpAsmParser &parser, OperationState &result) {
406 llvm::SMLoc typesLoc;
407 SmallVector<Type, 2> types;
408 OpAsmParser::UnresolvedOperand source;
409
410 // Parse the source vector
411 if (parser.parseOperand(source))
412 return failure();
413
414 // Parse all the attributes and types
415 if (parser.parseOptionalAttrDict(result.attributes) ||
416 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
417 return failure();
418
419 if (result.attributes.getAttrs().size() != 1)
420 return parser.emitError(typesLoc, "requires one attribute");
421
422 // Assert that there are two types (source vector and accumulator result)
423 if (types.size() != 2)
424 return parser.emitError(typesLoc, "requires two types");
425
426 // Some verification
427 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
428 if (!vectorType)
429 return parser.emitError(typesLoc, "requires vector type");
430 VectorType accType = llvm::dyn_cast<VectorType>(types[1]);
431 if (!accType)
432 return parser.emitError(typesLoc, "requires vector type");
433
434 // Populate the source in result
435 if (parser.resolveOperand(source, vectorType, result.operands))
436 return failure();
437
438 return parser.addTypeToList(accType, result.types);
439}
440
441//===----------------------------------------------------------------------===//
442// BroadcastOp
443//===----------------------------------------------------------------------===//
444
445// Print out Broadcast op.
446void BroadcastOp::print(OpAsmPrinter &p) {
447 // Print the source vector
448 p << " " << getSource();
449
450 // Print the attributes
451 p.printOptionalAttrDict((*this)->getAttrs());
452
453 // And now print the types
454 p << " : " << getSource().getType() << ", " << getResult().getType();
455}
456
457// Verify Broadcast op.
458LogicalResult BroadcastOp::verify() {
459 // Verify the types
460 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
461 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
462
463 if (!sourceType)
464 return emitError("requires vector type");
465 if (!resultType)
466 return emitError("requires vector type");
467
468 if (sourceType != resultType) {
469 return emitError("The vector type of source vector "
470 "and result vector must match");
471 }
472 // The number of lanes must match
473 unsigned sourceLanes = getVectorLaneSize(sourceType);
474 unsigned resultLanes = getVectorLaneSize(resultType);
475 if (sourceLanes != resultLanes)
476 return emitError("The number of lanes in source vector "
477 "and result vector must match");
478
479 // The element type of vectors must always be the same
480 Type stype = sourceType.getElementType();
481 Type rtype = resultType.getElementType();
482
483 if (stype != rtype) {
484 return emitError("the element type of result vector "
485 "must be the same as source vector");
486 }
487
488 return success();
489}
490
491// Parse Broadcast op.
492ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
493 llvm::SMLoc typesLoc;
494 SmallVector<Type, 2> types;
495 OpAsmParser::UnresolvedOperand source;
496
497 // Parse the source vector
498 if (parser.parseOperand(source))
499 return failure();
500
501 // Parse all the attributes and types
502 if (parser.parseOptionalAttrDict(result.attributes) ||
503 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
504 return failure();
505
506 if (result.attributes.getAttrs().size() != 1)
507 return parser.emitError(typesLoc, "requires one attribute");
508
509 // Assert that there are two types (source vector and result vector)
510 if (types.size() != 2)
511 return parser.emitError(typesLoc, "requires two types");
512
513 // Some verification
514 VectorType vecType = llvm::dyn_cast<VectorType>(types[0]);
515 if (!vecType)
516 return parser.emitError(typesLoc, "requires vector type");
517
518 VectorType resType = llvm::dyn_cast<VectorType>(types[1]);
519 if (!resType)
520 return parser.emitError(typesLoc, "requires vector type");
521
522 // Populate the source in result
523 if (parser.resolveOperand(source, vecType, result.operands))
524 return failure();
525
526 return parser.addTypeToList(resType, result.types);
527}
528
529//===----------------------------------------------------------------------===//
530// BroadcastScalarOp
531//===----------------------------------------------------------------------===//
532
533// Print out BroadcastScalar op.
534void BroadcastScalarOp::print(OpAsmPrinter &p) {
535 // Print the source vector
536 p << " " << getSource();
537
538 // And now print the types
539 p << " : " << getSource().getType() << ", " << getResult().getType();
540}
541
542// Verify BroadcastScalar op.
543LogicalResult BroadcastScalarOp::verify() {
544 // Verify the types
545 Type sourceType = getSource().getType();
546 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
547
548 if (!resultType)
549 return emitError("requires vector type");
550
551 if (!isa<IntegerType, FloatType>(sourceType))
552 return emitError("requires source type to be integer or float");
553
554 Type resultElemType = resultType.getElementType();
555 if (sourceType != resultElemType) {
556 return emitError("the element type of result vector must be the same as "
557 "the source type");
558 }
559
560 return success();
561}
562
563// Parse BroadcastScalar op.
564ParseResult BroadcastScalarOp::parse(OpAsmParser &parser,
565 OperationState &result) {
566 llvm::SMLoc typesLoc;
567 SmallVector<Type, 2> types;
568 OpAsmParser::UnresolvedOperand source;
569
570 // Parse the source vector
571 if (parser.parseOperand(source))
572 return failure();
573
574 // Parse all the attributes and types
575 if (parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
576 return failure();
577
578 if (!result.attributes.getAttrs().empty())
579 return parser.emitError(typesLoc, "do not require attributes");
580
581 // Assert that there is two type (source and result vector)
582 if (types.size() != 2)
583 return parser.emitError(typesLoc, "requires two types");
584
585 // Some verification
586 VectorType resType = llvm::dyn_cast<VectorType>(types[1]);
587 if (!resType)
588 return parser.emitError(typesLoc, "requires vector type");
589
590 if (parser.resolveOperand(source, types[0], result.operands))
591 return failure();
592
593 return parser.addTypeToList(resType, result.types);
594}
595
596//===----------------------------------------------------------------------===//
597// MulElemOp and FMAElemOp
598//===----------------------------------------------------------------------===//
599
600// MulElemOp and FMAElemOp are structurally similar, except that FMAElem op
601// has few extra fields (accumulator, bool flag to indicate if it is fmsub,
602// etc.). We create some specializations to print those fields specifically for
603// FMAElemOp and MULElemOp.
604
605// Print the accumulator
606template <typename T>
607void printAccumulator(OpAsmPrinter &p, T op);
608template <>
609inline void printAccumulator(OpAsmPrinter &p, aievec::FMAElemOp op) {
610 p << ", " << op.getAcc();
611}
612template <>
613inline void printAccumulator(OpAsmPrinter &p, aievec::MulElemOp op) {}
614
615// Mark fmsub indicator as elided if the FMAElem op is not fmsub
616template <typename T>
617void elideFMSubAttr(T op, SmallVector<StringRef, 4> &elidedAttrs);
618template <>
619inline void elideFMSubAttr(aievec::FMAElemOp op,
620 SmallVector<StringRef, 4> &elidedAttrs) {
621 if (!op.getFmsub())
622 elidedAttrs.push_back(op.getSubAttrName());
623}
624
625template <>
626inline void elideFMSubAttr(aievec::MulElemOp op,
627 SmallVector<StringRef, 4> &elidedAttrs) {}
628
629// Print out MulElem and FMAElem op.
630template <typename T>
631static void printMulFMAElemOp(OpAsmPrinter &p, T op) {
632 // Print the left operand
633 p << " " << op.getLhs();
634 // Print the right operand
635 p << ", " << op.getRhs();
636 // For fma op, print the accumulator
637 printAccumulator(p, op);
638
639 // Print the attributes, but don't print attributes that are empty strings
640 SmallVector<StringRef, 4> elidedAttrs;
641 for (int idx = 0; idx < 2; ++idx) {
642 elideFMSubAttr(op, elidedAttrs);
643 }
644 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
645
646 // And now print the types
647 p << " : " << op.getLhs().getType() << ", " << op.getRhs().getType();
648 p << ", " << op.getResult().getType();
649}
650
651void MulElemOp::print(OpAsmPrinter &p) {
652 printMulFMAElemOp<aievec::MulElemOp>(p, *this);
653}
654
655void aievec::FMAElemOp::print(OpAsmPrinter &p) {
656 printMulFMAElemOp<aievec::FMAElemOp>(p, *this);
657}
658
659// Verify MulElem and FMAElem op.
660template <typename T>
661LogicalResult verifyMulFMAElemOp(T op) {
662 // Verify the types
663 auto lhsType = llvm::dyn_cast<VectorType>(op.getLhs().getType());
664 auto rhsType = llvm::dyn_cast<VectorType>(op.getRhs().getType());
665
666 if (!lhsType || !rhsType)
667 return op.emitError("requires vector type");
668
669 auto resultType = llvm::dyn_cast<VectorType>(op.getResult().getType());
670
671 if (!resultType)
672 return op.emitError("requires vector type");
673
674 // Additional checks for FMAElem op
675 // Get the width of the underlying scalars of all the vectors
676 Type ltype = lhsType.getElementType();
677 Type rtype = rhsType.getElementType();
678 Type atype = resultType.getElementType();
679 unsigned ltypeWidth = ltype.getIntOrFloatBitWidth();
680 unsigned rtypeWidth = rtype.getIntOrFloatBitWidth();
681 unsigned atypeWidth = atype.getIntOrFloatBitWidth();
682
683 // Checks on the number of lanes
684 unsigned rhsLanes = getVectorLaneSize(rhsType);
685 unsigned lhsLanes = getVectorLaneSize(lhsType);
686
687 // lane size must match
688 if (lhsLanes != rhsLanes) {
689 return op.emitError("The number of lanes in lhs operand "
690 "must be the same as rhs operand");
691 }
692
693 // lhs and rhs vector's element type must match
694 if (ltype != rtype)
695 return op.emitError("The element type of lhs and rhs "
696 "operand vectors must match");
697
698 // The integer datatype of accumulator must always be greater width
699 if (isa<IntegerType>(atype)) {
700 if (!isa<IntegerType>(ltype))
701 return op.emitError("Integer result must have integer operands");
702
703 if (ltypeWidth >= atypeWidth || rtypeWidth >= atypeWidth)
704 return op.emitError("the element type of accumulator must have "
705 "wider width than that of the operand vectors");
706 } else if (isa<FloatType>(atype)) {
707 if (!isa<FloatType>(ltype))
708 return op.emitError("Floating point result must have "
709 "floating point operands");
710 }
711
712 return success();
713}
714
715LogicalResult aievec::MulElemOp::verify() {
716 return verifyMulFMAElemOp<aievec::MulElemOp>(*this);
717}
718
719LogicalResult aievec::FMAElemOp::verify() {
720 return verifyMulFMAElemOp<aievec::FMAElemOp>(*this);
721}
722
723// Parse MulElem and FMAElem op.
724ParseResult parseMulFMAElemOp(OpAsmParser &parser, OperationState &result,
725 bool isFMAElemOp = true) {
726 llvm::SMLoc typesLoc;
727 SmallVector<Type, 3> types;
728 OpAsmParser::UnresolvedOperand lhs, rhs, acc;
729
730 // Parse the lhs and rhs
731 if (parser.parseOperand(lhs) || parser.parseComma() ||
732 parser.parseOperand(rhs))
733 return failure();
734
735 // Parse the acc for FMA op
736 if (isFMAElemOp) {
737 if (parser.parseComma() || parser.parseOperand(acc))
738 return failure();
739 }
740
741 // Parse all the attributes and types
742 if (parser.parseOptionalAttrDict(result.attributes) ||
743 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
744 return failure();
745
746 // Assert that there are three types: lhs, rhs, and acc
747 if (types.size() != 3)
748 return parser.emitError(typesLoc, "requires three types");
749
750 // Some verification
751 VectorType lhsType = llvm::dyn_cast<VectorType>(types[0]);
752 if (!lhsType)
753 return parser.emitError(typesLoc, "requires vector type");
754 VectorType rhsType = llvm::dyn_cast<VectorType>(types[1]);
755 if (!rhsType)
756 return parser.emitError(typesLoc, "requires vector type");
757
758 // Int ops use the accumulator while float ops use normal vector registers
759 VectorType accType = llvm::dyn_cast<VectorType>(types[2]);
760 if (!accType)
761 return parser.emitError(typesLoc, "requires vector type");
762
763 // Populate the lhs and rhs operands, and result
764 if (parser.resolveOperand(lhs, lhsType, result.operands) ||
765 parser.resolveOperand(rhs, rhsType, result.operands))
766 return failure();
767
768 // Populate acc operand for FMA op
769 if (isFMAElemOp) {
770 if (parser.resolveOperand(acc, accType, result.operands))
771 return failure();
772 }
773
774 return parser.addTypeToList(accType, result.types);
775}
776
777ParseResult MulElemOp::parse(OpAsmParser &parser, OperationState &result) {
778 return parseMulFMAElemOp(parser, result, false);
779}
780
781ParseResult FMAElemOp::parse(OpAsmParser &parser, OperationState &result) {
782 return parseMulFMAElemOp(parser, result, true);
783}
784
785//===----------------------------------------------------------------------===//
786// ConcatOp
787//===----------------------------------------------------------------------===//
788
789// Print out Concat op.
790void ConcatOp::print(OpAsmPrinter &p) {
791 // Print the source vectors
792 assert(!getSources().empty() && "concat source empty");
793 p << " " << getSources();
794
795 // Print the attributes
796 p.printOptionalAttrDict((*this)->getAttrs());
797
798 // And now print the types
799 p << " : " << getSources().getTypes().front() << ", "
800 << getResult().getType();
801}
802
803// Verify Concat op.
804LogicalResult ConcatOp::verify() {
805 // Must be concatenating at least two sources
806 if (getSources().size() < 2)
807 return emitError("Must concatenate at least two vectors");
808
809 // Verify the types
810 VectorType sourceType =
811 llvm::dyn_cast<VectorType>(getSources().getTypes().front());
812 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
813 if (!sourceType || !resultType)
814 return emitError("requires vector type");
815
816 SmallVector<Value, 8> srcs(getSources().begin(), getSources().end());
817 // All the sources must have the same type
818 for (auto source : srcs) {
819 VectorType type = llvm::dyn_cast<VectorType>(source.getType());
820 if (!type)
821 return emitError("requires vector type");
822 if (type != sourceType)
823 return emitError("All sources must have same type");
824 }
825
826 // The lanes in concatenated type must be the sum of lanes of source vector
827 unsigned totalLanes = 0;
828 for (auto source : srcs) {
829 VectorType type = llvm::dyn_cast<VectorType>(source.getType());
830 totalLanes += getVectorLaneSize(type);
831 }
832
833 if (totalLanes != getVectorLaneSize(resultType))
834 return emitError("mismatch between vector lanes "
835 "and sum of source lanes");
836
837 return success();
838}
839
840// Parse Concat op.
841ParseResult ConcatOp::parse(OpAsmParser &parser, OperationState &result) {
842 llvm::SMLoc typesLoc;
843 SmallVector<Type, 2> types;
844 SmallVector<OpAsmParser::UnresolvedOperand, 8> sources;
845
846 // Parse the source vectors
847 if (parser.parseOperandList(sources))
848 return failure();
849
850 // Parse all the attributes and types
851 if (parser.parseOptionalAttrDict(result.attributes) ||
852 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
853 return failure();
854
855 // Currently there are no attributes in concat op
856 if (!result.attributes.getAttrs().empty())
857 return parser.emitError(typesLoc, "expects no attribute");
858
859 // Assert that there are two types (type of all sources, and result)
860 if (types.size() != 2)
861 return parser.emitError(typesLoc, "requires two types");
862
863 // Some verification
864 VectorType sourceType = llvm::dyn_cast<VectorType>(types[0]);
865 VectorType resultType = llvm::dyn_cast<VectorType>(types[1]);
866 if (!sourceType || !resultType)
867 return parser.emitError(typesLoc, "requires vector type");
868
869 // Populate the source vectors in result
870 if (parser.resolveOperands(sources, sourceType, result.operands))
871 return failure();
872
873 return parser.addTypeToList(resultType, result.types);
874}
875
876LogicalResult
877ConcatOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
878 ConcatOp::Adaptor adaptor,
879 SmallVectorImpl<Type> &inferredReturnTypes) {
880 SmallVector<Value, 8> srcs(adaptor.getSources().begin(),
881 adaptor.getSources().end());
882 unsigned totalLength = 0;
883 for (auto source : srcs) {
884 VectorType type = llvm::dyn_cast<VectorType>(source.getType());
885 assert(type.getRank() == 1 &&
886 "only rank 1 vectors currently supported by concat");
887 totalLength += type.getDimSize(0);
888 }
889 inferredReturnTypes.push_back(VectorType::get(
890 {totalLength},
891 llvm::dyn_cast<VectorType>(srcs[0].getType()).getElementType()));
892 return success();
893}
894
895//===----------------------------------------------------------------------===//
896// ExtOp
897//===----------------------------------------------------------------------===//
898
899// Print out Ext op.
900void ExtOp::print(OpAsmPrinter &p) {
901 // Print the source vector
902 p << " " << getSource();
903
904 // Print the attributes
905 p.printOptionalAttrDict((*this)->getAttrs());
906
907 // And now print the types
908 p << " : " << getSource().getType() << ", " << getResult().getType();
909}
910
911// Verify Ext op.
912LogicalResult ExtOp::verify() {
913 // Verify the types
914 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
915 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
916 if (!sourceType || !resultType)
917 return emitError("requires vector type");
918
919 // Check the number of lanes
920 unsigned sourceLanes = getVectorLaneSize(sourceType);
921 unsigned resultLanes = getVectorLaneSize(resultType);
922 // Source lanes must be greater than result lanes
923 if (sourceLanes / resultLanes <= 1)
924 return emitError("lanes in source vector must be at least "
925 "twice that of result vector");
926 // Source lanes must be a multiple of result lanes
927 if (sourceLanes % resultLanes != 0)
928 return emitError("lanes in result vector must be a multiple "
929 "of source vector lanes");
930
931 // Verify validity of index
932 unsigned factor = sourceLanes / resultLanes;
933 if (static_cast<unsigned>(getIndex()) >= factor)
934 return emitError("index out of bounds");
935
936 // The datatype of source and result must match
937 Type stype = sourceType.getElementType();
938 Type rtype = resultType.getElementType();
939 if (stype != rtype)
940 return emitError("source and result element type must be same");
941
942 return success();
943}
944
945// Parse Ext op.
946ParseResult ExtOp::parse(OpAsmParser &parser, OperationState &result) {
947 llvm::SMLoc typesLoc;
948 SmallVector<Type, 2> types;
949 OpAsmParser::UnresolvedOperand source;
950
951 // Parse the source vector
952 if (parser.parseOperand(source))
953 return failure();
954
955 // Parse all the attributes and types
956 if (parser.parseOptionalAttrDict(result.attributes) ||
957 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
958 return failure();
959
960 if (result.attributes.getAttrs().size() != 1)
961 return parser.emitError(typesLoc, "requires one attribute");
962
963 // Assert that there are two types (source and result)
964 if (types.size() != 2)
965 return parser.emitError(typesLoc, "requires two types");
966
967 // Some verification
968 VectorType sourceType = llvm::dyn_cast<VectorType>(types[0]);
969 VectorType resultType = llvm::dyn_cast<VectorType>(types[1]);
970 if (!sourceType || !resultType)
971 return parser.emitError(typesLoc, "requires vector type");
972
973 // Populate the source in result
974 if (parser.resolveOperand(source, sourceType, result.operands))
975 return failure();
976
977 return parser.addTypeToList(resultType, result.types);
978}
979
980//===----------------------------------------------------------------------===//
981// PackOp and UnpackOp
982//===----------------------------------------------------------------------===//
983
984// Print out Pack and Unpack op.
985template <typename T>
986static void printPackUnpackOp(OpAsmPrinter &p, T op) {
987 // Print the source vector
988 p << " " << op.getSource();
989
990 // Print the attributes
991 p.printOptionalAttrDict(op->getAttrs());
992
993 // And now print the types
994 p << " : " << op.getSource().getType() << ", " << op.getResult().getType();
995}
996
997void PackOp::print(OpAsmPrinter &p) { printPackUnpackOp<PackOp>(p, *this); }
998
999void UnpackOp::print(OpAsmPrinter &p) { printPackUnpackOp<UnpackOp>(p, *this); }
1000
1001// Verify Pack and Unpack op.
1002template <typename T>
1003LogicalResult verifyPackUnpackOp(T op) {
1004 // Verify the types
1005 auto sourceType = llvm::dyn_cast<VectorType>(op.getSource().getType());
1006 auto resultType = llvm::dyn_cast<VectorType>(op.getResult().getType());
1007 if (!sourceType || !resultType)
1008 return op.emitError("requires vector type");
1009
1010 // The number of lanes must match
1011 unsigned sourceLanes = getVectorLaneSize(sourceType);
1012 unsigned resultLanes = getVectorLaneSize(resultType);
1013 if (sourceLanes != resultLanes)
1014 return op.emitError("The number of lanes in input and "
1015 "output vector must match");
1016
1017 Type stype = sourceType.getElementType();
1018 unsigned stypeWidth = stype.getIntOrFloatBitWidth();
1019 Type rtype = resultType.getElementType();
1020 unsigned rtypeWidth = rtype.getIntOrFloatBitWidth();
1021
1022 if (isa<PackOp>(op)) {
1023 // The datatype of source must be i16, and datatype of result must be i8
1024 if (stypeWidth != 16)
1025 return op.emitError("input must be an int16 vector");
1026 if (rtypeWidth != 8)
1027 return op.emitError("output must be an int8 vector");
1028 } else {
1029 if (stypeWidth != 8)
1030 return op.emitError("input must be an int8 vector");
1031 if (rtypeWidth != 16)
1032 return op.emitError("output must be an int16 vector");
1033 }
1034
1035 return success();
1036}
1037
1038LogicalResult PackOp::verify() { return verifyPackUnpackOp<PackOp>(*this); }
1039
1040LogicalResult UnpackOp::verify() { return verifyPackUnpackOp<UnpackOp>(*this); }
1041
1042// Parse Pack and Unpack op.
1043ParseResult parsePackUnpackOp(OpAsmParser &parser, OperationState &result) {
1044 llvm::SMLoc typesLoc;
1045 SmallVector<Type, 2> types;
1046 OpAsmParser::UnresolvedOperand source;
1047
1048 // Parse the source vector
1049 if (parser.parseOperand(source))
1050 return failure();
1051
1052 // Parse all the attributes and types
1053 if (parser.parseOptionalAttrDict(result.attributes) ||
1054 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
1055 return failure();
1056
1057 // Currently there are no attributes in pack/unpack op
1058 if (!result.attributes.getAttrs().empty())
1059 return parser.emitError(typesLoc, "expects no attributes");
1060
1061 // Assert that there are two types (source and result vectors)
1062 if (types.size() != 2)
1063 return parser.emitError(typesLoc, "requires two types");
1064
1065 // Some verification
1066 VectorType sourceType = llvm::dyn_cast<VectorType>(types[0]);
1067 VectorType resultType = llvm::dyn_cast<VectorType>(types[1]);
1068 if (!sourceType || !resultType)
1069 return parser.emitError(typesLoc, "requires vector type");
1070
1071 // Populate the source in result
1072 if (parser.resolveOperand(source, sourceType, result.operands))
1073 return failure();
1074
1075 return parser.addTypeToList(resultType, result.types);
1076}
1077
1078ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
1079 return parsePackUnpackOp(parser, result);
1080}
1081
1082ParseResult UnpackOp::parse(OpAsmParser &parser, OperationState &result) {
1083 return parsePackUnpackOp(parser, result);
1084}
1085
1086//===----------------------------------------------------------------------===//
1087// ExtElemOp
1088//===----------------------------------------------------------------------===//
1089
1090// Verify Extract Element op.
1091LogicalResult ExtElemOp::verify() {
1092 // Verify the types
1093 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
1094
1095 if (!sourceType)
1096 return emitError("source requires vector type");
1097
1098 // The element type of vectors must always be the same
1099 Type stype = sourceType.getElementType();
1100 Type rtype = getResult().getType();
1101
1102 if (stype != rtype) {
1103 return emitError("the type of result must be the same as the element "
1104 "type of source vector");
1105 }
1106
1107 return success();
1108}
1109
1110//===----------------------------------------------------------------------===//
1111// ShiftOp
1112//===----------------------------------------------------------------------===//
1113
1114// Print out Shift op.
1115void ShiftOp::print(OpAsmPrinter &p) {
1116 // Print the lhs and rhs vectors
1117 p << " " << getLhs() << ", " << getRhs();
1118
1119 // Print shift
1120 p << ", " << getShift();
1121
1122 // Print the attributes
1123 p.printOptionalAttrDict((*this)->getAttrs());
1124
1125 // And now print the types
1126 p << " : " << getLhs().getType() << ", " << getLhs().getType() << ", "
1127 << getShift().getType() << ", " << getResult().getType();
1128}
1129
1130// Verify Shift op.
1131LogicalResult ShiftOp::verify() {
1132 // Verify the types
1133 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
1134 if (!resultType)
1135 return emitError("requires vector type");
1136
1137 // lhs, rhs and result must have the same type
1138 VectorType lhsType = llvm::dyn_cast<VectorType>(getLhs().getType());
1139 VectorType rhsType = llvm::dyn_cast<VectorType>(getRhs().getType());
1140
1141 if (!lhsType || !rhsType)
1142 return emitError("requires vector type");
1143 if (lhsType != resultType || rhsType != resultType)
1144 return emitError("All vectors must have same type");
1145
1146 if (!isa<IntegerType>(getShift().getType()))
1147 return emitError("requires integer type");
1148
1149 return success();
1150}
1151
1152// Parse Shift op.
1153ParseResult ShiftOp::parse(OpAsmParser &parser, OperationState &result) {
1154 llvm::SMLoc typesLoc;
1155 SmallVector<Type, 4> types;
1156 OpAsmParser::UnresolvedOperand lhs, rhs, shift;
1157
1158 // Parse the source vectors
1159 if (parser.parseOperand(lhs) || parser.parseComma() ||
1160 parser.parseOperand(rhs) || parser.parseComma() ||
1161 parser.parseOperand(shift))
1162 return failure();
1163
1164 // Parse all the attributes and types
1165 if (parser.parseOptionalAttrDict(result.attributes) ||
1166 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
1167 return failure();
1168
1169 if (result.attributes.getAttrs().size() != 1)
1170 return parser.emitError(typesLoc, "expects one attribute");
1171
1172 // Assert that there are two types (source and result vectors)
1173 if (types.size() != 4)
1174 return parser.emitError(typesLoc, "requires four types");
1175
1176 // Some verification
1177 VectorType lhsType = llvm::dyn_cast<VectorType>(types[0]);
1178 VectorType rhsType = llvm::dyn_cast<VectorType>(types[1]);
1179 IntegerType shiftType = llvm::dyn_cast<IntegerType>(types[2]);
1180 VectorType resultType = llvm::dyn_cast<VectorType>(types[3]);
1181 if (!lhsType || !rhsType || !resultType)
1182 return parser.emitError(typesLoc, "requires vector type");
1183
1184 if (!shiftType)
1185 return parser.emitError(typesLoc, "requires integer type");
1186
1187 // Populate the lhs vector, rhs vectors and shift in result
1188 if (parser.resolveOperand(lhs, lhsType, result.operands) ||
1189 parser.resolveOperand(rhs, rhsType, result.operands) ||
1190 parser.resolveOperand(shift, shiftType, result.operands))
1191 return failure();
1192
1193 return parser.addTypeToList(resultType, result.types);
1194}
1195
1196//===----------------------------------------------------------------------===//
1197// ShuffleOp
1198//===----------------------------------------------------------------------===//
1199
1200// This verification function makes sure that the shuffle mode supports the
1201// number and type of operands provided.
1202LogicalResult ShuffleOp::verify() {
1203 unsigned modeBitWidth;
1204 bool requireRhs = true;
1205 auto mode = getMode();
1206 switch (mode) {
1207 case ShuffleMode::T8_8X8: // 35
1208 case ShuffleMode::T8_16X4: // 36
1209 case ShuffleMode::T8_4X16: // 37
1210 case ShuffleMode::T8_8X4: // 46
1211 case ShuffleMode::T8_4X8: // 47
1212 requireRhs = false;
1213 LLVM_FALLTHROUGH;
1214 case ShuffleMode::T8_64X2_LO: // 0
1215 case ShuffleMode::T8_64X2_HI: // 1
1216 case ShuffleMode::T8_2X64_LO: // 20
1217 case ShuffleMode::T8_2X64_HI: // 21
1218 modeBitWidth = 8u;
1219 break;
1220 case ShuffleMode::T16_8X4: // 28
1221 case ShuffleMode::T16_4X8: // 29
1222 case ShuffleMode::T16_1X2_flip: // 38
1223 case ShuffleMode::T16_4X4: // 39
1224 case ShuffleMode::T16_4X2: // 40
1225 case ShuffleMode::T16_2X4: // 41
1226 case ShuffleMode::T16_8X2: // 42
1227 case ShuffleMode::T16_2X8: // 43
1228 case ShuffleMode::T16_16X2: // 44
1229 case ShuffleMode::T16_2X16: // 45
1230 requireRhs = false;
1231 LLVM_FALLTHROUGH;
1232 case ShuffleMode::T16_32X2_LO: // 2
1233 case ShuffleMode::T16_32X2_HI: // 3
1234 case ShuffleMode::T16_2X32_LO: // 18
1235 case ShuffleMode::T16_2X32_HI: // 19
1236 case ShuffleMode::T16_16X4_LO: // 24
1237 case ShuffleMode::T16_16X4_HI: // 25
1238 case ShuffleMode::T16_4X16_LO: // 26
1239 case ShuffleMode::T16_4X16_HI: // 27
1240 modeBitWidth = 16u;
1241 break;
1242 case ShuffleMode::T32_4X4: // 34
1243 requireRhs = false;
1244 LLVM_FALLTHROUGH;
1245 case ShuffleMode::T32_16X2_LO: // 4
1246 case ShuffleMode::T32_16X2_HI: // 5
1247 case ShuffleMode::T32_2X16_LO: // 16
1248 case ShuffleMode::T32_2X16_HI: // 17
1249 case ShuffleMode::T32_8X4_LO: // 30
1250 case ShuffleMode::T32_8X4_HI: // 31
1251 case ShuffleMode::T32_4X8_LO: // 32
1252 case ShuffleMode::T32_4X8_HI: // 33
1253 modeBitWidth = 32u;
1254 break;
1255 case ShuffleMode::T64_8X2_LO: // 6
1256 case ShuffleMode::T64_8X2_HI: // 7
1257 case ShuffleMode::T64_2X8_LO: // 14
1258 case ShuffleMode::T64_2X8_HI: // 15
1259 modeBitWidth = 64u;
1260 break;
1261 case ShuffleMode::T128_4X2_LO: // 8
1262 case ShuffleMode::T128_4X2_HI: // 9
1263 case ShuffleMode::T128_2X4_LO: // 12
1264 case ShuffleMode::T128_2X4_HI: // 13
1265 modeBitWidth = 128u;
1266 break;
1267 case ShuffleMode::T256_2X2_LO: // 10
1268 case ShuffleMode::T256_2X2_HI: // 11
1269 modeBitWidth = 256u;
1270 break;
1271 case ShuffleMode::T512_1X2_LO: // 22
1272 case ShuffleMode::T512_1X2_HI: // 23
1273 modeBitWidth = 512u;
1274 break;
1275 }
1276
1277 // Verify number of operands
1278 if (requireRhs && !getRhs())
1279 return emitError() << "shuffle mode '" << stringifyEnum(mode)
1280 << "' requires a second operand";
1281
1282 if (!requireRhs && getRhs())
1283 return emitError() << "shuffle mode '" << stringifyEnum(mode)
1284 << "' does not admit a second operand";
1285
1286 // Verify vector element type
1287 auto elemBitWidth =
1288 cast<VectorType>(getLhs().getType()).getElementTypeBitWidth();
1289 if (modeBitWidth != elemBitWidth)
1290 return emitError() << "shuffle mode '" << stringifyEnum(mode)
1291 << "' requires vectors of " << modeBitWidth
1292 << "-bit elements";
1293
1294 return success();
1295}
1296
1297// Print out Shuffle op.
1298void LegacyShuffleOp::print(OpAsmPrinter &p) {
1299 // Print the source vector
1300 p << " " << getSource();
1301
1302 // Print the attributes
1303 p.printOptionalAttrDict((*this)->getAttrs());
1304
1305 // And now print the types
1306 p << " : " << getSource().getType() << ", " << getResult().getType();
1307}
1308
1309// Verify Shuffle op.
1310LogicalResult LegacyShuffleOp::verify() {
1311 // Verify the types
1312 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
1313 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
1314 if (!sourceType || !resultType)
1315 return emitError("requires vector type");
1316
1317 // The number of lanes must match
1318 unsigned sourceLanes = getVectorLaneSize(sourceType);
1319 unsigned resultLanes = getVectorLaneSize(resultType);
1320 if (sourceLanes != resultLanes)
1321 return emitError("The number of lanes in input and "
1322 "output vector must match");
1323
1324 Type stype = sourceType.getElementType();
1325 unsigned stypeWidth = stype.getIntOrFloatBitWidth();
1326 Type rtype = resultType.getElementType();
1327 unsigned rtypeWidth = rtype.getIntOrFloatBitWidth();
1328
1329 if (stypeWidth != rtypeWidth)
1330 return emitError("The type width in input and "
1331 "output must match");
1332
1333 return success();
1334}
1335
1336// Parse Shuffle op.
1337ParseResult LegacyShuffleOp::parse(OpAsmParser &parser,
1338 OperationState &result) {
1339 llvm::SMLoc typesLoc;
1340 SmallVector<Type, 2> types;
1341 OpAsmParser::UnresolvedOperand source;
1342
1343 // Parse the source vector
1344 if (parser.parseOperand(source))
1345 return failure();
1346
1347 // Parse all the attributes and types
1348 if (parser.parseOptionalAttrDict(result.attributes) ||
1349 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
1350 return failure();
1351
1352 // Currently there is one attribute in shuffle op
1353 if (result.attributes.getAttrs().size() != 1)
1354 return parser.emitError(typesLoc, "expects one attribute");
1355
1356 // Some verification
1357 VectorType sourceType = llvm::dyn_cast<VectorType>(types[0]);
1358 VectorType resultType = llvm::dyn_cast<VectorType>(types[1]);
1359 if (!sourceType || !resultType)
1360 return parser.emitError(typesLoc, "requires vector type");
1361
1362 // Populate the source vectors in result
1363 if (parser.resolveOperand(source, sourceType, result.operands))
1364 return failure();
1365
1366 return parser.addTypeToList(resultType, result.types);
1367}
1368
1369//===----------------------------------------------------------------------===//
1370// MulConvOp and FMAConvOp
1371//===----------------------------------------------------------------------===//
1372
1373// MulConvOp and FMAConvOp are structurally similar, except that FMAConv op
1374// has few extra fields (accumulator, bool flag to indicate if it is fmsub,
1375// etc.). We create some specializations to print those fields specifically for
1376// FMAConvOp and MULConvOp.
1377
1378// Print the accumulator
1379template <typename T>
1380void printAccumulator(OpAsmPrinter &p, T op);
1381template <>
1382inline void printAccumulator(OpAsmPrinter &p, aievec::FMAConvOp op) {
1383 p << ", " << op.getAcc();
1384}
1385template <>
1386inline void printAccumulator(OpAsmPrinter &p, aievec::MulConvOp op) {}
1387
1388// Mark fmsub indicator as elided if the FMAElem op is not fmsub
1389template <typename T>
1390void elideFMSubAttr(T op, SmallVector<StringRef, 4> &elidedAttrs);
1391template <>
1392inline void elideFMSubAttr(FMAConvOp op,
1393 SmallVector<StringRef, 4> &elidedAttrs) {
1394 if (!op.getFmsub())
1395 elidedAttrs.push_back(op.getSubAttrName());
1396}
1397
1398template <>
1399inline void elideFMSubAttr(MulConvOp op,
1400 SmallVector<StringRef, 4> &elidedAttrs) {}
1401
1402// Print out MulConv and FMAConv op.
1403template <typename T>
1404static void printMulFMAConvOp(OpAsmPrinter &p, T op) {
1405 // Print the left operand
1406 p << " " << op.getLhs();
1407 // Print the right operand
1408 p << ", " << op.getRhs();
1409 // For fma op, print the accumulator
1410 printAccumulator(p, op);
1411
1412 // Print the attributes, but don't print attributes that are empty strings
1413 SmallVector<StringRef, 4> elidedAttrs;
1414 for (int idx = 0; idx < 2; ++idx) {
1415 elideFMSubAttr(op, elidedAttrs);
1416 }
1417 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
1418
1419 // And now print the types
1420 p << " : " << op.getLhs().getType() << ", " << op.getRhs().getType();
1421 p << ", " << op.getResult().getType();
1422}
1423
1424void MulConvOp::print(OpAsmPrinter &p) {
1425 printMulFMAConvOp<aievec::MulConvOp>(p, *this);
1426}
1427
1428void aievec::FMAConvOp::print(OpAsmPrinter &p) {
1429 printMulFMAConvOp<aievec::FMAConvOp>(p, *this);
1430}
1431
1432// Verify MulConv and FMAConv op.
1433template <typename T>
1434LogicalResult verifyMulFMAConvOp(T op) {
1435 // Verify the types
1436 auto lhsType = llvm::dyn_cast<VectorType>(op.getLhs().getType());
1437 auto rhsType = llvm::dyn_cast<VectorType>(op.getRhs().getType());
1438
1439 if (!lhsType || !rhsType)
1440 return op.emitError("requires vector type");
1441
1442 unsigned M = op.getM();
1443 unsigned N = op.getN();
1444
1445 if (M <= 0 || N <= 0 || 2 * M < M + N - 1)
1446 return op.emitError(
1447 "M and N should be larger than 0 and 2*M should be no less than M+N-1");
1448
1449 auto resultType = llvm::dyn_cast<VectorType>(op.getResult().getType());
1450
1451 if (!resultType)
1452 return op.emitError("requires vector type");
1453
1454 // Additional checks for FMAElem op
1455 // Get the width of the underlying scalars of all the vectors
1456 Type ltype = lhsType.getElementType();
1457 Type rtype = rhsType.getElementType();
1458 Type atype = resultType.getElementType();
1459
1460 // lhs and rhs vector's element type must match
1461 if (ltype != rtype)
1462 return op.emitError("The element type of lhs and rhs "
1463 "operand vectors must match");
1464
1465 if (!isa<IntegerType>(ltype) || !isa<IntegerType>(rtype) ||
1466 !isa<IntegerType>(atype)) {
1467 return op.emitError("requires integer type");
1468 }
1469
1470 unsigned ltypeWidth = ltype.getIntOrFloatBitWidth();
1471 unsigned rtypeWidth = rtype.getIntOrFloatBitWidth();
1472 unsigned atypeWidth = atype.getIntOrFloatBitWidth();
1473
1474 // Checks on the number of lanes
1475 unsigned accLanes = getVectorLaneSize(resultType);
1476 unsigned rhsLanes = getVectorLaneSize(rhsType);
1477 unsigned lhsLanes = getVectorLaneSize(lhsType);
1478
1479 // lane size must match
1480 if (accLanes != M || accLanes != (rhsLanes / 2) || lhsLanes != rhsLanes) {
1481 return op.emitError(
1482 "The number of lanes in accumulator "
1483 "must be the same as M and the half as lhs and rhs operand");
1484 }
1485
1486 // The integer datatype of accumulator must always be greater width
1487 if (ltypeWidth >= atypeWidth || rtypeWidth >= atypeWidth)
1488 return op.emitError("the element type of accumulator must have "
1489 "wider width than that of the operand vectors");
1490
1491 return success();
1492}
1493
1494LogicalResult aievec::MulConvOp::verify() {
1495 return verifyMulFMAConvOp<aievec::MulConvOp>(*this);
1496}
1497
1498LogicalResult aievec::FMAConvOp::verify() {
1499 return verifyMulFMAConvOp<aievec::FMAConvOp>(*this);
1500}
1501
1502// Parse MulConv and FMAConv op.
1503ParseResult parseMulFMAConvOp(OpAsmParser &parser, OperationState &result,
1504 bool isFMAConvOp = true) {
1505 llvm::SMLoc typesLoc;
1506 SmallVector<Type, 3> types;
1507 OpAsmParser::UnresolvedOperand lhs, rhs, acc;
1508
1509 // Parse the lhs and rhs
1510 if (parser.parseOperand(lhs) || parser.parseComma() ||
1511 parser.parseOperand(rhs))
1512 return failure();
1513
1514 // Parse the acc for FMA op
1515 if (isFMAConvOp) {
1516 if (parser.parseComma() || parser.parseOperand(acc))
1517 return failure();
1518 }
1519
1520 // Parse all the attributes and types
1521 if (parser.parseOptionalAttrDict(result.attributes) ||
1522 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
1523 return failure();
1524
1525 // Assert that there are three types: lhs, rhs, and acc
1526 if (types.size() != 3)
1527 return parser.emitError(typesLoc, "requires three types");
1528
1529 // Some verification
1530 VectorType lhsType = llvm::dyn_cast<VectorType>(types[0]);
1531 if (!lhsType)
1532 return parser.emitError(typesLoc, "requires vector type");
1533 VectorType rhsType = llvm::dyn_cast<VectorType>(types[1]);
1534 if (!rhsType)
1535 return parser.emitError(typesLoc, "requires vector type");
1536
1537 // Int ops use the accumulator
1538 VectorType accType = llvm::dyn_cast<VectorType>(types[2]);
1539 if (!accType)
1540 return parser.emitError(typesLoc, "requires vector type");
1541
1542 // Populate the lhs and rhs operands, and result
1543 if (parser.resolveOperand(lhs, lhsType, result.operands) ||
1544 parser.resolveOperand(rhs, rhsType, result.operands))
1545 return failure();
1546
1547 // Populate acc operand for FMA op
1548 if (isFMAConvOp) {
1549 if (parser.resolveOperand(acc, accType, result.operands))
1550 return failure();
1551 }
1552
1553 return parser.addTypeToList(accType, result.types);
1554}
1555
1556ParseResult MulConvOp::parse(OpAsmParser &parser, OperationState &result) {
1557 return parseMulFMAConvOp(parser, result, false);
1558}
1559
1560ParseResult FMAConvOp::parse(OpAsmParser &parser, OperationState &result) {
1561 return parseMulFMAConvOp(parser, result, true);
1562}
1563
1564#define GET_ATTRDEF_CLASSES
1565#include "aie/Dialect/AIEVec/IR/AIEVecAttributes.cpp.inc"
1566
1567#define GET_OP_CLASSES
1568#include "aie/Dialect/AIEVec/IR/AIEVecOps.cpp.inc"
ParseResult parsePackUnpackOp(OpAsmParser &parser, OperationState &result)
void printAccumulator(OpAsmPrinter &p, T op)
ParseResult parseMulFMAElemOp(OpAsmParser &parser, OperationState &result, bool isFMAElemOp=true)
ParseResult parseMulFMAConvOp(OpAsmParser &parser, OperationState &result, bool isFMAConvOp=true)
void elideFMSubAttr(T op, SmallVector< StringRef, 4 > &elidedAttrs)
LogicalResult verifyMulFMAElemOp(T op)
LogicalResult verifyMulFMAConvOp(T op)
LogicalResult verifyPackUnpackOp(T op)
std::vector< Port > srcs
unsigned getVectorLaneSize(mlir::VectorType type)
Definition AIEVecUtils.h:55