Index: mlir/include/mlir/Dialect/Vector/VectorOps.td =================================================================== --- mlir/include/mlir/Dialect/Vector/VectorOps.td +++ mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -457,6 +457,43 @@ }]; } +def Vector_ExtractMapOp : + Vector_Op<"extract_map", [NoSideEffect]>, + Arguments<(ins AnyVector:$vector, Index:$id, I64Attr:$dimension)>, + Results<(outs AnyVector)> { + let summary = "vector extract map operation"; + let description = [{ + Takes an N-D vector (currently limited to 1D) and extract a sub-part of the + vector starting at id with a size of `vector size / dimension`. This maps a + a given dimension of the vector to an SPMD id. + + This operation doesn't have trivial lowering and is meant to be used to do + incremental lowering. Eventually this should be merged with a + vector.transfer_read op. + + Example: + + ```mlir + %ev = vector.extract_map %v, %id , 32 : vector<32xf32> to vector<1xf32> + ``` + }]; + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, " # + "Value vector, Value id, int64_t dimension">]; + let extraClassDeclaration = [{ + VectorType getSourceVectorType() { + return vector().getType().cast(); + } + VectorType getResultType() { + return getResult().getType().cast(); + } + }]; + let assemblyFormat = [{ + $vector `,` $id `,` $dimension attr-dict `:` type($vector) `to` + type(results) + }]; +} + def Vector_FMAOp : Op]>, @@ -632,6 +669,44 @@ }]; } +def Vector_InsertMapOp : + Vector_Op<"insert_map", [NoSideEffect]>, + Arguments<(ins AnyVector:$vector, Index:$id, I64Attr:$dimension)>, + Results<(outs AnyVector)> { + let summary = "vector insert slices operation"; + let description = [{ + insert an N-D vector (currently limited to 1D) and within a larger vector + vector starting at id. The new vector created will have a size of + `vector size * dimension`. This represents how a sub-part of the vector is + written for a given SPMD id. + + This operation doesn't have trivial lowering and is meant to be used + to do incremental lowering. Eventually this should be merged with a + vector transfer_write op. + + Example: + + ```mlir + %v = vector.insert_map %ev, %id, 32 : vector<1xf32> to vector<32xf32> + ``` + }]; + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, " # + "Value vector, Value id, int64_t dimension">]; + let extraClassDeclaration = [{ + VectorType getSourceVectorType() { + return vector().getType().cast(); + } + VectorType getResultType() { + return getResult().getType().cast(); + } + }]; + let assemblyFormat = [{ + $vector `,` $id `,` $dimension attr-dict `:` type($vector) `to` + type(results) + }]; +} + def Vector_InsertStridedSliceOp : Vector_Op<"insert_strided_slice", [NoSideEffect, PredOpTrait<"operand #0 and result have same element type", Index: mlir/include/mlir/Dialect/Vector/VectorTransforms.h =================================================================== --- mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -172,6 +172,65 @@ FilterConstraintType filter; }; +using CreateIdCallBack = std::function; +/// Distribute a 1D vector pointwise operation over a range of given IDs going +/// from 0 to dimension. +/// Transforms: +/// %v = addf %a, %b : vector<32xf32> +/// to: +/// %ad = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32> +/// %bd = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32> +/// %vd = addf %ad, %bd : vector<1xf32> +/// %v = vector.insert_map %vd, %id, 32 : vector<1xf32> into vector<32xf32> +Value distributPointwiseVectorOp(OpBuilder &builder, Operation *op, + CreateIdCallBack idFn, int64_t dimension); + +/// Canonicalize an extra element using the result of a pointwise operation. +/// Transforms: +/// %v = addf %a, %b : vector32xf32> +/// %dv = vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32> +/// to: +/// %da = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32> +/// %db = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32> +/// %dv = addf %da, %db : vector<1xf32> +Value canonicalizePointwiseExtract(OpBuilder &builder, + vector::ExtractMapOp extract); + +/// Pattern to apply `distributPointwiseVectorOp` transformation. +template +struct DistributeVectorPattern : public OpRewritePattern { + using FilterConstraintType = std::function; + DistributeVectorPattern( + CreateIdCallBack idFn, int64_t dim, MLIRContext *context, + FilterConstraintType constraint = [](OpTy op) { return success(); }) + : OpRewritePattern(context), idFn(idFn), dim(dim), + filter(constraint) {} + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + if (failed(filter(op))) + return failure(); + Operation *operation = op.getOperation(); + if (operation->getNumResults() != 1) + return failure(); + auto vecType = operation->getResult(0).getType().dyn_cast(); + if (!vecType) + return failure(); + if (vecType.getRank() != 1 || vecType.getNumElements() != dim) + return failure(); + Value newOp = + distributPointwiseVectorOp(rewriter, op.getOperation(), idFn, dim); + if (!newOp) + return failure(); + rewriter.replaceOp(op, newOp); + return success(); + } + +private: + CreateIdCallBack idFn; + int64_t dim; + FilterConstraintType filter; +}; + } // namespace vector //===----------------------------------------------------------------------===// Index: mlir/lib/Dialect/Vector/VectorOps.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorOps.cpp +++ mlir/lib/Dialect/Vector/VectorOps.cpp @@ -901,6 +901,29 @@ } //===----------------------------------------------------------------------===// +// ExtractMapOp +//===----------------------------------------------------------------------===// + +void ExtractMapOp::build(OpBuilder &builder, OperationState &result, + Value vector, Value id, int64_t dimension) { + VectorType type = vector.getType().cast(); + VectorType resultType = + VectorType::get(type.getNumElements() / dimension, type.getElementType()); + ExtractMapOp::build(builder, result, resultType, vector, id, dimension); +} + +static LogicalResult verify(ExtractMapOp op) { + if (op.getSourceVectorType().getShape().size() != 1 || + op.getResultType().getShape().size() != 1) + return op.emitOpError("expects source and destination vectors of rank 1"); + if (op.getResultType().getNumElements() * (int64_t)op.dimension() != + op.getSourceVectorType().getNumElements()) + return op.emitOpError("vector sizes mismatch. Source size must be equal " + "to destination size * dimension"); + return success(); +} + +//===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// @@ -1123,6 +1146,30 @@ } //===----------------------------------------------------------------------===// +// InsertMapOp +//===----------------------------------------------------------------------===// + +void InsertMapOp::build(OpBuilder &builder, OperationState &result, + Value vector, Value id, int64_t dimension) { + VectorType type = vector.getType().cast(); + VectorType resultType = + VectorType::get(type.getNumElements() * dimension, type.getElementType()); + InsertMapOp::build(builder, result, resultType, vector, id, dimension); +} + +static LogicalResult verify(InsertMapOp op) { + if (op.getSourceVectorType().getShape().size() != 1 || + op.getResultType().getShape().size() != 1) + return op.emitOpError("expected source and destination vectors of rank 1"); + if ((int64_t)op.dimension() * op.getSourceVectorType().getNumElements() != + op.getResultType().getNumElements()) + return op.emitOpError( + "vector sizes mismatch. Destination size must be equal " + "to source size * dimension"); + return success(); +} + +//===----------------------------------------------------------------------===// // InsertStridedSliceOp //===----------------------------------------------------------------------===// Index: mlir/lib/Dialect/Vector/VectorTransforms.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -2421,6 +2421,42 @@ return failure(); } +Value mlir::vector::canonicalizePointwiseExtract(OpBuilder &builder, + vector::ExtractMapOp extract) { + Operation *op = extract.vector().getDefiningOp(); + if (op->getNumResults() != 1) + return nullptr; + Location loc = extract.getLoc(); + SmallVector extractOperands; + for (OpOperand &operand : op->getOpOperands()) + extractOperands.push_back(builder.create( + loc, operand.get(), extract.id(), extract.dimension())); + Operation *newOp = cloneOpWithOperandsAndTypes( + builder, loc, op, extractOperands, extract.getResult().getType()); + return newOp->getResult(0); +} + +Value mlir::vector::distributPointwiseVectorOp(OpBuilder &builder, + Operation *op, + CreateIdCallBack idFn, + int64_t dimension) { + Location loc = op->getLoc(); + builder.setInsertionPointAfter(op); + Value id = idFn(builder, op->getLoc()); + Value result = op->getResult(0); + auto extract = + builder.create(loc, result, id, dimension); + Value newVec = + builder.create(loc, extract, id, dimension); + llvm::SmallPtrSet extractOp({extract}); + builder.setInsertionPoint(extract); + Value distributedOp = canonicalizePointwiseExtract(builder, extract); + if (!distributedOp) + return nullptr; + extract.replaceAllUsesWith(distributedOp); + return newVec; +} + // TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp). // TODO: Add this as DRR pattern. void mlir::vector::populateVectorToVectorTransformationPatterns( Index: mlir/test/Dialect/Vector/invalid.mlir =================================================================== --- mlir/test/Dialect/Vector/invalid.mlir +++ mlir/test/Dialect/Vector/invalid.mlir @@ -1328,3 +1328,31 @@ // expected-error@+1 {{'vector.compressstore' op expected value dim to match mask dim}} vector.compressstore %base, %mask, %value : memref, vector<17xi1>, vector<16xf32> } + +// ----- + +func @extract_map_rank(%v: vector<2x32xf32>, %id : index) { + // expected-error@+1 {{'vector.extract_map' op expects source and destination vectors of rank 1}} + %0 = vector.extract_map %v, %id, 32 : vector<2x32xf32> to vector<2x1xf32> +} + +// ----- + +func @extract_map_size(%v: vector<63xf32>, %id : index) { + // expected-error@+1 {{'vector.extract_map' op vector sizes mismatch. Source size must be equal to destination size * dimension}} + %0 = vector.extract_map %v, %id, 32 : vector<63xf32> to vector<2xf32> +} + +// ----- + +func @insert_map_rank(%v: vector<2x1xf32>, %id : index) { + // expected-error@+1 {{'vector.insert_map' op expected source and destination vectors of rank 1}} + %0 = vector.insert_map %v, %id, 32 : vector<2x1xf32> to vector<2x32xf32> +} + +// ----- + +func @insert_map_size(%v: vector<1xf32>, %id : index) { + // expected-error@+1 {{'vector.insert_map' op vector sizes mismatch. Destination size must be equal to source size * dimension}} + %0 = vector.insert_map %v, %id, 32 : vector<1xf32> to vector<64xf32> +} Index: mlir/test/Dialect/Vector/ops.mlir =================================================================== --- mlir/test/Dialect/Vector/ops.mlir +++ mlir/test/Dialect/Vector/ops.mlir @@ -432,3 +432,14 @@ vector.compressstore %base, %mask, %0 : memref, vector<16xi1>, vector<16xf32> return } + +// CHECK-LABEL: @extract_insert_map +func @extract_insert_map(%v: vector<32xf32>, %id : index) -> vector<32xf32> { + // CHECK: %[[V:.*]] = vector.extract_map %{{.*}}, %{{.*}}, 16 : vector<32xf32> to vector<2xf32> + %vd = vector.extract_map %v, %id, 16 : vector<32xf32> to vector<2xf32> + // CHECK: %[[R:.*]] = vector.insert_map %[[V]], %{{.*}}, 16 : vector<2xf32> to vector<32xf32> + %r = vector.insert_map %vd, %id, 16 : vector<2xf32> to vector<32xf32> + // CHECK: return %[[R]] : vector<32xf32> + return %r : vector<32xf32> +} + Index: mlir/test/Dialect/Vector/vector-distribution.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Vector/vector-distribution.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt %s -test-vector-distribute-patterns -allow-unregistered-dialect | FileCheck %s + +// CHECK-LABEL: func @distribute_vector_add +// CHECK: %[[ID:.*]] = "getID"() : () -> index +// CHECK-NEXT: %[[EXA:.*]] = vector.extract_map %{{.*}}, %[[ID]], 32 : vector<32xf32> to vector<1xf32> +// CHECK-NEXT: %[[EXB:.*]] = vector.extract_map %{{.*}}, %[[ID]], 32 : vector<32xf32> to vector<1xf32> +// CHECK-NEXT: %[[ADD:.*]] = addf %[[EXA]], %[[EXB]] : vector<1xf32> +// CHECK-NEXT: %[[INS:.*]] = vector.insert_map %[[ADD]], %[[ID]], 32 : vector<1xf32> to vector<32xf32> +// CHECK-NEXT: return %[[INS]] : vector<32xf32> +func @distribute_vector_add(%A: vector<32xf32>, %B: vector<32xf32>) -> vector<32xf32> { + %0 = addf %A, %B : vector<32xf32> + return %0: vector<32xf32> +} Index: mlir/test/lib/Transforms/TestVectorTransforms.cpp =================================================================== --- mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -125,6 +125,25 @@ } }; +struct TestVectorDistributePatterns + : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnFunction() override { + MLIRContext *ctx = &getContext(); + OwningRewritePatternList patterns; + auto getId = [](OpBuilder &b, Location loc) { + OperationState opState(loc, "getID"); + opState.addTypes(b.getIndexType()); + Operation *op = b.createOperation(opState); + return op->getResult(0); + }; + patterns.insert>(getId, 32, ctx); + applyPatternsAndFoldGreedily(getFunction(), patterns); + } +}; + struct TestVectorTransferFullPartialSplitPatterns : public PassWrapper { @@ -178,5 +197,9 @@ vectorTransformFullPartialPass("test-vector-transfer-full-partial-split", "Test conversion patterns to split " "transfer ops via scf.if + linalg ops"); + PassRegistration distributePass( + "test-vector-distribute-patterns", + "Test conversion patterns to distribute vector ops in the vector " + "dialect"); } } // namespace mlir