diff --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h --- a/mlir/include/mlir/Transforms/FoldUtils.h +++ b/mlir/include/mlir/Transforms/FoldUtils.h @@ -23,7 +23,6 @@ class Operation; class Value; - //===--------------------------------------------------------------------===// // OperationFolder //===--------------------------------------------------------------------===// @@ -34,6 +33,11 @@ public: OperationFolder(MLIRContext *ctx) : interfaces(ctx) {} + /// Scan the specified region for constants that can be used in folding, + /// moving them to the entry block and adding them to our known-constants + /// table. + void processExistingConstants(Region ®ion); + /// Tries to perform folding on the given `op`, including unifying /// deduplicated constants. If successful, replaces `op`'s uses with /// folded results, and returns success. `preReplaceAction` is invoked on `op` diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -84,6 +84,81 @@ // OperationFolder //===----------------------------------------------------------------------===// +/// Scan the specified region for constants that can be used in folding, +/// moving them to the entry block and adding them to our known-constants +/// table. +void OperationFolder::processExistingConstants(Region ®ion) { + if (region.empty()) + return; + + // March the constant insertion point forward, moving all constants to the + // top of the block, but keeping them in their order of discovery. + Region *insertRegion = getInsertionRegion(interfaces, ®ion.front()); + auto &uniquedConstants = foldScopes[insertRegion]; + + Block &insertBlock = insertRegion->front(); + Block::iterator constantIterator = insertBlock.begin(); + + // Process each constant that we discover in this region. + auto processConstant = [&](Operation *op, Attribute value) { + // Check to see if we already have an instance of this constant. + Operation *&constOp = uniquedConstants[std::make_tuple( + op->getDialect(), value, op->getResult(0).getType())]; + + // If we already have an instance of this constant, CSE/delete this one as + // we go. + if (constOp) { + if (constantIterator == Block::iterator(op)) + ++constantIterator; // Don't invalidate our iterator when scanning. + op->getResult(0).replaceAllUsesWith(constOp->getResult(0)); + op->erase(); + return; + } + + // Otherwise, remember that we have this constant. + constOp = op; + referencedDialects[op].push_back(op->getDialect()); + + // If the constant isn't already at the insertion point then move it up. + if (constantIterator == insertBlock.end() || &*constantIterator != op) + op->moveBefore(&insertBlock, constantIterator); + else + ++constantIterator; // It was pointing at the constant. + }; + + SmallVector isolatedOps; + region.walk([&](Operation *op) { + // If this is a constant, process it. + Attribute value; + if (matchPattern(op, m_Constant(&value))) { + processConstant(op, value); + // We may have deleted the operation, don't check it for regions. + return WalkResult::skip(); + } + + // If the operation has regions and is isolated, don't recurse into it. + if (op->getNumRegions() != 0) { + auto hasDifferentInsertRegion = [&](Region ®ion) { + return !region.empty() && + getInsertionRegion(interfaces, ®ion.front()) != insertRegion; + }; + if (llvm::any_of(op->getRegions(), hasDifferentInsertRegion)) { + isolatedOps.push_back(op); + return WalkResult::skip(); + } + } + + // Otherwise keep going. + return WalkResult::advance(); + }); + + // Process regions in any isolated ops separately. + for (Operation *isolated : isolatedOps) { + for (Region ®ion : isolated->getRegions()) + processExistingConstants(region); + } +} + LogicalResult OperationFolder::tryToFold( Operation *op, function_ref processGeneratedConstants, function_ref preReplaceAction, bool *inPlaceUpdate) { @@ -262,19 +337,19 @@ Attribute value, Type type, Location loc) { // Check if an existing mapping already exists. auto constKey = std::make_tuple(dialect, value, type); - auto *&constInst = uniquedConstants[constKey]; - if (constInst) - return constInst; + auto *&constOp = uniquedConstants[constKey]; + if (constOp) + return constOp; // If one doesn't exist, try to materialize one. - if (!(constInst = materializeConstant(dialect, builder, value, type, loc))) + if (!(constOp = materializeConstant(dialect, builder, value, type, loc))) return nullptr; // Check to see if the generated constant is in the expected dialect. - auto *newDialect = constInst->getDialect(); + auto *newDialect = constOp->getDialect(); if (newDialect == dialect) { - referencedDialects[constInst].push_back(dialect); - return constInst; + referencedDialects[constOp].push_back(dialect); + return constOp; } // If it isn't, then we also need to make sure that the mapping for the new @@ -284,13 +359,13 @@ // If an existing operation in the new dialect already exists, delete the // materialized operation in favor of the existing one. if (auto *existingOp = uniquedConstants.lookup(newKey)) { - constInst->erase(); + constOp->erase(); referencedDialects[existingOp].push_back(dialect); - return constInst = existingOp; + return constOp = existingOp; } // Otherwise, update the new dialect to the materialized operation. - referencedDialects[constInst].assign({dialect, newDialect}); - auto newIt = uniquedConstants.insert({newKey, constInst}); + referencedDialects[constOp].assign({dialect, newDialect}); + auto newIt = uniquedConstants.insert({newKey, constOp}); return newIt.first->second; } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -107,7 +107,8 @@ // be re-added to the worklist. This function should be called when an // operation is modified or removed, as it may trigger further // simplifications. - template void addToWorklist(Operands &&operands) { + template + void addToWorklist(Operands &&operands) { for (Value operand : operands) { // If the use count of this operand is now < 2, we re-add the defining // operation to the worklist. @@ -140,15 +141,26 @@ /// if the rewrite converges in `maxIterations`. bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions, int maxIterations) { - // Add the given operation to the worklist. - auto collectOps = [this](Operation *op) { addToWorklist(op); }; + // Perform a prepass over the IR to discover constants. + for (auto ®ion : regions) + folder.processExistingConstants(region); bool changed = false; - int i = 0; + int iteration = 0; do { - // Add all nested operations to the worklist. + worklist.clear(); + worklistMap.clear(); + + // Add all nested operations to the worklist in preorder. for (auto ®ion : regions) - region.walk(collectOps); + region.walk( + [this](Operation *op) { worklist.push_back(op); }); + + // Reverse the list so our pop-back loop processes them in-order. + std::reverse(worklist.begin(), worklist.end()); + // Remember the reverse index. + for (unsigned i = 0, e = worklist.size(); i != e; ++i) + worklistMap[worklist[i]] = i; // These are scratch vectors used in the folding loop below. SmallVector originalOperands, resultValues; @@ -186,6 +198,9 @@ notifyOperationRemoved(op); }; + // Add the given operation to the worklist. + auto collectOps = [this](Operation *op) { addToWorklist(op); }; + // Try to fold this op. bool inPlaceUpdate; if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction, @@ -203,7 +218,8 @@ // After applying patterns, make sure that the CFG of each of the regions is // kept up to date. changed |= succeeded(simplifyRegions(*this, regions)); - } while (changed && ++i < maxIterations); + } while (changed && ++iteration < maxIterations); + // Whether the rewrite converges, i.e. wasn't changed in the last iteration. return !changed; } diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir --- a/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir +++ b/mlir/test/Conversion/VectorToSCF/vector-to-loops.mlir @@ -204,12 +204,13 @@ // CHECK-DAG: %[[C0:.*]] = constant 0 : index // CHECK-DAG: %[[splat:.*]] = constant dense<7.000000e+00> : vector<15xf32> // CHECK-DAG: %[[alloc:.*]] = memref.alloca() : memref<3xvector<15xf32>> + // CHECK-DAG: [[CST:%.*]] = constant 7.000000e+00 : f32 // CHECK-DAG: %[[dim:.*]] = memref.dim %[[A]], %[[C0]] : memref // CHECK: affine.for %[[I:.*]] = 0 to 3 { // CHECK: %[[add:.*]] = affine.apply #[[$MAP0]](%[[I]])[%[[base]]] // CHECK: %[[cond1:.*]] = cmpi slt, %[[add]], %[[dim]] : index // CHECK: scf.if %[[cond1]] { - // CHECK: %[[vec_1d:.*]] = vector.transfer_read %[[A]][%[[add]], %[[base]]], %cst : memref, vector<15xf32> + // CHECK: %[[vec_1d:.*]] = vector.transfer_read %[[A]][%[[add]], %[[base]]], [[CST]] : memref, vector<15xf32> // CHECK: store %[[vec_1d]], %[[alloc]][%[[I]]] : memref<3xvector<15xf32>> // CHECK: } else { // CHECK: store %[[splat]], %[[alloc]][%[[I]]] : memref<3xvector<15xf32>> @@ -217,13 +218,14 @@ // CHECK: %[[vmemref:.*]] = vector.type_cast %[[alloc]] : memref<3xvector<15xf32>> to memref> // CHECK: %[[cst:.*]] = memref.load %[[vmemref]][] : memref> - // FULL-UNROLL: %[[VEC0:.*]] = constant dense<7.000000e+00> : vector<3x15xf32> - // FULL-UNROLL: %[[C0:.*]] = constant 0 : index - // FULL-UNROLL: %[[SPLAT:.*]] = constant dense<7.000000e+00> : vector<15xf32> + // FULL-UNROLL-DAG: %[[VEC0:.*]] = constant dense<7.000000e+00> : vector<3x15xf32> + // FULL-UNROLL-DAG: %[[C0:.*]] = constant 0 : index + // FULL-UNROLL-DAG: %[[SPLAT:.*]] = constant dense<7.000000e+00> : vector<15xf32> + // FULL-UNROLL-DAG: [[CST:%.*]] = constant 7.000000e+00 : f32 // FULL-UNROLL: %[[DIM:.*]] = memref.dim %[[A]], %[[C0]] : memref // FULL-UNROLL: cmpi slt, %[[base]], %[[DIM]] : index // FULL-UNROLL: %[[VEC1:.*]] = scf.if %{{.*}} -> (vector<3x15xf32>) { - // FULL-UNROLL: vector.transfer_read %[[A]][%[[base]], %[[base]]], %cst : memref, vector<15xf32> + // FULL-UNROLL: vector.transfer_read %[[A]][%[[base]], %[[base]]], [[CST]] : memref, vector<15xf32> // FULL-UNROLL: vector.insert %{{.*}}, %[[VEC0]] [0] : vector<15xf32> into vector<3x15xf32> // FULL-UNROLL: scf.yield %{{.*}} : vector<3x15xf32> // FULL-UNROLL: } else { @@ -233,7 +235,7 @@ // FULL-UNROLL: affine.apply #[[$MAP1]]()[%[[base]]] // FULL-UNROLL: cmpi slt, %{{.*}}, %[[DIM]] : index // FULL-UNROLL: %[[VEC2:.*]] = scf.if %{{.*}} -> (vector<3x15xf32>) { - // FULL-UNROLL: vector.transfer_read %[[A]][%{{.*}}, %[[base]]], %cst : memref, vector<15xf32> + // FULL-UNROLL: vector.transfer_read %[[A]][%{{.*}}, %[[base]]], [[CST]] : memref, vector<15xf32> // FULL-UNROLL: vector.insert %{{.*}}, %[[VEC1]] [1] : vector<15xf32> into vector<3x15xf32> // FULL-UNROLL: scf.yield %{{.*}} : vector<3x15xf32> // FULL-UNROLL: } else { @@ -243,7 +245,7 @@ // FULL-UNROLL: affine.apply #[[$MAP2]]()[%[[base]]] // FULL-UNROLL: cmpi slt, %{{.*}}, %[[DIM]] : index // FULL-UNROLL: %[[VEC3:.*]] = scf.if %{{.*}} -> (vector<3x15xf32>) { - // FULL-UNROLL: vector.transfer_read %[[A]][%{{.*}}, %[[base]]], %cst : memref, vector<15xf32> + // FULL-UNROLL: vector.transfer_read %[[A]][%{{.*}}, %[[base]]], [[CST]] : memref, vector<15xf32> // FULL-UNROLL: vector.insert %{{.*}}, %[[VEC2]] [2] : vector<15xf32> into vector<3x15xf32> // FULL-UNROLL: scf.yield %{{.*}} : vector<3x15xf32> // FULL-UNROLL: } else { @@ -377,16 +379,16 @@ // CHECK-LABEL: transfer_read_minor_identity( // CHECK-SAME: %[[A:.*]]: memref) -> vector<3x3xf32> -// CHECK-DAG: %[[c0:.*]] = constant 0 : index -// CHECK-DAG: %cst = constant 0.000000e+00 : f32 // CHECK-DAG: %[[c2:.*]] = constant 2 : index // CHECK-DAG: %[[cst0:.*]] = constant dense<0.000000e+00> : vector<3xf32> // CHECK: %[[m:.*]] = memref.alloca() : memref<3xvector<3xf32>> +// CHECK-DAG: %[[cst:.*]] = constant 0.000000e+00 : f32 +// CHECK-DAG: %[[c0:.*]] = constant 0 : index // CHECK: %[[d:.*]] = memref.dim %[[A]], %[[c2]] : memref // CHECK: affine.for %[[arg1:.*]] = 0 to 3 { // CHECK: %[[cmp:.*]] = cmpi slt, %[[arg1]], %[[d]] : index // CHECK: scf.if %[[cmp]] { -// CHECK: %[[tr:.*]] = vector.transfer_read %[[A]][%c0, %c0, %[[arg1]], %c0], %cst : memref, vector<3xf32> +// CHECK: %[[tr:.*]] = vector.transfer_read %[[A]][%c0, %c0, %[[arg1]], %c0], %[[cst]] : memref, vector<3xf32> // CHECK: store %[[tr]], %[[m]][%[[arg1]]] : memref<3xvector<3xf32>> // CHECK: } else { // CHECK: store %[[cst0]], %[[m]][%[[arg1]]] : memref<3xvector<3xf32>> @@ -409,8 +411,8 @@ // CHECK-SAME: %[[A:.*]]: vector<3x3xf32>, // CHECK-SAME: %[[B:.*]]: memref) // CHECK-DAG: %[[c2:.*]] = constant 2 : index -// CHECK-DAG: %[[c0:.*]] = constant 0 : index // CHECK: %[[m:.*]] = memref.alloca() : memref<3xvector<3xf32>> +// CHECK-DAG: %[[c0:.*]] = constant 0 : index // CHECK: %[[cast:.*]] = vector.type_cast %[[m]] : memref<3xvector<3xf32>> to memref> // CHECK: store %[[A]], %[[cast]][] : memref> // CHECK: %[[d:.*]] = memref.dim %[[B]], %[[c2]] : memref diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -207,7 +207,7 @@ // ----- -// CHECK-DAG: #[[$MAP14:.*]] = affine_map<()[s0, s1] -> (((s1 + s0) * 4) floordiv s0)> +// CHECK-DAG: #[[$MAP14:.*]] = affine_map<()[s0, s1] -> ((s0 * 4 + s1 * 4) floordiv s0)> // CHECK-LABEL: func @compose_affine_maps_multiple_symbols func @compose_affine_maps_multiple_symbols(%arg0: index, %arg1: index) -> index { @@ -312,7 +312,7 @@ // ----- -// CHECK-DAG: #[[$MAP_symbolic_composition_d:.*]] = affine_map<()[s0, s1] -> (s0 + s1 * 3)> +// CHECK-DAG: #[[$MAP_symbolic_composition_d:.*]] = affine_map<()[s0, s1] -> (s0 * 3 + s1)> // CHECK-LABEL: func @symbolic_composition_d( // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: index @@ -321,7 +321,7 @@ %0 = affine.apply affine_map<(d0) -> (d0)>(%arg0) %1 = affine.apply affine_map<()[s0] -> (s0)>()[%arg1] %2 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 + s1 + s2 + s3)>()[%0, %0, %0, %1] - // CHECK: %{{.*}} = affine.apply #[[$MAP_symbolic_composition_d]]()[%[[ARG1]], %[[ARG0]]] + // CHECK: %{{.*}} = affine.apply #[[$MAP_symbolic_composition_d]]()[%[[ARG0]], %[[ARG1]]] return %2 : index } diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -336,7 +336,7 @@ return } // CHECK-LABEL: func @aligned_promote_fill -// CHECK: %[[cf:.*]] = constant {{.*}} : f32 +// CHECK: %[[cf:.*]] = constant 1.0{{.*}} : f32 // CHECK: %[[s0:.*]] = memref.subview {{%.*}}[{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] [{{%.*}}, {{%.*}}] : memref to memref // CHECK: %[[a0:.*]] = memref.alloc({{%.*}}) {alignment = 32 : i64} : memref // CHECK: %[[v0:.*]] = memref.view %[[a0]][{{.*}}][{{%.*}}, {{%.*}}] : memref to memref diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -234,10 +234,10 @@ // CHECK: [[T0:%.*]] = vector.transpose [[ARG]], [2, 1, 0] %0 = vector.transpose %arg, [1, 2, 0] : vector<4x3x2xf32> to vector<3x2x4xf32> %1 = vector.transpose %0, [1, 0, 2] : vector<3x2x4xf32> to vector<2x3x4xf32> - // CHECK-NOT: transpose + // CHECK: [[T1:%.*]] = vector.transpose [[ARG]], [2, 1, 0] %2 = vector.transpose %1, [2, 1, 0] : vector<2x3x4xf32> to vector<4x3x2xf32> %3 = vector.transpose %2, [2, 1, 0] : vector<4x3x2xf32> to vector<2x3x4xf32> - // CHECK: [[MUL:%.*]] = mulf [[T0]], [[T0]] + // CHECK: [[MUL:%.*]] = mulf [[T0]], [[T1]] %4 = mulf %1, %3 : vector<2x3x4xf32> // CHECK: [[T5:%.*]] = vector.transpose [[MUL]], [2, 1, 0] %5 = vector.transpose %4, [2, 1, 0] : vector<2x3x4xf32> to vector<4x3x2xf32> diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -630,7 +630,7 @@ // // CHECK-LABEL: func @lowered_affine_ceildiv func @lowered_affine_ceildiv() -> (index, index) { -// CHECK-NEXT: %c-1 = constant -1 : index +// CHECK-DAG: %c-1 = constant -1 : index %c-43 = constant -43 : index %c42 = constant 42 : index %c0 = constant 0 : index @@ -643,7 +643,7 @@ %5 = subi %c0, %4 : index %6 = addi %4, %c1 : index %7 = select %0, %5, %6 : index -// CHECK-NEXT: %c2 = constant 2 : index +// CHECK-DAG: %c2 = constant 2 : index %c43 = constant 43 : index %c42_0 = constant 42 : index %c0_1 = constant 0 : index diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -5,8 +5,8 @@ %0 = "test.op_a"(%arg0) {attr = 10 : i32} : (i32) -> i32 loc("a") %result = "test.op_a"(%0) {attr = 20 : i32} : (i32) -> i32 loc("b") - // CHECK: "test.op_b"(%arg0) {attr = 10 : i32} : (i32) -> i32 loc("a") - // CHECK: "test.op_b"(%arg0) {attr = 20 : i32} : (i32) -> i32 loc(fused["b", "a"]) + // CHECK: %0 = "test.op_b"(%arg0) {attr = 10 : i32} : (i32) -> i32 loc("a") + // CHECK: %1 = "test.op_b"(%0) {attr = 20 : i32} : (i32) -> i32 loc("b") return %result : i32 } @@ -67,7 +67,7 @@ %2 = "test.op_g"(%1) : (i32) -> i32 // CHECK: "test.op_f"(%arg0) - // CHECK: "test.op_b"(%arg0) {attr = 34 : i32} + // CHECK: "test.op_b"(%arg0) {attr = 20 : i32} return %0 : i32 }