diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -967,6 +967,8 @@ return result().getType().cast(); } }]; + + let hasCanonicalizer = 1; } def AffineVectorStoreOp : AffineStoreOpBase<"vector_store"> { @@ -1029,6 +1031,8 @@ return value().getType().cast(); } }]; + + let hasCanonicalizer = 1; } #endif // AFFINE_OPS diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -853,15 +853,18 @@ LogicalResult matchAndRewrite(AffineOpTy affineOp, PatternRewriter &rewriter) const override { - static_assert(llvm::is_one_of::value, - "affine load/store/apply/prefetch/min/max op expected"); + static_assert( + llvm::is_one_of::value, + "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op " + "expected"); auto map = affineOp.getAffineMap(); AffineMap oldMap = map; auto oldOperands = affineOp.getMapOperands(); SmallVector resultOperands(oldOperands); composeAffineMapAndOperands(&map, &resultOperands); + canonicalizeMapAndOperands(&map, &resultOperands); if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(), resultOperands.begin())) return failure(); @@ -895,6 +898,22 @@ rewriter.replaceOpWithNewOp( store, store.getValueToStore(), store.getMemRef(), map, mapOperands); } +template <> +void SimplifyAffineOp::replaceAffineOp( + PatternRewriter &rewriter, AffineVectorLoadOp vectorload, AffineMap map, + ArrayRef mapOperands) const { + rewriter.replaceOpWithNewOp( + vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map, + mapOperands); +} +template <> +void SimplifyAffineOp::replaceAffineOp( + PatternRewriter &rewriter, AffineVectorStoreOp vectorstore, AffineMap map, + ArrayRef mapOperands) const { + rewriter.replaceOpWithNewOp( + vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map, + mapOperands); +} // Generic version for ops that don't have extra operands. template @@ -3267,6 +3286,11 @@ build(builder, result, resultType, memref, map, indices); } +void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add>(context); +} + static ParseResult parseAffineVectorLoadOp(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); @@ -3353,6 +3377,10 @@ rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap(); build(builder, result, valueToStore, memref, map, indices); } +void AffineVectorStoreOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.add>(context); +} static ParseResult parseAffineVectorStoreOp(OpAsmParser &parser, OperationState &result) { diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -901,3 +901,27 @@ } return } + +// ----- + +// Compose maps into affine.vector_load / affine.vector_store + +// CHECK-LABEL: func @compose_into_affine_vector_load_vector_store +// CHECK: affine.for %[[IV:.*]] = 0 to 1024 +// CHECK-NEXT: affine.vector_load %{{.*}}[%[[IV]] + 1] +// CHECK-NEXT: affine.vector_store %{{.*}}, %{{.*}}[%[[IV]] + 1] +// CHECK-NEXT: affine.vector_load %{{.*}}[%[[IV]]] +func @compose_into_affine_vector_load_vector_store(%A : memref<1024xf32>, %u : index) { + affine.for %i = 0 to 1024 { + // Make sure the unused operand (%u below) gets dropped as well. + %idx = affine.apply affine_map<(d0, d1) -> (d0 + 1)> (%i, %u) + %0 = affine.vector_load %A[%idx] : memref<1024xf32>, vector<8xf32> + affine.vector_store %0, %A[%idx] : memref<1024xf32>, vector<8xf32> + + // Map remains the same, but operand changes on composition. + %copy = affine.apply affine_map<(d0) -> (d0)> (%i) + %1 = affine.vector_load %A[%copy] : memref<1024xf32>, vector<8xf32> + "prevent.dce"(%1) : (vector<8xf32>) -> () + } + return +}