diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1717,6 +1717,85 @@ return success(); } }; + +/// Replaces std.dim operations that use the result of a LinalgOp (on tensors) +/// with std.dim operations that use one of the arguments. For example, +/// +/// %0 = linalg.matmul ins(%arg0, %arg1, ...) +/// %1 = dim %0, %c0 +/// +/// with +/// +/// %1 = dim %arg0, %c0 +/// +/// where possible. With this the result of the `linalg.matmul` is not used in +/// dim operations. If the value produced is replaced with another value (say by +/// tiling `linalg.matmul`) will make the `linalg.matmul` truly dead instead of +/// used in a dim op that would prevert the DCE of this op. +struct ReplaceDimOfLinalgOpResult : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DimOp dimOp, + PatternRewriter &rewriter) const override { + Value dimValue = dimOp.memrefOrTensor(); + Optional dimIndex = dimOp.getConstantIndex(); + if (!dimIndex) + return failure(); + LinalgOp linalgOp = dimValue.getDefiningOp(); + if (!linalgOp) + return failure(); + + // Get the indexing map of the current result value which is the operand of + // the dim op. + auto indexingMaps = linalgOp.getIndexingMaps(); + unsigned resultIndex = dimValue.cast().getResultNumber(); + AffineMap resultIndexMap = + indexingMaps[resultIndex + linalgOp.getNumInputs()]; + AffineDimExpr resultShapeDimExpr = + resultIndexMap.getResult(*dimIndex).dyn_cast(); + if (!resultShapeDimExpr) + return failure(); + unsigned resultDimPosition = resultShapeDimExpr.getPosition(); + // If there is another indexing map from the input that uses the same + // AffineDimExpr as the output, then we can use that instead. For example, + // in `linalg.matmul` + // + // LHS indexing map : affine_map<(d0, d1, d2) -> (d0, d2)> + // Output indexing map : affine_map<(d0, d1, d2) -> (d0, d1)> + // + // so + // + // dim %{{matmul-output}}, %c0 + // + // is same as + // + // dim %{{matmul-LHS}}, %c0 + // + // since + // {{LHS indexing map}}.getResult(0) == {{RHS indexing map}}.getResult(0) + for (unsigned inputIndex : + llvm::seq(0, linalgOp.getNumInputs())) { + unsigned operandIndex = + linalgOp.getOperandIndexForInputIndex(inputIndex).getValue(); + for (auto inputMapResultExpr : + enumerate(indexingMaps[operandIndex].getResults())) { + AffineDimExpr inputMapResultDimExpr = + inputMapResultExpr.value().dyn_cast(); + if (!inputMapResultDimExpr) + continue; + unsigned inputMapResultDimPosition = + inputMapResultDimExpr.getPosition(); + if (inputMapResultDimPosition != resultDimPosition) + continue; + rewriter.replaceOpWithNewOp(dimOp, linalgOp.getInput(inputIndex), + inputMapResultExpr.index()); + return success(); + } + } + return failure(); + } +}; + } // namespace namespace { @@ -1825,6 +1904,7 @@ results.insert(); \ results.insert(); \ results.insert(); \ + results.insert(context); \ } \ \ LogicalResult XXX::fold(ArrayRef, \ diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -351,3 +351,42 @@ outs(%b : memref) return } + +// ----- + +func @remove_dim_result_uses + (%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor) -> (index, index, index, index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) + init(%arg2 : tensor) -> tensor + %1 = dim %0, %c0 : tensor + %2 = dim %0, %c1 : tensor + %3 = linalg.generic + {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0)>, + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d0, d2)>], + iterator_types = ["parallel", "reduction", "parallel"]} + ins(%arg0, %arg1 : tensor, tensor) + init(%arg2 : tensor) { + ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): + %4 = mulf %arg3, %arg4 : f32 + %5 = addf %4, %arg5 : f32 + linalg.yield %5 : f32 + } -> tensor + %6 = dim %3, %c0 : tensor + %7 = dim %3, %c1 : tensor + return %1, %2, %6, %7 : index, index, index, index +} +// CHECK-LABEL: func @remove_dim_result_uses +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[T0:.+]] = dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[T1:.+]] = dim %[[ARG1]], %[[C1]] +// CHECK-DAG: %[[T2:.+]] = dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[T3:.+]] = dim %[[ARG1]], %[[C1]] +// CHECK: return %[[T0]], %[[T1]], %[[T2]], %[[T3]] \ No newline at end of file