diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -32,6 +32,9 @@ the op semantics. }]; let cppNamespace = "::mlir::linalg"; + let dependentDialects = [ + "AffineDialect", "StandardOpsDialect", "tensor::TensorDialect" + ]; } // Whether a type is a RangeType. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td @@ -946,6 +946,56 @@ return inversePermutation(getLoopsToShapesMap()); }] >, + InterfaceMethod< + /*desc=*/[{ + Return the position in the results of the affine map computed + by getLoopsToShapesMap() that represents the shape of an + operand (input or output) at a dimension. + }], + /*retTy=*/"Optional", + /*methodName=*/"getOperandDimPositionInLoopsToShapeMap", + /*args=*/(ins "unsigned":$operandIdx, "unsigned":$dim), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + unsigned pos = 0; + for (auto type : llvm::enumerate(getShapedOperandTypes())) { + if (type.index() == operandIdx) return pos + dim; + pos += type.value().getRank(); + } + return {}; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the position in the results of the affine map computed + by getLoopsToShapesMap() that represents the shape of an + input operand at a dimension. + }], + /*retTy=*/"Optional", + /*methodName=*/"getInputValueDimPositionInLoopsToShapeMap", + /*args=*/(ins "unsigned":$inputIdx, "unsigned":$dim), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + if (inputIdx >= getNumInputs()) return {}; + return getOperandDimPositionInLoopsToShapeMap(inputIdx, dim); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the position in the results of the affine map computed + by getLoopsToShapesMap() that represents the shape of the + result value at a dimension. + }], + /*retTy=*/"Optional", + /*methodName=*/"getResultValueDimPositionInLoopsToShapeMap", + /*args=*/(ins "unsigned":$resultIdx, "unsigned":$dim), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + if (resultIdx >= getNumOutputs()) return {}; + return getOperandDimPositionInLoopsToShapeMap( + getNumInputs() + resultIdx, dim); + }] + >, //===------------------------------------------------------------------===// // Other static interface methods. @@ -1027,6 +1077,12 @@ } return res; } + + /// Returns the value that expresses the shape of the output in terms of + /// shape of the input operands where possible + Optional inferResultDimFromInputShapes + (OpBuilder &b, Location loc, unsigned resultIdx, unsigned im); + //========================================================================// // Helper functions to mutate the `operand_segment_sizes` attribute. // These are useful when cloning and changing operand types. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h @@ -9,6 +9,9 @@ #ifndef MLIR_DIALECT_LINALG_LINALGTYPES_H_ #define MLIR_DIALECT_LINALG_LINALGTYPES_H_ +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Types.h" diff --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h --- a/mlir/include/mlir/IR/AffineExprVisitor.h +++ b/mlir/include/mlir/IR/AffineExprVisitor.h @@ -159,29 +159,29 @@ // Default visit methods. Note that the default op-specific binary op visit // methods call the general visitAffineBinaryOpExpr visit method. - void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {} - void visitAddExpr(AffineBinaryOpExpr expr) { - static_cast(this)->visitAffineBinaryOpExpr(expr); + RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { return RetTy(); } + RetTy visitAddExpr(AffineBinaryOpExpr expr) { + return static_cast(this)->visitAffineBinaryOpExpr(expr); } - void visitMulExpr(AffineBinaryOpExpr expr) { - static_cast(this)->visitAffineBinaryOpExpr(expr); + RetTy visitMulExpr(AffineBinaryOpExpr expr) { + return static_cast(this)->visitAffineBinaryOpExpr(expr); } - void visitModExpr(AffineBinaryOpExpr expr) { - static_cast(this)->visitAffineBinaryOpExpr(expr); + RetTy visitModExpr(AffineBinaryOpExpr expr) { + return static_cast(this)->visitAffineBinaryOpExpr(expr); } - void visitFloorDivExpr(AffineBinaryOpExpr expr) { - static_cast(this)->visitAffineBinaryOpExpr(expr); + RetTy visitFloorDivExpr(AffineBinaryOpExpr expr) { + return static_cast(this)->visitAffineBinaryOpExpr(expr); } - void visitCeilDivExpr(AffineBinaryOpExpr expr) { - static_cast(this)->visitAffineBinaryOpExpr(expr); + RetTy visitCeilDivExpr(AffineBinaryOpExpr expr) { + return static_cast(this)->visitAffineBinaryOpExpr(expr); } - void visitConstantExpr(AffineConstantExpr expr) {} - void visitDimExpr(AffineDimExpr expr) {} - void visitSymbolExpr(AffineSymbolExpr expr) {} + RetTy visitConstantExpr(AffineConstantExpr expr) { return RetTy(); } + RetTy visitDimExpr(AffineDimExpr expr) { return RetTy(); } + RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); } private: // Walk the operands - each operand is itself walked in post order. - void walkOperandsPostOrder(AffineBinaryOpExpr expr) { + RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) { walkPostOrder(expr.getLHS()); walkPostOrder(expr.getRHS()); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -16,12 +16,14 @@ #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" @@ -86,6 +88,82 @@ return res; } +/// Visitor to check if any of the given set of positions from AffineDimExprs +/// are used within an AffineExpr. +struct HasAffineDimExprVisitor + : public AffineExprVisitor { + HasAffineDimExprVisitor(llvm::SmallSet &positions) + : positions(positions) {} + + bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) { + return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS()); + } + + bool visitDimExpr(AffineDimExpr dimExpr) { + return positions.count(dimExpr.getPosition()); + } + + bool visitConstantExpr(AffineConstantExpr constExpr) { return false; } + + bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; } + +private: + llvm::SmallSet positions; +}; + +Optional LinalgOp::inferResultDimFromInputShapes(OpBuilder &b, + Location loc, + unsigned resultIdx, + unsigned dim) { + // An example that helps understand the logic below. + // Consider the following expression O(i+j, j) += A(i,k) * B(k, j) + // We want to express the shape of dim 0 of O in terms of shape of the inputs. + // This is achieved as follows. + // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1) + // subMapOfResultDim = (d0, d1, d2) -> (d0 + d1) + // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2) + // resultFromFromInputDim = subMapOfResultDim.compose(shapesToLoopMap) + // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1) + AffineMap loopsToShapesMap = getLoopsToShapesMap(); + + // Find the position in the above map that represents the shape of the + // result:dim being inferred. + Optional resultDimSubMapPos = + getResultValueDimPositionInLoopsToShapeMap(resultIdx, dim); + if (!resultDimSubMapPos) + return {}; + + /// From loopsToShapesMap extract the submap that represents the shape of the + /// (resultIdx, dim) needed + AffineMap loopToResultDimShapeMap = + loopsToShapesMap.getSubMap(*resultDimSubMapPos); + AffineMap operandShapesToResultDimMap = + loopToResultDimShapeMap.compose(getShapesToLoopsMap()); + + // Check that the result dim map does not contain the positions corresponding + // to the outputs. + llvm::SmallSet outputDims; + unsigned outputDimPosStart = + getResultValueDimPositionInLoopsToShapeMap(0, 0).getValue(); + unsigned outputDimPosEnd = + getResultValueDimPositionInLoopsToShapeMap(getNumOutputs() - 1, + getOutputOpOperands() + .back() + .get() + .getType() + .cast() + .getRank() - + 1) + .getValue(); + llvm::for_each(llvm::seq(outputDimPosStart, outputDimPosEnd), + [&outputDims](unsigned dim) { outputDims.insert(dim); }); + HasAffineDimExprVisitor checkDimExpr(outputDims); + if (checkDimExpr.visit(operandShapesToResultDimMap.getResult(0))) + return llvm::None; + return applyMapToValues(b, loc, operandShapesToResultDimMap, + createFlatListOfOperandDims(b, loc))[0]; +} + /// Forward declarations. template static void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder, @@ -2022,6 +2100,49 @@ return success(); } }; + +/// Replaces std.dim operations that use the result of a LinalgOp (on tensors) +/// with std.dim operations that use one of the arguments. For example, +/// +/// %0 = linalg.matmul ins(%arg0, %arg1, ...) +/// %1 = dim %0, %c0 +/// +/// with +/// +/// %1 = dim %arg0, %c0 +/// +/// where possible. With this the result of the `linalg.matmul` is not used in +/// dim operations. If the value produced is replaced with another value (say by +/// tiling `linalg.matmul`) will make the `linalg.matmul` truly dead instead of +/// used in a dim op that would prevent the DCE of this op. +struct ReplaceDimOfLinalgOpResult : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DimOp dimOp, + PatternRewriter &rewriter) const override { + Value dimValue = dimOp.memrefOrTensor(); + Optional dimIndex = dimOp.getConstantIndex(); + if (!dimIndex) + return failure(); + auto linalgOp = dimValue.getDefiningOp(); + if (!linalgOp) + return failure(); + + unsigned resultIndex = dimValue.cast().getResultNumber(); + Optional operandDimValue = linalgOp.inferResultDimFromInputShapes( + rewriter, dimOp.getLoc(), resultIndex, + static_cast(*dimIndex)); + if (!operandDimValue) { + // Its always possible to replace using the corresponding `outs` + // parameter. + operandDimValue = rewriter.create( + dimOp.getLoc(), linalgOp.getOutput(resultIndex), *dimIndex); + } + rewriter.replaceOp(dimOp, *operandDimValue); + return success(); + } +}; + } // namespace namespace { @@ -2166,26 +2287,6 @@ return success(); } }; - -/// Canonicalize a `linalgOp` -> `dim` pattern by replacing the `dim` arg -/// with the corresponding output tensor argument of the linalg op. -struct ReplaceDimOfLinalgResult : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DimOp dimOp, - PatternRewriter &rewriter) const override { - Value dimOpArg = dimOp.memrefOrTensor(); - auto linalgOp = dimOpArg.getDefiningOp(); - if (!linalgOp) - return failure(); - - auto results = linalgOp.getOperation()->getResults(); - int64_t id = std::distance(results.begin(), llvm::find(results, dimOpArg)); - auto outputTensors = linalgOp.getOutputTensors(); - rewriter.replaceOpWithNewOp(dimOp, outputTensors[id], dimOp.index()); - return success(); - } -}; } // namespace #define CANONICALIZERS_AND_FOLDERS(XXX) \ @@ -2193,7 +2294,7 @@ MLIRContext *context) { \ results.insert(); \ - results.insert(context); \ + results.insert(context); \ } \ \ LogicalResult XXX::fold(ArrayRef, \ diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp @@ -58,9 +58,6 @@ //===----------------------------------------------------------------------===// void mlir::linalg::LinalgDialect::initialize() { - getContext()->getOrLoadDialect("std"); - getContext()->getOrLoadDialect("tensor"); - addTypes(); addOperations< #define GET_OP_LIST diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -390,10 +390,147 @@ // ----- +func @init_tensor_dynamic_dim2(%arg0 : index, %arg1 : index) -> (index, index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = linalg.init_tensor [%arg0, %arg1] : tensor + %1 = dim %0, %c0 : tensor + %2 = dim %0, %c1 : tensor + return %1, %2 : index, index +} +// CHECK: func @init_tensor_dynamic_dim2 +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK: return %[[ARG0]], %[[ARG1]] + +// ----- + +func @remove_dim_result_uses + (%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor) -> (index) { + %c0 = constant 0 : index + %0 = linalg.generic + {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d0 + d1, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) { + ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): + %1 = mulf %arg3, %arg4 : f32 + %2 = addf %1, %arg5 : f32 + linalg.yield %2 : f32 + } -> tensor + %3 = dim %0, %c0 : tensor + return %3 : index +} +// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK: func @remove_dim_result_uses +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[T0:.+]] = dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[T1:.+]] = dim %[[ARG1]], %[[C1]] +// CHECK: %[[T2:.+]] = affine.apply #[[MAP]]()[%[[T0]], %[[T1]]] +// CHECK: return %[[T2]] + +// ----- + +func @remove_dim_result_uses_outs + (%arg0 : tensor, %arg1 : index) -> (index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %d0 = dim %arg0, %c0 : tensor + %0 = linalg.init_tensor [%d0, %arg1] : tensor + %1 = linalg.generic + {indexing_maps = [affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor) outs(%0 : tensor) { + ^bb0(%arg2: f32, %arg3: f32) : + linalg.yield %arg2 : f32 + } -> tensor + %2 = dim %1, %c1 : tensor + return %2 : index +} +// CHECK: func @remove_dim_result_uses_outs +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK: return %[[ARG1]] + +// ----- + +func @remove_dim_result_uses_sequence + (%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor) -> (index, index, index, index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + %1 = dim %0, %c0 : tensor + %2 = dim %0, %c1 : tensor + %3 = linalg.generic + {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0)>, + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d0, d2)>], + iterator_types = ["parallel", "reduction", "parallel"]} + ins(%arg0, %arg1 : tensor, tensor) + outs(%0 : tensor) { + ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32): + %4 = mulf %arg3, %arg4 : f32 + %5 = addf %4, %arg5 : f32 + linalg.yield %5 : f32 + } -> tensor + %6 = dim %3, %c0 : tensor + %7 = dim %3, %c1 : tensor + return %1, %2, %6, %7 : index, index, index, index +} +// CHECK-LABEL: func @remove_dim_result_uses_sequence +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[T0:.+]] = dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[T1:.+]] = dim %[[ARG1]], %[[C1]] +// CHECK-DAG: %[[T2:.+]] = dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[T3:.+]] = dim %[[ARG1]], %[[C1]] +// CHECK: return %[[T0]], %[[T1]], %[[T2]], %[[T3]] + +// ----- + +func @keep_result_dim_uses_sequence2 + (%arg0 : tensor, %arg1 : index) -> (index, index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %d0 = dim %arg0, %c0 : tensor + %0 = linalg.init_tensor [%d0, %arg1] : tensor + %1 = linalg.generic + {indexing_maps = [affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor) outs(%0 : tensor) { + ^bb0(%arg2: f32, %arg3 : f32): + linalg.yield %arg2 : f32 + } -> tensor + %2 = dim %1, %c0 : tensor + %3 = dim %1, %c1 : tensor + return %2, %3 : index, index +} +// CHECK: func @keep_result_dim_uses_sequence2 +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[T0:.+]] = dim %[[ARG0]], %[[C0]] +// CHECK: return %[[T0]], %[[ARG1]] + +// ----- + #map = affine_map<(d0) -> (d0)> func @init_tensor_dim_of_linalg_result(%arg_0 : tensor, - %arg_1: tensor) -> (tensor, tensor) { + %arg_1: tensor) -> (index, index) { %0, %1 = linalg.generic { indexing_maps = [#map, #map, #map], iterator_types = ["parallel"] @@ -405,16 +542,16 @@ %c0 = constant 0 : index %num_elem_0 = dim %0, %c0 : tensor - %result_0 = linalg.init_tensor [%num_elem_0] : tensor %num_elem_1 = dim %1, %c0 : tensor - %result_1 = linalg.init_tensor [%num_elem_1] : tensor - return %result_0, %result_1 : tensor, tensor + return %num_elem_0, %num_elem_1 : index, index } -// CHECK-LABEL: func @init_tensor_dim_of_linalg_result( -// CHECK-SAME: [[ARG_0:%.*]]: tensor, [[ARG_1:%.*]]: tensor) -// CHECK: dim [[ARG_0]] -// CHECK: dim [[ARG_1]] +// CHECK: func @init_tensor_dim_of_linalg_result( +// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG_1:[a-zA-Z0-9_]+]]: tensor) +// CHECK: %[[R0:.+]] = dim %[[ARG_0]] +// CHECK: %[[R1:.+]] = dim %[[ARG_0]] +// CHECK: return %[[R0]], %[[R1]] // -----