Index: mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h =================================================================== --- mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -49,6 +49,9 @@ Matmul = 1, /// Lower to `vector.outerproduct`. OuterProduct = 2, + /// Lower contract with all reduction dimensions unrolled to 1 to a vector + /// elementwise operations. + ParallelArith = 2, }; /// Enum to control the splitting of `vector.transfer` operations into /// in-bounds and out-of-bounds variants. Index: mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp =================================================================== --- mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -144,6 +144,70 @@ [](IntegerAttr attr) { return static_cast(attr.getInt()); })); } +/// Helper to create arithmetic operation associated with a kind of contraction. +static Optional createContractArithOp(Location loc, Value x, Value y, + Value acc, + vector::CombiningKind kind, + PatternRewriter &rewriter, + bool isInt) { + using vector::CombiningKind; + Value mul; + if (isInt) { + if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF) + // Only valid for floating point types. + return Optional(); + mul = rewriter.create(loc, x, y); + } else { + // Float case. + if (kind == CombiningKind::AND || kind == CombiningKind::MINUI || + kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI || + kind == CombiningKind::MAXSI || kind == CombiningKind::OR || + kind == CombiningKind::XOR) + // Only valid for integer types. + return Optional(); + // Special case for fused multiply-add. + if (acc && kind == CombiningKind::ADD) { + return Optional(rewriter.create(loc, x, y, acc)); + } + mul = rewriter.create(loc, x, y); + } + if (!acc) + return Optional(mul); + return makeArithReduction(rewriter, loc, kind, mul, acc); +} + +/// Return the positions of the reductions in the given map. +static SmallVector getReductionIndex(AffineMap map, + ArrayAttr iteratorTypes) { + SmallVector dimsIdx; + for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { + auto iteratorTypeName = + iteratorTypes[map.getDimPosition(i)].cast().getValue(); + if (iteratorTypeName == getReductionIteratorTypeName()) + dimsIdx.push_back(i); + } + return dimsIdx; +} + +/// Return the permutation to apply to have only reductions as leading +/// dimensions. +static SmallVector +getPermForLeadingReductions(AffineMap map, int64_t numReduction, + ArrayAttr iteratorTypes) { + SmallVector indices; + int64_t parallelLoopIdx = 0; + for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { + auto iteratorTypeName = + iteratorTypes[map.getDimPosition(i)].cast().getValue(); + if (iteratorTypeName == getReductionIteratorTypeName()) { + indices.push_back(i - parallelLoopIdx); + continue; + } + indices.push_back(numReduction + parallelLoopIdx++); + } + return indices; +} + namespace { /// ShapeCastOpFolder folds cancelling ShapeCastOps away. @@ -498,9 +562,8 @@ if (!rhsType) { // Special case: AXPY operation. Value b = rewriter.create(loc, lhsType, op.getRhs()); - Optional mult = - isInt ? genMultI(loc, op.getLhs(), b, acc, kind, rewriter) - : genMultF(loc, op.getLhs(), b, acc, kind, rewriter); + Optional mult = createContractArithOp(loc, op.getLhs(), b, acc, + kind, rewriter, isInt); if (!mult.hasValue()) return failure(); rewriter.replaceOp(op, mult.getValue()); @@ -518,8 +581,7 @@ if (acc) r = rewriter.create(loc, rhsType, acc, pos); Optional m = - isInt ? genMultI(loc, a, op.getRhs(), r, kind, rewriter) - : genMultF(loc, a, op.getRhs(), r, kind, rewriter); + createContractArithOp(loc, a, op.getRhs(), r, kind, rewriter, isInt); if (!m.hasValue()) return failure(); result = rewriter.create(loc, resType, m.getValue(), @@ -528,48 +590,92 @@ rewriter.replaceOp(op, result); return success(); } +}; -private: - static Optional genMultI(Location loc, Value x, Value y, Value acc, - vector::CombiningKind kind, - PatternRewriter &rewriter) { - using vector::CombiningKind; - - auto mul = rewriter.create(loc, x, y); - if (!acc) - return Optional(mul); - - if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF) - // Only valid for floating point types. - return Optional(); - - return makeArithReduction(rewriter, loc, kind, mul, acc); +/// Lower vector.contract with all size one reduction dimensions to +/// elementwise ops when possible. +struct ContractOpToElementwise + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + using FilterConstraintType = + std::function; + static LogicalResult defaultFilter(vector::ContractionOp op) { + return success(); } + ContractOpToElementwise( + vector::VectorTransformsOptions vectorTransformOptions, + MLIRContext *context, + const FilterConstraintType &constraint = defaultFilter) + : OpRewritePattern(context), + vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {} - static Optional genMultF(Location loc, Value x, Value y, Value acc, - vector::CombiningKind kind, - PatternRewriter &rewriter) { - using vector::CombiningKind; - - // Special case for fused multiply-add. - if (acc && kind == CombiningKind::ADD) { - return Optional(rewriter.create(loc, x, y, acc)); - } - - auto mul = rewriter.create(loc, x, y); - - if (!acc) - return Optional(mul); + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + // TODO: implement masks + if (llvm::size(contractOp.getMasks()) != 0) + return failure(); - if (kind == CombiningKind::ADD || kind == CombiningKind::AND || - kind == CombiningKind::MINUI || kind == CombiningKind::MINSI || - kind == CombiningKind::MAXUI || kind == CombiningKind::MAXSI || - kind == CombiningKind::OR || kind == CombiningKind::XOR) - // Already handled or only valid for integer types. - return Optional(); + if (failed(filter(contractOp))) + return failure(); - return makeArithReduction(rewriter, loc, kind, mul, acc); + if (vectorTransformOptions.vectorContractLowering != + vector::VectorContractLowering::ParallelArith) + return failure(); + ArrayRef lhsShape = contractOp.getLhsType().getShape(); + ArrayRef rhsShape = contractOp.getRhsType().getShape(); + AffineMap lhsMap = contractOp.getIndexingMaps()[0]; + AffineMap rhsMap = contractOp.getIndexingMaps()[1]; + SmallVector lhsReductionDims = + getReductionIndex(lhsMap, contractOp.getIteratorTypes()); + SmallVector rhsReductionDims = + getReductionIndex(rhsMap, contractOp.getIteratorTypes()); + // All the reduction dimensions must be a size 1. + for (int64_t dim : lhsReductionDims) { + if (lhsShape[dim] != 1) + return failure(); + } + for (int64_t dim : rhsReductionDims) { + if (rhsShape[dim] != 1) + return failure(); + } + SmallVector lhsTranspose = getPermForLeadingReductions( + lhsMap, lhsReductionDims.size(), contractOp.getIteratorTypes()); + SmallVector rhsTranspose = getPermForLeadingReductions( + rhsMap, rhsReductionDims.size(), contractOp.getIteratorTypes()); + + Location loc = contractOp.getLoc(); + bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex(); + Value newLhs = rewriter.create( + loc, contractOp.getLhs(), lhsTranspose); + Value newRhs = rewriter.create( + loc, contractOp.getRhs(), rhsTranspose); + int64_t dstRank = + contractOp.getResultType().isa() + ? contractOp.getResultType().cast().getRank() + : 0; + SmallVector lhsOffsets(lhsReductionDims.size(), 0); + SmallVector rhsOffsets(rhsReductionDims.size(), 0); + newLhs = rewriter.create( + loc, newLhs, rewriter.getI64ArrayAttr(lhsOffsets)); + if (contractOp.getLhsType().getRank() - lhsOffsets.size() < dstRank) + newLhs = rewriter.create( + loc, contractOp.getResultType(), newLhs); + newRhs = rewriter.create( + loc, newRhs, rewriter.getI64ArrayAttr(rhsOffsets)); + if (contractOp.getRhsType().getRank() - rhsOffsets.size() < dstRank) + newRhs = rewriter.create( + loc, contractOp.getResultType(), newRhs); + Optional result = + createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(), + contractOp.getKind(), rewriter, isInt); + rewriter.replaceOp(contractOp, {*result}); + return success(); } + +private: + /// Options to control the vector patterns. + vector::VectorTransformsOptions vectorTransformOptions; + FilterConstraintType filter; }; /// Progressive lowering of ConstantMaskOp. @@ -1594,6 +1700,9 @@ ContractionOpToDotLowering pat3(vectorTransformOptions, ctx); if (succeeded(pat3.matchAndRewrite(op, rewriter))) return success(); + ContractOpToElementwise pat4(vectorTransformOptions, ctx); + if (succeeded(pat4.matchAndRewrite(op, rewriter))) + return success(); // Find first batch dimension in LHS/RHS, and lower when found. std::vector> batchDimMap = op.getBatchDimMap(); Index: mlir/test/Dialect/Vector/vector-contract-transforms.mlir =================================================================== --- mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -2,6 +2,7 @@ // RUN: mlir-opt %s -test-vector-contraction-lowering=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX // RUN: mlir-opt %s -test-vector-contraction-lowering=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT // RUN: mlir-opt %s -test-vector-contraction-lowering=vector-filter-outerproduct=1 | FileCheck %s --check-prefix=FILTEROUTERPRODUCT +// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-parallel-arith=1 | FileCheck %s --check-prefix=PARALLEL #dotp_accesses = [ affine_map<(i) -> (i)>, @@ -1104,3 +1105,38 @@ : vector<3x4xf32>, vector<4x4xf32> into vector<3x4xf32> return %0 : vector<3x4xf32> } + +// PARALLEL-LABEL: func @parrallel_contract_lowering +// PARALLEL: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32> +// PARALLEL: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32> +// PARALLEL: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %{{.*}} : vector<4xf32> +// PARALLEL: return %[[F]] : vector<4xf32> +func.func @parrallel_contract_lowering(%arg0: vector<1x1x4xf32>, %arg1: vector<1x1x4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> { + %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<1x1x4xf32>, vector<1x1x4xf32> into vector<4xf32> + return %0 : vector<4xf32> +} + +// PARALLEL-LABEL: func @parrallel_contract_lowering_broadcast +// PARALLEL: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1xf32> +// PARALLEL: %[[B:.*]] = vector.broadcast %[[E0]] : f32 to vector<4xf32> +// PARALLEL: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32> +// PARALLEL: %[[F:.*]] = vector.fma %[[B]], %[[E1]], %{{.*}} : vector<4xf32> +// PARALLEL: return %[[F]] : vector<4xf32> +func.func @parrallel_contract_lowering_broadcast(%arg0: vector<1x1xf32>, %arg1: vector<1x1x4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> { + %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1x4xf32> into vector<4xf32> + return %0 : vector<4xf32> +} + +// PARALLEL-LABEL: func @parrallel_contract_lowering +// PARALLEL: %[[T:.*]] = vector.transpose %{{.*}}, [0, 2, 1] : vector<1x4x1xf32> to vector<1x1x4xf32> +// PARALLEL: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1xf32> +// PARALLEL: %[[B:.*]] = vector.broadcast %[[E0]] : f32 to vector<4xf32> +// PARALLEL: %[[E1:.*]] = vector.extract %[[T]][0, 0] : vector<1x1x4xf32> +// PARALLEL: %[[F:.*]] = vector.fma %[[B]], %[[E1]], %arg2 : vector<4xf32> +// PARALLEL: return %[[F]] : vector<4xf32> +func.func @parrallel_contract_lowering_transpose(%arg0: vector<1x1xf32>, %arg1: vector<1x4x1xf32>, %arg2: vector<4xf32>) -> vector<4xf32> { + %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x4x1xf32> into vector<4xf32> + return %0 : vector<4xf32> +} + + Index: mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp =================================================================== --- mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -135,6 +135,10 @@ llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for " "vectors of size 4."), llvm::cl::init(false)}; + Option lowerToParallelArith{ + *this, "vector-parallel-arith", + llvm::cl::desc("Lower vector.contract to elementwise vector ops."), + llvm::cl::init(false)}; void runOnOperation() override { RewritePatternSet patterns(&getContext()); @@ -165,6 +169,15 @@ return; } + if (lowerToParallelArith) { + vector::populateVectorContractLoweringPatterns( + patterns, + vector::VectorTransformsOptions().setVectorTransformsOptions( + vector::VectorContractLowering::ParallelArith)); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + return; + } + // Test on all contract lowering patterns. VectorContractLowering contractLowering = VectorContractLowering::Dot; if (lowerToFlatMatrix)