13#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
14#include "mlir/IR/DialectImplementation.h"
15#include "mlir/IR/OpDefinition.h"
16#include "mlir/IR/TypeUtilities.h"
17#include "mlir/Transforms/FoldUtils.h"
18#include "llvm/ADT/TypeSwitch.h"
28#include "aie/Dialect/AIEVec/IR/AIEVecEnums.cpp.inc"
29#include "aie/Dialect/AIEVec/IR/AIEVecOpsDialect.cpp.inc"
35void AIEVecDialect::initialize() {
38#define GET_ATTRDEF_LIST
39#include "aie/Dialect/AIEVec/IR/AIEVecAttributes.cpp.inc"
43#include "aie/Dialect/AIEVec/IR/AIEVecOps.cpp.inc"
52void UPDOp::print(OpAsmPrinter &p) {
54 p <<
" " << getSource() <<
"[" << getIndices() <<
"]";
57 p <<
", " << getVector();
60 SmallVector<StringRef, 3> elidedAttrs;
61 elidedAttrs.push_back(UPDOp::getOperandSegmentSizeAttr());
62 p.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
65 p <<
" : " << getSource().getType() <<
", " << getResult().getType();
69LogicalResult UPDOp::verify() {
71 MemRefType sourceType = llvm::dyn_cast<MemRefType>(getSource().getType());
72 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
74 return emitError(
"requires memref type");
76 return emitError(
"requires vector type");
77 if (getIndices().
empty())
78 return emitError(
"upd source cannot come from scalar value");
83 Type vecType = llvm::dyn_cast<VectorType>(getVector().getType());
84 if (vecType != resultType)
85 return emitError(
"result types of linked UPD ops do not match");
91ParseResult UPDOp::parse(OpAsmParser &parser, OperationState &result) {
92 auto &builder = parser.getBuilder();
94 SmallVector<Type, 2> types;
95 OpAsmParser::UnresolvedOperand source, vector;
96 SmallVector<OpAsmParser::UnresolvedOperand, 8> indices;
99 if (parser.parseOperand(source) ||
100 parser.parseOperandList(indices, OpAsmParser::Delimiter::Square))
102 ParseResult hasVector = parser.parseOptionalComma();
103 if (hasVector.succeeded() && parser.parseOperand(vector))
107 if (parser.parseOptionalAttrDict(result.attributes) ||
108 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
111 if (result.attributes.getAttrs().size() != 2)
112 return parser.emitError(typesLoc,
"requires two attributes");
115 if (types.size() != 2)
116 return parser.emitError(typesLoc,
"requires two types");
119 auto memrefType = llvm::dyn_cast<MemRefType>(types[0]);
121 return parser.emitError(typesLoc,
"requires memref type");
122 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
124 return parser.emitError(typesLoc,
"requires vector type");
125 auto indicesType = builder.getIndexType();
128 if (parser.resolveOperand(source, memrefType, result.operands) ||
129 parser.resolveOperands(indices, indicesType, result.operands))
132 if (hasVector.succeeded())
133 if (parser.resolveOperand(vector, vectorType, result.operands))
137 result.addAttribute(UPDOp::getOperandSegmentSizeAttr(),
138 builder.getDenseI32ArrayAttr(
139 {1, static_cast<int32_t>(indices.size()),
140 static_cast<int32_t>(hasVector.succeeded())}));
142 return parser.addTypeToList(vectorType, result.types);
150void CastOp::print(OpAsmPrinter &p) {
152 p <<
" " << getSource();
155 p.printOptionalAttrDict((*this)->getAttrs());
158 p <<
" : " << getSource().getType() <<
", " << getResult().getType();
162LogicalResult CastOp::verify() {
164 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
165 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
167 return emitError(
"requires source vector type");
169 return emitError(
"requires result vector type");
171 if (sourceType.getElementType().getIntOrFloatBitWidth() !=
172 resultType.getElementType().getIntOrFloatBitWidth()) {
173 return emitError(
"the bitwidth of resource and result should be equal");
180ParseResult CastOp::parse(OpAsmParser &parser, OperationState &result) {
181 llvm::SMLoc typesLoc;
182 SmallVector<Type, 2> types;
183 OpAsmParser::UnresolvedOperand source;
186 if (parser.parseOperand(source))
190 if (parser.parseOptionalAttrDict(result.attributes) ||
191 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
194 if (result.attributes.getAttrs().size() != 1)
195 return parser.emitError(typesLoc,
"requires one attribute");
198 if (types.size() != 2)
199 return parser.emitError(typesLoc,
"requires two types");
202 VectorType sourceType = llvm::dyn_cast<VectorType>(types[0]);
204 return parser.emitError(typesLoc,
"requires vector type");
205 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
207 return parser.emitError(typesLoc,
"requires vector type");
210 if (parser.resolveOperand(source, sourceType, result.operands))
213 return parser.addTypeToList(vectorType, result.types);
217OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
218 auto srcCastOp = getSource().getDefiningOp<aievec::CastOp>();
222 if (srcCastOp.getIsResAcc() == getIsResAcc())
223 return srcCastOp.getResult();
225 return srcCastOp.getSource();
233OpFoldResult SRSOp::fold(FoldAdaptor adaptor) {
234 auto srcDefOp = getSource().getDefiningOp();
238 auto upsOp = dyn_cast<UPSOp>(srcDefOp);
242 auto shiftDefOp = getShift().getDefiningOp();
246 auto constOp = dyn_cast<arith::ConstantOp>(shiftDefOp);
252 auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue());
253 if (!intAttr || intAttr.getInt() != 0)
256 if (upsOp.getSource().getType() != getResult().getType())
259 return upsOp.getSource();
263void SRSOp::print(OpAsmPrinter &p) {
265 p <<
" " << getSource() <<
", ";
272 if (getSign() == 1) {
273 p.printOptionalAttrDict((*this)->getAttrs(), {
"sign"});
275 p.printOptionalAttrDict((*this)->getAttrs());
279 p <<
" : " << getSource().getType() <<
", " << getShift().getType() <<
", "
280 << getResult().getType();
284LogicalResult SRSOp::verify() {
286 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
287 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
289 return emitError(
"requires accumulator type");
291 return emitError(
"requires vector type");
296 if (accLanes != vecLanes)
297 return emitError(
"The number of lanes in result vector "
298 "and source accumulator must match");
301 Type stype = resultType.getElementType();
302 Type atype = sourceType.getElementType();
303 unsigned stypeWidth = stype.getIntOrFloatBitWidth();
304 unsigned atypeWidth = atype.getIntOrFloatBitWidth();
306 if (isa<IntegerType>(atype) && stypeWidth >= atypeWidth)
307 return emitError(
"the element type of source accumulator must be "
308 "wider than that of the result vector");
309 else if (isa<FloatType>(atype) && stypeWidth != 16 &&
310 stypeWidth != atypeWidth)
311 return emitError(
"the element type of source accumulator must be "
312 "same as the result vector");
318ParseResult SRSOp::parse(OpAsmParser &parser, OperationState &result) {
319 llvm::SMLoc typesLoc;
320 SmallVector<Type, 3> types;
321 OpAsmParser::UnresolvedOperand source, shift;
324 if (parser.parseOperand(source) || parser.parseComma() ||
325 parser.parseOperand(shift))
329 if (parser.parseOptionalAttrDict(result.attributes))
333 if (parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
337 if (types.size() != 3)
338 return parser.emitError(typesLoc,
"requires three types");
341 VectorType accType = llvm::dyn_cast<VectorType>(types[0]);
343 return parser.emitError(typesLoc,
"requires vector type");
345 IntegerType shiftType = llvm::dyn_cast<IntegerType>(types[1]);
347 return parser.emitError(typesLoc,
"requires integer type");
349 VectorType vectorType = llvm::dyn_cast<VectorType>(types[2]);
351 return parser.emitError(typesLoc,
"requires vector type");
354 if (parser.resolveOperand(source, accType, result.operands) ||
355 parser.resolveOperand(shift, shiftType, result.operands))
358 return parser.addTypeToList(vectorType, result.types);
366OpFoldResult UPSOp::fold(FoldAdaptor adaptor) {
371 auto srcDefOp = getSource().getDefiningOp();
374 auto srsOp = llvm::dyn_cast<SRSOp>(srcDefOp);
377 return srsOp.getSource();
381void UPSOp::print(OpAsmPrinter &p) {
383 p <<
" " << getSource();
386 p.printOptionalAttrDict((*this)->getAttrs());
389 p <<
" : " << getSource().getType() <<
", " << getResult().getType();
393LogicalResult UPSOp::verify() {
395 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
396 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
398 return emitError(
"requires vector type");
400 return emitError(
"requires vector type");
405 if (vecLanes != accLanes)
406 return emitError(
"The number of lanes in source vector "
407 "and result accumulator must match");
410 Type stype = sourceType.getElementType();
411 Type atype = resultType.getElementType();
412 unsigned stypeWidth = stype.getIntOrFloatBitWidth();
413 unsigned atypeWidth = atype.getIntOrFloatBitWidth();
415 if (stypeWidth >= atypeWidth)
416 return emitError(
"the element type of result accumulator "
417 "must be wider than that of the source vector");
423ParseResult UPSOp::parse(OpAsmParser &parser, OperationState &result) {
424 llvm::SMLoc typesLoc;
425 SmallVector<Type, 2> types;
426 OpAsmParser::UnresolvedOperand source;
429 if (parser.parseOperand(source))
433 if (parser.parseOptionalAttrDict(result.attributes) ||
434 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
437 if (result.attributes.getAttrs().size() != 1)
438 return parser.emitError(typesLoc,
"requires one attribute");
441 if (types.size() != 2)
442 return parser.emitError(typesLoc,
"requires two types");
445 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
447 return parser.emitError(typesLoc,
"requires vector type");
448 VectorType accType = llvm::dyn_cast<VectorType>(types[1]);
450 return parser.emitError(typesLoc,
"requires vector type");
453 if (parser.resolveOperand(source, vectorType, result.operands))
456 return parser.addTypeToList(accType, result.types);
464void BroadcastOp::print(OpAsmPrinter &p) {
466 p <<
" " << getSource();
469 p.printOptionalAttrDict((*this)->getAttrs());
472 p <<
" : " << getSource().getType() <<
", " << getResult().getType();
476LogicalResult BroadcastOp::verify() {
478 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
479 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
482 return emitError(
"requires vector type");
484 return emitError(
"requires vector type");
486 if (sourceType != resultType) {
487 return emitError(
"The vector type of source vector "
488 "and result vector must match");
493 if (sourceLanes != resultLanes)
494 return emitError(
"The number of lanes in source vector "
495 "and result vector must match");
498 Type stype = sourceType.getElementType();
499 Type rtype = resultType.getElementType();
501 if (stype != rtype) {
502 return emitError(
"the element type of result vector "
503 "must be the same as source vector");
510ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
511 llvm::SMLoc typesLoc;
512 SmallVector<Type, 2> types;
513 OpAsmParser::UnresolvedOperand source;
516 if (parser.parseOperand(source))
520 if (parser.parseOptionalAttrDict(result.attributes) ||
521 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
524 if (result.attributes.getAttrs().size() != 1)
525 return parser.emitError(typesLoc,
"requires one attribute");
528 if (types.size() != 2)
529 return parser.emitError(typesLoc,
"requires two types");
532 VectorType vecType = llvm::dyn_cast<VectorType>(types[0]);
534 return parser.emitError(typesLoc,
"requires vector type");
536 VectorType resType = llvm::dyn_cast<VectorType>(types[1]);
538 return parser.emitError(typesLoc,
"requires vector type");
541 if (parser.resolveOperand(source, vecType, result.operands))
544 return parser.addTypeToList(resType, result.types);
552void BroadcastScalarOp::print(OpAsmPrinter &p) {
554 p <<
" " << getSource();
557 p <<
" : " << getSource().getType() <<
", " << getResult().getType();
561LogicalResult BroadcastScalarOp::verify() {
563 Type sourceType = getSource().getType();
564 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
567 return emitError(
"requires vector type");
569 if (!isa<IntegerType, FloatType>(sourceType))
570 return emitError(
"requires source type to be integer or float");
572 Type resultElemType = resultType.getElementType();
573 if (sourceType != resultElemType) {
574 return emitError(
"the element type of result vector must be the same as "
582ParseResult BroadcastScalarOp::parse(OpAsmParser &parser,
583 OperationState &result) {
584 llvm::SMLoc typesLoc;
585 SmallVector<Type, 2> types;
586 OpAsmParser::UnresolvedOperand source;
589 if (parser.parseOperand(source))
593 if (parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
596 if (!result.attributes.getAttrs().empty())
597 return parser.emitError(typesLoc,
"do not require attributes");
600 if (types.size() != 2)
601 return parser.emitError(typesLoc,
"requires two types");
604 VectorType resType = llvm::dyn_cast<VectorType>(types[1]);
606 return parser.emitError(typesLoc,
"requires vector type");
608 if (parser.resolveOperand(source, types[0], result.operands))
611 return parser.addTypeToList(resType, result.types);
628 p <<
", " << op.getAcc();
638 SmallVector<StringRef, 4> &elidedAttrs) {
640 elidedAttrs.push_back(op.getSubAttrName());
645 SmallVector<StringRef, 4> &elidedAttrs) {}
649static void printMulFMAElemOp(OpAsmPrinter &p, T op) {
651 p <<
" " << op.getLhs();
653 p <<
", " << op.getRhs();
658 SmallVector<StringRef, 4> elidedAttrs;
659 for (
int idx = 0; idx < 2; ++idx) {
662 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
665 p <<
" : " << op.getLhs().getType() <<
", " << op.getRhs().getType();
666 p <<
", " << op.getResult().getType();
669void MulElemOp::print(OpAsmPrinter &p) {
670 printMulFMAElemOp<aievec::MulElemOp>(p, *
this);
673void aievec::FMAElemOp::print(OpAsmPrinter &p) {
674 printMulFMAElemOp<aievec::FMAElemOp>(p, *
this);
681 auto lhsType = llvm::dyn_cast<VectorType>(op.getLhs().getType());
682 auto rhsType = llvm::dyn_cast<VectorType>(op.getRhs().getType());
684 if (!lhsType || !rhsType)
685 return op.emitError(
"requires vector type");
687 auto resultType = llvm::dyn_cast<VectorType>(op.getResult().getType());
690 return op.emitError(
"requires vector type");
694 Type ltype = lhsType.getElementType();
695 Type rtype = rhsType.getElementType();
696 Type atype = resultType.getElementType();
697 unsigned ltypeWidth = ltype.getIntOrFloatBitWidth();
698 unsigned rtypeWidth = rtype.getIntOrFloatBitWidth();
699 unsigned atypeWidth = atype.getIntOrFloatBitWidth();
706 if (lhsLanes != rhsLanes) {
707 return op.emitError(
"The number of lanes in lhs operand "
708 "must be the same as rhs operand");
713 return op.emitError(
"The element type of lhs and rhs "
714 "operand vectors must match");
717 if (isa<IntegerType>(atype)) {
718 if (!isa<IntegerType>(ltype))
719 return op.emitError(
"Integer result must have integer operands");
721 if (ltypeWidth >= atypeWidth || rtypeWidth >= atypeWidth)
722 return op.emitError(
"the element type of accumulator must have "
723 "wider width than that of the operand vectors");
724 }
else if (isa<FloatType>(atype)) {
725 if (!isa<FloatType>(ltype))
726 return op.emitError(
"Floating point result must have "
727 "floating point operands");
733LogicalResult aievec::MulElemOp::verify() {
734 return verifyMulFMAElemOp<aievec::MulElemOp>(*
this);
737LogicalResult aievec::FMAElemOp::verify() {
738 return verifyMulFMAElemOp<aievec::FMAElemOp>(*
this);
743 bool isFMAElemOp =
true) {
744 llvm::SMLoc typesLoc;
745 SmallVector<Type, 3> types;
746 OpAsmParser::UnresolvedOperand lhs, rhs, acc;
749 if (parser.parseOperand(lhs) || parser.parseComma() ||
750 parser.parseOperand(rhs))
755 if (parser.parseComma() || parser.parseOperand(acc))
760 if (parser.parseOptionalAttrDict(result.attributes) ||
761 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
765 if (types.size() != 3)
766 return parser.emitError(typesLoc,
"requires three types");
769 VectorType lhsType = llvm::dyn_cast<VectorType>(types[0]);
771 return parser.emitError(typesLoc,
"requires vector type");
772 VectorType rhsType = llvm::dyn_cast<VectorType>(types[1]);
774 return parser.emitError(typesLoc,
"requires vector type");
777 VectorType accType = llvm::dyn_cast<VectorType>(types[2]);
779 return parser.emitError(typesLoc,
"requires vector type");
782 if (parser.resolveOperand(lhs, lhsType, result.operands) ||
783 parser.resolveOperand(rhs, rhsType, result.operands))
788 if (parser.resolveOperand(acc, accType, result.operands))
792 return parser.addTypeToList(accType, result.types);
795ParseResult MulElemOp::parse(OpAsmParser &parser, OperationState &result) {
799ParseResult FMAElemOp::parse(OpAsmParser &parser, OperationState &result) {
808void ConcatOp::print(OpAsmPrinter &p) {
810 assert(!getSources().
empty() &&
"concat source empty");
811 p <<
" " << getSources();
814 p.printOptionalAttrDict((*this)->getAttrs());
817 p <<
" : " << getSources().getTypes().front() <<
", "
818 << getResult().getType();
822LogicalResult ConcatOp::verify() {
824 if (getSources().size() < 2)
825 return emitError(
"Must concatenate at least two vectors");
828 VectorType sourceType =
829 llvm::dyn_cast<VectorType>(getSources().getTypes().front());
830 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
831 if (!sourceType || !resultType)
832 return emitError(
"requires vector type");
834 SmallVector<Value, 8>
srcs(getSources().begin(), getSources().end());
836 for (
auto source :
srcs) {
837 VectorType type = llvm::dyn_cast<VectorType>(source.getType());
839 return emitError(
"requires vector type");
840 if (type != sourceType)
841 return emitError(
"All sources must have same type");
845 unsigned totalLanes = 0;
846 for (
auto source :
srcs) {
847 VectorType type = llvm::dyn_cast<VectorType>(source.getType());
852 return emitError(
"mismatch between vector lanes "
853 "and sum of source lanes");
859ParseResult ConcatOp::parse(OpAsmParser &parser, OperationState &result) {
860 llvm::SMLoc typesLoc;
861 SmallVector<Type, 2> types;
862 SmallVector<OpAsmParser::UnresolvedOperand, 8> sources;
865 if (parser.parseOperandList(sources))
869 if (parser.parseOptionalAttrDict(result.attributes) ||
870 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
874 if (!result.attributes.getAttrs().empty())
875 return parser.emitError(typesLoc,
"expects no attribute");
878 if (types.size() != 2)
879 return parser.emitError(typesLoc,
"requires two types");
882 VectorType sourceType = llvm::dyn_cast<VectorType>(types[0]);
883 VectorType resultType = llvm::dyn_cast<VectorType>(types[1]);
884 if (!sourceType || !resultType)
885 return parser.emitError(typesLoc,
"requires vector type");
888 if (parser.resolveOperands(sources, sourceType, result.operands))
891 return parser.addTypeToList(resultType, result.types);
895ConcatOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
896 ConcatOp::Adaptor adaptor,
897 SmallVectorImpl<Type> &inferredReturnTypes) {
898 SmallVector<Value, 8>
srcs(adaptor.getSources().begin(),
899 adaptor.getSources().end());
900 unsigned totalLength = 0;
901 for (
auto source :
srcs) {
902 VectorType type = llvm::dyn_cast<VectorType>(source.getType());
903 assert(type.getRank() == 1 &&
904 "only rank 1 vectors currently supported by concat");
905 totalLength += type.getDimSize(0);
907 inferredReturnTypes.push_back(VectorType::get(
909 llvm::dyn_cast<VectorType>(srcs[0].getType()).getElementType()));
918void ExtOp::print(OpAsmPrinter &p) {
920 p <<
" " << getSource();
923 p.printOptionalAttrDict((*this)->getAttrs());
926 p <<
" : " << getSource().getType() <<
", " << getResult().getType();
930LogicalResult ExtOp::verify() {
932 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
933 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
934 if (!sourceType || !resultType)
935 return emitError(
"requires vector type");
941 if (sourceLanes / resultLanes <= 1)
942 return emitError(
"lanes in source vector must be at least "
943 "twice that of result vector");
945 if (sourceLanes % resultLanes != 0)
946 return emitError(
"lanes in result vector must be a multiple "
947 "of source vector lanes");
950 unsigned factor = sourceLanes / resultLanes;
951 if (
static_cast<unsigned>(getIndex()) >= factor)
952 return emitError(
"index out of bounds");
955 Type stype = sourceType.getElementType();
956 Type rtype = resultType.getElementType();
958 return emitError(
"source and result element type must be same");
964ParseResult ExtOp::parse(OpAsmParser &parser, OperationState &result) {
965 llvm::SMLoc typesLoc;
966 SmallVector<Type, 2> types;
967 OpAsmParser::UnresolvedOperand source;
970 if (parser.parseOperand(source))
974 if (parser.parseOptionalAttrDict(result.attributes) ||
975 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
978 if (result.attributes.getAttrs().size() != 1)
979 return parser.emitError(typesLoc,
"requires one attribute");
982 if (types.size() != 2)
983 return parser.emitError(typesLoc,
"requires two types");
986 VectorType sourceType = llvm::dyn_cast<VectorType>(types[0]);
987 VectorType resultType = llvm::dyn_cast<VectorType>(types[1]);
988 if (!sourceType || !resultType)
989 return parser.emitError(typesLoc,
"requires vector type");
992 if (parser.resolveOperand(source, sourceType, result.operands))
995 return parser.addTypeToList(resultType, result.types);
1003template <
typename T>
1004static void printPackUnpackOp(OpAsmPrinter &p, T op) {
1006 p <<
" " << op.getSource();
1009 p.printOptionalAttrDict(op->getAttrs());
1012 p <<
" : " << op.getSource().getType() <<
", " << op.getResult().getType();
1015void PackOp::print(OpAsmPrinter &p) { printPackUnpackOp<PackOp>(p, *
this); }
1017void UnpackOp::print(OpAsmPrinter &p) { printPackUnpackOp<UnpackOp>(p, *
this); }
1020template <
typename T>
1023 auto sourceType = llvm::dyn_cast<VectorType>(op.getSource().getType());
1024 auto resultType = llvm::dyn_cast<VectorType>(op.getResult().getType());
1025 if (!sourceType || !resultType)
1026 return op.emitError(
"requires vector type");
1031 if (sourceLanes != resultLanes)
1032 return op.emitError(
"The number of lanes in input and "
1033 "output vector must match");
1035 Type stype = sourceType.getElementType();
1036 unsigned stypeWidth = stype.getIntOrFloatBitWidth();
1037 Type rtype = resultType.getElementType();
1038 unsigned rtypeWidth = rtype.getIntOrFloatBitWidth();
1040 if (isa<PackOp>(op)) {
1042 if (stypeWidth != 16)
1043 return op.emitError(
"input must be an int16 vector");
1044 if (rtypeWidth != 8)
1045 return op.emitError(
"output must be an int8 vector");
1047 if (stypeWidth != 8)
1048 return op.emitError(
"input must be an int8 vector");
1049 if (rtypeWidth != 16)
1050 return op.emitError(
"output must be an int16 vector");
1056LogicalResult PackOp::verify() {
return verifyPackUnpackOp<PackOp>(*
this); }
1058LogicalResult UnpackOp::verify() {
return verifyPackUnpackOp<UnpackOp>(*
this); }
1062 llvm::SMLoc typesLoc;
1063 SmallVector<Type, 2> types;
1064 OpAsmParser::UnresolvedOperand source;
1067 if (parser.parseOperand(source))
1071 if (parser.parseOptionalAttrDict(result.attributes) ||
1072 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
1076 if (!result.attributes.getAttrs().empty())
1077 return parser.emitError(typesLoc,
"expects no attributes");
1080 if (types.size() != 2)
1081 return parser.emitError(typesLoc,
"requires two types");
1084 VectorType sourceType = llvm::dyn_cast<VectorType>(types[0]);
1085 VectorType resultType = llvm::dyn_cast<VectorType>(types[1]);
1086 if (!sourceType || !resultType)
1087 return parser.emitError(typesLoc,
"requires vector type");
1090 if (parser.resolveOperand(source, sourceType, result.operands))
1093 return parser.addTypeToList(resultType, result.types);
1096ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
1100ParseResult UnpackOp::parse(OpAsmParser &parser, OperationState &result) {
1109LogicalResult ExtElemOp::verify() {
1111 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
1114 return emitError(
"source requires vector type");
1117 Type stype = sourceType.getElementType();
1118 Type rtype = getResult().getType();
1120 if (stype != rtype) {
1121 return emitError(
"the type of result must be the same as the element "
1122 "type of source vector");
1133void ShiftOp::print(OpAsmPrinter &p) {
1135 p <<
" " << getLhs() <<
", " << getRhs();
1138 p <<
", " << getShift();
1141 p.printOptionalAttrDict((*this)->getAttrs());
1144 p <<
" : " << getLhs().getType() <<
", " << getLhs().getType() <<
", "
1145 << getShift().getType() <<
", " << getResult().getType();
1149LogicalResult ShiftOp::verify() {
1151 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
1153 return emitError(
"requires vector type");
1156 VectorType lhsType = llvm::dyn_cast<VectorType>(getLhs().getType());
1157 VectorType rhsType = llvm::dyn_cast<VectorType>(getRhs().getType());
1159 if (!lhsType || !rhsType)
1160 return emitError(
"requires vector type");
1161 if (lhsType != resultType || rhsType != resultType)
1162 return emitError(
"All vectors must have same type");
1164 if (!isa<IntegerType>(getShift().getType()))
1165 return emitError(
"requires integer type");
1171ParseResult ShiftOp::parse(OpAsmParser &parser, OperationState &result) {
1172 llvm::SMLoc typesLoc;
1173 SmallVector<Type, 4> types;
1174 OpAsmParser::UnresolvedOperand lhs, rhs, shift;
1177 if (parser.parseOperand(lhs) || parser.parseComma() ||
1178 parser.parseOperand(rhs) || parser.parseComma() ||
1179 parser.parseOperand(shift))
1183 if (parser.parseOptionalAttrDict(result.attributes) ||
1184 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
1187 if (result.attributes.getAttrs().size() != 1)
1188 return parser.emitError(typesLoc,
"expects one attribute");
1191 if (types.size() != 4)
1192 return parser.emitError(typesLoc,
"requires four types");
1195 VectorType lhsType = llvm::dyn_cast<VectorType>(types[0]);
1196 VectorType rhsType = llvm::dyn_cast<VectorType>(types[1]);
1197 IntegerType shiftType = llvm::dyn_cast<IntegerType>(types[2]);
1198 VectorType resultType = llvm::dyn_cast<VectorType>(types[3]);
1199 if (!lhsType || !rhsType || !resultType)
1200 return parser.emitError(typesLoc,
"requires vector type");
1203 return parser.emitError(typesLoc,
"requires integer type");
1206 if (parser.resolveOperand(lhs, lhsType, result.operands) ||
1207 parser.resolveOperand(rhs, rhsType, result.operands) ||
1208 parser.resolveOperand(shift, shiftType, result.operands))
1211 return parser.addTypeToList(resultType, result.types);
1220LogicalResult ShuffleOp::verify() {
1221 unsigned modeBitWidth;
1222 bool requireRhs =
true;
1223 auto mode = getMode();
1225 case ShuffleMode::T8_8X8:
1226 case ShuffleMode::T8_16X4:
1227 case ShuffleMode::T8_4X16:
1228 case ShuffleMode::T8_8X4:
1229 case ShuffleMode::T8_4X8:
1232 case ShuffleMode::T8_64X2_LO:
1233 case ShuffleMode::T8_64X2_HI:
1234 case ShuffleMode::T8_2X64_LO:
1235 case ShuffleMode::T8_2X64_HI:
1238 case ShuffleMode::T16_8X4:
1239 case ShuffleMode::T16_4X8:
1240 case ShuffleMode::T16_1X2_flip:
1241 case ShuffleMode::T16_4X4:
1242 case ShuffleMode::T16_4X2:
1243 case ShuffleMode::T16_2X4:
1244 case ShuffleMode::T16_8X2:
1245 case ShuffleMode::T16_2X8:
1246 case ShuffleMode::T16_16X2:
1247 case ShuffleMode::T16_2X16:
1250 case ShuffleMode::T16_32X2_LO:
1251 case ShuffleMode::T16_32X2_HI:
1252 case ShuffleMode::T16_2X32_LO:
1253 case ShuffleMode::T16_2X32_HI:
1254 case ShuffleMode::T16_16X4_LO:
1255 case ShuffleMode::T16_16X4_HI:
1256 case ShuffleMode::T16_4X16_LO:
1257 case ShuffleMode::T16_4X16_HI:
1260 case ShuffleMode::T32_4X4:
1263 case ShuffleMode::T32_16X2_LO:
1264 case ShuffleMode::T32_16X2_HI:
1265 case ShuffleMode::T32_2X16_LO:
1266 case ShuffleMode::T32_2X16_HI:
1267 case ShuffleMode::T32_8X4_LO:
1268 case ShuffleMode::T32_8X4_HI:
1269 case ShuffleMode::T32_4X8_LO:
1270 case ShuffleMode::T32_4X8_HI:
1273 case ShuffleMode::T64_8X2_LO:
1274 case ShuffleMode::T64_8X2_HI:
1275 case ShuffleMode::T64_2X8_LO:
1276 case ShuffleMode::T64_2X8_HI:
1279 case ShuffleMode::T128_4X2_LO:
1280 case ShuffleMode::T128_4X2_HI:
1281 case ShuffleMode::T128_2X4_LO:
1282 case ShuffleMode::T128_2X4_HI:
1283 modeBitWidth = 128u;
1285 case ShuffleMode::T256_2X2_LO:
1286 case ShuffleMode::T256_2X2_HI:
1287 modeBitWidth = 256u;
1289 case ShuffleMode::T512_1X2_LO:
1290 case ShuffleMode::T512_1X2_HI:
1291 modeBitWidth = 512u;
1296 if (requireRhs && !getRhs())
1297 return emitError() <<
"shuffle mode '" << stringifyEnum(mode)
1298 <<
"' requires a second operand";
1300 if (!requireRhs && getRhs())
1301 return emitError() <<
"shuffle mode '" << stringifyEnum(mode)
1302 <<
"' does not admit a second operand";
1306 cast<VectorType>(getLhs().getType()).getElementTypeBitWidth();
1307 if (modeBitWidth != elemBitWidth)
1308 return emitError() <<
"shuffle mode '" << stringifyEnum(mode)
1309 <<
"' requires vectors of " << modeBitWidth
1316void LegacyShuffleOp::print(OpAsmPrinter &p) {
1318 p <<
" " << getSource();
1321 p.printOptionalAttrDict((*this)->getAttrs());
1324 p <<
" : " << getSource().getType() <<
", " << getResult().getType();
1328LogicalResult LegacyShuffleOp::verify() {
1330 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
1331 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
1332 if (!sourceType || !resultType)
1333 return emitError(
"requires vector type");
1338 if (sourceLanes != resultLanes)
1339 return emitError(
"The number of lanes in input and "
1340 "output vector must match");
1342 Type stype = sourceType.getElementType();
1343 unsigned stypeWidth = stype.getIntOrFloatBitWidth();
1344 Type rtype = resultType.getElementType();
1345 unsigned rtypeWidth = rtype.getIntOrFloatBitWidth();
1347 if (stypeWidth != rtypeWidth)
1348 return emitError(
"The type width in input and "
1349 "output must match");
1355ParseResult LegacyShuffleOp::parse(OpAsmParser &parser,
1356 OperationState &result) {
1357 llvm::SMLoc typesLoc;
1358 SmallVector<Type, 2> types;
1359 OpAsmParser::UnresolvedOperand source;
1362 if (parser.parseOperand(source))
1366 if (parser.parseOptionalAttrDict(result.attributes) ||
1367 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
1371 if (result.attributes.getAttrs().size() != 1)
1372 return parser.emitError(typesLoc,
"expects one attribute");
1375 VectorType sourceType = llvm::dyn_cast<VectorType>(types[0]);
1376 VectorType resultType = llvm::dyn_cast<VectorType>(types[1]);
1377 if (!sourceType || !resultType)
1378 return parser.emitError(typesLoc,
"requires vector type");
1381 if (parser.resolveOperand(source, sourceType, result.operands))
1384 return parser.addTypeToList(resultType, result.types);
1397template <
typename T>
1401 p <<
", " << op.getAcc();
1407template <
typename T>
1408void elideFMSubAttr(T op, SmallVector<StringRef, 4> &elidedAttrs);
1411 SmallVector<StringRef, 4> &elidedAttrs) {
1413 elidedAttrs.push_back(op.getSubAttrName());
1418 SmallVector<StringRef, 4> &elidedAttrs) {}
1421template <
typename T>
1422static void printMulFMAConvOp(OpAsmPrinter &p, T op) {
1424 p <<
" " << op.getLhs();
1426 p <<
", " << op.getRhs();
1431 SmallVector<StringRef, 4> elidedAttrs;
1432 for (
int idx = 0; idx < 2; ++idx) {
1435 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
1438 p <<
" : " << op.getLhs().getType() <<
", " << op.getRhs().getType();
1439 p <<
", " << op.getResult().getType();
1442void MulConvOp::print(OpAsmPrinter &p) {
1443 printMulFMAConvOp<aievec::MulConvOp>(p, *
this);
1446void aievec::FMAConvOp::print(OpAsmPrinter &p) {
1447 printMulFMAConvOp<aievec::FMAConvOp>(p, *
this);
1451template <
typename T>
1454 auto lhsType = llvm::dyn_cast<VectorType>(op.getLhs().getType());
1455 auto rhsType = llvm::dyn_cast<VectorType>(op.getRhs().getType());
1457 if (!lhsType || !rhsType)
1458 return op.emitError(
"requires vector type");
1460 unsigned M = op.getM();
1461 unsigned N = op.getN();
1463 if (M <= 0 || N <= 0 || 2 * M < M + N - 1)
1464 return op.emitError(
1465 "M and N should be larger than 0 and 2*M should be no less than M+N-1");
1467 auto resultType = llvm::dyn_cast<VectorType>(op.getResult().getType());
1470 return op.emitError(
"requires vector type");
1474 Type ltype = lhsType.getElementType();
1475 Type rtype = rhsType.getElementType();
1476 Type atype = resultType.getElementType();
1480 return op.emitError(
"The element type of lhs and rhs "
1481 "operand vectors must match");
1483 if (!isa<IntegerType>(ltype) || !isa<IntegerType>(rtype) ||
1484 !isa<IntegerType>(atype)) {
1485 return op.emitError(
"requires integer type");
1488 unsigned ltypeWidth = ltype.getIntOrFloatBitWidth();
1489 unsigned rtypeWidth = rtype.getIntOrFloatBitWidth();
1490 unsigned atypeWidth = atype.getIntOrFloatBitWidth();
1498 if (accLanes != M || accLanes != (rhsLanes / 2) || lhsLanes != rhsLanes) {
1499 return op.emitError(
1500 "The number of lanes in accumulator "
1501 "must be the same as M and the half as lhs and rhs operand");
1505 if (ltypeWidth >= atypeWidth || rtypeWidth >= atypeWidth)
1506 return op.emitError(
"the element type of accumulator must have "
1507 "wider width than that of the operand vectors");
1512LogicalResult aievec::MulConvOp::verify() {
1513 return verifyMulFMAConvOp<aievec::MulConvOp>(*
this);
1516LogicalResult aievec::FMAConvOp::verify() {
1517 return verifyMulFMAConvOp<aievec::FMAConvOp>(*
this);
1522 bool isFMAConvOp =
true) {
1523 llvm::SMLoc typesLoc;
1524 SmallVector<Type, 3> types;
1525 OpAsmParser::UnresolvedOperand lhs, rhs, acc;
1528 if (parser.parseOperand(lhs) || parser.parseComma() ||
1529 parser.parseOperand(rhs))
1534 if (parser.parseComma() || parser.parseOperand(acc))
1539 if (parser.parseOptionalAttrDict(result.attributes) ||
1540 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
1544 if (types.size() != 3)
1545 return parser.emitError(typesLoc,
"requires three types");
1548 VectorType lhsType = llvm::dyn_cast<VectorType>(types[0]);
1550 return parser.emitError(typesLoc,
"requires vector type");
1551 VectorType rhsType = llvm::dyn_cast<VectorType>(types[1]);
1553 return parser.emitError(typesLoc,
"requires vector type");
1556 VectorType accType = llvm::dyn_cast<VectorType>(types[2]);
1558 return parser.emitError(typesLoc,
"requires vector type");
1561 if (parser.resolveOperand(lhs, lhsType, result.operands) ||
1562 parser.resolveOperand(rhs, rhsType, result.operands))
1567 if (parser.resolveOperand(acc, accType, result.operands))
1571 return parser.addTypeToList(accType, result.types);
1574ParseResult MulConvOp::parse(OpAsmParser &parser, OperationState &result) {
1578ParseResult FMAConvOp::parse(OpAsmParser &parser, OperationState &result) {
1582#define GET_ATTRDEF_CLASSES
1583#include "aie/Dialect/AIEVec/IR/AIEVecAttributes.cpp.inc"
1585#define GET_OP_CLASSES
1586#include "aie/Dialect/AIEVec/IR/AIEVecOps.cpp.inc"
ParseResult parsePackUnpackOp(OpAsmParser &parser, OperationState &result)
void printAccumulator(OpAsmPrinter &p, T op)
ParseResult parseMulFMAElemOp(OpAsmParser &parser, OperationState &result, bool isFMAElemOp=true)
ParseResult parseMulFMAConvOp(OpAsmParser &parser, OperationState &result, bool isFMAConvOp=true)
void elideFMSubAttr(T op, SmallVector< StringRef, 4 > &elidedAttrs)
LogicalResult verifyMulFMAElemOp(T op)
LogicalResult verifyMulFMAConvOp(T op)
LogicalResult verifyPackUnpackOp(T op)
bool empty(const std::string &s)
unsigned getVectorLaneSize(mlir::VectorType type)