diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -929,21 +929,9 @@ } } - // Lower the only remaining contraction dimensions. - // TODO(ajcbik): handle multi-dim reductions - auto loc = op.getLoc(); - Type resType = op.getResultType(); - if (!resType.isa() && lhsType.getRank() == 1 && - rhsType.getRank() == 1) { - - Value zero = rewriter.create(loc, resType, - rewriter.getZeroAttr(resType)); - Value splat = rewriter.create(loc, lhsType, zero); - Value fma = - rewriter.create(loc, op.lhs(), op.rhs(), splat); - StringAttr kind = rewriter.getStringAttr("add"); - rewriter.replaceOpWithNewOp(op, resType, kind, fma, - op.acc()); + // Lower the first remaining reduction dimension. + if (!contractingDimMap.empty()) { + rewriter.replaceOp(op, lowerReduction(op, rewriter)); return matchSuccess(); } @@ -981,27 +969,14 @@ Optional lookup = getResultIndex(iMap[2], iterIndex); assert(lookup.hasValue() && "parallel index not listed in reduction"); int64_t resIndex = lookup.getValue(); - // Construct new iterator types. - ArrayAttr iteratorTypes = op.iterator_types(); - SmallVector lowIterTypes; - for (auto it : llvm::enumerate(iteratorTypes)) { - int64_t idx = it.index(); - if (idx == iterIndex) { - assert(it.value().cast().getValue() == - getParallelIteratorTypeName() && - "parallel index not marked as such"); - continue; - } - lowIterTypes.push_back(it.value()); - } - // Construct new affine map array attribute. + // Construct new iterator types and affine map array attribute. SmallVector lowIndexingMaps; lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter)); lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter)); lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter)); auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); - // Construct new iterator types array attribute. - auto lowIter = rewriter.getArrayAttr(lowIterTypes); + auto lowIter = + rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex)); // Unroll into a series of lower dimensional vector.contract ops. Location loc = op.getLoc(); Value result = zeroVector(loc, resType, rewriter); @@ -1017,6 +992,56 @@ return result; } + // Lower one reduction dimension. + Value lowerReduction(vector::ContractionOp op, + PatternRewriter &rewriter) const { + auto loc = op.getLoc(); + VectorType lhsType = op.getLhsType(); + VectorType rhsType = op.getRhsType(); + Type resType = op.getResultType(); + assert(!resType.isa()); + // Use iterator index 0. + int64_t iterIndex = 0; + SmallVector iMap = op.getIndexingMaps(); + Optional lookupLhs = getResultIndex(iMap[0], iterIndex); + Optional lookupRhs = getResultIndex(iMap[1], iterIndex); + assert(lookupLhs.hasValue() && "missing LHS parallel index"); + assert(lookupRhs.hasValue() && "missing RHS parallel index"); + int64_t lhsIndex = lookupLhs.getValue(); + int64_t rhsIndex = lookupRhs.getValue(); + int64_t dimSize = lhsType.getDimSize(lhsIndex); + assert(dimSize == rhsType.getDimSize(rhsIndex) && "corrupt shape"); + // Base case. + if (lhsType.getRank() == 1) { + assert(rhsType.getRank() == 1 && "corrupt contraction"); + Value zero = zeroVector(loc, lhsType, rewriter); + Value fma = rewriter.create(loc, op.lhs(), op.rhs(), zero); + StringAttr kind = rewriter.getStringAttr("add"); + return rewriter.create(loc, resType, kind, fma, + op.acc()); + } + // Construct new iterator types and affine map array attribute. + SmallVector lowIndexingMaps; + lowIndexingMaps.push_back(adjustMap(iMap[0], iterIndex, rewriter)); + lowIndexingMaps.push_back(adjustMap(iMap[1], iterIndex, rewriter)); + lowIndexingMaps.push_back(adjustMap(iMap[2], iterIndex, rewriter)); + auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); + auto lowIter = + rewriter.getArrayAttr(adjustIter(op.iterator_types(), iterIndex)); + // Unroll into a series of lower dimensional vector.contract ops. + // By feeding the initial accumulator into the first contraction, + // and the result of each contraction into the next, eventually + // the sum of all reductions is computed. + Value result = op.acc(); + for (int64_t d = 0; d < dimSize; ++d) { + auto lhs = reshapeLoad(loc, op.lhs(), lhsType, lhsIndex, d, rewriter); + auto rhs = reshapeLoad(loc, op.rhs(), rhsType, rhsIndex, d, rewriter); + result = rewriter.create(loc, lhs, rhs, result, + lowAffine, lowIter); + } + return result; + } + // Helper method to construct a zero vector. static Value zeroVector(Location loc, VectorType vType, PatternRewriter &rewriter) { @@ -1036,6 +1061,20 @@ return None; } + // Helper to construct iterator types with one index removed. + static SmallVector adjustIter(ArrayAttr iteratorTypes, + int64_t index) { + SmallVector results; + for (auto it : llvm::enumerate(iteratorTypes)) { + int64_t idx = it.index(); + if (idx == index) { + continue; + } + results.push_back(it.value()); + } + return results; + } + // Helper to construct an affine map with one index removed. static AffineMap adjustMap(AffineMap map, int64_t index, PatternRewriter &rewriter) { diff --git a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir --- a/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/VectorOps/vector-contract-transforms.mlir @@ -169,3 +169,84 @@ return %0 : vector<2x2xf32> } +#contraction2d_accesses = [ + affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> ()> +] +#contraction2d_trait = { + indexing_maps = #contraction2d_accesses, + iterator_types = ["reduction", "reduction"] +} + +// CHECK-LABEL: func @full_contract1 +// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>, +// CHECK-SAME: %[[B:.*1]]: vector<2x3xf32>, +// CHECK-SAME: %[[C:.*2]]: f32 +// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> +// CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<2x3xf32> +// CHECK: %[[T2:.*]] = vector.fma %[[T0]], %[[T1]], %[[Z]] : vector<3xf32> +// CHECK: %[[T3:.*]] = vector.reductionv2 "add", %[[T2]], %[[C]] : vector<3xf32>, f32 into f32 +// CHECK: %[[T4:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> +// CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32> +// CHECK: %[[T6:.*]] = vector.fma %[[T4]], %[[T5]], %[[Z]] : vector<3xf32> +// CHECK: %[[T7:.*]] = vector.reductionv2 "add", %[[T6]], %[[T3]] : vector<3xf32>, f32 into f32 +// CHECK: return %[[T7]] : f32 + +func @full_contract1(%arg0: vector<2x3xf32>, + %arg1: vector<2x3xf32>, + %arg2: f32) -> f32 { + %0 = vector.contract #contraction2d_trait %arg0, %arg1, %arg2 + : vector<2x3xf32>, vector<2x3xf32> into f32 + return %0 : f32 +} + +#contraction2d_trans_accesses = [ + affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (j, i)>, + affine_map<(i, j) -> ()> +] +#contraction2d_trans_trait = { + indexing_maps = #contraction2d_trans_accesses, + iterator_types = ["reduction", "reduction"] +} + +// CHECK-LABEL: func @full_contract2 +// CHECK-SAME: %[[A:.*0]]: vector<2x3xf32>, +// CHECK-SAME: %[[B:.*1]]: vector<3x2xf32>, +// CHECK-SAME: %[[C:.*2]]: f32 +// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<3xf32> +// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> +// CHECK: %[[T1:.*]] = vector.extract %[[B]][0] : vector<3x2xf32> +// CHECK: %[[T2:.*]] = vector.extract %[[T1]][0] : vector<2xf32> +// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[Z]] [0] : f32 into vector<3xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[B]][1] : vector<3x2xf32> +// CHECK: %[[T5:.*]] = vector.extract %[[T4]][0] : vector<2xf32> +// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T3]] [1] : f32 into vector<3xf32> +// CHECK: %[[T7:.*]] = vector.extract %[[B]][2] : vector<3x2xf32> +// CHECK: %[[T8:.*]] = vector.extract %[[T7]][0] : vector<2xf32> +// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T6]] [2] : f32 into vector<3xf32> +// CHECK: %[[T10:.*]] = vector.fma %[[T0]], %[[T9]], %[[Z]] : vector<3xf32> +// CHECK: %[[T11:.*]] = vector.reductionv2 "add", %[[T10]], %[[C]] : vector<3xf32>, f32 into f32 +// CHECK: %[[T12:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> +// CHECK: %[[T13:.*]] = vector.extract %[[B]][0] : vector<3x2xf32> +// CHECK: %[[T14:.*]] = vector.extract %[[T13]][1] : vector<2xf32> +// CHECK: %[[T15:.*]] = vector.insert %[[T14]], %[[Z]] [0] : f32 into vector<3xf32> +// CHECK: %[[T16:.*]] = vector.extract %[[B]][1] : vector<3x2xf32> +// CHECK: %[[T17:.*]] = vector.extract %[[T16]][1] : vector<2xf32> +// CHECK: %[[T18:.*]] = vector.insert %[[T17]], %[[T15]] [1] : f32 into vector<3xf32> +// CHECK: %[[T19:.*]] = vector.extract %[[B]][2] : vector<3x2xf32> +// CHECK: %[[T20:.*]] = vector.extract %[[T19]][1] : vector<2xf32> +// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T18]] [2] : f32 into vector<3xf32> +// CHECK: %[[T22:.*]] = vector.fma %[[T12]], %[[T21]], %[[Z]] : vector<3xf32> +// CHECK: %[[T23:.*]] = vector.reductionv2 "add", %[[T22]], %[[T11]] : vector<3xf32>, f32 into f32 +// CHECK: return %[[T23]] : f32 + +func @full_contract2(%arg0: vector<2x3xf32>, + %arg1: vector<3x2xf32>, + %arg2: f32) -> f32 { + %0 = vector.contract #contraction2d_trans_trait %arg0, %arg1, %arg2 + : vector<2x3xf32>, vector<3x2xf32> into f32 + return %0 : f32 +}