diff --git a/mlir/include/mlir/Dialect/AffineOps/AffineOps.td b/mlir/include/mlir/Dialect/AffineOps/AffineOps.td --- a/mlir/include/mlir/Dialect/AffineOps/AffineOps.td +++ b/mlir/include/mlir/Dialect/AffineOps/AffineOps.td @@ -238,13 +238,25 @@ Op { let arguments = (ins AffineMapAttr:$map, Variadic:$operands); let results = (outs Index); + + let builders = [ + OpBuilder<"Builder *builder, OperationState &result, AffineMap affineMap, " + "ValueRange mapOperands", + [{ + build(builder, result, builder->getIndexType(), affineMap, mapOperands); + }]> + ]; + let extraClassDeclaration = [{ static StringRef getMapAttrName() { return "map"; } + AffineMap getAffineMap() { return map(); } + ValueRange getMapOperands() { return operands(); } }]; let verifier = [{ return ::verifyAffineMinMaxOp(*this); }]; let printer = [{ return ::printAffineMinMaxOp(p, *this); }]; let parser = [{ return ::parseAffineMinMaxOp<$cppClass>(parser, result); }]; let hasFolder = 1; + let hasCanonicalizer = 1; } def AffineMinOp : AffineMinMaxOpBase<"min", [NoSideEffect]> { diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp --- a/mlir/lib/Dialect/AffineOps/AffineOps.cpp +++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp @@ -763,7 +763,9 @@ static_assert(std::is_same::value || std::is_same::value || std::is_same::value || - std::is_same::value, + std::is_same::value || + std::is_same::value || + std::is_same::value, "affine load/store/apply op expected"); auto map = affineOp.getAffineMap(); AffineMap oldMap = map; @@ -804,11 +806,13 @@ rewriter.replaceOpWithNewOp( store, store.getValueToStore(), store.getMemRef(), map, mapOperands); } -template <> -void SimplifyAffineOp::replaceAffineOp( - PatternRewriter &rewriter, AffineApplyOp apply, AffineMap map, + +// Generic version for ops that don't have extra operands. +template +void SimplifyAffineOp::replaceAffineOp( + PatternRewriter &rewriter, AffineOpTy op, AffineMap map, ArrayRef mapOperands) const { - rewriter.replaceOpWithNewOp(apply, map, mapOperands); + rewriter.replaceOpWithNewOp(op, map, mapOperands); } } // end anonymous namespace. @@ -2016,6 +2020,11 @@ return results[minIndex]; } +void AffineMinOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert>(context); +} + //===----------------------------------------------------------------------===// // AffineMaxOp //===----------------------------------------------------------------------===// @@ -2046,6 +2055,11 @@ return results[maxIndex]; } +void AffineMaxOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert>(context); +} + //===----------------------------------------------------------------------===// // AffinePrefetchOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/AffineOps/canonicalize.mlir b/mlir/test/AffineOps/canonicalize.mlir --- a/mlir/test/AffineOps/canonicalize.mlir +++ b/mlir/test/AffineOps/canonicalize.mlir @@ -551,4 +551,20 @@ // CHECK-NEXT: "op0"(%[[CST]]) : (index) -> () // CHECK-NEXT: return return -} \ No newline at end of file +} + +// ----- + +// CHECK: #[[map:.*]] = affine_map<(d0, d1) -> (d0, d1 - 2)> + +func @affine_min(%arg0: index) { + affine.for %i = 0 to %arg0 { + affine.for %j = 0 to %arg0 { + %c2 = constant 2 : index + // CHECK: affine.min #[[map]] + %0 = affine.min affine_map<(d0,d1,d2)->(d0, d1 - d2)>(%i, %j, %c2) + "consumer"(%0) : (index) -> () + } + } + return +}