13#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
14#include "mlir/IR/DialectImplementation.h"
15#include "mlir/IR/OpDefinition.h"
16#include "mlir/IR/TypeUtilities.h"
17#include "mlir/Transforms/FoldUtils.h"
18#include "llvm/ADT/TypeSwitch.h"
28#include "aie/Dialect/AIEVec/IR/AIEVecEnums.cpp.inc"
29#include "aie/Dialect/AIEVec/IR/AIEVecOpsDialect.cpp.inc"
35void AIEVecDialect::initialize() {
38#define GET_ATTRDEF_LIST
39#include "aie/Dialect/AIEVec/IR/AIEVecAttributes.cpp.inc"
43#include "aie/Dialect/AIEVec/IR/AIEVecOps.cpp.inc"
52void UPDOp::print(OpAsmPrinter &p) {
54 p <<
" " << getSource() <<
"[" << getIndices() <<
"]";
57 p <<
", " << getVector();
60 SmallVector<StringRef, 3> elidedAttrs;
61 elidedAttrs.push_back(UPDOp::getOperandSegmentSizeAttr());
62 p.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
65 p <<
" : " << getSource().getType() <<
", " << getResult().getType();
69LogicalResult UPDOp::verify() {
71 MemRefType sourceType = llvm::dyn_cast<MemRefType>(getSource().getType());
72 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
74 return emitError(
"requires memref type");
76 return emitError(
"requires vector type");
77 if (getIndices().empty())
78 return emitError(
"upd source cannot come from scalar value");
83 Type vecType = llvm::dyn_cast<VectorType>(getVector().getType());
84 if (vecType != resultType)
85 return emitError(
"result types of linked UPD ops do not match");
91ParseResult UPDOp::parse(OpAsmParser &parser, OperationState &result) {
92 auto &builder = parser.getBuilder();
94 SmallVector<Type, 2> types;
95 OpAsmParser::UnresolvedOperand source, vector;
96 SmallVector<OpAsmParser::UnresolvedOperand, 8> indices;
99 if (parser.parseOperand(source) ||
100 parser.parseOperandList(indices, OpAsmParser::Delimiter::Square))
102 ParseResult hasVector = parser.parseOptionalComma();
103 if (hasVector.succeeded() && parser.parseOperand(vector))
107 if (parser.parseOptionalAttrDict(result.attributes) ||
108 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
111 if (result.attributes.getAttrs().size() != 2)
112 return parser.emitError(typesLoc,
"requires two attributes");
115 if (types.size() != 2)
116 return parser.emitError(typesLoc,
"requires two types");
119 auto memrefType = llvm::dyn_cast<MemRefType>(types[0]);
121 return parser.emitError(typesLoc,
"requires memref type");
122 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
124 return parser.emitError(typesLoc,
"requires vector type");
125 auto indicesType = builder.getIndexType();
128 if (parser.resolveOperand(source, memrefType, result.operands) ||
129 parser.resolveOperands(indices, indicesType, result.operands))
132 if (hasVector.succeeded())
133 if (parser.resolveOperand(vector, vectorType, result.operands))
137 result.addAttribute(UPDOp::getOperandSegmentSizeAttr(),
138 builder.getDenseI32ArrayAttr(
139 {1, static_cast<int32_t>(indices.size()),
140 static_cast<int32_t>(hasVector.succeeded())}));
142 return parser.addTypeToList(vectorType, result.types);
150void CastOp::print(OpAsmPrinter &p) {
152 p <<
" " << getSource();
155 p.printOptionalAttrDict((*this)->getAttrs());
158 p <<
" : " << getSource().getType() <<
", " << getResult().getType();
162LogicalResult CastOp::verify() {
164 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
165 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
167 return emitError(
"requires source vector type");
169 return emitError(
"requires result vector type");
171 if (sourceType.getElementType().getIntOrFloatBitWidth() !=
172 resultType.getElementType().getIntOrFloatBitWidth()) {
173 return emitError(
"the bitwidth of resource and result should be equal");
180ParseResult CastOp::parse(OpAsmParser &parser, OperationState &result) {
181 llvm::SMLoc typesLoc;
182 SmallVector<Type, 2> types;
183 OpAsmParser::UnresolvedOperand source;
186 if (parser.parseOperand(source))
190 if (parser.parseOptionalAttrDict(result.attributes) ||
191 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
194 if (result.attributes.getAttrs().size() != 1)
195 return parser.emitError(typesLoc,
"requires one attribute");
198 if (types.size() != 2)
199 return parser.emitError(typesLoc,
"requires two types");
202 VectorType sourceType = llvm::dyn_cast<VectorType>(types[0]);
204 return parser.emitError(typesLoc,
"requires vector type");
205 VectorType vectorType = llvm::dyn_cast<VectorType>(types[1]);
207 return parser.emitError(typesLoc,
"requires vector type");
210 if (parser.resolveOperand(source, sourceType, result.operands))
213 return parser.addTypeToList(vectorType, result.types);
217OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
218 auto srcCastOp = getSource().getDefiningOp<aievec::CastOp>();
222 if (srcCastOp.getIsResAcc() == getIsResAcc())
223 return srcCastOp.getResult();
225 return srcCastOp.getSource();
233OpFoldResult SRSOp::fold(FoldAdaptor adaptor) {
234 auto srcDefOp = getSource().getDefiningOp();
238 auto upsOp = dyn_cast<UPSOp>(srcDefOp);
242 auto shiftDefOp = getShift().getDefiningOp();
246 auto constOp = dyn_cast<arith::ConstantOp>(shiftDefOp);
250 if (upsOp.getSource().getType() != getResult().getType())
253 return upsOp.getSource();
257void SRSOp::print(OpAsmPrinter &p) {
259 p <<
" " << getSource() <<
", ";
265 p <<
" : " << getSource().getType() <<
", " << getShift().getType() <<
", "
266 << getResult().getType();
270LogicalResult SRSOp::verify() {
272 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
273 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
275 return emitError(
"requires accumulator type");
277 return emitError(
"requires vector type");
282 if (accLanes != vecLanes)
283 return emitError(
"The number of lanes in result vector "
284 "and source accumulator must match");
287 Type stype = resultType.getElementType();
288 Type atype = sourceType.getElementType();
289 unsigned stypeWidth = stype.getIntOrFloatBitWidth();
290 unsigned atypeWidth = atype.getIntOrFloatBitWidth();
292 if (isa<IntegerType>(atype) && stypeWidth >= atypeWidth)
293 return emitError(
"the element type of source accumulator must be "
294 "wider than that of the result vector");
295 else if (isa<FloatType>(atype) && stypeWidth != 16 &&
296 stypeWidth != atypeWidth)
297 return emitError(
"the element type of source accumulator must be "
298 "same as the result vector");
304ParseResult SRSOp::parse(OpAsmParser &parser, OperationState &result) {
305 llvm::SMLoc typesLoc;
306 SmallVector<Type, 3> types;
307 OpAsmParser::UnresolvedOperand source, shift;
310 if (parser.parseOperand(source) || parser.parseComma() ||
311 parser.parseOperand(shift))
315 if (parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
319 if (types.size() != 3)
320 return parser.emitError(typesLoc,
"requires three types");
323 VectorType accType = llvm::dyn_cast<VectorType>(types[0]);
325 return parser.emitError(typesLoc,
"requires vector type");
327 IntegerType shiftType = llvm::dyn_cast<IntegerType>(types[1]);
329 return parser.emitError(typesLoc,
"requires integer type");
331 VectorType vectorType = llvm::dyn_cast<VectorType>(types[2]);
333 return parser.emitError(typesLoc,
"requires vector type");
336 if (parser.resolveOperand(source, accType, result.operands) ||
337 parser.resolveOperand(shift, shiftType, result.operands))
340 return parser.addTypeToList(vectorType, result.types);
348OpFoldResult UPSOp::fold(FoldAdaptor adaptor) {
353 auto srcDefOp = getSource().getDefiningOp();
356 auto srsOp = llvm::dyn_cast<SRSOp>(srcDefOp);
359 return srsOp.getSource();
363void UPSOp::print(OpAsmPrinter &p) {
365 p <<
" " << getSource();
368 p.printOptionalAttrDict((*this)->getAttrs());
371 p <<
" : " << getSource().getType() <<
", " << getResult().getType();
375LogicalResult UPSOp::verify() {
377 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
378 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
380 return emitError(
"requires vector type");
382 return emitError(
"requires vector type");
387 if (vecLanes != accLanes)
388 return emitError(
"The number of lanes in source vector "
389 "and result accumulator must match");
392 Type stype = sourceType.getElementType();
393 Type atype = resultType.getElementType();
394 unsigned stypeWidth = stype.getIntOrFloatBitWidth();
395 unsigned atypeWidth = atype.getIntOrFloatBitWidth();
397 if (stypeWidth >= atypeWidth)
398 return emitError(
"the element type of result accumulator "
399 "must be wider than that of the source vector");
405ParseResult UPSOp::parse(OpAsmParser &parser, OperationState &result) {
406 llvm::SMLoc typesLoc;
407 SmallVector<Type, 2> types;
408 OpAsmParser::UnresolvedOperand source;
411 if (parser.parseOperand(source))
415 if (parser.parseOptionalAttrDict(result.attributes) ||
416 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
419 if (result.attributes.getAttrs().size() != 1)
420 return parser.emitError(typesLoc,
"requires one attribute");
423 if (types.size() != 2)
424 return parser.emitError(typesLoc,
"requires two types");
427 VectorType vectorType = llvm::dyn_cast<VectorType>(types[0]);
429 return parser.emitError(typesLoc,
"requires vector type");
430 VectorType accType = llvm::dyn_cast<VectorType>(types[1]);
432 return parser.emitError(typesLoc,
"requires vector type");
435 if (parser.resolveOperand(source, vectorType, result.operands))
438 return parser.addTypeToList(accType, result.types);
446void BroadcastOp::print(OpAsmPrinter &p) {
448 p <<
" " << getSource();
451 p.printOptionalAttrDict((*this)->getAttrs());
454 p <<
" : " << getSource().getType() <<
", " << getResult().getType();
458LogicalResult BroadcastOp::verify() {
460 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
461 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
464 return emitError(
"requires vector type");
466 return emitError(
"requires vector type");
468 if (sourceType != resultType) {
469 return emitError(
"The vector type of source vector "
470 "and result vector must match");
475 if (sourceLanes != resultLanes)
476 return emitError(
"The number of lanes in source vector "
477 "and result vector must match");
480 Type stype = sourceType.getElementType();
481 Type rtype = resultType.getElementType();
483 if (stype != rtype) {
484 return emitError(
"the element type of result vector "
485 "must be the same as source vector");
492ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
493 llvm::SMLoc typesLoc;
494 SmallVector<Type, 2> types;
495 OpAsmParser::UnresolvedOperand source;
498 if (parser.parseOperand(source))
502 if (parser.parseOptionalAttrDict(result.attributes) ||
503 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
506 if (result.attributes.getAttrs().size() != 1)
507 return parser.emitError(typesLoc,
"requires one attribute");
510 if (types.size() != 2)
511 return parser.emitError(typesLoc,
"requires two types");
514 VectorType vecType = llvm::dyn_cast<VectorType>(types[0]);
516 return parser.emitError(typesLoc,
"requires vector type");
518 VectorType resType = llvm::dyn_cast<VectorType>(types[1]);
520 return parser.emitError(typesLoc,
"requires vector type");
523 if (parser.resolveOperand(source, vecType, result.operands))
526 return parser.addTypeToList(resType, result.types);
534void BroadcastScalarOp::print(OpAsmPrinter &p) {
536 p <<
" " << getSource();
539 p <<
" : " << getSource().getType() <<
", " << getResult().getType();
543LogicalResult BroadcastScalarOp::verify() {
545 Type sourceType = getSource().getType();
546 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
549 return emitError(
"requires vector type");
551 if (!isa<IntegerType, FloatType>(sourceType))
552 return emitError(
"requires source type to be integer or float");
554 Type resultElemType = resultType.getElementType();
555 if (sourceType != resultElemType) {
556 return emitError(
"the element type of result vector must be the same as "
564ParseResult BroadcastScalarOp::parse(OpAsmParser &parser,
565 OperationState &result) {
566 llvm::SMLoc typesLoc;
567 SmallVector<Type, 2> types;
568 OpAsmParser::UnresolvedOperand source;
571 if (parser.parseOperand(source))
575 if (parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
578 if (!result.attributes.getAttrs().empty())
579 return parser.emitError(typesLoc,
"do not require attributes");
582 if (types.size() != 2)
583 return parser.emitError(typesLoc,
"requires two types");
586 VectorType resType = llvm::dyn_cast<VectorType>(types[1]);
588 return parser.emitError(typesLoc,
"requires vector type");
590 if (parser.resolveOperand(source, types[0], result.operands))
593 return parser.addTypeToList(resType, result.types);
610 p <<
", " << op.getAcc();
620 SmallVector<StringRef, 4> &elidedAttrs) {
622 elidedAttrs.push_back(op.getSubAttrName());
627 SmallVector<StringRef, 4> &elidedAttrs) {}
631static void printMulFMAElemOp(OpAsmPrinter &p, T op) {
633 p <<
" " << op.getLhs();
635 p <<
", " << op.getRhs();
640 SmallVector<StringRef, 4> elidedAttrs;
641 for (
int idx = 0; idx < 2; ++idx) {
644 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
647 p <<
" : " << op.getLhs().getType() <<
", " << op.getRhs().getType();
648 p <<
", " << op.getResult().getType();
651void MulElemOp::print(OpAsmPrinter &p) {
652 printMulFMAElemOp<aievec::MulElemOp>(p, *
this);
655void aievec::FMAElemOp::print(OpAsmPrinter &p) {
656 printMulFMAElemOp<aievec::FMAElemOp>(p, *
this);
663 auto lhsType = llvm::dyn_cast<VectorType>(op.getLhs().getType());
664 auto rhsType = llvm::dyn_cast<VectorType>(op.getRhs().getType());
666 if (!lhsType || !rhsType)
667 return op.emitError(
"requires vector type");
669 auto resultType = llvm::dyn_cast<VectorType>(op.getResult().getType());
672 return op.emitError(
"requires vector type");
676 Type ltype = lhsType.getElementType();
677 Type rtype = rhsType.getElementType();
678 Type atype = resultType.getElementType();
679 unsigned ltypeWidth = ltype.getIntOrFloatBitWidth();
680 unsigned rtypeWidth = rtype.getIntOrFloatBitWidth();
681 unsigned atypeWidth = atype.getIntOrFloatBitWidth();
688 if (lhsLanes != rhsLanes) {
689 return op.emitError(
"The number of lanes in lhs operand "
690 "must be the same as rhs operand");
695 return op.emitError(
"The element type of lhs and rhs "
696 "operand vectors must match");
699 if (isa<IntegerType>(atype)) {
700 if (!isa<IntegerType>(ltype))
701 return op.emitError(
"Integer result must have integer operands");
703 if (ltypeWidth >= atypeWidth || rtypeWidth >= atypeWidth)
704 return op.emitError(
"the element type of accumulator must have "
705 "wider width than that of the operand vectors");
706 }
else if (isa<FloatType>(atype)) {
707 if (!isa<FloatType>(ltype))
708 return op.emitError(
"Floating point result must have "
709 "floating point operands");
715LogicalResult aievec::MulElemOp::verify() {
716 return verifyMulFMAElemOp<aievec::MulElemOp>(*
this);
719LogicalResult aievec::FMAElemOp::verify() {
720 return verifyMulFMAElemOp<aievec::FMAElemOp>(*
this);
725 bool isFMAElemOp =
true) {
726 llvm::SMLoc typesLoc;
727 SmallVector<Type, 3> types;
728 OpAsmParser::UnresolvedOperand lhs, rhs, acc;
731 if (parser.parseOperand(lhs) || parser.parseComma() ||
732 parser.parseOperand(rhs))
737 if (parser.parseComma() || parser.parseOperand(acc))
742 if (parser.parseOptionalAttrDict(result.attributes) ||
743 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
747 if (types.size() != 3)
748 return parser.emitError(typesLoc,
"requires three types");
751 VectorType lhsType = llvm::dyn_cast<VectorType>(types[0]);
753 return parser.emitError(typesLoc,
"requires vector type");
754 VectorType rhsType = llvm::dyn_cast<VectorType>(types[1]);
756 return parser.emitError(typesLoc,
"requires vector type");
759 VectorType accType = llvm::dyn_cast<VectorType>(types[2]);
761 return parser.emitError(typesLoc,
"requires vector type");
764 if (parser.resolveOperand(lhs, lhsType, result.operands) ||
765 parser.resolveOperand(rhs, rhsType, result.operands))
770 if (parser.resolveOperand(acc, accType, result.operands))
774 return parser.addTypeToList(accType, result.types);
777ParseResult MulElemOp::parse(OpAsmParser &parser, OperationState &result) {
781ParseResult FMAElemOp::parse(OpAsmParser &parser, OperationState &result) {
790void ConcatOp::print(OpAsmPrinter &p) {
792 assert(!getSources().empty() &&
"concat source empty");
793 p <<
" " << getSources();
796 p.printOptionalAttrDict((*this)->getAttrs());
799 p <<
" : " << getSources().getTypes().front() <<
", "
800 << getResult().getType();
804LogicalResult ConcatOp::verify() {
806 if (getSources().size() < 2)
807 return emitError(
"Must concatenate at least two vectors");
810 VectorType sourceType =
811 llvm::dyn_cast<VectorType>(getSources().getTypes().front());
812 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
813 if (!sourceType || !resultType)
814 return emitError(
"requires vector type");
816 SmallVector<Value, 8>
srcs(getSources().begin(), getSources().end());
818 for (
auto source :
srcs) {
819 VectorType type = llvm::dyn_cast<VectorType>(source.getType());
821 return emitError(
"requires vector type");
822 if (type != sourceType)
823 return emitError(
"All sources must have same type");
827 unsigned totalLanes = 0;
828 for (
auto source :
srcs) {
829 VectorType type = llvm::dyn_cast<VectorType>(source.getType());
834 return emitError(
"mismatch between vector lanes "
835 "and sum of source lanes");
841ParseResult ConcatOp::parse(OpAsmParser &parser, OperationState &result) {
842 llvm::SMLoc typesLoc;
843 SmallVector<Type, 2> types;
844 SmallVector<OpAsmParser::UnresolvedOperand, 8> sources;
847 if (parser.parseOperandList(sources))
851 if (parser.parseOptionalAttrDict(result.attributes) ||
852 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
856 if (!result.attributes.getAttrs().empty())
857 return parser.emitError(typesLoc,
"expects no attribute");
860 if (types.size() != 2)
861 return parser.emitError(typesLoc,
"requires two types");
864 VectorType sourceType = llvm::dyn_cast<VectorType>(types[0]);
865 VectorType resultType = llvm::dyn_cast<VectorType>(types[1]);
866 if (!sourceType || !resultType)
867 return parser.emitError(typesLoc,
"requires vector type");
870 if (parser.resolveOperands(sources, sourceType, result.operands))
873 return parser.addTypeToList(resultType, result.types);
877ConcatOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
878 ConcatOp::Adaptor adaptor,
879 SmallVectorImpl<Type> &inferredReturnTypes) {
880 SmallVector<Value, 8>
srcs(adaptor.getSources().begin(),
881 adaptor.getSources().end());
882 unsigned totalLength = 0;
883 for (
auto source :
srcs) {
884 VectorType type = llvm::dyn_cast<VectorType>(source.getType());
885 assert(type.getRank() == 1 &&
886 "only rank 1 vectors currently supported by concat");
887 totalLength += type.getDimSize(0);
889 inferredReturnTypes.push_back(VectorType::get(
891 llvm::dyn_cast<VectorType>(srcs[0].getType()).getElementType()));
900void ExtOp::print(OpAsmPrinter &p) {
902 p <<
" " << getSource();
905 p.printOptionalAttrDict((*this)->getAttrs());
908 p <<
" : " << getSource().getType() <<
", " << getResult().getType();
912LogicalResult ExtOp::verify() {
914 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
915 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
916 if (!sourceType || !resultType)
917 return emitError(
"requires vector type");
923 if (sourceLanes / resultLanes <= 1)
924 return emitError(
"lanes in source vector must be at least "
925 "twice that of result vector");
927 if (sourceLanes % resultLanes != 0)
928 return emitError(
"lanes in result vector must be a multiple "
929 "of source vector lanes");
932 unsigned factor = sourceLanes / resultLanes;
933 if (
static_cast<unsigned>(getIndex()) >= factor)
934 return emitError(
"index out of bounds");
937 Type stype = sourceType.getElementType();
938 Type rtype = resultType.getElementType();
940 return emitError(
"source and result element type must be same");
946ParseResult ExtOp::parse(OpAsmParser &parser, OperationState &result) {
947 llvm::SMLoc typesLoc;
948 SmallVector<Type, 2> types;
949 OpAsmParser::UnresolvedOperand source;
952 if (parser.parseOperand(source))
956 if (parser.parseOptionalAttrDict(result.attributes) ||
957 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
960 if (result.attributes.getAttrs().size() != 1)
961 return parser.emitError(typesLoc,
"requires one attribute");
964 if (types.size() != 2)
965 return parser.emitError(typesLoc,
"requires two types");
968 VectorType sourceType = llvm::dyn_cast<VectorType>(types[0]);
969 VectorType resultType = llvm::dyn_cast<VectorType>(types[1]);
970 if (!sourceType || !resultType)
971 return parser.emitError(typesLoc,
"requires vector type");
974 if (parser.resolveOperand(source, sourceType, result.operands))
977 return parser.addTypeToList(resultType, result.types);
986static void printPackUnpackOp(OpAsmPrinter &p, T op) {
988 p <<
" " << op.getSource();
991 p.printOptionalAttrDict(op->getAttrs());
994 p <<
" : " << op.getSource().getType() <<
", " << op.getResult().getType();
997void PackOp::print(OpAsmPrinter &p) { printPackUnpackOp<PackOp>(p, *
this); }
999void UnpackOp::print(OpAsmPrinter &p) { printPackUnpackOp<UnpackOp>(p, *
this); }
1002template <
typename T>
1005 auto sourceType = llvm::dyn_cast<VectorType>(op.getSource().getType());
1006 auto resultType = llvm::dyn_cast<VectorType>(op.getResult().getType());
1007 if (!sourceType || !resultType)
1008 return op.emitError(
"requires vector type");
1013 if (sourceLanes != resultLanes)
1014 return op.emitError(
"The number of lanes in input and "
1015 "output vector must match");
1017 Type stype = sourceType.getElementType();
1018 unsigned stypeWidth = stype.getIntOrFloatBitWidth();
1019 Type rtype = resultType.getElementType();
1020 unsigned rtypeWidth = rtype.getIntOrFloatBitWidth();
1022 if (isa<PackOp>(op)) {
1024 if (stypeWidth != 16)
1025 return op.emitError(
"input must be an int16 vector");
1026 if (rtypeWidth != 8)
1027 return op.emitError(
"output must be an int8 vector");
1029 if (stypeWidth != 8)
1030 return op.emitError(
"input must be an int8 vector");
1031 if (rtypeWidth != 16)
1032 return op.emitError(
"output must be an int16 vector");
1038LogicalResult PackOp::verify() {
return verifyPackUnpackOp<PackOp>(*
this); }
1040LogicalResult UnpackOp::verify() {
return verifyPackUnpackOp<UnpackOp>(*
this); }
1044 llvm::SMLoc typesLoc;
1045 SmallVector<Type, 2> types;
1046 OpAsmParser::UnresolvedOperand source;
1049 if (parser.parseOperand(source))
1053 if (parser.parseOptionalAttrDict(result.attributes) ||
1054 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
1058 if (!result.attributes.getAttrs().empty())
1059 return parser.emitError(typesLoc,
"expects no attributes");
1062 if (types.size() != 2)
1063 return parser.emitError(typesLoc,
"requires two types");
1066 VectorType sourceType = llvm::dyn_cast<VectorType>(types[0]);
1067 VectorType resultType = llvm::dyn_cast<VectorType>(types[1]);
1068 if (!sourceType || !resultType)
1069 return parser.emitError(typesLoc,
"requires vector type");
1072 if (parser.resolveOperand(source, sourceType, result.operands))
1075 return parser.addTypeToList(resultType, result.types);
1078ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
1082ParseResult UnpackOp::parse(OpAsmParser &parser, OperationState &result) {
1091LogicalResult ExtElemOp::verify() {
1093 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
1096 return emitError(
"source requires vector type");
1099 Type stype = sourceType.getElementType();
1100 Type rtype = getResult().getType();
1102 if (stype != rtype) {
1103 return emitError(
"the type of result must be the same as the element "
1104 "type of source vector");
1115void ShiftOp::print(OpAsmPrinter &p) {
1117 p <<
" " << getLhs() <<
", " << getRhs();
1120 p <<
", " << getShift();
1123 p.printOptionalAttrDict((*this)->getAttrs());
1126 p <<
" : " << getLhs().getType() <<
", " << getLhs().getType() <<
", "
1127 << getShift().getType() <<
", " << getResult().getType();
1131LogicalResult ShiftOp::verify() {
1133 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
1135 return emitError(
"requires vector type");
1138 VectorType lhsType = llvm::dyn_cast<VectorType>(getLhs().getType());
1139 VectorType rhsType = llvm::dyn_cast<VectorType>(getRhs().getType());
1141 if (!lhsType || !rhsType)
1142 return emitError(
"requires vector type");
1143 if (lhsType != resultType || rhsType != resultType)
1144 return emitError(
"All vectors must have same type");
1146 if (!isa<IntegerType>(getShift().getType()))
1147 return emitError(
"requires integer type");
1153ParseResult ShiftOp::parse(OpAsmParser &parser, OperationState &result) {
1154 llvm::SMLoc typesLoc;
1155 SmallVector<Type, 4> types;
1156 OpAsmParser::UnresolvedOperand lhs, rhs, shift;
1159 if (parser.parseOperand(lhs) || parser.parseComma() ||
1160 parser.parseOperand(rhs) || parser.parseComma() ||
1161 parser.parseOperand(shift))
1165 if (parser.parseOptionalAttrDict(result.attributes) ||
1166 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
1169 if (result.attributes.getAttrs().size() != 1)
1170 return parser.emitError(typesLoc,
"expects one attribute");
1173 if (types.size() != 4)
1174 return parser.emitError(typesLoc,
"requires four types");
1177 VectorType lhsType = llvm::dyn_cast<VectorType>(types[0]);
1178 VectorType rhsType = llvm::dyn_cast<VectorType>(types[1]);
1179 IntegerType shiftType = llvm::dyn_cast<IntegerType>(types[2]);
1180 VectorType resultType = llvm::dyn_cast<VectorType>(types[3]);
1181 if (!lhsType || !rhsType || !resultType)
1182 return parser.emitError(typesLoc,
"requires vector type");
1185 return parser.emitError(typesLoc,
"requires integer type");
1188 if (parser.resolveOperand(lhs, lhsType, result.operands) ||
1189 parser.resolveOperand(rhs, rhsType, result.operands) ||
1190 parser.resolveOperand(shift, shiftType, result.operands))
1193 return parser.addTypeToList(resultType, result.types);
1202LogicalResult ShuffleOp::verify() {
1203 unsigned modeBitWidth;
1204 bool requireRhs =
true;
1205 auto mode = getMode();
1207 case ShuffleMode::T8_8X8:
1208 case ShuffleMode::T8_16X4:
1209 case ShuffleMode::T8_4X16:
1210 case ShuffleMode::T8_8X4:
1211 case ShuffleMode::T8_4X8:
1214 case ShuffleMode::T8_64X2_LO:
1215 case ShuffleMode::T8_64X2_HI:
1216 case ShuffleMode::T8_2X64_LO:
1217 case ShuffleMode::T8_2X64_HI:
1220 case ShuffleMode::T16_8X4:
1221 case ShuffleMode::T16_4X8:
1222 case ShuffleMode::T16_1X2_flip:
1223 case ShuffleMode::T16_4X4:
1224 case ShuffleMode::T16_4X2:
1225 case ShuffleMode::T16_2X4:
1226 case ShuffleMode::T16_8X2:
1227 case ShuffleMode::T16_2X8:
1228 case ShuffleMode::T16_16X2:
1229 case ShuffleMode::T16_2X16:
1232 case ShuffleMode::T16_32X2_LO:
1233 case ShuffleMode::T16_32X2_HI:
1234 case ShuffleMode::T16_2X32_LO:
1235 case ShuffleMode::T16_2X32_HI:
1236 case ShuffleMode::T16_16X4_LO:
1237 case ShuffleMode::T16_16X4_HI:
1238 case ShuffleMode::T16_4X16_LO:
1239 case ShuffleMode::T16_4X16_HI:
1242 case ShuffleMode::T32_4X4:
1245 case ShuffleMode::T32_16X2_LO:
1246 case ShuffleMode::T32_16X2_HI:
1247 case ShuffleMode::T32_2X16_LO:
1248 case ShuffleMode::T32_2X16_HI:
1249 case ShuffleMode::T32_8X4_LO:
1250 case ShuffleMode::T32_8X4_HI:
1251 case ShuffleMode::T32_4X8_LO:
1252 case ShuffleMode::T32_4X8_HI:
1255 case ShuffleMode::T64_8X2_LO:
1256 case ShuffleMode::T64_8X2_HI:
1257 case ShuffleMode::T64_2X8_LO:
1258 case ShuffleMode::T64_2X8_HI:
1261 case ShuffleMode::T128_4X2_LO:
1262 case ShuffleMode::T128_4X2_HI:
1263 case ShuffleMode::T128_2X4_LO:
1264 case ShuffleMode::T128_2X4_HI:
1265 modeBitWidth = 128u;
1267 case ShuffleMode::T256_2X2_LO:
1268 case ShuffleMode::T256_2X2_HI:
1269 modeBitWidth = 256u;
1271 case ShuffleMode::T512_1X2_LO:
1272 case ShuffleMode::T512_1X2_HI:
1273 modeBitWidth = 512u;
1278 if (requireRhs && !getRhs())
1279 return emitError() <<
"shuffle mode '" << stringifyEnum(mode)
1280 <<
"' requires a second operand";
1282 if (!requireRhs && getRhs())
1283 return emitError() <<
"shuffle mode '" << stringifyEnum(mode)
1284 <<
"' does not admit a second operand";
1288 cast<VectorType>(getLhs().getType()).getElementTypeBitWidth();
1289 if (modeBitWidth != elemBitWidth)
1290 return emitError() <<
"shuffle mode '" << stringifyEnum(mode)
1291 <<
"' requires vectors of " << modeBitWidth
1298void LegacyShuffleOp::print(OpAsmPrinter &p) {
1300 p <<
" " << getSource();
1303 p.printOptionalAttrDict((*this)->getAttrs());
1306 p <<
" : " << getSource().getType() <<
", " << getResult().getType();
1310LogicalResult LegacyShuffleOp::verify() {
1312 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
1313 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
1314 if (!sourceType || !resultType)
1315 return emitError(
"requires vector type");
1320 if (sourceLanes != resultLanes)
1321 return emitError(
"The number of lanes in input and "
1322 "output vector must match");
1324 Type stype = sourceType.getElementType();
1325 unsigned stypeWidth = stype.getIntOrFloatBitWidth();
1326 Type rtype = resultType.getElementType();
1327 unsigned rtypeWidth = rtype.getIntOrFloatBitWidth();
1329 if (stypeWidth != rtypeWidth)
1330 return emitError(
"The type width in input and "
1331 "output must match");
1337ParseResult LegacyShuffleOp::parse(OpAsmParser &parser,
1338 OperationState &result) {
1339 llvm::SMLoc typesLoc;
1340 SmallVector<Type, 2> types;
1341 OpAsmParser::UnresolvedOperand source;
1344 if (parser.parseOperand(source))
1348 if (parser.parseOptionalAttrDict(result.attributes) ||
1349 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
1353 if (result.attributes.getAttrs().size() != 1)
1354 return parser.emitError(typesLoc,
"expects one attribute");
1357 VectorType sourceType = llvm::dyn_cast<VectorType>(types[0]);
1358 VectorType resultType = llvm::dyn_cast<VectorType>(types[1]);
1359 if (!sourceType || !resultType)
1360 return parser.emitError(typesLoc,
"requires vector type");
1363 if (parser.resolveOperand(source, sourceType, result.operands))
1366 return parser.addTypeToList(resultType, result.types);
1379template <
typename T>
1383 p <<
", " << op.getAcc();
1389template <
typename T>
1390void elideFMSubAttr(T op, SmallVector<StringRef, 4> &elidedAttrs);
1393 SmallVector<StringRef, 4> &elidedAttrs) {
1395 elidedAttrs.push_back(op.getSubAttrName());
1400 SmallVector<StringRef, 4> &elidedAttrs) {}
1403template <
typename T>
1404static void printMulFMAConvOp(OpAsmPrinter &p, T op) {
1406 p <<
" " << op.getLhs();
1408 p <<
", " << op.getRhs();
1413 SmallVector<StringRef, 4> elidedAttrs;
1414 for (
int idx = 0; idx < 2; ++idx) {
1417 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
1420 p <<
" : " << op.getLhs().getType() <<
", " << op.getRhs().getType();
1421 p <<
", " << op.getResult().getType();
1424void MulConvOp::print(OpAsmPrinter &p) {
1425 printMulFMAConvOp<aievec::MulConvOp>(p, *
this);
1428void aievec::FMAConvOp::print(OpAsmPrinter &p) {
1429 printMulFMAConvOp<aievec::FMAConvOp>(p, *
this);
1433template <
typename T>
1436 auto lhsType = llvm::dyn_cast<VectorType>(op.getLhs().getType());
1437 auto rhsType = llvm::dyn_cast<VectorType>(op.getRhs().getType());
1439 if (!lhsType || !rhsType)
1440 return op.emitError(
"requires vector type");
1442 unsigned M = op.getM();
1443 unsigned N = op.getN();
1445 if (M <= 0 || N <= 0 || 2 * M < M + N - 1)
1446 return op.emitError(
1447 "M and N should be larger than 0 and 2*M should be no less than M+N-1");
1449 auto resultType = llvm::dyn_cast<VectorType>(op.getResult().getType());
1452 return op.emitError(
"requires vector type");
1456 Type ltype = lhsType.getElementType();
1457 Type rtype = rhsType.getElementType();
1458 Type atype = resultType.getElementType();
1462 return op.emitError(
"The element type of lhs and rhs "
1463 "operand vectors must match");
1465 if (!isa<IntegerType>(ltype) || !isa<IntegerType>(rtype) ||
1466 !isa<IntegerType>(atype)) {
1467 return op.emitError(
"requires integer type");
1470 unsigned ltypeWidth = ltype.getIntOrFloatBitWidth();
1471 unsigned rtypeWidth = rtype.getIntOrFloatBitWidth();
1472 unsigned atypeWidth = atype.getIntOrFloatBitWidth();
1480 if (accLanes != M || accLanes != (rhsLanes / 2) || lhsLanes != rhsLanes) {
1481 return op.emitError(
1482 "The number of lanes in accumulator "
1483 "must be the same as M and the half as lhs and rhs operand");
1487 if (ltypeWidth >= atypeWidth || rtypeWidth >= atypeWidth)
1488 return op.emitError(
"the element type of accumulator must have "
1489 "wider width than that of the operand vectors");
1494LogicalResult aievec::MulConvOp::verify() {
1495 return verifyMulFMAConvOp<aievec::MulConvOp>(*
this);
1498LogicalResult aievec::FMAConvOp::verify() {
1499 return verifyMulFMAConvOp<aievec::FMAConvOp>(*
this);
1504 bool isFMAConvOp =
true) {
1505 llvm::SMLoc typesLoc;
1506 SmallVector<Type, 3> types;
1507 OpAsmParser::UnresolvedOperand lhs, rhs, acc;
1510 if (parser.parseOperand(lhs) || parser.parseComma() ||
1511 parser.parseOperand(rhs))
1516 if (parser.parseComma() || parser.parseOperand(acc))
1521 if (parser.parseOptionalAttrDict(result.attributes) ||
1522 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
1526 if (types.size() != 3)
1527 return parser.emitError(typesLoc,
"requires three types");
1530 VectorType lhsType = llvm::dyn_cast<VectorType>(types[0]);
1532 return parser.emitError(typesLoc,
"requires vector type");
1533 VectorType rhsType = llvm::dyn_cast<VectorType>(types[1]);
1535 return parser.emitError(typesLoc,
"requires vector type");
1538 VectorType accType = llvm::dyn_cast<VectorType>(types[2]);
1540 return parser.emitError(typesLoc,
"requires vector type");
1543 if (parser.resolveOperand(lhs, lhsType, result.operands) ||
1544 parser.resolveOperand(rhs, rhsType, result.operands))
1549 if (parser.resolveOperand(acc, accType, result.operands))
1553 return parser.addTypeToList(accType, result.types);
1556ParseResult MulConvOp::parse(OpAsmParser &parser, OperationState &result) {
1560ParseResult FMAConvOp::parse(OpAsmParser &parser, OperationState &result) {
1564#define GET_ATTRDEF_CLASSES
1565#include "aie/Dialect/AIEVec/IR/AIEVecAttributes.cpp.inc"
1567#define GET_OP_CLASSES
1568#include "aie/Dialect/AIEVec/IR/AIEVecOps.cpp.inc"
ParseResult parsePackUnpackOp(OpAsmParser &parser, OperationState &result)
void printAccumulator(OpAsmPrinter &p, T op)
ParseResult parseMulFMAElemOp(OpAsmParser &parser, OperationState &result, bool isFMAElemOp=true)
ParseResult parseMulFMAConvOp(OpAsmParser &parser, OperationState &result, bool isFMAConvOp=true)
void elideFMSubAttr(T op, SmallVector< StringRef, 4 > &elidedAttrs)
LogicalResult verifyMulFMAElemOp(T op)
LogicalResult verifyMulFMAConvOp(T op)
LogicalResult verifyPackUnpackOp(T op)
unsigned getVectorLaneSize(mlir::VectorType type)