20#include "mlir/Dialect/Arith/IR/Arith.h"
21#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
22#include "mlir/Dialect/EmitC/IR/EmitC.h"
23#include "mlir/Dialect/Func/IR/FuncOps.h"
24#include "mlir/Dialect/Index/IR/IndexOps.h"
25#include "mlir/Dialect/MemRef/IR/MemRef.h"
26#include "mlir/Dialect/SCF/IR/SCF.h"
27#include "mlir/Dialect/Vector/IR/VectorOps.h"
28#include "mlir/IR/BuiltinOps.h"
29#include "mlir/IR/BuiltinTypes.h"
30#include "mlir/IR/Operation.h"
31#include "mlir/Support/IndentedOstream.h"
33#include "llvm/ADT/ScopedHashTable.h"
34#include "llvm/ADT/SmallSet.h"
35#include "llvm/ADT/StringRef.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/CommandLine.h"
38#include "llvm/Support/Debug.h"
39#include "llvm/Support/FormatVariadic.h"
40#include "llvm/Support/MathExtras.h"
48#define DEBUG_TYPE "aievec-to-cpp"
58template <
typename ForwardIterator,
typename UnaryFunctor,
59 typename NullaryFunctor>
62 NullaryFunctor betweenFn) {
65 if (failed(eachFn(*begin)))
68 for (; begin != end; ++begin) {
70 if (failed(eachFn(*begin)))
76template <
typename Container,
typename UnaryFunctor,
typename NullaryFunctor>
78 NullaryFunctor betweenFn) {
82template <
typename Container,
typename UnaryFunctor>
84 UnaryFunctor eachFn) {
91 explicit CppEmitter(raw_ostream &os,
bool declareVariablesAtTop,
bool aie2);
94 LogicalResult emitAttribute(Location loc, Attribute attr);
97 LogicalResult emitOperation(Operation &op,
bool trailingSemicolon);
103 std::optional<std::string> genCppTypeName(Type type,
bool stdintType =
true,
108 LogicalResult emitType(Location loc, Type type,
bool stdintType =
true,
115 LogicalResult emitTypes(Location loc, ArrayRef<Type> types);
119 LogicalResult emitTupleType(Location loc, ArrayRef<Type> types);
122 LogicalResult emitVariableAssignment(OpResult result);
125 LogicalResult emitVariableDeclaration(OpResult result,
bool trailingSemicolon,
134 LogicalResult emitAssignPrefix(Operation &op,
bool isAcc =
false);
137 LogicalResult emitLabel(Block &block);
141 LogicalResult emitOperandsAndAttributes(Operation &op,
142 ArrayRef<StringRef> exclude = {});
145 LogicalResult emitOperands(Operation &op);
148 StringRef getOrCreateName(Value val, std::string prefix =
"v");
151 void setName(Value val, StringRef name);
154 std::string getNewName(std::string prefix =
"v");
157 void setMemRefDimParam(Value memref,
unsigned index,
158 const std::string ¶meter);
161 StringRef getMemRefDimParam(Value memref,
unsigned index);
164 bool isMemRefDimParam(Value memref,
unsigned index);
167 StringRef getOrCreateName(Block &block, std::string prefix =
"label");
170 bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);
174 Scope(CppEmitter &emitter)
175 : valueMapperScope(emitter.valueMapper),
176 blockMapperScope(emitter.blockMapper), emitter(emitter) {
177 emitter.valueInScopeCount.push(emitter.valueInScopeCount.top());
178 emitter.labelInScopeCount.push(emitter.labelInScopeCount.top());
181 emitter.valueInScopeCount.pop();
182 emitter.labelInScopeCount.pop();
186 llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
187 llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
192 bool hasValueInScope(Value val);
195 bool hasBlockLabel(Block &block);
198 raw_indented_ostream &ostream() {
return os; }
202 bool shouldDeclareVariablesAtTop() {
return declareVariablesAtTop; }
204 bool aie2() {
return aie2_; }
207 using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
208 using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
211 raw_indented_ostream
os;
216 bool declareVariablesAtTop;
219 ValueMapper valueMapper;
222 BlockMapper blockMapper;
225 DenseMap<std::pair<Value, unsigned>, std::string> paramIndexMapper;
229 std::stack<int64_t> valueInScopeCount;
230 std::stack<int64_t> labelInScopeCount;
232 llvm::SmallSet<StringRef, 16> includeNames;
248static bool skippedOp(Operation *op, CppEmitter &emitter,
249 bool checkStrongLiveness =
true) {
252 TypeSwitch<Operation *, bool>(op)
254 .Case<memref::DimOp, memref::AssumeAlignmentOp>(
255 [](
auto op) {
return true; })
257 .Case<aievec::SRSOp>([&](
auto srsOp) {
259 auto accType = cast<VectorType>(srsOp.getSource().getType());
260 Type eltType = accType.getElementType();
263 Value source = srsOp.getSource();
264 if (!emitter.aie2() && llvm::isa<FloatType>(eltType) &&
265 source.getDefiningOp()->hasOneUse()) {
266 StringRef srcName = emitter.getOrCreateName(source);
267 emitter.setName(srsOp->getResult(0), srcName);
273 .Case<aievec::UPSOp>([&](
auto upsOp) {
275 auto accType = cast<VectorType>(upsOp.getResult().getType());
276 Type eltType = accType.getElementType();
279 Value source = upsOp.getSource();
280 if (!emitter.aie2() && llvm::isa<FloatType>(eltType) &&
281 source.getDefiningOp()->hasOneUse()) {
282 StringRef srcName = emitter.getOrCreateName(source);
283 emitter.setName(upsOp->getResult(0), srcName);
291 .Case<aievec::CastOp>([&](
auto castOp) {
292 Value source = castOp.getSource();
293 auto srcVTy = cast<VectorType>(source.getType());
294 auto resVTy = cast<VectorType>(castOp.getResult().getType());
295 if (srcVTy.getElementType() == resVTy.getElementType()) {
296 auto iElTy = dyn_cast<IntegerType>(srcVTy.getElementType());
297 if (iElTy && iElTy.getWidth() == 64) {
298 StringRef srcName = emitter.getOrCreateName(source);
299 emitter.setName(castOp->getResult(0), srcName);
306 .Case<arith::IndexCastOp, arith::IndexCastUIOp, index::CastSOp,
307 index::CastUOp>([&](
auto idxCastOp) {
308 Value source = idxCastOp->getOperand(0);
309 StringRef srcName = emitter.getOrCreateName(source);
310 emitter.setName(idxCastOp->getResult(0), srcName);
314 .Case<vector::ShapeCastOp>([&](
auto castOp) {
315 Value source = castOp.getSource();
316 StringRef srcName = emitter.getOrCreateName(source);
317 emitter.setName(castOp.getResult(), srcName);
322 .Case<UnrealizedConversionCastOp>([&](
auto uccOp) {
323 auto inputs = uccOp.getInputs();
324 auto outputs = uccOp.getOutputs();
325 if (inputs.size() > 1 || inputs.size() > 1)
327 StringRef inputName = emitter.getOrCreateName(inputs[0]);
328 emitter.setName(outputs[0], inputName);
331 .Default([&](Operation *) {
return false; });
334 checkStrongLiveness &= isa<arith::ConstantOp>(op);
338 if (skip || !checkStrongLiveness)
344 for (
auto user : op->getUsers()) {
345 if (!skippedOp(user, emitter,
false))
352static LogicalResult parseMemRefDynamicDims(CppEmitter &emitter,
355 func.walk([&](Operation *Op) {
356 if (
auto op = dyn_cast<memref::DimOp>(Op)) {
358 Value source = op.getSource();
359 Value result = op.getResult();
360 auto indexOp = dyn_cast<arith::ConstantOp>(op.getIndex().getDefiningOp());
361 assert(indexOp &&
"Failed to get the index value of dimOp");
363 APInt idxVal = llvm::cast<IntegerAttr>(indexOp.getValue()).getValue();
364 unsigned index = idxVal.getZExtValue();
366 StringRef name = emitter.getOrCreateName(result,
"m");
367 emitter.setMemRefDimParam(source, index, name.str());
373 for (BlockArgument arg : func.getArguments()) {
374 auto argType = llvm::dyn_cast<MemRefType>(arg.getType());
377 for (
unsigned dim = 0; dim < argType.getRank(); ++dim) {
378 if (argType.isDynamicDim(dim)) {
380 if (!emitter.isMemRefDimParam(arg, dim)) {
381 std::string name = emitter.getNewName(
"m");
382 emitter.setMemRefDimParam(arg, dim, name);
391static LogicalResult printMemRefDims(CppEmitter &emitter, BlockArgument arg) {
392 raw_indented_ostream &
os = emitter.ostream();
393 if (
auto argType = llvm::dyn_cast<MemRefType>(arg.getType())) {
394 for (
unsigned dim = 0; dim < argType.getRank(); ++dim) {
395 if (argType.isDynamicDim(dim)) {
396 StringRef param = emitter.getMemRefDimParam(arg, dim);
397 os <<
", size_t " << param;
405static LogicalResult createLinearizedAccess(CppEmitter &emitter, Value source,
406 SmallVector<Value, 4> indices,
407 std::string &access) {
408 auto memRefType = llvm::dyn_cast<MemRefType>(source.getType());
410 "cannot creating linearized expression for non-memref type");
411 ArrayRef<int64_t> stride = memRefType.getShape();
414 if (stride.size() != indices.size() ||
415 static_cast<int64_t
>(stride.size()) != memRefType.getRank())
420 std::string paramPart;
422 SmallVector<std::string, 4> accessVec;
423 for (
int dim = memRefType.getRank() - 1; dim >= 0; --dim) {
425 if (!emitter.hasValueInScope(indices[dim]))
430 if (!paramPart.empty())
431 cur = paramPart +
"*";
433 cur += std::to_string(numPart) +
"*";
434 cur += emitter.getOrCreateName(indices[dim]);
435 accessVec.push_back(cur);
439 if (memRefType.isDynamicDim(dim)) {
440 StringRef param = emitter.getMemRefDimParam(source, dim);
441 paramPart = param.str() + (paramPart.empty() ?
"" :
"*" + paramPart);
443 numPart *= stride[dim];
446 while (!accessVec.empty()) {
447 access += (access.empty() ?
"" :
"+") + accessVec.back();
448 accessVec.pop_back();
458static bool isReadOnly(Value read) {
460 read.getUsers().begin(), read.getUsers().end(),
461 [](
auto *user) { return isa<vector::TransferWriteOp>(user); });
469static std::pair<bool, int64_t> getTripCount(scf::ForOp forOp) {
471 auto lb = forOp.getLowerBound().getDefiningOp<arith::ConstantOp>();
472 if (
auto ub = forOp.getUpperBound().getDefiningOp<arith::ConstantOp>();
474 APInt ubValue = llvm::cast<IntegerAttr>(ub.getValue()).getValue();
475 APInt lbValue = llvm::cast<IntegerAttr>(lb.getValue()).getValue();
476 return std::make_pair(
true,
477 ubValue.getSExtValue() - lbValue.getSExtValue());
479 return std::make_pair(
false, 0);
483static std::pair<bool, int64_t> getStep(scf::ForOp forOp) {
484 if (
auto step = forOp.getStep().getDefiningOp<arith::ConstantOp>()) {
485 APInt stepValue = llvm::cast<IntegerAttr>(step.getValue()).getValue();
486 return std::make_pair(
true, stepValue.getSExtValue());
488 return std::make_pair(
false, 0);
493static StringRef getOperator(T binOp) {
494 if (isa<arith::AddIOp>(binOp) || isa<arith::AddFOp>(binOp))
496 if (isa<arith::MulIOp>(binOp) || isa<arith::MulFOp>(binOp))
498 if (isa<arith::SubIOp>(binOp) || isa<arith::SubFOp>(binOp))
500 if (isa<arith::DivFOp>(binOp) || isa<arith::DivUIOp>(binOp) ||
501 isa<arith::DivSIOp>(binOp))
503 if (isa<arith::RemSIOp>(binOp))
505 if (isa<arith::CmpIOp>(binOp)) {
506 auto cmpOp = cast<arith::CmpIOp>(binOp);
507 switch (cmpOp.getPredicate()) {
508 case arith::CmpIPredicate::eq:
510 case arith::CmpIPredicate::ne:
512 case arith::CmpIPredicate::sge:
513 case arith::CmpIPredicate::uge:
515 case arith::CmpIPredicate::sgt:
516 case arith::CmpIPredicate::ugt:
518 case arith::CmpIPredicate::sle:
519 case arith::CmpIPredicate::ule:
521 case arith::CmpIPredicate::slt:
522 case arith::CmpIPredicate::ult:
526 llvm_unreachable(
"Cannot print the operation of binary operator");
531static LogicalResult printOperation(CppEmitter &emitter, T binOp) {
532 if (failed(emitter.emitAssignPrefix(*binOp)))
534 raw_indented_ostream &
os = emitter.ostream();
535 auto lhs = binOp.getLhs();
536 if (!emitter.hasValueInScope(lhs))
538 os << emitter.getOrCreateName(lhs);
539 os << getOperator(binOp);
540 auto rhs = binOp.getRhs();
541 if (!emitter.hasValueInScope(rhs))
543 os << emitter.getOrCreateName(rhs);
549static LogicalResult printOperation(CppEmitter &emitter,
550 arith::SelectOp selectOp) {
551 if (failed(emitter.emitAssignPrefix(*selectOp)))
554 auto cond = selectOp.getCondition();
555 if (!emitter.hasValueInScope(cond))
557 auto tVal = selectOp.getTrueValue();
558 if (!emitter.hasValueInScope(tVal))
560 auto fVal = selectOp.getFalseValue();
561 if (!emitter.hasValueInScope(fVal))
564 raw_indented_ostream &
os = emitter.ostream();
565 os << emitter.getOrCreateName(cond) <<
" ? " << emitter.getOrCreateName(tVal)
566 <<
" : " << emitter.getOrCreateName(fVal);
576static LogicalResult printOperation(CppEmitter &emitter, aievec::UPDOp updOp) {
577 Value source = updOp.getSource();
579 if (!emitter.hasValueInScope(source))
583 auto indices = updOp.getIndices();
585 if (failed(createLinearizedAccess(emitter, source, indices, access)))
588 raw_indented_ostream &
os = emitter.ostream();
589 Value result = updOp.getResult();
590 auto resultType = llvm::cast<VectorType>(result.getType());
595 if (updOp.getOffset() != 0) {
596 if (std::abs(updOp.getOffset()) % elementSizeInBits)
598 int32_t updOffset = updOp.getOffset() / elementSizeInBits;
599 access += updOffset > 0 ?
" + " :
" - ";
600 access += std::to_string(std::abs(updOffset));
606 if (vecSizeInBits <= (emitter.aie2() ? 1024 : 256)) {
608 if (failed(emitter.emitAssignPrefix(*updOp)))
611 if (failed(emitter.emitType(updOp->getLoc(), resultType)))
615 os << emitter.getOrCreateName(source);
617 os <<
" + " << access;
620 Value vector = updOp.getVector();
624 if (!emitter.shouldDeclareVariablesAtTop()) {
625 if (failed(emitter.emitVariableDeclaration(updOp->getResult(0),
true)))
629 if (!emitter.hasValueInScope(vector))
631 emitter.setName(updOp->getResult(0), emitter.getOrCreateName(vector));
635 int32_t granularity = vecSizeInBits == 256 ? 128
636 : vecSizeInBits == 512 ? 256
640 assert(lanes % 2 == 0 &&
641 "The number of vector lanes of UPD result is not even");
642 SmallVector<int64_t, 4> updShape = {lanes / 2};
643 VectorType updType = VectorType::get(updShape, resultType.getElementType());
645 if (!emitter.hasValueInScope(result))
648 bool readOnly = isReadOnly(source);
649 std::string restrictPrefix =
650 readOnly ?
"r_" + emitter.getOrCreateName(result).str() +
"_" :
"";
652 if (readOnly && !vector) {
653 if (failed(emitter.emitType(updOp->getLoc(), source.getType())))
655 os <<
" " << restrictPrefix << emitter.getOrCreateName(source);
657 os << emitter.getOrCreateName(source);
660 os << emitter.getOrCreateName(result);
662 os << (granularity == 128 ?
"upd_v"
663 : granularity == 256 ?
"upd_w"
666 os << emitter.getOrCreateName(result);
668 os << std::to_string(updOp.getIndex());
671 if (failed(emitter.emitType(updOp->getLoc(), updType)))
675 os << restrictPrefix << emitter.getOrCreateName(source);
677 os <<
" + " << access;
686static LogicalResult printOperation(CppEmitter &emitter, aievec::UPSOp upsOp) {
687 Value source = upsOp.getSource();
688 int32_t shift = upsOp.getShift();
690 raw_indented_ostream &
os = emitter.ostream();
693 if (failed(emitter.emitAssignPrefix(*upsOp,
true)))
697 if (!emitter.hasValueInScope(source))
700 auto accType = llvm::cast<VectorType>(upsOp.getResult().getType());
702 Type eltType = accType.getElementType();
706 if (!emitter.aie2() && llvm::isa<FloatType>(eltType)) {
707 os << emitter.getOrCreateName(source);
712 auto iType = llvm::dyn_cast<IntegerType>(eltType);
713 auto fType = llvm::dyn_cast<FloatType>(eltType);
715 if (iType.getWidth() == 80)
719 if (iType && emitter.aie2()) {
720 os <<
"ups_to_v" << lanes <<
"acc" << iType.getWidth();
721 }
else if (fType && emitter.aie2()) {
722 os <<
"ups_to_v16accfloat";
728 os << emitter.getOrCreateName(source);
729 if (!(fType && emitter.aie2())) {
731 os << std::to_string(shift);
739static LogicalResult printOperation(CppEmitter &emitter,
740 aievec::CastOp castOp) {
741 if (!emitter.aie2()) {
746 Value source = castOp.getSource();
747 if (!emitter.hasValueInScope(source))
750 bool isResAcc = castOp.getIsResAcc();
753 if (failed(emitter.emitAssignPrefix(*castOp, isResAcc)))
757 auto resType = llvm::cast<VectorType>(castOp->getResult(0).getType());
758 Type eltType = resType.getElementType();
761 raw_indented_ostream &
os = emitter.ostream();
765 if (llvm::isa<FloatType>(eltType))
766 os <<
"v" << lanes <<
"accfloat";
769 os <<
"v" << lanes <<
"acc" << width;
771 }
else if (llvm::isa<FloatType>(eltType)) {
772 width = llvm::cast<FloatType>(eltType).getWidth();
780 os <<
"v" << lanes <<
"int" << width;
783 os << emitter.getOrCreateName(source);
789static LogicalResult printOperation(CppEmitter &emitter,
790 aievec::UnpackOp unpackOp) {
793 Value source = unpackOp.getSource();
794 if (!emitter.hasValueInScope(source))
798 if (failed(emitter.emitAssignPrefix(*unpackOp,
false)))
801 raw_indented_ostream &
os = emitter.ostream();
804 os << emitter.getOrCreateName(source);
810static LogicalResult printOperation(CppEmitter &emitter, aievec::SRSOp srsOp) {
811 Value source = srsOp.getSource();
812 Value shift = srsOp.getShift();
815 auto accType = llvm::cast<VectorType>(srsOp.getSource().getType());
816 auto resType = llvm::cast<VectorType>(srsOp->getResult(0).getType());
817 Type eltType = accType.getElementType();
820 raw_indented_ostream &
os = emitter.ostream();
823 if (failed(emitter.emitAssignPrefix(*srsOp)))
827 if (!emitter.hasValueInScope(source))
832 if (llvm::isa<FloatType>(eltType)) {
833 if (emitter.aie2()) {
836 else if (width == 16)
837 os <<
"to_v16bfloat16";
839 os << emitter.getOrCreateName(source);
842 os << emitter.getOrCreateName(source);
851 unsigned srcWidth = 0;
852 if (
auto iType = llvm::dyn_cast<IntegerType>(eltType))
853 srcWidth = iType.getWidth();
856 if ((srcWidth == 80 && resultWidth == 64) ||
857 (srcWidth == 48 && resultWidth == 32))
859 else if (srcWidth == 48 && resultWidth == 8)
863 os <<
"srs_to_v" << std::to_string(lanes) <<
"int"
864 << std::to_string(resWidth);
869 os << emitter.getOrCreateName(source);
871 if (llvm::cast<IntegerType>(srsOp.getShift().getType()).getWidth() != 32)
873 os << emitter.getOrCreateName(shift);
880static LogicalResult printOperation(CppEmitter &emitter,
881 aievec::BroadcastOp broadcastOp) {
882 Value source = broadcastOp.getSource();
883 int8_t idx = broadcastOp.getIdx();
885 raw_indented_ostream &
os = emitter.ostream();
888 if (failed(emitter.emitAssignPrefix(*broadcastOp)))
892 if (!emitter.hasValueInScope(source))
895 os <<
"broadcast_elem";
897 os << emitter.getOrCreateName(source);
899 os << std::to_string(idx);
907printOperation(CppEmitter &emitter,
908 aievec::BroadcastScalarOp broadcastScalarOp) {
909 auto source = broadcastScalarOp.getSource();
911 llvm::cast<VectorType>(broadcastScalarOp.getResult().getType());
914 raw_indented_ostream &
os = emitter.ostream();
917 if (failed(emitter.emitAssignPrefix(*broadcastScalarOp)))
920 Type eltType = resType.getElementType();
921 os <<
"broadcast_to_v";
922 if (llvm::isa<IntegerType>(eltType)) {
923 os << lanes <<
"int";
925 }
else if (width == 16)
926 os << lanes <<
"bfloat16";
928 os << lanes <<
"float";
929 os <<
"(" << emitter.getOrCreateName(source) <<
")";
936static LogicalResult printExtOperation(CppEmitter &emitter, T extOp) {
937 Value source = extOp.getSource();
938 int8_t index = extOp.getIndex();
940 raw_indented_ostream &
os = emitter.ostream();
943 if (failed(emitter.emitAssignPrefix(*extOp)))
946 if (!emitter.hasValueInScope(source))
949 auto resType = llvm::cast<VectorType>(extOp.getResult().getType());
950 Type eltType = resType.getElementType();
955 if (emitter.aie2()) {
956 os <<
"extract_v" << std::to_string(lanes);
957 if (llvm::isa<IntegerType>(eltType))
958 os <<
"int" << std::to_string(resWidth);
959 else if (resWidth == 16)
966 assert(vecSizeInBits == 128 || vecSizeInBits == 256 ||
967 vecSizeInBits == 512);
968 os << (vecSizeInBits == 128 ?
"ext_v"
969 : vecSizeInBits == 256 ?
"ext_w"
974 os << emitter.getOrCreateName(source);
976 os << std::to_string(index);
983static LogicalResult printOperation(CppEmitter &emitter, aievec::ExtOp extOp) {
986 return printExtOperation<aievec::ExtOp>(emitter, extOp);
990static LogicalResult printOperation(CppEmitter &emitter,
991 aievec::aie1::ExtOp extOp) {
994 return printExtOperation<aievec::aie1::ExtOp>(emitter, extOp);
998static LogicalResult printOperation(CppEmitter &emitter,
999 aievec::ConcatOp concatOp) {
1000 SmallVector<Value> sources = concatOp.getSources();
1002 raw_indented_ostream &
os = emitter.ostream();
1005 if (failed(emitter.emitAssignPrefix(*concatOp)))
1012 for (
auto source : sources) {
1014 if (!emitter.hasValueInScope(source))
1018 os << emitter.getOrCreateName(source);
1027static LogicalResult printOperation(CppEmitter &emitter,
1028 aievec::ShiftOp shiftOp) {
1029 Value lhs = shiftOp.getLhs();
1030 Value rhs = shiftOp.getRhs();
1031 Value shift = shiftOp.getShift();
1032 bool isAcc = shiftOp.getIsAcc();
1034 raw_indented_ostream &
os = emitter.ostream();
1037 if (failed(emitter.emitAssignPrefix(*shiftOp, isAcc)))
1040 os <<
"shift_bytes";
1043 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs))
1045 os << emitter.getOrCreateName(lhs);
1047 os << emitter.getOrCreateName(rhs);
1048 os <<
", static_cast<uint32_t>(";
1050 if (!emitter.hasValueInScope(shift))
1052 os << emitter.getOrCreateName(shift);
1059static LogicalResult printOperation(CppEmitter &emitter,
1060 aievec::ShuffleOp shuffleOp) {
1061 Value lhs = shuffleOp.getLhs();
1062 Value rhs = shuffleOp.getRhs();
1063 aievec::ShuffleMode mode = shuffleOp.getMode();
1065 raw_indented_ostream &
os = emitter.ostream();
1068 if (failed(emitter.emitAssignPrefix(*shuffleOp)))
1073 if (!emitter.hasValueInScope(lhs))
1075 os << emitter.getOrCreateName(lhs);
1078 if (!emitter.hasValueInScope(rhs))
1080 os << emitter.getOrCreateName(rhs);
1083 os <<
"eShuffleMode::shuffle_T" << stringifyEnum(mode).substr(1);
1090static LogicalResult printOperation(CppEmitter &emitter,
1091 aievec::LegacyShuffleOp shuffleOp) {
1092 Value source = shuffleOp.getSource();
1093 unsigned mode = shuffleOp.getMode();
1095 raw_indented_ostream &
os = emitter.ostream();
1098 if (failed(emitter.emitAssignPrefix(*shuffleOp)))
1105 if (!emitter.hasValueInScope(source))
1107 os << emitter.getOrCreateName(source);
1109 os << std::to_string(mode);
1116static LogicalResult printOperation(CppEmitter &emitter,
1117 aievec::aie1::SelectOp selectOp) {
1118 Value xbuff = selectOp.getXbuff();
1119 assert(xbuff &&
"xbuff empty in select op");
1121 raw_indented_ostream &
os = emitter.ostream();
1124 if (failed(emitter.emitAssignPrefix(*selectOp)))
1128 auto xbuffType = llvm::cast<VectorType>(selectOp.getXbuff().getType());
1130 assert(elementSizeInBits == 16 || elementSizeInBits == 32 ||
1131 elementSizeInBits == 64);
1133 os << (elementSizeInBits == 16 ?
"select32"
1134 : elementSizeInBits == 32 ?
"select16"
1138 assert(!selectOp.getSelect().empty());
1139 os << selectOp.getSelect();
1141 if (!emitter.hasValueInScope(xbuff))
1145 os << emitter.getOrCreateName(xbuff);
1147 if (!selectOp.getXstart().empty())
1148 os <<
", " << selectOp.getXstart();
1149 if (!selectOp.getXoffsets().empty())
1150 os <<
", " << selectOp.getXoffsets();
1151 if (!selectOp.getXoffsetsHi().empty())
1152 os <<
", " << selectOp.getXoffsetsHi();
1153 if (!selectOp.getXsquare().empty())
1154 os <<
", " << selectOp.getXsquare();
1156 if (selectOp.getYbuff()) {
1157 Value ybuff = selectOp.getYbuff();
1159 if (!emitter.hasValueInScope(ybuff))
1163 os << emitter.getOrCreateName(ybuff);
1166 if (!selectOp.getYstart().empty())
1167 os <<
", " << selectOp.getYstart();
1168 if (!selectOp.getYoffsets().empty())
1169 os <<
", " << selectOp.getYoffsets();
1170 if (!selectOp.getYoffsetsHi().empty())
1171 os <<
", " << selectOp.getYoffsetsHi();
1172 if (!selectOp.getYsquare().empty())
1173 os <<
", " << selectOp.getYsquare();
1180static LogicalResult printOperation(CppEmitter &emitter,
1181 aievec::PackOp packOp) {
1182 Value source = packOp.getSource();
1184 raw_indented_ostream &
os = emitter.ostream();
1187 if (failed(emitter.emitAssignPrefix(*packOp)))
1191 auto sourceType = llvm::cast<VectorType>(packOp.getSource().getType());
1192 Type scalarType = sourceType.getElementType();
1193 os << (scalarType.isUnsignedInteger() ?
"upack" :
"pack");
1196 if (!emitter.hasValueInScope(source))
1198 os << emitter.getOrCreateName(source);
1205template <
typename T>
1206static LogicalResult printAddOrSubOperand(CppEmitter &emitter, T op,
1213 Value operand = opNum == 0 ? op.getLhs() : op.getRhs();
1214 if (!emitter.hasValueInScope(operand))
1217 raw_indented_ostream &
os = emitter.ostream();
1219 StringRef start = op.getStart(opNum);
1220 StringRef offset = op.getOffset(opNum);
1221 StringRef offsetHi = op.getOffsetHi(opNum);
1222 StringRef square = op.getSquare(opNum);
1224 os << emitter.getOrCreateName(operand);
1226 os <<
", " << start;
1227 if (!offset.empty())
1228 os <<
", " << offset;
1229 if (!offsetHi.empty())
1230 os <<
", " << offsetHi;
1231 if (!square.empty())
1232 os <<
", " << square;
1238template <
typename T>
1239static LogicalResult printMinMaxOperand(CppEmitter &emitter, T op,
1246 Value operand = opNum == 0 ? op.getLhs() : op.getRhs();
1247 if (!emitter.hasValueInScope(operand))
1250 raw_indented_ostream &
os = emitter.ostream();
1251 os << emitter.getOrCreateName(operand);
1257template <
typename T>
1258static LogicalResult printAddElemOrSubElemOperand(CppEmitter &emitter, T op,
1265 Value operand = opNum == 0 ? op.getLhs() : op.getRhs();
1266 if (!emitter.hasValueInScope(operand))
1269 raw_indented_ostream &
os = emitter.ostream();
1270 os << emitter.getOrCreateName(operand);
1276template <
typename T>
1277static LogicalResult printFMAOrMulOperand(CppEmitter &emitter, T op,
1284 Value operand = opNum == 0 ? op.getLhs() : op.getRhs();
1285 if (!emitter.hasValueInScope(operand))
1288 raw_indented_ostream &
os = emitter.ostream();
1290 StringRef start = op.getStart(opNum);
1291 StringRef offset = op.getOffset(opNum);
1292 StringRef offsetHi = op.getOffsetHi(opNum);
1293 StringRef step = op.getStep(opNum);
1294 StringRef square = op.getSquare(opNum);
1296 os << emitter.getOrCreateName(operand);
1298 os <<
", " << start;
1299 if (!offset.empty())
1300 os <<
", " << offset;
1301 if (!offsetHi.empty())
1302 os <<
", " << offsetHi;
1305 if (!square.empty())
1306 os <<
", " << square;
1312template <
typename T>
1313static LogicalResult printFMAOrMulElemOperand(CppEmitter &emitter, T op,
1314 Type iType, int32_t size,
1321 Value operand = opNum == 0 ? op.getLhs() : op.getRhs();
1322 if (!emitter.hasValueInScope(operand))
1325 raw_indented_ostream &
os = emitter.ostream();
1326 os << emitter.getOrCreateName(operand);
1327 if (size == 32 && iType)
1328 os <<
", " << (opNum == 0 ?
"undef_v16int32()" :
"broadcast_zero_s32()");
1334template <
typename T>
1335static LogicalResult printFMAOrMulConvOperand(CppEmitter &emitter, T op,
1342 Value operand = opNum == 0 ? op.getLhs() : op.getRhs();
1343 if (!emitter.hasValueInScope(operand))
1346 raw_indented_ostream &
os = emitter.ostream();
1347 os << emitter.getOrCreateName(operand);
1353static LogicalResult printOperation(CppEmitter &emitter,
1354 aievec::aie1::MulOp mulOp) {
1355 auto lhs = mulOp.getLhs();
1356 auto rhs = mulOp.getRhs();
1359 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs))
1363 bool simpleScheme = mulOp.getStart(0).empty();
1367 auto resType = llvm::cast<VectorType>(mulOp.getResult().getType());
1368 Type eltType = resType.getElementType();
1369 if (!simpleScheme) {
1370 if (
auto iType = llvm::dyn_cast<IntegerType>(eltType)) {
1371 if (iType.getWidth() == 80)
1373 }
else if (llvm::isa<FloatType>(eltType))
1378 if (!simpleScheme && !llvm::isa<FloatType>(eltType))
1381 raw_indented_ostream &
os = emitter.ostream();
1384 if (failed(emitter.emitAssignPrefix(*mulOp)))
1389 if (failed(printFMAOrMulOperand<aievec::aie1::MulOp>(emitter, mulOp, 0)))
1392 if (failed(printFMAOrMulOperand<aievec::aie1::MulOp>(emitter, mulOp, 1)))
1399static std::string printConversionTo512bit(CppEmitter &emitter, Value v) {
1400 std::string vName = emitter.getOrCreateName(v).str();
1401 auto vTy = cast<VectorType>(v.getType());
1402 auto vShape = vTy.getShape();
1403 int64_t elemBitWidth = vTy.getElementTypeBitWidth();
1404 int64_t numElems = std::accumulate(vShape.begin(), vShape.end(), 1,
1405 std::multiplies<int64_t>());
1406 int64_t vBitWidth = numElems * elemBitWidth;
1407 if (vBitWidth >= 512)
1410 int64_t newNumElems = 512 / elemBitWidth;
1412 std::string vNewName = emitter.getNewName();
1413 raw_indented_ostream &
os = emitter.ostream();
1414 auto newVecTy = VectorType::get({512 / elemBitWidth}, vTy.getElementType());
1416 emitter.genCppTypeName(newVecTy,
false,
false));
1418 *(emitter.genCppTypeName(vTy,
false,
false));
1420 os << newTyName <<
" " << vNewName <<
" = concat(";
1421 if (newNumElems / numElems == 4) {
1422 os <<
"concat(" << vName <<
", undef_" << oldTyName <<
"())";
1423 oldTyName = *(emitter.genCppTypeName(
1424 VectorType::get({256 / elemBitWidth}, vTy.getElementType())));
1428 os <<
", undef_" << oldTyName <<
"());\n";
1433static LogicalResult printOperation(CppEmitter &emitter,
1434 aievec::MulElemOp mulElemOp) {
1435 auto lhs = mulElemOp.getLhs();
1436 auto rhs = mulElemOp.getRhs();
1439 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs))
1442 auto lhsName = printConversionTo512bit(emitter, lhs);
1443 auto rhsName = printConversionTo512bit(emitter, rhs);
1445 std::string opname =
"mul_elem";
1448 auto lhsType = llvm::cast<VectorType>(mulElemOp.getLhs().getType());
1449 Type eltType = lhsType.getElementType();
1451 auto iType = llvm::dyn_cast<IntegerType>(eltType);
1456 else if (lsize == 16)
1458 else if (lsize == 8)
1460 }
else if (llvm::isa<FloatType>(eltType)) {
1463 else if (lsize == 16)
1467 raw_indented_ostream &
os = emitter.ostream();
1470 if (failed(emitter.emitAssignPrefix(*mulElemOp,
true )))
1474 os <<
"(" << lhsName;
1475 if ((lsize == 32) && iType)
1477 <<
"undef_v16int32()";
1478 os <<
" ," << rhsName;
1479 if ((lsize == 32) && iType)
1481 <<
"broadcast_zero_s32()";
1487static LogicalResult printOperation(CppEmitter &emitter,
1488 aievec::MulConvOp mulConvOp) {
1489 auto lhs = mulConvOp.getLhs();
1490 auto rhs = mulConvOp.getRhs();
1493 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs))
1497 auto lhsType = llvm::cast<VectorType>(mulConvOp.getLhs().getType());
1498 Type eltType = lhsType.getElementType();
1500 auto iType = llvm::dyn_cast<IntegerType>(eltType);
1503 if (!iType || !(lsize == 16 || lsize == 8)) {
1507 int32_t M = mulConvOp.getM();
1508 int32_t N = mulConvOp.getN();
1509 std::string opname =
1510 "mul_conv_" + std::to_string(M) +
"x" + std::to_string(N);
1512 raw_indented_ostream &
os = emitter.ostream();
1515 if (failed(emitter.emitAssignPrefix(*mulConvOp,
true )))
1522 printFMAOrMulConvOperand<aievec::MulConvOp>(emitter, mulConvOp, 0)))
1526 printFMAOrMulConvOperand<aievec::MulConvOp>(emitter, mulConvOp, 1)))
1534static LogicalResult printOperation(CppEmitter &emitter,
1535 aievec::aie1::AddOp addOp) {
1536 auto lhs = addOp.getLhs();
1537 auto rhs = addOp.getRhs();
1540 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs))
1543 raw_indented_ostream &
os = emitter.ostream();
1546 if (failed(emitter.emitAssignPrefix(*addOp)))
1550 auto resultType = llvm::cast<VectorType>(addOp.getResult().getType());
1552 Type elementType = resultType.getElementType();
1553 bool floatType = llvm::isa<FloatType>(elementType);
1557 if (addOp.getStart(0).empty()) {
1562 os << emitter.getOrCreateName(lhs);
1564 os << emitter.getOrCreateName(rhs);
1569 os << emitter.getOrCreateName(lhs);
1571 os << emitter.getOrCreateName(rhs);
1576 os << (floatType ?
"fpadd" :
"add" + std::to_string(lanes));
1578 if (failed(printAddOrSubOperand<aievec::aie1::AddOp>(emitter, addOp, 0)))
1581 if (failed(printAddOrSubOperand<aievec::aie1::AddOp>(emitter, addOp, 1)))
1589static LogicalResult printOperation(CppEmitter &emitter,
1590 aievec::aie1::SubOp subOp) {
1591 auto lhs = subOp.getLhs();
1592 auto rhs = subOp.getRhs();
1595 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs))
1598 raw_indented_ostream &
os = emitter.ostream();
1601 if (failed(emitter.emitAssignPrefix(*subOp)))
1605 auto resultType = llvm::cast<VectorType>(subOp.getResult().getType());
1607 Type elementType = resultType.getElementType();
1608 bool floatType = llvm::isa<FloatType>(elementType);
1612 if (subOp.getStart(0).empty()) {
1617 os << emitter.getOrCreateName(lhs);
1619 os << emitter.getOrCreateName(rhs);
1624 os << emitter.getOrCreateName(lhs);
1626 os << emitter.getOrCreateName(rhs);
1631 os << (floatType ?
"fpsub" :
"sub" + std::to_string(lanes));
1633 if (failed(printAddOrSubOperand<aievec::aie1::SubOp>(emitter, subOp, 0)))
1636 if (failed(printAddOrSubOperand<aievec::aie1::SubOp>(emitter, subOp, 1)))
1644static LogicalResult printOperation(CppEmitter &emitter, aievec::MinOp minOp) {
1645 auto lhs = minOp.getLhs();
1646 auto rhs = minOp.getRhs();
1649 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs))
1652 raw_indented_ostream &
os = emitter.ostream();
1655 if (failed(emitter.emitAssignPrefix(*minOp)))
1659 if (failed(printMinMaxOperand<aievec::MinOp>(emitter, minOp, 0)))
1662 if (failed(printMinMaxOperand<aievec::MinOp>(emitter, minOp, 1)))
1670static LogicalResult printOperation(CppEmitter &emitter, aievec::MaxOp maxOp) {
1671 auto lhs = maxOp.getLhs();
1672 auto rhs = maxOp.getRhs();
1675 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs))
1678 raw_indented_ostream &
os = emitter.ostream();
1681 if (failed(emitter.emitAssignPrefix(*maxOp)))
1685 if (failed(printMinMaxOperand<aievec::MaxOp>(emitter, maxOp, 0)))
1688 if (failed(printMinMaxOperand<aievec::MaxOp>(emitter, maxOp, 1)))
1696static LogicalResult printOperation(CppEmitter &emitter, aievec::NegOp negOp) {
1697 auto src = negOp.getSource();
1700 if (!emitter.hasValueInScope(src))
1703 raw_indented_ostream &
os = emitter.ostream();
1706 if (failed(emitter.emitAssignPrefix(*negOp,
true )))
1710 os << emitter.getOrCreateName(src);
1717static LogicalResult printOperation(CppEmitter &emitter,
1718 aievec::BnegOp bnegOp) {
1719 auto src = bnegOp.getSource();
1722 if (!emitter.hasValueInScope(src))
1725 raw_indented_ostream &
os = emitter.ostream();
1728 if (failed(emitter.emitAssignPrefix(*bnegOp)))
1732 os << emitter.getOrCreateName(src);
1739static LogicalResult printOperation(CppEmitter &emitter, aievec::BxorOp xorOp) {
1740 auto lhs = xorOp.getLhs();
1741 auto rhs = xorOp.getRhs();
1744 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs))
1747 raw_indented_ostream &
os = emitter.ostream();
1750 if (failed(emitter.emitAssignPrefix(*xorOp)))
1754 os << emitter.getOrCreateName(lhs);
1756 os << emitter.getOrCreateName(rhs);
1763static LogicalResult printOperation(CppEmitter &emitter, aievec::BandOp andOp) {
1764 auto lhs = andOp.getLhs();
1765 auto rhs = andOp.getRhs();
1768 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs))
1771 raw_indented_ostream &
os = emitter.ostream();
1774 if (failed(emitter.emitAssignPrefix(*andOp)))
1778 os << emitter.getOrCreateName(lhs);
1780 os << emitter.getOrCreateName(rhs);
1787static LogicalResult printOperation(CppEmitter &emitter, aievec::BorOp orOp) {
1788 auto lhs = orOp.getLhs();
1789 auto rhs = orOp.getRhs();
1792 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs))
1795 raw_indented_ostream &
os = emitter.ostream();
1798 if (failed(emitter.emitAssignPrefix(*orOp)))
1802 os << emitter.getOrCreateName(lhs);
1804 os << emitter.getOrCreateName(rhs);
1811static LogicalResult printOperation(CppEmitter &emitter,
1812 aievec::AddElemOp addElemOp) {
1813 auto lhs = addElemOp.getLhs();
1814 auto rhs = addElemOp.getRhs();
1817 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs))
1820 raw_indented_ostream &
os = emitter.ostream();
1825 auto resType = cast<VectorType>(addElemOp.getResult().getType());
1826 auto resElemType = resType.getElementType();
1827 unsigned resBitWidth = resElemType.getIntOrFloatBitWidth();
1829 if (isa<FloatType>(resElemType) || resBitWidth * resLaneSize == 1024)
1832 if (failed(emitter.emitAssignPrefix(*addElemOp, isAcc)))
1836 if (failed(printAddElemOrSubElemOperand<aievec::AddElemOp>(emitter, addElemOp,
1840 if (failed(printAddElemOrSubElemOperand<aievec::AddElemOp>(emitter, addElemOp,
1849static LogicalResult printOperation(CppEmitter &emitter,
1850 aievec::SubElemOp subElemOp) {
1851 auto lhs = subElemOp.getLhs();
1852 auto rhs = subElemOp.getRhs();
1855 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs))
1858 raw_indented_ostream &
os = emitter.ostream();
1863 auto resType = cast<VectorType>(subElemOp.getResult().getType());
1864 auto resElemType = resType.getElementType();
1865 unsigned resBitWidth = resElemType.getIntOrFloatBitWidth();
1867 if (isa<FloatType>(resElemType) || resBitWidth * resLaneSize == 1024)
1870 if (failed(emitter.emitAssignPrefix(*subElemOp, isAcc)))
1874 if (failed(printAddElemOrSubElemOperand<aievec::SubElemOp>(emitter, subElemOp,
1878 if (failed(printAddElemOrSubElemOperand<aievec::SubElemOp>(emitter, subElemOp,
1887static LogicalResult printOperation(CppEmitter &emitter,
1888 aievec::aie1::FMAOp fmaOp) {
1889 auto acc = fmaOp.getAcc();
1890 auto lhs = fmaOp.getLhs();
1891 auto rhs = fmaOp.getRhs();
1894 if (!emitter.hasValueInScope(acc) || !emitter.hasValueInScope(lhs) ||
1895 !emitter.hasValueInScope(rhs))
1899 bool simpleScheme = fmaOp.getStart(0).empty();
1903 auto resType = llvm::cast<VectorType>(fmaOp.getResult().getType());
1904 Type eltType = resType.getElementType();
1905 if (!simpleScheme) {
1906 if (
auto iType = llvm::dyn_cast<IntegerType>(eltType)) {
1907 if (iType.getWidth() == 80)
1909 }
else if (llvm::isa<FloatType>(eltType))
1913 opname += fmaOp.getFmsub() ?
"msc" :
"mac";
1914 if (!simpleScheme && !llvm::isa<FloatType>(eltType))
1917 raw_indented_ostream &
os = emitter.ostream();
1919 StringRef accName = emitter.getOrCreateName(acc);
1926 if (failed(printFMAOrMulOperand<aievec::aie1::FMAOp>(emitter, fmaOp, 0)))
1929 if (failed(printFMAOrMulOperand<aievec::aie1::FMAOp>(emitter, fmaOp, 1)))
1934 emitter.setName(fmaOp->getResult(0), accName);
1940static LogicalResult printOperation(CppEmitter &emitter,
1941 aievec::FMAElemOp fmaElemOp) {
1942 auto acc = fmaElemOp.getAcc();
1943 auto lhs = fmaElemOp.getLhs();
1944 auto rhs = fmaElemOp.getRhs();
1947 if (!emitter.hasValueInScope(acc) || !emitter.hasValueInScope(lhs) ||
1948 !emitter.hasValueInScope(rhs))
1951 std::string opname = fmaElemOp.getFmsub() ?
"msc_elem" :
"mac_elem";
1953 auto lhsType = llvm::cast<VectorType>(fmaElemOp.getLhs().getType());
1954 Type eltType = lhsType.getElementType();
1956 auto iType = llvm::dyn_cast<IntegerType>(eltType);
1961 else if (lsize == 16)
1963 else if (lsize == 8)
1965 }
else if (llvm::isa<FloatType>(eltType)) {
1968 else if (lsize == 16)
1972 raw_indented_ostream &
os = emitter.ostream();
1974 StringRef accName = emitter.getOrCreateName(acc);
1979 if (failed(printFMAOrMulElemOperand<aievec::FMAElemOp>(emitter, fmaElemOp,
1983 if (failed(printFMAOrMulElemOperand<aievec::FMAElemOp>(emitter, fmaElemOp,
1991 emitter.setName(fmaElemOp->getResult(0), accName);
1997static LogicalResult printOperation(CppEmitter &emitter,
1998 aievec::FMAConvOp fmaConvOp) {
1999 auto acc = fmaConvOp.getAcc();
2000 auto lhs = fmaConvOp.getLhs();
2001 auto rhs = fmaConvOp.getRhs();
2004 if (!emitter.hasValueInScope(acc) || !emitter.hasValueInScope(lhs) ||
2005 !emitter.hasValueInScope(rhs))
2008 std::string opname = fmaConvOp.getFmsub() ?
"msc_conv" :
"mac_conv";
2010 auto lhsType = llvm::cast<VectorType>(fmaConvOp.getLhs().getType());
2011 Type eltType = lhsType.getElementType();
2013 auto iType = llvm::dyn_cast<IntegerType>(eltType);
2016 if (!iType || !(lsize == 16 || lsize == 8))
2019 int32_t M = fmaConvOp.getM();
2020 int32_t N = fmaConvOp.getN();
2021 opname +=
"_" + std::to_string(M) +
"x" + std::to_string(N);
2023 raw_indented_ostream &
os = emitter.ostream();
2025 StringRef accName = emitter.getOrCreateName(acc);
2031 printFMAOrMulConvOperand<aievec::FMAConvOp>(emitter, fmaConvOp, 0)))
2035 printFMAOrMulConvOperand<aievec::FMAConvOp>(emitter, fmaConvOp, 1)))
2042 emitter.setName(fmaConvOp->getResult(0), accName);
2048static LogicalResult printOperation(CppEmitter &emitter, aievec::CmpOp cmpOp) {
2049 if (!emitter.aie2())
2053 Value lhs = cmpOp.getLhs();
2054 Value rhs = cmpOp.getRhs();
2056 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs))
2060 if (failed(emitter.emitAssignPrefix(*cmpOp)))
2063 raw_indented_ostream &
os = emitter.ostream();
2065 StringRef pred = cmpOp.getPred();
2068 else if (pred ==
"ne")
2070 else if (pred ==
"slt" || pred ==
"ult")
2072 else if (pred ==
"sle" || pred ==
"ule")
2074 else if (pred ==
"sgt" || pred ==
"ugt")
2076 else if (pred ==
"sge" || pred ==
"uge")
2082 auto vType = llvm::cast<VectorType>(lhs.getType());
2084 if (Type eltType = vType.getElementType();
2085 llvm::isa<IntegerType>(eltType) &&
2086 (pred ==
"ult" || pred ==
"ule" || pred ==
"ugt" || pred ==
"uge")) {
2089 os <<
"v" << std::to_string(lanes) <<
"uint" << std::to_string(width);
2091 os << emitter.getOrCreateName(lhs);
2093 os <<
"v" << std::to_string(lanes) <<
"uint" << std::to_string(width);
2095 os << emitter.getOrCreateName(rhs);
2098 os << emitter.getOrCreateName(lhs);
2100 os << emitter.getOrCreateName(rhs);
2108static LogicalResult printOperation(CppEmitter &emitter, aievec::SelOp selOp) {
2109 if (!emitter.aie2())
2113 Value lhs = selOp.getLhs();
2114 Value rhs = selOp.getRhs();
2115 Value sel = selOp.getSel();
2117 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs) ||
2118 !emitter.hasValueInScope(sel))
2122 if (failed(emitter.emitAssignPrefix(*selOp)))
2125 raw_indented_ostream &
os = emitter.ostream();
2128 os << emitter.getOrCreateName(rhs);
2130 os << emitter.getOrCreateName(lhs);
2132 os << emitter.getOrCreateName(sel);
2139static LogicalResult printOperation(CppEmitter &emitter,
2140 aievec::ExtElemOp extElemOp) {
2141 Value source = extElemOp.getSource();
2142 Value index = extElemOp.getIndex();
2144 raw_indented_ostream &
os = emitter.ostream();
2147 if (failed(emitter.emitAssignPrefix(*extElemOp)))
2151 if (!emitter.hasValueInScope(source))
2154 os <<
"extract_elem";
2157 os << emitter.getOrCreateName(source);
2159 os << emitter.getOrCreateName(index);
2166static LogicalResult printOperation(CppEmitter &emitter,
2167 vector::TransferWriteOp writeOp) {
2168 Value source = writeOp.getSource();
2169 Value vector = writeOp.getVector();
2173 if (!emitter.hasValueInScope(source) || !emitter.hasValueInScope(vector))
2178 auto indices = writeOp.getIndices();
2179 if (failed(createLinearizedAccess(emitter, source, indices, access)))
2182 raw_indented_ostream &
os = emitter.ostream();
2185 if (failed(emitter.emitType(writeOp->getLoc(), vector.getType())))
2189 os << emitter.getOrCreateName(source);
2190 if (!access.empty())
2191 os <<
" + " << access;
2194 os << emitter.getOrCreateName(vector);
2200static LogicalResult printOperation(CppEmitter &emitter,
2201 memref::StoreOp storeOp) {
2202 Value value = storeOp.getValue();
2203 Value memref = storeOp.getMemref();
2207 if (!emitter.hasValueInScope(value) || !emitter.hasValueInScope(memref))
2210 raw_indented_ostream &
os = emitter.ostream();
2213 if (failed(emitter.emitType(
2215 cast<MemRefType>(memref.getType()).getElementType())))
2218 os << emitter.getOrCreateName(memref);
2220 os << emitter.getOrCreateName(value);
2226template <
typename OpTy>
2227static LogicalResult printValueForwardOperation(CppEmitter &emitter, OpTy op) {
2228 Value source = op.getSrc();
2232 if (!emitter.hasValueInScope(source))
2235 if (failed(emitter.emitAssignPrefix(*op)))
2238 raw_indented_ostream &
os = emitter.ostream();
2239 os << emitter.getOrCreateName(source);
2245static LogicalResult printOperation(CppEmitter &emitter,
2246 memref::ExpandShapeOp expandShapeOp) {
2247 return printValueForwardOperation<memref::ExpandShapeOp>(emitter,
2252static LogicalResult printOperation(CppEmitter &emitter,
2253 memref::CollapseShapeOp collapseShapeOp) {
2254 return printValueForwardOperation<memref::CollapseShapeOp>(emitter,
2258static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
2260 OpResult result = operation->getResult(0);
2264 if (emitter.shouldDeclareVariablesAtTop()) {
2266 if (
auto oAttr = llvm::dyn_cast<emitc::OpaqueAttr>(value))
2267 if (oAttr.getValue().empty())
2270 if (failed(emitter.emitVariableAssignment(result)))
2272 return emitter.emitAttribute(operation->getLoc(), value);
2276 if (
auto oAttr = llvm::dyn_cast<emitc::OpaqueAttr>(value))
2277 if (oAttr.getValue().empty())
2279 return emitter.emitVariableDeclaration(result,
2283 if (failed(emitter.emitAssignPrefix(*operation)))
2285 return emitter.emitAttribute(operation->getLoc(), value);
2288static LogicalResult printOperation(CppEmitter &emitter,
2289 emitc::ConstantOp constantOp) {
2290 Operation *operation = constantOp.getOperation();
2291 Attribute value = constantOp.getValue();
2292 return printConstantOp(emitter, operation, value);
2295static LogicalResult printOperation(CppEmitter &emitter,
2296 arith::ConstantOp constantOp) {
2297 Operation *operation = constantOp.getOperation();
2298 Attribute value = constantOp.getValue();
2299 return printConstantOp(emitter, operation, value);
2302static LogicalResult printOperation(CppEmitter &emitter,
2303 cf::BranchOp branchOp) {
2304 raw_ostream &
os = emitter.ostream();
2305 Block &successor = *branchOp.getSuccessor();
2307 for (
auto pair : zip(branchOp.getOperands(), successor.getArguments())) {
2308 Value &operand = std::get<0>(pair);
2309 BlockArgument &argument = std::get<1>(pair);
2310 os << emitter.getOrCreateName(argument) <<
" = "
2311 << emitter.getOrCreateName(operand) <<
";\n";
2315 if (!emitter.hasBlockLabel(successor))
2316 return branchOp.emitOpError(
"unable to find label for successor block");
2317 os << emitter.getOrCreateName(successor);
2321static LogicalResult printOperation(CppEmitter &emitter,
2322 cf::CondBranchOp condBranchOp) {
2323 raw_indented_ostream &
os = emitter.ostream();
2324 Block &trueSuccessor = *condBranchOp.getTrueDest();
2325 Block &falseSuccessor = *condBranchOp.getFalseDest();
2327 os <<
"if (" << emitter.getOrCreateName(condBranchOp.getCondition())
2334 zip(condBranchOp.getTrueOperands(), trueSuccessor.getArguments())) {
2335 Value &operand = std::get<0>(pair);
2336 BlockArgument &argument = std::get<1>(pair);
2337 os << emitter.getOrCreateName(argument) <<
" = "
2338 << emitter.getOrCreateName(operand) <<
";\n";
2342 if (!emitter.hasBlockLabel(trueSuccessor))
2343 return condBranchOp.emitOpError(
"unable to find label for successor block");
2344 os << emitter.getOrCreateName(trueSuccessor) <<
";\n";
2345 os.unindent() <<
"} else {\n";
2349 zip(condBranchOp.getFalseOperands(), falseSuccessor.getArguments())) {
2350 Value &operand = std::get<0>(pair);
2351 BlockArgument &argument = std::get<1>(pair);
2352 os << emitter.getOrCreateName(argument) <<
" = "
2353 << emitter.getOrCreateName(operand) <<
";\n";
2357 if (!emitter.hasBlockLabel(falseSuccessor))
2358 return condBranchOp.emitOpError()
2359 <<
"unable to find label for successor block";
2360 os << emitter.getOrCreateName(falseSuccessor) <<
";\n";
2361 os.unindent() <<
"}";
2366static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) {
2367 if (failed(emitter.emitAssignPrefix(*callOp.getOperation())))
2370 raw_ostream &
os = emitter.ostream();
2371 os << callOp.getCallee() <<
"(";
2372 if (failed(emitter.emitOperands(*callOp.getOperation())))
2379static LogicalResult printOperation(CppEmitter &emitter,
2380 emitc::CallOpaqueOp callOp) {
2381 raw_ostream &
os = emitter.ostream();
2382 Operation &op = *callOp.getOperation();
2383 if (callOp.getCallee() ==
"getTanhBf16" ||
2384 callOp.getCallee() ==
"getSqrtBf16" ||
2385 callOp.getCallee() ==
"getRsqrtBf16" ||
2386 callOp.getCallee() ==
"getErfBf16" || callOp.getCallee() ==
"getAbs" ||
2387 callOp.getCallee() ==
"getSigmoidBf16" ||
2388 callOp.getCallee() ==
"getCeilBf16" ||
2389 callOp.getCallee() ==
"getFloorBf16") {
2390 if (failed(emitter.emitAssignPrefix(op,
false)))
2392 }
else if (failed(emitter.emitAssignPrefix(op,
true)))
2395 os << callOp.getCallee();
2397 auto emitArgs = [&](Attribute attr) -> LogicalResult {
2399 if (
auto t = llvm::dyn_cast<IntegerAttr>(attr))
2400 if (t.getType().isIndex()) {
2401 int64_t idx = t.getInt();
2402 if (idx < 0 || idx >= op.getNumOperands())
2403 return op.emitOpError(
"invalid operand index");
2404 if (!emitter.hasValueInScope(op.getOperand(idx)))
2405 return op.emitOpError(
"operand ")
2406 << idx <<
"'s value not defined in scope";
2407 os << emitter.getOrCreateName(op.getOperand(idx));
2410 if (failed(emitter.emitAttribute(op.getLoc(), attr)))
2416 if (callOp.getTemplateArgs()) {
2426 LogicalResult emittedArgs =
2429 : emitter.emitOperands(op);
2430 if (failed(emittedArgs))
2437static LogicalResult printOperation(CppEmitter &emitter,
2438 emitc::ApplyOp applyOp) {
2439 raw_ostream &
os = emitter.ostream();
2441 if (Operation &op = *applyOp.getOperation();
2442 failed(emitter.emitAssignPrefix(op)))
2444 os << applyOp.getApplicableOperator();
2445 os << emitter.getOrCreateName(applyOp.getOperand());
2450static LogicalResult printOperation(CppEmitter &emitter,
2451 emitc::IncludeOp includeOp) {
2452 raw_ostream &
os = emitter.ostream();
2455 if (includeOp.getIsStandardInclude())
2456 os <<
"<" << includeOp.getInclude() <<
">";
2458 os <<
"\"" << includeOp.getInclude() <<
"\"";
2463static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) {
2464 raw_indented_ostream &
os = emitter.ostream();
2466 OperandRange operands = forOp.getInitArgs();
2467 Block::BlockArgListType iterArgs = forOp.getRegionIterArgs();
2468 Operation::result_range results = forOp.getResults();
2470 if (!emitter.shouldDeclareVariablesAtTop())
2471 for (OpResult result : results)
2472 if (failed(emitter.emitVariableDeclaration(result,
2476 for (
auto pair : zip(iterArgs, operands)) {
2477 if (failed(emitter.emitType(forOp.getLoc(), std::get<0>(pair).getType())))
2479 os <<
" " << emitter.getOrCreateName(std::get<0>(pair)) <<
" = ";
2480 os << emitter.getOrCreateName(std::get<1>(pair)) <<
";";
2486 emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
2490 os << emitter.getOrCreateName(forOp.getInductionVar());
2492 os << emitter.getOrCreateName(forOp.getLowerBound());
2494 os << emitter.getOrCreateName(forOp.getInductionVar());
2496 os << emitter.getOrCreateName(forOp.getUpperBound());
2498 os << emitter.getOrCreateName(forOp.getInductionVar());
2500 os << emitter.getOrCreateName(forOp.getStep());
2502 os <<
"chess_prepare_for_pipelining\n";
2505 if (
auto [constantLoopBound, tripCount] = getTripCount(forOp);
2506 constantLoopBound) {
2507 auto [constantStep, step] = getStep(forOp);
2509 constantStep && step > 0 ? llvm::divideFloorSigned(tripCount, step) : 1;
2511 constantStep && step > 0 ? llvm::divideCeilSigned(tripCount, step) : 0;
2512 os <<
"chess_loop_range(";
2513 os << std::to_string(lb);
2515 if (constantStep && step > 0)
2516 os << std::to_string(ub);
2522 Region &forRegion = forOp.getRegion();
2523 auto regionOps = forRegion.getOps();
2529 for (
auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) {
2530 if (
bool trailingSemicolon =
2531 !isa<scf::IfOp, scf::ForOp, cf::CondBranchOp>(*it);
2532 failed(emitter.emitOperation(*it, trailingSemicolon)))
2536 Operation *yieldOp = forRegion.getBlocks().front().getTerminator();
2538 for (
auto pair : zip(iterArgs, yieldOp->getOperands())) {
2539 BlockArgument iterArg = std::get<0>(pair);
2540 Value operand = std::get<1>(pair);
2541 os << emitter.getOrCreateName(iterArg) <<
" = "
2542 << emitter.getOrCreateName(operand) <<
";\n";
2545 os.unindent() <<
"}";
2548 for (
auto pair : zip(results, iterArgs)) {
2549 OpResult result = std::get<0>(pair);
2550 BlockArgument iterArg = std::get<1>(pair);
2552 << emitter.getOrCreateName(result) <<
" = "
2553 << emitter.getOrCreateName(iterArg) <<
";";
2559static LogicalResult printOperation(CppEmitter &emitter, scf::IfOp ifOp) {
2560 raw_indented_ostream &
os = emitter.ostream();
2562 if (!emitter.shouldDeclareVariablesAtTop())
2563 for (OpResult result : ifOp.getResults())
2564 if (failed(emitter.emitVariableDeclaration(result,
2569 if (failed(emitter.emitOperands(*ifOp.getOperation())))
2574 Region &thenRegion = ifOp.getThenRegion();
2577 for (Operation &op : thenRegion.getOps())
2578 if (failed(emitter.emitOperation(op, true)))
2581 os.unindent() <<
"}";
2583 if (Region &elseRegion = ifOp.getElseRegion(); !elseRegion.empty()) {
2589 for (Operation &op : elseRegion.getOps())
2590 if (failed(emitter.emitOperation(op, true)))
2593 os.unindent() <<
"}";
2599static LogicalResult printOperation(CppEmitter &emitter, scf::YieldOp yieldOp) {
2600 raw_ostream &
os = emitter.ostream();
2601 Operation &parentOp = *yieldOp.getOperation()->getParentOp();
2603 if (yieldOp.getNumOperands() != parentOp.getNumResults())
2604 return yieldOp.emitError(
"number of operands does not to match the number "
2605 "of the parent op's results");
2608 llvm::zip(parentOp.getResults(), yieldOp.getOperands()),
2609 [&](
auto pair) -> LogicalResult {
2610 auto result = std::get<0>(pair);
2611 auto operand = std::get<1>(pair);
2612 os << emitter.getOrCreateName(result) <<
" = ";
2614 if (!emitter.hasValueInScope(operand))
2615 return yieldOp.emitError(
"operand value not in scope");
2616 os << emitter.getOrCreateName(operand);
2619 [&] { os <<
";\n"; })))
2625static LogicalResult printOperation(CppEmitter &emitter,
2626 func::ReturnOp returnOp) {
2627 raw_ostream &
os = emitter.ostream();
2629 switch (returnOp.getNumOperands()) {
2633 os <<
" " << emitter.getOrCreateName(returnOp.getOperand(0));
2634 return success(emitter.hasValueInScope(returnOp.getOperand(0)));
2636 os <<
" std::make_tuple(";
2637 if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation())))
2645static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
2646 CppEmitter::Scope scope(emitter);
2648 for (Operation &op : moduleOp)
2649 if (failed(emitter.emitOperation(op, false)))
2655static LogicalResult printOperation(CppEmitter &emitter,
2656 func::FuncOp functionOp) {
2658 if (!emitter.shouldDeclareVariablesAtTop() &&
2659 functionOp.getBlocks().size() > 1)
2660 return functionOp.emitOpError(
2661 "with multiple blocks needs variables declared at top");
2663 CppEmitter::Scope scope(emitter);
2667 if (failed(parseMemRefDynamicDims(emitter, functionOp)))
2670 raw_indented_ostream &
os = emitter.ostream();
2671 if (failed(emitter.emitTypes(functionOp.getLoc(),
2672 functionOp.getFunctionType().getResults())))
2674 os <<
" " << functionOp.getName();
2677 if (functionOp.isDeclaration()) {
2679 functionOp.getArgumentTypes(), os, [&](Type type) -> LogicalResult {
2680 if (failed(emitter.emitType(functionOp.getLoc(), type)))
2684 if (auto argType = dyn_cast<MemRefType>(type))
2685 for (unsigned dim = 0; dim < argType.getRank(); ++dim)
2686 if (argType.isDynamicDim(dim))
2696 functionOp.getArguments(), os,
2697 [&](BlockArgument arg) -> LogicalResult {
2698 if (failed(emitter.emitType(functionOp.getLoc(), arg.getType())))
2700 os <<
" " << emitter.getOrCreateName(arg);
2703 if (failed(printMemRefDims(emitter, arg)))
2711 if (emitter.shouldDeclareVariablesAtTop()) {
2715 functionOp.walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
2716 for (OpResult result : op->getResults()) {
2717 if (failed(emitter.emitVariableDeclaration(
2720 op->emitError(
"unable to declare result variable for op")};
2722 return WalkResult::advance();
2724 if (result.wasInterrupted())
2728 Region::BlockListType &blocks = functionOp.getBlocks();
2730 for (Block &block : blocks)
2731 emitter.getOrCreateName(block);
2734 for (
auto it = std::next(blocks.begin()); it != blocks.end(); ++it) {
2736 for (BlockArgument &arg : block.getArguments()) {
2737 if (emitter.hasValueInScope(arg))
2738 return functionOp.emitOpError(
" block argument #")
2739 << arg.getArgNumber() <<
" is out of scope";
2741 emitter.emitType(block.getParentOp()->getLoc(), arg.getType())))
2743 os <<
" " << emitter.getOrCreateName(arg) <<
";\n";
2747 for (Block &block : blocks) {
2749 if (blocks.size() > 1)
2750 if (failed(emitter.emitLabel(block)))
2752 for (Operation &op : block.getOperations()) {
2757 if (
bool trailingSemicolon =
2758 !isa<scf::IfOp, scf::ForOp, cf::CondBranchOp>(op);
2759 failed(emitter.emitOperation(
2760 op, trailingSemicolon)))
2764 os.unindent() <<
"}\n";
2769static LogicalResult printOperation(CppEmitter &emitter,
2770 aievec::MatMulOp matmulOp) {
2771 auto lhs = matmulOp.getLhs();
2772 auto rhs = matmulOp.getRhs();
2773 auto acc = matmulOp.getAcc();
2776 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs) ||
2777 !emitter.hasValueInScope(acc))
2780 auto lhsName = printConversionTo512bit(emitter, lhs);
2781 auto rhsName = printConversionTo512bit(emitter, rhs);
2783 raw_indented_ostream &
os = emitter.ostream();
2785 StringRef accName = emitter.getOrCreateName(acc);
2787 auto lhsShape = cast<VectorType>(lhs.getType()).getShape();
2788 auto rhsShape = cast<VectorType>(rhs.getType()).getShape();
2789 os << accName <<
" = mac_" << lhsShape[0] <<
"x" << lhsShape[1] <<
"_"
2790 << rhsShape[0] <<
"x" << rhsShape[1] <<
"(";
2791 os << lhsName <<
", " << rhsName <<
", " << accName <<
")";
2794 emitter.setName(matmulOp.getResult(), accName);
2799CppEmitter::CppEmitter(raw_ostream &os,
bool declareVariablesAtTop,
bool aie2)
2800 :
os(
os), declareVariablesAtTop(declareVariablesAtTop), aie2_(aie2) {
2801 valueInScopeCount.push(0);
2802 labelInScopeCount.push(0);
2806StringRef CppEmitter::getOrCreateName(Value val, std::string prefix) {
2807 if (!valueMapper.count(val))
2808 valueMapper.insert(val,
2809 formatv(
"{0}{1}", prefix, ++valueInScopeCount.top()));
2810 return *valueMapper.begin(val);
2814void CppEmitter::setName(Value val, StringRef name) {
2815 valueMapper.insert(val, name.str());
2819std::string CppEmitter::getNewName(std::string prefix) {
2820 std::string ret = formatv(
"{0}{1}", prefix, ++valueInScopeCount.top());
2826void CppEmitter::setMemRefDimParam(Value memref,
unsigned index,
2827 const std::string ¶meter) {
2828 auto p = std::make_pair(memref, index);
2829 assert(!paramIndexMapper.count(p) &&
"memref dimension already set");
2830 paramIndexMapper[p] = parameter;
2834StringRef CppEmitter::getMemRefDimParam(Value memref,
unsigned index) {
2835 auto p = std::make_pair(memref, index);
2836 assert(paramIndexMapper.count(p) &&
"memref dimension not found");
2837 return paramIndexMapper[p];
2842bool CppEmitter::isMemRefDimParam(Value memref,
unsigned index) {
2844 auto type = llvm::dyn_cast<MemRefType>(memref.getType());
2845 if (!(type && type.isDynamicDim(index))) {
2846 printf(
"the dimension size at index is not dynamic\n");
2852 auto p = std::make_pair(memref, index);
2853 return paramIndexMapper.count(p);
2857StringRef CppEmitter::getOrCreateName(Block &block, std::string prefix) {
2858 if (!blockMapper.count(&block))
2859 blockMapper.insert(&block,
2860 formatv(
"{0}{1}", prefix, ++labelInScopeCount.top()));
2861 return *blockMapper.begin(&block);
2864bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) {
2866 case IntegerType::Signless:
2867 case IntegerType::Signed:
2869 case IntegerType::Unsigned:
2872 llvm::report_fatal_error(
"Unexpected IntegerType::SignednessSemantics");
2875bool CppEmitter::hasValueInScope(Value val) {
return valueMapper.count(val); }
2877bool CppEmitter::hasBlockLabel(Block &block) {
2878 return blockMapper.count(&block);
2883template <
typename ElTy>
2884static std::string getSplatValueOfIntDense(DenseIntElementsAttr dense) {
2885 ElTy splatVal = dense.getSplatValue<ElTy>();
2886 return std::to_string(splatVal);
2890static std::string getSplatValueOfFloatDense(DenseFPElementsAttr dense,
2891 bool isBFloat =
false) {
2892 auto apFloat = dense.getSplatValue<APFloat>();
2893 float splatVal = apFloat.convertToFloat();
2894 std::string firstValue = std::to_string(splatVal);
2896 if (apFloat.isPosInfinity())
2901 firstValue = std::to_string(0x1.FEp+127f);
2903 firstValue = std::to_string(std::numeric_limits<float>::max());
2904 else if (apFloat.isNegInfinity())
2906 firstValue = std::to_string(-0x1.FEp+127f);
2908 firstValue = std::to_string(std::numeric_limits<float>::lowest());
2909 else if (!apFloat.isNonZero())
2915LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
2916 auto printInt = [&](
const APInt &val,
bool isUnsigned) {
2917 if (val.getBitWidth() == 1)
2918 if (val.getBoolValue())
2923 SmallString<128> strValue;
2924 val.toString(strValue, 10, !isUnsigned,
false);
2929 auto printFloat = [&](
const APFloat &val) {
2930 if (val.isFinite()) {
2931 SmallString<128> strValue;
2933 val.toString(strValue, 0, 0,
false);
2934 switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
2935 case llvm::APFloatBase::S_IEEEsingle:
2938 case llvm::APFloatBase::S_IEEEdouble:
2945 }
else if (val.isNaN())
2947 else if (val.isInfinity()) {
2948 if (val.isNegative())
2955 if (
auto fAttr = llvm::dyn_cast<FloatAttr>(attr)) {
2956 printFloat(fAttr.getValue());
2960 if (
auto dense = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2961 if (aie2() && dense.isSplat()) {
2962 if (
auto vType = llvm::dyn_cast<VectorType>(dense.getType()))
2963 if (
auto fType = llvm::dyn_cast<FloatType>(vType.getElementType())) {
2964 unsigned width = fType.getWidth();
2965 std::string splatValue;
2967 splatValue = getSplatValueOfFloatDense(dense);
2968 else if (width == 16)
2969 splatValue = getSplatValueOfFloatDense(dense,
true);
2972 if (splatValue ==
"0") {
2973 os <<
"broadcast_zero_";
2974 if (failed(emitType(loc, fType)))
2978 os <<
"broadcast_to_";
2979 if (failed(emitType(loc, vType)))
2982 if (failed(emitType(loc, fType)))
2989 os <<
"extract_v16bfloat16(";
2990 if (splatValue ==
"0")
2991 os <<
"broadcast_zero_bfloat16()";
2993 os <<
"broadcast_to_v32bfloat16";
2995 if (failed(emitType(loc, fType)))
3007 interleaveComma(dense, os, [&](
const APFloat &val) { printFloat(val); });
3014 if (
auto iAttr = llvm::dyn_cast<IntegerAttr>(attr)) {
3015 if (
auto iType = llvm::dyn_cast<IntegerType>(iAttr.getType())) {
3016 printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness()));
3019 if (llvm::dyn_cast<IndexType>(iAttr.getType())) {
3020 printInt(iAttr.getValue(),
false);
3025 if (
auto dense = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
3026 if (
auto tType = llvm::dyn_cast<TensorType>(dense.getType())) {
3027 if (
auto iType = llvm::dyn_cast<IntegerType>(tType.getElementType())) {
3029 interleaveComma(dense, os, [&](
const APInt &val) {
3030 printInt(val, shouldMapToUnsigned(iType.getSignedness()));
3035 if (llvm::dyn_cast<IndexType>(tType.getElementType())) {
3037 interleaveComma(dense, os,
3038 [&](
const APInt &val) { printInt(val,
false); });
3044 if (
auto vType = llvm::dyn_cast<VectorType>(dense.getType())) {
3045 if (
auto iType = llvm::dyn_cast<IntegerType>(vType.getElementType())) {
3046 unsigned width = iType.getWidth();
3047 if (llvm::all_of(dense, [](
const APInt &val) {
return val == 0; })) {
3050 os <<
"concat(broadcast_zero_s" << width <<
"(), broadcast_zero_s"
3054 os <<
"broadcast_zero_s";
3058 if (failed(emitType(loc, vType)))
3065 if (aie2() && dense.isSplat()) {
3066 std::string splatValue;
3068 splatValue = getSplatValueOfIntDense<int32_t>(dense);
3069 else if (width == 16)
3070 splatValue = getSplatValueOfIntDense<int16_t>(dense);
3071 else if (width == 8)
3072 splatValue = getSplatValueOfIntDense<int8_t>(dense);
3073 os <<
"broadcast_to_";
3074 if (failed(emitType(loc, vType)))
3077 if (failed(emitType(loc, iType)))
3085 interleaveComma(dense, os, [&](
const APInt &val) {
3086 printInt(val, shouldMapToUnsigned(iType.getSignedness()));
3092 if (llvm::dyn_cast<IndexType>(vType.getElementType())) {
3094 interleaveComma(dense, os,
3095 [&](
const APInt &val) { printInt(val,
false); });
3103 if (
auto oAttr = llvm::dyn_cast<emitc::OpaqueAttr>(attr)) {
3104 os << oAttr.getValue();
3109 if (
auto sAttr = llvm::dyn_cast<SymbolRefAttr>(attr)) {
3110 if (sAttr.getNestedReferences().size() > 1)
3111 return emitError(loc,
"attribute has more than 1 nested reference");
3112 os << sAttr.getRootReference().getValue();
3117 if (
auto type = llvm::dyn_cast<TypeAttr>(attr))
3118 return emitType(loc, type.getValue());
3120 return emitError(loc,
"cannot emit attribute of type ") << attr;
3123LogicalResult CppEmitter::emitOperands(Operation &op) {
3124 auto emitOperandName = [&](Value result) -> LogicalResult {
3125 if (!hasValueInScope(result))
3126 return op.emitOpError() <<
"operand value not in scope";
3127 os << getOrCreateName(result);
3134CppEmitter::emitOperandsAndAttributes(Operation &op,
3135 ArrayRef<StringRef> exclude) {
3136 if (failed(emitOperands(op)))
3139 if (op.getNumOperands() > 0)
3140 for (NamedAttribute attr : op.getAttrs())
3141 if (!is_contained(exclude, attr.getName().strref())) {
3146 auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult {
3147 if (is_contained(exclude, attr.getName().strref()))
3149 os <<
"/* " << attr.getName().getValue() <<
" */";
3150 if (failed(emitAttribute(op.getLoc(), attr.getValue())))
3158LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
3159 if (!hasValueInScope(result)) {
3160 return result.getDefiningOp()->emitOpError(
3161 "result variable for the operation has not been declared");
3163 os << getOrCreateName(result) <<
" = ";
3168LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
3169 bool trailingSemicolon,
3171 if (hasValueInScope(result))
3172 return result.getDefiningOp()->emitError(
3173 "result variable for the operation already declared");
3175 emitType(result.getOwner()->getLoc(), result.getType(),
true, isAcc)))
3177 os <<
" " << getOrCreateName(result);
3178 if (trailingSemicolon)
3184LogicalResult CppEmitter::emitAssignPrefix(Operation &op,
bool isAcc) {
3185 switch (op.getNumResults()) {
3189 OpResult result = op.getResult(0);
3190 if (shouldDeclareVariablesAtTop()) {
3191 if (failed(emitVariableAssignment(result)))
3194 if (failed(emitVariableDeclaration(result,
false,
3202 if (!shouldDeclareVariablesAtTop())
3203 for (OpResult result : op.getResults())
3204 if (failed(emitVariableDeclaration(result, true)))
3208 interleaveComma(op.getResults(), os,
3209 [&](Value result) { os << getOrCreateName(result); });
3215LogicalResult CppEmitter::emitLabel(Block &block) {
3216 if (!hasBlockLabel(block))
3217 return block.getParentOp()->emitError(
"label for block not found");
3220 os.getOStream() << getOrCreateName(block) <<
":\n";
3224LogicalResult CppEmitter::emitOperation(Operation &op,
bool trailingSemicolon) {
3227 if (skippedOp(&op, *
this))
3230 LogicalResult status =
3231 TypeSwitch<Operation *, LogicalResult>(&op)
3233 .Case<emitc::ApplyOp, emitc::CallOpaqueOp, emitc::ConstantOp>(
3234 [&](
auto op) {
return printOperation(*
this, op); })
3235 .Case<emitc::IncludeOp>([&](
auto op) {
3236 if (StringRef name = op.getInclude(); !includeNames.count(name)) {
3237 includeNames.insert(name);
3238 return printOperation(*
this, op);
3243 .Case<scf::ForOp, scf::IfOp, scf::YieldOp>(
3244 [&](
auto op) {
return printOperation(*
this, op); })
3246 .Case<cf::BranchOp, func::CallOp, cf::CondBranchOp, func::FuncOp,
3247 ModuleOp, func::ReturnOp>(
3248 [&](
auto op) {
return printOperation(*
this, op); })
3250 .Case<arith::ConstantOp>(
3251 [&](
auto op) {
return printOperation(*
this, op); })
3254 .Case<arith::AddIOp>(
3255 [&](
auto op) {
return printOperation<arith::AddIOp>(*
this, op); })
3256 .Case<arith::AddFOp>(
3257 [&](
auto op) {
return printOperation<arith::AddFOp>(*
this, op); })
3258 .Case<arith::MulIOp>(
3259 [&](
auto op) {
return printOperation<arith::MulIOp>(*
this, op); })
3260 .Case<arith::MulFOp>(
3261 [&](
auto op) {
return printOperation<arith::MulFOp>(*
this, op); })
3262 .Case<arith::SubIOp>(
3263 [&](
auto op) {
return printOperation<arith::SubIOp>(*
this, op); })
3264 .Case<arith::SubFOp>(
3265 [&](
auto op) {
return printOperation<arith::SubFOp>(*
this, op); })
3266 .Case<arith::DivSIOp>([&](
auto op) {
3267 return printOperation<arith::DivSIOp>(*
this, op);
3269 .Case<arith::DivUIOp>([&](
auto op) {
3270 return printOperation<arith::DivUIOp>(*
this, op);
3272 .Case<arith::DivFOp>(
3273 [&](
auto op) {
return printOperation<arith::DivFOp>(*
this, op); })
3274 .Case<arith::RemSIOp>([&](
auto op) {
3275 return printOperation<arith::RemSIOp>(*
this, op);
3277 .Case<arith::CmpIOp>(
3278 [&](
auto op) {
return printOperation<arith::CmpIOp>(*
this, op); })
3279 .Case<arith::SelectOp>(
3280 [&](
auto op) {
return printOperation(*
this, op); })
3282 .Case<vector::TransferWriteOp>(
3283 [&](
auto op) {
return printOperation(*
this, op); })
3285 .Case<memref::StoreOp, memref::ExpandShapeOp,
3286 memref::CollapseShapeOp>(
3287 [&](
auto op) {
return printOperation(*
this, op); })
3289 .Case<aievec::aie1::AddOp, aievec::aie1::SubOp, aievec::aie1::FMAOp,
3290 aievec::aie1::MulOp, aievec::aie1::SelectOp,
3291 aievec::aie1::ExtOp>(
3292 [&](
auto op) {
return printOperation(*
this, op); })
3294 .Case<AddElemOp, ConcatOp, ExtOp, PackOp, SRSOp, SubElemOp, UPDOp,
3295 UPSOp, FMAElemOp, MulElemOp, BroadcastOp, BroadcastScalarOp,
3296 MulConvOp, FMAConvOp, ShiftOp, ShuffleOp, CastOp, MinOp, MaxOp,
3297 NegOp, CmpOp, SelOp, ExtElemOp, BxorOp, BnegOp, BandOp, BorOp,
3298 UnpackOp, MatMulOp, LegacyShuffleOp>(
3299 [&](
auto op) {
return printOperation(*
this, op); })
3300 .Default([&](Operation *) {
3301 return op.emitOpError(
"unable to find printer for op");
3306 os << (trailingSemicolon ?
";\n" :
"\n");
3311std::optional<std::string>
3312CppEmitter::genCppTypeName(Type type,
bool stdintType,
bool isAcc) {
3313 std::stringstream ss;
3314 if (
auto iType = dyn_cast<IntegerType>(type)) {
3315 switch (iType.getWidth()) {
3322 if (shouldMapToUnsigned(iType.getSignedness()))
3323 ss <<
"uint" << iType.getWidth() << (stdintType ?
"_t" :
"");
3325 ss <<
"int" << iType.getWidth() << (stdintType ?
"_t" :
"");
3329 ss <<
"acc" << iType.getWidth();
3335 if (
auto fType = dyn_cast<FloatType>(type)) {
3336 switch (fType.getWidth()) {
3347 if (
auto iType = dyn_cast<IndexType>(type))
3350 if (
auto tType = dyn_cast<TensorType>(type)) {
3351 if (!tType.hasRank())
3353 if (!tType.hasStaticShape())
3356 auto nestedTypeName = genCppTypeName(tType.getElementType());
3357 if (!nestedTypeName)
3359 ss << *nestedTypeName;
3360 auto shape = tType.getShape();
3361 for (
auto dimSize : shape) {
3368 if (
auto tType = dyn_cast<TupleType>(type)) {
3369 ss <<
"std::tuple<";
3370 bool itrleaveFailed =
false;
3374 auto optTyNameStr = genCppTypeName(type);
3376 ss << *optTyNameStr;
3378 itrleaveFailed = true;
3380 [&]() { ss <<
", "; });
3382 if (!itrleaveFailed)
3386 if (
auto oType = dyn_cast<emitc::OpaqueType>(type)) {
3387 ss << oType.getValue().str();
3392 if (
auto tType = dyn_cast<MemRefType>(type)) {
3393 auto elemTyStrOpt = genCppTypeName(tType.getElementType());
3396 ss << *elemTyStrOpt <<
" * restrict";
3400 if (
auto tType = dyn_cast<VectorType>(type)) {
3401 Type eltType = tType.getElementType();
3403 auto vShape = tType.getShape();
3404 int64_t numElems = std::accumulate(vShape.begin(), vShape.end(), 1,
3405 std::multiplies<int64_t>());
3406 ss <<
"v" << std::to_string(numElems);
3408 int64_t iElTyBitWidth = 0;
3409 auto iElTy = dyn_cast<IntegerType>(eltType);
3411 iElTyBitWidth = iElTy.getWidth();
3412 if (aie2() && (isAcc || iElTyBitWidth == 64)) {
3416 if ((numElems == 16 && iElTyBitWidth == 64) ||
3417 (numElems == 32 && iElTyBitWidth == 32) ||
3418 (numElems == 16 && iElTyBitWidth == 32)) {
3419 ss <<
"acc" << iElTyBitWidth;
3424 if (isa<FloatType>(eltType)) {
3430 auto elTyNameOpt = genCppTypeName(eltType,
false);
3439LogicalResult CppEmitter::emitType(Location loc, Type type,
bool stdintType,
3441 auto typeName = genCppTypeName(type, stdintType, isAcc);
3443 return emitError(loc,
"cannot emit type ") << type;
3448LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
3449 switch (types.size()) {
3454 return emitType(loc, types.front());
3456 return emitTupleType(loc, types);
3460LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
3461 os <<
"std::tuple<";
3463 types, os, [&](Type type) {
return emitType(loc, type); })))
3471 CppEmitter emitter(os,
false, aie2);
3472 return emitter.emitOperation(*op,
false);
LogicalResult interleaveCommaWithError(const Container &c, raw_ostream &os, UnaryFunctor eachFn)
LogicalResult interleaveWithError(ForwardIterator begin, ForwardIterator end, UnaryFunctor eachFn, NullaryFunctor betweenFn)
Convenience functions to produce interleaved output with functions returning a LogicalResult.
mlir::LogicalResult translateAIEVecToCpp(mlir::Operation *op, bool aie2, mlir::raw_ostream &os)
Translates the AIE vector dialect MLIR to C++ code.
int32_t getVectorSizeInBits(mlir::VectorType type)
unsigned getVectorLaneSize(mlir::VectorType type)
int32_t getElementSizeInBits(mlir::VectorType type)