diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -897,6 +897,15 @@ VectorType getVectorType() { return vector().getType().cast(); } + // Number of dimensions that participate in the permutation map. + unsigned getTransferRank() { + return permutation_map().getNumResults(); + } + // Number of leading dimensions that do not participate in the permutation + // map. + unsigned getLeadingMemRefRank() { + return getMemRefType().getRank() - permutation_map().getNumResults(); + } }]; } diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -99,9 +99,9 @@ /// dimensional identifiers. bool isIdentity() const; - /// Returns true if the map is a minor identity map, i.e. an identity affine - /// map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions. - static bool isMinorIdentity(AffineMap map); + /// Returns true if this affine map is a minor identity, i.e. an identity + /// affine map (d0, ..., dn) -> (dp, ..., dn) on the most minor dimensions. + bool isMinorIdentity() const; /// Returns true if this affine map is an empty map, i.e., () -> (). bool isEmpty() const; diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp --- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp +++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp @@ -72,7 +72,7 @@ llvm::size(xferOp.indices()) == 0) return failure(); - if (!AffineMap::isMinorIdentity(xferOp.permutation_map())) + if (!xferOp.permutation_map().isMinorIdentity()) return failure(); // Have it handled in vector->llvm conversion pass. diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -89,7 +89,7 @@ // TODO: when we go to k > 1-D vectors adapt minorRank. minorRank = 1; majorRank = vectorType.getRank() - minorRank; - leadingRank = xferOp.getMemRefType().getRank() - (majorRank + minorRank); + leadingRank = xferOp.getLeadingMemRefRank(); majorVectorType = VectorType::get(vectorType.getShape().take_front(majorRank), vectorType.getElementType()); @@ -538,7 +538,7 @@ using namespace mlir::edsc::op; TransferReadOp transfer = cast(op); - if (AffineMap::isMinorIdentity(transfer.permutation_map())) { + if (transfer.permutation_map().isMinorIdentity()) { // If > 1D, emit a bunch of loops around 1-D vector transfers. if (transfer.getVectorType().getRank() > 1) return NDTransferOpHelper(rewriter, transfer, options) @@ -611,7 +611,7 @@ using namespace edsc::op; TransferWriteOp transfer = cast(op); - if (AffineMap::isMinorIdentity(transfer.permutation_map())) { + if (transfer.permutation_map().isMinorIdentity()) { // If > 1D, emit a bunch of loops around 1-D vector transfers. if (transfer.getVectorType().getRank() > 1) return NDTransferOpHelper(rewriter, transfer, options) diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -620,7 +620,7 @@ MLIRContext *ctx = extractOp.getContext(); AffineMap permutationMap = AffineMap::getPermutationMap(permutation, ctx); AffineMap minorMap = permutationMap.getMinorSubMap(minorRank); - if (minorMap && !AffineMap::isMinorIdentity(minorMap)) + if (minorMap && !minorMap.isMinorIdentity()) return failure(); // %1 = transpose %0[x, y, z] : vector @@ -730,7 +730,7 @@ unsigned minorRank = permutationMap.getNumResults() - insertedPos.size(); AffineMap minorMap = permutationMap.getMinorSubMap(minorRank); - if (!minorMap || AffineMap::isMinorIdentity(minorMap)) + if (!minorMap || minorMap.isMinorIdentity()) return insertOp.source(); } } @@ -1704,8 +1704,68 @@ return success(folded); } +template +static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) { + // TODO: support more aggressive createOrFold on: + // `op.indices()[indicesIdx] + vectorType < dim(op.memref(), indicesIdx)` + if (op.getMemRefType().isDynamicDim(indicesIdx)) + return false; + Value index = op.indices()[indicesIdx]; + auto cstOp = index.getDefiningOp(); + if (!cstOp) + return false; + + int64_t memrefSize = op.getMemRefType().getDimSize(indicesIdx); + int64_t vectorSize = op.getVectorType().getDimSize(resultIdx); + return cstOp.getValue() + vectorSize <= memrefSize; +} + +template +static LogicalResult foldTransferMaskAttribute(TransferOp op) { + AffineMap permutationMap = op.permutation_map(); + if (!permutationMap.isMinorIdentity()) + return failure(); + bool changed = false; + SmallVector isMasked; + isMasked.reserve(op.getTransferRank()); + // `permutationMap` results and `op.indices` sizes may not match and may not + // be aligned. The first `indicesIdx` may just be indexed and not transferred + // from/into the vector. + // For example: + // vector.transfer %0[%i, %j, %k, %c0] : memref, vector<2x4xf32> + // with `permutation_map = (d0, d1, d2, d3) -> (d2, d3)`. + // The `permutationMap` results and `op.indices` are however aligned when + // iterating in reverse until we exhaust `permutationMap` results. + // As a consequence we iterate with 2 running indices: `resultIdx` and + // `indicesIdx`, until `resultIdx` reaches 0. + for (int64_t resultIdx = permutationMap.getNumResults() - 1, + indicesIdx = op.indices().size() - 1; + resultIdx >= 0; --resultIdx, --indicesIdx) { + // Already marked unmasked, nothing to see here. + if (!op.isMaskedDim(resultIdx)) { + isMasked.push_back(false); + continue; + } + // Currently masked, check whether we can statically determine it is + // inBounds. + auto inBounds = isInBounds(op, resultIdx, indicesIdx); + isMasked.push_back(!inBounds); + // We commit the pattern if it is "more inbounds". + changed |= inBounds; + } + if (!changed) + return failure(); + // OpBuilder is only used as a helper to build an I64ArrayAttr. + OpBuilder b(op.getContext()); + std::reverse(isMasked.begin(), isMasked.end()); + op.setAttr(TransferOp::getMaskedAttrName(), b.getBoolArrayAttr(isMasked)); + return success(); +} + OpFoldResult TransferReadOp::fold(ArrayRef) { /// transfer_read(memrefcast) -> transfer_read + if (succeeded(foldTransferMaskAttribute(*this))) + return getResult(); if (succeeded(foldMemRefCast(*this))) return getResult(); return OpFoldResult(); @@ -1803,6 +1863,8 @@ LogicalResult TransferWriteOp::fold(ArrayRef, SmallVectorImpl &) { + if (succeeded(foldTransferMaskAttribute(*this))) + return success(); return foldMemRefCast(*this); } diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -104,11 +104,9 @@ return AffineMap::get(dims, 0, id.getResults().take_back(results), context); } -bool AffineMap::isMinorIdentity(AffineMap map) { - if (!map) - return false; - return map == getMinorIdentityMap(map.getNumDims(), map.getNumResults(), - map.getContext()); +bool AffineMap::isMinorIdentity() const { + return *this == + getMinorIdentityMap(getNumDims(), getNumResults(), getContext()); } /// Returns an AffineMap representing a permutation. diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -168,10 +168,10 @@ %f0 = constant 0.0 : f32 %0 = memref_cast %A : memref<4x8xf32> to memref - // CHECK: vector.transfer_read %{{.*}} : memref<4x8xf32>, vector<4x8xf32> + // CHECK: vector.transfer_read %{{.*}} {masked = [false, false]} : memref<4x8xf32>, vector<4x8xf32> %1 = vector.transfer_read %0[%c0, %c0], %f0 : memref, vector<4x8xf32> - // CHECK: vector.transfer_write %{{.*}} : vector<4x8xf32>, memref<4x8xf32> + // CHECK: vector.transfer_write %{{.*}} {masked = [false, false]} : vector<4x8xf32>, memref<4x8xf32> vector.transfer_write %1, %0[%c0, %c0] : vector<4x8xf32>, memref return %1 : vector<4x8xf32> } @@ -345,3 +345,30 @@ return %1, %3, %5 : vector<6xf32>, vector<6xf32>, vector<6xf32> } + +// ----- + +// CHECK-LABEL: fold_vector_transfers +func @fold_vector_transfers(%A: memref) -> (vector<4x8xf32>, vector<4x9xf32>) { + %c0 = constant 0 : index + %f0 = constant 0.0 : f32 + + // CHECK: vector.transfer_read %{{.*}} {masked = [true, false]} + %1 = vector.transfer_read %A[%c0, %c0], %f0 : memref, vector<4x8xf32> + + // CHECK: vector.transfer_write %{{.*}} {masked = [true, false]} + vector.transfer_write %1, %A[%c0, %c0] : vector<4x8xf32>, memref + + // Both dims masked, attribute is elided. + // CHECK: vector.transfer_read %{{.*}} + // CHECK-NOT: masked + %2 = vector.transfer_read %A[%c0, %c0], %f0 : memref, vector<4x9xf32> + + // Both dims masked, attribute is elided. + // CHECK: vector.transfer_write %{{.*}} + // CHECK-NOT: masked + vector.transfer_write %2, %A[%c0, %c0] : vector<4x9xf32>, memref + + // CHECK: return + return %1, %2 : vector<4x8xf32>, vector<4x9xf32> +} diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -248,10 +248,10 @@ // CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> -// CHECK-NEXT: vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] : vector<2x2xf32>, memref<4x4xf32> -// CHECK-NEXT: vector.transfer_write %[[R1]], %{{.*}}[%[[C0]], %[[C2]]] : vector<2x2xf32>, memref<4x4xf32> -// CHECK-NEXT: vector.transfer_write %[[R2]], %{{.*}}[%[[C2]], %[[C0]]] : vector<2x2xf32>, memref<4x4xf32> -// CHECK-NEXT: vector.transfer_write %[[R3]], %{{.*}}[%[[C2]], %[[C2]]] : vector<2x2xf32>, memref<4x4xf32> +// CHECK-NEXT: vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32> +// CHECK-NEXT: vector.transfer_write %[[R1]], %{{.*}}[%[[C0]], %[[C2]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32> +// CHECK-NEXT: vector.transfer_write %[[R2]], %{{.*}}[%[[C2]], %[[C0]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32> +// CHECK-NEXT: vector.transfer_write %[[R3]], %{{.*}}[%[[C2]], %[[C2]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32> // CHECK-NEXT: return func @contraction4x4_ikj_xfer_read(%arg0 : memref<4x2xf32>, @@ -296,8 +296,8 @@ %cst_1 = constant 2.000000e+00 : f32 affine.for %arg2 = 0 to %arg0 step 4 { affine.for %arg3 = 0 to %arg1 step 4 { - %4 = vector.transfer_read %0[%arg2, %arg3], %cst {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : memref, vector<4x4xf32> - %5 = vector.transfer_read %1[%arg2, %arg3], %cst {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : memref, vector<4x4xf32> + %4 = vector.transfer_read %0[%arg2, %arg3], %cst {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : memref, vector<4x4xf32> + %5 = vector.transfer_read %1[%arg2, %arg3], %cst {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : memref, vector<4x4xf32> %6 = addf %4, %5 : vector<4x4xf32> vector.transfer_write %6, %2[%arg2, %arg3] {permutation_map = affine_map<(d0, d1) -> (d0, d1)>} : vector<4x4xf32>, memref } @@ -426,10 +426,10 @@ // CHECK-LABEL: func @vector_transfers_vector_element_type // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C1:.*]] = constant 1 : index -// CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} : memref<6x2x1xvector<2x4xf32>>, vector<1x1x2x4xf32> -// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C1]], %[[C0]]], %{{.*}} : memref<6x2x1xvector<2x4xf32>>, vector<1x1x2x4xf32> -// CHECK-NEXT: vector.transfer_write %[[VTR0]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] : vector<1x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>> -// CHECK-NEXT: vector.transfer_write %[[VTR1]], %{{.*}}[%[[C0]], %[[C1]], %[[C0]]] : vector<1x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>> +// CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} {masked = [false, false]} : memref<6x2x1xvector<2x4xf32>>, vector<1x1x2x4xf32> +// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C1]], %[[C0]]], %{{.*}} {masked = [false, false]} : memref<6x2x1xvector<2x4xf32>>, vector<1x1x2x4xf32> +// CHECK-NEXT: vector.transfer_write %[[VTR0]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {masked = [false, false]} : vector<1x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>> +// CHECK-NEXT: vector.transfer_write %[[VTR1]], %{{.*}}[%[[C0]], %[[C1]], %[[C0]]] {masked = [false, false]} : vector<1x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>> func @vector_transfers_vector_element_type() { %c0 = constant 0 : index