diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td @@ -412,8 +412,8 @@ /*methodBody=*/"", /*defaultImplementation=*/[{ auto range = this->getOperation()->getOperands(); - return {range.begin() + getNumInputsAndOutputBuffers(), - range.begin() + getNumInputsAndOutputs()}; + auto base = range.begin() + getNumInputsAndOutputBuffers(); + return {base, base + $_op.getNumInitTensors()}; }] >, InterfaceMethod< @@ -739,7 +739,7 @@ /// allow transformations like tiling to just use the values when cloning /// `linalgOp`. SmallVector getAssumedNonShapedOperands() { - unsigned numShapedOperands = getNumInputsAndOutputs(); + unsigned numShapedOperands = getNumShapedOperands(); unsigned nExtraOperands = getOperation()->getNumOperands() - numShapedOperands; SmallVector res; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -25,6 +25,7 @@ #include "mlir/IR/StandardTypes.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/FormatVariadic.h" @@ -1641,11 +1642,112 @@ }; } // namespace +namespace { +// Deduplicate redundant args of a linalg op. +// An arg is redundant if it has the same Value and indexing map as another. +struct DeduplicateInputs : public RewritePattern { + DeduplicateInputs(PatternBenefit benefit = 1) + : RewritePattern(benefit, MatchAnyOpTypeTag()) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + // This pattern reduces the number of arguments of an op, which breaks + // the invariants of semantically charged named ops. + if (!isa(op)) + return failure(); + auto linalgOp = cast(op); + + // Associate each input to an equivalent "canonical" input that has the same + // Value and indexing map. + // + // In the non-duplicate case, input `i` will have canonical input `i`. But + // in the case of duplicated inputs, the canonical input could be some other + // input `< i`. That is, a later input will have some earlier input as its + // canonical input. + llvm::SmallDenseMap, int> canonicalInput; + // For later remapping tasks like deduplicating payload block arguments, + // having a simple "inputIndex -> canonicalInputIndex" integer mapping is + // convenient. + SmallVector canonicalInputIndices; + for (int i = 0, e = linalgOp.getNumInputs(); i != e; i++) { + Value input = linalgOp.getInput(i); + AffineMap indexingMap = linalgOp.getInputIndexingMap(i); + // STL-like maps have a convenient behavior for our use case here. In the + // case of duplicate keys, the insertion is rejected, and the returned + // iterator gives access to the value already in the map. + auto pair = canonicalInput.insert({{input, indexingMap}, i}); + canonicalInputIndices.push_back(pair.first->second); + } + + // If there are no duplicate args, then bail out. + if (canonicalInput.size() == linalgOp.getNumInputs()) + return failure(); + + // The operands for the newly canonicalized op. + SmallVector newOperands; + for (auto v : llvm::enumerate(linalgOp.getInputs())) + if (canonicalInputIndices[v.index()] == static_cast(v.index())) + newOperands.push_back(v.value()); + llvm::append_range(newOperands, linalgOp.getOutputBuffers()); + llvm::append_range(newOperands, linalgOp.getInitTensors()); + llvm::append_range(newOperands, linalgOp.getAssumedNonShapedOperands()); + + // Clone the old op with new operands. + Operation *newOp = linalgOp.clone(rewriter, op->getLoc(), + op->getResultTypes(), newOperands); + auto newLinalgOp = cast(newOp); + + // Repair the indexing maps by filtering out the ones that have been + // eliminated. + SmallVector newIndexingMaps; + for (int i = 0, e = newLinalgOp.getNumInputs(); i != e; i++) + if (canonicalInputIndices[i] == i) + newIndexingMaps.push_back(newLinalgOp.getIndexingMap(i)); + for (int i = 0, e = newLinalgOp.getNumOutputs(); i != e; i++) + newIndexingMaps.push_back(newLinalgOp.getOutputIndexingMap(i)); + newOp->setAttr("indexing_maps", + rewriter.getAffineMapArrayAttr(newIndexingMaps)); + + // Set the number of inputs to the new value. The `clone` call above kept + // the value from the original op. + newLinalgOp.setNumInputs(canonicalInput.size()); + + // linalg.indexed_generic payloads have additional arguments prepended to + // the block arg list. The number of such args is one per dimension of the + // iteration space. + int bbArgBaseOffset = 0; + if (isa(op)) + bbArgBaseOffset = newIndexingMaps[0].getNumInputs(); + + // Repair the payload entry block by RAUW'ing redundant arguments and + // erasing them. + Block &payload = newOp->getRegion(0).front(); + for (int i = 0, e = linalgOp.getNumInputs(); i < e; i++) { + // Iterate in reverse, so that we erase later args first, preventing the + // argument list from shifting unexpectedly and invalidating all our + // indices. + int reversed = e - i - 1; + int canonicalIndex = canonicalInputIndices[reversed]; + if (canonicalInputIndices[reversed] == reversed) + continue; + payload.getArgument(bbArgBaseOffset + reversed) + .replaceAllUsesWith( + payload.getArgument(bbArgBaseOffset + canonicalIndex)); + payload.eraseArgument(bbArgBaseOffset + reversed); + } + + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } +}; +} // namespace + #define CANONICALIZERS_AND_FOLDERS(XXX) \ void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results, \ MLIRContext *context) { \ results.insert(); \ results.insert(); \ + results.insert(); \ } \ \ LogicalResult XXX::fold(ArrayRef, \ diff --git a/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir b/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir @@ -0,0 +1,104 @@ +// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s + +// Test case: Most basic case. Adding a vector to itself. + +#map = affine_map<(d0) -> (d0)> + +// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: @basic +func @basic(%arg0: tensor) -> tensor { + // CHECK: linalg.generic{{.*}}[#[[$MAP]], #[[$MAP]]] + // CHECK: ^bb0(%[[BBARG:.*]]: f32): + // CHECK: addf %[[BBARG]], %[[BBARG]] + %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg0 : tensor, tensor) { + ^bb0(%arg1: f32, %arg2: f32): + %1 = addf %arg1, %arg2 : f32 + linalg.yield %1 : f32 + } -> tensor + return %0 : tensor +} + +// ----- + +// Test case: Different indexing maps mean that args are not redundant, despite +// being the same Value. + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d1, d0)> + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1, d0)> +// CHECK-LABEL: @distinct_affine_maps +func @distinct_affine_maps(%arg0: tensor) -> tensor { + // CHECK: linalg.generic{{.*}}[#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]] + %0 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg0 : tensor, tensor) { + ^bb0(%arg1: f32, %arg2: f32): + %1 = addf %arg1, %arg2 : f32 + linalg.yield %1 : f32 + } -> tensor + return %0 : tensor +} + +// ----- + +// Test case: Check rewriting mechanics for mixed redundant and +// non-redundant args. + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d1, d0)> + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1, d0)> +// CHECK-LABEL: @mixed_redundant_non_redundant +func @mixed_redundant_non_redundant(%arg0: tensor) -> tensor { + // CHECK: linalg.generic{{.*}}[#[[$MAP0]], #[[$MAP1]], #[[$MAP0]]] + // CHECK: ^bb0(%[[BBARG0:.*]]: f32, %[[BBARG1:.*]]: f32): + // CHECK: "test.elementwise_mappable"(%[[BBARG0]], %[[BBARG1]], %[[BBARG0]]) + %0 = linalg.generic {indexing_maps = [#map0, #map1, #map0, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg0, %arg0 : tensor, tensor, tensor) { + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): + %1 = "test.elementwise_mappable"(%arg1, %arg2, %arg3) : (f32, f32, f32) -> f32 + linalg.yield %1 : f32 + } -> tensor + return %0 : tensor +} + +// ----- + +// Test case: Check rewriting mechanics for multiple different redundant args. + +#map = affine_map<(d0) -> (d0)> + +// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: @multiple_different_redundant_args +func @multiple_different_redundant_args(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: linalg.generic{{.*}}[#[[$MAP]], #[[$MAP]], #[[$MAP]]] + // CHECK: ^bb0(%[[BBARG0:.*]]: f32, %[[BBARG1:.*]]: f32): + // CHECK: "test.elementwise_mappable"(%[[BBARG0]], %[[BBARG1]], %[[BBARG0]], %[[BBARG1]]) + %0 = linalg.generic {indexing_maps = [#map, #map, #map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1, %arg0, %arg1 : tensor, tensor, tensor, tensor) { + ^bb0(%arg2: f32, %arg3: f32, %arg4: f32, %arg5: f32): + %1 = "test.elementwise_mappable"(%arg2, %arg3, %arg4, %arg5) : (f32, f32, f32, f32) -> f32 + linalg.yield %1 : f32 + } -> tensor + return %0 : tensor +} + +// ----- + +// Test case: linalg.indexed_generic. +// Other than the payload argument handling, everything else is the same. + +#map = affine_map<(d0) -> (d0)> + +// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: @indexed_generic +func @indexed_generic(%arg0: tensor) -> tensor { + // CHECK: linalg.indexed_generic + // CHECK: ^bb0(%{{.*}}: index, %[[BBARG:.*]]: f32): + // CHECK: addf %[[BBARG]], %[[BBARG]] + %0 = linalg.indexed_generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg0 : tensor, tensor) { + ^bb0(%index: index, %arg1: f32, %arg2: f32): + %1 = addf %arg1, %arg2 : f32 + linalg.yield %1 : f32 + } -> tensor + return %0 : tensor +}