diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -32,6 +32,7 @@ the op semantics. }]; let cppNamespace = "::mlir::linalg"; + let dependentDialects = ["AffineDialect", "StandardOpsDialect"]; } // Whether a type is a RangeType. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td @@ -542,9 +542,12 @@ /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - SmallVector inputOutputTypes( - this->getOperation()->operand_type_begin(), - this->getOperation()->operand_type_end()); + SmallVector inputOutputTypes = + llvm::to_vector<4>(this->getOperation()->getOperandTypes()); + if ($_op.getNumInitTensors()) { + inputOutputTypes.resize( + inputOutputTypes.size() - $_op.getNumInitTensors()); + } inputOutputTypes.append(this->getOperation()->result_type_begin(), this->getOperation()->result_type_end()); return llvm::to_vector<4>( @@ -898,6 +901,12 @@ } return res; } + + /// Returns the value that expresses the shape of the output in terms of + /// shape of the input operands where possible + Optional inferResultDimFromInputShapes + (OpBuilder &b, Location loc, unsigned resultIdx, unsigned im); + //========================================================================// // Helper functions to mutate the `operand_segment_sizes` attribute. // These are useful when cloning and changing operand types. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h @@ -9,6 +9,8 @@ #ifndef MLIR_DIALECT_LINALG_LINALGTYPES_H_ #define MLIR_DIALECT_LINALG_LINALGTYPES_H_ +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Types.h" diff --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h --- a/mlir/include/mlir/IR/AffineExprVisitor.h +++ b/mlir/include/mlir/IR/AffineExprVisitor.h @@ -159,29 +159,29 @@ // Default visit methods. Note that the default op-specific binary op visit // methods call the general visitAffineBinaryOpExpr visit method. - void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {} - void visitAddExpr(AffineBinaryOpExpr expr) { - static_cast(this)->visitAffineBinaryOpExpr(expr); + RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { return RetTy(); } + RetTy visitAddExpr(AffineBinaryOpExpr expr) { + return static_cast(this)->visitAffineBinaryOpExpr(expr); } - void visitMulExpr(AffineBinaryOpExpr expr) { - static_cast(this)->visitAffineBinaryOpExpr(expr); + RetTy visitMulExpr(AffineBinaryOpExpr expr) { + return static_cast(this)->visitAffineBinaryOpExpr(expr); } - void visitModExpr(AffineBinaryOpExpr expr) { - static_cast(this)->visitAffineBinaryOpExpr(expr); + RetTy visitModExpr(AffineBinaryOpExpr expr) { + return static_cast(this)->visitAffineBinaryOpExpr(expr); } - void visitFloorDivExpr(AffineBinaryOpExpr expr) { - static_cast(this)->visitAffineBinaryOpExpr(expr); + RetTy visitFloorDivExpr(AffineBinaryOpExpr expr) { + return static_cast(this)->visitAffineBinaryOpExpr(expr); } - void visitCeilDivExpr(AffineBinaryOpExpr expr) { - static_cast(this)->visitAffineBinaryOpExpr(expr); + RetTy visitCeilDivExpr(AffineBinaryOpExpr expr) { + return static_cast(this)->visitAffineBinaryOpExpr(expr); } - void visitConstantExpr(AffineConstantExpr expr) {} - void visitDimExpr(AffineDimExpr expr) {} - void visitSymbolExpr(AffineSymbolExpr expr) {} + RetTy visitConstantExpr(AffineConstantExpr expr) { return RetTy(); } + RetTy visitDimExpr(AffineDimExpr expr) { return RetTy(); } + RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); } private: // Walk the operands - each operand is itself walked in post order. - void walkOperandsPostOrder(AffineBinaryOpExpr expr) { + RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) { walkPostOrder(expr.getLHS()); walkPostOrder(expr.getRHS()); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -16,12 +16,14 @@ #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" @@ -86,6 +88,95 @@ return res; } +/// Visitor to check if any of the given set of positions from AffineDimExprs +/// are used within an AffineExpr. +struct HasAffineDimExprVisitor + : public AffineExprVisitor { + HasAffineDimExprVisitor(llvm::SmallSet &positions) + : positions(positions) {} + + bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) { + return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS()); + } + + bool visitDimExpr(AffineDimExpr dimExpr) { + return positions.count(dimExpr.getPosition()); + } + + bool visitConstantExpr(AffineConstantExpr constExpr) { return false; } + + bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; } + +private: + llvm::SmallSet positions; +}; + +Optional LinalgOp::inferResultDimFromInputShapes(OpBuilder &b, + Location loc, + unsigned resultIdx, + unsigned dim) { + // An example that helps understand the logic below. + // Consider the following expression O(i+j, j) += A(i,k) * B(k, j) + // We want to express the shape of dim 0 of O in terms of shape of the inputs. + // This is achieved as follows. + // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1) + // subMapOfResultDim = (d0, d1, d2) -> (d0 + d1) + // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2) + // resultFromFromInputDim = subMapOfResultDim.compose(shapesToLoopMap) + // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1) + AffineMap loopsToShapesMap = getLoopsToShapesMap(); + + // Find the position in the above map that represents the shape of the + // result:dim being infered. It is + // (Sum the rank of all the inputs + (results < resultIdx)) + dim + SmallVector inputOutputShapedTypes = + getInputOutputShapedTypes(); + Optional resultOperandIndex = + getOperandIndexForOutputIndex(resultIdx); + if (!resultOperandIndex || + *resultOperandIndex >= inputOutputShapedTypes.size()) + return llvm::None; + unsigned resultDimSubMapPos = 0; + for (unsigned idx : llvm::seq(0, *resultOperandIndex)) { + resultDimSubMapPos += inputOutputShapedTypes[idx].getRank(); + } + resultDimSubMapPos += dim; + + /// From loopsToShapesMap extract the submap that represents the shape of the + /// (resultIdx, dim) needed + if (resultDimSubMapPos >= loopsToShapesMap.getNumResults()) + return llvm::None; + AffineMap loopToResultDimShapeMap = + loopsToShapesMap.getSubMap(resultDimSubMapPos); + AffineMap operandShapesToResultDimMap = + loopToResultDimShapeMap.compose(getShapesToLoopsMap()); + // Check that the result dim map does not contain the positions corresponding + // to the outputs. + unsigned numInputDims = 0; + for (unsigned idx : llvm::seq(0, getNumInputs())) { + numInputDims += + inputOutputShapedTypes[getOperandIndexForInputIndex(idx).getValue()] + .getRank(); + } + llvm::SmallSet outputDims; + llvm::for_each(llvm::seq(numInputDims, + operandShapesToResultDimMap.getNumDims()), + [&outputDims](unsigned dim) { outputDims.insert(dim); }); + HasAffineDimExprVisitor checkDimExpr(outputDims); + if (checkDimExpr.visit(operandShapesToResultDimMap.getResult(0))) + return llvm::None; + + // Since map.compose(createFlatListOfOperands()) only works for the case where + // init tensors exist, drop the dims corresponding to output shapes in the + // map. + if (getNumInitTensors() == 0 && getOperation()->getNumResults() != 0) { + operandShapesToResultDimMap = getProjectedMap( + operandShapesToResultDimMap, llvm::to_vector<4>(outputDims)); + } + return applyMapToValues(b, loc, operandShapesToResultDimMap, + createFlatListOfOperandDims(b, loc))[0]; +} + /// Forward declarations. template static void buildNamedStructuredOpRegionAndAttributes( @@ -1717,6 +1808,45 @@ return success(); } }; + +/// Replaces std.dim operations that use the result of a LinalgOp (on tensors) +/// with std.dim operations that use one of the arguments. For example, +/// +/// %0 = linalg.matmul ins(%arg0, %arg1, ...) +/// %1 = dim %0, %c0 +/// +/// with +/// +/// %1 = dim %arg0, %c0 +/// +/// where possible. With this the result of the `linalg.matmul` is not used in +/// dim operations. If the value produced is replaced with another value (say by +/// tiling `linalg.matmul`) will make the `linalg.matmul` truly dead instead of +/// used in a dim op that would prevent the DCE of this op. +struct ReplaceDimOfLinalgOpResult : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DimOp dimOp, + PatternRewriter &rewriter) const override { + Value dimValue = dimOp.memrefOrTensor(); + Optional dimIndex = dimOp.getConstantIndex(); + if (!dimIndex) + return failure(); + auto linalgOp = dimValue.getDefiningOp(); + if (!linalgOp) + return failure(); + + unsigned resultIndex = dimValue.cast().getResultNumber(); + Optional operandDimValue = linalgOp.inferResultDimFromInputShapes( + rewriter, dimOp.getLoc(), resultIndex, + static_cast(*dimIndex)); + if (!operandDimValue) + return failure(); + rewriter.replaceOp(dimOp, *operandDimValue); + return success(); + } +}; + } // namespace namespace { @@ -1825,6 +1955,7 @@ results.insert(); \ results.insert(); \ results.insert(); \ + results.insert(context); \ } \ \ LogicalResult XXX::fold(ArrayRef, \ diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp @@ -58,7 +58,6 @@ //===----------------------------------------------------------------------===// void mlir::linalg::LinalgDialect::initialize() { - getContext()->getOrLoadDialect("std"); addTypes(); addOperations< diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -351,3 +351,101 @@ outs(%b : memref) return } + +// ----- + +func @remove_dim_result_uses + (%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor) -> (index, index, index, index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) + init(%arg2 : tensor) -> tensor + %1 = dim %0, %c0 : tensor + %2 = dim %0, %c1 : tensor + %3 = linalg.generic + {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0)>, + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d0, d2)>], + iterator_types = ["parallel", "reduction", "parallel"]} + ins(%arg0, %arg1 : tensor, tensor) + init(%arg2 : tensor) { + ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): + %4 = mulf %arg3, %arg4 : f32 + %5 = addf %4, %arg5 : f32 + linalg.yield %5 : f32 + } -> tensor + %6 = dim %3, %c0 : tensor + %7 = dim %3, %c1 : tensor + return %1, %2, %6, %7 : index, index, index, index +} +// CHECK-LABEL: func @remove_dim_result_uses +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[T0:.+]] = dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[T1:.+]] = dim %[[ARG1]], %[[C1]] +// CHECK-DAG: %[[T2:.+]] = dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[T3:.+]] = dim %[[ARG1]], %[[C1]] +// CHECK: return %[[T0]], %[[T1]], %[[T2]], %[[T3]] + +// ----- + +func @remove_dim_result_uses2 + (%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor) -> (index) { + %c0 = constant 0 : index + %0 = linalg.generic + {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0 + d1, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%arg0, %arg1 : tensor, tensor) + init(%arg2 : tensor) { + ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): + %1 = mulf %arg3, %arg4 : f32 + %2 = addf %1, %arg5 : f32 + linalg.yield %2 : f32 + } -> tensor + %3 = dim %0, %c0 : tensor + return %3 : index +} +// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK: func @remove_dim_result_uses2 +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[T0:.+]] = dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[T1:.+]] = dim %[[ARG1]], %[[C1]] +// CHECK: %[[T2:.+]] = affine.apply #[[MAP]]()[%[[T0]], %[[T1]]] +// CHECK: return %[[T2]] + +// ----- + +func @keep_result_dim_uses(%arg0 : tensor) -> (index, index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = linalg.generic + {indexing_maps = [affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor) { + ^bb0(%arg1: f32): + linalg.yield %arg1 : f32 + } -> tensor + %1 = dim %0, %c0 : tensor + %2 = dim %0, %c1 : tensor + return %1, %2 : index, index +} +// CHECK: func @keep_result_dim_uses +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK: %[[OP:.+]] = linalg.generic +// CHECK-DAG: %[[T0:.+]] = dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[T1:.+]] = dim %[[OP]], %[[C1]] +// CHECK: return %[[T0]], %[[T1]]