Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Show First 20 Lines • Show All 1,147 Lines • ▼ Show 20 Lines | llvm::interleave( | ||||||||||||||||||
types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); }, | types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); }, | ||||||||||||||||||
[&]() { ss << "_"; }); | [&]() { ss << "_"; }); | ||||||||||||||||||
return ss.str(); | return ss.str(); | ||||||||||||||||||
} | } | ||||||||||||||||||
// TODO: Consider making all this boilerplate easy to autogenerate | // TODO: Consider making all this boilerplate easy to autogenerate | ||||||||||||||||||
// with Tablegen. This seems a desirable property in the context of OpInterfaces | // with Tablegen. This seems a desirable property in the context of OpInterfaces | ||||||||||||||||||
// where a Linalg "named" op **isa** LinalgOp. | // where a Linalg "named" op **isa** LinalgOp. | ||||||||||||||||||
LogicalResult ConvOp::fold(ArrayRef<Attribute>, | |||||||||||||||||||
SmallVectorImpl<OpFoldResult> &) { | |||||||||||||||||||
return foldMemRefCast(*this); | |||||||||||||||||||
} | |||||||||||||||||||
LogicalResult PoolingMaxOp::fold(ArrayRef<Attribute>, | |||||||||||||||||||
SmallVectorImpl<OpFoldResult> &) { | |||||||||||||||||||
return foldMemRefCast(*this); | |||||||||||||||||||
} | |||||||||||||||||||
LogicalResult PoolingMinOp::fold(ArrayRef<Attribute>, | |||||||||||||||||||
SmallVectorImpl<OpFoldResult> &) { | |||||||||||||||||||
return foldMemRefCast(*this); | |||||||||||||||||||
} | |||||||||||||||||||
LogicalResult PoolingSumOp::fold(ArrayRef<Attribute>, | |||||||||||||||||||
SmallVectorImpl<OpFoldResult> &) { | |||||||||||||||||||
return foldMemRefCast(*this); | |||||||||||||||||||
} | |||||||||||||||||||
LogicalResult CopyOp::fold(ArrayRef<Attribute>, | |||||||||||||||||||
SmallVectorImpl<OpFoldResult> &) { | |||||||||||||||||||
return foldMemRefCast(*this); | |||||||||||||||||||
} | |||||||||||||||||||
LogicalResult FillOp::fold(ArrayRef<Attribute>, | |||||||||||||||||||
SmallVectorImpl<OpFoldResult> &) { | |||||||||||||||||||
return foldMemRefCast(*this); | |||||||||||||||||||
} | |||||||||||||||||||
LogicalResult GenericOp::fold(ArrayRef<Attribute>, | |||||||||||||||||||
SmallVectorImpl<OpFoldResult> &) { | |||||||||||||||||||
return foldMemRefCast(*this); | |||||||||||||||||||
} | |||||||||||||||||||
LogicalResult IndexedGenericOp::fold(ArrayRef<Attribute>, | |||||||||||||||||||
SmallVectorImpl<OpFoldResult> &) { | |||||||||||||||||||
return foldMemRefCast(*this); | |||||||||||||||||||
} | |||||||||||||||||||
OpFoldResult ReshapeOp::fold(ArrayRef<Attribute>) { | OpFoldResult ReshapeOp::fold(ArrayRef<Attribute>) { | ||||||||||||||||||
if (succeeded(foldMemRefCast(*this))) | if (succeeded(foldMemRefCast(*this))) | ||||||||||||||||||
return getResult(); | return getResult(); | ||||||||||||||||||
return foldReshapeOp(*this); | return foldReshapeOp(*this); | ||||||||||||||||||
} | } | ||||||||||||||||||
OpFoldResult SliceOp::fold(ArrayRef<Attribute>) { | OpFoldResult SliceOp::fold(ArrayRef<Attribute>) { | ||||||||||||||||||
if (succeeded(foldMemRefCast(*this))) | if (succeeded(foldMemRefCast(*this))) | ||||||||||||||||||
return getResult(); | return getResult(); | ||||||||||||||||||
▲ Show 20 Lines • Show All 98 Lines • ▼ Show 20 Lines | return parser.resolveOperands(operandsInfo, operandTypes, | ||||||||||||||||||
parser.getCurrentLocation(), result.operands); | parser.getCurrentLocation(), result.operands); | ||||||||||||||||||
} | } | ||||||||||||||||||
template <typename NamedStructuredOpType> | template <typename NamedStructuredOpType> | ||||||||||||||||||
static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) { | static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) { | ||||||||||||||||||
return verifyGenericOp<NamedStructuredOpType>(op); | return verifyGenericOp<NamedStructuredOpType>(op); | ||||||||||||||||||
} | } | ||||||||||||||||||
struct EraseDeadLinalgOp : public RewritePattern { | |||||||||||||||||||
EraseDeadLinalgOp(PatternBenefit benefit = 1) | |||||||||||||||||||
: RewritePattern(benefit, MatchAnyOpTypeTag()) {} | |||||||||||||||||||
LogicalResult matchAndRewrite(Operation *op, | |||||||||||||||||||
PatternRewriter &rewriter) const override { | |||||||||||||||||||
auto linalgOp = dyn_cast<LinalgOp>(op); | |||||||||||||||||||
rriddle: I wonder if it would be useful to define an `InterfaceRewritePattern`, similarly to… | |||||||||||||||||||
Definitely, I have more and more uses for this. nicolasvasilache: Definitely, I have more and more uses for this. | |||||||||||||||||||
if (!linalgOp) | |||||||||||||||||||
return failure(); | |||||||||||||||||||
for (Value v : linalgOp.getInputsAndOutputBuffers()) { | |||||||||||||||||||
// Linalg "inputs" may be either tensor or memref type. | |||||||||||||||||||
bkramer: | |||||||||||||||||||
Added some comments to explain why dyn_cast. nicolasvasilache: Added some comments to explain why `dyn_cast`. | |||||||||||||||||||
You aren't checking the result of this dyn_cast. rriddle: You aren't checking the result of this dyn_cast. | |||||||||||||||||||
thanks for catching! nicolasvasilache: thanks for catching!
Also added a test for the tensor case. | |||||||||||||||||||
// tensor<0xelt_type> is a convention that may not always mean | |||||||||||||||||||
Can you also use is_contained here? rriddle: Can you also use is_contained here? | |||||||||||||||||||
// "0 iterations". Only erase in cases we see memref<...x0x...>. | |||||||||||||||||||
auto mt = v.getType().dyn_cast<MemRefType>(); | |||||||||||||||||||
if (!mt) | |||||||||||||||||||
continue; | |||||||||||||||||||
if (llvm::is_contained(mt.getShape(), 0)) { | |||||||||||||||||||
bkramer: | |||||||||||||||||||
rewriter.eraseOp(linalgOp); | |||||||||||||||||||
return success(); | |||||||||||||||||||
} | |||||||||||||||||||
} | |||||||||||||||||||
return failure(); | |||||||||||||||||||
} | |||||||||||||||||||
}; | |||||||||||||||||||
#define CANONICALIZERS_AND_FOLDERS(XXX) \ | |||||||||||||||||||
void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results, \ | |||||||||||||||||||
MLIRContext *context) { \ | |||||||||||||||||||
results.insert<EraseDeadLinalgOp>(); \ | |||||||||||||||||||
} \ | |||||||||||||||||||
\ | |||||||||||||||||||
LogicalResult XXX::fold(ArrayRef<Attribute>, \ | |||||||||||||||||||
SmallVectorImpl<OpFoldResult> &) { \ | |||||||||||||||||||
return foldMemRefCast(*this); \ | |||||||||||||||||||
} | |||||||||||||||||||
CANONICALIZERS_AND_FOLDERS(ConvOp); | |||||||||||||||||||
CANONICALIZERS_AND_FOLDERS(PoolingMaxOp); | |||||||||||||||||||
CANONICALIZERS_AND_FOLDERS(PoolingMinOp); | |||||||||||||||||||
CANONICALIZERS_AND_FOLDERS(PoolingSumOp); | |||||||||||||||||||
CANONICALIZERS_AND_FOLDERS(CopyOp); | |||||||||||||||||||
CANONICALIZERS_AND_FOLDERS(FillOp); | |||||||||||||||||||
CANONICALIZERS_AND_FOLDERS(GenericOp); | |||||||||||||||||||
CANONICALIZERS_AND_FOLDERS(IndexedGenericOp); | |||||||||||||||||||
#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc" | #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc" | ||||||||||||||||||
// TODO: Determine whether we can generate the folders and verifiers. | // TODO: Determine whether we can generate the folders and verifiers. | ||||||||||||||||||
LogicalResult BatchMatmulOp::fold(ArrayRef<Attribute>, | CANONICALIZERS_AND_FOLDERS(BatchMatmulOp); | ||||||||||||||||||
SmallVectorImpl<OpFoldResult> &) { | CANONICALIZERS_AND_FOLDERS(DotOp); | ||||||||||||||||||
return foldMemRefCast(*this); | CANONICALIZERS_AND_FOLDERS(MatmulOp); | ||||||||||||||||||
} | CANONICALIZERS_AND_FOLDERS(MatvecOp); | ||||||||||||||||||
LogicalResult DotOp::fold(ArrayRef<Attribute>, | CANONICALIZERS_AND_FOLDERS(ConvWOp); | ||||||||||||||||||
SmallVectorImpl<OpFoldResult> &) { | CANONICALIZERS_AND_FOLDERS(ConvNWCOp); | ||||||||||||||||||
return foldMemRefCast(*this); | CANONICALIZERS_AND_FOLDERS(ConvNCWOp); | ||||||||||||||||||
} | CANONICALIZERS_AND_FOLDERS(ConvHWOp); | ||||||||||||||||||
LogicalResult MatmulOp::fold(ArrayRef<Attribute>, | CANONICALIZERS_AND_FOLDERS(ConvNHWCOp); | ||||||||||||||||||
SmallVectorImpl<OpFoldResult> &) { | CANONICALIZERS_AND_FOLDERS(ConvNCHWOp); | ||||||||||||||||||
return foldMemRefCast(*this); | CANONICALIZERS_AND_FOLDERS(ConvDHWOp); | ||||||||||||||||||
} | CANONICALIZERS_AND_FOLDERS(ConvNDHWCOp); | ||||||||||||||||||
LogicalResult MatvecOp::fold(ArrayRef<Attribute>, | CANONICALIZERS_AND_FOLDERS(ConvNCDHWOp); | ||||||||||||||||||
SmallVectorImpl<OpFoldResult> &) { | |||||||||||||||||||
return foldMemRefCast(*this); | |||||||||||||||||||
} | |||||||||||||||||||
LogicalResult ConvWOp::fold(ArrayRef<Attribute>, | |||||||||||||||||||
SmallVectorImpl<OpFoldResult> &) { | |||||||||||||||||||
return foldMemRefCast(*this); | |||||||||||||||||||
} | |||||||||||||||||||
LogicalResult ConvNWCOp::fold(ArrayRef<Attribute>, | |||||||||||||||||||
SmallVectorImpl<OpFoldResult> &) { | |||||||||||||||||||
return foldMemRefCast(*this); | |||||||||||||||||||
} | |||||||||||||||||||
LogicalResult ConvNCWOp::fold(ArrayRef<Attribute>, | |||||||||||||||||||
SmallVectorImpl<OpFoldResult> &) { | |||||||||||||||||||
return foldMemRefCast(*this); | |||||||||||||||||||
} | |||||||||||||||||||
LogicalResult ConvHWOp::fold(ArrayRef<Attribute>, | |||||||||||||||||||
SmallVectorImpl<OpFoldResult> &) { | |||||||||||||||||||
return foldMemRefCast(*this); | |||||||||||||||||||
} | |||||||||||||||||||
LogicalResult ConvNHWCOp::fold(ArrayRef<Attribute>, | |||||||||||||||||||
SmallVectorImpl<OpFoldResult> &) { | |||||||||||||||||||
return foldMemRefCast(*this); | |||||||||||||||||||
} | |||||||||||||||||||
LogicalResult ConvNCHWOp::fold(ArrayRef<Attribute>, | |||||||||||||||||||
SmallVectorImpl<OpFoldResult> &) { | |||||||||||||||||||
return foldMemRefCast(*this); | |||||||||||||||||||
} | |||||||||||||||||||
LogicalResult ConvDHWOp::fold(ArrayRef<Attribute>, | |||||||||||||||||||
SmallVectorImpl<OpFoldResult> &) { | |||||||||||||||||||
return foldMemRefCast(*this); | |||||||||||||||||||
} | |||||||||||||||||||
LogicalResult ConvNDHWCOp::fold(ArrayRef<Attribute>, | |||||||||||||||||||
SmallVectorImpl<OpFoldResult> &) { | |||||||||||||||||||
return foldMemRefCast(*this); | |||||||||||||||||||
} | |||||||||||||||||||
LogicalResult ConvNCDHWOp::fold(ArrayRef<Attribute>, | |||||||||||||||||||
SmallVectorImpl<OpFoldResult> &) { | |||||||||||||||||||
return foldMemRefCast(*this); | |||||||||||||||||||
} |
I wonder if it would be useful to define an InterfaceRewritePattern, similarly to OpRewritePattern. (I would still enforce that the user provide MatchAnyOpTypeTag)