Index: mlir/include/mlir/Dialect/Vector/VectorOps.td =================================================================== --- mlir/include/mlir/Dialect/Vector/VectorOps.td +++ mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -373,6 +373,7 @@ let assemblyFormat = "$kind `,` $source attr-dict $reduction_dims `:` type($source) `to` type($dest)"; let hasFolder = 1; + let hasCanonicalizer = 1; } def Vector_BroadcastOp : Index: mlir/lib/Dialect/Vector/VectorOps.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorOps.cpp +++ mlir/lib/Dialect/Vector/VectorOps.cpp @@ -118,6 +118,13 @@ return succeeded(successStrides) && (strides.empty() || strides.back() == 1); } +template +static SmallVector extractVector(ArrayAttr arrayAttr) { + return llvm::to_vector<4>(llvm::map_range( + arrayAttr.getAsRange(), + [](IntegerAttr attr) { return static_cast(attr.getInt()); })); +} + //===----------------------------------------------------------------------===// // CombiningKindAttr //===----------------------------------------------------------------------===// @@ -295,6 +302,63 @@ return {}; } +/// Convert MulIOp/MulFOp + MultiDimReductionOp into ContractionOp. +/// Ex: +/// ``` +/// %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32> +/// %1 = vector.multi_reduction #vector.kind, %0 [1] +/// : vector<8x32x16xf32> to vector<8x16xf32> +/// ``` +/// Gets converted to: +/// ``` +/// %1 = vector.contract {indexing_maps = [ +/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, +/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, +/// affine_map<(d0, d1, d2) -> (d0, d1)>], +/// iterator_types = ["parallel", "parallel", "reduction"], +/// kind = #vector.kind} %0, %arg1, %cst_f0 +/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> +/// ``` +struct MultiReduceToContract : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MultiDimReductionOp reduceOp, + PatternRewriter &rewriter) const override { + if (reduceOp.kind() != CombiningKind::ADD) + return failure(); + Operation *mulOp = reduceOp.source().getDefiningOp(); + if (!mulOp || !isa(mulOp)) + return failure(); + SmallVector reductionMask = reduceOp.getReductionMask(); + auto srcMap = rewriter.getMultiDimIdentityMap(reductionMask.size()); + SmallVector exprs; + SmallVector iteratorTypes; + for (auto isReduceDim : llvm::enumerate(reductionMask)) { + if (!isReduceDim.value()) { + iteratorTypes.push_back(getParallelIteratorTypeName()); + exprs.push_back(rewriter.getAffineDimExpr(isReduceDim.index())); + } else { + iteratorTypes.push_back(getReductionIteratorTypeName()); + } + } + auto dstMap = AffineMap::get(/*dimCount=*/reductionMask.size(), + /*symCount=*/0, exprs, reduceOp.getContext()); + Value zero = rewriter.create( + reduceOp.getLoc(), reduceOp.getDestType(), + rewriter.getZeroAttr(reduceOp.getDestType())); + rewriter.replaceOpWithNewOp( + reduceOp, mulOp->getOperand(0), mulOp->getOperand(1), zero, + rewriter.getAffineMapArrayAttr({srcMap, srcMap, dstMap}), + rewriter.getStrArrayAttr(iteratorTypes)); + return success(); + } +}; + +void MultiDimReductionOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // ReductionOp //===----------------------------------------------------------------------===// @@ -810,10 +874,144 @@ } }; +/// Merge TransposeOp into ContractionOp user. +/// Ex: +/// ``` +/// %0 = vector.transpose %arg0, [2, 0, 1] +/// : vector<32x16x8xf32> to vector<8x32x16xf32> +/// %1 = vector.contract {indexing_maps = [ +/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, +/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, +/// affine_map<(d0, d1, d2) -> (d0, d1)>], +/// iterator_types = ["parallel", "parallel", "reduction"], +/// kind = #vector.kind} %0, %arg1, %cst_f0 +/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> +/// ``` +/// Gets converted to: +/// ``` +/// %1 = vector.contract {indexing_maps = [ +/// affine_map<(d0, d1, d2) -> (d1, d2, d0)>, +/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, +/// affine_map<(d0, d1, d2) -> (d0, d1)>], +/// iterator_types = ["parallel", "parallel", "reduction"], +/// kind = #vector.kind} %arg0, %arg1, %cst_f0 +/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> +/// ``` +struct CanonicalizeContractTranspose : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ContractionOp contractOp, + PatternRewriter &rewriter) const override { + SmallVector maps = + llvm::to_vector<4>(contractOp.getIndexingMaps()); + Value lhs = contractOp.lhs(); + Value rhs = contractOp.rhs(); + size_t index = 0; + bool changed = false; + for (Value *operand : {&lhs, &rhs}) { + auto transposeOp = operand->getDefiningOp(); + if (!transposeOp) + continue; + SmallVector perm; + transposeOp.getTransp(perm); + AffineMap permutationMap = AffineMap::getPermutationMap( + extractVector(transposeOp.transp()), + contractOp.getContext()); + maps[index] = inversePermutation(permutationMap).compose(maps[index]); + *operand = transposeOp.vector(); + index++; + changed = true; + } + if (!changed) + return failure(); + rewriter.replaceOpWithNewOp( + contractOp, lhs, rhs, contractOp.acc(), + rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types()); + return success(); + } +}; + +/// Merge BroadcastOp into ContractionOp user. +/// Ex: +/// ``` +/// %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32> +/// %1 = vector.contract {indexing_maps = [ +/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, +/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, +/// affine_map<(d0, d1, d2) -> (d0, d1)>], +/// iterator_types = ["parallel", "parallel", "reduction"], +/// kind = #vector.kind} %0, %arg1, %cst_f0 +/// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> +/// ``` +/// Gets converted to: +/// ``` +/// %1 = vector.contract {indexing_maps = [ +/// affine_map<(d0, d1, d2) -> (d1, d2)>, +/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, +/// affine_map<(d0, d1, d2) -> (d0, d1)>], +/// iterator_types = ["parallel", "parallel", "reduction"], +/// kind = #vector.kind} %arg0, %arg1, %cst_f0 +/// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> +/// ``` +struct CanonicalizeContractBroadcast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ContractionOp contractOp, + PatternRewriter &rewriter) const override { + SmallVector maps = + llvm::to_vector<4>(contractOp.getIndexingMaps()); + Value lhs = contractOp.lhs(); + Value rhs = contractOp.rhs(); + size_t index = 0; + bool changed = false; + for (Value *operand : {&lhs, &rhs}) { + auto broadcast = operand->getDefiningOp(); + if (!broadcast) + continue; + // contractionOp can only take vector as operands. + auto srcType = broadcast.getSourceType().dyn_cast(); + if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank()) + continue; + int64_t rankDiff = + broadcast.getVectorType().getRank() - srcType.getRank(); + bool innerDimBroadcast = false; + SmallVector originalDims; + for (auto dim : llvm::enumerate(srcType.getShape())) { + if (dim.value() != + broadcast.getVectorType().getDimSize(rankDiff + dim.index())) { + innerDimBroadcast = true; + break; + } + originalDims.push_back( + rewriter.getAffineDimExpr(dim.index() + rankDiff)); + } + // Contract doesn't support inner dimension broadcast. Once this is + // relaxed we can remove this case. + if (innerDimBroadcast) + continue; + AffineMap broadcastMap = + AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims, + contractOp.getContext()); + maps[index] = broadcastMap.compose(maps[index]); + *operand = broadcast.source(); + index++; + changed = true; + } + if (!changed) + return failure(); + rewriter.replaceOpWithNewOp( + contractOp, lhs, rhs, contractOp.acc(), + rewriter.getAffineMapArrayAttr(maps), contractOp.iterator_types()); + return success(); + } +}; + void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add, - CanonicalizeContractAdd>(context); + CanonicalizeContractAdd, + CanonicalizeContractTranspose, CanonicalizeContractBroadcast>( + context); } //===----------------------------------------------------------------------===// @@ -923,13 +1121,6 @@ return success(); } -template -static SmallVector extractVector(ArrayAttr arrayAttr) { - return llvm::to_vector<4>(llvm::map_range( - arrayAttr.getAsRange(), - [](IntegerAttr attr) { return static_cast(attr.getInt()); })); -} - /// Fold the result of chains of ExtractOp in place by simply concatenating the /// positions. static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) { Index: mlir/test/Dialect/Vector/canonicalize.mlir =================================================================== --- mlir/test/Dialect/Vector/canonicalize.mlir +++ mlir/test/Dialect/Vector/canonicalize.mlir @@ -1082,3 +1082,91 @@ vector<16x4xf16> to vector<2x4xf16> return %1 : vector<2x4xf16> } + +// ----- + +// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> + +// CHECK-LABEL: multidimreduction_contract +// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x16xf32> +// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]], +// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} +// CHECK-SAME: %{{.*}}, %{{.*}}, %[[C0]] : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x16xf32> +// CHECK-NEXT: return %[[R]] : vector<8x16xf32> +func @multidimreduction_contract( + %arg0: vector<8x32x16xf32>,%arg1: vector<8x32x16xf32>) -> vector<8x16xf32> { + %0 = arith.mulf %arg0, %arg1 : vector<8x32x16xf32> + %1 = vector.multi_reduction #vector.kind, %0 [1] : vector<8x32x16xf32> to vector<8x16xf32> + return %1 : vector<8x16xf32> +} + +// ----- + +// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> + +// CHECK-LABEL: multidimreduction_contract_int +// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0> : vector<8x16xi32> +// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]], +// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind} +// CHECK-SAME: %{{.*}}, %{{.*}}, %[[C0]] : vector<8x32x16xi32>, vector<8x32x16xi32> into vector<8x16xi32> +// CHECK-NEXT: return %[[R]] : vector<8x16xi32> +func @multidimreduction_contract_int( + %arg0: vector<8x32x16xi32>,%arg1: vector<8x32x16xi32>) -> vector<8x16xi32> { + %0 = arith.muli %arg0, %arg1 : vector<8x32x16xi32> + %1 = vector.multi_reduction #vector.kind, %0 [1] : vector<8x32x16xi32> to vector<8x16xi32> + return %1 : vector<8x16xi32> +} + +// ----- + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: contract_transpose +// CHECK-SAME: (%[[ARG0:.+]]: vector<32x16x8xf32>, +// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x32xf32> +// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} +// CHECK-SAME: %[[ARG0]], %{{.*}}, %[[C0]] : vector<32x16x8xf32>, vector<8x32x16xf32> into vector<8x32xf32> +// CHECK-NEXT: return %[[R]] : vector<8x32xf32> +func @contract_transpose( + %arg0: vector<32x16x8xf32>, %arg1: vector<8x32x16xf32>) -> vector<8x32xf32> { + %cst = arith.constant dense<0.000000e+00> : vector<8x32xf32> + %0 = vector.transpose %arg0, [2, 0, 1] : vector<32x16x8xf32> to vector<8x32x16xf32> + %1 = vector.contract {indexing_maps = [#map0, #map0, #map1], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %0, %arg1, %cst : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> + return %1 : vector<8x32xf32> +} + +// ----- + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: contract_broadcast +// CHECK-SAME: (%[[ARG0:.+]]: vector<32x16xf32>, +// CHECK-NEXT: %[[C0:.+]] = arith.constant dense<0.000000e+00> : vector<8x32xf32> +// CHECK-NEXT: %[[R:.+]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} +// CHECK-SAME: %[[ARG0]], %{{.*}}, %[[C0]] : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> +// CHECK-NEXT: return %[[R]] : vector<8x32xf32> +func @contract_broadcast( + %arg0: vector<32x16xf32>, %arg1: vector<8x32x16xf32>) -> vector<8x32xf32> { + %cst = arith.constant dense<0.000000e+00> : vector<8x32xf32> + %0 = vector.broadcast %arg0 : vector<32x16xf32> to vector<8x32x16xf32> + %1 = vector.contract {indexing_maps = [#map0, #map0, #map1], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %0, %arg1, %cst : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> + return %1 : vector<8x32xf32> +}