diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -215,6 +215,10 @@ RewritePatternSet &patterns, VectorTransformsOptions options = VectorTransformsOptions()); +/// Collect patterns to convert reduction op to vector.contract and fold +/// transpose/broadcast ops into the contract. +void populateVetorReductionToContractPatterns(RewritePatternSet &patterns); + /// Returns the integer type required for subscripts in the vector dialect. IntegerType getVectorSubscriptType(Builder &builder); diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -240,6 +240,13 @@ return slicedIndices; } +template +static SmallVector extractVector(ArrayAttr arrayAttr) { + return llvm::to_vector<4>(llvm::map_range( + arrayAttr.getAsRange(), + [](IntegerAttr attr) { return static_cast(attr.getInt()); })); +} + namespace { struct UnrollTransferReadPattern @@ -1114,6 +1121,193 @@ } }; +/// Convert MulIOp/MulFOp + MultiDimReductionOp into ContractionOp. +/// Ex: +/// ``` +/// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32> +/// %1 = vector.multi_reduction #vector.kind, %0 [1] +/// : vector<8x32x16xf32> to vector<8x16xf32> +/// ``` +/// Gets converted to: +/// ``` +/// %1 = vector.contract {indexing_maps = [ +/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, +/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, +/// affine_map<(d0, d1, d2) -> (d0, d1)>], +/// iterator_types = ["parallel", "parallel", "reduction"], +/// kind = #vector.kind} %0, %arg1, %cst_f0 +/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> +/// ``` +struct MultiReduceToContract + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp, + PatternRewriter &rewriter) const override { + if (reduceOp.kind() != vector::CombiningKind::ADD) + return failure(); + Operation *mulOp = reduceOp.source().getDefiningOp(); + if (!mulOp || !isa(mulOp)) + return failure(); + SmallVector reductionMask = reduceOp.getReductionMask(); + auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size()); + SmallVector exprs; + SmallVector iteratorTypes; + for (auto isReduceDim : llvm::enumerate(reductionMask)) { + if (!isReduceDim.value()) { + iteratorTypes.push_back(getParallelIteratorTypeName()); + exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index())); + } else { + iteratorTypes.push_back(getReductionIteratorTypeName()); + } + } + auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(), + /*symCount=*/0, exprs, reduceOp.getContext()); + Value zero = rewriter.create( + reduceOp.getLoc(), reduceOp.getDestType(), + rewriter.getZeroAttr(reduceOp.getDestType())); + rewriter.replaceOpWithNewOp( + reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero, + rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}), + rewriter.getStrArrayAttr(iteratorTypes)); + return success(); + } +}; + +/// Merge TransposeOp into ContractionOp user. +/// Ex: +/// ``` +/// %0 = vector.transpose %arg0, [2, 0, 1] +/// : vector<32x16x8xf32> to vector<8x32x16xf32> +/// %1 = vector.contract {indexing_maps = [ +/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, +/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, +/// affine_map<(d0, d1, d2) -> (d0, d1)>], +/// iterator_types = ["parallel", "parallel", "reduction"], +/// kind = #vector.kind} %0, %arg1, %cst_f0 +/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> +/// ``` +/// Gets converted to: +/// ``` +/// %1 = vector.contract {indexing_maps = [ +/// affine_map<(d0, d1, d2) -> (d1, d2, d0)>, +/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, +/// affine_map<(d0, d1, d2) -> (d0, d1)>], +/// iterator_types = ["parallel", "parallel", "reduction"], +/// kind = #vector.kind} %arg0, %arg1, %cst_f0 +/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> +/// ``` +struct CombineContractTranspose + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + SmallVector maps = + llvm::to_vector<4>(contractOp.getIndexingMaps()); + Value lhs = contractOp.lhs(); + Value rhs = contractOp.rhs(); + size_t index = 0; + bool changed = false; + for (Value *operand : {&lhs, &rhs}) { + AffineMap &map = maps[index++]; + auto transposeOp = operand->getDefiningOp(); + if (!transposeOp) + continue; + SmallVector perm; + transposeOp.getTransp(perm); + AffineMap permutationMap = AffineMap::getPermutationMap( + extractVector(transposeOp.transp()), + contractOp.getContext()); + map = inversePermutation(permutationMap).compose(map); + *operand = transposeOp.vector(); + changed = true; + } + if (!changed) + return failure(); + rewriter.replaceOpWithNewOp( + contractOp, lhs, rhs, contractOp.acc(), + rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types()); + return success(); + } +}; + +/// Merge BroadcastOp into ContractionOp user. +/// Ex: +/// ``` +/// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32> +/// %1 = vector.contract {indexing_maps = [ +/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, +/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, +/// affine_map<(d0, d1, d2) -> (d0, d1)>], +/// iterator_types = ["parallel", "parallel", "reduction"], +/// kind = #vector.kind} %0, %arg1, %cst_f0 +/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> +/// ``` +/// Gets converted to: +/// ``` +/// %1 = vector.contract {indexing_maps = [ +/// affine_map<(d0, d1, d2) -> (d1, d2)>, +/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, +/// affine_map<(d0, d1, d2) -> (d0, d1)>], +/// iterator_types = ["parallel", "parallel", "reduction"], +/// kind = #vector.kind} %arg0, %arg1, %cst_f0 +/// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> +/// ``` +struct CombineContractBroadcast + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + SmallVector maps = + llvm::to_vector<4>(contractOp.getIndexingMaps()); + Value lhs = contractOp.lhs(); + Value rhs = contractOp.rhs(); + size_t index = 0; + bool changed = false; + for (Value *operand : {&lhs, &rhs}) { + AffineMap &map = maps[index++]; + auto broadcast = operand->getDefiningOp(); + if (!broadcast) + continue; + // contractionOp can only take vector as operands. + auto srcType = broadcast.getSourceType().dyn_cast(); + if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank()) + continue; + int64_t rankDiff = + broadcast.getVectorType().getRank() - srcType.getRank(); + bool innerDimBroadcast = false; + SmallVector originalDims; + for (auto dim : llvm::enumerate(srcType.getShape())) { + if (dim.value() != + broadcast.getVectorType().getDimSize(rankDiff + dim.index())) { + innerDimBroadcast = true; + break; + } + originalDims.push_back( + rewriter.getAffineDimExpr(dim.index() + rankDiff)); + } + // Contract doesn't support inner dimension broadcast. Once this is + // relaxed we can remove this case. + if (innerDimBroadcast) + continue; + AffineMap broadcastMap = + AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims, + contractOp.getContext()); + map = broadcastMap.compose(map); + *operand = broadcast.source(); + changed = true; + } + if (!changed) + return failure(); + rewriter.replaceOpWithNewOp( + contractOp, lhs, rhs, contractOp.acc(), + rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types()); + return success(); + } +}; + } // namespace /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using @@ -3668,6 +3862,12 @@ patterns.add(options, patterns.getContext()); } +void mlir::vector::populateVetorReductionToContractPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( RewritePatternSet &patterns) { patterns.add (d0, d1, d2)> +// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> + +// CHECK-LABEL: multidimreduction_contract +// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x16xf32> +// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]], +// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} +// CHECK-SAME: %{{.*}}, %{{.*}}, %[[C0]] : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x16xf32> +// CHECK-NEXT: return %[[R]] : vector<8x16xf32> +func @multidimreduction_contract( + %arg0: vector<8x32x16xf32>,%arg1: vector<8x32x16xf32>) -> vector<8x16xf32> { + %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32> + %1 = vector.multi_reduction #vector.kind, %0 [1] : vector<8x32x16xf32> to vector<8x16xf32> + return %1 : vector<8x16xf32> +} + +// ----- + +// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> + +// CHECK-LABEL: multidimreduction_contract_int +// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0> : vector<8x16xi32> +// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]], +// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} +// CHECK-SAME: %{{.*}}, %{{.*}}, %[[C0]] : vector<8x32x16xi32>, vector<8x32x16xi32> into vector<8x16xi32> +// CHECK-NEXT: return %[[R]] : vector<8x16xi32> +func @multidimreduction_contract_int( + %arg0: vector<8x32x16xi32>,%arg1: vector<8x32x16xi32>) -> vector<8x16xi32> { + %0 = arith.muli %arg0, %arg1 : vector<8x32x16xi32> + %1 = vector.multi_reduction #vector.kind, %0 [1] : vector<8x32x16xi32> to vector<8x16xi32> + return %1 : vector<8x16xi32> +} + +// ----- + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: contract_transpose +// CHECK-SAME: (%[[ARG0:.+]]: vector<32x16x8xf32>, +// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x32xf32> +// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} +// CHECK-SAME: %[[ARG0]], %{{.*}}, %[[C0]] : vector<32x16x8xf32>, vector<8x32x16xf32> into vector<8x32xf32> +// CHECK-NEXT: return %[[R]] : vector<8x32xf32> +func @contract_transpose( + %arg0: vector<32x16x8xf32>, %arg1: vector<8x32x16xf32>) -> vector<8x32xf32> { + %cst = arith.constant dense<0.000000e+00> : vector<8x32xf32> + %0 = vector.transpose %arg0, [2, 0, 1] : vector<32x16x8xf32> to vector<8x32x16xf32> + %1 = vector.contract {indexing_maps = [#map0, #map0, #map1], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %0, %arg1, %cst : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> + return %1 : vector<8x32xf32> +} + +// ----- + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: contract_broadcast +// CHECK-SAME: (%[[ARG0:.+]]: vector<32x16xf32>, +// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x32xf32> +// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} +// CHECK-SAME: %[[ARG0]], %{{.*}}, %[[C0]] : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> +// CHECK-NEXT: return %[[R]] : vector<8x32xf32> +func @contract_broadcast( + %arg0: vector<32x16xf32>, %arg1: vector<8x32x16xf32>) -> vector<8x32xf32> { + %cst = arith.constant dense<0.000000e+00> : vector<8x32xf32> + %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32> + %1 = vector.contract {indexing_maps = [#map0, #map0, #map1], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %0, %arg1, %cst : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> + return %1 : vector<8x32xf32> +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -493,6 +493,23 @@ } }; +struct TestVectorReduceToContractPatternsPatterns + : public PassWrapper { + StringRef getArgument() const final { + return "test-vector-reduction-to-contract-patterns"; + } + StringRef getDescription() const final { + return "Test patterns to convert multireduce op to contract and combine " + "broadcast/transpose to contract"; + } + void runOnFunction() override { + RewritePatternSet patterns(&getContext()); + populateVetorReductionToContractPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); + } +}; + } // end anonymous namespace namespace mlir { @@ -519,6 +536,8 @@ PassRegistration(); PassRegistration(); + + PassRegistration(); } } // namespace test } // namespace mlir