diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -153,6 +153,7 @@ let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; + let hasCanonicalizer = 1; } def FillOp : LinalgStructured_Op<"fill", [NInputs<0>, NOutputs<1>]> { @@ -178,6 +179,7 @@ let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; + let hasCanonicalizer = 1; } /// A base class for pooling operation such as conv. The arguments must contain @@ -358,6 +360,7 @@ let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; + let hasCanonicalizer = 1; } class SingleInputPoolingBase_Op @@ -417,6 +420,7 @@ let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; + let hasCanonicalizer = 1; } def PoolingMaxOp: SingleInputPoolingBase_Op<"pooling_max"> { @@ -659,6 +663,7 @@ let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; + let hasCanonicalizer = 1; } /// GenericOp with Indexing (i.e. multi-for style in which the region is passed @@ -796,6 +801,7 @@ let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -818,6 +824,7 @@ let printer = [{ return ::printNamedStructuredOp(p, *this); }]; let verifier = [{ return ::verifyNamedStructuredOp(*this); }]; let hasFolder = 1; + let hasCanonicalizer = 1; } // This file is auto-generated from a tc specification. diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1153,38 +1153,6 @@ // TODO: Consider making all this boilerplate easy to autogenerate // with Tablegen. This seems a desirable property in the context of OpInterfaces // where a Linalg "named" op **isa** LinalgOp. -LogicalResult ConvOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult PoolingMaxOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult PoolingMinOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult PoolingSumOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult CopyOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult FillOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult GenericOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult IndexedGenericOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} OpFoldResult ReshapeOp::fold(ArrayRef) { if (succeeded(foldMemRefCast(*this))) return getResult(); @@ -1299,58 +1267,61 @@ return verifyGenericOp(op); } +struct EraseDeadLinalgOp : public RewritePattern { + EraseDeadLinalgOp(PatternBenefit benefit = 1) + : RewritePattern(benefit, MatchAnyOpTypeTag()) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto linalgOp = dyn_cast(op); + if (!linalgOp) + return failure(); + for (Value v : linalgOp.getInputsAndOutputBuffers()) { + auto mt = v.getType().dyn_cast(); + for (int64_t s : mt.getShape()) { + if (s == 0) { + rewriter.eraseOp(linalgOp); + return success(); + } + } + } + return failure(); + } +}; + +#define CANONICALIZERS_AND_FOLDERS(XXX) \ + void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results, \ + MLIRContext *context) { \ + results.insert(); \ + } \ + \ + LogicalResult XXX::fold(ArrayRef, \ + SmallVectorImpl &) { \ + return foldMemRefCast(*this); \ + } + +CANONICALIZERS_AND_FOLDERS(ConvOp); +CANONICALIZERS_AND_FOLDERS(PoolingMaxOp); +CANONICALIZERS_AND_FOLDERS(PoolingMinOp); +CANONICALIZERS_AND_FOLDERS(PoolingSumOp); +CANONICALIZERS_AND_FOLDERS(CopyOp); +CANONICALIZERS_AND_FOLDERS(FillOp); +CANONICALIZERS_AND_FOLDERS(GenericOp); +CANONICALIZERS_AND_FOLDERS(IndexedGenericOp); + #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc" // TODO: Determine whether we can generate the folders and verifiers. -LogicalResult BatchMatmulOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult DotOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult MatmulOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult MatvecOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult ConvWOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult ConvNWCOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult ConvNCWOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult ConvHWOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult ConvNHWCOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult ConvNCHWOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult ConvDHWOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult ConvNDHWCOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} -LogicalResult ConvNCDHWOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} +CANONICALIZERS_AND_FOLDERS(BatchMatmulOp); +CANONICALIZERS_AND_FOLDERS(DotOp); +CANONICALIZERS_AND_FOLDERS(MatmulOp); +CANONICALIZERS_AND_FOLDERS(MatvecOp); +CANONICALIZERS_AND_FOLDERS(ConvWOp); +CANONICALIZERS_AND_FOLDERS(ConvNWCOp); +CANONICALIZERS_AND_FOLDERS(ConvNCWOp); +CANONICALIZERS_AND_FOLDERS(ConvHWOp); +CANONICALIZERS_AND_FOLDERS(ConvNHWCOp); +CANONICALIZERS_AND_FOLDERS(ConvNCHWOp); +CANONICALIZERS_AND_FOLDERS(ConvDHWOp); +CANONICALIZERS_AND_FOLDERS(ConvNDHWCOp); +CANONICALIZERS_AND_FOLDERS(ConvNCDHWOp); diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -732,19 +732,16 @@ AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef sizes, ArrayRef exprs, MLIRContext *context) { + // Size 0 corner case is useful for canonicalizations. + if (llvm::any_of(sizes, [](int64_t size) { return size == 0; })) + return getAffineConstantExpr(0, context); + + auto maps = AffineMap::inferFromExprList(exprs); + assert(!maps.empty() && "Expected one non-empty map"); + unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols(); + AffineExpr expr; bool dynamicPoisonBit = false; - unsigned numDims = 0; - unsigned nSymbols = 0; - // Compute the number of symbols and dimensions of the passed exprs. - for (AffineExpr expr : exprs) { - expr.walk([&numDims, &nSymbols](AffineExpr d) { - if (AffineDimExpr dim = d.dyn_cast()) - numDims = std::max(numDims, dim.getPosition() + 1); - else if (AffineSymbolExpr symbol = d.dyn_cast()) - nSymbols = std::max(nSymbols, symbol.getPosition() + 1); - }); - } int64_t runningSize = 1; for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) { int64_t size = std::get<1>(en); diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -172,3 +172,12 @@ // CHECK-LABEL: @no_fold_memref_reshape // CHECK: linalg.reshape // CHECK: linalg.reshape + +// ----- + +func @dce_zero_memref(%arg0 : memref<0xf32>) { + linalg.copy(%arg0, %arg0): memref<0xf32>, memref<0xf32> + return +} +// CHECK-LABEL: @dce_zero_memref +// CHECK-NEXT: return