diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -113,6 +113,22 @@ bool isMinorIdentityWithBroadcasting( SmallVectorImpl *broadcastedDims = nullptr) const; + /// Return true if this affine map can be converted to a minor identity with + /// broadcast by doing a permute. Return a permutation (there may be + /// several) to apply to get to a minor identity with broadcasts. + /// Ex: + /// * (d0, d1, d2) -> (0, d1) maps to minor identity (d1, 0 = d2) with + /// perm = [1, 0] and broadcast d2 + /// * (d0, d1, d2) -> (d0, 0) cannot be mapped to a minor identity by + /// permutation + broadcast + /// * (d0, d1, d2, d3) -> (0, d1, d3) maps to minor identity (d1, 0 = d2, d3) + /// with perm = [1, 0, 2] and broadcast d2 + /// * (d0, d1) -> (d1, 0, 0, d0) maps to minor identity (d0, d1) with extra + /// leading broadcat dimensions. The map returned would be (0, 0, d0, d1) + /// with perm = [3, 0, 1, 2] + bool isPermutationOfMinorIdentityWithBroadcasting( + SmallVectorImpl &permutedDims) const; + /// Returns true if this affine map is an empty map, i.e., () -> (). bool isEmpty() const; diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -2842,6 +2842,113 @@ } }; +/// Lower transfer_read op with permutation into a transfer_read with a +/// permutation map composed of leading zeros followed by a minor identiy + +/// vector.transpose op. +/// Ex: +/// vector.transfer_read ... +/// permutation_map: (d0, d1, d2) -> (0, d1) +/// into: +/// %v = vector.transfer_read ... +/// permutation_map: (d0, d1, d2) -> (d1, 0) +/// vector.transpose %v, [1, 0] +/// +/// vector.transfer_read ... +/// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3) +/// into: +/// %v = vector.transfer_read ... +/// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3) +/// vector.transpose %v, [0, 1, 3, 2, 4] +/// Note that an alternative is to transform it to linalg.transpose + +/// vector.transfer_read to do the transpose in memory instead. +struct TransferReadPermutationLowering + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp op, + PatternRewriter &rewriter) const override { + SmallVector permutation; + AffineMap map = op.permutation_map(); + if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) + return failure(); + + AffineMap permutationMap = + map.getPermutationMap(permutation, op.getContext()); + if (permutationMap.isIdentity()) + return failure(); + // Caluclate the map of the new read by applying the inverse permutation. + permutationMap = inversePermutation(permutationMap); + AffineMap newMap = permutationMap.compose(map); + // Apply the reverse transpose to deduce the type of the transfer_read. + ArrayRef originalShape = op.getVectorType().getShape(); + SmallVector newVectorShape(originalShape.size()); + for (auto pos : llvm::enumerate(permutation)) { + newVectorShape[pos.value()] = originalShape[pos.index()]; + } + VectorType newReadType = + VectorType::get(newVectorShape, op.getVectorType().getElementType()); + Value newRead = rewriter.create( + op.getLoc(), newReadType, op.source(), op.indices(), newMap, + op.padding(), op.masked() ? *op.masked() : ArrayAttr()); + SmallVector transposePerm(permutation.begin(), permutation.end()); + rewriter.replaceOpWithNewOp(op, newRead, + transposePerm); + return success(); + } +}; + +/// Lower transfer_read op with broadcast in the leading dimensions into +/// transfer_read of lower rank + vector.broadcast. +/// Ex: vector.transfer_read ... +/// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3) +/// into: +/// %v = vector.transfer_read ... +/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3) +/// vector.broadcast %v +struct TransferOpReduceRank : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp op, + PatternRewriter &rewriter) const override { + AffineMap map = op.permutation_map(); + unsigned numLeadingBroadcast = 0; + for (auto expr : map.getResults()) { + auto dimExpr = expr.dyn_cast(); + if (!dimExpr || dimExpr.getValue() != 0) + break; + numLeadingBroadcast++; + } + // If there are no leading zeros in the map there is nothing to do. + if (numLeadingBroadcast == 0) + return failure(); + VectorType originalVecType = op.getVectorType(); + unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast; + // Calculate new map, vector type and masks without the leading zeros. + AffineMap newMap = AffineMap::get( + map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank), + op.getContext()); + // Only remove the leading zeros if the rest of the map is a minor identity + // with broadasting. Otherwise we first want to permute the map. + if (!newMap.isMinorIdentityWithBroadcasting()) + return failure(); + SmallVector newShape = llvm::to_vector<4>( + originalVecType.getShape().take_back(reducedShapeRank)); + VectorType newReadType = + VectorType::get(newShape, originalVecType.getElementType()); + ArrayAttr newMask = + op.masked() + ? rewriter.getArrayAttr( + op.maskedAttr().getValue().take_back(reducedShapeRank)) + : ArrayAttr(); + Value newRead = rewriter.create( + op.getLoc(), newReadType, op.source(), op.indices(), newMap, + op.padding(), newMask); + rewriter.replaceOpWithNewOp(op, originalVecType, + newRead); + return success(); + } +}; + // Trims leading one dimensions from `oldType` and returns the result type. // Returns `vector<1xT>` if `oldType` only has one element. static VectorType trimLeadingOneDims(VectorType oldType) { @@ -3317,6 +3424,8 @@ void mlir::vector::populateVectorTransferLoweringPatterns( RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns + .add( + patterns.getContext()); } diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/MathExtras.h" +#include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/raw_ostream.h" @@ -140,6 +141,66 @@ return true; } +/// Return true if this affine map can be converted to a minor identity with +/// broadcast by doing a permute. Return a permutation (there may be +/// several) to apply to get to a minor identity with broadcasts. +/// Ex: +/// * (d0, d1, d2) -> (0, d1) maps to minor identity (d1, 0 = d2) with +/// perm = [1, 0] and broadcast d2 +/// * (d0, d1, d2) -> (d0, 0) cannot be mapped to a minor identity by +/// permutation + broadcast +/// * (d0, d1, d2, d3) -> (0, d1, d3) maps to minor identity (d1, 0 = d2, d3) +/// with perm = [1, 0, 2] and broadcast d2 +/// * (d0, d1) -> (d1, 0, 0, d0) maps to minor identity (d0, d1) with extra +/// leading broadcat dimensions. The map returned would be (0, 0, d0, d1) with +/// perm = [3, 0, 1, 2] +bool AffineMap::isPermutationOfMinorIdentityWithBroadcasting( + SmallVectorImpl &permutedDims) const { + unsigned projectionStart = + getNumResults() < getNumInputs() ? getNumInputs() - getNumResults() : 0; + permutedDims.clear(); + SmallVector broadcastDims; + permutedDims.resize(getNumResults(), 0); + // If there are more results than input dimensions we want the new map to + // start with broadcast dimensions in order to be a minor identity with + // broadcasting. + unsigned leadingBroadcast = + getNumResults() > getNumInputs() ? getNumResults() - getNumInputs() : 0; + llvm::SmallBitVector dimFound(std::max(getNumInputs(), getNumResults()), + false); + for (auto idxAndExpr : llvm::enumerate(getResults())) { + unsigned resIdx = idxAndExpr.index(); + AffineExpr expr = idxAndExpr.value(); + // Each result may be either a constant 0 (broadcast dimension) or a + // dimension. + if (auto constExpr = expr.dyn_cast()) { + if (constExpr.getValue() != 0) + return false; + broadcastDims.push_back(resIdx); + } else if (auto dimExpr = expr.dyn_cast()) { + if (dimExpr.getPosition() < projectionStart) + return false; + unsigned newPosition = + dimExpr.getPosition() - projectionStart + leadingBroadcast; + permutedDims[resIdx] = newPosition; + dimFound[newPosition] = true; + } else { + return false; + } + } + // Find a permuation for the broadcast dimension. Since they are broadcasted + // any valid permutation is acceptable. We just permute the dim into a slot + // without an existing dimension. + unsigned pos = 0; + for (auto dim : broadcastDims) { + while (pos < dimFound.size() && dimFound[pos]) { + pos++; + } + permutedDims[dim] = pos++; + } + return true; +} + /// Returns an AffineMap representing a permutation. AffineMap AffineMap::getPermutationMap(ArrayRef permutation, MLIRContext *context) { diff --git a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir @@ -206,3 +206,56 @@ %res = vector.transfer_read %mem[%i, %i, %i, %i, %i], %cf0 {masked = [false, false, false, false], permutation_map = #broadcast} : memref<10x20x30x8x8xf32>, vector<3x2x4x5xf32> return %res : vector<3x2x4x5xf32> } + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3) -> (d1, d0, 0, 0)> +#map1 = affine_map<(d0, d1, d2, d3) -> (0, d1, 0, d0)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d3, d1, 0, 0)> +#map3 = affine_map<(d0, d1) -> (d1, d0, 0, 0)> +#map4 = affine_map<(d0, d1) -> (0, d1, 0, d0)> +#map5 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)> + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, 0, 0)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)> + +// CHECK-LABEL: func @transfer_read_permutations +func @transfer_read_permutations(%arg0 : memref, %arg1 : memref) + -> (vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, + vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>) { +// CHECK-DAG: %[[CF0:.*]] = constant 0.000000e+00 : f32 +// CHECK-DAG: %[[C0:.*]] = constant 0 : index + %cst = constant 0.000000e+00 : f32 + %c0 = constant 0 : index + + %0 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {permutation_map = #map0} : memref, vector<7x14x8x16xf32> +// CHECK: vector.transfer_read {{.*}} {permutation_map = #[[$MAP0]]} : memref, vector<14x7x8x16xf32> +// CHECK: vector.transpose %{{.*}}, [1, 0, 2, 3] : vector<14x7x8x16xf32> to vector<7x14x8x16xf32> + + %1 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {permutation_map = #map1} : memref, vector<7x14x8x16xf32> +// CHECK: vector.transfer_read {{.*}} {permutation_map = #[[$MAP0]]} : memref, vector<16x14x7x8xf32> +// CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32> + + %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {masked = [false, false, true, false], permutation_map = #map2} : memref, vector<7x14x8x16xf32> +// CHECK: vector.transfer_read {{.*}} {masked = [false, true, false], permutation_map = #[[$MAP1]]} : memref, vector<14x16x7xf32> +// CHECK: vector.broadcast %{{.*}} : vector<14x16x7xf32> to vector<8x14x16x7xf32> +// CHECK: vector.transpose %{{.*}}, [3, 1, 0, 2] : vector<8x14x16x7xf32> to vector<7x14x8x16xf32> + + %3 = vector.transfer_read %arg0[%c0, %c0], %cst {permutation_map = #map3} : memref, vector<7x14x8x16xf32> +// CHECK: vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CF0]] : memref, vector<14x7xf32> +// CHECK: vector.broadcast %{{.*}} : vector<14x7xf32> to vector<8x16x14x7xf32> +// CHECK: vector.transpose %{{.*}}, [3, 2, 0, 1] : vector<8x16x14x7xf32> to vector<7x14x8x16xf32> + + %4 = vector.transfer_read %arg0[%c0, %c0], %cst {permutation_map = #map4} : memref, vector<7x14x8x16xf32> +// CHECK: vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %[[CF0]] : memref, vector<16x14xf32> +// CHECK: vector.broadcast %{{.*}} : vector<16x14xf32> to vector<7x8x16x14xf32> +// CHECK: vector.transpose %{{.*}}, [0, 3, 1, 2] : vector<7x8x16x14xf32> to vector<7x14x8x16xf32> + + %5 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {permutation_map = #map5} : memref, vector<7x14x8x16xf32> +// CHECK: vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[CF0]] : memref, vector<16x14x7x8xf32> +// CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32> + + return %0, %1, %2, %3, %4, %5 : vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, + vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, + vector<7x14x8x16xf32> +}