21#include "mlir/Dialect/Arith/IR/Arith.h"
22#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
23#include "mlir/Dialect/EmitC/IR/EmitC.h"
24#include "mlir/Dialect/Func/IR/FuncOps.h"
25#include "mlir/Dialect/Index/IR/IndexOps.h"
26#include "mlir/Dialect/MemRef/IR/MemRef.h"
27#include "mlir/Dialect/SCF/IR/SCF.h"
28#include "mlir/Dialect/Vector/IR/VectorOps.h"
29#include "mlir/IR/BuiltinOps.h"
30#include "mlir/IR/BuiltinTypes.h"
31#include "mlir/IR/Operation.h"
32#include "mlir/Support/IndentedOstream.h"
34#include "llvm/ADT/ScopedHashTable.h"
35#include "llvm/ADT/SmallSet.h"
36#include "llvm/ADT/StringRef.h"
37#include "llvm/ADT/TypeSwitch.h"
38#include "llvm/Support/CommandLine.h"
39#include "llvm/Support/Debug.h"
40#include "llvm/Support/FormatVariadic.h"
41#include "llvm/Support/MathExtras.h"
49#define DEBUG_TYPE "aievec-to-cpp"
59template <
typename ForwardIterator,
typename UnaryFunctor,
60 typename NullaryFunctor>
63 NullaryFunctor betweenFn) {
66 if (failed(eachFn(*begin)))
69 for (; begin != end; ++begin) {
71 if (failed(eachFn(*begin)))
77template <
typename Container,
typename UnaryFunctor,
typename NullaryFunctor>
79 NullaryFunctor betweenFn) {
83template <
typename Container,
typename UnaryFunctor>
85 UnaryFunctor eachFn) {
92 explicit CppEmitter(raw_ostream &os,
bool declareVariablesAtTop,
bool aie2);
95 LogicalResult emitAttribute(Location loc, Attribute attr);
98 LogicalResult emitOperation(Operation &op,
bool trailingSemicolon);
104 std::optional<std::string> genCppTypeName(Type type,
bool stdintType =
true,
109 LogicalResult emitType(Location loc, Type type,
bool stdintType =
true,
116 LogicalResult emitTypes(Location loc, ArrayRef<Type> types);
120 LogicalResult emitTupleType(Location loc, ArrayRef<Type> types);
123 LogicalResult emitVariableAssignment(OpResult result);
126 LogicalResult emitVariableDeclaration(OpResult result,
bool trailingSemicolon,
135 LogicalResult emitAssignPrefix(Operation &op,
bool isAcc =
false);
138 LogicalResult emitLabel(Block &block);
142 LogicalResult emitOperandsAndAttributes(Operation &op,
143 ArrayRef<StringRef> exclude = {});
146 LogicalResult emitOperands(Operation &op);
149 StringRef getOrCreateName(Value val, std::string prefix =
"v");
152 void setName(Value val, StringRef name);
155 std::string getNewName(std::string prefix =
"v");
158 void setMemRefDimParam(Value memref,
unsigned index,
159 const std::string ¶meter);
162 StringRef getMemRefDimParam(Value memref,
unsigned index);
165 bool isMemRefDimParam(Value memref,
unsigned index);
168 StringRef getOrCreateName(Block &block, std::string prefix =
"label");
171 bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);
175 Scope(CppEmitter &emitter)
176 : valueMapperScope(emitter.valueMapper),
177 blockMapperScope(emitter.blockMapper), emitter(emitter) {
178 emitter.valueInScopeCount.push(emitter.valueInScopeCount.top());
179 emitter.labelInScopeCount.push(emitter.labelInScopeCount.top());
182 emitter.valueInScopeCount.pop();
183 emitter.labelInScopeCount.pop();
187 llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
188 llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
193 bool hasValueInScope(Value val);
196 bool hasBlockLabel(Block &block);
199 raw_indented_ostream &ostream() {
return os; }
203 bool shouldDeclareVariablesAtTop() {
return declareVariablesAtTop; }
205 bool aie2() {
return aie2_; }
208 using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
209 using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
212 raw_indented_ostream
os;
217 bool declareVariablesAtTop;
220 ValueMapper valueMapper;
223 BlockMapper blockMapper;
226 DenseMap<std::pair<Value, unsigned>, std::string> paramIndexMapper;
230 std::stack<int64_t> valueInScopeCount;
231 std::stack<int64_t> labelInScopeCount;
233 llvm::SmallSet<StringRef, 16> includeNames;
249static bool skippedOp(Operation *op, CppEmitter &emitter,
250 bool checkStrongLiveness =
true) {
253 TypeSwitch<Operation *, bool>(op)
255 .Case<memref::DimOp, memref::AssumeAlignmentOp>(
256 [](
auto op) {
return true; })
258 .Case<aievec::SRSOp>([&](
auto srsOp) {
260 auto accType = cast<VectorType>(srsOp.getSource().getType());
261 Type eltType = accType.getElementType();
264 Value source = srsOp.getSource();
265 if (!emitter.aie2() && llvm::isa<FloatType>(eltType) &&
266 source.getDefiningOp()->hasOneUse()) {
267 StringRef srcName = emitter.getOrCreateName(source);
268 emitter.setName(srsOp->getResult(0), srcName);
274 .Case<aievec::UPSOp>([&](
auto upsOp) {
276 auto accType = cast<VectorType>(upsOp.getResult().getType());
277 Type eltType = accType.getElementType();
280 Value source = upsOp.getSource();
281 if (!emitter.aie2() && llvm::isa<FloatType>(eltType) &&
282 source.getDefiningOp()->hasOneUse()) {
283 StringRef srcName = emitter.getOrCreateName(source);
284 emitter.setName(upsOp->getResult(0), srcName);
292 .Case<aievec::CastOp>([&](
auto castOp) {
293 Value source = castOp.getSource();
294 auto srcVTy = cast<VectorType>(source.getType());
295 auto resVTy = cast<VectorType>(castOp.getResult().getType());
296 if (srcVTy.getElementType() == resVTy.getElementType()) {
297 auto iElTy = dyn_cast<IntegerType>(srcVTy.getElementType());
298 if (iElTy && iElTy.getWidth() == 64) {
299 StringRef srcName = emitter.getOrCreateName(source);
300 emitter.setName(castOp->getResult(0), srcName);
307 .Case<arith::IndexCastOp, arith::IndexCastUIOp, index::CastSOp,
308 index::CastUOp>([&](
auto idxCastOp) {
309 Value source = idxCastOp->getOperand(0);
310 StringRef srcName = emitter.getOrCreateName(source);
311 emitter.setName(idxCastOp->getResult(0), srcName);
315 .Case<vector::ShapeCastOp>([&](
auto castOp) {
316 Value source = castOp.getSource();
317 StringRef srcName = emitter.getOrCreateName(source);
318 emitter.setName(castOp.getResult(), srcName);
323 .Case<UnrealizedConversionCastOp>([&](
auto uccOp) {
324 auto inputs = uccOp.getInputs();
325 auto outputs = uccOp.getOutputs();
326 if (inputs.size() > 1 || inputs.size() > 1)
328 StringRef inputName = emitter.getOrCreateName(inputs[0]);
329 emitter.setName(outputs[0], inputName);
332 .Default([&](Operation *) {
return false; });
335 checkStrongLiveness &= isa<arith::ConstantOp>(op);
339 if (skip || !checkStrongLiveness)
345 for (
auto user : op->getUsers()) {
346 if (!skippedOp(user, emitter,
false))
353static LogicalResult parseMemRefDynamicDims(CppEmitter &emitter,
356 func.walk([&](Operation *Op) {
357 if (
auto op = dyn_cast<memref::DimOp>(Op)) {
359 Value source = op.getSource();
360 Value result = op.getResult();
361 auto indexOp = dyn_cast<arith::ConstantOp>(op.getIndex().getDefiningOp());
362 assert(indexOp &&
"Failed to get the index value of dimOp");
364 APInt idxVal = llvm::cast<IntegerAttr>(indexOp.getValue()).getValue();
365 unsigned index = idxVal.getZExtValue();
367 StringRef name = emitter.getOrCreateName(result,
"m");
368 emitter.setMemRefDimParam(source, index, name.str());
374 for (BlockArgument arg : func.getArguments()) {
375 auto argType = llvm::dyn_cast<MemRefType>(arg.getType());
378 for (
unsigned dim = 0; dim < argType.getRank(); ++dim) {
379 if (argType.isDynamicDim(dim)) {
381 if (!emitter.isMemRefDimParam(arg, dim)) {
382 std::string name = emitter.getNewName(
"m");
383 emitter.setMemRefDimParam(arg, dim, name);
392static LogicalResult printMemRefDims(CppEmitter &emitter, BlockArgument arg) {
393 raw_indented_ostream &
os = emitter.ostream();
394 if (
auto argType = llvm::dyn_cast<MemRefType>(arg.getType())) {
395 for (
unsigned dim = 0; dim < argType.getRank(); ++dim) {
396 if (argType.isDynamicDim(dim)) {
397 StringRef param = emitter.getMemRefDimParam(arg, dim);
398 os <<
", size_t " << param;
406static LogicalResult createLinearizedAccess(CppEmitter &emitter, Value source,
407 SmallVector<Value, 4> indices,
408 std::string &access) {
409 auto memRefType = llvm::dyn_cast<MemRefType>(source.getType());
411 "cannot creating linearized expression for non-memref type");
412 ArrayRef<int64_t> stride = memRefType.getShape();
415 if (stride.size() != indices.size() ||
416 static_cast<int64_t
>(stride.size()) != memRefType.getRank())
421 std::string paramPart;
423 SmallVector<std::string, 4> accessVec;
424 for (
int dim = memRefType.getRank() - 1; dim >= 0; --dim) {
426 if (!emitter.hasValueInScope(indices[dim]))
431 if (!paramPart.empty())
432 cur = paramPart +
"*";
434 cur += std::to_string(numPart) +
"*";
435 cur += emitter.getOrCreateName(indices[dim]);
436 accessVec.push_back(cur);
440 if (memRefType.isDynamicDim(dim)) {
441 StringRef param = emitter.getMemRefDimParam(source, dim);
442 paramPart = param.str() + (paramPart.empty() ?
"" :
"*" + paramPart);
444 numPart *= stride[dim];
447 while (!accessVec.empty()) {
448 access += (access.empty() ?
"" :
"+") + accessVec.back();
449 accessVec.pop_back();
459static bool isReadOnly(Value read) {
461 read.getUsers().begin(), read.getUsers().end(),
462 [](
auto *user) { return isa<vector::TransferWriteOp>(user); });
470static std::pair<bool, int64_t> getTripCount(scf::ForOp forOp) {
472 auto lb = forOp.getLowerBound().getDefiningOp<arith::ConstantOp>();
473 if (
auto ub = forOp.getUpperBound().getDefiningOp<arith::ConstantOp>();
475 APInt ubValue = llvm::cast<IntegerAttr>(ub.getValue()).getValue();
476 APInt lbValue = llvm::cast<IntegerAttr>(lb.getValue()).getValue();
477 return std::make_pair(
true,
478 ubValue.getSExtValue() - lbValue.getSExtValue());
480 return std::make_pair(
false, 0);
484static std::pair<bool, int64_t> getStep(scf::ForOp forOp) {
485 if (
auto step = forOp.getStep().getDefiningOp<arith::ConstantOp>()) {
486 APInt stepValue = llvm::cast<IntegerAttr>(step.getValue()).getValue();
487 return std::make_pair(
true, stepValue.getSExtValue());
489 return std::make_pair(
false, 0);
494static StringRef getOperator(T binOp) {
495 if (isa<arith::AddIOp>(binOp) || isa<arith::AddFOp>(binOp))
497 if (isa<arith::MulIOp>(binOp) || isa<arith::MulFOp>(binOp))
499 if (isa<arith::SubIOp>(binOp) || isa<arith::SubFOp>(binOp))
501 if (isa<arith::DivFOp>(binOp) || isa<arith::DivUIOp>(binOp) ||
502 isa<arith::DivSIOp>(binOp))
504 if (isa<arith::RemSIOp>(binOp))
506 if (isa<arith::CmpIOp>(binOp)) {
507 auto cmpOp = cast<arith::CmpIOp>(binOp);
508 switch (cmpOp.getPredicate()) {
509 case arith::CmpIPredicate::eq:
511 case arith::CmpIPredicate::ne:
513 case arith::CmpIPredicate::sge:
514 case arith::CmpIPredicate::uge:
516 case arith::CmpIPredicate::sgt:
517 case arith::CmpIPredicate::ugt:
519 case arith::CmpIPredicate::sle:
520 case arith::CmpIPredicate::ule:
522 case arith::CmpIPredicate::slt:
523 case arith::CmpIPredicate::ult:
527 llvm_unreachable(
"Cannot print the operation of binary operator");
532static LogicalResult printOperation(CppEmitter &emitter, T binOp) {
533 if (failed(emitter.emitAssignPrefix(*binOp)))
535 raw_indented_ostream &
os = emitter.ostream();
536 auto lhs = binOp.getLhs();
537 if (!emitter.hasValueInScope(lhs))
539 os << emitter.getOrCreateName(lhs);
540 os << getOperator(binOp);
541 auto rhs = binOp.getRhs();
542 if (!emitter.hasValueInScope(rhs))
544 os << emitter.getOrCreateName(rhs);
550static LogicalResult printOperation(CppEmitter &emitter,
551 arith::SelectOp selectOp) {
552 if (failed(emitter.emitAssignPrefix(*selectOp)))
555 auto cond = selectOp.getCondition();
556 if (!emitter.hasValueInScope(cond))
558 auto tVal = selectOp.getTrueValue();
559 if (!emitter.hasValueInScope(tVal))
561 auto fVal = selectOp.getFalseValue();
562 if (!emitter.hasValueInScope(fVal))
565 raw_indented_ostream &
os = emitter.ostream();
566 os << emitter.getOrCreateName(cond) <<
" ? " << emitter.getOrCreateName(tVal)
567 <<
" : " << emitter.getOrCreateName(fVal);
577static LogicalResult printOperation(CppEmitter &emitter, aievec::UPDOp updOp) {
578 Value source = updOp.getSource();
580 if (!emitter.hasValueInScope(source))
584 auto indices = updOp.getIndices();
586 if (failed(createLinearizedAccess(emitter, source, indices, access)))
589 raw_indented_ostream &
os = emitter.ostream();
590 Value result = updOp.getResult();
591 auto resultType = llvm::cast<VectorType>(result.getType());
596 if (updOp.getOffset() != 0) {
597 if (std::abs(updOp.getOffset()) % elementSizeInBits)
599 int32_t updOffset = updOp.getOffset() / elementSizeInBits;
600 access += updOffset > 0 ?
" + " :
" - ";
601 access += std::to_string(std::abs(updOffset));
607 if (vecSizeInBits <= (emitter.aie2() ? 1024 : 256)) {
609 if (failed(emitter.emitAssignPrefix(*updOp)))
612 if (failed(emitter.emitType(updOp->getLoc(), resultType)))
616 os << emitter.getOrCreateName(source);
618 os <<
" + " << access;
621 Value vector = updOp.getVector();
625 if (!emitter.shouldDeclareVariablesAtTop()) {
626 if (failed(emitter.emitVariableDeclaration(updOp->getResult(0),
true)))
630 if (!emitter.hasValueInScope(vector))
632 emitter.setName(updOp->getResult(0), emitter.getOrCreateName(vector));
636 int32_t granularity = vecSizeInBits == 256 ? 128
637 : vecSizeInBits == 512 ? 256
641 assert(lanes % 2 == 0 &&
642 "The number of vector lanes of UPD result is not even");
643 SmallVector<int64_t, 4> updShape = {lanes / 2};
644 VectorType updType = VectorType::get(updShape, resultType.getElementType());
646 if (!emitter.hasValueInScope(result))
649 bool readOnly = isReadOnly(source);
650 std::string restrictPrefix =
651 readOnly ?
"r_" + emitter.getOrCreateName(result).str() +
"_" :
"";
653 if (readOnly && !vector) {
654 if (failed(emitter.emitType(updOp->getLoc(), source.getType())))
656 os <<
" " << restrictPrefix << emitter.getOrCreateName(source);
658 os << emitter.getOrCreateName(source);
661 os << emitter.getOrCreateName(result);
663 os << (granularity == 128 ?
"upd_v"
664 : granularity == 256 ?
"upd_w"
667 os << emitter.getOrCreateName(result);
669 os << std::to_string(updOp.getIndex());
672 if (failed(emitter.emitType(updOp->getLoc(), updType)))
676 os << restrictPrefix << emitter.getOrCreateName(source);
678 os <<
" + " << access;
687static LogicalResult printOperation(CppEmitter &emitter, aievec::UPSOp upsOp) {
688 Value source = upsOp.getSource();
689 int32_t shift = upsOp.getShift();
691 raw_indented_ostream &
os = emitter.ostream();
694 if (failed(emitter.emitAssignPrefix(*upsOp,
true)))
698 if (!emitter.hasValueInScope(source))
701 auto accType = llvm::cast<VectorType>(upsOp.getResult().getType());
703 Type eltType = accType.getElementType();
707 if (!emitter.aie2() && llvm::isa<FloatType>(eltType)) {
708 os << emitter.getOrCreateName(source);
713 auto iType = llvm::dyn_cast<IntegerType>(eltType);
714 auto fType = llvm::dyn_cast<FloatType>(eltType);
716 if (iType.getWidth() == 80)
720 if (iType && emitter.aie2()) {
721 os <<
"ups_to_v" << lanes <<
"acc" << iType.getWidth();
722 }
else if (fType && emitter.aie2()) {
723 os <<
"ups_to_v16accfloat";
729 os << emitter.getOrCreateName(source);
730 if (!(fType && emitter.aie2())) {
732 os << std::to_string(shift);
740static LogicalResult printOperation(CppEmitter &emitter,
741 aievec::CastOp castOp) {
742 if (!emitter.aie2()) {
747 Value source = castOp.getSource();
748 if (!emitter.hasValueInScope(source))
751 bool isResAcc = castOp.getIsResAcc();
754 if (failed(emitter.emitAssignPrefix(*castOp, isResAcc)))
758 auto resType = llvm::cast<VectorType>(castOp->getResult(0).getType());
759 Type eltType = resType.getElementType();
762 raw_indented_ostream &
os = emitter.ostream();
766 if (llvm::isa<FloatType>(eltType))
767 os <<
"v" << lanes <<
"accfloat";
770 os <<
"v" << lanes <<
"acc" << width;
772 }
else if (llvm::isa<FloatType>(eltType)) {
773 width = llvm::cast<FloatType>(eltType).getWidth();
781 os <<
"v" << lanes <<
"int" << width;
784 os << emitter.getOrCreateName(source);
790static LogicalResult printOperation(CppEmitter &emitter,
791 aievec::UnpackOp unpackOp) {
794 Value source = unpackOp.getSource();
795 if (!emitter.hasValueInScope(source))
799 if (failed(emitter.emitAssignPrefix(*unpackOp,
false)))
802 raw_indented_ostream &
os = emitter.ostream();
805 os << emitter.getOrCreateName(source);
811static LogicalResult printOperation(CppEmitter &emitter, aievec::SRSOp srsOp) {
812 Value source = srsOp.getSource();
813 Value shift = srsOp.getShift();
816 auto accType = llvm::cast<VectorType>(srsOp.getSource().getType());
817 auto resType = llvm::cast<VectorType>(srsOp->getResult(0).getType());
818 Type eltType = accType.getElementType();
821 raw_indented_ostream &
os = emitter.ostream();
824 if (failed(emitter.emitAssignPrefix(*srsOp)))
828 if (!emitter.hasValueInScope(source))
833 if (llvm::isa<FloatType>(eltType)) {
834 if (emitter.aie2()) {
837 else if (width == 16)
838 os <<
"to_v16bfloat16";
840 os << emitter.getOrCreateName(source);
843 os << emitter.getOrCreateName(source);
852 unsigned srcWidth = 0;
853 if (
auto iType = llvm::dyn_cast<IntegerType>(eltType))
854 srcWidth = iType.getWidth();
857 if ((srcWidth == 80 && resultWidth == 64) ||
858 (srcWidth == 48 && resultWidth == 32))
860 else if (srcWidth == 48 && resultWidth == 8)
864 os <<
"srs_to_v" << std::to_string(lanes) <<
"int"
865 << std::to_string(resWidth);
870 os << emitter.getOrCreateName(source);
872 if (llvm::cast<IntegerType>(srsOp.getShift().getType()).getWidth() != 32)
874 os << emitter.getOrCreateName(shift);
881static LogicalResult printOperation(CppEmitter &emitter,
882 aievec::BroadcastOp broadcastOp) {
883 Value source = broadcastOp.getSource();
884 int8_t idx = broadcastOp.getIdx();
886 raw_indented_ostream &
os = emitter.ostream();
889 if (failed(emitter.emitAssignPrefix(*broadcastOp)))
893 if (!emitter.hasValueInScope(source))
896 os <<
"broadcast_elem";
898 os << emitter.getOrCreateName(source);
900 os << std::to_string(idx);
908printOperation(CppEmitter &emitter,
909 aievec::BroadcastScalarOp broadcastScalarOp) {
910 auto source = broadcastScalarOp.getSource();
912 llvm::cast<VectorType>(broadcastScalarOp.getResult().getType());
915 raw_indented_ostream &
os = emitter.ostream();
918 if (failed(emitter.emitAssignPrefix(*broadcastScalarOp)))
921 Type eltType = resType.getElementType();
922 os <<
"broadcast_to_v";
923 if (llvm::isa<IntegerType>(eltType)) {
924 os << lanes <<
"int";
926 }
else if (width == 16)
927 os << lanes <<
"bfloat16";
929 os << lanes <<
"float";
930 os <<
"(" << emitter.getOrCreateName(source) <<
")";
937static LogicalResult printExtOperation(CppEmitter &emitter, T extOp) {
938 Value source = extOp.getSource();
939 int8_t index = extOp.getIndex();
941 raw_indented_ostream &
os = emitter.ostream();
944 if (failed(emitter.emitAssignPrefix(*extOp)))
947 if (!emitter.hasValueInScope(source))
950 auto resType = llvm::cast<VectorType>(extOp.getResult().getType());
951 Type eltType = resType.getElementType();
956 if (emitter.aie2()) {
957 os <<
"extract_v" << std::to_string(lanes);
958 if (llvm::isa<IntegerType>(eltType))
959 os <<
"int" << std::to_string(resWidth);
960 else if (resWidth == 16)
967 assert(vecSizeInBits == 128 || vecSizeInBits == 256 ||
968 vecSizeInBits == 512);
969 os << (vecSizeInBits == 128 ?
"ext_v"
970 : vecSizeInBits == 256 ?
"ext_w"
975 os << emitter.getOrCreateName(source);
977 os << std::to_string(index);
984static LogicalResult printOperation(CppEmitter &emitter, aievec::ExtOp extOp) {
987 return printExtOperation<aievec::ExtOp>(emitter, extOp);
991static LogicalResult printOperation(CppEmitter &emitter,
992 aievec::aie1::ExtOp extOp) {
995 return printExtOperation<aievec::aie1::ExtOp>(emitter, extOp);
999static LogicalResult printOperation(CppEmitter &emitter,
1000 aievec::ConcatOp concatOp) {
1001 SmallVector<Value> sources = concatOp.getSources();
1003 raw_indented_ostream &
os = emitter.ostream();
1006 if (failed(emitter.emitAssignPrefix(*concatOp)))
1013 for (
auto source : sources) {
1015 if (!emitter.hasValueInScope(source))
1019 os << emitter.getOrCreateName(source);
1028static LogicalResult printOperation(CppEmitter &emitter,
1029 aievec::ShiftOp shiftOp) {
1030 Value lhs = shiftOp.getLhs();
1031 Value rhs = shiftOp.getRhs();
1032 Value shift = shiftOp.getShift();
1033 bool isAcc = shiftOp.getIsAcc();
1035 raw_indented_ostream &
os = emitter.ostream();
1038 if (failed(emitter.emitAssignPrefix(*shiftOp, isAcc)))
1041 os <<
"shift_bytes";
1044 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs))
1046 os << emitter.getOrCreateName(lhs);
1048 os << emitter.getOrCreateName(rhs);
1049 os <<
", static_cast<uint32_t>(";
1051 if (!emitter.hasValueInScope(shift))
1053 os << emitter.getOrCreateName(shift);
1060static LogicalResult printOperation(CppEmitter &emitter,
1061 aievec::ShuffleOp shuffleOp) {
1062 Value lhs = shuffleOp.getLhs();
1063 Value rhs = shuffleOp.getRhs();
1064 aievec::ShuffleMode mode = shuffleOp.getMode();
1066 raw_indented_ostream &
os = emitter.ostream();
1069 if (failed(emitter.emitAssignPrefix(*shuffleOp)))
1074 if (!emitter.hasValueInScope(lhs))
1076 os << emitter.getOrCreateName(lhs);
1079 if (!emitter.hasValueInScope(rhs))
1081 os << emitter.getOrCreateName(rhs);
1084 os <<
"eShuffleMode::shuffle_T" << stringifyEnum(mode).substr(1);
1091static LogicalResult printOperation(CppEmitter &emitter,
1092 aievec::LegacyShuffleOp shuffleOp) {
1093 Value source = shuffleOp.getSource();
1094 unsigned mode = shuffleOp.getMode();
1096 raw_indented_ostream &
os = emitter.ostream();
1099 if (failed(emitter.emitAssignPrefix(*shuffleOp)))
1106 if (!emitter.hasValueInScope(source))
1108 os << emitter.getOrCreateName(source);
1110 os << std::to_string(mode);
1117static LogicalResult printOperation(CppEmitter &emitter,
1118 aievec::aie1::SelectOp selectOp) {
1119 Value xbuff = selectOp.getXbuff();
1120 assert(xbuff &&
"xbuff empty in select op");
1122 raw_indented_ostream &
os = emitter.ostream();
1125 if (failed(emitter.emitAssignPrefix(*selectOp)))
1129 auto xbuffType = llvm::cast<VectorType>(selectOp.getXbuff().getType());
1131 assert(elementSizeInBits == 16 || elementSizeInBits == 32 ||
1132 elementSizeInBits == 64);
1134 os << (elementSizeInBits == 16 ?
"select32"
1135 : elementSizeInBits == 32 ?
"select16"
1139 assert(!selectOp.getSelect().empty());
1140 os << selectOp.getSelect();
1142 if (!emitter.hasValueInScope(xbuff))
1146 os << emitter.getOrCreateName(xbuff);
1148 if (!selectOp.getXstart().empty())
1149 os <<
", " << selectOp.getXstart();
1150 if (!selectOp.getXoffsets().empty())
1151 os <<
", " << selectOp.getXoffsets();
1152 if (!selectOp.getXoffsetsHi().empty())
1153 os <<
", " << selectOp.getXoffsetsHi();
1154 if (!selectOp.getXsquare().empty())
1155 os <<
", " << selectOp.getXsquare();
1157 if (selectOp.getYbuff()) {
1158 Value ybuff = selectOp.getYbuff();
1160 if (!emitter.hasValueInScope(ybuff))
1164 os << emitter.getOrCreateName(ybuff);
1167 if (!selectOp.getYstart().empty())
1168 os <<
", " << selectOp.getYstart();
1169 if (!selectOp.getYoffsets().empty())
1170 os <<
", " << selectOp.getYoffsets();
1171 if (!selectOp.getYoffsetsHi().empty())
1172 os <<
", " << selectOp.getYoffsetsHi();
1173 if (!selectOp.getYsquare().empty())
1174 os <<
", " << selectOp.getYsquare();
1181static LogicalResult printOperation(CppEmitter &emitter,
1182 aievec::PackOp packOp) {
1183 Value source = packOp.getSource();
1185 raw_indented_ostream &
os = emitter.ostream();
1188 if (failed(emitter.emitAssignPrefix(*packOp)))
1192 auto sourceType = llvm::cast<VectorType>(packOp.getSource().getType());
1193 Type scalarType = sourceType.getElementType();
1194 os << (scalarType.isUnsignedInteger() ?
"upack" :
"pack");
1197 if (!emitter.hasValueInScope(source))
1199 os << emitter.getOrCreateName(source);
1206template <
typename T>
1207static LogicalResult printAddOrSubOperand(CppEmitter &emitter, T op,
1214 Value operand = opNum == 0 ? op.getLhs() : op.getRhs();
1215 if (!emitter.hasValueInScope(operand))
1218 raw_indented_ostream &
os = emitter.ostream();
1220 StringRef start = op.getStart(opNum);
1221 StringRef offset = op.getOffset(opNum);
1222 StringRef offsetHi = op.getOffsetHi(opNum);
1223 StringRef square = op.getSquare(opNum);
1225 os << emitter.getOrCreateName(operand);
1227 os <<
", " << start;
1228 if (!offset.empty())
1229 os <<
", " << offset;
1230 if (!offsetHi.empty())
1231 os <<
", " << offsetHi;
1232 if (!square.empty())
1233 os <<
", " << square;
1239template <
typename T>
1240static LogicalResult printMinMaxOperand(CppEmitter &emitter, T op,
1247 Value operand = opNum == 0 ? op.getLhs() : op.getRhs();
1248 if (!emitter.hasValueInScope(operand))
1251 raw_indented_ostream &
os = emitter.ostream();
1252 os << emitter.getOrCreateName(operand);
1258template <
typename T>
1259static LogicalResult printAddElemOrSubElemOperand(CppEmitter &emitter, T op,
1266 Value operand = opNum == 0 ? op.getLhs() : op.getRhs();
1267 if (!emitter.hasValueInScope(operand))
1270 raw_indented_ostream &
os = emitter.ostream();
1271 os << emitter.getOrCreateName(operand);
1277template <
typename T>
1278static LogicalResult printFMAOrMulOperand(CppEmitter &emitter, T op,
1285 Value operand = opNum == 0 ? op.getLhs() : op.getRhs();
1286 if (!emitter.hasValueInScope(operand))
1289 raw_indented_ostream &
os = emitter.ostream();
1291 StringRef start = op.getStart(opNum);
1292 StringRef offset = op.getOffset(opNum);
1293 StringRef offsetHi = op.getOffsetHi(opNum);
1294 StringRef step = op.getStep(opNum);
1295 StringRef square = op.getSquare(opNum);
1297 os << emitter.getOrCreateName(operand);
1299 os <<
", " << start;
1300 if (!offset.empty())
1301 os <<
", " << offset;
1302 if (!offsetHi.empty())
1303 os <<
", " << offsetHi;
1306 if (!square.empty())
1307 os <<
", " << square;
1313template <
typename T>
1314static LogicalResult printFMAOrMulElemOperand(CppEmitter &emitter, T op,
1315 Type iType, int32_t size,
1322 Value operand = opNum == 0 ? op.getLhs() : op.getRhs();
1323 if (!emitter.hasValueInScope(operand))
1326 raw_indented_ostream &
os = emitter.ostream();
1327 os << emitter.getOrCreateName(operand);
1328 if (size == 32 && iType)
1329 os <<
", " << (opNum == 0 ?
"undef_v16int32()" :
"broadcast_zero_s32()");
1335template <
typename T>
1336static LogicalResult printFMAOrMulConvOperand(CppEmitter &emitter, T op,
1343 Value operand = opNum == 0 ? op.getLhs() : op.getRhs();
1344 if (!emitter.hasValueInScope(operand))
1347 raw_indented_ostream &
os = emitter.ostream();
1348 os << emitter.getOrCreateName(operand);
1354static LogicalResult printOperation(CppEmitter &emitter,
1355 aievec::aie1::MulOp mulOp) {
1356 auto lhs = mulOp.getLhs();
1357 auto rhs = mulOp.getRhs();
1360 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs))
1364 bool simpleScheme = mulOp.getStart(0).empty();
1368 auto resType = llvm::cast<VectorType>(mulOp.getResult().getType());
1369 Type eltType = resType.getElementType();
1370 if (!simpleScheme) {
1371 if (
auto iType = llvm::dyn_cast<IntegerType>(eltType)) {
1372 if (iType.getWidth() == 80)
1374 }
else if (llvm::isa<FloatType>(eltType))
1379 if (!simpleScheme && !llvm::isa<FloatType>(eltType))
1382 raw_indented_ostream &
os = emitter.ostream();
1385 if (failed(emitter.emitAssignPrefix(*mulOp)))
1390 if (failed(printFMAOrMulOperand<aievec::aie1::MulOp>(emitter, mulOp, 0)))
1393 if (failed(printFMAOrMulOperand<aievec::aie1::MulOp>(emitter, mulOp, 1)))
1400static std::string printConversionTo512bit(CppEmitter &emitter, Value v) {
1401 std::string vName = emitter.getOrCreateName(v).str();
1402 auto vTy = cast<VectorType>(v.getType());
1403 auto vShape = vTy.getShape();
1404 int64_t elemBitWidth = vTy.getElementTypeBitWidth();
1405 int64_t numElems = std::accumulate(vShape.begin(), vShape.end(), 1,
1406 std::multiplies<int64_t>());
1407 int64_t vBitWidth = numElems * elemBitWidth;
1408 if (vBitWidth >= 512)
1411 int64_t newNumElems = 512 / elemBitWidth;
1413 std::string vNewName = emitter.getNewName();
1414 raw_indented_ostream &
os = emitter.ostream();
1415 auto newVecTy = VectorType::get({512 / elemBitWidth}, vTy.getElementType());
1417 emitter.genCppTypeName(newVecTy,
false,
false));
1419 *(emitter.genCppTypeName(vTy,
false,
false));
1421 os << newTyName <<
" " << vNewName <<
" = concat(";
1422 if (newNumElems / numElems == 4) {
1423 os <<
"concat(" << vName <<
", undef_" << oldTyName <<
"())";
1424 oldTyName = *(emitter.genCppTypeName(
1425 VectorType::get({256 / elemBitWidth}, vTy.getElementType())));
1429 os <<
", undef_" << oldTyName <<
"());\n";
1434static LogicalResult printOperation(CppEmitter &emitter,
1435 aievec::MulElemOp mulElemOp) {
1436 auto lhs = mulElemOp.getLhs();
1437 auto rhs = mulElemOp.getRhs();
1440 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs))
1443 auto lhsName = printConversionTo512bit(emitter, lhs);
1444 auto rhsName = printConversionTo512bit(emitter, rhs);
1446 std::string opname =
"mul_elem";
1449 auto lhsType = llvm::cast<VectorType>(mulElemOp.getLhs().getType());
1450 Type eltType = lhsType.getElementType();
1452 auto iType = llvm::dyn_cast<IntegerType>(eltType);
1457 else if (lsize == 16)
1459 else if (lsize == 8)
1461 }
else if (llvm::isa<FloatType>(eltType)) {
1464 else if (lsize == 16)
1468 raw_indented_ostream &
os = emitter.ostream();
1471 if (failed(emitter.emitAssignPrefix(*mulElemOp,
true )))
1475 os <<
"(" << lhsName;
1476 if ((lsize == 32) && iType)
1478 <<
"undef_v16int32()";
1479 os <<
" ," << rhsName;
1480 if ((lsize == 32) && iType)
1482 <<
"broadcast_zero_s32()";
1488static LogicalResult printOperation(CppEmitter &emitter,
1489 aievec::MulConvOp mulConvOp) {
1490 auto lhs = mulConvOp.getLhs();
1491 auto rhs = mulConvOp.getRhs();
1494 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs))
1498 auto lhsType = llvm::cast<VectorType>(mulConvOp.getLhs().getType());
1499 Type eltType = lhsType.getElementType();
1501 auto iType = llvm::dyn_cast<IntegerType>(eltType);
1504 if (!iType || !(lsize == 16 || lsize == 8)) {
1508 int32_t M = mulConvOp.getM();
1509 int32_t N = mulConvOp.getN();
1510 std::string opname =
1511 "mul_conv_" + std::to_string(M) +
"x" + std::to_string(N);
1513 raw_indented_ostream &
os = emitter.ostream();
1516 if (failed(emitter.emitAssignPrefix(*mulConvOp,
true )))
1523 printFMAOrMulConvOperand<aievec::MulConvOp>(emitter, mulConvOp, 0)))
1527 printFMAOrMulConvOperand<aievec::MulConvOp>(emitter, mulConvOp, 1)))
1535static LogicalResult printOperation(CppEmitter &emitter,
1536 aievec::aie1::AddOp addOp) {
1537 auto lhs = addOp.getLhs();
1538 auto rhs = addOp.getRhs();
1541 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs))
1544 raw_indented_ostream &
os = emitter.ostream();
1547 if (failed(emitter.emitAssignPrefix(*addOp)))
1551 auto resultType = llvm::cast<VectorType>(addOp.getResult().getType());
1553 Type elementType = resultType.getElementType();
1554 bool floatType = llvm::isa<FloatType>(elementType);
1558 if (addOp.getStart(0).empty()) {
1563 os << emitter.getOrCreateName(lhs);
1565 os << emitter.getOrCreateName(rhs);
1570 os << emitter.getOrCreateName(lhs);
1572 os << emitter.getOrCreateName(rhs);
1577 os << (floatType ?
"fpadd" :
"add" + std::to_string(lanes));
1579 if (failed(printAddOrSubOperand<aievec::aie1::AddOp>(emitter, addOp, 0)))
1582 if (failed(printAddOrSubOperand<aievec::aie1::AddOp>(emitter, addOp, 1)))
1590static LogicalResult printOperation(CppEmitter &emitter,
1591 aievec::aie1::SubOp subOp) {
1592 auto lhs = subOp.getLhs();
1593 auto rhs = subOp.getRhs();
1596 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs))
1599 raw_indented_ostream &
os = emitter.ostream();
1602 if (failed(emitter.emitAssignPrefix(*subOp)))
1606 auto resultType = llvm::cast<VectorType>(subOp.getResult().getType());
1608 Type elementType = resultType.getElementType();
1609 bool floatType = llvm::isa<FloatType>(elementType);
1613 if (subOp.getStart(0).empty()) {
1618 os << emitter.getOrCreateName(lhs);
1620 os << emitter.getOrCreateName(rhs);
1625 os << emitter.getOrCreateName(lhs);
1627 os << emitter.getOrCreateName(rhs);
1632 os << (floatType ?
"fpsub" :
"sub" + std::to_string(lanes));
1634 if (failed(printAddOrSubOperand<aievec::aie1::SubOp>(emitter, subOp, 0)))
1637 if (failed(printAddOrSubOperand<aievec::aie1::SubOp>(emitter, subOp, 1)))
1645static LogicalResult printOperation(CppEmitter &emitter, aievec::MinOp minOp) {
1646 auto lhs = minOp.getLhs();
1647 auto rhs = minOp.getRhs();
1650 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs))
1653 raw_indented_ostream &
os = emitter.ostream();
1656 if (failed(emitter.emitAssignPrefix(*minOp)))
1660 if (failed(printMinMaxOperand<aievec::MinOp>(emitter, minOp, 0)))
1663 if (failed(printMinMaxOperand<aievec::MinOp>(emitter, minOp, 1)))
1671static LogicalResult printOperation(CppEmitter &emitter, aievec::MaxOp maxOp) {
1672 auto lhs = maxOp.getLhs();
1673 auto rhs = maxOp.getRhs();
1676 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs))
1679 raw_indented_ostream &
os = emitter.ostream();
1682 if (failed(emitter.emitAssignPrefix(*maxOp)))
1686 if (failed(printMinMaxOperand<aievec::MaxOp>(emitter, maxOp, 0)))
1689 if (failed(printMinMaxOperand<aievec::MaxOp>(emitter, maxOp, 1)))
1697static LogicalResult printOperation(CppEmitter &emitter, aievec::NegOp negOp) {
1698 auto src = negOp.getSource();
1701 if (!emitter.hasValueInScope(src))
1704 raw_indented_ostream &
os = emitter.ostream();
1707 if (failed(emitter.emitAssignPrefix(*negOp,
true )))
1711 os << emitter.getOrCreateName(src);
1718static LogicalResult printOperation(CppEmitter &emitter,
1719 aievec::BnegOp bnegOp) {
1720 auto src = bnegOp.getSource();
1723 if (!emitter.hasValueInScope(src))
1726 raw_indented_ostream &
os = emitter.ostream();
1729 if (failed(emitter.emitAssignPrefix(*bnegOp)))
1733 os << emitter.getOrCreateName(src);
1740static LogicalResult printOperation(CppEmitter &emitter, aievec::BxorOp xorOp) {
1741 auto lhs = xorOp.getLhs();
1742 auto rhs = xorOp.getRhs();
1745 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs))
1748 raw_indented_ostream &
os = emitter.ostream();
1751 if (failed(emitter.emitAssignPrefix(*xorOp)))
1755 os << emitter.getOrCreateName(lhs);
1757 os << emitter.getOrCreateName(rhs);
1764static LogicalResult printOperation(CppEmitter &emitter, aievec::BandOp andOp) {
1765 auto lhs = andOp.getLhs();
1766 auto rhs = andOp.getRhs();
1769 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs))
1772 raw_indented_ostream &
os = emitter.ostream();
1775 if (failed(emitter.emitAssignPrefix(*andOp)))
1779 os << emitter.getOrCreateName(lhs);
1781 os << emitter.getOrCreateName(rhs);
1788static LogicalResult printOperation(CppEmitter &emitter, aievec::BorOp orOp) {
1789 auto lhs = orOp.getLhs();
1790 auto rhs = orOp.getRhs();
1793 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs))
1796 raw_indented_ostream &
os = emitter.ostream();
1799 if (failed(emitter.emitAssignPrefix(*orOp)))
1803 os << emitter.getOrCreateName(lhs);
1805 os << emitter.getOrCreateName(rhs);
1812static LogicalResult printOperation(CppEmitter &emitter,
1813 aievec::AddElemOp addElemOp) {
1814 auto lhs = addElemOp.getLhs();
1815 auto rhs = addElemOp.getRhs();
1818 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs))
1821 raw_indented_ostream &
os = emitter.ostream();
1826 auto resType = cast<VectorType>(addElemOp.getResult().getType());
1827 auto resElemType = resType.getElementType();
1828 unsigned resBitWidth = resElemType.getIntOrFloatBitWidth();
1830 if (isa<FloatType>(resElemType) || resBitWidth * resLaneSize == 1024)
1833 if (failed(emitter.emitAssignPrefix(*addElemOp, isAcc)))
1837 if (failed(printAddElemOrSubElemOperand<aievec::AddElemOp>(emitter, addElemOp,
1841 if (failed(printAddElemOrSubElemOperand<aievec::AddElemOp>(emitter, addElemOp,
1850static LogicalResult printOperation(CppEmitter &emitter,
1851 aievec::SubElemOp subElemOp) {
1852 auto lhs = subElemOp.getLhs();
1853 auto rhs = subElemOp.getRhs();
1856 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs))
1859 raw_indented_ostream &
os = emitter.ostream();
1864 auto resType = cast<VectorType>(subElemOp.getResult().getType());
1865 auto resElemType = resType.getElementType();
1866 unsigned resBitWidth = resElemType.getIntOrFloatBitWidth();
1868 if (isa<FloatType>(resElemType) || resBitWidth * resLaneSize == 1024)
1871 if (failed(emitter.emitAssignPrefix(*subElemOp, isAcc)))
1875 if (failed(printAddElemOrSubElemOperand<aievec::SubElemOp>(emitter, subElemOp,
1879 if (failed(printAddElemOrSubElemOperand<aievec::SubElemOp>(emitter, subElemOp,
1888static LogicalResult printOperation(CppEmitter &emitter,
1889 aievec::aie1::FMAOp fmaOp) {
1890 auto acc = fmaOp.getAcc();
1891 auto lhs = fmaOp.getLhs();
1892 auto rhs = fmaOp.getRhs();
1895 if (!emitter.hasValueInScope(acc) || !emitter.hasValueInScope(lhs) ||
1896 !emitter.hasValueInScope(rhs))
1900 bool simpleScheme = fmaOp.getStart(0).empty();
1904 auto resType = llvm::cast<VectorType>(fmaOp.getResult().getType());
1905 Type eltType = resType.getElementType();
1906 if (!simpleScheme) {
1907 if (
auto iType = llvm::dyn_cast<IntegerType>(eltType)) {
1908 if (iType.getWidth() == 80)
1910 }
else if (llvm::isa<FloatType>(eltType))
1914 opname += fmaOp.getFmsub() ?
"msc" :
"mac";
1915 if (!simpleScheme && !llvm::isa<FloatType>(eltType))
1918 raw_indented_ostream &
os = emitter.ostream();
1920 StringRef accName = emitter.getOrCreateName(acc);
1927 if (failed(printFMAOrMulOperand<aievec::aie1::FMAOp>(emitter, fmaOp, 0)))
1930 if (failed(printFMAOrMulOperand<aievec::aie1::FMAOp>(emitter, fmaOp, 1)))
1935 emitter.setName(fmaOp->getResult(0), accName);
1941static LogicalResult printOperation(CppEmitter &emitter,
1942 aievec::FMAElemOp fmaElemOp) {
1943 auto acc = fmaElemOp.getAcc();
1944 auto lhs = fmaElemOp.getLhs();
1945 auto rhs = fmaElemOp.getRhs();
1948 if (!emitter.hasValueInScope(acc) || !emitter.hasValueInScope(lhs) ||
1949 !emitter.hasValueInScope(rhs))
1952 std::string opname = fmaElemOp.getFmsub() ?
"msc_elem" :
"mac_elem";
1954 auto lhsType = llvm::cast<VectorType>(fmaElemOp.getLhs().getType());
1955 Type eltType = lhsType.getElementType();
1957 auto iType = llvm::dyn_cast<IntegerType>(eltType);
1962 else if (lsize == 16)
1964 else if (lsize == 8)
1966 }
else if (llvm::isa<FloatType>(eltType)) {
1969 else if (lsize == 16)
1973 raw_indented_ostream &
os = emitter.ostream();
1975 StringRef accName = emitter.getOrCreateName(acc);
1980 if (failed(printFMAOrMulElemOperand<aievec::FMAElemOp>(emitter, fmaElemOp,
1984 if (failed(printFMAOrMulElemOperand<aievec::FMAElemOp>(emitter, fmaElemOp,
1992 emitter.setName(fmaElemOp->getResult(0), accName);
1998static LogicalResult printOperation(CppEmitter &emitter,
1999 aievec::FMAConvOp fmaConvOp) {
2000 auto acc = fmaConvOp.getAcc();
2001 auto lhs = fmaConvOp.getLhs();
2002 auto rhs = fmaConvOp.getRhs();
2005 if (!emitter.hasValueInScope(acc) || !emitter.hasValueInScope(lhs) ||
2006 !emitter.hasValueInScope(rhs))
2009 std::string opname = fmaConvOp.getFmsub() ?
"msc_conv" :
"mac_conv";
2011 auto lhsType = llvm::cast<VectorType>(fmaConvOp.getLhs().getType());
2012 Type eltType = lhsType.getElementType();
2014 auto iType = llvm::dyn_cast<IntegerType>(eltType);
2017 if (!iType || !(lsize == 16 || lsize == 8))
2020 int32_t M = fmaConvOp.getM();
2021 int32_t N = fmaConvOp.getN();
2022 opname +=
"_" + std::to_string(M) +
"x" + std::to_string(N);
2024 raw_indented_ostream &
os = emitter.ostream();
2026 StringRef accName = emitter.getOrCreateName(acc);
2032 printFMAOrMulConvOperand<aievec::FMAConvOp>(emitter, fmaConvOp, 0)))
2036 printFMAOrMulConvOperand<aievec::FMAConvOp>(emitter, fmaConvOp, 1)))
2043 emitter.setName(fmaConvOp->getResult(0), accName);
2049static LogicalResult printOperation(CppEmitter &emitter, aievec::CmpOp cmpOp) {
2050 if (!emitter.aie2())
2054 Value lhs = cmpOp.getLhs();
2055 Value rhs = cmpOp.getRhs();
2057 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs))
2061 if (failed(emitter.emitAssignPrefix(*cmpOp)))
2064 raw_indented_ostream &
os = emitter.ostream();
2066 StringRef pred = cmpOp.getPred();
2069 else if (pred ==
"ne")
2071 else if (pred ==
"slt" || pred ==
"ult")
2073 else if (pred ==
"sle" || pred ==
"ule")
2075 else if (pred ==
"sgt" || pred ==
"ugt")
2077 else if (pred ==
"sge" || pred ==
"uge")
2083 auto vType = llvm::cast<VectorType>(lhs.getType());
2085 if (Type eltType = vType.getElementType();
2086 llvm::isa<IntegerType>(eltType) &&
2087 (pred ==
"ult" || pred ==
"ule" || pred ==
"ugt" || pred ==
"uge")) {
2090 os <<
"v" << std::to_string(lanes) <<
"uint" << std::to_string(width);
2092 os << emitter.getOrCreateName(lhs);
2094 os <<
"v" << std::to_string(lanes) <<
"uint" << std::to_string(width);
2096 os << emitter.getOrCreateName(rhs);
2099 os << emitter.getOrCreateName(lhs);
2101 os << emitter.getOrCreateName(rhs);
2109static LogicalResult printOperation(CppEmitter &emitter, aievec::SelOp selOp) {
2110 if (!emitter.aie2())
2114 Value lhs = selOp.getLhs();
2115 Value rhs = selOp.getRhs();
2116 Value sel = selOp.getSel();
2118 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs) ||
2119 !emitter.hasValueInScope(sel))
2123 if (failed(emitter.emitAssignPrefix(*selOp)))
2126 raw_indented_ostream &
os = emitter.ostream();
2129 os << emitter.getOrCreateName(rhs);
2131 os << emitter.getOrCreateName(lhs);
2133 os << emitter.getOrCreateName(sel);
2140static LogicalResult printOperation(CppEmitter &emitter,
2141 aievec::ExtElemOp extElemOp) {
2142 Value source = extElemOp.getSource();
2143 Value index = extElemOp.getIndex();
2145 raw_indented_ostream &
os = emitter.ostream();
2148 if (failed(emitter.emitAssignPrefix(*extElemOp)))
2152 if (!emitter.hasValueInScope(source))
2155 os <<
"extract_elem";
2158 os << emitter.getOrCreateName(source);
2160 os << emitter.getOrCreateName(index);
2167static LogicalResult printOperation(CppEmitter &emitter,
2168 vector::TransferWriteOp writeOp) {
2169 Value source = writeOp.getBase();
2170 Value vector = writeOp.getVector();
2174 if (!emitter.hasValueInScope(source) || !emitter.hasValueInScope(vector))
2179 auto indices = writeOp.getIndices();
2180 if (failed(createLinearizedAccess(emitter, source, indices, access)))
2183 raw_indented_ostream &
os = emitter.ostream();
2186 if (failed(emitter.emitType(writeOp->getLoc(), vector.getType())))
2190 os << emitter.getOrCreateName(source);
2191 if (!access.empty())
2192 os <<
" + " << access;
2195 os << emitter.getOrCreateName(vector);
2201static LogicalResult printOperation(CppEmitter &emitter,
2202 memref::StoreOp storeOp) {
2203 Value
value = storeOp.getValue();
2204 Value memref = storeOp.getMemref();
2208 if (!emitter.hasValueInScope(value) || !emitter.hasValueInScope(memref))
2211 raw_indented_ostream &
os = emitter.ostream();
2214 if (failed(emitter.emitType(
2216 cast<MemRefType>(memref.getType()).getElementType())))
2219 os << emitter.getOrCreateName(memref);
2221 os << emitter.getOrCreateName(value);
2227template <
typename OpTy>
2228static LogicalResult printValueForwardOperation(CppEmitter &emitter, OpTy op) {
2229 Value source = op.getSrc();
2233 if (!emitter.hasValueInScope(source))
2236 if (failed(emitter.emitAssignPrefix(*op)))
2239 raw_indented_ostream &
os = emitter.ostream();
2240 os << emitter.getOrCreateName(source);
2246static LogicalResult printOperation(CppEmitter &emitter,
2247 memref::ExpandShapeOp expandShapeOp) {
2248 return printValueForwardOperation<memref::ExpandShapeOp>(emitter,
2253static LogicalResult printOperation(CppEmitter &emitter,
2254 memref::CollapseShapeOp collapseShapeOp) {
2255 return printValueForwardOperation<memref::CollapseShapeOp>(emitter,
2259static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
2261 OpResult result = operation->getResult(0);
2265 if (emitter.shouldDeclareVariablesAtTop()) {
2267 if (
auto oAttr = llvm::dyn_cast<emitc::OpaqueAttr>(value))
2268 if (oAttr.getValue().empty())
2271 if (failed(emitter.emitVariableAssignment(result)))
2273 return emitter.emitAttribute(operation->getLoc(), value);
2277 if (
auto oAttr = llvm::dyn_cast<emitc::OpaqueAttr>(value))
2278 if (oAttr.getValue().empty())
2280 return emitter.emitVariableDeclaration(result,
2284 if (failed(emitter.emitAssignPrefix(*operation)))
2286 return emitter.emitAttribute(operation->getLoc(), value);
2289static LogicalResult printOperation(CppEmitter &emitter,
2290 emitc::ConstantOp constantOp) {
2291 Operation *operation = constantOp.getOperation();
2292 Attribute
value = constantOp.getValue();
2293 return printConstantOp(emitter, operation, value);
2296static LogicalResult printOperation(CppEmitter &emitter,
2297 arith::ConstantOp constantOp) {
2298 Operation *operation = constantOp.getOperation();
2299 Attribute
value = constantOp.getValue();
2300 return printConstantOp(emitter, operation, value);
2303static LogicalResult printOperation(CppEmitter &emitter,
2304 cf::BranchOp branchOp) {
2305 raw_ostream &
os = emitter.ostream();
2306 Block &successor = *branchOp.getSuccessor();
2308 for (
auto pair : zip(branchOp.getOperands(), successor.getArguments())) {
2309 Value &operand = std::get<0>(pair);
2310 BlockArgument &argument = std::get<1>(pair);
2311 os << emitter.getOrCreateName(argument) <<
" = "
2312 << emitter.getOrCreateName(operand) <<
";\n";
2316 if (!emitter.hasBlockLabel(successor))
2317 return branchOp.emitOpError(
"unable to find label for successor block");
2318 os << emitter.getOrCreateName(successor);
2322static LogicalResult printOperation(CppEmitter &emitter,
2323 cf::CondBranchOp condBranchOp) {
2324 raw_indented_ostream &
os = emitter.ostream();
2325 Block &trueSuccessor = *condBranchOp.getTrueDest();
2326 Block &falseSuccessor = *condBranchOp.getFalseDest();
2328 os <<
"if (" << emitter.getOrCreateName(condBranchOp.getCondition())
2335 zip(condBranchOp.getTrueOperands(), trueSuccessor.getArguments())) {
2336 Value &operand = std::get<0>(pair);
2337 BlockArgument &argument = std::get<1>(pair);
2338 os << emitter.getOrCreateName(argument) <<
" = "
2339 << emitter.getOrCreateName(operand) <<
";\n";
2343 if (!emitter.hasBlockLabel(trueSuccessor))
2344 return condBranchOp.emitOpError(
"unable to find label for successor block");
2345 os << emitter.getOrCreateName(trueSuccessor) <<
";\n";
2346 os.unindent() <<
"} else {\n";
2350 zip(condBranchOp.getFalseOperands(), falseSuccessor.getArguments())) {
2351 Value &operand = std::get<0>(pair);
2352 BlockArgument &argument = std::get<1>(pair);
2353 os << emitter.getOrCreateName(argument) <<
" = "
2354 << emitter.getOrCreateName(operand) <<
";\n";
2358 if (!emitter.hasBlockLabel(falseSuccessor))
2359 return condBranchOp.emitOpError()
2360 <<
"unable to find label for successor block";
2361 os << emitter.getOrCreateName(falseSuccessor) <<
";\n";
2362 os.unindent() <<
"}";
2367static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) {
2368 if (failed(emitter.emitAssignPrefix(*callOp.getOperation())))
2371 raw_ostream &
os = emitter.ostream();
2372 os << callOp.getCallee() <<
"(";
2373 if (failed(emitter.emitOperands(*callOp.getOperation())))
2380static LogicalResult printOperation(CppEmitter &emitter,
2381 emitc::CallOpaqueOp callOp) {
2382 raw_ostream &
os = emitter.ostream();
2383 Operation &op = *callOp.getOperation();
2384 if (callOp.getCallee() ==
"getTanhBf16" ||
2385 callOp.getCallee() ==
"getSqrtBf16" ||
2386 callOp.getCallee() ==
"getRsqrtBf16" ||
2387 callOp.getCallee() ==
"getErfBf16" || callOp.getCallee() ==
"getAbs" ||
2388 callOp.getCallee() ==
"getSigmoidBf16" ||
2389 callOp.getCallee() ==
"getCeilBf16" ||
2390 callOp.getCallee() ==
"getFloorBf16") {
2391 if (failed(emitter.emitAssignPrefix(op,
false)))
2393 }
else if (failed(emitter.emitAssignPrefix(op,
true)))
2396 os << callOp.getCallee();
2398 auto emitArgs = [&](Attribute attr) -> LogicalResult {
2400 if (
auto t = llvm::dyn_cast<IntegerAttr>(attr))
2401 if (t.getType().isIndex()) {
2402 int64_t idx = t.getInt();
2403 if (idx < 0 || idx >= op.getNumOperands())
2404 return op.emitOpError(
"invalid operand index");
2405 if (!emitter.hasValueInScope(op.getOperand(idx)))
2406 return op.emitOpError(
"operand ")
2407 << idx <<
"'s value not defined in scope";
2408 os << emitter.getOrCreateName(op.getOperand(idx));
2411 if (failed(emitter.emitAttribute(op.getLoc(), attr)))
2417 if (callOp.getTemplateArgs()) {
2427 LogicalResult emittedArgs =
2430 : emitter.emitOperands(op);
2431 if (failed(emittedArgs))
2438static LogicalResult printOperation(CppEmitter &emitter,
2439 emitc::ApplyOp applyOp) {
2440 raw_ostream &
os = emitter.ostream();
2442 if (Operation &op = *applyOp.getOperation();
2443 failed(emitter.emitAssignPrefix(op)))
2445 os << applyOp.getApplicableOperator();
2446 os << emitter.getOrCreateName(applyOp.getOperand());
2451static LogicalResult printOperation(CppEmitter &emitter,
2452 emitc::IncludeOp includeOp) {
2453 raw_ostream &
os = emitter.ostream();
2456 if (includeOp.getIsStandardInclude())
2457 os <<
"<" << includeOp.getInclude() <<
">";
2459 os <<
"\"" << includeOp.getInclude() <<
"\"";
2464static LogicalResult printOperation(CppEmitter &emitter, scf::ForOp forOp) {
2465 raw_indented_ostream &
os = emitter.ostream();
2467 OperandRange operands = forOp.getInitArgs();
2468 Block::BlockArgListType iterArgs = forOp.getRegionIterArgs();
2469 Operation::result_range results = forOp.getResults();
2471 if (!emitter.shouldDeclareVariablesAtTop())
2472 for (OpResult result : results)
2473 if (failed(emitter.emitVariableDeclaration(result,
2477 for (
auto pair : zip(iterArgs, operands)) {
2478 if (failed(emitter.emitType(forOp.getLoc(), std::get<0>(pair).getType())))
2480 os <<
" " << emitter.getOrCreateName(std::get<0>(pair)) <<
" = ";
2481 os << emitter.getOrCreateName(std::get<1>(pair)) <<
";";
2487 emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
2491 os << emitter.getOrCreateName(forOp.getInductionVar());
2493 os << emitter.getOrCreateName(forOp.getLowerBound());
2495 os << emitter.getOrCreateName(forOp.getInductionVar());
2497 os << emitter.getOrCreateName(forOp.getUpperBound());
2499 os << emitter.getOrCreateName(forOp.getInductionVar());
2501 os << emitter.getOrCreateName(forOp.getStep());
2503 os <<
"chess_prepare_for_pipelining\n";
2506 if (
auto [constantLoopBound, tripCount] = getTripCount(forOp);
2507 constantLoopBound) {
2508 auto [constantStep, step] = getStep(forOp);
2510 constantStep && step > 0 ? llvm::divideFloorSigned(tripCount, step) : 1;
2512 constantStep && step > 0 ? llvm::divideCeilSigned(tripCount, step) : 0;
2513 os <<
"chess_loop_range(";
2514 os << std::to_string(lb);
2516 if (constantStep && step > 0)
2517 os << std::to_string(ub);
2523 Region &forRegion = forOp.getRegion();
2524 auto regionOps = forRegion.getOps();
2530 for (
auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) {
2531 if (
bool trailingSemicolon =
2532 !isa<scf::IfOp, scf::ForOp, cf::CondBranchOp>(*it);
2533 failed(emitter.emitOperation(*it, trailingSemicolon)))
2537 Operation *yieldOp = forRegion.getBlocks().front().getTerminator();
2539 for (
auto pair : zip(iterArgs, yieldOp->getOperands())) {
2540 BlockArgument iterArg = std::get<0>(pair);
2541 Value operand = std::get<1>(pair);
2542 os << emitter.getOrCreateName(iterArg) <<
" = "
2543 << emitter.getOrCreateName(operand) <<
";\n";
2546 os.unindent() <<
"}";
2549 for (
auto pair : zip(results, iterArgs)) {
2550 OpResult result = std::get<0>(pair);
2551 BlockArgument iterArg = std::get<1>(pair);
2553 << emitter.getOrCreateName(result) <<
" = "
2554 << emitter.getOrCreateName(iterArg) <<
";";
2560static LogicalResult printOperation(CppEmitter &emitter, scf::IfOp ifOp) {
2561 raw_indented_ostream &
os = emitter.ostream();
2563 if (!emitter.shouldDeclareVariablesAtTop())
2564 for (OpResult result : ifOp.getResults())
2565 if (failed(emitter.emitVariableDeclaration(result,
2570 if (failed(emitter.emitOperands(*ifOp.getOperation())))
2575 Region &thenRegion = ifOp.getThenRegion();
2578 for (Operation &op : thenRegion.getOps())
2579 if (failed(emitter.emitOperation(op, true)))
2582 os.unindent() <<
"}";
2584 if (Region &elseRegion = ifOp.getElseRegion(); !elseRegion.empty()) {
2590 for (Operation &op : elseRegion.getOps())
2591 if (failed(emitter.emitOperation(op, true)))
2594 os.unindent() <<
"}";
2600static LogicalResult printOperation(CppEmitter &emitter, scf::YieldOp yieldOp) {
2601 raw_ostream &
os = emitter.ostream();
2602 Operation &parentOp = *yieldOp.getOperation()->getParentOp();
2604 if (yieldOp.getNumOperands() != parentOp.getNumResults())
2605 return yieldOp.emitError(
"number of operands does not to match the number "
2606 "of the parent op's results");
2609 llvm::zip(parentOp.getResults(), yieldOp.getOperands()),
2610 [&](
auto pair) -> LogicalResult {
2611 auto result = std::get<0>(pair);
2612 auto operand = std::get<1>(pair);
2613 os << emitter.getOrCreateName(result) <<
" = ";
2615 if (!emitter.hasValueInScope(operand))
2616 return yieldOp.emitError(
"operand value not in scope");
2617 os << emitter.getOrCreateName(operand);
2620 [&] { os <<
";\n"; })))
2626static LogicalResult printOperation(CppEmitter &emitter,
2627 func::ReturnOp returnOp) {
2628 raw_ostream &
os = emitter.ostream();
2630 switch (returnOp.getNumOperands()) {
2634 os <<
" " << emitter.getOrCreateName(returnOp.getOperand(0));
2635 return success(emitter.hasValueInScope(returnOp.getOperand(0)));
2637 os <<
" std::make_tuple(";
2638 if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation())))
2646static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
2647 CppEmitter::Scope scope(emitter);
2649 for (Operation &op : moduleOp)
2650 if (failed(emitter.emitOperation(op, false)))
2656static LogicalResult printOperation(CppEmitter &emitter,
2657 AIE::DeviceOp deviceOp) {
2658 CppEmitter::Scope scope(emitter);
2659 raw_indented_ostream &
os = emitter.ostream();
2662 os <<
"aie.device(" << deviceOp.getDevice() <<
") {\n";
2666 Region ®ion = deviceOp.getBodyRegion();
2667 for (Block &block : region.getBlocks()) {
2668 for (Operation &op : block.getOperations()) {
2670 if (op.hasTrait<OpTrait::IsTerminator>())
2673 if (failed(emitter.emitOperation(op,
false)))
2678 os.unindent() <<
"}\n";
2682static LogicalResult printOperation(CppEmitter &emitter,
2683 func::FuncOp functionOp) {
2685 if (!emitter.shouldDeclareVariablesAtTop() &&
2686 functionOp.getBlocks().size() > 1)
2687 return functionOp.emitOpError(
2688 "with multiple blocks needs variables declared at top");
2690 CppEmitter::Scope scope(emitter);
2694 if (failed(parseMemRefDynamicDims(emitter, functionOp)))
2697 raw_indented_ostream &
os = emitter.ostream();
2698 if (failed(emitter.emitTypes(functionOp.getLoc(),
2699 functionOp.getFunctionType().getResults())))
2701 os <<
" " << functionOp.getName();
2704 if (functionOp.isDeclaration()) {
2706 functionOp.getArgumentTypes(), os, [&](Type type) -> LogicalResult {
2707 if (failed(emitter.emitType(functionOp.getLoc(), type)))
2711 if (auto argType = dyn_cast<MemRefType>(type))
2712 for (unsigned dim = 0; dim < argType.getRank(); ++dim)
2713 if (argType.isDynamicDim(dim))
2723 functionOp.getArguments(), os,
2724 [&](BlockArgument arg) -> LogicalResult {
2725 if (failed(emitter.emitType(functionOp.getLoc(), arg.getType())))
2727 os <<
" " << emitter.getOrCreateName(arg);
2730 if (failed(printMemRefDims(emitter, arg)))
2738 if (emitter.shouldDeclareVariablesAtTop()) {
2742 functionOp.walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
2743 for (OpResult result : op->getResults()) {
2744 if (failed(emitter.emitVariableDeclaration(
2747 op->emitError(
"unable to declare result variable for op")};
2749 return WalkResult::advance();
2751 if (result.wasInterrupted())
2755 Region::BlockListType &blocks = functionOp.getBlocks();
2757 for (Block &block : blocks)
2758 emitter.getOrCreateName(block);
2761 for (
auto it = std::next(blocks.begin()); it != blocks.end(); ++it) {
2763 for (BlockArgument &arg : block.getArguments()) {
2764 if (emitter.hasValueInScope(arg))
2765 return functionOp.emitOpError(
" block argument #")
2766 << arg.getArgNumber() <<
" is out of scope";
2768 emitter.emitType(block.getParentOp()->getLoc(), arg.getType())))
2770 os <<
" " << emitter.getOrCreateName(arg) <<
";\n";
2774 for (Block &block : blocks) {
2776 if (blocks.size() > 1)
2777 if (failed(emitter.emitLabel(block)))
2779 for (Operation &op : block.getOperations()) {
2784 if (
bool trailingSemicolon =
2785 !isa<scf::IfOp, scf::ForOp, cf::CondBranchOp>(op);
2786 failed(emitter.emitOperation(
2787 op, trailingSemicolon)))
2791 os.unindent() <<
"}\n";
2796static LogicalResult printOperation(CppEmitter &emitter,
2797 aievec::MatMulOp matmulOp) {
2798 auto lhs = matmulOp.getLhs();
2799 auto rhs = matmulOp.getRhs();
2800 auto acc = matmulOp.getAcc();
2803 if (!emitter.hasValueInScope(lhs) || !emitter.hasValueInScope(rhs) ||
2804 !emitter.hasValueInScope(acc))
2807 auto lhsName = printConversionTo512bit(emitter, lhs);
2808 auto rhsName = printConversionTo512bit(emitter, rhs);
2810 raw_indented_ostream &
os = emitter.ostream();
2812 StringRef accName = emitter.getOrCreateName(acc);
2814 auto lhsShape = cast<VectorType>(lhs.getType()).getShape();
2815 auto rhsShape = cast<VectorType>(rhs.getType()).getShape();
2816 os << accName <<
" = mac_" << lhsShape[0] <<
"x" << lhsShape[1] <<
"_"
2817 << rhsShape[0] <<
"x" << rhsShape[1] <<
"(";
2818 os << lhsName <<
", " << rhsName <<
", " << accName <<
")";
2821 emitter.setName(matmulOp.getResult(), accName);
2826CppEmitter::CppEmitter(raw_ostream &os,
bool declareVariablesAtTop,
bool aie2)
2827 :
os(
os), declareVariablesAtTop(declareVariablesAtTop), aie2_(aie2) {
2828 valueInScopeCount.push(0);
2829 labelInScopeCount.push(0);
2833StringRef CppEmitter::getOrCreateName(Value val, std::string prefix) {
2834 if (!valueMapper.count(val))
2835 valueMapper.insert(val,
2836 formatv(
"{0}{1}", prefix, ++valueInScopeCount.top()));
2837 return *valueMapper.begin(val);
2841void CppEmitter::setName(Value val, StringRef name) {
2842 valueMapper.insert(val, name.str());
2846std::string CppEmitter::getNewName(std::string prefix) {
2847 std::string ret = formatv(
"{0}{1}", prefix, ++valueInScopeCount.top());
2853void CppEmitter::setMemRefDimParam(Value memref,
unsigned index,
2854 const std::string ¶meter) {
2855 auto p = std::make_pair(memref, index);
2856 assert(!paramIndexMapper.count(p) &&
"memref dimension already set");
2857 paramIndexMapper[p] = parameter;
2861StringRef CppEmitter::getMemRefDimParam(Value memref,
unsigned index) {
2862 auto p = std::make_pair(memref, index);
2863 assert(paramIndexMapper.count(p) &&
"memref dimension not found");
2864 return paramIndexMapper[p];
2869bool CppEmitter::isMemRefDimParam(Value memref,
unsigned index) {
2871 auto type = llvm::dyn_cast<MemRefType>(memref.getType());
2872 if (!(type && type.isDynamicDim(index))) {
2873 printf(
"the dimension size at index is not dynamic\n");
2879 auto p = std::make_pair(memref, index);
2880 return paramIndexMapper.count(p);
2884StringRef CppEmitter::getOrCreateName(Block &block, std::string prefix) {
2885 if (!blockMapper.count(&block))
2886 blockMapper.insert(&block,
2887 formatv(
"{0}{1}", prefix, ++labelInScopeCount.top()));
2888 return *blockMapper.begin(&block);
2891bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) {
2893 case IntegerType::Signless:
2894 case IntegerType::Signed:
2896 case IntegerType::Unsigned:
2899 llvm::report_fatal_error(
"Unexpected IntegerType::SignednessSemantics");
2902bool CppEmitter::hasValueInScope(Value val) {
return valueMapper.count(val); }
2904bool CppEmitter::hasBlockLabel(Block &block) {
2905 return blockMapper.count(&block);
2910template <
typename ElTy>
2911static std::string getSplatValueOfIntDense(DenseIntElementsAttr dense) {
2912 ElTy splatVal = dense.getSplatValue<ElTy>();
2913 return std::to_string(splatVal);
2917static std::string getSplatValueOfFloatDense(DenseFPElementsAttr dense,
2918 bool isBFloat =
false) {
2919 auto apFloat = dense.getSplatValue<APFloat>();
2920 float splatVal = apFloat.convertToFloat();
2921 std::string firstValue = std::to_string(splatVal);
2923 if (apFloat.isPosInfinity())
2928 firstValue = std::to_string(0x1.FEp+127f);
2930 firstValue = std::to_string(std::numeric_limits<float>::max());
2931 else if (apFloat.isNegInfinity())
2933 firstValue = std::to_string(-0x1.FEp+127f);
2935 firstValue = std::to_string(std::numeric_limits<float>::lowest());
2936 else if (!apFloat.isNonZero())
2942LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
2943 auto printInt = [&](
const APInt &val,
bool isUnsigned) {
2944 if (val.getBitWidth() == 1)
2945 if (val.getBoolValue())
2950 SmallString<128> strValue;
2951 val.toString(strValue, 10, !isUnsigned,
false);
2956 auto printFloat = [&](
const APFloat &val) {
2957 if (val.isFinite()) {
2958 SmallString<128> strValue;
2960 val.toString(strValue, 0, 0,
false);
2961 switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
2962 case llvm::APFloatBase::S_IEEEsingle:
2965 case llvm::APFloatBase::S_IEEEdouble:
2972 }
else if (val.isNaN())
2974 else if (val.isInfinity()) {
2975 if (val.isNegative())
2982 if (
auto fAttr = llvm::dyn_cast<FloatAttr>(attr)) {
2983 printFloat(fAttr.getValue());
2987 if (
auto dense = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2988 if (aie2() && dense.isSplat()) {
2989 if (
auto vType = llvm::dyn_cast<VectorType>(dense.getType()))
2990 if (
auto fType = llvm::dyn_cast<FloatType>(vType.getElementType())) {
2991 unsigned width = fType.getWidth();
2992 std::string splatValue;
2994 splatValue = getSplatValueOfFloatDense(dense);
2995 else if (width == 16)
2996 splatValue = getSplatValueOfFloatDense(dense,
true);
2999 if (splatValue ==
"0") {
3000 os <<
"broadcast_zero_";
3001 if (failed(emitType(loc, fType)))
3005 os <<
"broadcast_to_";
3006 if (failed(emitType(loc, vType)))
3009 if (failed(emitType(loc, fType)))
3016 os <<
"extract_v16bfloat16(";
3017 if (splatValue ==
"0")
3018 os <<
"broadcast_zero_bfloat16()";
3020 os <<
"broadcast_to_v32bfloat16";
3022 if (failed(emitType(loc, fType)))
3034 interleaveComma(dense, os, [&](
const APFloat &val) { printFloat(val); });
3041 if (
auto iAttr = llvm::dyn_cast<IntegerAttr>(attr)) {
3042 if (
auto iType = llvm::dyn_cast<IntegerType>(iAttr.getType())) {
3043 printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness()));
3046 if (llvm::dyn_cast<IndexType>(iAttr.getType())) {
3047 printInt(iAttr.getValue(),
false);
3052 if (
auto dense = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
3053 if (
auto tType = llvm::dyn_cast<TensorType>(dense.getType())) {
3054 if (
auto iType = llvm::dyn_cast<IntegerType>(tType.getElementType())) {
3056 interleaveComma(dense, os, [&](
const APInt &val) {
3057 printInt(val, shouldMapToUnsigned(iType.getSignedness()));
3062 if (llvm::dyn_cast<IndexType>(tType.getElementType())) {
3064 interleaveComma(dense, os,
3065 [&](
const APInt &val) { printInt(val,
false); });
3071 if (
auto vType = llvm::dyn_cast<VectorType>(dense.getType())) {
3072 if (
auto iType = llvm::dyn_cast<IntegerType>(vType.getElementType())) {
3073 unsigned width = iType.getWidth();
3074 if (llvm::all_of(dense, [](
const APInt &val) {
return val == 0; })) {
3077 os <<
"concat(broadcast_zero_s" << width <<
"(), broadcast_zero_s"
3081 os <<
"broadcast_zero_s";
3085 if (failed(emitType(loc, vType)))
3092 if (aie2() && dense.isSplat()) {
3093 std::string splatValue;
3095 splatValue = getSplatValueOfIntDense<int32_t>(dense);
3096 else if (width == 16)
3097 splatValue = getSplatValueOfIntDense<int16_t>(dense);
3098 else if (width == 8)
3099 splatValue = getSplatValueOfIntDense<int8_t>(dense);
3100 os <<
"broadcast_to_";
3101 if (failed(emitType(loc, vType)))
3104 if (failed(emitType(loc, iType)))
3112 interleaveComma(dense, os, [&](
const APInt &val) {
3113 printInt(val, shouldMapToUnsigned(iType.getSignedness()));
3119 if (llvm::dyn_cast<IndexType>(vType.getElementType())) {
3121 interleaveComma(dense, os,
3122 [&](
const APInt &val) { printInt(val,
false); });
3130 if (
auto oAttr = llvm::dyn_cast<emitc::OpaqueAttr>(attr)) {
3131 os << oAttr.getValue();
3136 if (
auto sAttr = llvm::dyn_cast<SymbolRefAttr>(attr)) {
3137 if (sAttr.getNestedReferences().size() > 1)
3138 return emitError(loc,
"attribute has more than 1 nested reference");
3139 os << sAttr.getRootReference().getValue();
3144 if (
auto type = llvm::dyn_cast<TypeAttr>(attr))
3145 return emitType(loc, type.getValue());
3147 return emitError(loc,
"cannot emit attribute of type ") << attr;
3150LogicalResult CppEmitter::emitOperands(Operation &op) {
3151 auto emitOperandName = [&](Value result) -> LogicalResult {
3152 if (!hasValueInScope(result))
3153 return op.emitOpError() <<
"operand value not in scope";
3154 os << getOrCreateName(result);
3161CppEmitter::emitOperandsAndAttributes(Operation &op,
3162 ArrayRef<StringRef> exclude) {
3163 if (failed(emitOperands(op)))
3166 if (op.getNumOperands() > 0)
3167 for (NamedAttribute attr : op.getAttrs())
3168 if (!is_contained(exclude, attr.getName().strref())) {
3173 auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult {
3174 if (is_contained(exclude, attr.getName().strref()))
3176 os <<
"/* " << attr.getName().getValue() <<
" */";
3177 if (failed(emitAttribute(op.getLoc(), attr.getValue())))
3185LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
3186 if (!hasValueInScope(result)) {
3187 return result.getDefiningOp()->emitOpError(
3188 "result variable for the operation has not been declared");
3190 os << getOrCreateName(result) <<
" = ";
3195LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
3196 bool trailingSemicolon,
3198 if (hasValueInScope(result))
3199 return result.getDefiningOp()->emitError(
3200 "result variable for the operation already declared");
3202 emitType(result.getOwner()->getLoc(), result.getType(),
true, isAcc)))
3204 os <<
" " << getOrCreateName(result);
3205 if (trailingSemicolon)
3211LogicalResult CppEmitter::emitAssignPrefix(Operation &op,
bool isAcc) {
3212 switch (op.getNumResults()) {
3216 OpResult result = op.getResult(0);
3217 if (shouldDeclareVariablesAtTop()) {
3218 if (failed(emitVariableAssignment(result)))
3221 if (failed(emitVariableDeclaration(result,
false,
3229 if (!shouldDeclareVariablesAtTop())
3230 for (OpResult result : op.getResults())
3231 if (failed(emitVariableDeclaration(result, true)))
3235 interleaveComma(op.getResults(), os,
3236 [&](Value result) { os << getOrCreateName(result); });
3242LogicalResult CppEmitter::emitLabel(Block &block) {
3243 if (!hasBlockLabel(block))
3244 return block.getParentOp()->emitError(
"label for block not found");
3247 os.getOStream() << getOrCreateName(block) <<
":\n";
3251LogicalResult CppEmitter::emitOperation(Operation &op,
bool trailingSemicolon) {
3254 if (skippedOp(&op, *
this))
3257 LogicalResult status =
3258 TypeSwitch<Operation *, LogicalResult>(&op)
3260 .Case<emitc::ApplyOp, emitc::CallOpaqueOp, emitc::ConstantOp>(
3261 [&](
auto op) {
return printOperation(*
this, op); })
3262 .Case<emitc::IncludeOp>([&](
auto op) {
3263 if (StringRef name = op.getInclude(); !includeNames.count(name)) {
3264 includeNames.insert(name);
3265 return printOperation(*
this, op);
3270 .Case<scf::ForOp, scf::IfOp, scf::YieldOp>(
3271 [&](
auto op) {
return printOperation(*
this, op); })
3273 .Case<cf::BranchOp, func::CallOp, cf::CondBranchOp, func::FuncOp,
3274 ModuleOp, func::ReturnOp>(
3275 [&](
auto op) {
return printOperation(*
this, op); })
3277 .Case<arith::ConstantOp>(
3278 [&](
auto op) {
return printOperation(*
this, op); })
3281 .Case<arith::AddIOp>(
3282 [&](
auto op) {
return printOperation<arith::AddIOp>(*
this, op); })
3283 .Case<arith::AddFOp>(
3284 [&](
auto op) {
return printOperation<arith::AddFOp>(*
this, op); })
3285 .Case<arith::MulIOp>(
3286 [&](
auto op) {
return printOperation<arith::MulIOp>(*
this, op); })
3287 .Case<arith::MulFOp>(
3288 [&](
auto op) {
return printOperation<arith::MulFOp>(*
this, op); })
3289 .Case<arith::SubIOp>(
3290 [&](
auto op) {
return printOperation<arith::SubIOp>(*
this, op); })
3291 .Case<arith::SubFOp>(
3292 [&](
auto op) {
return printOperation<arith::SubFOp>(*
this, op); })
3293 .Case<arith::DivSIOp>([&](
auto op) {
3294 return printOperation<arith::DivSIOp>(*
this, op);
3296 .Case<arith::DivUIOp>([&](
auto op) {
3297 return printOperation<arith::DivUIOp>(*
this, op);
3299 .Case<arith::DivFOp>(
3300 [&](
auto op) {
return printOperation<arith::DivFOp>(*
this, op); })
3301 .Case<arith::RemSIOp>([&](
auto op) {
3302 return printOperation<arith::RemSIOp>(*
this, op);
3304 .Case<arith::CmpIOp>(
3305 [&](
auto op) {
return printOperation<arith::CmpIOp>(*
this, op); })
3306 .Case<arith::SelectOp>(
3307 [&](
auto op) {
return printOperation(*
this, op); })
3309 .Case<vector::TransferWriteOp>(
3310 [&](
auto op) {
return printOperation(*
this, op); })
3312 .Case<memref::StoreOp, memref::ExpandShapeOp,
3313 memref::CollapseShapeOp>(
3314 [&](
auto op) {
return printOperation(*
this, op); })
3316 .Case<aievec::aie1::AddOp, aievec::aie1::SubOp, aievec::aie1::FMAOp,
3317 aievec::aie1::MulOp, aievec::aie1::SelectOp,
3318 aievec::aie1::ExtOp>(
3319 [&](
auto op) {
return printOperation(*
this, op); })
3321 .Case<AddElemOp, ConcatOp, ExtOp, PackOp, SRSOp, SubElemOp, UPDOp,
3322 UPSOp, FMAElemOp, MulElemOp, BroadcastOp, BroadcastScalarOp,
3323 MulConvOp, FMAConvOp, ShiftOp, ShuffleOp, CastOp, MinOp, MaxOp,
3324 NegOp, CmpOp, SelOp, ExtElemOp, BxorOp, BnegOp, BandOp, BorOp,
3325 UnpackOp, MatMulOp, LegacyShuffleOp>(
3326 [&](
auto op) {
return printOperation(*
this, op); })
3328 .Case<AIE::DeviceOp>(
3329 [&](
auto op) {
return printOperation(*
this, op); })
3330 .Default([&](Operation *) {
3331 return op.emitOpError(
"unable to find printer for op");
3336 os << (trailingSemicolon ?
";\n" :
"\n");
3341std::optional<std::string>
3342CppEmitter::genCppTypeName(Type type,
bool stdintType,
bool isAcc) {
3343 std::stringstream ss;
3344 if (
auto iType = dyn_cast<IntegerType>(type)) {
3345 switch (iType.getWidth()) {
3352 if (shouldMapToUnsigned(iType.getSignedness()))
3353 ss <<
"uint" << iType.getWidth() << (stdintType ?
"_t" :
"");
3355 ss <<
"int" << iType.getWidth() << (stdintType ?
"_t" :
"");
3359 ss <<
"acc" << iType.getWidth();
3365 if (
auto fType = dyn_cast<FloatType>(type)) {
3366 switch (fType.getWidth()) {
3377 if (
auto iType = dyn_cast<IndexType>(type))
3380 if (
auto tType = dyn_cast<TensorType>(type)) {
3381 if (!tType.hasRank())
3383 if (!tType.hasStaticShape())
3386 auto nestedTypeName = genCppTypeName(tType.getElementType());
3387 if (!nestedTypeName)
3389 ss << *nestedTypeName;
3390 auto shape = tType.getShape();
3391 for (
auto dimSize : shape) {
3398 if (
auto tType = dyn_cast<TupleType>(type)) {
3399 ss <<
"std::tuple<";
3400 bool itrleaveFailed =
false;
3404 auto optTyNameStr = genCppTypeName(type);
3406 ss << *optTyNameStr;
3408 itrleaveFailed = true;
3410 [&]() { ss <<
", "; });
3412 if (!itrleaveFailed)
3416 if (
auto oType = dyn_cast<emitc::OpaqueType>(type)) {
3417 ss << oType.getValue().str();
3422 if (
auto tType = dyn_cast<MemRefType>(type)) {
3423 auto elemTyStrOpt = genCppTypeName(tType.getElementType());
3426 ss << *elemTyStrOpt <<
" * restrict";
3430 if (
auto tType = dyn_cast<VectorType>(type)) {
3431 Type eltType = tType.getElementType();
3433 auto vShape = tType.getShape();
3434 int64_t numElems = std::accumulate(vShape.begin(), vShape.end(), 1,
3435 std::multiplies<int64_t>());
3436 ss <<
"v" << std::to_string(numElems);
3438 int64_t iElTyBitWidth = 0;
3439 auto iElTy = dyn_cast<IntegerType>(eltType);
3441 iElTyBitWidth = iElTy.getWidth();
3442 if (aie2() && (isAcc || iElTyBitWidth == 64)) {
3446 if ((numElems == 16 && iElTyBitWidth == 64) ||
3447 (numElems == 32 && iElTyBitWidth == 32) ||
3448 (numElems == 16 && iElTyBitWidth == 32)) {
3449 ss <<
"acc" << iElTyBitWidth;
3454 if (isa<FloatType>(eltType)) {
3460 auto elTyNameOpt = genCppTypeName(eltType,
false);
3469LogicalResult CppEmitter::emitType(Location loc, Type type,
bool stdintType,
3471 auto typeName = genCppTypeName(type, stdintType, isAcc);
3473 return emitError(loc,
"cannot emit type ") << type;
3478LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
3479 switch (types.size()) {
3484 return emitType(loc, types.front());
3486 return emitTupleType(loc, types);
3490LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
3491 os <<
"std::tuple<";
3493 types, os, [&](Type type) {
return emitType(loc, type); })))
3501 CppEmitter emitter(os,
false, aie2);
3502 return emitter.emitOperation(*op,
false);
LogicalResult interleaveCommaWithError(const Container &c, raw_ostream &os, UnaryFunctor eachFn)
LogicalResult interleaveWithError(ForwardIterator begin, ForwardIterator end, UnaryFunctor eachFn, NullaryFunctor betweenFn)
Convenience functions to produce interleaved output with functions returning a LogicalResult.
std::shared_ptr< Value > value()
mlir::LogicalResult translateAIEVecToCpp(mlir::Operation *op, bool aie2, mlir::raw_ostream &os)
Translates the AIE vector dialect MLIR to C++ code.
int32_t getVectorSizeInBits(mlir::VectorType type)
unsigned getVectorLaneSize(mlir::VectorType type)
int32_t getElementSizeInBits(mlir::VectorType type)