Index: mlir/include/mlir/Dialect/Vector/VectorOps.td =================================================================== --- mlir/include/mlir/Dialect/Vector/VectorOps.td +++ mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -457,11 +457,18 @@ def Vector_ExtractMapOp : Vector_Op<"extract_map", [NoSideEffect]>, - Arguments<(ins AnyVector:$vector, Index:$id)>, + Arguments<(ins AnyVector:$vector, Variadic:$ids)>, Results<(outs AnyVector)> { let summary = "vector extract map operation"; let description = [{ - Takes an 1-D vector and extracts a sub-part of the vector starting at id. + Takes an N-D vector and extracts a sub-part of the vector starting at id + along each dimension. + + The dimension associated to each element of `ids` used to extract are + implicitly deduced from the the destination type. For example if the source + type is `vector<64x4x32xf32>` and the destination type is + `vector<4x4x2xf32>`, the first id maps to dimension 0 and the second id to + dimension 2. Similarly to vector.tuple_get, this operation is used for progressive lowering and should be folded away before converting to LLVM. @@ -488,10 +495,14 @@ ```mlir %ev = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32> + %ev1 = vector.extract_map %v1[%id1, %id2] : vector<64x4x32xf32> + to vector<4x4x2xf32> ``` }]; let builders = [ - OpBuilderDAG<(ins "Value":$vector, "Value":$id, "int64_t":$multiplicity)>]; + OpBuilderDAG<(ins "Value":$vector, "ValueRange":$ids, + "ArrayRef":$multiplicity, + "AffineMap":$map)>]; let extraClassDeclaration = [{ VectorType getSourceVectorType() { return vector().getType().cast(); @@ -499,13 +510,11 @@ VectorType getResultType() { return getResult().getType().cast(); } - int64_t multiplicity() { - return getSourceVectorType().getNumElements() / - getResultType().getNumElements(); - } + void getMultiplicity(SmallVectorImpl &multiplicity); + AffineMap map(); }]; let assemblyFormat = [{ - $vector `[` $id `]` attr-dict `:` type($vector) `to` type(results) + $vector `[` $ids `]` attr-dict `:` type($vector) `to` type(results) }]; let hasFolder = 1; @@ -686,13 +695,18 @@ def Vector_InsertMapOp : Vector_Op<"insert_map", [NoSideEffect, AllTypesMatch<["dest", "result"]>]>, - Arguments<(ins AnyVector:$vector, AnyVector:$dest, Index:$id)>, + Arguments<(ins AnyVector:$vector, AnyVector:$dest, Variadic:$ids)>, Results<(outs AnyVector:$result)> { let summary = "vector insert map operation"; let description = [{ - Inserts a 1-D vector and within a larger vector starting at id. The new + Inserts a N-D vector and within a larger vector starting at id. The new vector created will have the same size as the destination operand vector. + The dimension associated to each element of `ids` used to insert is + implicitly deduced from the source type. For example if source type is + `vector<4x4x2xf32>` and the destination type is `vector<64x4x32xf32>`, + the first id maps to dimension 0 and the second id to dimension 2. + Similarly to vector.tuple_get, this operation is used for progressive lowering and should be folded away before converting to LLVM. @@ -723,10 +737,12 @@ ```mlir %v = vector.insert_map %ev %v[%id] : vector<1xf32> into vector<32xf32> + %v1 = vector.insert_map %ev1, %v1[%arg0, %arg1] : vector<2x4x1xf32> + into vector<64x4x32xf32> ``` }]; let builders = [OpBuilderDAG<(ins "Value":$vector, "Value":$dest, - "Value":$id, "int64_t":$multiplicity)>]; + "ValueRange":$ids)>]; let extraClassDeclaration = [{ VectorType getSourceVectorType() { return vector().getType().cast(); @@ -734,13 +750,11 @@ VectorType getResultType() { return getResult().getType().cast(); } - int64_t multiplicity() { - return getResultType().getNumElements() / - getSourceVectorType().getNumElements(); - } + // Return a map indicating the dimension mapping to the given Ids. + AffineMap map(); }]; let assemblyFormat = [{ - $vector `,` $dest `[` $id `]` attr-dict + $vector `,` $dest `[` $ids `]` attr-dict `:` type($vector) `into` type($result) }]; } Index: mlir/include/mlir/Dialect/Vector/VectorTransforms.h =================================================================== --- mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -231,7 +231,7 @@ InsertMapOp insert; }; -/// Distribute a 1D vector pointwise operation over a range of given IDs taking +/// Distribute a N-D vector pointwise operation over a range of given IDs taking /// *all* values in [0 .. multiplicity - 1] (e.g. loop induction variable or /// SPMD id). This transformation only inserts /// vector.extract_map/vector.insert_map. It is meant to be used with @@ -243,9 +243,10 @@ /// %v = addf %a, %b : vector<32xf32> /// %ev = vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32> /// %nv = vector.insert_map %ev, %id, 32 : vector<1xf32> into vector<32xf32> -Optional distributPointwiseVectorOp(OpBuilder &builder, - Operation *op, Value id, - int64_t multiplicity); +Optional +distributPointwiseVectorOp(OpBuilder &builder, Operation *op, + ArrayRef id, ArrayRef multiplicity, + const AffineMap &map); /// Canonicalize an extra element using the result of a pointwise operation. /// Transforms: /// %v = addf %a, %b : vector32xf32> Index: mlir/lib/Dialect/Vector/VectorOps.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorOps.cpp +++ mlir/lib/Dialect/Vector/VectorOps.cpp @@ -999,33 +999,78 @@ //===----------------------------------------------------------------------===// void ExtractMapOp::build(OpBuilder &builder, OperationState &result, - Value vector, Value id, int64_t multiplicity) { + Value vector, ValueRange ids, + ArrayRef multiplicity, + AffineMap permutationMap) { + assert(ids.size() == multiplicity.size() && + ids.size() == permutationMap.getNumResults()); + assert(permutationMap.isProjectedPermutation()); VectorType type = vector.getType().cast(); - VectorType resultType = VectorType::get(type.getNumElements() / multiplicity, - type.getElementType()); - ExtractMapOp::build(builder, result, resultType, vector, id); + SmallVector newShape(type.getShape().begin(), + type.getShape().end()); + for (unsigned i = 0, e = permutationMap.getNumResults(); i < e; i++) { + AffineExpr expr = permutationMap.getResult(i); + auto dim = expr.cast(); + newShape[dim.getPosition()] = newShape[dim.getPosition()] / multiplicity[i]; + } + VectorType resultType = VectorType::get(newShape, type.getElementType()); + ExtractMapOp::build(builder, result, resultType, vector, ids); } static LogicalResult verify(ExtractMapOp op) { - if (op.getSourceVectorType().getShape().size() != 1 || - op.getResultType().getShape().size() != 1) - return op.emitOpError("expects source and destination vectors of rank 1"); - if (op.getSourceVectorType().getNumElements() % - op.getResultType().getNumElements() != - 0) + if (op.getSourceVectorType().getRank() != op.getResultType().getRank()) return op.emitOpError( - "source vector size must be a multiple of destination vector size"); + "expected source and destination vectors of same rank"); + unsigned numId = 0; + for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; ++i) { + if (op.getSourceVectorType().getDimSize(i) % + op.getResultType().getDimSize(i) != + 0) + return op.emitOpError("source vector dimensions must be a multiple of " + "destination vector dimensions"); + if (op.getSourceVectorType().getDimSize(i) != + op.getResultType().getDimSize(i)) + numId++; + } + if (numId != op.ids().size()) + return op.emitOpError("expected number of ids must match the number of " + "dimensions distributed"); return success(); } OpFoldResult ExtractMapOp::fold(ArrayRef operands) { auto insert = vector().getDefiningOp(); - if (insert == nullptr || multiplicity() != insert.multiplicity() || - id() != insert.id()) + if (insert == nullptr || getType() != insert.vector().getType() || + ids() != insert.ids()) return {}; return insert.vector(); } +void ExtractMapOp::getMultiplicity(SmallVectorImpl &multiplicity) { + assert(multiplicity.empty()); + for (unsigned i = 0, e = getSourceVectorType().getRank(); i < e; i++) { + if (getSourceVectorType().getDimSize(i) != getResultType().getDimSize(i)) + multiplicity.push_back(getSourceVectorType().getDimSize(i) / + getResultType().getDimSize(i)); + } +} + +template AffineMap calculateImplicitMap(MapOp op) { + SmallVector perm; + // Check which dimension have a multiplicity greater than 1 and associated + // them to the IDs in order. + for (unsigned i = 0, e = op.getSourceVectorType().getRank(); i < e; i++) { + if (op.getSourceVectorType().getDimSize(i) != + op.getResultType().getDimSize(i)) + perm.push_back(getAffineDimExpr(i, op.getContext())); + } + auto map = AffineMap::get(op.getSourceVectorType().getRank(), 0, perm, + op.getContext()); + return map; +} + +AffineMap ExtractMapOp::map() { return calculateImplicitMap(*this); } + //===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// @@ -1253,26 +1298,33 @@ //===----------------------------------------------------------------------===// void InsertMapOp::build(OpBuilder &builder, OperationState &result, - Value vector, Value dest, Value id, - int64_t multiplicity) { - VectorType type = vector.getType().cast(); - VectorType resultType = VectorType::get(type.getNumElements() * multiplicity, - type.getElementType()); - InsertMapOp::build(builder, result, resultType, vector, dest, id); + Value vector, Value dest, ValueRange ids) { + InsertMapOp::build(builder, result, dest.getType(), vector, dest, ids); } static LogicalResult verify(InsertMapOp op) { - if (op.getSourceVectorType().getShape().size() != 1 || - op.getResultType().getShape().size() != 1) - return op.emitOpError("expected source and destination vectors of rank 1"); - if (op.getResultType().getNumElements() % - op.getSourceVectorType().getNumElements() != - 0) + if (op.getSourceVectorType().getRank() != op.getResultType().getRank()) return op.emitOpError( - "destination vector size must be a multiple of source vector size"); + "expected source and destination vectors of same rank"); + unsigned numId = 0; + for (unsigned i = 0, e = op.getResultType().getRank(); i < e; i++) { + if (op.getResultType().getDimSize(i) % + op.getSourceVectorType().getDimSize(i) != + 0) + return op.emitOpError( + "destination vector size must be a multiple of source vector size"); + if (op.getResultType().getDimSize(i) != + op.getSourceVectorType().getDimSize(i)) + numId++; + } + if (numId != op.ids().size()) + return op.emitOpError("expected number of ids must match the number of " + "dimensions distributed"); return success(); } +AffineMap InsertMapOp::map() { return calculateImplicitMap(*this); } + //===----------------------------------------------------------------------===// // InsertStridedSliceOp //===----------------------------------------------------------------------===// Index: mlir/lib/Dialect/Vector/VectorTransforms.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -2483,16 +2483,16 @@ SmallVector extractOperands; for (OpOperand &operand : definedOp->getOpOperands()) extractOperands.push_back(rewriter.create( - loc, operand.get(), extract.id(), extract.multiplicity())); + loc, extract.getResultType(), operand.get(), extract.ids())); Operation *newOp = cloneOpWithOperandsAndTypes( rewriter, loc, definedOp, extractOperands, extract.getResult().getType()); rewriter.replaceOp(extract, newOp->getResult(0)); return success(); } -Optional -mlir::vector::distributPointwiseVectorOp(OpBuilder &builder, Operation *op, - Value id, int64_t multiplicity) { +Optional mlir::vector::distributPointwiseVectorOp( + OpBuilder &builder, Operation *op, ArrayRef ids, + ArrayRef multiplicity, const AffineMap &map) { OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointAfter(op); Location loc = op->getLoc(); @@ -2500,15 +2500,24 @@ return {}; Value result = op->getResult(0); VectorType type = op->getResult(0).getType().dyn_cast(); - // Currently only support distributing 1-D vectors of size multiple of the - // given multiplicty. To handle more sizes we would need to support masking. - if (!type || type.getRank() != 1 || type.getNumElements() % multiplicity != 0) + if (!type || map.getNumResults() != multiplicity.size()) return {}; + // For each dimension being distributed check that the size is a multiple of + // the multiplicity. To handle more sizes we would need to support masking. + unsigned multiplictyCount = 0; + for (auto exp : map.getResults()) { + auto affinExp = exp.dyn_cast(); + if (!affinExp || affinExp.getPosition() >= type.getRank() || + type.getDimSize(affinExp.getPosition()) % + multiplicity[multiplictyCount++] != + 0) + return {}; + } DistributeOps ops; ops.extract = - builder.create(loc, result, id, multiplicity); - ops.insert = builder.create(loc, ops.extract, result, id, - multiplicity); + builder.create(loc, result, ids, multiplicity, map); + ops.insert = + builder.create(loc, ops.extract, result, ids); return ops; } @@ -2529,17 +2538,22 @@ using mlir::edsc::op::operator*; using namespace mlir::edsc::intrinsics; SmallVector indices(read.indices().begin(), read.indices().end()); - indices.back() = - indices.back() + - (extract.id() * - std_constant_index(extract.getResultType().getDimSize(0))); + AffineMap map = extract.map(); + unsigned idCount = 0; + for (auto expr : map.getResults()) { + unsigned pos = expr.cast().getPosition(); + indices[pos] = + indices[pos] + + extract.ids()[idCount++] * + std_constant_index(extract.getResultType().getDimSize(pos)); + } Value newRead = vector_transfer_read(extract.getType(), read.memref(), indices, read.permutation_map(), - read.padding(), ArrayAttr()); + read.padding(), read.maskedAttr()); Value dest = rewriter.create( read.getLoc(), read.getType(), rewriter.getZeroAttr(read.getType())); - newRead = rewriter.create( - read.getLoc(), newRead, dest, extract.id(), extract.multiplicity()); + newRead = rewriter.create(read.getLoc(), newRead, dest, + extract.ids()); rewriter.replaceOp(read, newRead); return success(); } @@ -2560,12 +2574,17 @@ using namespace mlir::edsc::intrinsics; SmallVector indices(write.indices().begin(), write.indices().end()); - indices.back() = - indices.back() + - (insert.id() * - std_constant_index(insert.getSourceVectorType().getDimSize(0))); + AffineMap map = insert.map(); + unsigned idCount = 0; + for (auto expr : map.getResults()) { + unsigned pos = expr.cast().getPosition(); + indices[pos] = + indices[pos] + + insert.ids()[idCount++] * + std_constant_index(insert.getSourceVectorType().getDimSize(pos)); + } vector_transfer_write(insert.vector(), write.memref(), indices, - write.permutation_map(), ArrayAttr()); + write.permutation_map(), write.maskedAttr()); rewriter.eraseOp(write); return success(); } Index: mlir/test/Dialect/Vector/invalid.mlir =================================================================== --- mlir/test/Dialect/Vector/invalid.mlir +++ mlir/test/Dialect/Vector/invalid.mlir @@ -1331,23 +1331,30 @@ // ----- -func @extract_map_rank(%v: vector<2x32xf32>, %id : index) { - // expected-error@+1 {{'vector.extract_map' op expects source and destination vectors of rank 1}} - %0 = vector.extract_map %v[%id] : vector<2x32xf32> to vector<2x1xf32> +func @extract_map_rank(%v: vector<32xf32>, %id : index) { + // expected-error@+1 {{'vector.extract_map' op expected source and destination vectors of same rank}} + %0 = vector.extract_map %v[%id] : vector<32xf32> to vector<2x1xf32> } // ----- func @extract_map_size(%v: vector<63xf32>, %id : index) { - // expected-error@+1 {{'vector.extract_map' op source vector size must be a multiple of destination vector size}} + // expected-error@+1 {{'vector.extract_map' op source vector dimensions must be a multiple of destination vector dimensions}} %0 = vector.extract_map %v[%id] : vector<63xf32> to vector<2xf32> } // ----- -func @insert_map_rank(%v: vector<2x1xf32>, %v1: vector<2x32xf32>, %id : index) { - // expected-error@+1 {{'vector.insert_map' op expected source and destination vectors of rank 1}} - %0 = vector.insert_map %v, %v1[%id] : vector<2x1xf32> into vector<2x32xf32> +func @extract_map_id(%v: vector<2x32xf32>, %id : index) { + // expected-error@+1 {{'vector.extract_map' op expected number of ids must match the number of dimensions distributed}} + %0 = vector.extract_map %v[%id] : vector<2x32xf32> to vector<1x1xf32> +} + +// ----- + +func @insert_map_rank(%v: vector<2x1xf32>, %v1: vector<32xf32>, %id : index) { + // expected-error@+1 {{'vector.insert_map' op expected source and destination vectors of same rank}} + %0 = vector.insert_map %v, %v1[%id] : vector<2x1xf32> into vector<32xf32> } // ----- @@ -1356,3 +1363,10 @@ // expected-error@+1 {{'vector.insert_map' op destination vector size must be a multiple of source vector size}} %0 = vector.insert_map %v, %v1[%id] : vector<3xf32> into vector<64xf32> } + +// ----- + +func @insert_map_id(%v: vector<2x1xf32>, %v1: vector<4x32xf32>, %id : index) { + // expected-error@+1 {{'vector.insert_map' op expected number of ids must match the number of dimensions distributed}} + %0 = vector.insert_map %v, %v1[%id] : vector<2x1xf32> into vector<4x32xf32> +} Index: mlir/test/Dialect/Vector/ops.mlir =================================================================== --- mlir/test/Dialect/Vector/ops.mlir +++ mlir/test/Dialect/Vector/ops.mlir @@ -434,12 +434,17 @@ } // CHECK-LABEL: @extract_insert_map -func @extract_insert_map(%v: vector<32xf32>, %id : index) -> vector<32xf32> { +func @extract_insert_map(%v: vector<32xf32>, %v2: vector<16x32xf32>, + %id0 : index, %id1 : index) -> (vector<32xf32>, vector<16x32xf32>) { // CHECK: %[[V:.*]] = vector.extract_map %{{.*}}[%{{.*}}] : vector<32xf32> to vector<2xf32> - %vd = vector.extract_map %v[%id] : vector<32xf32> to vector<2xf32> + %vd = vector.extract_map %v[%id0] : vector<32xf32> to vector<2xf32> + // CHECK: %[[V1:.*]] = vector.extract_map %{{.*}}[%{{.*}}, %{{.*}}] : vector<16x32xf32> to vector<4x2xf32> + %vd2 = vector.extract_map %v2[%id0, %id1] : vector<16x32xf32> to vector<4x2xf32> // CHECK: %[[R:.*]] = vector.insert_map %[[V]], %{{.*}}[%{{.*}}] : vector<2xf32> into vector<32xf32> - %r = vector.insert_map %vd, %v[%id] : vector<2xf32> into vector<32xf32> - // CHECK: return %[[R]] : vector<32xf32> - return %r : vector<32xf32> + %r = vector.insert_map %vd, %v[%id0] : vector<2xf32> into vector<32xf32> + // CHECK: %[[R1:.*]] = vector.insert_map %[[V1]], %{{.*}}[%{{.*}}, %{{.*}}] : vector<4x2xf32> into vector<16x32xf32> + %r2 = vector.insert_map %vd2, %v2[%id0, %id1] : vector<4x2xf32> into vector<16x32xf32> + // CHECK: return %[[R]], %[[R1]] : vector<32xf32>, vector<16x32xf32> + return %r, %r2 : vector<32xf32>, vector<16x32xf32> } Index: mlir/test/Dialect/Vector/vector-distribution.mlir =================================================================== --- mlir/test/Dialect/Vector/vector-distribution.mlir +++ mlir/test/Dialect/Vector/vector-distribution.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32 -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32,1,32 -split-input-file | FileCheck %s // CHECK-LABEL: func @distribute_vector_add // CHECK-SAME: (%[[ID:.*]]: index @@ -22,7 +22,7 @@ // CHECK-NEXT: %[[ADD1:.*]] = addf %[[EXA]], %[[EXB]] : vector<1xf32> // CHECK-NEXT: %[[EXC:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32> // CHECK-NEXT: %[[ADD2:.*]] = addf %[[ADD1]], %[[EXC]] : vector<1xf32> -// CHECK-NEXT: vector.transfer_write %[[ADD2]], %{{.*}}[%[[ID]]] : vector<1xf32>, memref<32xf32> +// CHECK-NEXT: vector.transfer_write %[[ADD2]], %{{.*}}[%[[ID]]] {{.*}} : vector<1xf32>, memref<32xf32> // CHECK-NEXT: return func @vector_add_read_write(%id : index, %A: memref<32xf32>, %B: memref<32xf32>, %C: memref<32xf32>, %D: memref<32xf32>) { %c0 = constant 0 : index @@ -48,7 +48,7 @@ // CHECK-NEXT: %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[ID2]]], %{{.*}} : memref<64xf32>, vector<2xf32> // CHECK-NEXT: %[[ADD:.*]] = addf %[[EXA]], %[[EXB]] : vector<2xf32> // CHECK-NEXT: %[[ID3:.*]] = affine.apply #[[MAP0]]()[%[[ID]]] -// CHECK-NEXT: vector.transfer_write %[[ADD]], %{{.*}}[%[[ID3]]] : vector<2xf32>, memref<64xf32> +// CHECK-NEXT: vector.transfer_write %[[ADD]], %{{.*}}[%[[ID3]]] {{.*}} : vector<2xf32>, memref<64xf32> // CHECK-NEXT: return func @vector_add_cycle(%id : index, %A: memref<64xf32>, %B: memref<64xf32>, %C: memref<64xf32>) { %c0 = constant 0 : index @@ -81,4 +81,46 @@ return } +// ----- + +// CHECK-LABEL: func @distribute_vector_add_3d +// CHECK-SAME: (%[[ID0:.*]]: index, %[[ID1:.*]]: index +// CHECK-NEXT: %[[ADDV:.*]] = addf %{{.*}}, %{{.*}} : vector<64x4x32xf32> +// CHECK-NEXT: %[[EXA:.*]] = vector.extract_map %{{.*}}[%[[ID0]], %[[ID1]]] : vector<64x4x32xf32> to vector<2x4x1xf32> +// CHECK-NEXT: %[[EXB:.*]] = vector.extract_map %{{.*}}[%[[ID0]], %[[ID1]]] : vector<64x4x32xf32> to vector<2x4x1xf32> +// CHECK-NEXT: %[[ADD:.*]] = addf %[[EXA]], %[[EXB]] : vector<2x4x1xf32> +// CHECK-NEXT: %[[INS:.*]] = vector.insert_map %[[ADD]], %[[ADDV]][%[[ID0]], %[[ID1]]] : vector<2x4x1xf32> into vector<64x4x32xf32> +// CHECK-NEXT: return %[[INS]] : vector<64x4x32xf32> +func @distribute_vector_add_3d(%id0 : index, %id1 : index, + %A: vector<64x4x32xf32>, %B: vector<64x4x32xf32>) -> vector<64x4x32xf32> { + %0 = addf %A, %B : vector<64x4x32xf32> + return %0: vector<64x4x32xf32> +} + +// ----- + +// CHECK-DAG: #[[MAP0:map[0-9]+]] = affine_map<()[s0] -> (s0 * 2)> + +// CHECK: func @vector_add_transfer_3d +// CHECK-SAME: (%[[ID_0:.*]]: index, %[[ID_1:.*]]: index +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[ID1:.*]] = affine.apply #[[MAP0]]()[%[[ID_0]]] +// CHECK-NEXT: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[ID1]], %[[C0]], %[[ID_1]]], %{{.*}} : memref<64x64x64xf32>, vector<2x4x1xf32> +// CHECK-NEXT: %[[ID2:.*]] = affine.apply #[[MAP0]]()[%[[ID_0]]] +// CHECK-NEXT: %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[ID2]], %[[C0]], %[[ID_1]]], %{{.*}} : memref<64x64x64xf32>, vector<2x4x1xf32> +// CHECK-NEXT: %[[ADD:.*]] = addf %[[EXA]], %[[EXB]] : vector<2x4x1xf32> +// CHECK-NEXT: %[[ID3:.*]] = affine.apply #[[MAP0]]()[%[[ID_0]]] +// CHECK-NEXT: vector.transfer_write %[[ADD]], %{{.*}}[%[[ID3]], %[[C0]], %[[ID_1]]] {{.*}} : vector<2x4x1xf32>, memref<64x64x64xf32> +// CHECK-NEXT: return +func @vector_add_transfer_3d(%id0 : index, %id1 : index, %A: memref<64x64x64xf32>, + %B: memref<64x64x64xf32>, %C: memref<64x64x64xf32>) { + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + %a = vector.transfer_read %A[%c0, %c0, %c0], %cf0: memref<64x64x64xf32>, vector<64x4x32xf32> + %b = vector.transfer_read %B[%c0, %c0, %c0], %cf0: memref<64x64x64xf32>, vector<64x4x32xf32> + %acc = addf %a, %b: vector<64x4x32xf32> + vector.transfer_write %acc, %C[%c0, %c0, %c0]: vector<64x4x32xf32>, memref<64x64x64xf32> + return +} + Index: mlir/test/lib/Transforms/TestVectorTransforms.cpp =================================================================== --- mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -163,21 +163,40 @@ registry.insert(); registry.insert(); } - Option multiplicity{ - *this, "distribution-multiplicity", - llvm::cl::desc("Set the multiplicity used for distributing vector"), - llvm::cl::init(32)}; + ListOption multiplicity{ + *this, "distribution-multiplicity", llvm::cl::MiscFlags::CommaSeparated, + llvm::cl::desc("Set the multiplicity used for distributing vector")}; + void runOnFunction() override { MLIRContext *ctx = &getContext(); OwningRewritePatternList patterns; FuncOp func = getFunction(); func.walk([&](AddFOp op) { OpBuilder builder(op); - Optional ops = distributPointwiseVectorOp( - builder, op.getOperation(), func.getArgument(0), multiplicity); - if (ops.hasValue()) { - SmallPtrSet extractOp({ops->extract, ops->insert}); - op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp); + if (auto vecType = op.getType().dyn_cast()) { + SmallVector mul; + SmallVector perm; + SmallVector ids; + unsigned count = 0; + // Remove the multiplicity of 1 and calculate the affine map based on + // the multiplicity. + SmallVector m(multiplicity.begin(), multiplicity.end()); + for (unsigned i = 0, e = vecType.getRank(); i < e; i++) { + if (i < m.size() && m[i] != 1 && vecType.getDimSize(i) % m[i] == 0) { + mul.push_back(m[i]); + ids.push_back(func.getArgument(count++)); + perm.push_back(getAffineDimExpr(i, ctx)); + } + } + auto map = AffineMap::get(op.getType().cast().getRank(), 0, + perm, ctx); + Optional ops = distributPointwiseVectorOp( + builder, op.getOperation(), ids, mul, map); + if (ops.hasValue()) { + SmallPtrSet extractOp({ops->extract, ops->insert}); + op.getResult().replaceAllUsesExcept(ops->insert.getResult(), + extractOp); + } } }); patterns.insert(ctx); @@ -229,9 +248,11 @@ for (Operation *it : dependentOps) { it->moveBefore(forOp.getBody()->getTerminator()); } + auto map = AffineMap::getMultiDimIdentityMap(1, ctx); // break up the original op and let the patterns propagate. Optional ops = distributPointwiseVectorOp( - builder, op.getOperation(), forOp.getInductionVar(), multiplicity); + builder, op.getOperation(), {forOp.getInductionVar()}, {multiplicity}, + map); if (ops.hasValue()) { SmallPtrSet extractOp({ops->extract, ops->insert}); op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);