13#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
14#include "mlir/IR/DialectImplementation.h"
15#include "mlir/IR/OpDefinition.h"
16#include "mlir/IR/TypeUtilities.h"
17#include "mlir/Transforms/FoldUtils.h"
18#include "llvm/ADT/TypeSwitch.h"
26#include "aie/Dialect/AIEVec/AIE1/IR/AIEVecAIE1OpsDialect.cpp.inc"
34void AIEVecAIE1Dialect::initialize() {
37#include "aie/Dialect/AIEVec/AIE1/IR/AIEVecAIE1Ops.cpp.inc"
49 p <<
" " << op.getLhs();
51 p <<
", " << op.getRhs();
54 SmallVector<StringRef, 10> elidedAttrs;
55 for (
int idx = 0; idx < 2; ++idx) {
56 if (op.getStart(idx).empty())
57 elidedAttrs.push_back(op.getStartAttrName(idx));
58 if (op.getOffset(idx).empty())
59 elidedAttrs.push_back(op.getOffsetAttrName(idx));
60 if (op.getOffsetHi(idx).empty())
61 elidedAttrs.push_back(op.getOffsetHiAttrName(idx));
62 if (op.getSquare(idx).empty())
63 elidedAttrs.push_back(op.getSquareAttrName(idx));
65 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
68 p <<
" : " << op.getLhs().getType() <<
", " << op.getRhs().getType();
69 p <<
", " << op.getResult().getType();
72void AddOp::print(OpAsmPrinter &p) { printAddSubOp<AddOp>(p, *
this); }
74void SubOp::print(OpAsmPrinter &p) { printAddSubOp<SubOp>(p, *
this); }
80 auto resultType = llvm::dyn_cast<VectorType>(op.getResult().getType());
81 auto lhsType = llvm::dyn_cast<VectorType>(op.getLhs().getType());
82 auto rhsType = llvm::dyn_cast<VectorType>(op.getRhs().getType());
84 if (!lhsType || !rhsType || !resultType)
85 return op.emitError(
"requires vector type");
88 if (lhsType != rhsType || rhsType != resultType)
89 return op.emitError(
"all vectors must be of same type");
94LogicalResult AddOp::verify() {
return verifyAddSubOp<AddOp>(*
this); }
96LogicalResult SubOp::verify() {
return verifyAddSubOp<SubOp>(*
this); }
100 llvm::SMLoc typesLoc;
101 SmallVector<Type, 3> types;
102 OpAsmParser::UnresolvedOperand lhs, rhs;
105 if (parser.parseOperand(lhs) || parser.parseComma() ||
106 parser.parseOperand(rhs))
110 if (parser.parseOptionalAttrDict(result.attributes) ||
111 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
115 if (types.size() != 3)
116 return parser.emitError(typesLoc,
"requires three types");
119 VectorType lhsType = llvm::dyn_cast<VectorType>(types[0]);
121 return parser.emitError(typesLoc,
"requires vector type");
122 VectorType rhsType = llvm::dyn_cast<VectorType>(types[1]);
124 return parser.emitError(typesLoc,
"requires vector type");
125 VectorType resultType = llvm::dyn_cast<VectorType>(types[2]);
127 return parser.emitError(typesLoc,
"requires vector type");
130 if (parser.resolveOperand(lhs, lhsType, result.operands) ||
131 parser.resolveOperand(rhs, rhsType, result.operands))
134 return parser.addTypeToList(resultType, result.types);
137ParseResult AddOp::parse(OpAsmParser &parser, OperationState &result) {
141ParseResult SubOp::parse(OpAsmParser &parser, OperationState &result) {
158 p <<
", " << op.getAcc();
169 elidedAttrs.push_back(op.getSubAttrName());
176static void printMulFMAOp(OpAsmPrinter &p, T op) {
178 p <<
" " << op.getLhs();
180 p <<
", " << op.getRhs();
185 SmallVector<StringRef, 10> elidedAttrs;
186 for (
int idx = 0; idx < 2; ++idx) {
187 if (op.getStart(idx).empty())
188 elidedAttrs.push_back(op.getStartAttrName(idx));
189 if (op.getOffset(idx).empty())
190 elidedAttrs.push_back(op.getOffsetAttrName(idx));
191 if (op.getOffsetHi(idx).empty())
192 elidedAttrs.push_back(op.getOffsetHiAttrName(idx));
193 if (op.getStep(idx).empty())
194 elidedAttrs.push_back(op.getStepAttrName(idx));
195 if (op.getSquare(idx).empty())
196 elidedAttrs.push_back(op.getSquareAttrName(idx));
199 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
202 p <<
" : " << op.getLhs().getType() <<
", " << op.getRhs().getType();
203 p <<
", " << op.getResult().getType();
206void MulOp::print(OpAsmPrinter &p) { printMulFMAOp<MulOp>(p, *
this); }
208void FMAOp::print(OpAsmPrinter &p) { printMulFMAOp<FMAOp>(p, *
this); }
214 auto lhsType = llvm::dyn_cast<VectorType>(op.getLhs().getType());
215 auto rhsType = llvm::dyn_cast<VectorType>(op.getRhs().getType());
217 if (!lhsType || !rhsType)
218 return op.emitError(
"requires vector type");
220 auto resultType = llvm::dyn_cast<VectorType>(op.getResult().getType());
222 return op.emitError(
"requires vector type");
226 Type ltype = lhsType.getElementType();
227 Type rtype = rhsType.getElementType();
228 Type atype = resultType.getElementType();
229 unsigned ltypeWidth = ltype.getIntOrFloatBitWidth();
230 unsigned rtypeWidth = rtype.getIntOrFloatBitWidth();
231 unsigned atypeWidth = atype.getIntOrFloatBitWidth();
239 if (accLanes != rhsLanes || accLanes != lhsLanes) {
240 if (rhsLanes != 256 / rtypeWidth)
241 return op.emitError(
"incorrect rhs operand vector lanes");
242 if (lhsLanes < 2 * rhsLanes)
243 return op.emitError(
"The number of lanes in lhs operand "
244 "must be at least twice that of rhs operand");
245 if (accLanes > rhsLanes)
246 return op.emitError(
"The number of lanes in accumulator "
247 "must be less than that of rhs operand");
252 return op.emitError(
"The element type of lhs and rhs "
253 "operand vectors must match");
256 if (isa<IntegerType>(atype)) {
257 if (!isa<IntegerType>(ltype))
258 return op.emitError(
"Integer result must have integer operands");
260 if (ltypeWidth >= atypeWidth || rtypeWidth >= atypeWidth)
261 return op.emitError(
"the element type of accumulator must have "
262 "wider width than that of the operand vectors");
263 }
else if (isa<FloatType>(atype)) {
264 if (!isa<FloatType>(ltype))
265 return op.emitError(
"Floating point result must have "
266 "floating point operands");
268 if (ltypeWidth != atypeWidth || rtypeWidth != atypeWidth)
269 return op.emitError(
"the element type of accumulator must be "
270 "same width as the operand vectors");
276LogicalResult MulOp::verify() {
return verifyMulFMAOp<MulOp>(*
this); }
278LogicalResult FMAOp::verify() {
return verifyMulFMAOp<FMAOp>(*
this); }
282 bool isFMAOp =
true) {
283 llvm::SMLoc typesLoc;
284 SmallVector<Type, 3> types;
285 OpAsmParser::UnresolvedOperand lhs, rhs, acc;
288 if (parser.parseOperand(lhs) || parser.parseComma() ||
289 parser.parseOperand(rhs))
294 if (parser.parseComma() || parser.parseOperand(acc))
299 if (parser.parseOptionalAttrDict(result.attributes) ||
300 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
304 if (types.size() != 3)
305 return parser.emitError(typesLoc,
"requires three types");
308 VectorType lhsType = llvm::dyn_cast<VectorType>(types[0]);
310 return parser.emitError(typesLoc,
"requires vector type");
311 VectorType rhsType = llvm::dyn_cast<VectorType>(types[1]);
313 return parser.emitError(typesLoc,
"requires vector type");
316 VectorType accType = llvm::dyn_cast<VectorType>(types[2]);
318 return parser.emitError(typesLoc,
"requires vector type");
321 if (parser.resolveOperand(lhs, lhsType, result.operands) ||
322 parser.resolveOperand(rhs, rhsType, result.operands))
327 if (parser.resolveOperand(acc, accType, result.operands))
331 return parser.addTypeToList(accType, result.types);
334ParseResult MulOp::parse(OpAsmParser &parser, OperationState &result) {
338ParseResult FMAOp::parse(OpAsmParser &parser, OperationState &result) {
347void SelectOp::print(OpAsmPrinter &p) {
349 p <<
" " << getXbuff();
352 p <<
", " << getYbuff();
355 SmallVector<StringRef, 10> elidedAttrs;
356 for (
int idx = 0; idx < 2; ++idx) {
357 if (getStart(idx).empty())
358 elidedAttrs.push_back(getStartAttrName(idx));
359 if (getOffset(idx).empty())
360 elidedAttrs.push_back(getOffsetAttrName(idx));
361 if (getOffsetHi(idx).empty())
362 elidedAttrs.push_back(getOffsetHiAttrName(idx));
363 if (getSquare(idx).empty())
364 elidedAttrs.push_back(getSquareAttrName(idx));
366 p.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
369 p <<
" : " << getXbuff().getType();
371 p <<
", " << getYbuff().getType();
372 p <<
", " << getResult().getType();
376LogicalResult SelectOp::verify() {
378 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
379 VectorType xbuffType = llvm::dyn_cast<VectorType>(getXbuff().getType());
381 if (!resultType || !xbuffType)
382 return emitError(
"requires vector type");
385 Type rtype = resultType.getElementType();
386 Type xtype = xbuffType.getElementType();
388 return emitError(
"types of result and xbuff must match");
392 VectorType ybuffType = llvm::dyn_cast<VectorType>(getYbuff().getType());
393 if (xbuffType != ybuffType)
394 return emitError(
"types of xbuff and ybuff must match");
400 if (sourceLanes < resultLanes)
401 return emitError(
"xbuff cannot be smaller than result");
407ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
408 llvm::SMLoc typesLoc;
409 SmallVector<Type, 3> types;
410 OpAsmParser::UnresolvedOperand xbuff, ybuff;
413 if (parser.parseOperand(xbuff))
417 ParseResult hasYbuff = parser.parseOptionalComma();
418 if (hasYbuff.succeeded() && parser.parseOperand(ybuff))
422 if (parser.parseOptionalAttrDict(result.attributes) ||
423 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
427 if (types.size() < 2)
428 return parser.emitError(typesLoc,
"requires at least two type");
431 VectorType xbuffType = llvm::dyn_cast<VectorType>(types[0]);
433 return parser.emitError(typesLoc,
"requires vector type");
434 VectorType ybuffType;
435 if (hasYbuff.succeeded()) {
436 ybuffType = llvm::dyn_cast<VectorType>(types[1]);
438 return parser.emitError(typesLoc,
"requires vector type");
440 VectorType resultType = llvm::dyn_cast<VectorType>(types.back());
442 return parser.emitError(typesLoc,
"requires vector type");
445 if (parser.resolveOperand(xbuff, xbuffType, result.operands))
448 if (hasYbuff.succeeded())
449 if (parser.resolveOperand(ybuff, ybuffType, result.operands))
452 return parser.addTypeToList(resultType, result.types);
460void ExtOp::print(OpAsmPrinter &p) {
462 p <<
" " << getSource();
465 p.printOptionalAttrDict((*this)->getAttrs());
468 p <<
" : " << getSource().getType() <<
", " << getResult().getType();
472LogicalResult ExtOp::verify() {
474 VectorType sourceType = llvm::dyn_cast<VectorType>(getSource().getType());
475 VectorType resultType = llvm::dyn_cast<VectorType>(getResult().getType());
476 if (!sourceType || !resultType)
477 return emitError(
"requires vector type");
483 if (sourceLanes / resultLanes <= 1)
484 return emitError(
"lanes in source vector must be at least "
485 "twice that of result vector");
487 if (sourceLanes % resultLanes != 0)
488 return emitError(
"lanes in result vector must be a multiple "
489 "of source vector lanes");
492 unsigned factor = sourceLanes / resultLanes;
493 if (
static_cast<unsigned>(getIndex()) >= factor)
494 return emitError(
"index out of bounds");
497 Type stype = sourceType.getElementType();
498 Type rtype = resultType.getElementType();
500 return emitError(
"source and result element type must be same");
506ParseResult ExtOp::parse(OpAsmParser &parser, OperationState &result) {
507 llvm::SMLoc typesLoc;
508 SmallVector<Type, 2> types;
509 OpAsmParser::UnresolvedOperand source;
512 if (parser.parseOperand(source))
516 if (parser.parseOptionalAttrDict(result.attributes) ||
517 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
520 if (result.attributes.getAttrs().size() != 1)
521 return parser.emitError(typesLoc,
"requires one attribute");
524 if (types.size() != 2)
525 return parser.emitError(typesLoc,
"requires two types");
528 VectorType sourceType = llvm::dyn_cast<VectorType>(types[0]);
529 VectorType resultType = llvm::dyn_cast<VectorType>(types[1]);
530 if (!sourceType || !resultType)
531 return parser.emitError(typesLoc,
"requires vector type");
534 if (parser.resolveOperand(source, sourceType, result.operands))
537 return parser.addTypeToList(resultType, result.types);
545#define GET_OP_CLASSES
546#include "aie/Dialect/AIEVec/AIE1/IR/AIEVecAIE1Ops.cpp.inc"
ParseResult parseAddSubOp(OpAsmParser &parser, OperationState &result)
LogicalResult verifyAddSubOp(T op)
ParseResult parseMulFMAOp(OpAsmParser &parser, OperationState &result, bool isFMAOp=true)
LogicalResult verifyMulFMAOp(T op)
void printAddSubOp(OpAsmPrinter &p, T op)
void elideFMSubAttr(T op, SmallVector< StringRef, 10 > &elidedAttrs)
void printAccumulator(OpAsmPrinter &p, T op)
unsigned getVectorLaneSize(mlir::VectorType type)