49 PatternRewriter &rewriter)
const override {
53 arith::CmpIOp cmpiOp = cast<arith::CmpIOp>(op);
55 if (cmpiOp.getPredicate() != arith::CmpIPredicate::eq)
58 auto lhsOp = cmpiOp.getLhs().getDefiningOp();
59 auto rhsOp = cmpiOp.getRhs().getDefiningOp();
60 if (!((isa<memref::DimOp>(lhsOp) || isa<tensor::DimOp>(lhsOp)) &&
61 isa<arith::ConstantOp>(rhsOp)) &&
62 !((isa<memref::DimOp>(rhsOp) || isa<tensor::DimOp>(rhsOp)) &&
63 isa<arith::ConstantOp>(lhsOp)))
67 if (isa<memref::DimOp>(rhsOp) || isa<tensor::DimOp>(rhsOp))
68 std::swap(lhsOp, rhsOp);
71 auto constantOp = cast<arith::ConstantOp>(rhsOp);
72 if (cast<IntegerAttr>(constantOp.getValue()).getValue().getZExtValue() != 1)
76 auto constIndexOp = lhsOp->getOperand(1).getDefiningOp<arith::ConstantOp>();
81 cast<IntegerAttr>(constIndexOp.getValue()).getValue().getZExtValue();
82 auto inputDimType = dyn_cast<ShapedType>(lhsOp->getOperand(0).getType());
83 if (!inputDimType || !inputDimType.isDynamicDim(index))
86 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
87 cmpiOp, rewriter.getIntegerAttr(rewriter.getI1Type(), 0));