diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -1153,8 +1153,8 @@ /// Returns the value that expresses the shape of the output in terms of /// shape of the input operands where possible - Optional inferResultDimFromInputShapes - (OpBuilder &b, Location loc, unsigned resultIdx, unsigned im); + Value reifyReturnTypeShapeForResult(OpBuilder &b, unsigned resultIdx, + int64_t dim); //========================================================================// // Helper functions to mutate the `operand_segment_sizes` attribute. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -22,6 +22,7 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/CopyOpInterface.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "mlir/Support/LLVM.h" @@ -107,13 +108,6 @@ void getDimsOfType(Operation *op, StringRef iteratorTypeName, SmallVectorImpl &res); -/// For reshape operation, compute the shape of the output based on the result -/// type and shape of the input. -SmallVector -getReshapeOutputShapeFromInputShape(OpBuilder &b, Location loc, Value src, - ArrayRef dstStaticShape, - ArrayRef reassociation); - namespace detail { LogicalResult verifyStructuredOpInterface(Operation *op); } // namespace detail diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -15,6 +15,7 @@ include "mlir/Dialect/Linalg/IR/LinalgBase.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" @@ -33,7 +34,10 @@ let parser = [{ return ::parse$cppClass(parser, result); }]; } -def Linalg_InitTensorOp : Linalg_Op<"init_tensor", [NoSideEffect]> { +def Linalg_InitTensorOp : Linalg_Op<"init_tensor", + [NoSideEffect, + DeclareOpInterfaceMethods]> { let summary = "operation to define a tensor of particular value"; let description = [{ @@ -126,7 +130,10 @@ } def Linalg_PadTensorOp : Linalg_Op<"pad_tensor", - [AttrSizedOperandSegments, NoSideEffect]> { + [AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + NoSideEffect]> { let summary = "tensor pad operation"; let description = [{ `linalg.pad_tensor` is an operation that pads the `source` tensor @@ -348,11 +355,6 @@ a.cast().getValue().getResults()); })); } - SmallVector getOutputShape(OpBuilder &b, Location loc) { - return getReshapeOutputShapeFromInputShape( - b, loc, src(), getResultType().getShape(), - getReassociationMaps()); - } }]; let assemblyFormat = [{ $src $reassociation attr-dict `:` type($src) `into` type(results) @@ -417,7 +419,10 @@ let hasCanonicalizer = 1; } -def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<"tensor_reshape">, +def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp< + "tensor_reshape", + [DeclareOpInterfaceMethods]>, Arguments<(ins AnyTensor:$src, AffineMapArrayAttr:$reassociation)>, Results<(outs AnyTensor:$result)> { diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -17,6 +17,7 @@ include "mlir/Dialect/Linalg/IR/LinalgBase.td" include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" include "mlir/Interfaces/CopyOpInterface.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" // Base Tablegen class for Linalg ops. @@ -25,7 +26,7 @@ // depending on the specific Linalg op. class LinalgStructuredBase_Op props> : Op { + LinalgStructuredInterface, InferShapedTypeOpInterface])> { code structuredOpsBaseDecls = [{ // Return the number of induction variables in the basic block. This should // always be 0 for index-free linalg ops. For IndexedGeneric, this must be @@ -33,6 +34,12 @@ unsigned getNumPayloadInductionVariables() { return isa(this->getOperation()) ? getNumLoops() : 0; } + + Value reifyReturnTypeShapeForResult(OpBuilder &b, unsigned resultIdx, + int64_t dim) { + return cast(getOperation()). + reifyReturnTypeShapeForResult(b, resultIdx, dim); + } }]; } diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h @@ -16,6 +16,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Support/LLVM.h" diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -98,20 +98,80 @@ "::mlir::DictionaryAttr":$attributes, "::mlir::RegionRange":$regions, "::mlir::SmallVectorImpl<::mlir::ShapedTypeComponents>&": - $inferredReturnShapes) + $inferredReturnShapes), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ return failure(); }] >, InterfaceMethod< /*desc=*/[{Reify the shape computation for the operation. Insert operations using the given OpBuilder that computes the result - shape. + shape along a particular dimension of result at `resultIndex`. + }], + /*retTy=*/"::mlir::Value", + /*methodName=*/"reifyReturnTypeShapeForResult", + /*args=*/(ins "::mlir::OpBuilder&":$builder, + "unsigned":$resultIndex, + "int64_t":$dim), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + if (resultIndex == 0) return $_op.reifyReturnTypeShape(builder, dim); + return nullptr; }] + >, + InterfaceMethod< + /*desc=*/[{Reify the shape computation for the operation. + + Insert operations using the given OpBuilder that computes the + result shape along a particular dimension. Valid only when the + operation has a single result value. + }], + /*retTy=*/"::mlir::Value", + /*methodName=*/"reifyReturnTypeShape", + /*args=*/(ins "::mlir::OpBuilder&":$builder, + "int64_t":$dim), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ return nullptr; }] + >, + InterfaceMethod< + /*desc=*/[{Reify the shape computation for the operation. + + Insert operations using the given OpBuilder that computes the result + shape for a result at `resultIndex`. + }], + /*retTy=*/"::mlir::LogicalResult", + /*methodName=*/"reifyReturnTypeShapesForResult", + /*args=*/(ins "::mlir::OpBuilder&":$builder, + "unsigned":$resultIndex, + "::mlir::SmallVectorImpl<::mlir::Value>&":$reifiedReturnShapes), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + if (resultIndex == 0) + return $_op.reifyReturnTypeShapes(builder, reifiedReturnShapes); + return ::mlir::failure(); + }] + >, + InterfaceMethod< + /*desc=*/[{Reify the shape computation for the operation. + + Insert operations using the given OpBuilder that computes the result + shape. Valid only when the operation has a single result value. }], /*retTy=*/"::mlir::LogicalResult", /*methodName=*/"reifyReturnTypeShapes", /*args=*/(ins "::mlir::OpBuilder&":$builder, "::mlir::SmallVectorImpl<::mlir::Value>&":$reifiedReturnShapes), /*methodBody=*/[{}], - /*defaultImplementation=*/[{ return ::mlir::failure(); }] + /*defaultImplementation=*/[{ + ShapedType type = + $_op.getOperation()->getResult(0).getType().template dyn_cast(); + if (!type) return failure(); + for (auto dim : llvm::seq(0, type.getRank())) { + Value shape = $_op.reifyReturnTypeShape(builder, dim); + if (!shape) return failure(); + reifiedReturnShapes.push_back(shape); + } + return success(); + }] >, ]; } diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt @@ -14,6 +14,7 @@ LINK_LIBS PUBLIC MLIRAffine MLIRDialectUtils + MLIRInferTypeOpInterface MLIRIR MLIRParser MLIRSideEffectInterfaces diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -233,10 +233,8 @@ llvm::SmallSet positions; }; -Optional LinalgOp::inferResultDimFromInputShapes(OpBuilder &b, - Location loc, - unsigned resultIdx, - unsigned dim) { +Value LinalgOp::reifyReturnTypeShapeForResult(OpBuilder &b, unsigned resultIdx, + int64_t dim) { // An example that helps understand the logic below. // Consider the following expression O(i+j, j) += A(i,k) * B(k, j) // We want to express the shape of dim 0 of O in terms of shape of the inputs. @@ -280,8 +278,9 @@ llvm::for_each(llvm::seq(outputDimPosStart, outputDimPosEnd), [&outputDims](unsigned dim) { outputDims.insert(dim); }); HasAffineDimExprVisitor checkDimExpr(outputDims); + Location loc = getOperation()->getLoc(); if (checkDimExpr.visit(operandShapesToResultDimMap.getResult(0))) - return llvm::None; + return b.createOrFold(loc, getOutput(resultIdx), dim); return applyMapToValues(b, loc, operandShapesToResultDimMap, createFlatListOfOperandDims(b, loc))[0]; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -20,6 +20,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Parser.h" #include "llvm/ADT/DenseMap.h" @@ -770,33 +771,6 @@ return success(); } }; - -/// Canonicalize a `linalg.init_tensor` -> `dim` pattern by replacing the `dim` -/// with -/// - A constant value if the size is static along the dimension. -/// - The dynamic value that defines the size of the result of -/// `linalg.init_tensor` op. -struct ReplaceDimOfInitTensorOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DimOp dimOp, - PatternRewriter &rewriter) const override { - auto initTensorOp = dimOp.memrefOrTensor().getDefiningOp(); - if (!initTensorOp) - return failure(); - auto dimIndex = dimOp.index().getDefiningOp(); - if (!dimIndex) - return failure(); - int64_t index = dimIndex.getValue(); - if (!initTensorOp.isDynamicSize(index)) { - rewriter.replaceOpWithNewOp( - dimOp, initTensorOp.getStaticSize(index)); - } else { - rewriter.replaceOp(dimOp, initTensorOp.getDynamicSize(index)); - } - return success(); - } -}; } // namespace namespace { @@ -830,8 +804,10 @@ if (!reshapeOp.src().getDefiningOp()) return failure(); Location loc = reshapeOp.getLoc(); - SmallVector resultShapeValues = - reshapeOp.getOutputShape(rewriter, loc); + SmallVector resultShapeValues; + if (failed(reshapeOp.reifyReturnTypeShapesForResult(rewriter, 0, + resultShapeValues))) + return failure(); Value initTensor = rewriter.create( loc, resultShapeValues, reshapeOp.getResultType().getElementType()); rewriter.replaceOpWithNewOp( @@ -843,9 +819,15 @@ void InitTensorOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results - .insert(context); + results.insert( + context); +} + +Value InitTensorOp::reifyReturnTypeShape(OpBuilder &builder, int64_t dim) { + if (isDynamicSize(dim)) + return getDynamicSize(dim); + return builder.create(getLoc(), getStaticSize(dim)); } //===----------------------------------------------------------------------===// @@ -997,6 +979,22 @@ builder); } +Value PadTensorOp::reifyReturnTypeShape(OpBuilder &b, int64_t dim) { + Location loc = getLoc(); + auto getAsValue = [&](OpFoldResult valueOrAttr) -> Value { + if (Attribute attr = valueOrAttr.dyn_cast()) + return b.create(loc, attr.cast().getInt()); + return valueOrAttr.get(); + }; + auto lowPad = getAsValue(getMixedLowPad()[dim]); + auto highPad = getAsValue(getMixedHighPad()[dim]); + Value sourceDim = b.create(loc, source(), dim); + AffineExpr expr = b.getAffineDimExpr(0) + b.getAffineSymbolExpr(0) + + b.getAffineSymbolExpr(1); + return applyMapToValues(b, loc, AffineMap::get(1, 2, expr), + {sourceDim, lowPad, highPad})[0]; +} + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// @@ -1384,9 +1382,10 @@ })); } -SmallVector mlir::linalg::getReshapeOutputShapeFromInputShape( - OpBuilder &builder, Location loc, Value src, - ArrayRef dstStaticShape, ArrayRef reassocation) { +static SmallVector +getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src, + ArrayRef dstStaticShape, + ArrayRef reassocation) { return dstStaticShape.size() > static_cast(src.getType().cast().getRank()) ? getExpandedOutputShapeFromInputShape( @@ -1635,35 +1634,26 @@ return success(); } }; - -/// Canonicalize dim ops that use the output shape with dim of the input. -struct ReplaceDimOfReshapeOpResult : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(DimOp dimOp, - PatternRewriter &rewriter) const override { - Value dimValue = dimOp.memrefOrTensor(); - Optional dimIndex = dimOp.getConstantIndex(); - if (!dimIndex) - return failure(); - - auto reshapeOp = dimValue.getDefiningOp(); - if (!reshapeOp) - return failure(); - - rewriter.replaceOp(dimOp, - getReshapeOutputDimFromInputShape( - rewriter, dimOp.getLoc(), *dimIndex, reshapeOp.src(), - reshapeOp.getResultType().getShape(), - reshapeOp.getReassociationMaps())); - return success(); - } -}; } // namespace void TensorReshapeOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - results.insert, FoldReshapeWithConstant, - ReplaceDimOfReshapeOpResult>(context); + results.insert, FoldReshapeWithConstant>( + context); +} + +Value TensorReshapeOp::reifyReturnTypeShape(OpBuilder &b, int64_t dim) { + return getReshapeOutputDimFromInputShape(b, getLoc(), dim, src(), + getResultType().getShape(), + getReassociationMaps()); +} + +LogicalResult TensorReshapeOp::reifyReturnTypeShapes( + OpBuilder &b, SmallVectorImpl &reifiedReturnShapes) { + auto resultShape = getReshapeOutputShapeFromInputShape( + b, getLoc(), src(), getResultType().getShape(), getReassociationMaps()); + reifiedReturnShapes.append(resultShape.begin(), resultShape.end()); + return success(); } //===----------------------------------------------------------------------===// @@ -2432,49 +2422,6 @@ return success(); } }; - -/// Replaces std.dim operations that use the result of a LinalgOp (on tensors) -/// with std.dim operations that use one of the arguments. For example, -/// -/// %0 = linalg.matmul ins(%arg0, %arg1, ...) -/// %1 = dim %0, %c0 -/// -/// with -/// -/// %1 = dim %arg0, %c0 -/// -/// where possible. With this the result of the `linalg.matmul` is not used in -/// dim operations. If the value produced is replaced with another value (say by -/// tiling `linalg.matmul`) will make the `linalg.matmul` truly dead instead of -/// used in a dim op that would prevent the DCE of this op. -struct ReplaceDimOfLinalgOpResult : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DimOp dimOp, - PatternRewriter &rewriter) const override { - Value dimValue = dimOp.memrefOrTensor(); - Optional dimIndex = dimOp.getConstantIndex(); - if (!dimIndex) - return failure(); - auto linalgOp = dimValue.getDefiningOp(); - if (!linalgOp) - return failure(); - - unsigned resultIndex = dimValue.cast().getResultNumber(); - Optional operandDimValue = linalgOp.inferResultDimFromInputShapes( - rewriter, dimOp.getLoc(), resultIndex, - static_cast(*dimIndex)); - if (!operandDimValue) { - // Its always possible to replace using the corresponding `outs` - // parameter. - operandDimValue = rewriter.create( - dimOp.getLoc(), linalgOp.getOutput(resultIndex), *dimIndex); - } - rewriter.replaceOp(dimOp, *operandDimValue); - return success(); - } -}; - } // namespace namespace { @@ -2628,7 +2575,6 @@ MLIRContext *context) { \ results.insert(); \ - results.insert(context); \ } \ \ LogicalResult XXX::fold(ArrayRef, \ diff --git a/mlir/lib/Dialect/StandardOps/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/CMakeLists.txt --- a/mlir/lib/Dialect/StandardOps/CMakeLists.txt +++ b/mlir/lib/Dialect/StandardOps/CMakeLists.txt @@ -15,6 +15,7 @@ MLIRCastInterfaces MLIRControlFlowInterfaces MLIREDSC + MLIRInferTypeOpInterface MLIRIR MLIRSideEffectInterfaces MLIRTensor diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Support/MathExtras.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/STLExtras.h" @@ -1556,12 +1557,41 @@ return success(); } }; + +/// Fold dim of an operation that implements the InferShapedTypeOpInterface +struct DimOfShapedTypeOpInterface : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DimOp dimOp, + PatternRewriter &rewriter) const override { + OpResult dimValue = dimOp.memrefOrTensor().dyn_cast(); + if (!dimValue) + return failure(); + auto shapedTypeOp = + dyn_cast(dimValue.getOwner()); + if (!shapedTypeOp) + return failure(); + + Optional dimIndex = dimOp.getConstantIndex(); + if (!dimIndex) + return failure(); + + Value dimReplacement = shapedTypeOp.reifyReturnTypeShapeForResult( + rewriter, dimValue.getResultNumber(), *dimIndex); + if (!dimReplacement) + return failure(); + + rewriter.replaceOp(dimOp, dimReplacement); + return success(); + } +}; } // end anonymous namespace. void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert, - DimOfCastOp>(context); + DimOfCastOp, DimOfShapedTypeOpInterface>( + context); } // --------------------------------------------------------------------------- diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -13,8 +13,6 @@ #include "mlir/Interfaces/InferTypeOpInterface.h" -#include "mlir/IR/BuiltinTypes.h" - using namespace mlir; namespace mlir { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -499,13 +499,14 @@ } def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [ - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods]> { let arguments = (ins AnyTensor, AnyTensor); let results = (outs AnyTensor); } def OpWithShapedTypeInferTypeInterfaceOp : TEST_Op<"op_with_shaped_type_infer_type_if", - InferTensorTypeWithReify.traits> { + InferTensorType<["inferReturnTypeComponents", "reifyReturnTypeShapes"]>.traits> { let arguments = (ins AnyTensor, AnyTensor); let results = (outs AnyTensor); }