Index: mlir/include/mlir/Dialect/Vector/VectorOps.td =================================================================== --- mlir/include/mlir/Dialect/Vector/VectorOps.td +++ mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1029,7 +1029,8 @@ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ]>, - Arguments<(ins AnyMemRef:$memref, Variadic:$indices, + Arguments<(ins Arg:$memref, + Variadic:$indices, AffineMapAttr:$permutation_map, AnyType:$padding, OptionalAttr:$masked)>, Results<(outs AnyVector:$vector)> { Index: mlir/lib/Dialect/Vector/VectorTransforms.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -12,6 +12,7 @@ #include +#include "mlir/Dialect/Affine/EDSC/Builders.h" #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" @@ -2455,6 +2456,65 @@ return ops; } +struct TransferReadExtractPattern + : public OpRewritePattern { + TransferReadExtractPattern(MLIRContext *context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(vector::ExtractMapOp extract, + PatternRewriter &rewriter) const override { + auto read = extract.vector().getDefiningOp(); + if (!read) + return failure(); + edsc::ScopedContext scope(rewriter, extract.getLoc()); + using mlir::edsc::op::operator+; + SmallVector indices(read.indices().begin(), read.indices().end()); + indices.back() = indices.back() + extract.id(); + Value newRead = rewriter.create( + read.getLoc(), extract.getType(), read.memref(), indices, + read.permutation_map(), read.padding(), ArrayAttr()); + rewriter.replaceOp(extract, newRead); + return success(); + } +}; + +struct TransferWriteInsertPattern + : public OpRewritePattern { + TransferWriteInsertPattern(MLIRContext *context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(vector::TransferWriteOp write, + PatternRewriter &rewriter) const override { + auto insert = write.vector().getDefiningOp(); + if (!insert) + return failure(); + edsc::ScopedContext scope(rewriter, write.getLoc()); + using mlir::edsc::op::operator+; + SmallVector indices(write.indices().begin(), + write.indices().end()); + indices.back() = indices.back() + insert.id(); + rewriter.create( + write.getLoc(), insert.vector(), write.memref(), indices, + write.permutation_map(), ArrayAttr()); + rewriter.eraseOp(write); + return success(); + } +}; + +struct ExtractMapFolderOp : public OpRewritePattern { + ExtractMapFolderOp(MLIRContext *context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(vector::ExtractMapOp extract, + PatternRewriter &rewriter) const override { + auto insert = extract.vector().getDefiningOp(); + if (!insert) + return failure(); + if (extract.multiplicity() != insert.multiplicity() || + extract.id() != insert.id()) + return failure(); + rewriter.replaceOp(extract, insert.vector()); + return success(); + } +}; + // TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp). // TODO: Add this as DRR pattern. void mlir::vector::populateVectorToVectorTransformationPatterns( @@ -2464,7 +2524,10 @@ ShapeCastOpFolder, SplitTransferReadOp, SplitTransferWriteOp, - TupleGetFolderOp>(context); + TupleGetFolderOp, + TransferReadExtractPattern, + TransferWriteInsertPattern, + ExtractMapFolderOp>(context); // clang-format on } Index: mlir/test/Dialect/Vector/vector-distribution.mlir =================================================================== --- mlir/test/Dialect/Vector/vector-distribution.mlir +++ mlir/test/Dialect/Vector/vector-distribution.mlir @@ -11,3 +11,24 @@ %0 = addf %A, %B : vector<32xf32> return %0: vector<32xf32> } + +// CHECK-LABEL: func @vector_add_read_write +// CHECK-SAME: (%[[ID:.*]]: index +// CHECK: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%{{.*}}], %{{.*}} : memref<32xf32>, vector<1xf32> +// CHECK-NEXT: %[[EXB:.*]] = vector.transfer_read %{{.*}}[%{{.*}}], %{{.*}} : memref<32xf32>, vector<1xf32> +// CHECK-NEXT: %[[ADD1:.*]] = addf %[[EXA]], %[[EXB]] : vector<1xf32> +// CHECK-NEXT: %[[EXC:.*]] = vector.transfer_read %{{.*}}[%{{.*}}], %{{.*}} : memref<32xf32>, vector<1xf32> +// CHECK-NEXT: %[[ADD2:.*]] = addf %[[ADD1]], %[[EXC]] : vector<1xf32> +// CHECK-NEXT: vector.transfer_write %[[ADD2]], %{{.*}}[%{{.*}}] : vector<1xf32>, memref<32xf32> +// CHECK-NEXT: return +func @vector_add_read_write(%id : index, %A: memref<32xf32>, %B: memref<32xf32>, %C: memref<32xf32>, %D: memref<32xf32>) { + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + %a = vector.transfer_read %A[%c0], %cf0: memref<32xf32>, vector<32xf32> + %b = vector.transfer_read %B[%c0], %cf0: memref<32xf32>, vector<32xf32> + %acc = addf %a, %b: vector<32xf32> + %c = vector.transfer_read %C[%c0], %cf0: memref<32xf32>, vector<32xf32> + %d = addf %acc, %c: vector<32xf32> + vector.transfer_write %d, %D[%c0]: vector<32xf32>, memref<32xf32> + return +} Index: mlir/test/lib/Transforms/TestVectorTransforms.cpp =================================================================== --- mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -129,6 +129,7 @@ : public PassWrapper { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); + registry.insert(); } void runOnFunction() override { MLIRContext *ctx = &getContext(); @@ -143,6 +144,7 @@ op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp); }); patterns.insert(ctx); + populateVectorToVectorTransformationPatterns(patterns, ctx); applyPatternsAndFoldGreedily(getFunction(), patterns); } };