411 ConversionPatternRewriter &rewriter)
const override {
412 auto &convMacChainAnalysis =
418 auto loc = srcOp.getLoc();
419 VectorType vecTy = cast<VectorType>(srcOp.getResult().getType());
420 unsigned elemWidth = cast<IntegerType>(vecTy.getElementType()).getWidth();
421 unsigned accWidth = elemWidth <= 8 ? 32 : 64;
422 int32_t M = elemWidth == 8 ? 32 : 16;
423 int32_t N = elemWidth == 8 ? 8 : 4;
425 Type wideElemTy = IntegerType::get(getContext(), accWidth);
426 Type accVecTy = VectorType::get(vecTy.getShape(), wideElemTy);
428 const auto &groups = convMacChainAnalysis.getGroupsInChain();
429 Value grpAcc = (*convMacChain)[groups[0].fromIdx]->acc;
432 .create<aievec::UPSOp>(srcOp.getLoc(), accVecTy, grpAcc,
435 for (
const auto &group : groups) {
436 Value grpLhs = (*convMacChain)[group.fromIdx]->lhs;
437 Value grpRhs = (*convMacChain)[group.fromIdx]->rhs;
438 auto filterVecTy = cast<VectorType>(grpRhs.getType());
439 auto signalVecTy = cast<VectorType>(grpLhs.getType());
443 if (2 * filterVecTy.getShape()[0] == signalVecTy.getShape()[0])
446 .create<aievec::ConcatOp>(
447 loc, signalVecTy, SmallVector<Value, 2>({grpRhs, grpRhs}))
450 if (group.bcastDist == 2)
453 .create<aievec::ShuffleOp>(loc, signalVecTy, grpRhs,
454 grpRhs, ShuffleMode::T8_64X2_LO)
458 if (group.bcastShift) {
461 (3 + group.bcastDist - 1);
464 .create<arith::ConstantOp>(
465 loc, rewriter.getI32IntegerAttr(shiftBytes))
468 .create<aievec::ShiftOp>(grpRhs.getDefiningOp()->getLoc(),
469 signalVecTy, grpRhs, grpRhs,
476 if (group.signalShift) {
481 .create<arith::ConstantOp>(
482 loc, rewriter.getI32IntegerAttr(shiftBytes))
485 .create<aievec::ShiftOp>(loc, signalVecTy, grpLhs, grpLhs,
494 .create<aievec::MulConvOp>(srcOp.getLoc(), accVecTy,
495 grpLhs, grpRhs, M, N)
500 .create<aievec::FMAConvOp>(srcOp.getLoc(), accVecTy, grpLhs,
501 grpRhs, grpAcc, M, N,
false)
505 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
506 srcOp.getLoc(), rewriter.getI32IntegerAttr(
shiftParam));
507 rewriter.replaceOpWithNewOp<aievec::SRSOp>(srcOp, vecTy, grpAcc,
508 shiftParamOp.getResult());