MLIR-AIE
DynamicSizeNoImplicitBroadcast.cpp
Go to the documentation of this file.
1//===- DynamicSizeNoImplicitBroadcast.cpp -=====================-*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// (c) Copyright 2023, Advanced Micro Devices, Inc.
8//
9//===----------------------------------------------------------------------===//
10// This file contains rewrites to the arith dialect to enable the support of
11// dynamic sized tensor/memref for the auto-vectorization to CPP flow.
12// MLIR-AIE auto-vectorization to CPP flow currently doesn't support to
13// implicitly broadcast a dynamic dimension of size `1`. Hence, we assume that
14// dynamic dimensions are not with size '1' that can be interpreted to various
15// broadcasting scenarios. The effectiveness of this rewrite pattern is guarded
16// by the attribute `tosa.no_implicit_broadcast_of_dynamic_sizes`.
17//===----------------------------------------------------------------------===//
18
21
22#include "mlir/Dialect/Arith/IR/Arith.h"
23#include "mlir/Dialect/MemRef/IR/MemRef.h"
24#include "mlir/Dialect/Tensor/IR/Tensor.h"
25#include "mlir/Pass/PassManager.h"
26#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27
28#define DEBUG_TYPE "dynamic-size-no-implicit-broadcast"
29
30using namespace llvm;
31using namespace mlir;
32using namespace xilinx;
33using namespace xilinx::aievec;
34
35//============================================================================//
36//=========================== Rewrite Patterns ===============================//
37//============================================================================//
38
39// This pattern replaces a arith::CmpIOp with a arith::ConstantOp `false` only
40// when the CmpIOp compares the equality of a dynamic dimension's runtime size
41// to a constant 1, and is guarded by the attribute
42// `tosa.no_implicit_broadcast_of_dynamic_sizes`.
45 : RewritePattern(arith::CmpIOp::getOperationName(), /*benefit=*/1,
46 context) {}
47
48 LogicalResult matchAndRewrite(Operation *op,
49 PatternRewriter &rewriter) const override {
51 return failure();
52
53 arith::CmpIOp cmpiOp = cast<arith::CmpIOp>(op);
54
55 if (cmpiOp.getPredicate() != arith::CmpIPredicate::eq)
56 return failure();
57
58 auto lhsOp = cmpiOp.getLhs().getDefiningOp();
59 auto rhsOp = cmpiOp.getRhs().getDefiningOp();
60 if (!((isa<memref::DimOp>(lhsOp) || isa<tensor::DimOp>(lhsOp)) &&
61 isa<arith::ConstantOp>(rhsOp)) &&
62 !((isa<memref::DimOp>(rhsOp) || isa<tensor::DimOp>(rhsOp)) &&
63 isa<arith::ConstantOp>(lhsOp)))
64 return failure();
65
66 // Make sure rhsOp is ConstantOp and lhsOp is DimOp
67 if (isa<memref::DimOp>(rhsOp) || isa<tensor::DimOp>(rhsOp))
68 std::swap(lhsOp, rhsOp);
69
70 // If ConstantOp is 1 for Integer/Index, replace cmpiOp as constant 0
71 auto constantOp = cast<arith::ConstantOp>(rhsOp);
72 if (cast<IntegerAttr>(constantOp.getValue()).getValue().getZExtValue() != 1)
73 return failure();
74
75 // Check the DimOp's input is a dynamic dim from the given index
76 auto constIndexOp = lhsOp->getOperand(1).getDefiningOp<arith::ConstantOp>();
77 if (!constIndexOp)
78 return failure();
79
80 auto index =
81 cast<IntegerAttr>(constIndexOp.getValue()).getValue().getZExtValue();
82 auto inputDimType = dyn_cast<ShapedType>(lhsOp->getOperand(0).getType());
83 if (!inputDimType || !inputDimType.isDynamicDim(index))
84 return failure();
85
86 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
87 cmpiOp, rewriter.getIntegerAttr(rewriter.getI1Type(), 0));
88
89 return success();
90 }
91};
92
93//============================================================================//
94//======================== Canonicalization Passes ===========================//
95//============================================================================//
96
98 : PassWrapper<DynamicSizeNoImplicitBroadcastPass, OperationPass<>> {
99
100 StringRef getArgument() const final {
101 return "test-dynamic-size-no-implicit-broadcast";
102 }
103
104 StringRef getDescription() const final {
105 return "Test rewriting arith operations when assuming no implict "
106 "broadcast of dynamic sizes";
107 }
108
109 void runOnOperation() override {
110 auto op = getOperation();
111 MLIRContext *context = &getContext();
112 RewritePatternSet patterns(context);
113
114 patterns.add<DynamicSizeNoImplicitBroadcastPattern>(patterns.getContext());
115
116 (void)applyPatternsGreedily(op, std::move(patterns));
117 }
118};
119
120std::unique_ptr<::mlir::Pass>
122 return std::make_unique<DynamicSizeNoImplicitBroadcastPass>();
123}
124
125//============================================================================//
126//====================== Main Pipeline Configuration =========================//
127//============================================================================//
128
130 OpPassManager &pm) {
132}
bool isAssumingNoImplicitBroadcastOfDynamicSizes(mlir::Block *block)
std::unique_ptr<::mlir::Pass > createDynamicSizeNoImplicitBroadcastPass()
void buildDynamicSizeNoImplicitBroadcastPass(mlir::OpPassManager &pm)
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override