11#include "mlir/Interfaces/CopyOpInterface.h"
12#include "mlir/Interfaces/SideEffectInterfaces.h"
13#include "mlir/Pass/Pass.h"
16using namespace MemoryEffects;
53class CopyRemovalPass :
public PassWrapper<CopyRemovalPass, OperationPass<>> {
55 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CopyRemovalPass)
57 void runOnOperation()
override {
58 getOperation()->walk([&](CopyOpInterface copyOp) {
59 reuseCopySourceAsTarget(copyOp);
60 reuseCopyTargetAsSource(copyOp);
62 for (std::pair<Value, Value> &pair : replaceList)
63 pair.first.replaceAllUsesWith(pair.second);
64 for (Operation *op : eraseList)
70 llvm::SmallPtrSet<Operation *, 4> eraseList;
73 llvm::SmallDenseSet<std::pair<Value, Value>, 4> replaceList;
77 Operation *getAllocationOpInBlock(Value value, Block *block) {
78 assert(block &&
"Block cannot be null");
79 Operation *op = value.getDefiningOp();
80 if (op && op->getBlock() == block) {
81 auto effects = dyn_cast<MemoryEffectOpInterface>(op);
82 if (effects && effects.hasEffect<Allocate>())
90 Operation *getDeallocationOpInBlock(Value value, Block *block) {
91 assert(block &&
"Block cannot be null");
92 auto valueUsers = value.getUsers();
93 auto it = llvm::find_if(valueUsers, [&](Operation *op) {
94 auto effects = dyn_cast<MemoryEffectOpInterface>(op);
95 return effects && op->getBlock() == block && effects.hasEffect<Free>();
97 return (it == valueUsers.end() ? nullptr : *it);
102 bool hasMemoryEffectOpBetween(Operation *start, Operation *end) {
103 assert((start || end) &&
"Start and end operations cannot be null");
104 assert(start->getBlock() == end->getBlock() &&
105 "Start and end operations should be in the same block.");
106 Operation *op = start->getNextNode();
107 while (op->isBeforeInBlock(end)) {
108 if (isa<MemoryEffectOpInterface>(op))
110 op = op->getNextNode();
117 bool hasUsersBetween(Value val, Operation *start, Operation *end) {
118 assert((start || end) &&
"Start and end operations cannot be null");
119 Block *block = start->getBlock();
120 assert(block == end->getBlock() &&
121 "Start and end operations should be in the same block.");
122 return llvm::any_of(val.getUsers(), [&](Operation *op) {
123 return op->getBlock() == block && start->isBeforeInBlock(op) &&
124 op->isBeforeInBlock(end);
128 bool areOpsInTheSameBlock(ArrayRef<Operation *> operations) {
129 assert(!operations.empty() &&
130 "The operations list should contain at least a single operation");
131 Block *block = operations.front()->getBlock();
132 return llvm::none_of(
133 operations, [&](Operation *op) {
return block != op->getBlock(); });
162 void reuseCopySourceAsTarget(CopyOpInterface copyOp) {
163 if (eraseList.count(copyOp))
166 Value from = copyOp.getSource();
167 Value to = copyOp.getTarget();
169 Operation *copy = copyOp.getOperation();
170 Block *copyBlock = copy->getBlock();
171 Operation *fromDefiningOp = from.getDefiningOp();
172 Operation *fromFreeingOp = getDeallocationOpInBlock(from, copyBlock);
173 Operation *toDefiningOp = getAllocationOpInBlock(to, copyBlock);
174 if (!fromDefiningOp || !fromFreeingOp || !toDefiningOp ||
175 !areOpsInTheSameBlock({fromFreeingOp, toDefiningOp, copy}) ||
176 hasUsersBetween(to, toDefiningOp, copy) ||
177 hasUsersBetween(from, copy, fromFreeingOp) ||
178 hasMemoryEffectOpBetween(copy, fromFreeingOp))
181 replaceList.insert({to, from});
182 eraseList.insert(copy);
183 eraseList.insert(toDefiningOp);
184 eraseList.insert(fromFreeingOp);
213 void reuseCopyTargetAsSource(CopyOpInterface copyOp) {
214 if (eraseList.count(copyOp))
217 Value from = copyOp.getSource();
218 Value to = copyOp.getTarget();
220 Operation *copy = copyOp.getOperation();
221 Block *copyBlock = copy->getBlock();
222 Operation *fromDefiningOp = getAllocationOpInBlock(from, copyBlock);
223 Operation *fromFreeingOp = getDeallocationOpInBlock(from, copyBlock);
224 if (!fromDefiningOp || !fromFreeingOp ||
225 !areOpsInTheSameBlock({fromFreeingOp, fromDefiningOp, copy}) ||
226 hasUsersBetween(to, fromDefiningOp, copy) ||
227 hasUsersBetween(from, copy, fromFreeingOp) ||
228 hasMemoryEffectOpBetween(copy, fromFreeingOp))
231 replaceList.insert({from, to});
232 eraseList.insert(copy);
233 eraseList.insert(fromDefiningOp);
234 eraseList.insert(fromFreeingOp);
245 return std::make_unique<CopyRemovalPass>();
std::unique_ptr<::mlir::Pass > createCopyRemovalPass()
Create a pass that removes unnecessary Copy operations.