Index: mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h =================================================================== --- mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -49,6 +49,9 @@ Matmul = 1, /// Lower to `vector.outerproduct`. OuterProduct = 2, + /// Lower contract with all reduction dimensions unrolled to 1 to a vector + /// elementwise operations. + ParallelArith = 3, }; /// Enum to control the splitting of `vector.transfer` operations into /// in-bounds and out-of-bounds variants. Index: mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp =================================================================== --- mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -144,6 +144,59 @@ [](IntegerAttr attr) { return static_cast(attr.getInt()); })); } +/// Helper to create arithmetic operation associated with a kind of contraction. +static Optional createContractArithOp(Location loc, Value x, Value y, + Value acc, + vector::CombiningKind kind, + PatternRewriter &rewriter, + bool isInt) { + using vector::CombiningKind; + Value mul; + if (isInt) { + if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF) + // Only valid for floating point types. + return Optional(); + mul = rewriter.create(loc, x, y); + } else { + // Float case. + if (kind == CombiningKind::AND || kind == CombiningKind::MINUI || + kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI || + kind == CombiningKind::MAXSI || kind == CombiningKind::OR || + kind == CombiningKind::XOR) + // Only valid for integer types. + return Optional(); + // Special case for fused multiply-add. + if (acc && acc.getType().isa() && kind == CombiningKind::ADD) { + return Optional(rewriter.create(loc, x, y, acc)); + } + mul = rewriter.create(loc, x, y); + } + if (!acc) + return Optional(mul); + return makeArithReduction(rewriter, loc, kind, mul, acc); +} + +/// Return the positions of the reductions in the given map. +static SmallVector getReductionIndex(AffineMap map, + ArrayAttr iteratorTypes) { + SmallVector dimsIdx; + for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { + if (isReductionIterator(iteratorTypes[map.getDimPosition(i)])) + dimsIdx.push_back(i); + } + return dimsIdx; +} + +/// Look for a given dimension in an affine map and return its position. Return +/// llvm::None if the dimension is not in the map results. +static llvm::Optional getDimPosition(AffineMap map, unsigned dim) { + for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { + if (map.getDimPosition(i) == dim) + return i; + } + return llvm::None; +} + namespace { /// ShapeCastOpFolder folds cancelling ShapeCastOps away. @@ -498,9 +551,8 @@ if (!rhsType) { // Special case: AXPY operation. Value b = rewriter.create(loc, lhsType, op.getRhs()); - Optional mult = - isInt ? genMultI(loc, op.getLhs(), b, acc, kind, rewriter) - : genMultF(loc, op.getLhs(), b, acc, kind, rewriter); + Optional mult = createContractArithOp(loc, op.getLhs(), b, acc, + kind, rewriter, isInt); if (!mult.hasValue()) return failure(); rewriter.replaceOp(op, mult.getValue()); @@ -518,8 +570,7 @@ if (acc) r = rewriter.create(loc, rhsType, acc, pos); Optional m = - isInt ? genMultI(loc, a, op.getRhs(), r, kind, rewriter) - : genMultF(loc, a, op.getRhs(), r, kind, rewriter); + createContractArithOp(loc, a, op.getRhs(), r, kind, rewriter, isInt); if (!m.hasValue()) return failure(); result = rewriter.create(loc, resType, m.getValue(), @@ -528,48 +579,127 @@ rewriter.replaceOp(op, result); return success(); } +}; -private: - static Optional genMultI(Location loc, Value x, Value y, Value acc, - vector::CombiningKind kind, - PatternRewriter &rewriter) { - using vector::CombiningKind; - - auto mul = rewriter.create(loc, x, y); - if (!acc) - return Optional(mul); - - if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF) - // Only valid for floating point types. - return Optional(); - - return makeArithReduction(rewriter, loc, kind, mul, acc); +/// Lower vector.contract with all size one reduction dimensions to +/// elementwise ops when possible. +struct ContractOpToElementwise + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + using FilterConstraintType = + std::function; + static LogicalResult defaultFilter(vector::ContractionOp op) { + return success(); } + ContractOpToElementwise( + vector::VectorTransformsOptions vectorTransformOptions, + MLIRContext *context, + const FilterConstraintType &constraint = defaultFilter) + : OpRewritePattern(context), + vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {} - static Optional genMultF(Location loc, Value x, Value y, Value acc, - vector::CombiningKind kind, - PatternRewriter &rewriter) { - using vector::CombiningKind; - - // Special case for fused multiply-add. - if (acc && kind == CombiningKind::ADD) { - return Optional(rewriter.create(loc, x, y, acc)); - } - - auto mul = rewriter.create(loc, x, y); - - if (!acc) - return Optional(mul); + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + // TODO: implement masks + if (llvm::size(contractOp.getMasks()) != 0) + return failure(); - if (kind == CombiningKind::ADD || kind == CombiningKind::AND || - kind == CombiningKind::MINUI || kind == CombiningKind::MINSI || - kind == CombiningKind::MAXUI || kind == CombiningKind::MAXSI || - kind == CombiningKind::OR || kind == CombiningKind::XOR) - // Already handled or only valid for integer types. - return Optional(); + if (failed(filter(contractOp))) + return failure(); - return makeArithReduction(rewriter, loc, kind, mul, acc); + if (vectorTransformOptions.vectorContractLowering != + vector::VectorContractLowering::ParallelArith) + return failure(); + ArrayRef lhsShape = contractOp.getLhsType().getShape(); + ArrayRef rhsShape = contractOp.getRhsType().getShape(); + AffineMap lhsMap = contractOp.getIndexingMaps()[0]; + AffineMap rhsMap = contractOp.getIndexingMaps()[1]; + SmallVector lhsReductionDims = + getReductionIndex(lhsMap, contractOp.getIteratorTypes()); + SmallVector rhsReductionDims = + getReductionIndex(rhsMap, contractOp.getIteratorTypes()); + // All the reduction dimensions must be a size 1. + for (int64_t dim : lhsReductionDims) { + if (lhsShape[dim] != 1) + return failure(); + } + for (int64_t dim : rhsReductionDims) { + if (rhsShape[dim] != 1) + return failure(); + } + AffineMap accMap = contractOp.getIndexingMaps()[2]; + unsigned numParallelDims = accMap.getNumResults(); + unsigned numLhsDimToBroadcast = + numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size()); + unsigned numRhsDimToBroadcast = + numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size()); + SmallVector lhsDims; + SmallVector lhsTranspose; + SmallVector rhsDims; + SmallVector rhsTranspose; + for (int64_t dim : lhsReductionDims) + lhsTranspose.push_back(numLhsDimToBroadcast + dim); + for (int64_t dim : rhsReductionDims) + rhsTranspose.push_back(numRhsDimToBroadcast + dim); + // Loop through the parallel dimensions to calculate the dimensions to + // broadcast and to permute in order to extract only parallel dimensions. + for (unsigned i = 0; i < numParallelDims; i++) { + llvm::Optional lhsDim = + getDimPosition(lhsMap, accMap.getDimPosition(i)); + if (lhsDim) { + lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim); + } else { + // If the parallel dimension doesn't exist we will have to broadcast it. + lhsDims.push_back( + contractOp.getResultType().cast().getDimSize(i)); + lhsTranspose.push_back(lhsDims.size() - 1); + } + llvm::Optional rhsDim = + getDimPosition(rhsMap, accMap.getDimPosition(i)); + if (rhsDim) { + rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim); + } else { + // If the parallel dimension doesn't exist we will have to broadcast it. + rhsDims.push_back( + contractOp.getResultType().cast().getDimSize(i)); + rhsTranspose.push_back(rhsDims.size() - 1); + } + } + Value newLhs = contractOp.getLhs(); + Value newRhs = contractOp.getRhs(); + Location loc = contractOp.getLoc(); + if (!lhsDims.empty()) { + lhsDims.append(lhsShape.begin(), lhsShape.end()); + auto expandedType = + VectorType::get(lhsDims, contractOp.getLhsType().getElementType()); + newLhs = rewriter.create(loc, expandedType, newLhs); + } + if (!rhsDims.empty()) { + rhsDims.append(rhsShape.begin(), rhsShape.end()); + auto expandedType = + VectorType::get(rhsDims, contractOp.getRhsType().getElementType()); + newRhs = rewriter.create(loc, expandedType, newRhs); + } + bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex(); + newLhs = rewriter.create(loc, newLhs, lhsTranspose); + newRhs = rewriter.create(loc, newRhs, rhsTranspose); + SmallVector lhsOffsets(lhsReductionDims.size(), 0); + SmallVector rhsOffsets(rhsReductionDims.size(), 0); + newLhs = rewriter.create( + loc, newLhs, rewriter.getI64ArrayAttr(lhsOffsets)); + newRhs = rewriter.create( + loc, newRhs, rewriter.getI64ArrayAttr(rhsOffsets)); + Optional result = + createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(), + contractOp.getKind(), rewriter, isInt); + rewriter.replaceOp(contractOp, {*result}); + return success(); } + +private: + /// Options to control the vector patterns. + vector::VectorTransformsOptions vectorTransformOptions; + FilterConstraintType filter; }; /// Progressive lowering of ConstantMaskOp. @@ -1594,6 +1724,9 @@ ContractionOpToDotLowering pat3(vectorTransformOptions, ctx); if (succeeded(pat3.matchAndRewrite(op, rewriter))) return success(); + ContractOpToElementwise pat4(vectorTransformOptions, ctx); + if (succeeded(pat4.matchAndRewrite(op, rewriter))) + return success(); // Find first batch dimension in LHS/RHS, and lower when found. std::vector> batchDimMap = op.getBatchDimMap(); Index: mlir/test/Dialect/Vector/vector-contract-transforms.mlir =================================================================== --- mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -2,6 +2,7 @@ // RUN: mlir-opt %s -test-vector-contraction-lowering=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX // RUN: mlir-opt %s -test-vector-contraction-lowering=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT // RUN: mlir-opt %s -test-vector-contraction-lowering=vector-filter-outerproduct=1 | FileCheck %s --check-prefix=FILTEROUTERPRODUCT +// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-parallel-arith=1 | FileCheck %s --check-prefix=PARALLEL #dotp_accesses = [ affine_map<(i) -> (i)>, @@ -1104,3 +1105,54 @@ : vector<3x4xf32>, vector<4x4xf32> into vector<3x4xf32> return %0 : vector<3x4xf32> } + +// PARALLEL-LABEL: func @parrallel_contract_lowering +// PARALLEL: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32> +// PARALLEL: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32> +// PARALLEL: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %{{.*}} : vector<4xf32> +// PARALLEL: return %[[F]] : vector<4xf32> +func.func @parrallel_contract_lowering(%arg0: vector<1x1x4xf32>, %arg1: vector<1x1x4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> { + %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<1x1x4xf32>, vector<1x1x4xf32> into vector<4xf32> + return %0 : vector<4xf32> +} + +// PARALLEL-LABEL: func @parrallel_contract_lowering_broadcast +// PARALLEL: %[[B:.*]] = vector.broadcast %{{.*}} : vector<1x1xf32> to vector<4x1x1xf32> +// PARALLEL: %[[T:.*]] = vector.transpose %[[B]], [1, 2, 0] : vector<4x1x1xf32> to vector<1x1x4xf32> +// PARALLEL: %[[E0:.*]] = vector.extract %[[T]][0, 0] : vector<1x1x4xf32> +// PARALLEL: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32> +// PARALLEL: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %{{.*}} : vector<4xf32> +// PARALLEL: return %[[F]] : vector<4xf32> +func.func @parrallel_contract_lowering_broadcast(%arg0: vector<1x1xf32>, %arg1: vector<1x1x4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> { + %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d2, d0)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1x4xf32> into vector<4xf32> + return %0 : vector<4xf32> +} + +// PARALLEL-LABEL: func @parrallel_contract_lowering +// PARALLEL: %[[B:.*]] = vector.broadcast %{{.*}} : vector<1x1xf32> to vector<4x1x1xf32> +// PARALLEL: %[[T0:.*]] = vector.transpose %[[B]], [1, 2, 0] : vector<4x1x1xf32> to vector<1x1x4xf32> +// PARALLEL: %[[T1:.*]] = vector.transpose %{{.*}}, [0, 2, 1] : vector<1x4x1xf32> to vector<1x1x4xf32> +// PARALLEL: %[[E0:.*]] = vector.extract %[[T0]][0, 0] : vector<1x1x4xf32> +// PARALLEL: %[[E1:.*]] = vector.extract %[[T1]][0, 0] : vector<1x1x4xf32> +// PARALLEL: %[[F:.*]] = vector.fma %[[E0]], %[[E1]], %arg2 : vector<4xf32> +// PARALLEL: return %[[F]] : vector<4xf32> +func.func @parrallel_contract_lowering_transpose(%arg0: vector<1x1xf32>, %arg1: vector<1x4x1xf32>, %arg2: vector<4xf32>) -> vector<4xf32> { + %0 = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d1, d0, d2)>, affine_map<(d0, d1, d2) -> (d0)>], iterator_types = ["parallel", "reduction", "reduction"], kind = #vector.kind} %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x4x1xf32> into vector<4xf32> + return %0 : vector<4xf32> +} + +// PARALLEL-LABEL: func @parrallel_contract_lowering_scalar +// PARALLEL: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1xf32> +// PARALLEL: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1xf32> +// PARALLEL: %[[M:.*]] = arith.mulf %[[E0]], %[[E1]] : f32 +// PARALLEL: %[[A:.*]] = arith.addf %[[M]], %{{.*}} : f32 +// PARALLEL: return %[[A]] : f32 +func.func @parrallel_contract_lowering_scalar(%arg0: vector<1x1xf32>, %arg1: vector<1x1xf32>, %arg2: f32) -> f32 { + %0 = vector.contract { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> ()>], + iterator_types = ["reduction", "reduction"], kind = #vector.kind} + %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1xf32> into f32 + return %0 : f32 +} Index: mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp =================================================================== --- mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -135,6 +135,10 @@ llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for " "vectors of size 4."), llvm::cl::init(false)}; + Option lowerToParallelArith{ + *this, "vector-parallel-arith", + llvm::cl::desc("Lower vector.contract to elementwise vector ops."), + llvm::cl::init(false)}; void runOnOperation() override { RewritePatternSet patterns(&getContext()); @@ -165,6 +169,15 @@ return; } + if (lowerToParallelArith) { + vector::populateVectorContractLoweringPatterns( + patterns, + vector::VectorTransformsOptions().setVectorTransformsOptions( + vector::VectorContractLowering::ParallelArith)); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + return; + } + // Test on all contract lowering patterns. VectorContractLowering contractLowering = VectorContractLowering::Dot; if (lowerToFlatMatrix)