summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFelix Schneider <fx.schn@gmail.com>2024-05-12 18:11:42 +0200
committerGitHub <noreply@github.com>2024-05-12 18:11:42 +0200
commit78b3a00418ce6da0426a261a64a77608d0264fe5 (patch)
treebcf193d3751283676162f212946d710cf6ced911
parent502e77df1fc4aa859db6709e14e93af6207e4dc4 (diff)
downloadllvm-libc-78b3a00418ce6da0426a261a64a77608d0264fe5.tar.gz
[mlir] `int-range-optmizations`: Fix referencing of deleted ops (#91807)
The pass runs a `DataFlowSolver` and collects state information on the input IR. Then, the rewrite driver and folding is applied. During pattern application and folding it can happen that an Op from the input IR is deleted and a new Op is created at the same address. When the newly created Ops is looked up in the `DataFlowSolver` state memory, the state of the original Op is returned. This patch adds a method to `DataFlowSolver` which removes all state related to a `ProgramPoint`. It also adds a listener to the Pass which clears the state information of deleted Ops from the `DataFlowSolver`. Fix https://github.com/llvm/llvm-project/issues/81228
-rw-r--r--mlir/include/mlir/Analysis/DataFlowFramework.h11
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp25
2 files changed, 35 insertions, 1 deletions
diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index c76cfac07fc7..2580ec28b519 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -242,6 +242,17 @@ public:
return static_cast<const StateT *>(it->second.get());
}
+ /// Erase any analysis state associated with the given program point.
+ template <typename PointT>
+ void eraseState(PointT point) {
+ ProgramPoint pp(point);
+
+ for (auto it = analysisStates.begin(); it != analysisStates.end(); ++it) {
+ if (it->first.first == pp)
+ analysisStates.erase(it);
+ }
+ }
+
/// Get a uniqued program point instance. If one is not present, it is
/// created with the provided arguments.
template <typename PointT, typename... Args>
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 92cad7cd1ef2..2473169962b9 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -102,6 +102,24 @@ static FailureOr<bool> handleUge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
}
namespace {
+/// This class listens on IR transformations performed during a pass relying on
+/// information from a `DataflowSolver`. It erases state associated with the
+/// erased operation and its results from the `DataFlowSolver` so that Patterns
+/// do not accidentally query old state information for newly created Ops.
+class DataFlowListener : public RewriterBase::Listener {
+public:
+ DataFlowListener(DataFlowSolver &s) : s(s) {}
+
+protected:
+ void notifyOperationErased(Operation *op) override {
+ s.eraseState(op);
+ for (Value res : op->getResults())
+ s.eraseState(res);
+ }
+
+ DataFlowSolver &s;
+};
+
struct ConvertCmpOp : public OpRewritePattern<arith::CmpIOp> {
ConvertCmpOp(MLIRContext *context, DataFlowSolver &s)
@@ -167,10 +185,15 @@ struct IntRangeOptimizationsPass
if (failed(solver.initializeAndRun(op)))
return signalPassFailure();
+ DataFlowListener listener(solver);
+
RewritePatternSet patterns(ctx);
populateIntRangeOptimizationsPatterns(patterns, solver);
- if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+ GreedyRewriteConfig config;
+ config.listener = &listener;
+
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
signalPassFailure();
}
};