diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -105,6 +105,10 @@ // a sequence of vector.reduction ops. void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns); +/// Collect a set of patterns to propagate insert_map/extract_map in the ssa +/// chain. +void populatePropagateVectorDistributionPatterns(RewritePatternSet &patterns); + /// An attribute that specifies the combining function for `vector.contract`, /// and `vector.reduction`. class CombiningKindAttr diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h --- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -251,26 +251,6 @@ distributPointwiseVectorOp(OpBuilder &builder, Operation *op, ArrayRef id, ArrayRef multiplicity, const AffineMap &map); -/// Canonicalize an extra element using the result of a pointwise operation. -/// Transforms: -/// %v = addf %a, %b : vector32xf32> -/// %dv = vector.extract_map %v, %id, 32 : vector<32xf32> into vector<1xf32> -/// to: -/// %da = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32> -/// %db = vector.extract_map %a, %id, 32 : vector<32xf32> into vector<1xf32> -/// %dv = addf %da, %db : vector<1xf32> -struct PointwiseExtractPattern : public OpRewritePattern { - using FilterConstraintType = std::function; - PointwiseExtractPattern( - MLIRContext *context, FilterConstraintType constraint = - [](ExtractMapOp op) { return success(); }) - : OpRewritePattern(context), filter(constraint) {} - LogicalResult matchAndRewrite(ExtractMapOp extract, - PatternRewriter &rewriter) const override; - -private: - FilterConstraintType filter; -}; /// Implements transfer op write to read forwarding and dead transfer write /// optimizations. diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -2793,25 +2793,6 @@ return failure(); } -LogicalResult mlir::vector::PointwiseExtractPattern::matchAndRewrite( - ExtractMapOp extract, PatternRewriter &rewriter) const { - Operation *definedOp = extract.vector().getDefiningOp(); - if (!definedOp || definedOp->getNumResults() != 1) - return failure(); - // TODO: Create an interfaceOp for elementwise operations. - if (!isa(definedOp)) - return failure(); - Location loc = extract.getLoc(); - SmallVector extractOperands; - for (OpOperand &operand : definedOp->getOpOperands()) - extractOperands.push_back(rewriter.create( - loc, extract.getResultType(), operand.get(), extract.ids())); - Operation *newOp = cloneOpWithOperandsAndTypes( - rewriter, loc, definedOp, extractOperands, extract.getResult().getType()); - rewriter.replaceOp(extract, newOp->getResult(0)); - return success(); -} - Optional mlir::vector::distributPointwiseVectorOp( OpBuilder &builder, Operation *op, ArrayRef ids, ArrayRef multiplicity, const AffineMap &map) { @@ -2843,6 +2824,91 @@ return ops; } +/// Canonicalize an extract_map using the result of a pointwise operation. +/// Transforms: +/// %v = addf %a, %b : vector32xf32> +/// %dv = vector.extract_map %v[%id] : vector<32xf32> to vector<1xf32> +/// to: +/// %da = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32> +/// %db = vector.extract_map %a[%id] : vector<32xf32> to vector<1xf32> +/// %dv = addf %da, %db : vector<1xf32> +struct PointwiseExtractPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ExtractMapOp extract, + PatternRewriter &rewriter) const override { + Operation *definedOp = extract.vector().getDefiningOp(); + if (!definedOp || !OpTrait::hasElementwiseMappableTraits(definedOp) || + definedOp->getNumResults() != 1) + return failure(); + Location loc = extract.getLoc(); + SmallVector extractOperands; + for (OpOperand &operand : definedOp->getOpOperands()) { + auto vecType = operand.get().getType().template dyn_cast(); + if (!vecType) { + extractOperands.push_back(operand.get()); + continue; + } + extractOperands.push_back(rewriter.create( + loc, + VectorType::get(extract.getResultType().getShape(), + vecType.getElementType()), + operand.get(), extract.ids())); + } + Operation *newOp = cloneOpWithOperandsAndTypes( + rewriter, loc, definedOp, extractOperands, extract.getResultType()); + rewriter.replaceOp(extract, newOp->getResult(0)); + return success(); + } +}; + +/// Canonicalize an extract_map using the result of a contract operation. +/// This propagate the extract_map to operands. +struct ContractExtractPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ExtractMapOp extract, + PatternRewriter &rewriter) const override { + Operation *definedOp = extract.vector().getDefiningOp(); + auto contract = dyn_cast_or_null(definedOp); + if (!contract) + return failure(); + Location loc = contract.getLoc(); + unsigned accIndex = vector::ContractionOp::getAccOperandIndex(); + AffineMap affineMap = contract.getIndexingMaps()[accIndex]; + // Create a map of the dimensions distributed based on the acc affine map. + // Only parallel dimensions are being distributed, reduction dimensions are + // untouched. + DenseMap map; + for (unsigned i : llvm::seq(unsigned(0), affineMap.getNumResults())) + map[affineMap.getDimPosition(i)] = extract.getResultType().getDimSize(i); + SmallVector extractOperands; + for (auto it : llvm::enumerate(contract.getIndexingMaps())) { + // For each operands calculate the new vector type after distribution. + Value operand = contract->getOperand(it.index()); + auto vecType = operand.getType().cast(); + SmallVector operandShape(vecType.getShape().begin(), + vecType.getShape().end()); + for (unsigned i : llvm::seq(unsigned(0), it.value().getNumResults())) { + unsigned dim = it.value().getDimPosition(i); + auto distributedDim = map.find(dim); + // If the dimension is not in the map it means it is a reduction and + // doesn't get distributed. + if (distributedDim == map.end()) + continue; + operandShape[i] = distributedDim->second; + } + VectorType newVecType = + VectorType::get(operandShape, vecType.getElementType()); + extractOperands.push_back(rewriter.create( + loc, newVecType, operand, extract.ids())); + } + Operation *newOp = + cloneOpWithOperandsAndTypes(rewriter, loc, definedOp, extractOperands, + extract.getResult().getType()); + rewriter.replaceOp(extract, newOp->getResult(0)); + return success(); + } +}; + /// Converts TransferRead op used by ExtractMap op into a smaller dimension /// TransferRead. /// Example: @@ -4100,8 +4166,7 @@ // TODO: Add this as DRR pattern. void mlir::vector::populateVectorToVectorTransformationPatterns( RewritePatternSet &patterns) { - patterns.add( + patterns.add( patterns.getContext()); } @@ -4112,6 +4177,13 @@ ignoreFilter); } +void mlir::vector::populatePropagateVectorDistributionPatterns( + RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} + void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( RewritePatternSet &patterns) { patterns.add +// CHECK-NEXT: %[[ADDV:.*]] = addf %[[EXPV]], %{{.*}} : vector<32xf32> +// CHECK-NEXT: %[[EXA:.*]] = vector.extract_map %{{.*}}[%[[ID]]] : vector<32xf32> to vector<1xf32> +// CHECK-NEXT: %[[EXC:.*]] = math.exp %[[EXA]] : vector<1xf32> +// CHECK-NEXT: %[[EXB:.*]] = vector.extract_map %{{.*}}[%[[ID]]] : vector<32xf32> to vector<1xf32> +// CHECK-NEXT: %[[ADD:.*]] = addf %[[EXC]], %[[EXB]] : vector<1xf32> +// CHECK-NEXT: %[[INS:.*]] = vector.insert_map %[[ADD]], %[[ADDV]][%[[ID]]] : vector<1xf32> into vector<32xf32> +// CHECK-NEXT: return %[[INS]] : vector<32xf32> +func @distribute_vector_add_exp(%id : index, %A: vector<32xf32>, %B: vector<32xf32>) -> vector<32xf32> { + %C = math.exp %A : vector<32xf32> + %0 = addf %C, %B : vector<32xf32> + return %0: vector<32xf32> +} + +// ----- + // CHECK-LABEL: func @vector_add_read_write // CHECK-SAME: (%[[ID:.*]]: index // CHECK: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32> @@ -154,3 +173,32 @@ vector.transfer_write %acc, %C[%c0, %c0, %c0, %c0] {permutation_map = #map2}: vector<64x4x32xf32>, memref return } + +// ----- + +// CHECK2D-LABEL: vector_add_contract +// CHECK2D: %[[A:.+]] = vector.transfer_read %arg2[%0, %c0], %cst : memref, vector<2x4xf32> +// CHECK2D: %[[B:.+]] = vector.transfer_read %arg3[%2, %c0], %cst : memref, vector<16x4xf32> +// CHECK2D: %[[C:.+]] = vector.transfer_read %arg4[%4, %5], %cst : memref, vector<2x16xf32> +// CHECK2D: %[[E:.+]] = vector.transfer_read %arg5[%7, %8], %cst : memref, vector<2x16xf32> +// CHECK2D: %[[D:.+]] = vector.contract {{.*}} %[[A]], %[[B]], %[[C]] : vector<2x4xf32>, vector<16x4xf32> into vector<2x16xf32> +// CHECK2D: %[[R:.+]] = addf %[[D]], %[[E]] : vector<2x16xf32> +// CHECK2D: vector.transfer_write %[[R]], {{.*}} : vector<2x16xf32>, memref +func @vector_add_contract(%id0 : index, %id1 : index, %A: memref, + %B: memref, %C: memref, %D: memref) { + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + %a = vector.transfer_read %A[%c0, %c0], %cf0 : memref, vector<64x4xf32> + %b = vector.transfer_read %B[%c0, %c0], %cf0 : memref, vector<64x4xf32> + %c = vector.transfer_read %C[%c0, %c0], %cf0 : memref, vector<64x64xf32> + %d = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} + %a, %b, %c : vector<64x4xf32>, vector<64x4xf32> into vector<64x64xf32> + %e = vector.transfer_read %D[%c0, %c0], %cf0 : memref, vector<64x64xf32> + %r = addf %d, %e : vector<64x64xf32> + vector.transfer_write %r, %C[%c0, %c0] : vector<64x64xf32>, memref + return +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -275,8 +275,7 @@ } } }); - patterns.add(ctx); - populateVectorToVectorTransformationPatterns(patterns); + populatePropagateVectorDistributionPatterns(patterns); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } }; @@ -339,8 +338,7 @@ } return mlir::WalkResult::interrupt(); }); - patterns.add(ctx); - populateVectorToVectorTransformationPatterns(patterns); + populatePropagateVectorDistributionPatterns(patterns); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } };