diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -105,6 +105,34 @@ return success(folded); } +/// Helper function to find if there is atleast one dimension in an AffineMap +/// testMap that is contained in location testMapLocation of +/// SmallVector Maps but not in any other locations +static bool hasaUniqueDim(SmallVector Maps, + unsigned testMapLocation) { + AffineMap testMap = Maps[testMapLocation]; + llvm::SmallDenseSet dimsToCheck; + for (auto result : testMap.getResults()) { + auto expr = result.dyn_cast(); + if (expr != nullptr) + dimsToCheck.insert(expr.getPosition()); + } + for (auto It : llvm::enumerate(Maps)) { + if (It.index() == testMapLocation) + continue; + auto map = It.value(); + for (auto result : map.getResults()) { + auto expr = result.dyn_cast(); + if (expr != nullptr) { + dimsToCheck.erase(expr.getPosition()); + } + if (dimsToCheck.empty()) + return false; + } + } + return true; +} + //===----------------------------------------------------------------------===// // Region builder helper. // TODO: Move this to a utility library. @@ -870,11 +898,95 @@ return success(); } }; + +/// Drop dead args of a linalg generic op. +/// An arg is dead if it has zero uses in the op region. +struct DeadArgsGenericOpInputs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + SmallVector oldIndexingMaps = genericOp.getIndexingMaps(); + // Maps must be projected permutations. + if (llvm::any_of(genericOp.getIndexingMaps(), [](AffineMap map) { + return !map.isProjectedPermutation(); + })) + return failure(); + Block &payload = genericOp.region().front(); + SmallVector newInputOperands; + SmallVector newIndexingMaps; + bool deadArgFound = false; + int inputSize = genericOp.getInputOperands().size(); + for (int i = inputSize - 1; i >= 0; i--) { + OpOperand *opOperand = genericOp.getInputOperand(i); + // Iterate in reverse, so that we erase later args first, preventing the + // argument list from shifting unexpectedly and invalidating all our + // indices. + if (payload.getArgument(i).use_empty() && + !hasaUniqueDim(oldIndexingMaps, i)) { + payload.eraseArgument(i); + deadArgFound = true; + // remove this indexing map out of consideration for hasaUniqueDim check + oldIndexingMaps.erase(oldIndexingMaps.begin() + i); + } else { + newInputOperands.insert(newInputOperands.begin(), opOperand->get()); + newIndexingMaps.insert(newIndexingMaps.begin(), + genericOp.getTiedIndexingMap(opOperand)); + } + } + // Bail out if there are no dead args. + if (!deadArgFound) + return failure(); + for (OpOperand *opOperand : genericOp.getOutputOperands()) + newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand)); + SmallVector outputOperands = genericOp.getOutputOperands(); + + auto newOp = rewriter.create( + genericOp.getLoc(), genericOp->getResultTypes(), newInputOperands, + outputOperands, rewriter.getAffineMapArrayAttr(newIndexingMaps), + genericOp.iterator_types(), genericOp.docAttr(), + genericOp.library_callAttr()); + // Copy over unknown attributes. They might be load bearing for some flow. + ArrayRef odsAttrs = genericOp.getAttributeNames(); + for (NamedAttribute kv : genericOp->getAttrs()) { + if (!llvm::is_contained(odsAttrs, kv.getName().getValue())) { + newOp->setAttr(kv.getName(), kv.getValue()); + } + } + rewriter.inlineRegionBefore(genericOp.region(), newOp.region(), + newOp.region().begin()); + rewriter.replaceOp(genericOp, newOp->getResults()); + return success(); + } +}; + +/// Fold linalg.fill into linalg.generic +struct FoldFillWithGenericOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + if (!genericOp.hasTensorSemantics()) + return failure(); + bool fillFound = false; + Block &payload = genericOp.region().front(); + for (OpOperand *opOperand : genericOp.getInputOperands()) { + FillOp fillOp = opOperand->get().getDefiningOp(); + if (fillOp) { + fillFound = true; + payload.getArgument(opOperand->getOperandNumber()) + .replaceAllUsesWith(fillOp.value()); + } + } + // fail if there are no FillOps to fold. + return success(fillFound); + } +}; } // namespace void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -325,6 +325,106 @@ // ----- +// CHECK-LABEL: func @fold_fill_generic_basic +// CHECK-SAME: (%[[ARG0:.*]]: tensor) -> tensor { +// CHECK-NOT: linalg.fill +// CHECK: %[[GENERIC_OP:.*]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]] : tensor) +// CHECK-SAME: outs({{.*}} : tensor) { +#map0 = affine_map<(d0) -> (d0)> +func @fold_fill_generic_basic(%arg0: tensor) -> (tensor) { + %c0 = arith.constant 0 : index + %cst = arith.constant 7.0 : f32 + %0 = tensor.dim %arg0, %c0 : tensor + %1 = linalg.init_tensor [%0] : tensor + %2 = linalg.fill(%cst, %1) : f32, tensor -> tensor + %3 = linalg.init_tensor [%0] : tensor + %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor, tensor) outs (%3:tensor) { + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): + %5 = arith.addf %arg1, %arg2 : f32 + linalg.yield %5 : f32 + } -> tensor + return %4 : tensor +} + +// ----- + +// CHECK-LABEL: func @fold_fill_generic_mixedaccess +// CHECK-NOT: linalg.fill +// CHECK: %[[GENERIC_OP:.*]] = linalg.generic +// CHECK-NOT: ins +// CHECK-SAME: outs({{.*}} : tensor) { +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d1, d0)> +func @fold_fill_generic_mixedaccess(%arg0: tensor) -> (tensor) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 0 : index + %cst1 = arith.constant 7.0 : f32 + %cst2 = arith.constant 6.0 : f32 + %0 = tensor.dim %arg0, %c0 : tensor + %1 = tensor.dim %arg0, %c1 : tensor + %2 = linalg.init_tensor [%0, %1] : tensor + %3 = linalg.fill(%cst1, %2) : f32, tensor -> tensor + %4 = linalg.init_tensor [%1, %0] : tensor + %5 = linalg.fill(%cst2, %4) : f32, tensor -> tensor + %6 = linalg.init_tensor [%0, %1] : tensor + %7 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types=["parallel","parallel"]} ins(%3, %5 : tensor, tensor) outs (%6:tensor) { + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): + %8 = arith.divf %arg1, %arg2 : f32 + linalg.yield %8 : f32 + } -> tensor + return %7 : tensor +} + +// ----- + +// CHECK-LABEL: func @remove_deadargs_generic_basic +// CHECK-SAME: (%[[ARG0:.*]]: tensor) -> tensor { +// CHECK: %[[GENERIC_OP:.*]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]] : tensor) +// CHECK-SAME: outs({{.*}} : tensor) { +#map0 = affine_map<(d0) -> (d0)> +func @remove_deadargs_generic_basic(%arg0: tensor) -> (tensor) { + %c0 = arith.constant 0 : index + %cst = arith.constant 7.0 : f32 + %0 = tensor.dim %arg0, %c0 : tensor + %1 = linalg.init_tensor [%0] : tensor + %2 = linalg.init_tensor [%0] : tensor + %3 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %1 : tensor, tensor) outs (%2:tensor) { + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): + %4 = arith.addf %arg1, %cst : f32 + linalg.yield %4 : f32 + } -> tensor + return %3 : tensor +} + +// ----- + +// CHECK-LABEL: func @remove_deadargs_generic_mixedaccess +// CHECK: %[[GENERIC_OP:.*]] = linalg.generic +// CHECK-NOT: ins +// CHECK-SAME: outs({{.*}} : tensor) { +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d1, d0)> +func @remove_deadargs_generic_mixedaccess(%arg0: tensor) -> (tensor) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 0 : index + %cst1 = arith.constant 7.0 : f32 + %cst2 = arith.constant 6.0 : f32 + %0 = tensor.dim %arg0, %c0 : tensor + %1 = tensor.dim %arg0, %c1 : tensor + %2 = linalg.init_tensor [%0, %1] : tensor + %3 = linalg.init_tensor [%1, %0] : tensor + %4 = linalg.init_tensor [%0, %1] : tensor + %5 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types=["parallel","parallel"]} ins(%2, %3 : tensor, tensor) outs (%4:tensor) { + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): + %6 = arith.divf %cst1, %cst2 : f32 + linalg.yield %6 : f32 + } -> tensor + return %5 : tensor +} + +// ----- // CHECK-LABEL: func @fold_fill_reshape() func @fold_fill_reshape() -> tensor<6x4xf32> { %zero = arith.constant 0.0 : f32 diff --git a/mlir/test/Dialect/Linalg/fusion-indexed.mlir b/mlir/test/Dialect/Linalg/fusion-indexed.mlir --- a/mlir/test/Dialect/Linalg/fusion-indexed.mlir +++ b/mlir/test/Dialect/Linalg/fusion-indexed.mlir @@ -46,7 +46,8 @@ %10 = arith.index_cast %7 : index to i32 %11 = arith.sitofp %10 : i32 to f32 %12 = arith.addf %9, %11 : f32 - linalg.yield %12 : f32 + %13 = arith.addf %12, %arg4 : f32 + linalg.yield %13 : f32 } } }