diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/VectorOps/VectorOps.h" #include "mlir/Dialect/VectorOps/VectorTransforms.h" #include "mlir/Dialect/VectorOps/VectorUtils.h" @@ -864,6 +865,19 @@ }; /// Progressive lowering of ConstractionOp. +/// One: +/// %x = vector.contract with at least one free/batch dimension +/// is replaced by: +/// %a = vector.contract with one less free/batch dimension +/// %b = vector.contract with one less free/batch dimension +/// .. +/// %x = combine %a %b .. +/// until a pure contraction is reached (no free/batch dimensions), +/// which is replaced by a fma/reduction op. +/// +/// TODO(ajcbik): break down into transpose/reshape/cast ops +/// when they become available to avoid code dup +/// TODO(ajcbik): investigate lowering order impact on performance class ContractionOpLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -874,16 +888,13 @@ if (llvm::size(op.masks()) != 0) return matchFailure(); - auto loc = op.getLoc(); - VectorType lhsType = op.getLhsType(); - VectorType rhsType = op.getRhsType(); - Type resType = op.getResultType(); - - // Find first batch dimension in lhs/rhs, and lower when found. + // Find first batch dimension in LHS/RHS, and lower when found. std::vector> batchDimMap = op.getBatchDimMap(); if (!batchDimMap.empty()) { - // TODO(ajcbik): implement batch - return matchFailure(); + int64_t lhsIndex = batchDimMap[0].first; + int64_t rhsIndex = batchDimMap[0].second; + rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter)); + return matchSuccess(); } // Collect contracting dimensions. @@ -896,24 +907,35 @@ rhsContractingDimSet.insert(dimPair.second); } - // Find free dimension in lhs/rhs, and lower first when found. - for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) { - if (lhsContractingDimSet.count(i) == 0) { - // TODO(ajcbik): implement free - return matchFailure(); + // Find first free dimension in LHS, and lower when found. + VectorType lhsType = op.getLhsType(); + for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; + ++lhsIndex) { + if (lhsContractingDimSet.count(lhsIndex) == 0) { + rewriter.replaceOp( + op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter)); + return matchSuccess(); } } - for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) { - if (rhsContractingDimSet.count(i) == 0) { - // TODO(ajcbik): implement free - return matchFailure(); + + // Find first free dimension in RHS, and lower when found. + VectorType rhsType = op.getRhsType(); + for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; + ++rhsIndex) { + if (rhsContractingDimSet.count(rhsIndex) == 0) { + rewriter.replaceOp( + op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter)); + return matchSuccess(); } } - // Only contraction dimensions remain. + // Lower the only remaining contraction dimensions. + // TODO(ajcbik): handle multi-dim reductions + auto loc = op.getLoc(); + Type resType = op.getResultType(); if (!resType.isa() && lhsType.getRank() == 1 && rhsType.getRank() == 1) { - // Handle reduction into scalar. + Value zero = rewriter.create(loc, resType, rewriter.getZeroAttr(resType)); Value splat = rewriter.create(loc, lhsType, zero); @@ -924,9 +946,191 @@ op.acc()); return matchSuccess(); } - // TODO(ajcbik): implement more contraction + return matchFailure(); } + +private: + // Lower one parallel dimension. + // TODO(ajcbik): consider reusing existing contract unrolling + Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex, + int64_t rhsIndex, PatternRewriter &rewriter) const { + VectorType lhsType = op.getLhsType(); + VectorType rhsType = op.getRhsType(); + VectorType resType = op.getResultType().cast(); + // Find the iterator type index and result index. + SmallVector iMap = op.getIndexingMaps(); + int64_t iterIndex = -1; + int64_t dimSize = -1; + if (lhsIndex >= 0) { + iterIndex = + iMap[0].getResult(lhsIndex).cast().getPosition(); + assert((rhsIndex < 0 || iterIndex == iMap[1] + .getResult(rhsIndex) + .cast() + .getPosition()) && + "parallel index should be free in LHS or batch in LHS/RHS"); + dimSize = lhsType.getDimSize(lhsIndex); + } else { + assert(rhsIndex >= 0 && "missing parallel index"); + iterIndex = + iMap[1].getResult(rhsIndex).cast().getPosition(); + dimSize = rhsType.getDimSize(rhsIndex); + } + assert(iterIndex >= 0 && "parallel index not listed in operand mapping"); + Optional lookup = getResultIndex(iMap[2], iterIndex); + assert(lookup.hasValue() && "parallel index not listed in reduction"); + int64_t resIndex = lookup.getValue(); + // Construct new iterator types. + ArrayAttr iteratorTypes = op.iterator_types(); + SmallVector lowIterTypes; + for (auto it : llvm::enumerate(iteratorTypes)) { + int64_t idx = it.index(); + if (idx == iterIndex) { + assert(it.value().cast().getValue() == + getParallelIteratorTypeName() && + "parallel index not marked as such"); + continue; + } + lowIterTypes.push_back(it.value()); + } + // Construct new affine map array attribute. + SmallVector lowIndexingMaps; + lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter)); + lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter)); + lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter)); + auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); + // Construct new iterator types array attribute. + auto lowIter = rewriter.getArrayAttr(lowIterTypes); + // Unroll into a series of lower dimensional vector.contract ops. + Location loc = op.getLoc(); + Value result = zeroVector(loc, resType, rewriter); + for (int64_t d = 0; d < dimSize; ++d) { + auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter); + auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter); + auto acc = reshapeLoad(loc, op.acc(), resType, resIndex, d, rewriter); + Value lowContract = rewriter.create( + loc, lhs, rhs, acc, lowAffine, lowIter); + result = reshapeStore(loc, lowContract, result, resType, resIndex, d, + rewriter); + } + return result; + } + + // Helper method to construct a zero vector. + static Value zeroVector(Location loc, VectorType vType, + PatternRewriter &rewriter) { + Type eltType = vType.getElementType(); + Value zero = rewriter.create(loc, eltType, + rewriter.getZeroAttr(eltType)); + return rewriter.create(loc, vType, zero); + } + + // Helper to find an index in an affine map. + static Optional getResultIndex(AffineMap map, int64_t index) { + for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { + int64_t idx = map.getResult(i).cast().getPosition(); + if (idx == index) + return i; + } + return None; + } + + // Helper to construct an affine map with one index removed. + static AffineMap adjustMap(AffineMap map, int64_t index, + PatternRewriter &rewriter) { + SmallVector results; + for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { + int64_t idx = map.getResult(i).cast().getPosition(); + if (idx == index) + continue; + // Re-insert remaining indices, but renamed when occurring + // after the removed index. + auto targetExpr = + getAffineDimExpr(idx < index ? idx : idx - 1, rewriter.getContext()); + results.push_back(targetExpr); + } + // Since (...) -> () cannot be represented properly, + // we resort to an empty map when this situation happens. + return results.empty() ? AffineMap::get(rewriter.getContext()) + : AffineMap::get(map.getNumDims() - 1, 0, results); + } + + // Helper to drop dimension from vector type. + static Type adjustType(VectorType tp, int64_t index) { + int64_t rank = tp.getRank(); + Type eltType = tp.getElementType(); + if (rank == 1) { + assert(index == 0 && "index for scalar result out of bounds"); + return eltType; + } + SmallVector adjustedShape; + for (int64_t i = 0; i < rank; ++i) { + // Omit dimension at the given index. + if (i == index) + continue; + // Otherwise, add dimension back. + adjustedShape.push_back(tp.getDimSize(i)); + } + return VectorType::get(adjustedShape, eltType); + } + + // Helper method to possibly drop a dimension in a load. + // TODO(ajcbik): use a reshaping vector load (and share lowering code) + static Value reshapeLoad(Location loc, Value val, VectorType type, + int64_t index, int64_t pos, + PatternRewriter &rewriter) { + if (index == -1) + return val; + Type lowType = adjustType(type, 0); + // At extraction dimension? + if (index == 0) { + auto posAttr = rewriter.getI64ArrayAttr(pos); + return rewriter.create(loc, lowType, val, posAttr); + } + // Unroll leading dimensions. + VectorType vType = lowType.cast(); + VectorType resType = adjustType(type, index).cast(); + Value result = zeroVector(loc, resType, rewriter); + for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) { + auto posAttr = rewriter.getI64ArrayAttr(d); + Value ext = rewriter.create(loc, vType, val, posAttr); + Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); + result = rewriter.create(loc, resType, load, result, + posAttr); + } + return result; + } + + // Helper method to possibly drop a dimension in a store. + // TODO(ajcbik): use a reshaping vector store (and share lowering code) + static Value reshapeStore(Location loc, Value val, Value result, + VectorType type, int64_t index, int64_t pos, + PatternRewriter &rewriter) { + // Unmodified? + if (index == -1) + return val; + // At insertion dimension? + if (index == 0) { + auto posAttr = rewriter.getI64ArrayAttr(pos); + return rewriter.create(loc, type, val, result, posAttr); + } + // Unroll leading dimensions. + Type lowType = adjustType(type, 0); + VectorType vType = lowType.cast(); + Type insType = adjustType(vType, 0); + for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { + auto posAttr = rewriter.getI64ArrayAttr(d); + Value ext = + rewriter.create(loc, vType, result, posAttr); + Value ins = + rewriter.create(loc, insType, val, posAttr); + Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter); + result = + rewriter.create(loc, type, sto, result, posAttr); + } + return result; + } }; } // namespace diff --git a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir --- a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir @@ -14,7 +14,7 @@ // CHECK-SAME: %[[A:.*0]]: vector<4xf32>, // CHECK-SAME: %[[B:.*1]]: vector<4xf32>, // CHECK-SAME: %[[C:.*2]]: f32 -// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> +// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<4xf32> // CHECK: %[[F:.*]] = vector.fma %[[A]], %[[B]], %[[Z]] : vector<4xf32> // CHECK: %[[R:.*]] = vector.reductionv2 "add", %[[F]], %[[C]] // CHECK: return %[[R]] : f32 @@ -24,3 +24,148 @@ : vector<4xf32>, vector<4xf32> into f32 return %0 : f32 } + +#matvec_accesses = [ + affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (j)>, + affine_map<(i, j) -> (i)> +] +#matvec_trait = { + indexing_maps = #matvec_accesses, + iterator_types = ["parallel", "reduction"] +} + +// CHECK-LABEL: func @extract_contract2 +// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>, +// CHECK-SAME: %[[B:.*1]]: vector<3xf32>, +// CHECK-SAME: %[[C:.*2]]: vector<2xf32> +// CHECK: %[[R:.*]] = constant dense<0.000000e+00> : vector<2xf32> +// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> +// CHECK: %[[T1:.*]] = vector.extract %[[C]][0] : vector<2xf32> +// CHECK: %[[T2:.*]] = vector.fma %[[T0]], %[[B]], %[[Z]] : vector<3xf32> +// CHECK: %[[T3:.*]] = vector.reductionv2 "add", %[[T2]], %[[T1]] : vector<3xf32>, f32 into f32 +// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32> +// CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> +// CHECK: %[[T6:.*]] = vector.extract %[[C]][1] : vector<2xf32> +// CHECK: %[[T7:.*]] = vector.fma %[[T5]], %[[B]], %[[Z]] : vector<3xf32> +// CHECK: %[[T8:.*]] = vector.reductionv2 "add", %[[T7]], %[[T6]] : vector<3xf32>, f32 into f32 +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32> +// CHECK: return %[[T9]] : vector<2xf32> + +func @extract_contract2(%arg0: vector<2x3xf32>, + %arg1: vector<3xf32>, + %arg2: vector<2xf32>) -> vector<2xf32> { + %0 = vector.contract #matvec_trait %arg0, %arg1, %arg2 + : vector<2x3xf32>, vector<3xf32> into vector<2xf32> + return %0 : vector<2xf32> +} + +#vecmat_accesses = [ + affine_map<(i, j) -> (j)>, + affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (i)> +] +#vecmat_trait = { + indexing_maps = #vecmat_accesses, + iterator_types = ["parallel", "reduction"] +} + +// CHECK-LABEL: func @extract_contract3 +// CHECK-SAME: %[[A:.*0]]: vector<3xf32>, +// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>, +// CHECK-SAME: %[[C:.*2]]: vector<2xf32> +// CHECK: %[[R:.*]] = constant dense<0.000000e+00> : vector<2xf32> +// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[B]][0] : vector<2x3xf32> +// CHECK: %[[T1:.*]] = vector.extract %[[C]][0] : vector<2xf32> +// CHECK: %[[T2:.*]] = vector.fma %[[A]], %[[T0]], %[[Z]] : vector<3xf32> +// CHECK: %[[T3:.*]] = vector.reductionv2 "add", %[[T2]], %[[T1]] : vector<3xf32>, f32 into f32 +// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32> +// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32> +// CHECK: %[[T6:.*]] = vector.extract %[[C]][1] : vector<2xf32> +// CHECK: %[[T7:.*]] = vector.fma %[[A]], %[[T5]], %[[Z]] : vector<3xf32> +// CHECK: %[[T8:.*]] = vector.reductionv2 "add", %[[T7]], %[[T6]] : vector<3xf32>, f32 into f32 +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32> +// CHECK: return %[[T9]] : vector<2xf32> + +func @extract_contract3(%arg0: vector<3xf32>, + %arg1: vector<2x3xf32>, + %arg2: vector<2xf32>) -> vector<2xf32> { + %0 = vector.contract #vecmat_trait %arg0, %arg1, %arg2 + : vector<3xf32>, vector<2x3xf32> into vector<2xf32> + return %0 : vector<2xf32> +} + +#matmat_accesses = [ + affine_map<(i, j, k) -> (i, k)>, + affine_map<(i, j, k) -> (k, j)>, + affine_map<(i, j, k) -> (i, j)> +] +#matmat_trait = { + indexing_maps = #matmat_accesses, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// CHECK-LABEL: func @extract_contract4 +// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>, +// CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>, +// CHECK-SAME: %[[C:.*2]]: vector<2x2xf32> +// CHECK: %[[R:.*]] = constant dense<0.000000e+00> : vector<2x2xf32> +// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<2xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x2xf32> +// CHECK: %[[T1:.*]] = vector.extract %[[C]][0] : vector<2x2xf32> +// CHECK: %[[T2:.*]] = vector.extract %[[B]][0] : vector<2x2xf32> +// CHECK: %[[T3:.*]] = vector.extract %[[T2]][0] : vector<2xf32> +// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[Z]] [0] : f32 into vector<2xf32> +// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x2xf32> +// CHECK: %[[T6:.*]] = vector.extract %[[T5]][0] : vector<2xf32> +// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T4]] [1] : f32 into vector<2xf32> +// CHECK: %[[T8:.*]] = vector.extract %[[T1]][0] : vector<2xf32> +// CHECK: %[[T9:.*]] = vector.fma %[[T0]], %[[T7]], %[[Z]] : vector<2xf32> +// CHECK: %[[T10:.*]] = vector.reductionv2 "add", %[[T9]], %[[T8]] : vector<2xf32>, f32 into f32 +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[Z]] [0] : f32 into vector<2xf32> +// CHECK: %[[T12:.*]] = vector.extract %[[B]][0] : vector<2x2xf32> +// CHECK: %[[T13:.*]] = vector.extract %[[T12]][1] : vector<2xf32> +// CHECK: %[[T14:.*]] = vector.insert %[[T13]], %[[Z]] [0] : f32 into vector<2xf32> +// CHECK: %[[T15:.*]] = vector.extract %[[B]][1] : vector<2x2xf32> +// CHECK: %[[T16:.*]] = vector.extract %[[T15]][1] : vector<2xf32> +// CHECK: %[[T17:.*]] = vector.insert %[[T16]], %[[T14]] [1] : f32 into vector<2xf32> +// CHECK: %[[T18:.*]] = vector.extract %[[T1]][1] : vector<2xf32> +// CHECK: %[[T19:.*]] = vector.fma %[[T0]], %[[T17]], %[[Z]] : vector<2xf32> +// CHECK: %[[T20:.*]] = vector.reductionv2 "add", %[[T19]], %[[T18]] : vector<2xf32>, f32 into f32 +// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [1] : f32 into vector<2xf32> +// CHECK: %[[T22:.*]] = vector.insert %[[T21]], %[[R]] [0] : vector<2xf32> into vector<2x2xf32> +// CHECK: %[[T23:.*]] = vector.extract %[[A]][1] : vector<2x2xf32> +// CHECK: %[[T24:.*]] = vector.extract %[[C]][1] : vector<2x2xf32> +// CHECK: %[[T25:.*]] = vector.extract %[[B]][0] : vector<2x2xf32> +// CHECK: %[[T26:.*]] = vector.extract %[[T25]][0] : vector<2xf32> +// CHECK: %[[T27:.*]] = vector.insert %[[T26]], %[[Z]] [0] : f32 into vector<2xf32> +// CHECK: %[[T28:.*]] = vector.extract %[[B]][1] : vector<2x2xf32> +// CHECK: %[[T29:.*]] = vector.extract %[[T28]][0] : vector<2xf32> +// CHECK: %[[T30:.*]] = vector.insert %[[T29]], %[[T27]] [1] : f32 into vector<2xf32> +// CHECK: %[[T31:.*]] = vector.extract %[[T24]][0] : vector<2xf32> +// CHECK: %[[T32:.*]] = vector.fma %[[T23]], %[[T30]], %[[Z]] : vector<2xf32> +// CHECK: %[[T33:.*]] = vector.reductionv2 "add", %[[T32]], %[[T31]] : vector<2xf32>, f32 into f32 +// CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[Z]] [0] : f32 into vector<2xf32> +// CHECK: %[[T35:.*]] = vector.extract %[[B]][0] : vector<2x2xf32> +// CHECK: %[[T36:.*]] = vector.extract %[[T35]][1] : vector<2xf32> +// CHECK: %[[T37:.*]] = vector.insert %[[T36]], %[[Z]] [0] : f32 into vector<2xf32> +// CHECK: %[[T38:.*]] = vector.extract %[[B]][1] : vector<2x2xf32> +// CHECK: %[[T39:.*]] = vector.extract %[[T38]][1] : vector<2xf32> +// CHECK: %[[T40:.*]] = vector.insert %[[T39]], %[[T37]] [1] : f32 into vector<2xf32> +// CHECK: %[[T41:.*]] = vector.extract %[[T24]][1] : vector<2xf32> +// CHECK: %[[T42:.*]] = vector.fma %[[T23]], %[[T40]], %[[Z]] : vector<2xf32> +// CHECK: %[[T43:.*]] = vector.reductionv2 "add", %[[T42]], %[[T41]] : vector<2xf32>, f32 into f32 +// CHECK: %[[T44:.*]] = vector.insert %[[T43]], %[[T34]] [1] : f32 into vector<2xf32> +// CHECK: %[[T45:.*]] = vector.insert %[[T44]], %[[T22]] [1] : vector<2xf32> into vector<2x2xf32> +// CHECK: return %[[T45]] : vector<2x2xf32> + +func @extract_contract4(%arg0: vector<2x2xf32>, + %arg1: vector<2x2xf32>, + %arg2: vector<2x2xf32>) -> vector<2x2xf32> { + %0 = vector.contract #matmat_trait %arg0, %arg1, %arg2 + : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> + return %0 : vector<2x2xf32> +} +