diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -44,6 +44,7 @@ } // namespace arith namespace vector { +class ContractionOp; class TransferReadOp; class TransferWriteOp; class VectorDialect; @@ -76,6 +77,11 @@ void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); +/// Cast away the leading unit dim, if exists, for the given contract op. +/// Return success if the transformation applies; return failure otherwise. +LogicalResult castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, + RewriterBase &rewriter); + /// Collect a set of leading one dimension removal patterns. /// /// These patterns insert vector.shape_cast to remove leading one dimensions diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -279,6 +279,121 @@ } }; +} // namespace + +LogicalResult +mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, + RewriterBase &rewriter) { + VectorType oldAccType = contractOp.getAccType().dyn_cast(); + if (oldAccType == nullptr) + return failure(); + if (oldAccType.getRank() < 2) + return failure(); + if (oldAccType.getShape()[0] != 1) + return failure(); + // currently we support only dropping one dim but the pattern can be applied + // greedily to drop more. + int64_t dropDim = 1; + + auto oldIndexingMaps = contractOp.getIndexingMapsArray(); + SmallVector newIndexingMaps; + + auto oldIteratorTypes = contractOp.getIteratorTypes(); + SmallVector newIteratorTypes; + + int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0); + + if (!isParallelIterator(oldIteratorTypes[dimToDrop])) + // only parallel type iterators can be dropped. + return failure(); + + for (const auto &it : llvm::enumerate(oldIteratorTypes)) { + int64_t currDim = it.index(); + if (currDim == dimToDrop) + continue; + newIteratorTypes.push_back(it.value()); + } + + SmallVector operands = {contractOp.getLhs(), contractOp.getRhs(), + contractOp.getAcc()}; + SmallVector newOperands; + + for (const auto &it : llvm::enumerate(oldIndexingMaps)) { + // Check if the dim to be dropped exists as a leading dim in the operand + // if it does then we use vector.extract to drop it. + bool validExtract = false; + SmallVector results; + auto map = it.value(); + int64_t orginalZeroDim = it.value().getDimPosition(0); + if (orginalZeroDim != dimToDrop) { + // There are two reasons to be in this path, 1. We need to + // tranpose the operand to make the dim to be dropped + // leading. 2. The dim to be dropped does not exist and in + // that case we dont want to add a unit tranpose but we must + // check all the indices to make sure this is the case. + bool tranposeNeeded = false; + SmallVector perm; + SmallVector transposeResults; + + for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { + int64_t currDim = map.getDimPosition(i); + if (currDim == dimToDrop) { + tranposeNeeded = true; + perm.insert(perm.begin(), i); + auto targetExpr = rewriter.getAffineDimExpr(currDim); + transposeResults.insert(transposeResults.begin(), targetExpr); + } else { + perm.push_back(i); + auto targetExpr = rewriter.getAffineDimExpr(currDim); + transposeResults.push_back(targetExpr); + } + } + // Do the tranpose now if needed so that we can drop the + // correct dim using extract later. + if (tranposeNeeded) { + map = AffineMap::get(map.getNumDims(), 0, transposeResults, + contractOp.getContext()); + operands[it.index()] = rewriter.create( + contractOp.getLoc(), operands[it.index()], perm); + } + } + // We have taken care to have the dim to be dropped be + // the leading dim. If its still not leading that means it + // does not exist in this operand and hence we do not need + // an extract. + if (map.getDimPosition(0) == dimToDrop) + validExtract = true; + + for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { + int64_t currDim = map.getDimPosition(i); + if (currDim == dimToDrop) + // This is the dim we are dropping. + continue; + auto targetExpr = rewriter.getAffineDimExpr( + currDim < dimToDrop ? currDim : currDim - 1); + results.push_back(targetExpr); + } + newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results, + contractOp.getContext())); + // Extract if its a valid extraction, otherwise use the operand + // without extraction. + newOperands.push_back( + validExtract ? rewriter.create(contractOp.getLoc(), + operands[it.index()], + splatZero(dropDim)) + : operands[it.index()]); + } + auto newContractOp = rewriter.create( + contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2], + rewriter.getAffineMapArrayAttr(newIndexingMaps), + rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind()); + rewriter.replaceOpWithNewOp( + contractOp, contractOp->getResultTypes()[0], newContractOp); + return success(); +} + +namespace { + /// Turns vector.contract on vector with leading 1 dimensions into /// vector.extract followed by vector.contract on vector without leading /// 1 dimensions. Also performs tranpose of lhs and rhs operands if required @@ -289,112 +404,7 @@ LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override { - VectorType oldAccType = contractOp.getAccType().dyn_cast(); - if (oldAccType == nullptr) - return failure(); - if (oldAccType.getRank() < 2) - return failure(); - if (oldAccType.getShape()[0] != 1) - return failure(); - // currently we support only dropping one dim but the pattern can be applied - // greedily to drop more. - int64_t dropDim = 1; - - auto oldIndexingMaps = contractOp.getIndexingMapsArray(); - SmallVector newIndexingMaps; - - auto oldIteratorTypes = contractOp.getIteratorTypes(); - SmallVector newIteratorTypes; - - int64_t dimToDrop = oldIndexingMaps[2].getDimPosition(0); - - if (!isParallelIterator(oldIteratorTypes[dimToDrop])) - // only parallel type iterators can be dropped. - return failure(); - - for (const auto &it : llvm::enumerate(oldIteratorTypes)) { - int64_t currDim = it.index(); - if (currDim == dimToDrop) - continue; - newIteratorTypes.push_back(it.value()); - } - - SmallVector operands = {contractOp.getLhs(), contractOp.getRhs(), - contractOp.getAcc()}; - SmallVector newOperands; - - for (const auto &it : llvm::enumerate(oldIndexingMaps)) { - // Check if the dim to be dropped exists as a leading dim in the operand - // if it does then we use vector.extract to drop it. - bool validExtract = false; - SmallVector results; - auto map = it.value(); - int64_t orginalZeroDim = it.value().getDimPosition(0); - if (orginalZeroDim != dimToDrop) { - // There are two reasons to be in this path, 1. We need to - // tranpose the operand to make the dim to be dropped - // leading. 2. The dim to be dropped does not exist and in - // that case we dont want to add a unit tranpose but we must - // check all the indices to make sure this is the case. - bool tranposeNeeded = false; - SmallVector perm; - SmallVector transposeResults; - - for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { - int64_t currDim = map.getDimPosition(i); - if (currDim == dimToDrop) { - tranposeNeeded = true; - perm.insert(perm.begin(), i); - auto targetExpr = rewriter.getAffineDimExpr(currDim); - transposeResults.insert(transposeResults.begin(), targetExpr); - } else { - perm.push_back(i); - auto targetExpr = rewriter.getAffineDimExpr(currDim); - transposeResults.push_back(targetExpr); - } - } - // Do the tranpose now if needed so that we can drop the - // correct dim using extract later. - if (tranposeNeeded) { - map = AffineMap::get(map.getNumDims(), 0, transposeResults, - contractOp.getContext()); - operands[it.index()] = rewriter.create( - contractOp.getLoc(), operands[it.index()], perm); - } - } - // We have taken care to have the dim to be dropped be - // the leading dim. If its still not leading that means it - // does not exist in this operand and hence we do not need - // an extract. - if (map.getDimPosition(0) == dimToDrop) - validExtract = true; - - for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { - int64_t currDim = map.getDimPosition(i); - if (currDim == dimToDrop) - // This is the dim we are dropping. - continue; - auto targetExpr = rewriter.getAffineDimExpr( - currDim < dimToDrop ? currDim : currDim - 1); - results.push_back(targetExpr); - } - newIndexingMaps.push_back(AffineMap::get(map.getNumDims() - 1, 0, results, - contractOp.getContext())); - // Extract if its a valid extraction, otherwise use the operand - // without extraction. - newOperands.push_back(validExtract - ? rewriter.create( - contractOp.getLoc(), operands[it.index()], - splatZero(dropDim)) - : operands[it.index()]); - } - auto newContractOp = rewriter.create( - contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2], - rewriter.getAffineMapArrayAttr(newIndexingMaps), - rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind()); - rewriter.replaceOpWithNewOp( - contractOp, contractOp->getResultTypes()[0], newContractOp); - return success(); + return castAwayContractionLeadingOneDim(contractOp, rewriter); } };