diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -25,13 +25,6 @@ class OwningRewritePatternList; namespace vector { -/// Structure to control the behavior of vector transform patterns. -struct VectorTransformsOptions { - /// Let vector.contract lower to vector.matrix_multiply and LLVM matrix - /// intrinsics. - bool lowerToLLVMMatrixIntrinsics = false; -}; - /// Collect a set of vector-to-vector canonicalization patterns. void populateVectorToVectorCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context); @@ -51,6 +44,20 @@ void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns, MLIRContext *context); +/// Enum to control the lowering of `vector.contract` operations. +enum class VectorContractLowering { + /// Progressively lower to finer grained `vector.contract` and `vector.fma`. + FMA = 0, + /// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics. + Matmul = 1, + /// Lower to `vector.outerproduct`. + OuterProduct = 2, +}; +/// Structure to control the behavior of vector transform patterns. +struct VectorTransformsOptions { + VectorContractLowering vectorContractLowering = VectorContractLowering::FMA; +}; + /// Collect a set of transformation patterns that are related to contracting /// or expanding vector operations: /// ContractionOpLowering, diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -685,6 +685,11 @@ return %3: vector<4x8xf32> ``` }]; + let builders = [ + // Build an op without mask, use the type of `acc` as the return type. + OpBuilder< + "OpBuilder &builder, OperationState &result, Value lhs, Value rhs, " + "Value acc">]; let extraClassDeclaration = [{ VectorType getOperandVectorTypeLHS() { return lhs().getType().cast(); diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h --- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -9,6 +9,7 @@ #ifndef DIALECT_VECTOR_VECTORTRANSFORMS_H_ #define DIALECT_VECTOR_VECTORTRANSFORMS_H_ +#include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/PatternMatch.h" namespace mlir { @@ -22,13 +23,6 @@ ArrayRef coarseVectorShape = {}, ArrayRef fineVectorShape = {}); -//////////////////////////////////////////////////////////////////////////////// -// The following Declarative Rewrite Rule (DRR) helpers are used in rewrite -// patterns. As such, they must not call into `rewriter.erase/replace` APIs and -// it is the responsibility of the enclosing PatternRewriter to erase on -// success. -//////////////////////////////////////////////////////////////////////////////// - namespace vector { // Entry point for unrolling declarative pattern rewrites. @@ -69,6 +63,114 @@ ArrayRef targetShape); } // namespace vector + +//===----------------------------------------------------------------------===// +// Finer-grained patterns exposed for more control over individual lowerings. +//===----------------------------------------------------------------------===// + +/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul +/// semantics to: +/// ``` +/// %flattened_a = vector.shape_cast %a +/// %flattened_b = vector.shape_cast %b +/// %flattened_d = vector.matmul %flattened_a, %flattened_b +/// %d = vector.shape_cast %%flattened_d +/// %e = add %c, %d +/// ``` +/// `vector.matmul` later lowers to `llvm.matrix.multiply`. +// +/// This only kicks in when VectorTransformsOptions is set to OuterProduct and +/// the vector.contract op is a row-major matrix multiply. +class ContractionOpToMatmulOpLowering + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + ContractionOpToMatmulOpLowering( + vector::VectorTransformsOptions vectorTransformsOptions, + MLIRContext *context) + : OpRewritePattern(context), + vectorTransformsOptions(vectorTransformsOptions) {} + + /// Options to control the vector patterns. + vector::VectorTransformsOptions vectorTransformsOptions; + + LogicalResult match(vector::ContractionOp op) const override; + void rewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override; +}; + +/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul +/// semantics to a reduction_size-unrolled sequence: +/// ``` +/// %at = vector.transpose %a, [1, 0] +/// %bRow0 = vector.extract %b[0] +/// %atRow0 = vector.extract %at[0] +/// %c0 = vector.outerproduct %atRow0, %bRow0, %c +/// ... +/// %bRowK = vector.extract %b[K] +/// %atRowK = vector.extract %at[K] +/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 +/// ``` +/// +/// This only kicks in when VectorTransformsOptions is set to OuterProduct and +/// the vector.contract op is a row-major matrix multiply. +class ContractionOpToOuterProductOpLowering + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + ContractionOpToOuterProductOpLowering( + vector::VectorTransformsOptions vectorTransformsOptions, + MLIRContext *context) + : OpRewritePattern(context), + vectorTransformsOptions(vectorTransformsOptions) {} + + LogicalResult match(vector::ContractionOp op) const override; + void rewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override; + + /// Options to control the vector patterns. + vector::VectorTransformsOptions vectorTransformsOptions; +}; + +/// Progressive lowering of ContractionOp. +/// +/// One: +/// %x = vector.contract with at least one free/batch dimension +/// is replaced by: +/// %a = vector.contract with one less free/batch dimension +/// %b = vector.contract with one less free/batch dimension +/// .. +/// %x = combine %a %b .. +/// until a pure contraction is reached (no free/batch dimensions), +/// which is replaced by a fma/reduction op. +/// +/// This only kicks in when either VectorTransformsOptions is set to FMA or when +/// other contraction patterns fail. +class ContractionOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + ContractionOpLowering(vector::VectorTransformsOptions vectorTransformsOptions, + MLIRContext *context) + : OpRewritePattern(context), + vectorTransformsOptions(vectorTransformsOptions) {} + + /// Options to control the vector patterns. + vector::VectorTransformsOptions vectorTransformsOptions; + + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override; + +private: + // Lower one parallel dimension. + Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex, + int64_t rhsIndex, PatternRewriter &rewriter) const; + // Lower one reduction dimension. + Value lowerReduction(vector::ContractionOp op, + PatternRewriter &rewriter) const; +}; + } // namespace mlir #endif // DIALECT_VECTOR_VECTORTRANSFORMS_H_ diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -957,6 +957,13 @@ // OuterProductOp //===----------------------------------------------------------------------===// +/// Build an op without mask, use the type of `acc` as the return type. +void OuterProductOp::build(OpBuilder &builder, OperationState &result, + Value lhs, Value rhs, Value acc) { + result.addOperands({lhs, rhs, acc}); + result.addTypes(acc.getType()); +} + static void print(OpAsmPrinter &p, OuterProductOp op) { p << op.getOperationName() << " " << op.lhs() << ", " << op.rhs(); if (!op.acc().empty()) diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1252,343 +1252,6 @@ } }; -/// Progressive lowering of ContractionOp. -/// One: -/// %x = vector.contract with at least one free/batch dimension -/// is replaced by: -/// %a = vector.contract with one less free/batch dimension -/// %b = vector.contract with one less free/batch dimension -/// .. -/// %x = combine %a %b .. -/// until a pure contraction is reached (no free/batch dimensions), -/// which is replaced by a fma/reduction op. -/// -/// TODO(ajcbik): break down into transpose/reshape/cast ops -/// when they become available to avoid code dup -/// TODO(ajcbik): investigate lowering order impact on performance -class ContractionOpLowering : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - ContractionOpLowering(vector::VectorTransformsOptions vectorTransformsOptions, - MLIRContext *context) - : OpRewritePattern(context), - vectorTransformsOptions(vectorTransformsOptions) {} - - LogicalResult matchAndRewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const override { - // TODO(ajcbik): implement masks - if (llvm::size(op.masks()) != 0) - return failure(); - - // TODO(ntv, ajcbik): implement benefits, cost models, separate this out in - // a new pattern. - if (vectorTransformsOptions.lowerToLLVMMatrixIntrinsics && - isRowMajorMatmul(op.indexing_maps())) { - VectorType lhsType = op.getLhsType(); - VectorType rhsType = op.getRhsType(); - unsigned lhsRows = op.getLhsType().getShape()[0]; - unsigned lhsColumns = op.getLhsType().getShape()[1]; - unsigned rhsColumns = op.getRhsType().getShape()[1]; - - Type flattenedLHSType = - VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); - Type flattenedRHSType = - VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); - auto lhs = rewriter.create( - op.getLoc(), flattenedLHSType, op.lhs()); - auto rhs = rewriter.create( - op.getLoc(), flattenedRHSType, op.rhs()); - - Value mul = rewriter.create( - op.getLoc(), lhs, rhs, lhsRows, lhsColumns, rhsColumns); - mul = rewriter.create(op.getLoc(), - op.acc().getType(), mul); - Type elementType = op.getLhsType().getElementType(); - assert(elementType.isIntOrFloat()); - if (elementType.isa()) - rewriter.replaceOpWithNewOp(op, op.acc(), mul); - else - rewriter.replaceOpWithNewOp(op, op.acc(), mul); - return success(); - } - - // Find first batch dimension in LHS/RHS, and lower when found. - std::vector> batchDimMap = op.getBatchDimMap(); - if (!batchDimMap.empty()) { - int64_t lhsIndex = batchDimMap[0].first; - int64_t rhsIndex = batchDimMap[0].second; - rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter)); - return success(); - } - - // Collect contracting dimensions. - std::vector> contractingDimMap = - op.getContractingDimMap(); - DenseSet lhsContractingDimSet; - DenseSet rhsContractingDimSet; - for (auto &dimPair : contractingDimMap) { - lhsContractingDimSet.insert(dimPair.first); - rhsContractingDimSet.insert(dimPair.second); - } - - // Find first free dimension in LHS, and lower when found. - VectorType lhsType = op.getLhsType(); - for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; - ++lhsIndex) { - if (lhsContractingDimSet.count(lhsIndex) == 0) { - rewriter.replaceOp( - op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter)); - return success(); - } - } - - // Find first free dimension in RHS, and lower when found. - VectorType rhsType = op.getRhsType(); - for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; - ++rhsIndex) { - if (rhsContractingDimSet.count(rhsIndex) == 0) { - rewriter.replaceOp( - op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter)); - return success(); - } - } - - // Lower the first remaining reduction dimension. - if (!contractingDimMap.empty()) { - rewriter.replaceOp(op, lowerReduction(op, rewriter)); - return success(); - } - - return failure(); - } - -private: - // Lower one parallel dimension. - // TODO(ajcbik): consider reusing existing contract unrolling - Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex, - int64_t rhsIndex, PatternRewriter &rewriter) const { - VectorType lhsType = op.getLhsType(); - VectorType rhsType = op.getRhsType(); - VectorType resType = op.getResultType().cast(); - // Find the iterator type index and result index. - SmallVector iMap = op.getIndexingMaps(); - int64_t iterIndex = -1; - int64_t dimSize = -1; - if (lhsIndex >= 0) { - iterIndex = - iMap[0].getResult(lhsIndex).cast().getPosition(); - assert((rhsIndex < 0 || iterIndex == iMap[1] - .getResult(rhsIndex) - .cast() - .getPosition()) && - "parallel index should be free in LHS or batch in LHS/RHS"); - dimSize = lhsType.getDimSize(lhsIndex); - } else { - assert(rhsIndex >= 0 && "missing parallel index"); - iterIndex = - iMap[1].getResult(rhsIndex).cast().getPosition(); - dimSize = rhsType.getDimSize(rhsIndex); - } - assert(iterIndex >= 0 && "parallel index not listed in operand mapping"); - Optional lookup = getResultIndex(iMap[2], iterIndex); - assert(lookup.hasValue() && "parallel index not listed in reduction"); - int64_t resIndex = lookup.getValue(); - // Construct new iterator types and affine map array attribute. - SmallVector lowIndexingMaps; - lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter)); - lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter)); - lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter)); - auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); - auto lowIter = - rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex)); - // Unroll into a series of lower dimensional vector.contract ops. - Location loc = op.getLoc(); - Value result = rewriter.create(loc, resType, - rewriter.getZeroAttr(resType)); - for (int64_t d = 0; d < dimSize; ++d) { - auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter); - auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter); - auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter); - Value lowContract = rewriter.create( - loc, lhs, rhs, acc, lowAffine, lowIter); - result = reshapeStore(loc, lowContract, result, resType, resIndex, d, - rewriter); - } - return result; - } - - // Lower one reduction dimension. - Value lowerReduction(vector::ContractionOp op, - PatternRewriter &rewriter) const { - auto loc = op.getLoc(); - VectorType lhsType = op.getLhsType(); - VectorType rhsType = op.getRhsType(); - Type resType = op.getResultType(); - assert(!resType.isa()); - // Use iterator index 0. - int64_t iterIndex = 0; - SmallVector iMap = op.getIndexingMaps(); - Optional lookupLhs = getResultIndex(iMap[0], iterIndex); - Optional lookupRhs = getResultIndex(iMap[1], iterIndex); - assert(lookupLhs.hasValue() && "missing LHS parallel index"); - assert(lookupRhs.hasValue() && "missing RHS parallel index"); - int64_t lhsIndex = lookupLhs.getValue(); - int64_t rhsIndex = lookupRhs.getValue(); - int64_t dimSize = lhsType.getDimSize(lhsIndex); - assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape"); - // Base case. - if (lhsType.getRank() == 1) { - assert(rhsType.getRank() == 1 && "corrupt contraction"); - Value zero = rewriter.create(loc, lhsType, - rewriter.getZeroAttr(lhsType)); - Value fma = rewriter.create(loc, op.lhs(), op.rhs(), zero); - StringAttr kind = rewriter.getStringAttr("add"); - return rewriter.create(loc, resType, kind, fma, - op.acc()); - } - // Construct new iterator types and affine map array attribute. - SmallVector lowIndexingMaps; - lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter)); - lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter)); - lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter)); - auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); - auto lowIter = - rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex)); - // Unroll into a series of lower dimensional vector.contract ops. - // By feeding the initial accumulator into the first contraction, - // and the result of each contraction into the next, eventually - // the sum of all reductions is computed. - Value result = op.acc(); - for (int64_t d = 0; d < dimSize; ++d) { - auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter); - auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter); - result = rewriter.create(loc, lhs, rhs, result, - lowAffine, lowIter); - } - return result; - } - - // Helper to find an index in an affine map. - static Optional getResultIndex(AffineMap map, int64_t index) { - for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { - int64_t idx = map.getResult(i).cast().getPosition(); - if (idx == index) - return i; - } - return None; - } - - // Helper to construct iterator types with one index removed. - static SmallVector adjustIter(ArrayAttr iteratorTypes, - int64_t index) { - SmallVector results; - for (auto it : llvm::enumerate(iteratorTypes)) { - int64_t idx = it.index(); - if (idx == index) - continue; - results.push_back(it.value()); - } - return results; - } - - // Helper to construct an affine map with one index removed. - static AffineMap adjustMap(AffineMap map, int64_t index, - PatternRewriter &rewriter) { - auto *ctx = rewriter.getContext(); - SmallVector results; - for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { - int64_t idx = map.getResult(i).cast().getPosition(); - if (idx == index) - continue; - // Re-insert remaining indices, but renamed when occurring - // after the removed index. - auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx); - results.push_back(targetExpr); - } - return AffineMap::get(map.getNumDims() - 1, 0, results, ctx); - } - - // Helper to drop dimension from vector type. - static Type adjustType(VectorType tp, int64_t index) { - int64_t rank = tp.getRank(); - Type eltType = tp.getElementType(); - if (rank == 1) { - assert(index == 0 && "index for scalar result out of bounds"); - return eltType; - } - SmallVector adjustedShape; - for (int64_t i = 0; i < rank; ++i) { - // Omit dimension at the given index. - if (i == index) - continue; - // Otherwise, add dimension back. - adjustedShape.push_back(tp.getDimSize(i)); - } - return VectorType::get(adjustedShape, eltType); - } - - // Helper method to possibly drop a dimension in a load. - // TODO(ajcbik): use a reshaping vector load (and share lowering code) - static Value reshapeLoad(Location loc, Value val, VectorType type, - int64_t index, int64_t pos, - PatternRewriter &rewriter) { - if (index == -1) - return val; - Type lowType = adjustType(type, 0); - // At extraction dimension? - if (index == 0) { - auto posAttr = rewriter.getI64ArrayAttr(pos); - return rewriter.create(loc, lowType, val, posAttr); - } - // Unroll leading dimensions. - VectorType vType = lowType.cast(); - VectorType resType = adjustType(type, index).cast(); - Value result = rewriter.create(loc, resType, - rewriter.getZeroAttr(resType)); - for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) { - auto posAttr = rewriter.getI64ArrayAttr(d); - Value ext = rewriter.create(loc, vType, val, posAttr); - Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); - result = rewriter.create(loc, resType, load, result, - posAttr); - } - return result; - } - - // Helper method to possibly drop a dimension in a store. - // TODO(ajcbik): use a reshaping vector store (and share lowering code) - static Value reshapeStore(Location loc, Value val, Value result, - VectorType type, int64_t index, int64_t pos, - PatternRewriter &rewriter) { - // Unmodified? - if (index == -1) - return val; - // At insertion dimension? - if (index == 0) { - auto posAttr = rewriter.getI64ArrayAttr(pos); - return rewriter.create(loc, type, val, result, posAttr); - } - // Unroll leading dimensions. - Type lowType = adjustType(type, 0); - VectorType vType = lowType.cast(); - Type insType = adjustType(vType, 0); - for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { - auto posAttr = rewriter.getI64ArrayAttr(d); - Value ext = - rewriter.create(loc, vType, result, posAttr); - Value ins = - rewriter.create(loc, insType, val, posAttr); - Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter); - result = - rewriter.create(loc, type, sto, result, posAttr); - } - return result; - } - - vector::VectorTransformsOptions vectorTransformsOptions; -}; - /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D /// vectors progressively on the way to target llvm.matrix intrinsics. /// This iterates over the most major dimension of the 2-D vector and performs @@ -1656,6 +1319,416 @@ } // namespace +// Helper to find an index in an affine map. +static Optional getResultIndex(AffineMap map, int64_t index) { + for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { + int64_t idx = map.getResult(i).cast().getPosition(); + if (idx == index) + return i; + } + return None; +} + +// Helper to construct iterator types with one index removed. +static SmallVector adjustIter(ArrayAttr iteratorTypes, + int64_t index) { + SmallVector results; + for (auto it : llvm::enumerate(iteratorTypes)) { + int64_t idx = it.index(); + if (idx == index) + continue; + results.push_back(it.value()); + } + return results; +} + +// Helper to construct an affine map with one index removed. +static AffineMap adjustMap(AffineMap map, int64_t index, + PatternRewriter &rewriter) { + auto *ctx = rewriter.getContext(); + SmallVector results; + for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { + int64_t idx = map.getResult(i).cast().getPosition(); + if (idx == index) + continue; + // Re-insert remaining indices, but renamed when occurring + // after the removed index. + auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx); + results.push_back(targetExpr); + } + return AffineMap::get(map.getNumDims() - 1, 0, results, ctx); +} + +// Helper to drop dimension from vector type. +static Type adjustType(VectorType tp, int64_t index) { + int64_t rank = tp.getRank(); + Type eltType = tp.getElementType(); + if (rank == 1) { + assert(index == 0 && "index for scalar result out of bounds"); + return eltType; + } + SmallVector adjustedShape; + for (int64_t i = 0; i < rank; ++i) { + // Omit dimension at the given index. + if (i == index) + continue; + // Otherwise, add dimension back. + adjustedShape.push_back(tp.getDimSize(i)); + } + return VectorType::get(adjustedShape, eltType); +} + +// Helper method to possibly drop a dimension in a load. +// TODO(ajcbik): use a reshaping vector load (and share lowering code) +static Value reshapeLoad(Location loc, Value val, VectorType type, + int64_t index, int64_t pos, + PatternRewriter &rewriter) { + if (index == -1) + return val; + Type lowType = adjustType(type, 0); + // At extraction dimension? + if (index == 0) { + auto posAttr = rewriter.getI64ArrayAttr(pos); + return rewriter.create(loc, lowType, val, posAttr); + } + // Unroll leading dimensions. + VectorType vType = lowType.cast(); + VectorType resType = adjustType(type, index).cast(); + Value result = + rewriter.create(loc, resType, rewriter.getZeroAttr(resType)); + for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) { + auto posAttr = rewriter.getI64ArrayAttr(d); + Value ext = rewriter.create(loc, vType, val, posAttr); + Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); + result = + rewriter.create(loc, resType, load, result, posAttr); + } + return result; +} + +// Helper method to possibly drop a dimension in a store. +// TODO(ajcbik): use a reshaping vector store (and share lowering code) +static Value reshapeStore(Location loc, Value val, Value result, + VectorType type, int64_t index, int64_t pos, + PatternRewriter &rewriter) { + // Unmodified? + if (index == -1) + return val; + // At insertion dimension? + if (index == 0) { + auto posAttr = rewriter.getI64ArrayAttr(pos); + return rewriter.create(loc, type, val, result, posAttr); + } + // Unroll leading dimensions. + Type lowType = adjustType(type, 0); + VectorType vType = lowType.cast(); + Type insType = adjustType(vType, 0); + for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { + auto posAttr = rewriter.getI64ArrayAttr(d); + Value ext = rewriter.create(loc, vType, result, posAttr); + Value ins = rewriter.create(loc, insType, val, posAttr); + Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter); + result = rewriter.create(loc, type, sto, result, posAttr); + } + return result; +} + +namespace mlir { + +/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul +/// semantics to: +/// ``` +/// %flattened_a = vector.shape_cast %a +/// %flattened_b = vector.shape_cast %b +/// %flattened_d = vector.matmul %flattened_a, %flattened_b +/// %d = vector.shape_cast %%flattened_d +/// %e = add %c, %d +/// ``` +/// `vector.matmul` later lowers to `llvm.matrix.multiply`. +// +/// This only kicks in when VectorTransformsOptions is set to OuterProduct and +/// the vector.contract op is a row-major matrix multiply. +LogicalResult +ContractionOpToMatmulOpLowering::match(vector::ContractionOp op) const { + // TODO(ajcbik): implement masks + if (llvm::size(op.masks()) != 0) + return failure(); + + if (vectorTransformsOptions.vectorContractLowering != + vector::VectorContractLowering::Matmul || + !isRowMajorMatmul(op.indexing_maps())) + return failure(); + return success(); +} + +void ContractionOpToMatmulOpLowering::rewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const { + VectorType lhsType = op.getLhsType(); + VectorType rhsType = op.getRhsType(); + unsigned lhsRows = op.getLhsType().getShape()[0]; + unsigned lhsColumns = op.getLhsType().getShape()[1]; + unsigned rhsColumns = op.getRhsType().getShape()[1]; + + Type flattenedLHSType = + VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); + Type flattenedRHSType = + VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); + auto lhs = rewriter.create(op.getLoc(), flattenedLHSType, + op.lhs()); + auto rhs = rewriter.create(op.getLoc(), flattenedRHSType, + op.rhs()); + + Value mul = rewriter.create(op.getLoc(), lhs, rhs, lhsRows, + lhsColumns, rhsColumns); + mul = rewriter.create(op.getLoc(), op.acc().getType(), + mul); + Type elementType = op.getLhsType().getElementType(); + assert(elementType.isIntOrFloat()); + if (elementType.isa()) + rewriter.replaceOpWithNewOp(op, op.acc(), mul); + else + rewriter.replaceOpWithNewOp(op, op.acc(), mul); +} + +/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul +/// semantics to a reduction_size-unrolled sequence: +/// ``` +/// %at = vector.transpose %a, [1, 0] +/// %bRow0 = vector.extract %b[0] +/// %atRow0 = vector.extract %at[0] +/// %c0 = vector.outerproduct %atRow0, %bRow0, %c +/// ... +/// %bRowK = vector.extract %b[K] +/// %atRowK = vector.extract %at[K] +/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 +/// ``` +/// +/// This only kicks in when VectorTransformsOptions is set to OuterProduct and +/// the vector.contract op is a row-major matrix multiply. +void ContractionOpToOuterProductOpLowering::rewrite( + vector::ContractionOp op, PatternRewriter &rewriter) const { + VectorType lhsType = op.getLhsType(); + // TODO(ntv) other modes. + // We know we are in row-major. + bool transposeLhs = false; + unsigned reductionSize = + transposeLhs ? lhsType.getShape()[0] : lhsType.getShape()[1]; + + // If transposeLhs == false (i.e. lhs(m, reductionSize)), we need to + // transpose it to extract the proper vector. Otherwise, just take + // the lhs. + Value lhs = transposeLhs + ? op.lhs() + : rewriter.create( + op.getLoc(), op.lhs(), ArrayRef{1, 0}); + Value res = op.acc(); + // ExtractOp does not allow dynamic indexing, we must unroll explicitly. + for (unsigned k = 0; k < reductionSize; ++k) { + Value a = rewriter.create(op.getLoc(), lhs, k); + Value b = rewriter.create(op.getLoc(), op.rhs(), k); + res = rewriter.create(op.getLoc(), a, b, res); + } + rewriter.replaceOp(op, res); +} + +LogicalResult +ContractionOpToOuterProductOpLowering ::match(vector::ContractionOp op) const { + // TODO(ajcbik): implement masks + if (llvm::size(op.masks()) != 0) + return failure(); + + if (vectorTransformsOptions.vectorContractLowering != + vector::VectorContractLowering::OuterProduct || + !isRowMajorMatmul(op.indexing_maps())) + return failure(); + return success(); +} + +/// Progressive lowering of ContractionOp. +/// One: +/// %x = vector.contract with at least one free/batch dimension +/// is replaced by: +/// %a = vector.contract with one less free/batch dimension +/// %b = vector.contract with one less free/batch dimension +/// .. +/// %x = combine %a %b .. +/// until a pure contraction is reached (no free/batch dimensions), +/// which is replaced by a fma/reduction op. +/// +/// TODO(ajcbik): break down into transpose/reshape/cast ops +/// when they become available to avoid code dup +/// TODO(ajcbik): investigate lowering order impact on performance +LogicalResult +ContractionOpLowering::matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const { + + // TODO(ajcbik): implement masks. + if (llvm::size(op.masks()) != 0) + return failure(); + + // TODO(ntv, ajcbik): implement benefits, cost models. + MLIRContext *ctx = op.getContext(); + ContractionOpToMatmulOpLowering pat1(vectorTransformsOptions, ctx); + if (succeeded(pat1.match(op))) + return failure(); + ContractionOpToOuterProductOpLowering pat2(vectorTransformsOptions, ctx); + if (succeeded(pat2.match(op))) + return failure(); + + // Find first batch dimension in LHS/RHS, and lower when found. + std::vector> batchDimMap = op.getBatchDimMap(); + if (!batchDimMap.empty()) { + int64_t lhsIndex = batchDimMap[0].first; + int64_t rhsIndex = batchDimMap[0].second; + rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter)); + return success(); + } + + // Collect contracting dimensions. + std::vector> contractingDimMap = + op.getContractingDimMap(); + DenseSet lhsContractingDimSet; + DenseSet rhsContractingDimSet; + for (auto &dimPair : contractingDimMap) { + lhsContractingDimSet.insert(dimPair.first); + rhsContractingDimSet.insert(dimPair.second); + } + + // Find first free dimension in LHS, and lower when found. + VectorType lhsType = op.getLhsType(); + for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) { + if (lhsContractingDimSet.count(lhsIndex) == 0) { + rewriter.replaceOp( + op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter)); + return success(); + } + } + + // Find first free dimension in RHS, and lower when found. + VectorType rhsType = op.getRhsType(); + for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) { + if (rhsContractingDimSet.count(rhsIndex) == 0) { + rewriter.replaceOp( + op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter)); + return success(); + } + } + + // Lower the first remaining reduction dimension. + if (!contractingDimMap.empty()) { + rewriter.replaceOp(op, lowerReduction(op, rewriter)); + return success(); + } + + return failure(); +} + +// Lower one parallel dimension. +// TODO(ajcbik): consider reusing existing contract unrolling +Value ContractionOpLowering::lowerParallel(vector::ContractionOp op, + int64_t lhsIndex, int64_t rhsIndex, + PatternRewriter &rewriter) const { + VectorType lhsType = op.getLhsType(); + VectorType rhsType = op.getRhsType(); + VectorType resType = op.getResultType().cast(); + // Find the iterator type index and result index. + SmallVector iMap = op.getIndexingMaps(); + int64_t iterIndex = -1; + int64_t dimSize = -1; + if (lhsIndex >= 0) { + iterIndex = iMap[0].getResult(lhsIndex).cast().getPosition(); + assert( + (rhsIndex < 0 || + iterIndex == + iMap[1].getResult(rhsIndex).cast().getPosition()) && + "parallel index should be free in LHS or batch in LHS/RHS"); + dimSize = lhsType.getDimSize(lhsIndex); + } else { + assert(rhsIndex >= 0 && "missing parallel index"); + iterIndex = iMap[1].getResult(rhsIndex).cast().getPosition(); + dimSize = rhsType.getDimSize(rhsIndex); + } + assert(iterIndex >= 0 && "parallel index not listed in operand mapping"); + Optional lookup = getResultIndex(iMap[2], iterIndex); + assert(lookup.hasValue() && "parallel index not listed in reduction"); + int64_t resIndex = lookup.getValue(); + // Construct new iterator types and affine map array attribute. + SmallVector lowIndexingMaps; + lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter)); + lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter)); + lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter)); + auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); + auto lowIter = + rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex)); + // Unroll into a series of lower dimensional vector.contract ops. + Location loc = op.getLoc(); + Value result = + rewriter.create(loc, resType, rewriter.getZeroAttr(resType)); + for (int64_t d = 0; d < dimSize; ++d) { + auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter); + auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter); + auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter); + Value lowContract = rewriter.create( + loc, lhs, rhs, acc, lowAffine, lowIter); + result = + reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter); + } + return result; +} + +// Lower one reduction dimension. +Value ContractionOpLowering::lowerReduction(vector::ContractionOp op, + PatternRewriter &rewriter) const { + auto loc = op.getLoc(); + VectorType lhsType = op.getLhsType(); + VectorType rhsType = op.getRhsType(); + Type resType = op.getResultType(); + assert(!resType.isa()); + // Use iterator index 0. + int64_t iterIndex = 0; + SmallVector iMap = op.getIndexingMaps(); + Optional lookupLhs = getResultIndex(iMap[0], iterIndex); + Optional lookupRhs = getResultIndex(iMap[1], iterIndex); + assert(lookupLhs.hasValue() && "missing LHS parallel index"); + assert(lookupRhs.hasValue() && "missing RHS parallel index"); + int64_t lhsIndex = lookupLhs.getValue(); + int64_t rhsIndex = lookupRhs.getValue(); + int64_t dimSize = lhsType.getDimSize(lhsIndex); + assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape"); + // Base case. + if (lhsType.getRank() == 1) { + assert(rhsType.getRank() == 1 && "corrupt contraction"); + Value zero = rewriter.create(loc, lhsType, + rewriter.getZeroAttr(lhsType)); + Value fma = rewriter.create(loc, op.lhs(), op.rhs(), zero); + StringAttr kind = rewriter.getStringAttr("add"); + return rewriter.create(loc, resType, kind, fma, + op.acc()); + } + // Construct new iterator types and affine map array attribute. + SmallVector lowIndexingMaps; + lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter)); + lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter)); + lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter)); + auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); + auto lowIter = + rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex)); + // Unroll into a series of lower dimensional vector.contract ops. + // By feeding the initial accumulator into the first contraction, + // and the result of each contraction into the next, eventually + // the sum of all reductions is computed. + Value result = op.acc(); + for (int64_t d = 0; d < dimSize; ++d) { + auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter); + auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter); + result = rewriter.create(loc, lhs, rhs, result, + lowAffine, lowIter); + } + return result; +} + +} // namespace mlir + // TODO(andydavis) Add pattern to rewrite ExtractSlices(ConstantMaskOp). // TODO(andydavis) Add this as DRR pattern. void mlir::vector::populateVectorToVectorTransformationPatterns( @@ -1685,6 +1758,8 @@ ShapeCastOp2DDownCastRewritePattern, ShapeCastOp2DUpCastRewritePattern, TransposeOpLowering>(context); + patterns.insert(parameters, context); // clang-format on - patterns.insert(parameters, context); } diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -1,5 +1,6 @@ -// RUN: mlir-opt %s -test-vector-contraction-conversion | FileCheck %s -// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX +// RUN: mlir-opt %s -test-vector-contraction-conversion | FileCheck %s --dump-input-on-failure +// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX --dump-input-on-failure +// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT --dump-input-on-failure #dotp_accesses = [ affine_map<(i) -> (i)>, @@ -382,6 +383,35 @@ // MATRIX: %[[mm4:.*]] = vector.extract_strided_slice %[[mm1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32> // MATRIX: %[[mm5:.*]] = vector.insert %[[mm4]], %[[mm3]] [1] : vector<3xf32> into vector<2x3xf32> // MATRIX: %[[mm6:.*]] = addf %[[C]], %[[mm5]] : vector<2x3xf32> + +// OUTERPRODUCT-LABEL: func @matmul +// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>, +// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>, +// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> +// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// OUTERPRODUCT-SAME: : vector<2x4xf32> to vector<4x2xf32> +// +// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<4x2xf32> +// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<4x3xf32> +// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] +// OUTERPRODUCT-SAME: : vector<2xf32>, vector<3xf32> +// +// OUTERPRODUCT: %[[a1:.*]] = vector.extract %[[At]][1] : vector<4x2xf32> +// OUTERPRODUCT: %[[b1:.*]] = vector.extract %[[B]][1] : vector<4x3xf32> +// OUTERPRODUCT: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]] +// OUTERPRODUCT-SAME: : vector<2xf32>, vector<3xf32> +// +// OUTERPRODUCT: %[[a2:.*]] = vector.extract %[[At]][2] : vector<4x2xf32> +// OUTERPRODUCT: %[[b2:.*]] = vector.extract %[[B]][2] : vector<4x3xf32> +// OUTERPRODUCT: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]] +// OUTERPRODUCT-SAME: : vector<2xf32>, vector<3xf32> +// +// OUTERPRODUCT: %[[a3:.*]] = vector.extract %[[At]][3] : vector<4x2xf32> +// OUTERPRODUCT: %[[b3:.*]] = vector.extract %[[B]][3] : vector<4x3xf32> +// OUTERPRODUCT: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]] +// OUTERPRODUCT-SAME: : vector<2xf32>, vector<3xf32> +// +// OUTERPRODUCT: return %[[c3]] : vector<2x3xf32> func @matmul(%arg0: vector<2x4xf32>, %arg1: vector<4x3xf32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> { diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -51,11 +51,26 @@ *this, "vector-lower-matrix-intrinsics", llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), llvm::cl::init(false)}; + Option lowerToOuterProduct{ + *this, "vector-outerproduct", + llvm::cl::desc("Lower vector.contract to vector.outerproduct"), + llvm::cl::init(false)}; void runOnFunction() override { OwningRewritePatternList patterns; - VectorTransformsOptions options{ - /*lowerToLLVMMatrixIntrinsics=*/lowerToLLVMMatrixIntrinsics}; + if (lowerToOuterProduct) { + VectorContractLowering lowering = VectorContractLowering::OuterProduct; + VectorTransformsOptions options{lowering}; + patterns.insert(options, + &getContext()); + applyPatternsAndFoldGreedily(getFunction(), patterns); + return; + } + + VectorContractLowering lowering = VectorContractLowering::FMA; + if (lowerToLLVMMatrixIntrinsics) + lowering = VectorContractLowering::Matmul; + VectorTransformsOptions options{lowering}; populateVectorContractLoweringPatterns(patterns, &getContext(), options); applyPatternsAndFoldGreedily(getFunction(), patterns); }