410 ConversionPatternRewriter &rewriter)
const override {
411 auto &convMacChainAnalysis =
417 auto loc = srcOp.getLoc();
418 VectorType vecTy = cast<VectorType>(srcOp.getResult().getType());
419 unsigned elemWidth = cast<IntegerType>(vecTy.getElementType()).getWidth();
420 unsigned accWidth = elemWidth <= 8 ? 32 : 64;
421 int32_t M = elemWidth == 8 ? 32 : 16;
422 int32_t N = elemWidth == 8 ? 8 : 4;
424 Type wideElemTy = IntegerType::get(getContext(), accWidth);
425 Type accVecTy = VectorType::get(vecTy.getShape(), wideElemTy);
427 const auto &groups = convMacChainAnalysis.getGroupsInChain();
428 Value grpAcc = (*convMacChain)[groups[0].fromIdx]->acc;
431 .create<aievec::UPSOp>(srcOp.getLoc(), accVecTy, grpAcc,
434 for (
const auto &group : groups) {
435 Value grpLhs = (*convMacChain)[group.fromIdx]->lhs;
436 Value grpRhs = (*convMacChain)[group.fromIdx]->rhs;
437 auto filterVecTy = cast<VectorType>(grpRhs.getType());
438 auto signalVecTy = cast<VectorType>(grpLhs.getType());
442 if (2 * filterVecTy.getShape()[0] == signalVecTy.getShape()[0])
445 .create<aievec::ConcatOp>(
446 loc, signalVecTy, SmallVector<Value, 2>({grpRhs, grpRhs}))
449 if (group.bcastDist == 2)
452 .create<aievec::ShuffleOp>(loc, signalVecTy, grpRhs,
453 grpRhs, ShuffleMode::T8_64X2_LO)
457 if (group.bcastShift) {
460 (3 + group.bcastDist - 1);
463 .create<arith::ConstantOp>(
464 loc, rewriter.getI32IntegerAttr(shiftBytes))
467 .create<aievec::ShiftOp>(grpRhs.getDefiningOp()->getLoc(),
468 signalVecTy, grpRhs, grpRhs,
475 if (group.signalShift) {
480 .create<arith::ConstantOp>(
481 loc, rewriter.getI32IntegerAttr(shiftBytes))
484 .create<aievec::ShiftOp>(loc, signalVecTy, grpLhs, grpLhs,
493 .create<aievec::MulConvOp>(srcOp.getLoc(), accVecTy,
494 grpLhs, grpRhs, M, N)
499 .create<aievec::FMAConvOp>(srcOp.getLoc(), accVecTy, grpLhs,
500 grpRhs, grpAcc, M, N,
false)
504 auto shiftParamOp = rewriter.create<arith::ConstantOp>(
505 srcOp.getLoc(), rewriter.getI32IntegerAttr(
shiftParam));
506 rewriter.replaceOpWithNewOp<aievec::SRSOp>(srcOp, vecTy, grpAcc,
507 shiftParamOp.getResult());