Index: mlir/lib/Dialect/Vector/VectorTransforms.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -2444,7 +2444,12 @@ OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointAfter(op); Location loc = op->getLoc(); + if (op->getNumResults() != 1) + return {}; Value result = op->getResult(0); + VectorType type = op->getResult(0).getType().dyn_cast(); + if (!type || type.getRank() != 1 || type.getNumElements() % multiplicity != 0) + return {}; DistributeOps ops; ops.extract = builder.create(loc, result, id, multiplicity); Index: mlir/test/Dialect/Vector/vector-distribution.mlir =================================================================== --- mlir/test/Dialect/Vector/vector-distribution.mlir +++ mlir/test/Dialect/Vector/vector-distribution.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-vector-distribute-patterns | FileCheck %s +// RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32 | FileCheck %s // CHECK-LABEL: func @distribute_vector_add // CHECK-SAME: (%[[ID:.*]]: index @@ -14,12 +14,12 @@ // CHECK-LABEL: func @vector_add_read_write // CHECK-SAME: (%[[ID:.*]]: index -// CHECK: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%{{.*}}], %{{.*}} : memref<32xf32>, vector<1xf32> -// CHECK-NEXT: %[[EXB:.*]] = vector.transfer_read %{{.*}}[%{{.*}}], %{{.*}} : memref<32xf32>, vector<1xf32> +// CHECK: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32> +// CHECK-NEXT: %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32> // CHECK-NEXT: %[[ADD1:.*]] = addf %[[EXA]], %[[EXB]] : vector<1xf32> -// CHECK-NEXT: %[[EXC:.*]] = vector.transfer_read %{{.*}}[%{{.*}}], %{{.*}} : memref<32xf32>, vector<1xf32> +// CHECK-NEXT: %[[EXC:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32> // CHECK-NEXT: %[[ADD2:.*]] = addf %[[ADD1]], %[[EXC]] : vector<1xf32> -// CHECK-NEXT: vector.transfer_write %[[ADD2]], %{{.*}}[%{{.*}}] : vector<1xf32>, memref<32xf32> +// CHECK-NEXT: vector.transfer_write %[[ADD2]], %{{.*}}[%[[ID]]] : vector<1xf32>, memref<32xf32> // CHECK-NEXT: return func @vector_add_read_write(%id : index, %A: memref<32xf32>, %B: memref<32xf32>, %C: memref<32xf32>, %D: memref<32xf32>) { %c0 = constant 0 : index @@ -32,3 +32,41 @@ vector.transfer_write %d, %D[%c0]: vector<32xf32>, memref<32xf32> return } + +// CHECK-LABEL: func @vector_add_cycle +// CHECK-SAME: (%[[ID:.*]]: index +// CHECK: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<64xf32>, vector<2xf32> +// CHECK-NEXT: %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<64xf32>, vector<2xf32> +// CHECK-NEXT: %[[ADD:.*]] = addf %[[EXA]], %[[EXB]] : vector<2xf32> +// CHECK-NEXT: vector.transfer_write %[[ADD]], %{{.*}}[%[[ID]]] : vector<2xf32>, memref<64xf32> +// CHECK-NEXT: return +func @vector_add_cycle(%id : index, %A: memref<64xf32>, %B: memref<64xf32>, %C: memref<64xf32>) { + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + %a = vector.transfer_read %A[%c0], %cf0: memref<64xf32>, vector<64xf32> + %b = vector.transfer_read %B[%c0], %cf0: memref<64xf32>, vector<64xf32> + %acc = addf %a, %b: vector<64xf32> + vector.transfer_write %acc, %C[%c0]: vector<64xf32>, memref<64xf32> + return +} + +// Negative test to make sure nothing is done in case the vector size is not a +// multiple of multiplicity. +// CHECK-LABEL: func @vector_negative_test +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[C0]]], %{{.*}} : memref<64xf32>, vector<16xf32> +// CHECK-NEXT: %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[C0]]], %{{.*}} : memref<64xf32>, vector<16xf32> +// CHECK-NEXT: %[[ADD:.*]] = addf %[[EXA]], %[[EXB]] : vector<16xf32> +// CHECK-NEXT: vector.transfer_write %[[ADD]], %{{.*}}[%[[C0]]] {{.*}} : vector<16xf32>, memref<64xf32> +// CHECK-NEXT: return +func @vector_negative_test(%id : index, %A: memref<64xf32>, %B: memref<64xf32>, %C: memref<64xf32>) { + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + %a = vector.transfer_read %A[%c0], %cf0: memref<64xf32>, vector<16xf32> + %b = vector.transfer_read %B[%c0], %cf0: memref<64xf32>, vector<16xf32> + %acc = addf %a, %b: vector<16xf32> + vector.transfer_write %acc, %C[%c0]: vector<16xf32>, memref<64xf32> + return +} + + Index: mlir/test/lib/Transforms/TestVectorTransforms.cpp =================================================================== --- mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -127,10 +127,16 @@ struct TestVectorDistributePatterns : public PassWrapper { + TestVectorDistributePatterns() = default; + TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {} void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); registry.insert(); } + Option multiplicity{ + *this, "distribution-multiplicity", + llvm::cl::desc("Set the multiplicity used for distributing vector"), + llvm::cl::init(32)}; void runOnFunction() override { MLIRContext *ctx = &getContext(); OwningRewritePatternList patterns; @@ -138,10 +144,11 @@ func.walk([&](AddFOp op) { OpBuilder builder(op); Optional ops = distributPointwiseVectorOp( - builder, op.getOperation(), func.getArgument(0), 32); - assert(ops.hasValue()); - SmallPtrSet extractOp({ops->extract}); - op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp); + builder, op.getOperation(), func.getArgument(0), multiplicity); + if (ops.hasValue()) { + SmallPtrSet extractOp({ops->extract}); + op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp); + } }); patterns.insert(ctx); populateVectorToVectorTransformationPatterns(patterns, ctx);