Index: google3/third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Vector/VectorOps.td =================================================================== --- google3/third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ google3/third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -520,6 +520,8 @@ $vector `[` $id `:` $multiplicity `]` attr-dict `:` type($vector) `to` type(results) }]; + + let hasFolder = 1; } def Vector_FMAOp : @@ -1029,7 +1031,8 @@ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ]>, - Arguments<(ins AnyMemRef:$memref, Variadic:$indices, + Arguments<(ins Arg:$memref, + Variadic:$indices, AffineMapAttr:$permutation_map, AnyType:$padding, OptionalAttr:$masked)>, Results<(outs AnyVector:$vector)> { Index: google3/third_party/llvm/llvm-project/mlir/lib/Dialect/Vector/VectorOps.cpp =================================================================== --- google3/third_party/llvm/llvm-project/mlir/lib/Dialect/Vector/VectorOps.cpp +++ google3/third_party/llvm/llvm-project/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -923,6 +923,14 @@ return success(); } +OpFoldResult ExtractMapOp::fold(ArrayRef operands) { + auto insert = vector().getDefiningOp(); + if (insert == nullptr || multiplicity() != insert.multiplicity() || + id() != insert.id()) + return {}; + return insert.vector(); +} + //===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// Index: google3/third_party/llvm/llvm-project/mlir/lib/Dialect/Vector/VectorTransforms.cpp =================================================================== --- google3/third_party/llvm/llvm-project/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ google3/third_party/llvm/llvm-project/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -12,6 +12,7 @@ #include +#include "mlir/Dialect/Affine/EDSC/Builders.h" #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" @@ -2455,6 +2456,50 @@ return ops; } +struct TransferReadExtractPattern + : public OpRewritePattern { + TransferReadExtractPattern(MLIRContext *context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(vector::ExtractMapOp extract, + PatternRewriter &rewriter) const override { + auto read = extract.vector().getDefiningOp(); + if (!read) + return failure(); + edsc::ScopedContext scope(rewriter, extract.getLoc()); + using mlir::edsc::op::operator+; + using namespace mlir::edsc::intrinsics; + SmallVector indices(read.indices().begin(), read.indices().end()); + indices.back() = indices.back() + extract.id(); + Value newRead = vector_transfer_read(extract.getType(), read.memref(), + indices, read.permutation_map(), + read.padding(), ArrayAttr()); + rewriter.replaceOp(extract, newRead); + return success(); + } +}; + +struct TransferWriteInsertPattern + : public OpRewritePattern { + TransferWriteInsertPattern(MLIRContext *context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(vector::TransferWriteOp write, + PatternRewriter &rewriter) const override { + auto insert = write.vector().getDefiningOp(); + if (!insert) + return failure(); + edsc::ScopedContext scope(rewriter, write.getLoc()); + using mlir::edsc::op::operator+; + using namespace mlir::edsc::intrinsics; + SmallVector indices(write.indices().begin(), + write.indices().end()); + indices.back() = indices.back() + insert.id(); + vector_transfer_write(insert.vector(), write.memref(), indices, + write.permutation_map(), ArrayAttr()); + rewriter.eraseOp(write); + return success(); + } +}; + // TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp). // TODO: Add this as DRR pattern. void mlir::vector::populateVectorToVectorTransformationPatterns( @@ -2464,7 +2509,9 @@ ShapeCastOpFolder, SplitTransferReadOp, SplitTransferWriteOp, - TupleGetFolderOp>(context); + TupleGetFolderOp, + TransferReadExtractPattern, + TransferWriteInsertPattern>(context); // clang-format on } Index: google3/third_party/llvm/llvm-project/mlir/test/Dialect/Vector/vector-distribution.mlir =================================================================== --- google3/third_party/llvm/llvm-project/mlir/test/Dialect/Vector/vector-distribution.mlir +++ google3/third_party/llvm/llvm-project/mlir/test/Dialect/Vector/vector-distribution.mlir @@ -11,3 +11,24 @@ %0 = addf %A, %B : vector<32xf32> return %0: vector<32xf32> } + +// CHECK-LABEL: func @vector_add_read_write +// CHECK-SAME: (%[[ID:.*]]: index +// CHECK: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%{{.*}}], %{{.*}} : memref<32xf32>, vector<1xf32> +// CHECK-NEXT: %[[EXB:.*]] = vector.transfer_read %{{.*}}[%{{.*}}], %{{.*}} : memref<32xf32>, vector<1xf32> +// CHECK-NEXT: %[[ADD1:.*]] = addf %[[EXA]], %[[EXB]] : vector<1xf32> +// CHECK-NEXT: %[[EXC:.*]] = vector.transfer_read %{{.*}}[%{{.*}}], %{{.*}} : memref<32xf32>, vector<1xf32> +// CHECK-NEXT: %[[ADD2:.*]] = addf %[[ADD1]], %[[EXC]] : vector<1xf32> +// CHECK-NEXT: vector.transfer_write %[[ADD2]], %{{.*}}[%{{.*}}] : vector<1xf32>, memref<32xf32> +// CHECK-NEXT: return +func @vector_add_read_write(%id : index, %A: memref<32xf32>, %B: memref<32xf32>, %C: memref<32xf32>, %D: memref<32xf32>) { + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + %a = vector.transfer_read %A[%c0], %cf0: memref<32xf32>, vector<32xf32> + %b = vector.transfer_read %B[%c0], %cf0: memref<32xf32>, vector<32xf32> + %acc = addf %a, %b: vector<32xf32> + %c = vector.transfer_read %C[%c0], %cf0: memref<32xf32>, vector<32xf32> + %d = addf %acc, %c: vector<32xf32> + vector.transfer_write %d, %D[%c0]: vector<32xf32>, memref<32xf32> + return +} Index: google3/third_party/llvm/llvm-project/mlir/test/lib/Transforms/TestVectorTransforms.cpp =================================================================== --- google3/third_party/llvm/llvm-project/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ google3/third_party/llvm/llvm-project/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -129,6 +129,7 @@ : public PassWrapper { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); + registry.insert(); } void runOnFunction() override { MLIRContext *ctx = &getContext(); @@ -143,6 +144,7 @@ op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp); }); patterns.insert(ctx); + populateVectorToVectorTransformationPatterns(patterns, ctx); applyPatternsAndFoldGreedily(getFunction(), patterns); } };