Index: mlir/include/mlir/Dialect/Vector/VectorOps.h =================================================================== --- mlir/include/mlir/Dialect/Vector/VectorOps.h +++ mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -85,6 +85,12 @@ void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns, MLIRContext *context); +/// Collect a set of transfer read/write lowering patterns. +/// +/// These patterns lower transfer ops to simpler ops like `vector.load`. +void populateVectorTransferLoweringPatterns(OwningRewritePatternList &patterns, + MLIRContext *context); + /// An attribute that specifies the combining function for `vector.contract`, /// and `vector.reduction`. class CombiningKindAttr Index: mlir/lib/Dialect/Vector/VectorTransforms.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -37,6 +37,7 @@ #include "mlir/IR/Types.h" #include "mlir/Interfaces/VectorInterfaces.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -586,6 +587,39 @@ return true; } +/// Returns true if 'map' is a suffix of an identity affine map except for +/// broadcasted dimensions (which are indicated by 0 values in the result). If +/// `broadcastedDims` is not null, it will be populated with the indices of the +/// broadcasted dimensions. +/// Example: affine_map<(d0, d1, d2, d3, d4) -> (0, d2, 0, d4)> +/// `broadcastedDims` will contain [0, 2] (the result indices equal to 0) +static bool isMinorIdentityWithBroadcasting( + AffineMap map, SmallVectorImpl *broadcastedDims = nullptr) { + if (broadcastedDims) + broadcastedDims->clear(); + if (map.getNumDims() < map.getNumResults()) + return false; + unsigned suffixStart = map.getNumDims() - map.getNumResults(); + for (auto idxAndExpr : llvm::enumerate(map.getResults())) { + unsigned resIdx = idxAndExpr.index(); + AffineExpr expr = idxAndExpr.value(); + if (auto constExpr = expr.dyn_cast()) { + // Each result must be either a constant 0 (broadcasted dimension) + if (constExpr.getValue() != 0) + return false; + if (broadcastedDims) + broadcastedDims->push_back(resIdx); + } else if (auto dimExpr = expr.dyn_cast()) { + // or the input dimension corresponding to this result position. + if (dimExpr.getPosition() != suffixStart + resIdx) + return false; + } else { + return false; + } + } + return true; +} + /// Unroll transfer_read ops to the given shape and create an aggregate with all /// the chunks. static Value unrollTransferReadOp(vector::TransferReadOp readOp, @@ -2729,6 +2763,91 @@ } }; +/// Progressive lowering of transfer_read. +struct TransferReadToVectorLoadLowering + : public OpRewritePattern { + TransferReadToVectorLoadLowering(MLIRContext *context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(vector::TransferReadOp read, + PatternRewriter &rewriter) const override { + SmallVector broadcastedDims; + // TODO: Support permutations. + if (!isMinorIdentityWithBroadcasting(read.permutation_map(), + &broadcastedDims)) + return failure(); + // TODO: Support tensors. + auto memRefType = read.getShapedType().dyn_cast(); + if (!memRefType) + return failure(); + + // If there is broadcasting involved then we first load the unbroadcasted + // vector, and then broadcast it with `vector.broadcast`. + ArrayRef vectorShape = read.getVectorType().getShape(); + SmallVector unbroadcastedVectorShape(vectorShape.begin(), + vectorShape.end()); + for (unsigned i : broadcastedDims) + unbroadcastedVectorShape[i] = 1; + VectorType unbroadcastedVectorType = VectorType::get( + unbroadcastedVectorShape, read.getVectorType().getElementType()); + + // `vector.load` supports vector types as memref's elements only when the + // resulting vector type is the same as the element type. + if (memRefType.getElementType().isa() && + memRefType.getElementType() != unbroadcastedVectorType) + return failure(); + // Only the default layout is supported by `vector.load`. + // TODO: Support non-default layouts. + if (!memRefType.getAffineMaps().empty()) + return failure(); + // TODO: When masking is required, we can create a MaskedLoadOp + if (read.hasMaskedDim()) + return failure(); + auto loadOp = rewriter.create( + read.getLoc(), unbroadcastedVectorType, read.source(), read.indices()); + // Insert a broadcasting op if required. + if (!broadcastedDims.empty()) { + rewriter.replaceOpWithNewOp( + read, read.getVectorType(), loadOp.getResult()); + } else { + rewriter.replaceOp(read, loadOp.getResult()); + } + + return success(); + } +}; + +/// Progressive lowering of transfer_write. +struct TransferWriteToVectorStoreLowering + : public OpRewritePattern { + TransferWriteToVectorStoreLowering(MLIRContext *context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(vector::TransferWriteOp write, + PatternRewriter &rewriter) const override { + // TODO: Support non-minor-identity maps + if (!write.permutation_map().isMinorIdentity()) + return failure(); + // TODO: Support tensors. + auto memRefType = write.getShapedType().dyn_cast(); + if (!memRefType) + return failure(); + // `vector.store` supports vector types as memref's elements only when the + // type of the vector value being written is the same as the element type. + if (memRefType.getElementType().isa() && + memRefType.getElementType() != write.getVectorType()) + return failure(); + // Only the default layout is supported by `vector.store`. + // TODO: Support non-default layouts. + if (!memRefType.getAffineMaps().empty()) + return failure(); + // TODO: When masking is required, we can create a MaskedStoreOp + if (write.hasMaskedDim()) + return failure(); + rewriter.replaceOpWithNewOp( + write, write.vector(), write.source(), write.indices()); + return success(); + } +}; + // Trims leading one dimensions from `oldType` and returns the result type. // Returns `vector<1xT>` if `oldType` only has one element. static VectorType trimLeadingOneDims(VectorType oldType) { @@ -3201,3 +3320,9 @@ ContractionOpToOuterProductOpLowering>(parameters, context); // clang-format on } + +void mlir::vector::populateVectorTransferLoweringPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert(context); +} Index: mlir/test/Dialect/Vector/vector-transfer-lowering.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Vector/vector-transfer-lowering.mlir @@ -0,0 +1,188 @@ +// RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -split-input-file | FileCheck %s + +// transfer_read/write are lowered to vector.load/store +// CHECK-LABEL: func @transfer_to_load( +// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, +// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> { +// CHECK-NEXT: %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<4xf32> +// CHECK-NEXT: vector.store %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<4xf32> +// CHECK-NEXT: return %[[RES]] : vector<4xf32> +// CHECK-NEXT: } + +func @transfer_to_load(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> { + %cf0 = constant 0.0 : f32 + %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false]} : memref<8x8xf32>, vector<4xf32> + vector.transfer_write %res, %mem[%i, %i] {masked = [false]} : vector<4xf32>, memref<8x8xf32> + return %res : vector<4xf32> +} + +// ----- + +// n-D results are also supported. +// CHECK-LABEL: func @transfer_2D( +// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, +// CHECK-SAME: %[[IDX:.*]]: index) -> vector<2x4xf32> { +// CHECK-NEXT: %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<2x4xf32> +// CHECK-NEXT: vector.store %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<2x4xf32> +// CHECK-NEXT: return %[[RES]] : vector<2x4xf32> +// CHECK-NEXT: } + +func @transfer_2D(%mem : memref<8x8xf32>, %i : index) -> vector<2x4xf32> { + %cf0 = constant 0.0 : f32 + %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false, false]} : memref<8x8xf32>, vector<2x4xf32> + vector.transfer_write %res, %mem[%i, %i] {masked = [false, false]} : vector<2x4xf32>, memref<8x8xf32> + return %res : vector<2x4xf32> +} + +// ----- + +// Vector element types are supported when the result has the same type. +// CHECK-LABEL: func @transfer_vector_element( +// CHECK-SAME: %[[MEM:.*]]: memref<8x8xvector<2x4xf32>>, +// CHECK-SAME: %[[IDX:.*]]: index) -> vector<2x4xf32> { +// CHECK-NEXT: %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xvector<2x4xf32>>, vector<2x4xf32> +// CHECK-NEXT: vector.store %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xvector<2x4xf32>>, vector<2x4xf32> +// CHECK-NEXT: return %[[RES]] : vector<2x4xf32> +// CHECK-NEXT: } + +func @transfer_vector_element(%mem : memref<8x8xvector<2x4xf32>>, %i : index) -> vector<2x4xf32> { + %cf0 = constant dense<0.0> : vector<2x4xf32> + %res = vector.transfer_read %mem[%i, %i], %cf0 : memref<8x8xvector<2x4xf32>>, vector<2x4xf32> + vector.transfer_write %res, %mem[%i, %i] : vector<2x4xf32>, memref<8x8xvector<2x4xf32>> + return %res : vector<2x4xf32> +} + +// ----- + +// Vector element types are not supported yet when the result has a different +// type. +// CHECK-LABEL: func @transfer_vector_element_different_types( +// CHECK-SAME: %[[MEM:.*]]: memref<8x8xvector<2x4xf32>>, +// CHECK-SAME: %[[IDX:.*]]: index) -> vector<1x2x4xf32> { +// CHECK-NEXT: %[[CF0:.*]] = constant dense<0.000000e+00> : vector<2x4xf32> +// CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {masked = [false]} : memref<8x8xvector<2x4xf32>>, vector<1x2x4xf32> +// CHECK-NEXT: vector.transfer_write %[[RES:.*]], %[[MEM]][%[[IDX]], %[[IDX]]] {masked = [false]} : vector<1x2x4xf32>, memref<8x8xvector<2x4xf32>> +// CHECK-NEXT: return %[[RES]] : vector<1x2x4xf32> +// CHECK-NEXT: } + +func @transfer_vector_element_different_types(%mem : memref<8x8xvector<2x4xf32>>, %i : index) -> vector<1x2x4xf32> { + %cf0 = constant dense<0.0> : vector<2x4xf32> + %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false]} : memref<8x8xvector<2x4xf32>>, vector<1x2x4xf32> + vector.transfer_write %res, %mem[%i, %i] {masked = [false]} : vector<1x2x4xf32>, memref<8x8xvector<2x4xf32>> + return %res : vector<1x2x4xf32> +} + +// ----- + +// transfer_read/write cannot be lowered because there is an unmasked dimension. +// CHECK-LABEL: func @transfer_2D_masked( +// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, +// CHECK-SAME: %[[IDX:.*]]: index) -> vector<2x4xf32> { +// CHECK-NEXT: %[[CF0:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {masked = [false, true]} : memref<8x8xf32>, vector<2x4xf32> +// CHECK-NEXT: vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] {masked = [true, false]} : vector<2x4xf32>, memref<8x8xf32> +// CHECK-NEXT: return %[[RES]] : vector<2x4xf32> +// CHECK-NEXT: } + +func @transfer_2D_masked(%mem : memref<8x8xf32>, %i : index) -> vector<2x4xf32> { + %cf0 = constant 0.0 : f32 + %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false, true]} : memref<8x8xf32>, vector<2x4xf32> + vector.transfer_write %res, %mem[%i, %i] {masked = [true, false]} : vector<2x4xf32>, memref<8x8xf32> + return %res : vector<2x4xf32> +} + +// ----- + +// transfer_read/write cannot be lowered because they are masked. +// CHECK-LABEL: func @transfer_masked( +// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, +// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> { +// CHECK-NEXT: %[[CF0:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] : memref<8x8xf32>, vector<4xf32> +// CHECK-NEXT: vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] : vector<4xf32>, memref<8x8xf32> +// CHECK-NEXT: return %[[RES]] : vector<4xf32> +// CHECK-NEXT: } + +func @transfer_masked(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> { + %cf0 = constant 0.0 : f32 + %res = vector.transfer_read %mem[%i, %i], %cf0 : memref<8x8xf32>, vector<4xf32> + vector.transfer_write %res, %mem[%i, %i] : vector<4xf32>, memref<8x8xf32> + return %res : vector<4xf32> +} + +// ----- + +// transfer_read/write cannot be lowered to vector.load/store because the memref +// has a non-default layout. +// CHECK-LABEL: func @transfer_nondefault_layout( +// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32, #{{.*}}>, +// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> { +// CHECK-NEXT: %[[CF0:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {masked = [false]} : memref<8x8xf32, #{{.*}}>, vector<4xf32> +// CHECK-NEXT: vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] {masked = [false]} : vector<4xf32>, memref<8x8xf32, #{{.*}}> +// CHECK-NEXT: return %[[RES]] : vector<4xf32> +// CHECK-NEXT: } + +#layout = affine_map<(d0, d1) -> (d0*16 + d1)> +func @transfer_nondefault_layout(%mem : memref<8x8xf32, #layout>, %i : index) -> vector<4xf32> { + %cf0 = constant 0.0 : f32 + %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false]} : memref<8x8xf32, #layout>, vector<4xf32> + vector.transfer_write %res, %mem[%i, %i] {masked = [false]} : vector<4xf32>, memref<8x8xf32, #layout> + return %res : vector<4xf32> +} + +// ----- + +// transfer_read/write cannot be lowered to vector.load/store yet when the +// permutation map is not the minor identity map (up to broadcasting). +// CHECK-LABEL: func @transfer_perm_map( +// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, +// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> { +// CHECK-NEXT: %[[CF0:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {masked = [false], permutation_map = #{{.*}}} : memref<8x8xf32>, vector<4xf32> +// CHECK-NEXT: vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] {masked = [false], permutation_map = #{{.*}}} : vector<4xf32>, memref<8x8xf32> +// CHECK-NEXT: return %[[RES]] : vector<4xf32> +// CHECK-NEXT: } + +func @transfer_perm_map(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> { + %cf0 = constant 0.0 : f32 + %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false], permutation_map = affine_map<(d0, d1) -> (d0)>} : memref<8x8xf32>, vector<4xf32> + vector.transfer_write %res, %mem[%i, %i] {masked = [false], permutation_map = affine_map<(d0, d1) -> (d0)>} : vector<4xf32>, memref<8x8xf32> + return %res : vector<4xf32> +} + +// ----- + +// Lowering of transfer_read with broadcasting is supported. +// CHECK-LABEL: func @transfer_broadcasting( +// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, +// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> { +// CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<1xf32> +// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : vector<1xf32> to vector<4xf32> +// CHECK-NEXT: return %[[RES]] : vector<4xf32> +// CHECK-NEXT: } + +#broadcast0 = affine_map<(d0, d1) -> (0)> +func @transfer_broadcasting(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> { + %cf0 = constant 0.0 : f32 + %res = vector.transfer_read %mem[%i, %i], %cf0 {masked = [false], permutation_map = #broadcast0} : memref<8x8xf32>, vector<4xf32> + return %res : vector<4xf32> +} + +// ----- + +// More complex broadcasting case. +// CHECK-LABEL: func @transfer_broadcasting_1( +// CHECK-SAME: %[[MEM:.*]]: memref<10x20x30x8x8xf32>, +// CHECK-SAME: %[[IDX:.*]]: index) -> vector<3x2x4x5xf32> { +// CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]] : memref<10x20x30x8x8xf32>, vector<3x1x1x5xf32> +// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : vector<3x1x1x5xf32> to vector<3x2x4x5xf32> +// CHECK-NEXT: return %[[RES]] : vector<3x2x4x5xf32> +// CHECK-NEXT: } + +#broadcast1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, 0, 0, d4)> +func @transfer_broadcasting_1(%mem : memref<10x20x30x8x8xf32>, %i : index) -> vector<3x2x4x5xf32> { + %cf0 = constant 0.0 : f32 + %res = vector.transfer_read %mem[%i, %i, %i, %i, %i], %cf0 {masked = [false, false, false, false], permutation_map = #broadcast1} : memref<10x20x30x8x8xf32>, vector<3x2x4x5xf32> + return %res : vector<3x2x4x5xf32> +} Index: mlir/test/lib/Transforms/TestVectorTransforms.cpp =================================================================== --- mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -361,6 +361,15 @@ void runOnFunction() override { transferOpflowOpt(getFunction()); } }; +struct TestVectorTransferLoweringPatterns + : public PassWrapper { + void runOnFunction() override { + OwningRewritePatternList patterns; + populateVectorTransferLoweringPatterns(patterns, &getContext()); + (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); + } +}; + } // end anonymous namespace namespace mlir { @@ -400,6 +409,9 @@ PassRegistration transferOpOpt( "test-vector-transferop-opt", "Test optimization transformations for transfer ops"); + PassRegistration transferOpLoweringPass( + "test-vector-transfer-lowering-patterns", + "Test conversion patterns to lower transfer ops to other vector ops"); } } // namespace test } // namespace mlir