411 ConversionPatternRewriter &rewriter)
const override {
412 auto &convMacChainAnalysis =
418 auto loc = srcOp.getLoc();
419 VectorType vecTy = cast<VectorType>(srcOp.getResult().getType());
420 unsigned elemWidth = cast<IntegerType>(vecTy.getElementType()).getWidth();
421 unsigned accWidth = elemWidth <= 8 ? 32 : 64;
422 int32_t M = elemWidth == 8 ? 32 : 16;
423 int32_t N = elemWidth == 8 ? 8 : 4;
425 Type wideElemTy = IntegerType::get(getContext(), accWidth);
426 Type accVecTy = VectorType::get(vecTy.getShape(), wideElemTy);
428 const auto &groups = convMacChainAnalysis.getGroupsInChain();
429 Value grpAcc = (*convMacChain)[groups[0].fromIdx]->acc;
431 grpAcc = aievec::UPSOp::create(rewriter, srcOp.getLoc(), accVecTy, grpAcc,
434 for (
const auto &group : groups) {
435 Value grpLhs = (*convMacChain)[group.fromIdx]->lhs;
436 Value grpRhs = (*convMacChain)[group.fromIdx]->rhs;
437 auto filterVecTy = cast<VectorType>(grpRhs.getType());
438 auto signalVecTy = cast<VectorType>(grpLhs.getType());
442 if (2 * filterVecTy.getShape()[0] == signalVecTy.getShape()[0])
444 aievec::ConcatOp::create(rewriter, loc, signalVecTy,
445 SmallVector<Value, 2>({grpRhs, grpRhs}))
448 if (group.bcastDist == 2)
450 grpRhs = aievec::ShuffleOp::create(rewriter, loc, signalVecTy, grpRhs,
451 grpRhs, ShuffleMode::T8_64X2_LO)
455 if (group.bcastShift) {
458 (3 + group.bcastDist - 1);
460 arith::ConstantOp::create(rewriter, loc,
461 rewriter.getI32IntegerAttr(shiftBytes))
464 aievec::ShiftOp::create(rewriter, grpRhs.getDefiningOp()->getLoc(),
465 signalVecTy, grpRhs, grpRhs, shiftBytesCst)
471 if (group.signalShift) {
475 arith::ConstantOp::create(rewriter, loc,
476 rewriter.getI32IntegerAttr(shiftBytes))
478 grpLhs = aievec::ShiftOp::create(rewriter, loc, signalVecTy, grpLhs,
479 grpLhs, shiftBytesCst)
486 grpAcc = aievec::MulConvOp::create(rewriter, srcOp.getLoc(), accVecTy,
487 grpLhs, grpRhs, M, N)
490 grpAcc = aievec::FMAConvOp::create(rewriter, srcOp.getLoc(), accVecTy,
491 grpLhs, grpRhs, grpAcc, M, N,
false)
495 auto shiftParamOp = arith::ConstantOp::create(
496 rewriter, srcOp.getLoc(), rewriter.getI32IntegerAttr(
shiftParam));
497 rewriter.replaceOpWithNewOp<aievec::SRSOp>(srcOp, vecTy, grpAcc,
498 shiftParamOp.getResult());