Index: mlir/include/mlir/Dialect/Vector/VectorOps.td =================================================================== --- mlir/include/mlir/Dialect/Vector/VectorOps.td +++ mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -457,6 +457,71 @@ }]; } +def Vector_ExtractMapOp : + Vector_Op<"extract_map", [NoSideEffect]>, + Arguments<(ins AnyVector:$vector, Index:$id, I64Attr:$multiplicity)>, + Results<(outs AnyVector)> { + let summary = "vector extract map operation"; + let description = [{ + Takes an 1-D vector and extract a sub-part of the vector starting at id with + a size of `vector size / multiplicity`. This maps a given multiplicity of + the vector to a Value such as a loop induction variable or an SPMD id. + + Similarly to vector.tuple_get, this operation is used for progressive + lowering and should be folded away before converting to LLVM. + + + For instance, the following code: + ```mlir + %a = vector.transfer_read %A[%c0]: memref<32xf32>, vector<32xf32> + %b = vector.transfer_read %B[%c0]: memref<32xf32>, vector<32xf32> + %c = addf %a, %b: vector<32xf32> + vector.transfer_write %c, %C[%c0]: memref<32xf32>, vector<32xf32> + ``` + can be rewritten to: + ```mlir + %a = vector.transfer_read %A[%c0]: memref<32xf32>, vector<32xf32> + %b = vector.transfer_read %B[%c0]: memref<32xf32>, vector<32xf32> + %ea = vector.extract_map %a[%id : 32] : vector<32xf32> to vector<1xf32> + %eb = vector.extract_map %b[%id : 32] : vector<32xf32> to vector<1xf32> + %ec = addf %ea, %eb : vector<1xf32> + %c = vector.insert_map %ec, %id, 32 : vector<1xf32> to vector<32xf32> + vector.transfer_write %c, %C[%c0]: memref<32xf32>, vector<32xf32> + ``` + + Where %id can be an induction variable or an SPMD id going from 0 to 31. + + And then be rewritten to: + ```mlir + %a = vector.transfer_read %A[%id]: memref<32xf32>, vector<1xf32> + %b = vector.transfer_read %B[%id]: memref<32xf32>, vector<1xf32> + %c = addf %a, %b: vector<1xf32> + vector.transfer_write %c, %C[%id]: memref<32xf32>, vector<1xf32> + ``` + + Example: + + ```mlir + %ev = vector.extract_map %v[%id:32] : vector<32xf32> to vector<1xf32> + ``` + }]; + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, " # + "Value vector, Value id, int64_t multiplicity">]; + let extraClassDeclaration = [{ + VectorType getSourceVectorType() { + return vector().getType().cast(); + } + VectorType getResultType() { + return getResult().getType().cast(); + } + }]; + let assemblyFormat = [{ + $vector `[` $id `:` $multiplicity `]` attr-dict `:` type($vector) `to` + type(results) + }]; +} + def Vector_FMAOp : Op]>, @@ -632,6 +697,46 @@ }]; } +def Vector_InsertMapOp : + Vector_Op<"insert_map", [NoSideEffect]>, + Arguments<(ins AnyVector:$vector, Index:$id, I64Attr:$multiplicity)>, + Results<(outs AnyVector)> { + let summary = "vector insert map operation"; + let description = [{ + insert an 1-D vector and within a larger vector starting at id. The new + vector created will have a size of `vector size * multiplicity`. This + represents how a sub-part of the vector is written for a given Value such as + a loop induction variable or an SPMD id. + + Similarly to vector.tuple_get, this operation is used for progressive + lowering and should be folded away before converting to LLVM. + + This operations is meant to be used in combination with vector.extract_map. + See example in extract.map description. + + Example: + + ```mlir + %v = vector.insert_map %ev, %id, 32 : vector<1xf32> to vector<32xf32> + ``` + }]; + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, " # + "Value vector, Value id, int64_t multiplicity">]; + let extraClassDeclaration = [{ + VectorType getSourceVectorType() { + return vector().getType().cast(); + } + VectorType getResultType() { + return getResult().getType().cast(); + } + }]; + let assemblyFormat = [{ + $vector `,` $id `,` $multiplicity attr-dict `:` type($vector) `to` + type(results) + }]; +} + def Vector_InsertStridedSliceOp : Vector_Op<"insert_strided_slice", [NoSideEffect, PredOpTrait<"operand #0 and result have same element type", Index: mlir/include/mlir/Dialect/Vector/VectorTransforms.h =================================================================== --- mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -172,6 +172,41 @@ FilterConstraintType filter; }; +/// Distribute a 1D vector pointwise operation over a range of given IDs going +/// from 0 to dimension. This transformation only inserts +/// vector.extract_map/vector.insert_map. It is meant to be used with +/// canonicalizations pattern to propagate and fold the vector +/// insert_map/extract_map operations. +/// Transforms: +/// %v = addf %a, %b : vector<32xf32> +/// to: +/// %v = addf %a, %b : vector<32xf32> +/// %ev = vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32> +/// %nv = vector.insert_map %ev, %id, 32 : vector<1xf32> into vector<32xf32> +void distributPointwiseVectorOp(OpBuilder &builder, Operation *op, Value id, + int64_t dimension, ExtractMapOp &extract, + InsertMapOp &insert); +/// Canonicalize an extra element using the result of a pointwise operation. +/// Transforms: +/// %v = addf %a, %b : vector32xf32> +/// %dv = vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32> +/// to: +/// %da = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32> +/// %db = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32> +/// %dv = addf %da, %db : vector<1xf32> +struct PointwiseExtractPattern : public OpRewritePattern { + using FilterConstraintType = std::function; + PointwiseExtractPattern( + MLIRContext *context, FilterConstraintType constraint = + [](ExtractMapOp op) { return success(); }) + : OpRewritePattern(context), filter(constraint) {} + LogicalResult matchAndRewrite(ExtractMapOp extract, + PatternRewriter &rewriter) const override; + +private: + FilterConstraintType filter; +}; + } // namespace vector //===----------------------------------------------------------------------===// Index: mlir/lib/Dialect/Vector/VectorOps.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorOps.cpp +++ mlir/lib/Dialect/Vector/VectorOps.cpp @@ -901,6 +901,29 @@ } //===----------------------------------------------------------------------===// +// ExtractMapOp +//===----------------------------------------------------------------------===// + +void ExtractMapOp::build(OpBuilder &builder, OperationState &result, + Value vector, Value id, int64_t dimension) { + VectorType type = vector.getType().cast(); + VectorType resultType = + VectorType::get(type.getNumElements() / dimension, type.getElementType()); + ExtractMapOp::build(builder, result, resultType, vector, id, dimension); +} + +static LogicalResult verify(ExtractMapOp op) { + if (op.getSourceVectorType().getShape().size() != 1 || + op.getResultType().getShape().size() != 1) + return op.emitOpError("expects source and destination vectors of rank 1"); + if (op.getResultType().getNumElements() * (int64_t)op.multiplicity() != + op.getSourceVectorType().getNumElements()) + return op.emitOpError("vector sizes mismatch. Source size must be equal " + "to destination size * dimension"); + return success(); +} + +//===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// @@ -1123,6 +1146,30 @@ } //===----------------------------------------------------------------------===// +// InsertMapOp +//===----------------------------------------------------------------------===// + +void InsertMapOp::build(OpBuilder &builder, OperationState &result, + Value vector, Value id, int64_t dimension) { + VectorType type = vector.getType().cast(); + VectorType resultType = + VectorType::get(type.getNumElements() * dimension, type.getElementType()); + InsertMapOp::build(builder, result, resultType, vector, id, dimension); +} + +static LogicalResult verify(InsertMapOp op) { + if (op.getSourceVectorType().getShape().size() != 1 || + op.getResultType().getShape().size() != 1) + return op.emitOpError("expected source and destination vectors of rank 1"); + if ((int64_t)op.multiplicity() * op.getSourceVectorType().getNumElements() != + op.getResultType().getNumElements()) + return op.emitOpError( + "vector sizes mismatch. Destination size must be equal " + "to source size * dimension"); + return success(); +} + +//===----------------------------------------------------------------------===// // InsertStridedSliceOp //===----------------------------------------------------------------------===// Index: mlir/lib/Dialect/Vector/VectorTransforms.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -2421,6 +2421,37 @@ return failure(); } +LogicalResult mlir::vector::PointwiseExtractPattern::matchAndRewrite( + ExtractMapOp extract, PatternRewriter &rewriter) const { + Operation *definedOp = extract.vector().getDefiningOp(); + if (!definedOp || definedOp->getNumResults() != 1) + return failure(); + // TODO: Create an interfaceOp for elementwise operations. + if (!isa(definedOp)) + return failure(); + Location loc = extract.getLoc(); + SmallVector extractOperands; + for (OpOperand &operand : definedOp->getOpOperands()) + extractOperands.push_back(rewriter.create( + loc, operand.get(), extract.id(), extract.multiplicity())); + Operation *newOp = cloneOpWithOperandsAndTypes( + rewriter, loc, definedOp, extractOperands, extract.getResult().getType()); + rewriter.replaceOp(extract, newOp->getResult(0)); + return success(); +} + +void mlir::vector::distributPointwiseVectorOp(OpBuilder &builder, Operation *op, + Value id, int64_t dimension, + ExtractMapOp &extract, + InsertMapOp &insert) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointAfter(op); + Location loc = op->getLoc(); + Value result = op->getResult(0); + extract = builder.create(loc, result, id, dimension); + insert = builder.create(loc, extract, id, dimension); +} + // TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp). // TODO: Add this as DRR pattern. void mlir::vector::populateVectorToVectorTransformationPatterns( Index: mlir/test/Dialect/Vector/invalid.mlir =================================================================== --- mlir/test/Dialect/Vector/invalid.mlir +++ mlir/test/Dialect/Vector/invalid.mlir @@ -1328,3 +1328,31 @@ // expected-error@+1 {{'vector.compressstore' op expected value dim to match mask dim}} vector.compressstore %base, %mask, %value : memref, vector<17xi1>, vector<16xf32> } + +// ----- + +func @extract_map_rank(%v: vector<2x32xf32>, %id : index) { + // expected-error@+1 {{'vector.extract_map' op expects source and destination vectors of rank 1}} + %0 = vector.extract_map %v[%id : 32] : vector<2x32xf32> to vector<2x1xf32> +} + +// ----- + +func @extract_map_size(%v: vector<63xf32>, %id : index) { + // expected-error@+1 {{'vector.extract_map' op vector sizes mismatch. Source size must be equal to destination size * dimension}} + %0 = vector.extract_map %v[%id : 32] : vector<63xf32> to vector<2xf32> +} + +// ----- + +func @insert_map_rank(%v: vector<2x1xf32>, %id : index) { + // expected-error@+1 {{'vector.insert_map' op expected source and destination vectors of rank 1}} + %0 = vector.insert_map %v, %id, 32 : vector<2x1xf32> to vector<2x32xf32> +} + +// ----- + +func @insert_map_size(%v: vector<1xf32>, %id : index) { + // expected-error@+1 {{'vector.insert_map' op vector sizes mismatch. Destination size must be equal to source size * dimension}} + %0 = vector.insert_map %v, %id, 32 : vector<1xf32> to vector<64xf32> +} Index: mlir/test/Dialect/Vector/ops.mlir =================================================================== --- mlir/test/Dialect/Vector/ops.mlir +++ mlir/test/Dialect/Vector/ops.mlir @@ -432,3 +432,14 @@ vector.compressstore %base, %mask, %0 : memref, vector<16xi1>, vector<16xf32> return } + +// CHECK-LABEL: @extract_insert_map +func @extract_insert_map(%v: vector<32xf32>, %id : index) -> vector<32xf32> { + // CHECK: %[[V:.*]] = vector.extract_map %{{.*}}[%{{.*}} : 16] : vector<32xf32> to vector<2xf32> + %vd = vector.extract_map %v[%id : 16] : vector<32xf32> to vector<2xf32> + // CHECK: %[[R:.*]] = vector.insert_map %[[V]], %{{.*}}, 16 : vector<2xf32> to vector<32xf32> + %r = vector.insert_map %vd, %id, 16 : vector<2xf32> to vector<32xf32> + // CHECK: return %[[R]] : vector<32xf32> + return %r : vector<32xf32> +} + Index: mlir/test/Dialect/Vector/vector-distribution.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Vector/vector-distribution.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt %s -test-vector-distribute-patterns | FileCheck %s + +// CHECK-LABEL: func @distribute_vector_add +// CHECK-SAME: (%[[ID:.*]]: index +// CHECK-NEXT: %[[EXA:.*]] = vector.extract_map %{{.*}}[%[[ID]] : 32] : vector<32xf32> to vector<1xf32> +// CHECK-NEXT: %[[EXB:.*]] = vector.extract_map %{{.*}}[%[[ID]] : 32] : vector<32xf32> to vector<1xf32> +// CHECK-NEXT: %[[ADD:.*]] = addf %[[EXA]], %[[EXB]] : vector<1xf32> +// CHECK-NEXT: %[[INS:.*]] = vector.insert_map %[[ADD]], %[[ID]], 32 : vector<1xf32> to vector<32xf32> +// CHECK-NEXT: return %[[INS]] : vector<32xf32> +func @distribute_vector_add(%id : index, %A: vector<32xf32>, %B: vector<32xf32>) -> vector<32xf32> { + %0 = addf %A, %B : vector<32xf32> + return %0: vector<32xf32> +} Index: mlir/test/lib/Transforms/TestVectorTransforms.cpp =================================================================== --- mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -125,6 +125,30 @@ } }; +struct TestVectorDistributePatterns + : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnFunction() override { + MLIRContext *ctx = &getContext(); + OwningRewritePatternList patterns; + FuncOp func = getFunction(); + func.walk([&](AddFOp op) { + OpBuilder builder(op); + vector::ExtractMapOp extract = nullptr; + vector::InsertMapOp insert = nullptr; + distributPointwiseVectorOp(builder, op.getOperation(), + func.getArgument(0), 32, extract, insert); + assert(extract && insert); + SmallPtrSet extractOp({extract}); + op.getResult().replaceAllUsesExcept(insert.getResult(), extractOp); + }); + patterns.insert(ctx); + applyPatternsAndFoldGreedily(getFunction(), patterns); + } +}; + struct TestVectorTransferFullPartialSplitPatterns : public PassWrapper { @@ -178,5 +202,9 @@ vectorTransformFullPartialPass("test-vector-transfer-full-partial-split", "Test conversion patterns to split " "transfer ops via scf.if + linalg ops"); + PassRegistration distributePass( + "test-vector-distribute-patterns", + "Test conversion patterns to distribute vector ops in the vector " + "dialect"); } } // namespace mlir